diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index 215d2832..7deeef74 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -329,6 +329,7 @@ set(TRAIN_SRC_LIST "graph/passes/stop_gradient_pass.cc" "graph/passes/subgraph_pass.cc" "graph/passes/data_pass.cc" + "graph/passes/branch_logical_remove_pass.cc" "graph/passes/switch_data_edges_bypass.cc" "graph/passes/switch_logic_remove_pass.cc" "graph/passes/merge_to_stream_merge_pass.cc" @@ -626,6 +627,7 @@ set(INFER_SRC_LIST "graph/passes/useless_control_out_remove_pass.cc" "graph/passes/transop_symmetry_elimination_pass.cc" "graph/passes/save_pass.cc" + "graph/passes/branch_logical_remove_pass.cc" "graph/passes/switch_dead_branch_elimination.cc" "graph/passes/switch_logic_remove_pass.cc" "graph/passes/switch_data_edges_bypass.cc" diff --git a/ge/ge_inference.mk b/ge/ge_inference.mk index a56eaadf..c0db74a9 100755 --- a/ge/ge_inference.mk +++ b/ge/ge_inference.mk @@ -203,6 +203,7 @@ OMG_HOST_SRC_FILES := \ graph/passes/common_subexpression_elimination_pass.cc \ graph/passes/transop_symmetry_elimination_pass.cc \ graph/passes/save_pass.cc \ + graph/passes/branch_logical_remove_pass.cc \ graph/passes/switch_dead_branch_elimination.cc \ graph/passes/switch_logic_remove_pass.cc \ graph/passes/switch_data_edges_bypass.cc \ diff --git a/ge/ge_runner.mk b/ge/ge_runner.mk index 8ca8572c..b0348a6c 100644 --- a/ge/ge_runner.mk +++ b/ge/ge_runner.mk @@ -218,6 +218,7 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/stop_gradient_pass.cc \ graph/passes/subgraph_pass.cc \ graph/passes/data_pass.cc \ + graph/passes/branch_logical_remove_pass.cc \ graph/passes/switch_data_edges_bypass.cc \ graph/passes/switch_logic_remove_pass.cc \ graph/passes/merge_to_stream_merge_pass.cc \ diff --git a/ge/graph/passes/branch_logical_remove_pass.cc b/ge/graph/passes/branch_logical_remove_pass.cc new file mode 100644 index 00000000..a735613b --- /dev/null +++ b/ge/graph/passes/branch_logical_remove_pass.cc @@ -0,0 +1,133 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/branch_logical_remove_pass.h" +#include "graph/utils/node_utils.h" + +namespace ge { +Status BranchLogicalRemovePass::Run(ComputeGraphPtr graph) { + BranchExecCondCalculator calculator(graph); + if (calculator.Calculate() != GRAPH_SUCCESS) { + GELOGE(FAILED, "Calculate branch exec cond for removing logical dead branch failed."); + return FAILED; + } + const auto &node_exec_cond = calculator.GetBranchExecCond(); + if (node_exec_cond.empty()) { + GELOGI("No branch in graph %s, skip.", graph->GetName().c_str()); + return SUCCESS; + } + + for (const auto &node : graph->GetDirectNode()) { + const std::string &type = NodeUtils::GetNodeType(node); + if ((type == SWITCH) || (type == REFSWITCH)) { + if (RemoveRedundantSwitch(node_exec_cond, node) != SUCCESS) { + GELOGE(FAILED, "Remove redundant switch %s failed.", node->GetName().c_str()); + return FAILED; + } + } else if ((type == MERGE) || (type == REFMERGE)) { + if (RemoveDeadInputForMerge(node_exec_cond, node) != SUCCESS) { + GELOGE(FAILED, "Remove dead input for merge %s failed.", node->GetName().c_str()); + return FAILED; + } + } + } + + return SUCCESS; +} + +Status BranchLogicalRemovePass::RemoveRedundantSwitch(const std::map &node_exec_cond, + const NodePtr &switch_node) { + const auto &iter = node_exec_cond.find(switch_node); + if (iter == node_exec_cond.end()) { + GELOGE(FAILED, "Find for exec cond for node %s.", switch_node->GetName().c_str()); + return FAILED; + } + if (!iter->second.IsValid()) { + return SUCCESS; + } + for (const auto &pair : switch_node->GetOutDataNodesAndAnchors()) { + const auto &out_node = pair.first; + const auto &out_iter = node_exec_cond.find(out_node); + if (out_iter == node_exec_cond.end()) { + GELOGE(FAILED, "Find for exec cond for node %s.", out_node->GetName().c_str()); + return FAILED; + } + if (!out_iter->second.IsValid()) { + GELOGI("Remove data edge %s:%d->%s:%d.", switch_node->GetName().c_str(), + pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); + if (GraphUtils::RemoveEdge(pair.second->GetPeerOutAnchor(), pair.second) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", switch_node->GetName().c_str(), + pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); + return FAILED; + } + } else if (iter->second.String() == out_iter->second.String()) { + GELOGI("Remove data edge %s:%d->%s:%d.", switch_node->GetName().c_str(), + pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); + if (GraphUtils::RemoveEdge(pair.second->GetPeerOutAnchor(), pair.second) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", switch_node->GetName().c_str(), + pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); + return FAILED; + } + + const auto data_input_anchor = switch_node->GetInDataAnchor(SWITCH_DATA_INPUT); + GE_CHECK_NOTNULL(data_input_anchor); + const auto &peer_out_anchor = data_input_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + GELOGI("Add data edge %s:%d->%s:%d.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), + pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); + if (GraphUtils::AddEdge(peer_out_anchor, pair.second) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add data edge %s:%d->%s:%d failed.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), + pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); + return FAILED; + } + } + } + + return SUCCESS; +} + +Status BranchLogicalRemovePass::RemoveDeadInputForMerge(const std::map &node_exec_cond, + const NodePtr &merge_node) { + const auto &iter = node_exec_cond.find(merge_node); + if (iter == node_exec_cond.end()) { + GELOGE(FAILED, "Find for exec cond for node %s.", merge_node->GetName().c_str()); + return FAILED; + } + if (!iter->second.IsValid()) { + return SUCCESS; + } + for (const auto &in_data_anchor : merge_node->GetAllInDataAnchors()) { + const auto &peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { continue; } + const auto &in_node = peer_out_anchor->GetOwnerNode(); + const auto &in_iter = node_exec_cond.find(in_node); + if (in_iter == node_exec_cond.end()) { + GELOGE(FAILED, "Find for exec cond for node %s.", in_node->GetName().c_str()); + return FAILED; + } + if (!in_iter->second.IsValid()) { + GELOGI("Remove data edge %s:%d->%s:%d.", in_node->GetName().c_str(), peer_out_anchor->GetIdx(), + merge_node->GetName().c_str(), in_data_anchor->GetIdx()); + if (GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", in_node->GetName().c_str(), peer_out_anchor->GetIdx(), + merge_node->GetName().c_str(), in_data_anchor->GetIdx()); + return FAILED; + } + } + } + return SUCCESS; +} +} // namespace ge diff --git a/ge/graph/passes/branch_logical_remove_pass.h b/ge/graph/passes/branch_logical_remove_pass.h new file mode 100644 index 00000000..497eef13 --- /dev/null +++ b/ge/graph/passes/branch_logical_remove_pass.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_BRANCH_LOGICAL_REMOVE_PASS_H_ +#define GE_GRAPH_PASSES_BRANCH_LOGICAL_REMOVE_PASS_H_ + +#include "inc/graph_pass.h" + +namespace ge { + +class BranchLogicalRemovePass : public GraphPass { + public: + Status Run(ge::ComputeGraphPtr graph) override; + + private: + static Status RemoveRedundantSwitch(const std::map &node_exec_cond, + const NodePtr &switch_node); + + static Status RemoveDeadInputForMerge(const std::map &node_exec_cond, + const NodePtr &merge_node); +}; + +} // namespace ge +#endif // GE_GRAPH_PASSES_BRANCH_LOGICAL_REMOVE_PASS_H_ diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index 0c4adeea..da8e58be 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -75,6 +75,7 @@ #include "graph/passes/var_is_initialized_op_pass.h" #include "graph/passes/variable_prepare_op_pass.h" #include "graph/passes/mark_force_unknown_for_cond_pass.h" +#include "graph/passes/branch_logical_remove_pass.h" #include "graph/preprocess/insert_op/util_insert_aipp_op.h" #include "graph/utils/type_utils.h" #include "inc/pass_manager.h" @@ -1725,6 +1726,7 @@ Status GraphPrepare::PrepareDynShape(const GraphNodePtr &graph_node, const std:: PP_RUN_AND_DUMP("CheckAndUpdateInput", CheckAndUpdateInput, user_input, graph_node->GetOptions()); PP_RUN_AND_DUMP("GraphEquivalentTransformation", GraphEquivalentTransformation); PP_RUN_AND_DUMP("ProcessOutput", ProcessNetOutput); + PP_RUN_AND_DUMP("RemoveLogicalDeadBranch", RemoveLogicalDeadBranch); PP_RUN_AND_DUMP("ProcessMultiBatch", multibatch::ProcessMultiBatch, compute_graph_); PP_RUN_AND_DUMP("InsertAipp", TryDoAipp); PP_RUN_AND_DUMP("ProcessBeforeInfershape", ProcessBeforeInfershape); @@ -2181,6 +2183,37 @@ Status GraphPrepare::GraphEquivalentTransformation() { return GEPass(compute_graph_).Run(names_to_pass); } +Status GraphPrepare::RemoveLogicalDeadBranch() { + GE_TIMESTAMP_START(BranchLogicalRemovePass); + PassManager logical_dead_branch_remove_pass; + GE_CHK_STATUS_RET(logical_dead_branch_remove_pass.AddPass("RemoveLogicalDeadBranch::BranchLogicalRemovePass", + new (std::nothrow) BranchLogicalRemovePass)) + auto ret = logical_dead_branch_remove_pass.Run(compute_graph_); + if ((ret != SUCCESS) && (ret != NOT_CHANGED)) { + GELOGE(ret, "Run logical_dead_branch_remove_pass for RemoveLogicalDeadBranch failed, ret:%d.", ret); + return ret; + } + + NamesToPass names_to_passes; + MergePass merge_pass; + names_to_passes.emplace_back("MergePass", &merge_pass); + ret = GEPass(compute_graph_).Run(names_to_passes); + if ((ret != SUCCESS) && (ret != NOT_CHANGED)) { + GELOGE(ret, "Run merge_pass for RemoveLogicalDeadBranch failed, ret:%d.", ret); + return ret; + } + + PassManager prune_pass; + GE_CHK_STATUS_RET(prune_pass.AddPass("RemoveLogicalDeadBranch::PrunePass", new (std::nothrow) PrunePass)) + if ((ret != SUCCESS) && (ret != NOT_CHANGED)) { + GELOGE(ret, "Run prune_pass for RemoveLogicalDeadBranch failed, ret:%d.", ret); + return ret; + } + GE_TIMESTAMP_END(BranchLogicalRemovePass, "GraphPrepare::RemoveLogicalDeadBranch"); + + return SUCCESS; +} + Status GraphPrepare::ProcessBeforeInfershape() { NamesToPass names_to_passes; CondRemovePass condition_remove_pass; diff --git a/ge/graph/preprocess/graph_preprocess.h b/ge/graph/preprocess/graph_preprocess.h index 584f4d16..87075e18 100755 --- a/ge/graph/preprocess/graph_preprocess.h +++ b/ge/graph/preprocess/graph_preprocess.h @@ -88,6 +88,8 @@ class GraphPrepare { void TypeConversionOfConstant(); bool IsDynamicDims(const NodePtr &input_node); + Status RemoveLogicalDeadBranch(); + ge::ComputeGraphPtr compute_graph_; GraphManagerOptions options_; uint64_t session_id_ = 0; diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 63579109..1654654e 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -258,6 +258,7 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/common_subexpression_elimination_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/transop_symmetry_elimination_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/save_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/branch_logical_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/switch_dead_branch_elimination.cc" "${GE_CODE_DIR}/ge/graph/passes/switch_logic_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/switch_data_edges_bypass.cc" @@ -508,6 +509,7 @@ set(GRAPH_PASS_COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/addn_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/save_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/merge_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/branch_logical_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/switch_logic_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/assert_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/dropout_pass.cc" @@ -671,6 +673,7 @@ set(PASS_TEST_FILES "graph/passes/addn_pass_unittest.cc" "graph/passes/save_pass_unittest.cc" "graph/passes/merge_pass_unittest.cc" + "graph/passes/branch_logical_remove_pass_unittest.cc" "graph/passes/switch_logic_remove_pass_unittest.cc" "graph/passes/cond_branch_v1_unittest.cc" "graph/passes/loop_branch_v1_unittest.cc"