| @@ -157,7 +157,9 @@ set(TRAIN_SRC_LIST | |||
| "graph/passes/compile_nodes_pass.cc" | |||
| "graph/passes/constant_folding_pass.cc" | |||
| "graph/passes/constant_fuse_same_pass.cc" | |||
| "graph/passes/fuse_data_nodes_with_common_input_pass.cc" | |||
| "graph/passes/remove_same_const_pass.cc" | |||
| "graph/passes/no_data_out_const_elimination_pass.cc" | |||
| "graph/passes/useless_control_out_remove_pass.cc" | |||
| "graph/passes/control_trigger_pass.cc" | |||
| "graph/passes/dimension_adjust_pass.cc" | |||
| @@ -439,6 +441,7 @@ set(INFER_SRC_LIST | |||
| "graph/passes/net_output_pass.cc" | |||
| "graph/passes/replace_transshape_pass.cc" | |||
| "graph/passes/constant_fuse_same_pass.cc" | |||
| "graph/passes/fuse_data_nodes_with_common_input_pass.cc" | |||
| "graph/passes/print_op_pass.cc" | |||
| "graph/passes/no_use_reshape_remove_pass.cc" | |||
| "graph/passes/iterator_op_pass.cc" | |||
| @@ -535,6 +538,7 @@ set(INFER_SRC_LIST | |||
| "graph/passes/addn_pass.cc" | |||
| "graph/passes/common_subexpression_elimination_pass.cc" | |||
| "graph/passes/remove_same_const_pass.cc" | |||
| "graph/passes/no_data_out_const_elimination_pass.cc" | |||
| "graph/passes/useless_control_out_remove_pass.cc" | |||
| "graph/passes/transop_symmetry_elimination_pass.cc" | |||
| "graph/passes/save_pass.cc" | |||
| @@ -103,6 +103,7 @@ OMG_HOST_SRC_FILES := \ | |||
| graph/passes/net_output_pass.cc \ | |||
| graph/passes/replace_transshape_pass.cc \ | |||
| graph/passes/constant_fuse_same_pass.cc \ | |||
| graph/passes/fuse_data_nodes_with_common_input_pass.cc \ | |||
| graph/passes/print_op_pass.cc \ | |||
| graph/passes/no_use_reshape_remove_pass.cc \ | |||
| graph/passes/iterator_op_pass.cc \ | |||
| @@ -193,6 +194,7 @@ OMG_HOST_SRC_FILES := \ | |||
| graph/passes/cond_pass.cc \ | |||
| graph/passes/cond_remove_pass.cc \ | |||
| graph/passes/remove_same_const_pass.cc \ | |||
| graph/passes/no_data_out_const_elimination_pass.cc \ | |||
| graph/passes/useless_control_out_remove_pass.cc \ | |||
| graph/passes/for_pass.cc \ | |||
| graph/passes/enter_pass.cc \ | |||
| @@ -127,7 +127,9 @@ LIBGE_LOCAL_SRC_FILES := \ | |||
| graph/passes/compile_nodes_pass.cc \ | |||
| graph/passes/constant_folding_pass.cc \ | |||
| graph/passes/constant_fuse_same_pass.cc \ | |||
| graph/passes/fuse_data_nodes_with_common_input_pass.cc \ | |||
| graph/passes/remove_same_const_pass.cc \ | |||
| graph/passes/no_data_out_const_elimination_pass.cc \ | |||
| graph/passes/useless_control_out_remove_pass.cc \ | |||
| graph/passes/control_trigger_pass.cc \ | |||
| graph/passes/dimension_adjust_pass.cc \ | |||
| @@ -65,7 +65,7 @@ class ZeroCopyOffset { | |||
| // data_size of Data/Netoutput | |||
| int64_t GetDataSize() const { return data_size_; } | |||
| // value of *outside_addrs_ from davinci_model | |||
| std::vector<std::map<const void *, std::vector<void *>>> &GetOutsideAddrs() { return outside_addrs_; } | |||
| const std::vector<std::map<const void *, std::vector<void *>>> &GetOutsideAddrs() { return outside_addrs_; } | |||
| // name of op | |||
| std::string GetOpName() const { return op_name_; } | |||
| @@ -53,6 +53,7 @@ | |||
| #include "graph/passes/dimension_adjust_pass.h" | |||
| #include "graph/passes/dimension_compute_pass.h" | |||
| #include "graph/passes/flow_ctrl_pass.h" | |||
| #include "graph/passes/fuse_data_nodes_with_common_input_pass.h" | |||
| #include "graph/passes/identity_pass.h" | |||
| #include "graph/passes/input_output_connection_identify_pass.h" | |||
| #include "graph/passes/iterator_op_pass.h" | |||
| @@ -70,6 +71,7 @@ | |||
| #include "graph/passes/remove_same_const_pass.h" | |||
| #include "graph/passes/reshape_recovery_pass.h" | |||
| #include "graph/passes/reshape_remove_pass.h" | |||
| #include "graph/passes/no_data_out_const_elimination_pass.h" | |||
| #include "graph/passes/same_transdata_breadth_fusion_pass.h" | |||
| #include "graph/passes/subgraph_pass.h" | |||
| #include "graph/passes/switch_data_edges_bypass.h" | |||
| @@ -2104,6 +2106,24 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { | |||
| after_merge_passes.AddPass("OptimizeStage1_1::SwitchDataEdgesBypass", new (std::nothrow) SwitchDataEdgesBypass)); | |||
| GE_CHK_STATUS_RET( | |||
| after_merge_passes.AddPass("OptimizeStage1_1::ConstantFuseSamePass", new (std::nothrow) ConstantFuseSamePass)); | |||
| /* | |||
| * Do CSE before FuseDataNodesWithCommonInputPass to resolve the scene in bertlarge as following: | |||
| * const | |||
| * / | \ | |||
| * cast1 cast2 cast3 | |||
| * \ | / | |||
| * case | |||
| * the node `const` is the fused const node after ConstantFuseSamePass | |||
| * the nodes `cast1`, `cast2` and 'cast3' will be fused by CSE. | |||
| * in order to eliminate hard code in FuseDataNodesWithCommonInputPass, | |||
| * we do CSE before FuseDataNodesWithCommonInputPass | |||
| * But it is a temp solution, this CSE will be deleted after change pass from graph pass to node pass | |||
| */ | |||
| GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::CSEBeforeFuseDataNodesWithCommonInputPass", | |||
| new (std::nothrow) CommonSubexpressionEliminationPass)); | |||
| // FuseDataNodesWithCommonInputPass: fuse same data with common input in same graph | |||
| GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::FuseDataNodesWithCommonInputPass", | |||
| new (std::nothrow) FuseDataNodesWithCommonInputPass)); | |||
| GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::CommonSubexpressionEliminationPass", | |||
| new (std::nothrow) CommonSubexpressionEliminationPass)); | |||
| GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::PermutePass", new (std::nothrow) PermutePass)) | |||
| @@ -2226,12 +2246,14 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { | |||
| GELOGE(ret, "Run passes when OptimizeStage1_3 failed, ret:%u.", ret); | |||
| return ret; | |||
| } | |||
| NamesToPass identity_remove_pass; | |||
| GE_TIMESTAMP_START(identity_remove_pass); | |||
| NamesToPass node_pass; | |||
| GE_TIMESTAMP_START(node_pass); | |||
| IdentityPass identity_force_pass(false); // after SwitchToStreamSwitchPass | |||
| identity_remove_pass.emplace_back("IdentityPass", &identity_force_pass); | |||
| ret = GEPass(compute_graph).Run(identity_remove_pass); | |||
| GE_TIMESTAMP_END(identity_remove_pass, "GraphPrepare::IdentityRemovePass"); | |||
| NoDataOutConstEliminationPass no_data_out_const_elimination_pass; | |||
| node_pass.emplace_back("IdentityPass", &identity_force_pass); | |||
| node_pass.emplace_back("NoDataOutConstEliminationPass", &no_data_out_const_elimination_pass); | |||
| ret = GEPass(compute_graph).Run(node_pass); | |||
| GE_TIMESTAMP_END(node_pass, "GraphPrepare::node_pass"); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Run identity remove pass for preprocess failed, ret:%u.", ret); | |||
| return ret; | |||
| @@ -0,0 +1,119 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "graph/passes/fuse_data_nodes_with_common_input_pass.h" | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <set> | |||
| #include "common/ge_inner_error_codes.h" | |||
| #include "graph/utils/op_desc_utils.h" | |||
| #include "graph/utils/type_utils.h" | |||
| #include "graph/utils/node_utils.h" | |||
| using std::map; | |||
| using std::vector; | |||
| using std::set; | |||
| using std::string; | |||
| namespace ge { | |||
| Status FuseDataNodesWithCommonInputPass::Run(ge::ComputeGraphPtr graph) { | |||
| if (graph == nullptr) { | |||
| GELOGE(GE_GRAPH_PARAM_NULLPTR, "Compute graph is null."); | |||
| return GE_GRAPH_PARAM_NULLPTR; | |||
| } | |||
| GELOGD("FuseDataNodesWithCommonInputPass in."); | |||
| // key: subgraph, value:--key: peer out anchor to parent node, --value: parent indexes to parent node | |||
| map<ComputeGraphPtr, map<OutDataAnchorPtr, set<uint32_t>>> subgraphs_to_need_fuse_nodes_info; | |||
| if (InitNeedFuseNodesInfo(graph, subgraphs_to_need_fuse_nodes_info) != SUCCESS) { | |||
| GELOGE(FAILED, "InitNeedFuseNodesInfo failed."); | |||
| return FAILED; | |||
| } | |||
| return FuseDataNodes(subgraphs_to_need_fuse_nodes_info); | |||
| } | |||
| Status FuseDataNodesWithCommonInputPass::InitNeedFuseNodesInfo(ComputeGraphPtr &graph, | |||
| map<ComputeGraphPtr, map<OutDataAnchorPtr, set<uint32_t>>> &subgraphs_to_need_fuse_nodes_info) { | |||
| for (const auto &subgraph : graph->GetAllSubgraphs()) { | |||
| GE_CHECK_NOTNULL(subgraph); | |||
| auto parent_node = subgraph->GetParentNode(); | |||
| GE_CHECK_NOTNULL(parent_node); | |||
| if (parent_node->GetType() == CASE || parent_node->GetType() == IF) { | |||
| auto &peer_out_anchors_to_parent_indexes = subgraphs_to_need_fuse_nodes_info[subgraph]; | |||
| for (const auto &in_data_anchor : parent_node->GetAllInDataAnchors()) { | |||
| GE_CHECK_NOTNULL(in_data_anchor); | |||
| OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| uint32_t parent_index = static_cast<uint32_t>(in_data_anchor->GetIdx()); | |||
| GE_CHECK_NOTNULL(peer_out_anchor); | |||
| peer_out_anchors_to_parent_indexes[peer_out_anchor].insert(parent_index); | |||
| GELOGD("Peer node %s is the %d input of parent node %s in %s.", | |||
| peer_out_anchor->GetOwnerNode()->GetName().c_str(), parent_index, parent_node->GetName().c_str(), | |||
| subgraph->GetName().c_str()); | |||
| } | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status FuseDataNodesWithCommonInputPass::FuseDataNodes( | |||
| const map<ComputeGraphPtr, map<OutDataAnchorPtr, set<uint32_t>>> &subgraphs_to_need_fuse_nodes_info) { | |||
| for (const auto &subgraph_to_need_fuse_nodes_info : subgraphs_to_need_fuse_nodes_info) { | |||
| auto subgraph = subgraph_to_need_fuse_nodes_info.first; | |||
| for (const auto &peer_out_anchors_to_parent_indexes : subgraph_to_need_fuse_nodes_info.second) { | |||
| if (peer_out_anchors_to_parent_indexes.second.size() <= 1) { | |||
| continue; | |||
| } | |||
| // key: out anchor, value: data nodes with common input will be fused | |||
| map<OutDataAnchorPtr, vector<NodePtr>> peer_out_anchors_to_need_fuse_nodes; | |||
| for (const auto &node : subgraph->GetDirectNode()) { | |||
| if (node->GetType() != DATA) { | |||
| continue; | |||
| } | |||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||
| uint32_t parent_index = 0; | |||
| if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | |||
| if (peer_out_anchors_to_parent_indexes.second.count(parent_index) > 0) { | |||
| peer_out_anchors_to_need_fuse_nodes[peer_out_anchors_to_parent_indexes.first].emplace_back(node); | |||
| } | |||
| } | |||
| } | |||
| for (const auto &peer_out_anchor_to_need_fuse_nodes : peer_out_anchors_to_need_fuse_nodes) { | |||
| auto need_fuse_data_nodes = peer_out_anchor_to_need_fuse_nodes.second; | |||
| auto first_node = need_fuse_data_nodes.at(0); | |||
| for (size_t i = 1; i < need_fuse_data_nodes.size(); ++i) { | |||
| auto node = need_fuse_data_nodes.at(i); | |||
| GELOGI("Replace redundant data node %s by %s exist in graph: %s.", node->GetName().c_str(), | |||
| first_node->GetName().c_str(), subgraph->GetName().c_str()); | |||
| // the data node which can be fused has none input(both data and control in) | |||
| if (GraphUtils::MoveOutCtrlEdges(node, first_node) != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| if (GraphUtils::ReplaceNodeDataAnchors(first_node, node, {}, {0}) != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| if (GraphUtils::RemoveNodeWithoutRelink(subgraph, node) != SUCCESS) { | |||
| GELOGE(FAILED, "[%s] RemoveNodeWithoutRelink failed.", node->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace ge | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GRAPH_PASSES_FUSE_DATA_NODES_WITH_COMMON_INPUT_PASS_H_ | |||
| #define GE_GRAPH_PASSES_FUSE_DATA_NODES_WITH_COMMON_INPUT_PASS_H_ | |||
| #include <set> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "graph/types.h" | |||
| #include "inc/graph_pass.h" | |||
| namespace ge { | |||
| class FuseDataNodesWithCommonInputPass : public GraphPass { | |||
| public: | |||
| Status Run(ge::ComputeGraphPtr graph) override; | |||
| private: | |||
| Status InitNeedFuseNodesInfo(ComputeGraphPtr &graph, | |||
| map<ComputeGraphPtr, map<OutDataAnchorPtr, std::set<uint32_t>>> &subgraphs_to_need_fuse_nodes_info); | |||
| Status FuseDataNodes( | |||
| const map<ComputeGraphPtr, map<OutDataAnchorPtr, std::set<uint32_t>>> &subgraphs_to_need_fuse_nodes_info); | |||
| }; | |||
| } // namespace ge | |||
| #endif // GE_GRAPH_PASSES_FUSE_DATA_NODES_WITH_COMMON_INPUT_PASS_H_ | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "graph/passes/no_data_out_const_elimination_pass.h" | |||
| namespace ge { | |||
| Status NoDataOutConstEliminationPass::Run(NodePtr &node) { | |||
| GE_CHECK_NOTNULL(node); | |||
| GELOGD("RemoveConstWithoutDataPass running of %s.", node->GetName().c_str()); | |||
| if (node->GetType() == CONSTANT || node->GetType() == CONSTANTOP) { | |||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||
| // delete const which has no input and no output of data | |||
| if (node->GetOpDesc()->GetInputsSize() == 0 && node->GetOutDataNodes().size() == 0) { | |||
| GELOGI("Remove const %s.", node->GetName().c_str()); | |||
| if (IsolateAndDeleteNode(node, {}) != SUCCESS) { | |||
| GELOGE(FAILED, "IsolateAndDeleteNode %s failed.", node->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace ge | |||
| @@ -0,0 +1,31 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GRAPH_PASSES_REMOVE_CONST_WITHOUT_DATA_PASS_H_ | |||
| #define GE_GRAPH_PASSES_REMOVE_CONST_WITHOUT_DATA_PASS_H_ | |||
| #include "graph/passes/base_pass.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "framework/common/util.h" | |||
| namespace ge { | |||
| class NoDataOutConstEliminationPass : public BaseNodePass { | |||
| public: | |||
| Status Run(ge::NodePtr &node) override; | |||
| }; | |||
| } // namespace ge | |||
| #endif // GE_GRAPH_PASSES_REMOVE_CONST_WITHOUT_DATA_PASS_H_ | |||
| @@ -178,6 +178,8 @@ set(COMMON_SRC_FILES | |||
| "${GE_CODE_DIR}/ge/graph/passes/net_output_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/replace_transshape_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/constant_fuse_same_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/fuse_data_nodes_with_common_input_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/no_data_out_const_elimination_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/print_op_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/no_use_reshape_remove_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/iterator_op_pass.cc" | |||
| @@ -616,6 +618,8 @@ set(PASS_TEST_FILES | |||
| "graph/passes/trans_op_depth_fusion_pass_unittest.cc" | |||
| "graph/passes/transop_nearby_allreduce_fusion_pass_unittest.cc" | |||
| "graph/passes/constant_folding_pass_unittest.cc" | |||
| "graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc" | |||
| "graph/passes/no_data_out_const_elimination_pass_unittest.cc" | |||
| "graph/passes/stop_gradient_pass_unittest.cc" | |||
| "graph/passes/prevent_gradient_pass_unittest.cc" | |||
| "graph/passes/identity_pass_unittest.cc" | |||
| @@ -0,0 +1,156 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "graph/passes/fuse_data_nodes_with_common_input_pass.h" | |||
| #include <gtest/gtest.h> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <map> | |||
| #include "inc/pass_manager.h" | |||
| #include "common/ge_inner_error_codes.h" | |||
| #include "graph_builder_utils.h" | |||
| #include "graph/utils/tensor_utils.h" | |||
| #include "graph/utils/op_desc_utils.h" | |||
| #include "graph/utils/type_utils.h" | |||
| #include "graph/utils/node_utils.h" | |||
| namespace ge { | |||
| class UtestFuseDataNodesWithCommonInputPass : public testing::Test { | |||
| protected: | |||
| void SetUp() {} | |||
| void TearDown() {} | |||
| public: | |||
| NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { | |||
| GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||
| auto op_desc = std::make_shared<OpDesc>(name, type); | |||
| for (auto i = 0; i < in_num; ++i) { | |||
| op_desc->AddInputDesc(test_desc); | |||
| } | |||
| for (auto i = 0; i < out_num; ++i) { | |||
| op_desc->AddOutputDesc(test_desc); | |||
| } | |||
| return graph->AddNode(op_desc); | |||
| } | |||
| }; | |||
| /// graph with subgraph | |||
| /// const | |||
| /// | | | | |||
| /// case | |||
| /// | | |||
| /// netoutput | |||
| /// ... | |||
| /// data0 data1 data2 | |||
| /// | \ / | |||
| /// conv add | |||
| TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) { | |||
| PassManager pass_manager; | |||
| pass_manager.AddPass("FuseDataNodesWithCommonInputPass", new (std::nothrow) FuseDataNodesWithCommonInputPass); | |||
| ComputeGraphPtr parent_graph = std::make_shared<ComputeGraph>("parent_graph"); | |||
| auto parent_const = MakeNode(parent_graph, 0, 1, "parent_const", "Const"); | |||
| auto parent_case = MakeNode(parent_graph, 3, 1, "parent_case", "Case"); | |||
| auto parent_output = MakeNode(parent_graph, 1, 0, "parent_output", "NetOutput"); | |||
| GeTensorDesc tensor_desc(GeShape({1,3,224,224}), FORMAT_NCHW, DT_FLOAT); | |||
| parent_const->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
| parent_case->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||
| parent_case->GetOpDesc()->UpdateInputDesc(1, tensor_desc); | |||
| parent_case->GetOpDesc()->UpdateInputDesc(2, tensor_desc); | |||
| parent_case->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
| GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_case->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_case->GetInDataAnchor(1)); | |||
| GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_case->GetInDataAnchor(2)); | |||
| GraphUtils::AddEdge(parent_case->GetOutDataAnchor(0), parent_output->GetInDataAnchor(0)); | |||
| ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph"); | |||
| auto data0 = MakeNode(parent_graph, 1, 1, "data0", "Data"); | |||
| data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||
| data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
| auto data1 = MakeNode(parent_graph, 1, 1, "data1", "Data"); | |||
| data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||
| data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
| auto data2 = MakeNode(parent_graph, 1, 1, "data2", "Data"); | |||
| data2->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||
| data2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
| (void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); | |||
| (void)AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); | |||
| (void)AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 2); | |||
| sub_graph->SetParentNode(parent_case); | |||
| sub_graph->SetParentGraph(parent_graph); | |||
| EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS); | |||
| } | |||
| /// graph with subgraph | |||
| /// const | |||
| /// / \ | |||
| /// cast1 cast2 | |||
| /// \ / | |||
| /// case | |||
| /// | | |||
| /// netoutput | |||
| /// ... | |||
| /// data1 data2 | |||
| /// \ / | |||
| /// add | |||
| TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph2) { | |||
| PassManager pass_manager; | |||
| pass_manager.AddPass("FuseDataNodesWithCommonInputPass", new (std::nothrow) FuseDataNodesWithCommonInputPass); | |||
| ComputeGraphPtr parent_graph = std::make_shared<ComputeGraph>("parent_graph"); | |||
| auto parent_const = MakeNode(parent_graph, 0, 1, "parent_const", "Const"); | |||
| auto parent_cast1 = MakeNode(parent_graph, 1, 1, "parent_cast1", "Cast"); | |||
| auto parent_cast2 = MakeNode(parent_graph, 1, 1, "parent_cast2", "Cast"); | |||
| auto parent_case = MakeNode(parent_graph, 2, 1, "parent_case", "Case"); | |||
| auto parent_output = MakeNode(parent_graph, 1, 0, "parent_output", "NetOutput"); | |||
| GeTensorDesc tensor_desc(GeShape({1,3,224,224}), FORMAT_NCHW, DT_FLOAT); | |||
| parent_const->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
| parent_cast1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||
| parent_cast1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
| parent_cast2->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||
| parent_cast2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
| parent_case->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||
| parent_case->GetOpDesc()->UpdateInputDesc(1, tensor_desc); | |||
| parent_case->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
| GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast1->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(parent_cast1->GetOutDataAnchor(0), parent_case->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast2->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(parent_cast2->GetOutDataAnchor(0), parent_case->GetInDataAnchor(1)); | |||
| GraphUtils::AddEdge(parent_case->GetOutDataAnchor(0), parent_output->GetInDataAnchor(0)); | |||
| ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph"); | |||
| auto data0 = MakeNode(parent_graph, 1, 1, "data0", "Data"); | |||
| data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||
| data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
| auto data1 = MakeNode(parent_graph, 1, 1, "data1", "Data"); | |||
| data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||
| data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
| (void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); | |||
| (void)AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); | |||
| sub_graph->SetParentNode(parent_case); | |||
| sub_graph->SetParentGraph(parent_graph); | |||
| EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS); | |||
| } | |||
| } // namespace ge | |||
| @@ -0,0 +1,75 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "graph/passes/no_data_out_const_elimination_pass.h" | |||
| #include <gtest/gtest.h> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <map> | |||
| #include "common/ge_inner_error_codes.h" | |||
| #include "graph/utils/graph_utils.h" | |||
| namespace ge { | |||
| class UtestNoDataOutConstEliminationPass : public testing::Test { | |||
| protected: | |||
| void SetUp() {} | |||
| void TearDown() {} | |||
| public: | |||
| NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { | |||
| GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||
| auto op_desc = std::make_shared<OpDesc>(name, type); | |||
| for (auto i = 0; i < in_num; ++i) { | |||
| op_desc->AddInputDesc(test_desc); | |||
| } | |||
| for (auto i = 0; i < out_num; ++i) { | |||
| op_desc->AddOutputDesc(test_desc); | |||
| } | |||
| return graph->AddNode(op_desc); | |||
| } | |||
| }; | |||
| /// graph with subgraph | |||
| /// const1 | |||
| /// |(control) | |||
| /// const2 | |||
| /// | | |||
| /// output | |||
| TEST_F(UtestNoDataOutConstEliminationPass, succ_graph1) { | |||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
| auto const_node1 = MakeNode(graph, 0, 1, "const_node1", "Const"); | |||
| auto const_node2 = MakeNode(graph, 1, 1, "const_node2", "Const"); | |||
| auto output_node = MakeNode(graph, 1, 0, "output_node", "NetOutput"); | |||
| GeTensorDesc tensor_desc(GeShape({1,3,224,224}), FORMAT_NCHW, DT_FLOAT); | |||
| const_node1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
| const_node2->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||
| const_node2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
| output_node->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||
| GraphUtils::AddEdge(const_node1->GetOutControlAnchor(), const_node2->GetInControlAnchor()); | |||
| GraphUtils::AddEdge(const_node2->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); | |||
| GEPass pass(graph); | |||
| NamesToPass node_pass; | |||
| NoDataOutConstEliminationPass no_data_out_const_elimination_pass; | |||
| node_pass.emplace_back("NoDataOutConstEliminationPass", &no_data_out_const_elimination_pass); | |||
| EXPECT_EQ(pass.Run(node_pass), SUCCESS); | |||
| } | |||
| } // namespace ge | |||