| @@ -329,6 +329,7 @@ set(TRAIN_SRC_LIST | |||
| "graph/passes/stop_gradient_pass.cc" | |||
| "graph/passes/subgraph_pass.cc" | |||
| "graph/passes/data_pass.cc" | |||
| "graph/passes/branch_logical_remove_pass.cc" | |||
| "graph/passes/switch_data_edges_bypass.cc" | |||
| "graph/passes/switch_logic_remove_pass.cc" | |||
| "graph/passes/merge_to_stream_merge_pass.cc" | |||
| @@ -626,6 +627,7 @@ set(INFER_SRC_LIST | |||
| "graph/passes/useless_control_out_remove_pass.cc" | |||
| "graph/passes/transop_symmetry_elimination_pass.cc" | |||
| "graph/passes/save_pass.cc" | |||
| "graph/passes/branch_logical_remove_pass.cc" | |||
| "graph/passes/switch_dead_branch_elimination.cc" | |||
| "graph/passes/switch_logic_remove_pass.cc" | |||
| "graph/passes/switch_data_edges_bypass.cc" | |||
| @@ -203,6 +203,7 @@ OMG_HOST_SRC_FILES := \ | |||
| graph/passes/common_subexpression_elimination_pass.cc \ | |||
| graph/passes/transop_symmetry_elimination_pass.cc \ | |||
| graph/passes/save_pass.cc \ | |||
| graph/passes/branch_logical_remove_pass.cc \ | |||
| graph/passes/switch_dead_branch_elimination.cc \ | |||
| graph/passes/switch_logic_remove_pass.cc \ | |||
| graph/passes/switch_data_edges_bypass.cc \ | |||
| @@ -218,6 +218,7 @@ LIBGE_LOCAL_SRC_FILES := \ | |||
| graph/passes/stop_gradient_pass.cc \ | |||
| graph/passes/subgraph_pass.cc \ | |||
| graph/passes/data_pass.cc \ | |||
| graph/passes/branch_logical_remove_pass.cc \ | |||
| graph/passes/switch_data_edges_bypass.cc \ | |||
| graph/passes/switch_logic_remove_pass.cc \ | |||
| graph/passes/merge_to_stream_merge_pass.cc \ | |||
| @@ -0,0 +1,133 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "graph/passes/branch_logical_remove_pass.h" | |||
| #include "graph/utils/node_utils.h" | |||
| namespace ge { | |||
| Status BranchLogicalRemovePass::Run(ComputeGraphPtr graph) { | |||
| BranchExecCondCalculator calculator(graph); | |||
| if (calculator.Calculate() != GRAPH_SUCCESS) { | |||
| GELOGE(FAILED, "Calculate branch exec cond for removing logical dead branch failed."); | |||
| return FAILED; | |||
| } | |||
| const auto &node_exec_cond = calculator.GetBranchExecCond(); | |||
| if (node_exec_cond.empty()) { | |||
| GELOGI("No branch in graph %s, skip.", graph->GetName().c_str()); | |||
| return SUCCESS; | |||
| } | |||
| for (const auto &node : graph->GetDirectNode()) { | |||
| const std::string &type = NodeUtils::GetNodeType(node); | |||
| if ((type == SWITCH) || (type == REFSWITCH)) { | |||
| if (RemoveRedundantSwitch(node_exec_cond, node) != SUCCESS) { | |||
| GELOGE(FAILED, "Remove redundant switch %s failed.", node->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| } else if ((type == MERGE) || (type == REFMERGE)) { | |||
| if (RemoveDeadInputForMerge(node_exec_cond, node) != SUCCESS) { | |||
| GELOGE(FAILED, "Remove dead input for merge %s failed.", node->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status BranchLogicalRemovePass::RemoveRedundantSwitch(const std::map<NodePtr, LogicOperatorItem> &node_exec_cond, | |||
| const NodePtr &switch_node) { | |||
| const auto &iter = node_exec_cond.find(switch_node); | |||
| if (iter == node_exec_cond.end()) { | |||
| GELOGE(FAILED, "Find for exec cond for node %s.", switch_node->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| if (!iter->second.IsValid()) { | |||
| return SUCCESS; | |||
| } | |||
| for (const auto &pair : switch_node->GetOutDataNodesAndAnchors()) { | |||
| const auto &out_node = pair.first; | |||
| const auto &out_iter = node_exec_cond.find(out_node); | |||
| if (out_iter == node_exec_cond.end()) { | |||
| GELOGE(FAILED, "Find for exec cond for node %s.", out_node->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| if (!out_iter->second.IsValid()) { | |||
| GELOGI("Remove data edge %s:%d->%s:%d.", switch_node->GetName().c_str(), | |||
| pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); | |||
| if (GraphUtils::RemoveEdge(pair.second->GetPeerOutAnchor(), pair.second) != GRAPH_SUCCESS) { | |||
| GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", switch_node->GetName().c_str(), | |||
| pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); | |||
| return FAILED; | |||
| } | |||
| } else if (iter->second.String() == out_iter->second.String()) { | |||
| GELOGI("Remove data edge %s:%d->%s:%d.", switch_node->GetName().c_str(), | |||
| pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); | |||
| if (GraphUtils::RemoveEdge(pair.second->GetPeerOutAnchor(), pair.second) != GRAPH_SUCCESS) { | |||
| GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", switch_node->GetName().c_str(), | |||
| pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); | |||
| return FAILED; | |||
| } | |||
| const auto data_input_anchor = switch_node->GetInDataAnchor(SWITCH_DATA_INPUT); | |||
| GE_CHECK_NOTNULL(data_input_anchor); | |||
| const auto &peer_out_anchor = data_input_anchor->GetPeerOutAnchor(); | |||
| GE_CHECK_NOTNULL(peer_out_anchor); | |||
| GELOGI("Add data edge %s:%d->%s:%d.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
| pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); | |||
| if (GraphUtils::AddEdge(peer_out_anchor, pair.second) != GRAPH_SUCCESS) { | |||
| GELOGE(FAILED, "Add data edge %s:%d->%s:%d failed.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
| pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); | |||
| return FAILED; | |||
| } | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status BranchLogicalRemovePass::RemoveDeadInputForMerge(const std::map<NodePtr, LogicOperatorItem> &node_exec_cond, | |||
| const NodePtr &merge_node) { | |||
| const auto &iter = node_exec_cond.find(merge_node); | |||
| if (iter == node_exec_cond.end()) { | |||
| GELOGE(FAILED, "Find for exec cond for node %s.", merge_node->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| if (!iter->second.IsValid()) { | |||
| return SUCCESS; | |||
| } | |||
| for (const auto &in_data_anchor : merge_node->GetAllInDataAnchors()) { | |||
| const auto &peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| if (peer_out_anchor == nullptr) { continue; } | |||
| const auto &in_node = peer_out_anchor->GetOwnerNode(); | |||
| const auto &in_iter = node_exec_cond.find(in_node); | |||
| if (in_iter == node_exec_cond.end()) { | |||
| GELOGE(FAILED, "Find for exec cond for node %s.", in_node->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| if (!in_iter->second.IsValid()) { | |||
| GELOGI("Remove data edge %s:%d->%s:%d.", in_node->GetName().c_str(), peer_out_anchor->GetIdx(), | |||
| merge_node->GetName().c_str(), in_data_anchor->GetIdx()); | |||
| if (GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor) != GRAPH_SUCCESS) { | |||
| GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", in_node->GetName().c_str(), peer_out_anchor->GetIdx(), | |||
| merge_node->GetName().c_str(), in_data_anchor->GetIdx()); | |||
| return FAILED; | |||
| } | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace ge | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GRAPH_PASSES_BRANCH_LOGICAL_REMOVE_PASS_H_ | |||
| #define GE_GRAPH_PASSES_BRANCH_LOGICAL_REMOVE_PASS_H_ | |||
| #include "inc/graph_pass.h" | |||
| namespace ge { | |||
| class BranchLogicalRemovePass : public GraphPass { | |||
| public: | |||
| Status Run(ge::ComputeGraphPtr graph) override; | |||
| private: | |||
| static Status RemoveRedundantSwitch(const std::map<NodePtr, LogicOperatorItem> &node_exec_cond, | |||
| const NodePtr &switch_node); | |||
| static Status RemoveDeadInputForMerge(const std::map<NodePtr, LogicOperatorItem> &node_exec_cond, | |||
| const NodePtr &merge_node); | |||
| }; | |||
| } // namespace ge | |||
| #endif // GE_GRAPH_PASSES_BRANCH_LOGICAL_REMOVE_PASS_H_ | |||
| @@ -75,6 +75,7 @@ | |||
| #include "graph/passes/var_is_initialized_op_pass.h" | |||
| #include "graph/passes/variable_prepare_op_pass.h" | |||
| #include "graph/passes/mark_force_unknown_for_cond_pass.h" | |||
| #include "graph/passes/branch_logical_remove_pass.h" | |||
| #include "graph/preprocess/insert_op/util_insert_aipp_op.h" | |||
| #include "graph/utils/type_utils.h" | |||
| #include "inc/pass_manager.h" | |||
| @@ -1725,6 +1726,7 @@ Status GraphPrepare::PrepareDynShape(const GraphNodePtr &graph_node, const std:: | |||
| PP_RUN_AND_DUMP("CheckAndUpdateInput", CheckAndUpdateInput, user_input, graph_node->GetOptions()); | |||
| PP_RUN_AND_DUMP("GraphEquivalentTransformation", GraphEquivalentTransformation); | |||
| PP_RUN_AND_DUMP("ProcessOutput", ProcessNetOutput); | |||
| PP_RUN_AND_DUMP("RemoveLogicalDeadBranch", RemoveLogicalDeadBranch); | |||
| PP_RUN_AND_DUMP("ProcessMultiBatch", multibatch::ProcessMultiBatch, compute_graph_); | |||
| PP_RUN_AND_DUMP("InsertAipp", TryDoAipp); | |||
| PP_RUN_AND_DUMP("ProcessBeforeInfershape", ProcessBeforeInfershape); | |||
| @@ -2181,6 +2183,37 @@ Status GraphPrepare::GraphEquivalentTransformation() { | |||
| return GEPass(compute_graph_).Run(names_to_pass); | |||
| } | |||
| Status GraphPrepare::RemoveLogicalDeadBranch() { | |||
| GE_TIMESTAMP_START(BranchLogicalRemovePass); | |||
| PassManager logical_dead_branch_remove_pass; | |||
| GE_CHK_STATUS_RET(logical_dead_branch_remove_pass.AddPass("RemoveLogicalDeadBranch::BranchLogicalRemovePass", | |||
| new (std::nothrow) BranchLogicalRemovePass)) | |||
| auto ret = logical_dead_branch_remove_pass.Run(compute_graph_); | |||
| if ((ret != SUCCESS) && (ret != NOT_CHANGED)) { | |||
| GELOGE(ret, "Run logical_dead_branch_remove_pass for RemoveLogicalDeadBranch failed, ret:%d.", ret); | |||
| return ret; | |||
| } | |||
| NamesToPass names_to_passes; | |||
| MergePass merge_pass; | |||
| names_to_passes.emplace_back("MergePass", &merge_pass); | |||
| ret = GEPass(compute_graph_).Run(names_to_passes); | |||
| if ((ret != SUCCESS) && (ret != NOT_CHANGED)) { | |||
| GELOGE(ret, "Run merge_pass for RemoveLogicalDeadBranch failed, ret:%d.", ret); | |||
| return ret; | |||
| } | |||
| PassManager prune_pass; | |||
| GE_CHK_STATUS_RET(prune_pass.AddPass("RemoveLogicalDeadBranch::PrunePass", new (std::nothrow) PrunePass)) | |||
| if ((ret != SUCCESS) && (ret != NOT_CHANGED)) { | |||
| GELOGE(ret, "Run prune_pass for RemoveLogicalDeadBranch failed, ret:%d.", ret); | |||
| return ret; | |||
| } | |||
| GE_TIMESTAMP_END(BranchLogicalRemovePass, "GraphPrepare::RemoveLogicalDeadBranch"); | |||
| return SUCCESS; | |||
| } | |||
| Status GraphPrepare::ProcessBeforeInfershape() { | |||
| NamesToPass names_to_passes; | |||
| CondRemovePass condition_remove_pass; | |||
| @@ -88,6 +88,8 @@ class GraphPrepare { | |||
| void TypeConversionOfConstant(); | |||
| bool IsDynamicDims(const NodePtr &input_node); | |||
| Status RemoveLogicalDeadBranch(); | |||
| ge::ComputeGraphPtr compute_graph_; | |||
| GraphManagerOptions options_; | |||
| uint64_t session_id_ = 0; | |||
| @@ -258,6 +258,7 @@ set(COMMON_SRC_FILES | |||
| "${GE_CODE_DIR}/ge/graph/passes/common_subexpression_elimination_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/transop_symmetry_elimination_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/save_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/branch_logical_remove_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/switch_dead_branch_elimination.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/switch_logic_remove_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/switch_data_edges_bypass.cc" | |||
| @@ -508,6 +509,7 @@ set(GRAPH_PASS_COMMON_SRC_FILES | |||
| "${GE_CODE_DIR}/ge/graph/passes/addn_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/save_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/merge_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/branch_logical_remove_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/switch_logic_remove_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/assert_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/dropout_pass.cc" | |||
| @@ -671,6 +673,7 @@ set(PASS_TEST_FILES | |||
| "graph/passes/addn_pass_unittest.cc" | |||
| "graph/passes/save_pass_unittest.cc" | |||
| "graph/passes/merge_pass_unittest.cc" | |||
| "graph/passes/branch_logical_remove_pass_unittest.cc" | |||
| "graph/passes/switch_logic_remove_pass_unittest.cc" | |||
| "graph/passes/cond_branch_v1_unittest.cc" | |||
| "graph/passes/loop_branch_v1_unittest.cc" | |||