diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index 317ff00a..17e8e80a 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -157,7 +157,9 @@ set(TRAIN_SRC_LIST "graph/passes/compile_nodes_pass.cc" "graph/passes/constant_folding_pass.cc" "graph/passes/constant_fuse_same_pass.cc" + "graph/passes/fuse_data_nodes_with_common_input_pass.cc" "graph/passes/remove_same_const_pass.cc" + "graph/passes/no_data_out_const_elimination_pass.cc" "graph/passes/useless_control_out_remove_pass.cc" "graph/passes/control_trigger_pass.cc" "graph/passes/dimension_adjust_pass.cc" @@ -439,6 +441,7 @@ set(INFER_SRC_LIST "graph/passes/net_output_pass.cc" "graph/passes/replace_transshape_pass.cc" "graph/passes/constant_fuse_same_pass.cc" + "graph/passes/fuse_data_nodes_with_common_input_pass.cc" "graph/passes/print_op_pass.cc" "graph/passes/no_use_reshape_remove_pass.cc" "graph/passes/iterator_op_pass.cc" @@ -535,6 +538,7 @@ set(INFER_SRC_LIST "graph/passes/addn_pass.cc" "graph/passes/common_subexpression_elimination_pass.cc" "graph/passes/remove_same_const_pass.cc" + "graph/passes/no_data_out_const_elimination_pass.cc" "graph/passes/useless_control_out_remove_pass.cc" "graph/passes/transop_symmetry_elimination_pass.cc" "graph/passes/save_pass.cc" diff --git a/ge/ge_inference.mk b/ge/ge_inference.mk index 74d09404..1830e847 100755 --- a/ge/ge_inference.mk +++ b/ge/ge_inference.mk @@ -103,6 +103,7 @@ OMG_HOST_SRC_FILES := \ graph/passes/net_output_pass.cc \ graph/passes/replace_transshape_pass.cc \ graph/passes/constant_fuse_same_pass.cc \ + graph/passes/fuse_data_nodes_with_common_input_pass.cc \ graph/passes/print_op_pass.cc \ graph/passes/no_use_reshape_remove_pass.cc \ graph/passes/iterator_op_pass.cc \ @@ -193,6 +194,7 @@ OMG_HOST_SRC_FILES := \ graph/passes/cond_pass.cc \ graph/passes/cond_remove_pass.cc \ graph/passes/remove_same_const_pass.cc \ + graph/passes/no_data_out_const_elimination_pass.cc \ graph/passes/useless_control_out_remove_pass.cc \ graph/passes/for_pass.cc \ graph/passes/enter_pass.cc \ diff --git a/ge/ge_runner.mk b/ge/ge_runner.mk index 5a99dc8c..9dcac211 100644 --- a/ge/ge_runner.mk +++ b/ge/ge_runner.mk @@ -127,7 +127,9 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/compile_nodes_pass.cc \ graph/passes/constant_folding_pass.cc \ graph/passes/constant_fuse_same_pass.cc \ + graph/passes/fuse_data_nodes_with_common_input_pass.cc \ graph/passes/remove_same_const_pass.cc \ + graph/passes/no_data_out_const_elimination_pass.cc \ graph/passes/useless_control_out_remove_pass.cc \ graph/passes/control_trigger_pass.cc \ graph/passes/dimension_adjust_pass.cc \ diff --git a/ge/graph/load/new_model_manager/zero_copy_offset.h b/ge/graph/load/new_model_manager/zero_copy_offset.h index 8ead742d..66fcd887 100644 --- a/ge/graph/load/new_model_manager/zero_copy_offset.h +++ b/ge/graph/load/new_model_manager/zero_copy_offset.h @@ -65,7 +65,7 @@ class ZeroCopyOffset { // data_size of Data/Netoutput int64_t GetDataSize() const { return data_size_; } // value of *outside_addrs_ from davinci_model - std::vector>> &GetOutsideAddrs() { return outside_addrs_; } + const std::vector>> &GetOutsideAddrs() { return outside_addrs_; } // name of op std::string GetOpName() const { return op_name_; } diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index aec811e4..ae516a8f 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -53,6 +53,7 @@ #include "graph/passes/dimension_adjust_pass.h" #include "graph/passes/dimension_compute_pass.h" #include "graph/passes/flow_ctrl_pass.h" +#include "graph/passes/fuse_data_nodes_with_common_input_pass.h" #include "graph/passes/identity_pass.h" #include "graph/passes/input_output_connection_identify_pass.h" #include "graph/passes/iterator_op_pass.h" @@ -70,6 +71,7 @@ #include "graph/passes/remove_same_const_pass.h" #include "graph/passes/reshape_recovery_pass.h" #include "graph/passes/reshape_remove_pass.h" +#include "graph/passes/no_data_out_const_elimination_pass.h" #include "graph/passes/same_transdata_breadth_fusion_pass.h" #include "graph/passes/subgraph_pass.h" #include "graph/passes/switch_data_edges_bypass.h" @@ -2104,6 +2106,24 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { after_merge_passes.AddPass("OptimizeStage1_1::SwitchDataEdgesBypass", new (std::nothrow) SwitchDataEdgesBypass)); GE_CHK_STATUS_RET( after_merge_passes.AddPass("OptimizeStage1_1::ConstantFuseSamePass", new (std::nothrow) ConstantFuseSamePass)); + /* + * Do CSE before FuseDataNodesWithCommonInputPass to resolve the scene in bertlarge as following: + * const + * / | \ + * cast1 cast2 cast3 + * \ | / + * case + * the node `const` is the fused const node after ConstantFuseSamePass + * the nodes `cast1`, `cast2` and 'cast3' will be fused by CSE. + * in order to eliminate hard code in FuseDataNodesWithCommonInputPass, + * we do CSE before FuseDataNodesWithCommonInputPass + * But it is a temp solution, this CSE will be deleted after change pass from graph pass to node pass + */ + GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::CSEBeforeFuseDataNodesWithCommonInputPass", + new (std::nothrow) CommonSubexpressionEliminationPass)); + // FuseDataNodesWithCommonInputPass: fuse same data with common input in same graph + GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::FuseDataNodesWithCommonInputPass", + new (std::nothrow) FuseDataNodesWithCommonInputPass)); GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::CommonSubexpressionEliminationPass", new (std::nothrow) CommonSubexpressionEliminationPass)); GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::PermutePass", new (std::nothrow) PermutePass)) @@ -2226,12 +2246,14 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { GELOGE(ret, "Run passes when OptimizeStage1_3 failed, ret:%u.", ret); return ret; } - NamesToPass identity_remove_pass; - GE_TIMESTAMP_START(identity_remove_pass); + NamesToPass node_pass; + GE_TIMESTAMP_START(node_pass); IdentityPass identity_force_pass(false); // after SwitchToStreamSwitchPass - identity_remove_pass.emplace_back("IdentityPass", &identity_force_pass); - ret = GEPass(compute_graph).Run(identity_remove_pass); - GE_TIMESTAMP_END(identity_remove_pass, "GraphPrepare::IdentityRemovePass"); + NoDataOutConstEliminationPass no_data_out_const_elimination_pass; + node_pass.emplace_back("IdentityPass", &identity_force_pass); + node_pass.emplace_back("NoDataOutConstEliminationPass", &no_data_out_const_elimination_pass); + ret = GEPass(compute_graph).Run(node_pass); + GE_TIMESTAMP_END(node_pass, "GraphPrepare::node_pass"); if (ret != SUCCESS) { GELOGE(ret, "Run identity remove pass for preprocess failed, ret:%u.", ret); return ret; diff --git a/ge/graph/passes/fuse_data_nodes_with_common_input_pass.cc b/ge/graph/passes/fuse_data_nodes_with_common_input_pass.cc new file mode 100644 index 00000000..ab8fc39b --- /dev/null +++ b/ge/graph/passes/fuse_data_nodes_with_common_input_pass.cc @@ -0,0 +1,119 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/fuse_data_nodes_with_common_input_pass.h" + +#include +#include +#include +#include +#include +#include "common/ge_inner_error_codes.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/type_utils.h" +#include "graph/utils/node_utils.h" + +using std::map; +using std::vector; +using std::set; +using std::string; + +namespace ge { +Status FuseDataNodesWithCommonInputPass::Run(ge::ComputeGraphPtr graph) { + if (graph == nullptr) { + GELOGE(GE_GRAPH_PARAM_NULLPTR, "Compute graph is null."); + return GE_GRAPH_PARAM_NULLPTR; + } + GELOGD("FuseDataNodesWithCommonInputPass in."); + // key: subgraph, value:--key: peer out anchor to parent node, --value: parent indexes to parent node + map>> subgraphs_to_need_fuse_nodes_info; + if (InitNeedFuseNodesInfo(graph, subgraphs_to_need_fuse_nodes_info) != SUCCESS) { + GELOGE(FAILED, "InitNeedFuseNodesInfo failed."); + return FAILED; + } + return FuseDataNodes(subgraphs_to_need_fuse_nodes_info); +} + +Status FuseDataNodesWithCommonInputPass::InitNeedFuseNodesInfo(ComputeGraphPtr &graph, + map>> &subgraphs_to_need_fuse_nodes_info) { + for (const auto &subgraph : graph->GetAllSubgraphs()) { + GE_CHECK_NOTNULL(subgraph); + auto parent_node = subgraph->GetParentNode(); + GE_CHECK_NOTNULL(parent_node); + if (parent_node->GetType() == CASE || parent_node->GetType() == IF) { + auto &peer_out_anchors_to_parent_indexes = subgraphs_to_need_fuse_nodes_info[subgraph]; + for (const auto &in_data_anchor : parent_node->GetAllInDataAnchors()) { + GE_CHECK_NOTNULL(in_data_anchor); + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + uint32_t parent_index = static_cast(in_data_anchor->GetIdx()); + GE_CHECK_NOTNULL(peer_out_anchor); + peer_out_anchors_to_parent_indexes[peer_out_anchor].insert(parent_index); + GELOGD("Peer node %s is the %d input of parent node %s in %s.", + peer_out_anchor->GetOwnerNode()->GetName().c_str(), parent_index, parent_node->GetName().c_str(), + subgraph->GetName().c_str()); + } + } + } + return SUCCESS; +} + +Status FuseDataNodesWithCommonInputPass::FuseDataNodes( + const map>> &subgraphs_to_need_fuse_nodes_info) { + for (const auto &subgraph_to_need_fuse_nodes_info : subgraphs_to_need_fuse_nodes_info) { + auto subgraph = subgraph_to_need_fuse_nodes_info.first; + for (const auto &peer_out_anchors_to_parent_indexes : subgraph_to_need_fuse_nodes_info.second) { + if (peer_out_anchors_to_parent_indexes.second.size() <= 1) { + continue; + } + // key: out anchor, value: data nodes with common input will be fused + map> peer_out_anchors_to_need_fuse_nodes; + for (const auto &node : subgraph->GetDirectNode()) { + if (node->GetType() != DATA) { + continue; + } + GE_CHECK_NOTNULL(node->GetOpDesc()); + uint32_t parent_index = 0; + if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + if (peer_out_anchors_to_parent_indexes.second.count(parent_index) > 0) { + peer_out_anchors_to_need_fuse_nodes[peer_out_anchors_to_parent_indexes.first].emplace_back(node); + } + } + } + for (const auto &peer_out_anchor_to_need_fuse_nodes : peer_out_anchors_to_need_fuse_nodes) { + auto need_fuse_data_nodes = peer_out_anchor_to_need_fuse_nodes.second; + auto first_node = need_fuse_data_nodes.at(0); + for (size_t i = 1; i < need_fuse_data_nodes.size(); ++i) { + auto node = need_fuse_data_nodes.at(i); + GELOGI("Replace redundant data node %s by %s exist in graph: %s.", node->GetName().c_str(), + first_node->GetName().c_str(), subgraph->GetName().c_str()); + // the data node which can be fused has none input(both data and control in) + if (GraphUtils::MoveOutCtrlEdges(node, first_node) != SUCCESS) { + return FAILED; + } + if (GraphUtils::ReplaceNodeDataAnchors(first_node, node, {}, {0}) != SUCCESS) { + return FAILED; + } + if (GraphUtils::RemoveNodeWithoutRelink(subgraph, node) != SUCCESS) { + GELOGE(FAILED, "[%s] RemoveNodeWithoutRelink failed.", node->GetName().c_str()); + return FAILED; + } + } + } + } + } + return SUCCESS; +} +} // namespace ge diff --git a/ge/graph/passes/fuse_data_nodes_with_common_input_pass.h b/ge/graph/passes/fuse_data_nodes_with_common_input_pass.h new file mode 100755 index 00000000..9ff6ab89 --- /dev/null +++ b/ge/graph/passes/fuse_data_nodes_with_common_input_pass.h @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_FUSE_DATA_NODES_WITH_COMMON_INPUT_PASS_H_ +#define GE_GRAPH_PASSES_FUSE_DATA_NODES_WITH_COMMON_INPUT_PASS_H_ + +#include +#include +#include +#include "graph/types.h" +#include "inc/graph_pass.h" + +namespace ge { +class FuseDataNodesWithCommonInputPass : public GraphPass { + public: + Status Run(ge::ComputeGraphPtr graph) override; + + private: + Status InitNeedFuseNodesInfo(ComputeGraphPtr &graph, + map>> &subgraphs_to_need_fuse_nodes_info); + Status FuseDataNodes( + const map>> &subgraphs_to_need_fuse_nodes_info); +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_FUSE_DATA_NODES_WITH_COMMON_INPUT_PASS_H_ diff --git a/ge/graph/passes/no_data_out_const_elimination_pass.cc b/ge/graph/passes/no_data_out_const_elimination_pass.cc new file mode 100644 index 00000000..c55148bd --- /dev/null +++ b/ge/graph/passes/no_data_out_const_elimination_pass.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/no_data_out_const_elimination_pass.h" + +namespace ge { +Status NoDataOutConstEliminationPass::Run(NodePtr &node) { + GE_CHECK_NOTNULL(node); + GELOGD("RemoveConstWithoutDataPass running of %s.", node->GetName().c_str()); + if (node->GetType() == CONSTANT || node->GetType() == CONSTANTOP) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + // delete const which has no input and no output of data + if (node->GetOpDesc()->GetInputsSize() == 0 && node->GetOutDataNodes().size() == 0) { + GELOGI("Remove const %s.", node->GetName().c_str()); + if (IsolateAndDeleteNode(node, {}) != SUCCESS) { + GELOGE(FAILED, "IsolateAndDeleteNode %s failed.", node->GetName().c_str()); + return FAILED; + } + } + } + return SUCCESS; +} +} // namespace ge diff --git a/ge/graph/passes/no_data_out_const_elimination_pass.h b/ge/graph/passes/no_data_out_const_elimination_pass.h new file mode 100644 index 00000000..112c4867 --- /dev/null +++ b/ge/graph/passes/no_data_out_const_elimination_pass.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_REMOVE_CONST_WITHOUT_DATA_PASS_H_ +#define GE_GRAPH_PASSES_REMOVE_CONST_WITHOUT_DATA_PASS_H_ + +#include "graph/passes/base_pass.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" + +namespace ge { +class NoDataOutConstEliminationPass : public BaseNodePass { + public: + Status Run(ge::NodePtr &node) override; +}; +} // namespace ge + +#endif // GE_GRAPH_PASSES_REMOVE_CONST_WITHOUT_DATA_PASS_H_ diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 2ebe9fc9..0d4f6a66 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -178,6 +178,8 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/net_output_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/replace_transshape_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/constant_fuse_same_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/fuse_data_nodes_with_common_input_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/no_data_out_const_elimination_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/print_op_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/no_use_reshape_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/iterator_op_pass.cc" @@ -616,6 +618,8 @@ set(PASS_TEST_FILES "graph/passes/trans_op_depth_fusion_pass_unittest.cc" "graph/passes/transop_nearby_allreduce_fusion_pass_unittest.cc" "graph/passes/constant_folding_pass_unittest.cc" + "graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc" + "graph/passes/no_data_out_const_elimination_pass_unittest.cc" "graph/passes/stop_gradient_pass_unittest.cc" "graph/passes/prevent_gradient_pass_unittest.cc" "graph/passes/identity_pass_unittest.cc" diff --git a/tests/ut/ge/graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc b/tests/ut/ge/graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc new file mode 100644 index 00000000..1660b3c6 --- /dev/null +++ b/tests/ut/ge/graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc @@ -0,0 +1,156 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/fuse_data_nodes_with_common_input_pass.h" + +#include +#include +#include +#include + +#include "inc/pass_manager.h" +#include "common/ge_inner_error_codes.h" +#include "graph_builder_utils.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/type_utils.h" +#include "graph/utils/node_utils.h" + +namespace ge { + +class UtestFuseDataNodesWithCommonInputPass : public testing::Test { +protected: + void SetUp() {} + void TearDown() {} + +public: + NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { + GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); + auto op_desc = std::make_shared(name, type); + for (auto i = 0; i < in_num; ++i) { + op_desc->AddInputDesc(test_desc); + } + for (auto i = 0; i < out_num; ++i) { + op_desc->AddOutputDesc(test_desc); + } + return graph->AddNode(op_desc); + } +}; + +/// graph with subgraph +/// const +/// | | | +/// case +/// | +/// netoutput +/// ... +/// data0 data1 data2 +/// | \ / +/// conv add +TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) { + PassManager pass_manager; + pass_manager.AddPass("FuseDataNodesWithCommonInputPass", new (std::nothrow) FuseDataNodesWithCommonInputPass); + ComputeGraphPtr parent_graph = std::make_shared("parent_graph"); + auto parent_const = MakeNode(parent_graph, 0, 1, "parent_const", "Const"); + auto parent_case = MakeNode(parent_graph, 3, 1, "parent_case", "Case"); + auto parent_output = MakeNode(parent_graph, 1, 0, "parent_output", "NetOutput"); + + GeTensorDesc tensor_desc(GeShape({1,3,224,224}), FORMAT_NCHW, DT_FLOAT); + + parent_const->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); + parent_case->GetOpDesc()->UpdateInputDesc(0, tensor_desc); + parent_case->GetOpDesc()->UpdateInputDesc(1, tensor_desc); + parent_case->GetOpDesc()->UpdateInputDesc(2, tensor_desc); + parent_case->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); + + GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_case->GetInDataAnchor(0)); + GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_case->GetInDataAnchor(1)); + GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_case->GetInDataAnchor(2)); + GraphUtils::AddEdge(parent_case->GetOutDataAnchor(0), parent_output->GetInDataAnchor(0)); + + ComputeGraphPtr sub_graph = std::make_shared("sub_graph"); + auto data0 = MakeNode(parent_graph, 1, 1, "data0", "Data"); + data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc); + data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); + auto data1 = MakeNode(parent_graph, 1, 1, "data1", "Data"); + data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); + data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); + auto data2 = MakeNode(parent_graph, 1, 1, "data2", "Data"); + data2->GetOpDesc()->UpdateInputDesc(0, tensor_desc); + data2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); + (void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); + (void)AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); + (void)AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 2); + + sub_graph->SetParentNode(parent_case); + sub_graph->SetParentGraph(parent_graph); + EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS); +} + +/// graph with subgraph +/// const +/// / \ +/// cast1 cast2 +/// \ / +/// case +/// | +/// netoutput +/// ... +/// data1 data2 +/// \ / +/// add +TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph2) { + PassManager pass_manager; + pass_manager.AddPass("FuseDataNodesWithCommonInputPass", new (std::nothrow) FuseDataNodesWithCommonInputPass); + ComputeGraphPtr parent_graph = std::make_shared("parent_graph"); + auto parent_const = MakeNode(parent_graph, 0, 1, "parent_const", "Const"); + auto parent_cast1 = MakeNode(parent_graph, 1, 1, "parent_cast1", "Cast"); + auto parent_cast2 = MakeNode(parent_graph, 1, 1, "parent_cast2", "Cast"); + auto parent_case = MakeNode(parent_graph, 2, 1, "parent_case", "Case"); + auto parent_output = MakeNode(parent_graph, 1, 0, "parent_output", "NetOutput"); + + GeTensorDesc tensor_desc(GeShape({1,3,224,224}), FORMAT_NCHW, DT_FLOAT); + + parent_const->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); + parent_cast1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); + parent_cast1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); + parent_cast2->GetOpDesc()->UpdateInputDesc(0, tensor_desc); + parent_cast2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); + parent_case->GetOpDesc()->UpdateInputDesc(0, tensor_desc); + parent_case->GetOpDesc()->UpdateInputDesc(1, tensor_desc); + parent_case->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); + + GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast1->GetInDataAnchor(0)); + GraphUtils::AddEdge(parent_cast1->GetOutDataAnchor(0), parent_case->GetInDataAnchor(0)); + GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast2->GetInDataAnchor(0)); + GraphUtils::AddEdge(parent_cast2->GetOutDataAnchor(0), parent_case->GetInDataAnchor(1)); + GraphUtils::AddEdge(parent_case->GetOutDataAnchor(0), parent_output->GetInDataAnchor(0)); + + ComputeGraphPtr sub_graph = std::make_shared("sub_graph"); + auto data0 = MakeNode(parent_graph, 1, 1, "data0", "Data"); + data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc); + data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); + auto data1 = MakeNode(parent_graph, 1, 1, "data1", "Data"); + data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); + data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); + (void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); + (void)AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); + + sub_graph->SetParentNode(parent_case); + sub_graph->SetParentGraph(parent_graph); + EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS); +} +} // namespace ge diff --git a/tests/ut/ge/graph/passes/no_data_out_const_elimination_pass_unittest.cc b/tests/ut/ge/graph/passes/no_data_out_const_elimination_pass_unittest.cc new file mode 100644 index 00000000..c102f5c2 --- /dev/null +++ b/tests/ut/ge/graph/passes/no_data_out_const_elimination_pass_unittest.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/no_data_out_const_elimination_pass.h" + +#include +#include +#include +#include + +#include "common/ge_inner_error_codes.h" +#include "graph/utils/graph_utils.h" + +namespace ge { + +class UtestNoDataOutConstEliminationPass : public testing::Test { +protected: + void SetUp() {} + void TearDown() {} + +public: + NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { + GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); + auto op_desc = std::make_shared(name, type); + for (auto i = 0; i < in_num; ++i) { + op_desc->AddInputDesc(test_desc); + } + for (auto i = 0; i < out_num; ++i) { + op_desc->AddOutputDesc(test_desc); + } + return graph->AddNode(op_desc); + } +}; + +/// graph with subgraph +/// const1 +/// |(control) +/// const2 +/// | +/// output +TEST_F(UtestNoDataOutConstEliminationPass, succ_graph1) { + ComputeGraphPtr graph = std::make_shared("test"); + auto const_node1 = MakeNode(graph, 0, 1, "const_node1", "Const"); + auto const_node2 = MakeNode(graph, 1, 1, "const_node2", "Const"); + auto output_node = MakeNode(graph, 1, 0, "output_node", "NetOutput"); + GeTensorDesc tensor_desc(GeShape({1,3,224,224}), FORMAT_NCHW, DT_FLOAT); + + const_node1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); + const_node2->GetOpDesc()->UpdateInputDesc(0, tensor_desc); + const_node2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); + output_node->GetOpDesc()->UpdateInputDesc(0, tensor_desc); + + GraphUtils::AddEdge(const_node1->GetOutControlAnchor(), const_node2->GetInControlAnchor()); + GraphUtils::AddEdge(const_node2->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); + + GEPass pass(graph); + NamesToPass node_pass; + NoDataOutConstEliminationPass no_data_out_const_elimination_pass; + node_pass.emplace_back("NoDataOutConstEliminationPass", &no_data_out_const_elimination_pass); + EXPECT_EQ(pass.Run(node_pass), SUCCESS); +} +} // namespace ge