| @@ -329,6 +329,7 @@ set(TRAIN_SRC_LIST | |||||
| "graph/passes/stop_gradient_pass.cc" | "graph/passes/stop_gradient_pass.cc" | ||||
| "graph/passes/subgraph_pass.cc" | "graph/passes/subgraph_pass.cc" | ||||
| "graph/passes/data_pass.cc" | "graph/passes/data_pass.cc" | ||||
| "graph/passes/branch_logical_remove_pass.cc" | |||||
| "graph/passes/switch_data_edges_bypass.cc" | "graph/passes/switch_data_edges_bypass.cc" | ||||
| "graph/passes/switch_logic_remove_pass.cc" | "graph/passes/switch_logic_remove_pass.cc" | ||||
| "graph/passes/merge_to_stream_merge_pass.cc" | "graph/passes/merge_to_stream_merge_pass.cc" | ||||
| @@ -626,6 +627,7 @@ set(INFER_SRC_LIST | |||||
| "graph/passes/useless_control_out_remove_pass.cc" | "graph/passes/useless_control_out_remove_pass.cc" | ||||
| "graph/passes/transop_symmetry_elimination_pass.cc" | "graph/passes/transop_symmetry_elimination_pass.cc" | ||||
| "graph/passes/save_pass.cc" | "graph/passes/save_pass.cc" | ||||
| "graph/passes/branch_logical_remove_pass.cc" | |||||
| "graph/passes/switch_dead_branch_elimination.cc" | "graph/passes/switch_dead_branch_elimination.cc" | ||||
| "graph/passes/switch_logic_remove_pass.cc" | "graph/passes/switch_logic_remove_pass.cc" | ||||
| "graph/passes/switch_data_edges_bypass.cc" | "graph/passes/switch_data_edges_bypass.cc" | ||||
| @@ -203,6 +203,7 @@ OMG_HOST_SRC_FILES := \ | |||||
| graph/passes/common_subexpression_elimination_pass.cc \ | graph/passes/common_subexpression_elimination_pass.cc \ | ||||
| graph/passes/transop_symmetry_elimination_pass.cc \ | graph/passes/transop_symmetry_elimination_pass.cc \ | ||||
| graph/passes/save_pass.cc \ | graph/passes/save_pass.cc \ | ||||
| graph/passes/branch_logical_remove_pass.cc \ | |||||
| graph/passes/switch_dead_branch_elimination.cc \ | graph/passes/switch_dead_branch_elimination.cc \ | ||||
| graph/passes/switch_logic_remove_pass.cc \ | graph/passes/switch_logic_remove_pass.cc \ | ||||
| graph/passes/switch_data_edges_bypass.cc \ | graph/passes/switch_data_edges_bypass.cc \ | ||||
| @@ -218,6 +218,7 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
| graph/passes/stop_gradient_pass.cc \ | graph/passes/stop_gradient_pass.cc \ | ||||
| graph/passes/subgraph_pass.cc \ | graph/passes/subgraph_pass.cc \ | ||||
| graph/passes/data_pass.cc \ | graph/passes/data_pass.cc \ | ||||
| graph/passes/branch_logical_remove_pass.cc \ | |||||
| graph/passes/switch_data_edges_bypass.cc \ | graph/passes/switch_data_edges_bypass.cc \ | ||||
| graph/passes/switch_logic_remove_pass.cc \ | graph/passes/switch_logic_remove_pass.cc \ | ||||
| graph/passes/merge_to_stream_merge_pass.cc \ | graph/passes/merge_to_stream_merge_pass.cc \ | ||||
| @@ -0,0 +1,133 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/passes/branch_logical_remove_pass.h" | |||||
| #include "graph/utils/node_utils.h" | |||||
| namespace ge { | |||||
| Status BranchLogicalRemovePass::Run(ComputeGraphPtr graph) { | |||||
| BranchExecCondCalculator calculator(graph); | |||||
| if (calculator.Calculate() != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "Calculate branch exec cond for removing logical dead branch failed."); | |||||
| return FAILED; | |||||
| } | |||||
| const auto &node_exec_cond = calculator.GetBranchExecCond(); | |||||
| if (node_exec_cond.empty()) { | |||||
| GELOGI("No branch in graph %s, skip.", graph->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| for (const auto &node : graph->GetDirectNode()) { | |||||
| const std::string &type = NodeUtils::GetNodeType(node); | |||||
| if ((type == SWITCH) || (type == REFSWITCH)) { | |||||
| if (RemoveRedundantSwitch(node_exec_cond, node) != SUCCESS) { | |||||
| GELOGE(FAILED, "Remove redundant switch %s failed.", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } else if ((type == MERGE) || (type == REFMERGE)) { | |||||
| if (RemoveDeadInputForMerge(node_exec_cond, node) != SUCCESS) { | |||||
| GELOGE(FAILED, "Remove dead input for merge %s failed.", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status BranchLogicalRemovePass::RemoveRedundantSwitch(const std::map<NodePtr, LogicOperatorItem> &node_exec_cond, | |||||
| const NodePtr &switch_node) { | |||||
| const auto &iter = node_exec_cond.find(switch_node); | |||||
| if (iter == node_exec_cond.end()) { | |||||
| GELOGE(FAILED, "Find for exec cond for node %s.", switch_node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (!iter->second.IsValid()) { | |||||
| return SUCCESS; | |||||
| } | |||||
| for (const auto &pair : switch_node->GetOutDataNodesAndAnchors()) { | |||||
| const auto &out_node = pair.first; | |||||
| const auto &out_iter = node_exec_cond.find(out_node); | |||||
| if (out_iter == node_exec_cond.end()) { | |||||
| GELOGE(FAILED, "Find for exec cond for node %s.", out_node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (!out_iter->second.IsValid()) { | |||||
| GELOGI("Remove data edge %s:%d->%s:%d.", switch_node->GetName().c_str(), | |||||
| pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); | |||||
| if (GraphUtils::RemoveEdge(pair.second->GetPeerOutAnchor(), pair.second) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", switch_node->GetName().c_str(), | |||||
| pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); | |||||
| return FAILED; | |||||
| } | |||||
| } else if (iter->second.String() == out_iter->second.String()) { | |||||
| GELOGI("Remove data edge %s:%d->%s:%d.", switch_node->GetName().c_str(), | |||||
| pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); | |||||
| if (GraphUtils::RemoveEdge(pair.second->GetPeerOutAnchor(), pair.second) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", switch_node->GetName().c_str(), | |||||
| pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); | |||||
| return FAILED; | |||||
| } | |||||
| const auto data_input_anchor = switch_node->GetInDataAnchor(SWITCH_DATA_INPUT); | |||||
| GE_CHECK_NOTNULL(data_input_anchor); | |||||
| const auto &peer_out_anchor = data_input_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(peer_out_anchor); | |||||
| GELOGI("Add data edge %s:%d->%s:%d.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); | |||||
| if (GraphUtils::AddEdge(peer_out_anchor, pair.second) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "Add data edge %s:%d->%s:%d failed.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status BranchLogicalRemovePass::RemoveDeadInputForMerge(const std::map<NodePtr, LogicOperatorItem> &node_exec_cond, | |||||
| const NodePtr &merge_node) { | |||||
| const auto &iter = node_exec_cond.find(merge_node); | |||||
| if (iter == node_exec_cond.end()) { | |||||
| GELOGE(FAILED, "Find for exec cond for node %s.", merge_node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (!iter->second.IsValid()) { | |||||
| return SUCCESS; | |||||
| } | |||||
| for (const auto &in_data_anchor : merge_node->GetAllInDataAnchors()) { | |||||
| const auto &peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| if (peer_out_anchor == nullptr) { continue; } | |||||
| const auto &in_node = peer_out_anchor->GetOwnerNode(); | |||||
| const auto &in_iter = node_exec_cond.find(in_node); | |||||
| if (in_iter == node_exec_cond.end()) { | |||||
| GELOGE(FAILED, "Find for exec cond for node %s.", in_node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (!in_iter->second.IsValid()) { | |||||
| GELOGI("Remove data edge %s:%d->%s:%d.", in_node->GetName().c_str(), peer_out_anchor->GetIdx(), | |||||
| merge_node->GetName().c_str(), in_data_anchor->GetIdx()); | |||||
| if (GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", in_node->GetName().c_str(), peer_out_anchor->GetIdx(), | |||||
| merge_node->GetName().c_str(), in_data_anchor->GetIdx()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GRAPH_PASSES_BRANCH_LOGICAL_REMOVE_PASS_H_ | |||||
| #define GE_GRAPH_PASSES_BRANCH_LOGICAL_REMOVE_PASS_H_ | |||||
| #include "inc/graph_pass.h" | |||||
| namespace ge { | |||||
| class BranchLogicalRemovePass : public GraphPass { | |||||
| public: | |||||
| Status Run(ge::ComputeGraphPtr graph) override; | |||||
| private: | |||||
| static Status RemoveRedundantSwitch(const std::map<NodePtr, LogicOperatorItem> &node_exec_cond, | |||||
| const NodePtr &switch_node); | |||||
| static Status RemoveDeadInputForMerge(const std::map<NodePtr, LogicOperatorItem> &node_exec_cond, | |||||
| const NodePtr &merge_node); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_GRAPH_PASSES_BRANCH_LOGICAL_REMOVE_PASS_H_ | |||||
| @@ -75,6 +75,7 @@ | |||||
| #include "graph/passes/var_is_initialized_op_pass.h" | #include "graph/passes/var_is_initialized_op_pass.h" | ||||
| #include "graph/passes/variable_prepare_op_pass.h" | #include "graph/passes/variable_prepare_op_pass.h" | ||||
| #include "graph/passes/mark_force_unknown_for_cond_pass.h" | #include "graph/passes/mark_force_unknown_for_cond_pass.h" | ||||
| #include "graph/passes/branch_logical_remove_pass.h" | |||||
| #include "graph/preprocess/insert_op/util_insert_aipp_op.h" | #include "graph/preprocess/insert_op/util_insert_aipp_op.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "inc/pass_manager.h" | #include "inc/pass_manager.h" | ||||
| @@ -1725,6 +1726,7 @@ Status GraphPrepare::PrepareDynShape(const GraphNodePtr &graph_node, const std:: | |||||
| PP_RUN_AND_DUMP("CheckAndUpdateInput", CheckAndUpdateInput, user_input, graph_node->GetOptions()); | PP_RUN_AND_DUMP("CheckAndUpdateInput", CheckAndUpdateInput, user_input, graph_node->GetOptions()); | ||||
| PP_RUN_AND_DUMP("GraphEquivalentTransformation", GraphEquivalentTransformation); | PP_RUN_AND_DUMP("GraphEquivalentTransformation", GraphEquivalentTransformation); | ||||
| PP_RUN_AND_DUMP("ProcessOutput", ProcessNetOutput); | PP_RUN_AND_DUMP("ProcessOutput", ProcessNetOutput); | ||||
| PP_RUN_AND_DUMP("RemoveLogicalDeadBranch", RemoveLogicalDeadBranch); | |||||
| PP_RUN_AND_DUMP("ProcessMultiBatch", multibatch::ProcessMultiBatch, compute_graph_); | PP_RUN_AND_DUMP("ProcessMultiBatch", multibatch::ProcessMultiBatch, compute_graph_); | ||||
| PP_RUN_AND_DUMP("InsertAipp", TryDoAipp); | PP_RUN_AND_DUMP("InsertAipp", TryDoAipp); | ||||
| PP_RUN_AND_DUMP("ProcessBeforeInfershape", ProcessBeforeInfershape); | PP_RUN_AND_DUMP("ProcessBeforeInfershape", ProcessBeforeInfershape); | ||||
| @@ -2181,6 +2183,37 @@ Status GraphPrepare::GraphEquivalentTransformation() { | |||||
| return GEPass(compute_graph_).Run(names_to_pass); | return GEPass(compute_graph_).Run(names_to_pass); | ||||
| } | } | ||||
| Status GraphPrepare::RemoveLogicalDeadBranch() { | |||||
| GE_TIMESTAMP_START(BranchLogicalRemovePass); | |||||
| PassManager logical_dead_branch_remove_pass; | |||||
| GE_CHK_STATUS_RET(logical_dead_branch_remove_pass.AddPass("RemoveLogicalDeadBranch::BranchLogicalRemovePass", | |||||
| new (std::nothrow) BranchLogicalRemovePass)) | |||||
| auto ret = logical_dead_branch_remove_pass.Run(compute_graph_); | |||||
| if ((ret != SUCCESS) && (ret != NOT_CHANGED)) { | |||||
| GELOGE(ret, "Run logical_dead_branch_remove_pass for RemoveLogicalDeadBranch failed, ret:%d.", ret); | |||||
| return ret; | |||||
| } | |||||
| NamesToPass names_to_passes; | |||||
| MergePass merge_pass; | |||||
| names_to_passes.emplace_back("MergePass", &merge_pass); | |||||
| ret = GEPass(compute_graph_).Run(names_to_passes); | |||||
| if ((ret != SUCCESS) && (ret != NOT_CHANGED)) { | |||||
| GELOGE(ret, "Run merge_pass for RemoveLogicalDeadBranch failed, ret:%d.", ret); | |||||
| return ret; | |||||
| } | |||||
| PassManager prune_pass; | |||||
| GE_CHK_STATUS_RET(prune_pass.AddPass("RemoveLogicalDeadBranch::PrunePass", new (std::nothrow) PrunePass)) | |||||
| if ((ret != SUCCESS) && (ret != NOT_CHANGED)) { | |||||
| GELOGE(ret, "Run prune_pass for RemoveLogicalDeadBranch failed, ret:%d.", ret); | |||||
| return ret; | |||||
| } | |||||
| GE_TIMESTAMP_END(BranchLogicalRemovePass, "GraphPrepare::RemoveLogicalDeadBranch"); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GraphPrepare::ProcessBeforeInfershape() { | Status GraphPrepare::ProcessBeforeInfershape() { | ||||
| NamesToPass names_to_passes; | NamesToPass names_to_passes; | ||||
| CondRemovePass condition_remove_pass; | CondRemovePass condition_remove_pass; | ||||
| @@ -88,6 +88,8 @@ class GraphPrepare { | |||||
| void TypeConversionOfConstant(); | void TypeConversionOfConstant(); | ||||
| bool IsDynamicDims(const NodePtr &input_node); | bool IsDynamicDims(const NodePtr &input_node); | ||||
| Status RemoveLogicalDeadBranch(); | |||||
| ge::ComputeGraphPtr compute_graph_; | ge::ComputeGraphPtr compute_graph_; | ||||
| GraphManagerOptions options_; | GraphManagerOptions options_; | ||||
| uint64_t session_id_ = 0; | uint64_t session_id_ = 0; | ||||
| @@ -258,6 +258,7 @@ set(COMMON_SRC_FILES | |||||
| "${GE_CODE_DIR}/ge/graph/passes/common_subexpression_elimination_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/common_subexpression_elimination_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/transop_symmetry_elimination_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/transop_symmetry_elimination_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/save_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/save_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/branch_logical_remove_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/switch_dead_branch_elimination.cc" | "${GE_CODE_DIR}/ge/graph/passes/switch_dead_branch_elimination.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/switch_logic_remove_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/switch_logic_remove_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/switch_data_edges_bypass.cc" | "${GE_CODE_DIR}/ge/graph/passes/switch_data_edges_bypass.cc" | ||||
| @@ -508,6 +509,7 @@ set(GRAPH_PASS_COMMON_SRC_FILES | |||||
| "${GE_CODE_DIR}/ge/graph/passes/addn_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/addn_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/save_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/save_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/merge_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/merge_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/branch_logical_remove_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/switch_logic_remove_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/switch_logic_remove_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/assert_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/assert_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/dropout_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/dropout_pass.cc" | ||||
| @@ -671,6 +673,7 @@ set(PASS_TEST_FILES | |||||
| "graph/passes/addn_pass_unittest.cc" | "graph/passes/addn_pass_unittest.cc" | ||||
| "graph/passes/save_pass_unittest.cc" | "graph/passes/save_pass_unittest.cc" | ||||
| "graph/passes/merge_pass_unittest.cc" | "graph/passes/merge_pass_unittest.cc" | ||||
| "graph/passes/branch_logical_remove_pass_unittest.cc" | |||||
| "graph/passes/switch_logic_remove_pass_unittest.cc" | "graph/passes/switch_logic_remove_pass_unittest.cc" | ||||
| "graph/passes/cond_branch_v1_unittest.cc" | "graph/passes/cond_branch_v1_unittest.cc" | ||||
| "graph/passes/loop_branch_v1_unittest.cc" | "graph/passes/loop_branch_v1_unittest.cc" | ||||