| @@ -73,7 +73,7 @@ template <typename T> | |||
| void FinalizeAllocatorMap(std::map<rtMemType_t, T *> &allocate_map) { | |||
| for (auto &allocator : allocate_map) { | |||
| if (allocator.second != nullptr) { | |||
| allocator.second->Finalize(); | |||
| allocator.second->Finalize(ge::GetContext().DeviceId()); | |||
| delete allocator.second; | |||
| allocator.second = nullptr; | |||
| } | |||
| @@ -26,9 +26,9 @@ | |||
| #include "graph/utils/graph_utils.h" | |||
| #include "graph/utils/op_desc_utils.h" | |||
| #include "init/gelib.h" | |||
| namespace { | |||
| const int kNoTransOp = 1; | |||
| const uint32_t kIndexZero = 0; | |||
| } // namespace | |||
| namespace ge { | |||
| @@ -895,6 +895,63 @@ graphStatus SameTransdataBreadthFusionPass::AddCastNode(const ComputeGraphPtr &g | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus SameTransdataBreadthFusionPass::CheckAccuracySupported(const OpDescPtr &op_desc, bool &is_supported) { | |||
| // is_supported is set to be false as default value. | |||
| auto instance = GELib::GetInstance(); | |||
| if ((instance == nullptr) || (!instance->InitFlag())) { | |||
| REPORT_INNER_ERROR("E19999", "GELib is not initialized!"); | |||
| GELOGE(GRAPH_FAILED, "GELib is not initialized!"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| OpsKernelManager &ops_kernel_manager = instance->OpsKernelManagerObj(); | |||
| vector<OpInfo> op_infos = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType()); | |||
| if (op_infos.empty()) { | |||
| GELOGI("Can not get op info by op type:%s", op_desc->GetType().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| std::string unsupported_reason; | |||
| for (const auto &it : op_infos) { | |||
| auto kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores(); | |||
| auto &kernel_name = it.opKernelLib; | |||
| auto kernel_info_store = kernel_map.find(kernel_name); | |||
| if (kernel_info_store != kernel_map.end()) { | |||
| if (kernel_info_store->second != nullptr && | |||
| kernel_info_store->second->CheckAccuracySupported(op_desc, unsupported_reason)) { | |||
| GELOGI("OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(), it.engine.c_str(), | |||
| op_desc->GetName().c_str()); | |||
| is_supported = true; | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| } | |||
| } | |||
| GELOGI("op:%s CheckAccuracySupported failed!reason:%s", op_desc->GetName().c_str(), unsupported_reason.c_str()); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| // avoid scene: A->Cast->TransData while A's DataType is not supported by TransData | |||
| graphStatus SameTransdataBreadthFusionPass::CheckTransDataSupported(const NodePtr &node, bool &is_supported) { | |||
| auto op_desc = node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| auto input_desc = op_desc->GetInputDescPtr(kIndexZero); | |||
| GE_CHECK_NOTNULL(input_desc); | |||
| auto in_nodes = node->GetInDataNodes(); | |||
| for (const auto &in_node : in_nodes) { | |||
| if (in_node->GetType() != TRANSDATA) { | |||
| continue; | |||
| } | |||
| auto transdata_op_desc = std::make_shared<ge::OpDesc>(TRANSDATA, TRANSDATA); | |||
| GE_CHECK_NOTNULL(transdata_op_desc); | |||
| transdata_op_desc->AddInputDesc(*input_desc); | |||
| if (CheckAccuracySupported(transdata_op_desc, is_supported) != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "[Check][AccuracySupported] failed."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus SameTransdataBreadthFusionPass::GetSubGraphsBetweenNormalAndTransdataNode( | |||
| OutDataAnchorPtr &out_anchor, | |||
| std::vector<std::vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>>> &sub_graphs_out, | |||
| @@ -925,6 +982,18 @@ graphStatus SameTransdataBreadthFusionPass::GetSubGraphsBetweenNormalAndTransdat | |||
| continue; | |||
| } | |||
| } | |||
| // avoid transdata receiving unsupported datatype input after deleting cast node. | |||
| // peer_in_node is cast op. | |||
| bool is_supported = false; | |||
| if (CheckTransDataSupported(peer_in_node, is_supported) != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "[Check][Param] CheckTransDataSupported failed!"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| if (!is_supported) { | |||
| GELOGD("CheckAccuracySupported return unsupported for transdata constructed from node [%s]'s output, skip it.", | |||
| peer_in_node->GetName().c_str()); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| for (auto &peer_out_anchor : peer_in_node->GetAllOutDataAnchors()) { | |||
| ret = GetSubGraphsBetweenNormalAndTransdataNode(peer_out_anchor, sub_graphs_out, nodes_list); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| @@ -107,6 +107,10 @@ class SameTransdataBreadthFusionPass : public GraphPass { | |||
| static bool IsHandleOp(const NodePtr &node); | |||
| static graphStatus CheckTransDataSupported(const NodePtr &node, bool &is_supported); | |||
| static graphStatus CheckAccuracySupported(const OpDescPtr &op_desc, bool &is_supported); | |||
| vector<vector<pair<OutDataAnchorPtr, InDataAnchorPtr>>> sub_graph_anchors_; | |||
| vector<vector<NodePtr>> before_transdata_nodes_; | |||
| vector<pair<int, InDataAnchorPtr>> all_transdata_nodes_; | |||
| @@ -422,6 +422,7 @@ set(GRAPH_PASS_COMMON_SRC_FILES | |||
| "${GE_CODE_DIR}/ge/graph/passes/parallel_group_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/buffer_pool_memory_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/mark_node_unknown_shape_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/same_transdata_breadth_fusion_pass.cc" | |||
| ) | |||
| set(KERNEL_SRC_FILES | |||
| @@ -603,6 +604,7 @@ set(PASS_TEST_FILES | |||
| "graph/passes/memcpy_addr_async_unittest.cc" | |||
| "graph/passes/hccl_continuous_pass_unittest.cc" | |||
| "graph/passes/hccl_memcpy_pass_unittest.cc" | |||
| "graph/passes/same_transdata_breadth_fusion_pass_unittest.cc" | |||
| ) | |||
| set(KERNEL_TEST_FILES | |||
| @@ -0,0 +1,121 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "graph/passes/same_transdata_breadth_fusion_pass.cc" | |||
| #include <gtest/gtest.h> | |||
| #include <string> | |||
| using namespace ge; | |||
| class UtestGraphPassesSameTransdataBreadthFusionPass : public testing::Test { | |||
| protected: | |||
| void SetUp() {} | |||
| void TearDown() {} | |||
| }; | |||
| class NodeBuilder { | |||
| public: | |||
| NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared<OpDesc>(name, type); } | |||
| NodeBuilder &AddInputDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW, | |||
| ge::DataType data_type = DT_FLOAT) { | |||
| op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); | |||
| return *this; | |||
| } | |||
| NodeBuilder &AddOutputDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW, | |||
| ge::DataType data_type = DT_FLOAT) { | |||
| op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); | |||
| return *this; | |||
| } | |||
| ge::NodePtr Build(const ge::ComputeGraphPtr &graph) { return graph->AddNode(op_desc_); } | |||
| private: | |||
| ge::GeTensorDescPtr CreateTensorDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW, | |||
| ge::DataType data_type = DT_FLOAT) { | |||
| GeShape ge_shape{std::vector<int64_t>(shape)}; | |||
| ge::GeTensorDescPtr tensor_desc = std::make_shared<ge::GeTensorDesc>(); | |||
| tensor_desc->SetShape(ge_shape); | |||
| tensor_desc->SetFormat(format); | |||
| tensor_desc->SetDataType(data_type); | |||
| return tensor_desc; | |||
| } | |||
| ge::OpDescPtr op_desc_; | |||
| }; | |||
| TEST_F(UtestGraphPassesSameTransdataBreadthFusionPass, test_unsupported_transdata_succ) { | |||
| // Node4D(NCHW)->cast1(DT_BOOL->FP16)->transdata1(NCHW->NC1HWC0)->sinh1 | |||
| // / | |||
| // --->cast2(DT_BOOL->FP16)->transdata2(NCHW->NC1HWC0)->sinh2 | |||
| ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
| // Node4D | |||
| ge::NodePtr node_data = NodeBuilder("Data4D", DATA).AddOutputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_BOOL).Build(graph); | |||
| // cast1 | |||
| ge::NodePtr node_cast_1 = NodeBuilder("node_cast_1", CAST) | |||
| .AddInputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_BOOL) | |||
| .AddOutputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_FLOAT16) | |||
| .Build(graph); | |||
| auto src_name = node_data->GetName(); | |||
| node_cast_1->GetOpDesc()->SetSrcName({src_name}); | |||
| node_cast_1->GetOpDesc()->SetInputName({src_name}); | |||
| AttrUtils::SetInt(node_cast_1->GetOpDesc(), CAST_ATTR_SRCT, DT_FLOAT); | |||
| // trandata1 | |||
| ge::NodePtr node_transdata_1 = NodeBuilder("node_transdata_1", TRANSDATA) | |||
| .AddInputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_FLOAT16) | |||
| .AddOutputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16) | |||
| .Build(graph); | |||
| // sinh1 | |||
| ge::NodePtr node_sinh_1 = NodeBuilder("node_sinh_1", SINH) | |||
| .AddInputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16) | |||
| .AddOutputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16) | |||
| .Build(graph); | |||
| // cast2 | |||
| ge::NodePtr node_cast_2 = NodeBuilder("node_cast_2", CAST) | |||
| .AddInputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_BOOL) | |||
| .AddOutputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_FLOAT16) | |||
| .Build(graph); | |||
| node_cast_2->GetOpDesc()->SetSrcName({src_name}); | |||
| node_cast_2->GetOpDesc()->SetInputName({src_name}); | |||
| // transdata2 | |||
| ge::NodePtr node_transdata_2 = NodeBuilder("node_transdata_2", TRANSDATA) | |||
| .AddInputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_FLOAT16) | |||
| .AddOutputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16) | |||
| .Build(graph); | |||
| // sinh2 | |||
| ge::NodePtr node_sinh_2 = NodeBuilder("node_sinh_2", SINH) | |||
| .AddInputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16) | |||
| .AddOutputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16) | |||
| .Build(graph); | |||
| // add edge | |||
| ge::GraphUtils::AddEdge(node_data->GetOutDataAnchor(0), node_cast_1->GetInDataAnchor(0)); | |||
| ge::GraphUtils::AddEdge(node_cast_1->GetOutDataAnchor(0), node_transdata_1->GetInDataAnchor(0)); | |||
| ge::GraphUtils::AddEdge(node_transdata_1->GetOutDataAnchor(0), node_sinh_1->GetInDataAnchor(0)); | |||
| ge::GraphUtils::AddEdge(node_data->GetOutDataAnchor(0), node_cast_2->GetInDataAnchor(0)); | |||
| ge::GraphUtils::AddEdge(node_cast_2->GetOutDataAnchor(0), node_transdata_2->GetInDataAnchor(0)); | |||
| ge::GraphUtils::AddEdge(node_transdata_2->GetOutDataAnchor(0), node_sinh_2->GetInDataAnchor(0)); | |||
| ge::SameTransdataBreadthFusionPass pass; | |||
| ge::graphStatus status = pass.Run(graph); | |||
| EXPECT_EQ(ge::GRAPH_SUCCESS, status); | |||
| } | |||