Browse Source

remove logical dead branch

pull/2074/head
chenyemeng 4 years ago
parent
commit
8f038ea4e6
8 changed files with 212 additions and 0 deletions
  1. +2
    -0
      ge/CMakeLists.txt
  2. +1
    -0
      ge/ge_inference.mk
  3. +1
    -0
      ge/ge_runner.mk
  4. +133
    -0
      ge/graph/passes/branch_logical_remove_pass.cc
  5. +37
    -0
      ge/graph/passes/branch_logical_remove_pass.h
  6. +33
    -0
      ge/graph/preprocess/graph_preprocess.cc
  7. +2
    -0
      ge/graph/preprocess/graph_preprocess.h
  8. +3
    -0
      tests/ut/ge/CMakeLists.txt

+ 2
- 0
ge/CMakeLists.txt View File

@@ -329,6 +329,7 @@ set(TRAIN_SRC_LIST
"graph/passes/stop_gradient_pass.cc" "graph/passes/stop_gradient_pass.cc"
"graph/passes/subgraph_pass.cc" "graph/passes/subgraph_pass.cc"
"graph/passes/data_pass.cc" "graph/passes/data_pass.cc"
"graph/passes/branch_logical_remove_pass.cc"
"graph/passes/switch_data_edges_bypass.cc" "graph/passes/switch_data_edges_bypass.cc"
"graph/passes/switch_logic_remove_pass.cc" "graph/passes/switch_logic_remove_pass.cc"
"graph/passes/merge_to_stream_merge_pass.cc" "graph/passes/merge_to_stream_merge_pass.cc"
@@ -626,6 +627,7 @@ set(INFER_SRC_LIST
"graph/passes/useless_control_out_remove_pass.cc" "graph/passes/useless_control_out_remove_pass.cc"
"graph/passes/transop_symmetry_elimination_pass.cc" "graph/passes/transop_symmetry_elimination_pass.cc"
"graph/passes/save_pass.cc" "graph/passes/save_pass.cc"
"graph/passes/branch_logical_remove_pass.cc"
"graph/passes/switch_dead_branch_elimination.cc" "graph/passes/switch_dead_branch_elimination.cc"
"graph/passes/switch_logic_remove_pass.cc" "graph/passes/switch_logic_remove_pass.cc"
"graph/passes/switch_data_edges_bypass.cc" "graph/passes/switch_data_edges_bypass.cc"


+ 1
- 0
ge/ge_inference.mk View File

@@ -203,6 +203,7 @@ OMG_HOST_SRC_FILES := \
graph/passes/common_subexpression_elimination_pass.cc \ graph/passes/common_subexpression_elimination_pass.cc \
graph/passes/transop_symmetry_elimination_pass.cc \ graph/passes/transop_symmetry_elimination_pass.cc \
graph/passes/save_pass.cc \ graph/passes/save_pass.cc \
graph/passes/branch_logical_remove_pass.cc \
graph/passes/switch_dead_branch_elimination.cc \ graph/passes/switch_dead_branch_elimination.cc \
graph/passes/switch_logic_remove_pass.cc \ graph/passes/switch_logic_remove_pass.cc \
graph/passes/switch_data_edges_bypass.cc \ graph/passes/switch_data_edges_bypass.cc \


+ 1
- 0
ge/ge_runner.mk View File

@@ -218,6 +218,7 @@ LIBGE_LOCAL_SRC_FILES := \
graph/passes/stop_gradient_pass.cc \ graph/passes/stop_gradient_pass.cc \
graph/passes/subgraph_pass.cc \ graph/passes/subgraph_pass.cc \
graph/passes/data_pass.cc \ graph/passes/data_pass.cc \
graph/passes/branch_logical_remove_pass.cc \
graph/passes/switch_data_edges_bypass.cc \ graph/passes/switch_data_edges_bypass.cc \
graph/passes/switch_logic_remove_pass.cc \ graph/passes/switch_logic_remove_pass.cc \
graph/passes/merge_to_stream_merge_pass.cc \ graph/passes/merge_to_stream_merge_pass.cc \


+ 133
- 0
ge/graph/passes/branch_logical_remove_pass.cc View File

@@ -0,0 +1,133 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/passes/branch_logical_remove_pass.h"
#include "graph/utils/node_utils.h"

namespace ge {
Status BranchLogicalRemovePass::Run(ComputeGraphPtr graph) {
BranchExecCondCalculator calculator(graph);
if (calculator.Calculate() != GRAPH_SUCCESS) {
GELOGE(FAILED, "Calculate branch exec cond for removing logical dead branch failed.");
return FAILED;
}
const auto &node_exec_cond = calculator.GetBranchExecCond();
if (node_exec_cond.empty()) {
GELOGI("No branch in graph %s, skip.", graph->GetName().c_str());
return SUCCESS;
}

for (const auto &node : graph->GetDirectNode()) {
const std::string &type = NodeUtils::GetNodeType(node);
if ((type == SWITCH) || (type == REFSWITCH)) {
if (RemoveRedundantSwitch(node_exec_cond, node) != SUCCESS) {
GELOGE(FAILED, "Remove redundant switch %s failed.", node->GetName().c_str());
return FAILED;
}
} else if ((type == MERGE) || (type == REFMERGE)) {
if (RemoveDeadInputForMerge(node_exec_cond, node) != SUCCESS) {
GELOGE(FAILED, "Remove dead input for merge %s failed.", node->GetName().c_str());
return FAILED;
}
}
}

return SUCCESS;
}

Status BranchLogicalRemovePass::RemoveRedundantSwitch(const std::map<NodePtr, LogicOperatorItem> &node_exec_cond,
const NodePtr &switch_node) {
const auto &iter = node_exec_cond.find(switch_node);
if (iter == node_exec_cond.end()) {
GELOGE(FAILED, "Find for exec cond for node %s.", switch_node->GetName().c_str());
return FAILED;
}
if (!iter->second.IsValid()) {
return SUCCESS;
}
for (const auto &pair : switch_node->GetOutDataNodesAndAnchors()) {
const auto &out_node = pair.first;
const auto &out_iter = node_exec_cond.find(out_node);
if (out_iter == node_exec_cond.end()) {
GELOGE(FAILED, "Find for exec cond for node %s.", out_node->GetName().c_str());
return FAILED;
}
if (!out_iter->second.IsValid()) {
GELOGI("Remove data edge %s:%d->%s:%d.", switch_node->GetName().c_str(),
pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx());
if (GraphUtils::RemoveEdge(pair.second->GetPeerOutAnchor(), pair.second) != GRAPH_SUCCESS) {
GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", switch_node->GetName().c_str(),
pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx());
return FAILED;
}
} else if (iter->second.String() == out_iter->second.String()) {
GELOGI("Remove data edge %s:%d->%s:%d.", switch_node->GetName().c_str(),
pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx());
if (GraphUtils::RemoveEdge(pair.second->GetPeerOutAnchor(), pair.second) != GRAPH_SUCCESS) {
GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", switch_node->GetName().c_str(),
pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx());
return FAILED;
}

const auto data_input_anchor = switch_node->GetInDataAnchor(SWITCH_DATA_INPUT);
GE_CHECK_NOTNULL(data_input_anchor);
const auto &peer_out_anchor = data_input_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_out_anchor);
GELOGI("Add data edge %s:%d->%s:%d.", peer_out_anchor->GetOwnerNode()->GetName().c_str(),
pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx());
if (GraphUtils::AddEdge(peer_out_anchor, pair.second) != GRAPH_SUCCESS) {
GELOGE(FAILED, "Add data edge %s:%d->%s:%d failed.", peer_out_anchor->GetOwnerNode()->GetName().c_str(),
pair.second->GetPeerOutAnchor()->GetIdx(), out_node->GetName().c_str(), pair.second->GetIdx());
return FAILED;
}
}
}

return SUCCESS;
}

Status BranchLogicalRemovePass::RemoveDeadInputForMerge(const std::map<NodePtr, LogicOperatorItem> &node_exec_cond,
const NodePtr &merge_node) {
const auto &iter = node_exec_cond.find(merge_node);
if (iter == node_exec_cond.end()) {
GELOGE(FAILED, "Find for exec cond for node %s.", merge_node->GetName().c_str());
return FAILED;
}
if (!iter->second.IsValid()) {
return SUCCESS;
}
for (const auto &in_data_anchor : merge_node->GetAllInDataAnchors()) {
const auto &peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
if (peer_out_anchor == nullptr) { continue; }
const auto &in_node = peer_out_anchor->GetOwnerNode();
const auto &in_iter = node_exec_cond.find(in_node);
if (in_iter == node_exec_cond.end()) {
GELOGE(FAILED, "Find for exec cond for node %s.", in_node->GetName().c_str());
return FAILED;
}
if (!in_iter->second.IsValid()) {
GELOGI("Remove data edge %s:%d->%s:%d.", in_node->GetName().c_str(), peer_out_anchor->GetIdx(),
merge_node->GetName().c_str(), in_data_anchor->GetIdx());
if (GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor) != GRAPH_SUCCESS) {
GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", in_node->GetName().c_str(), peer_out_anchor->GetIdx(),
merge_node->GetName().c_str(), in_data_anchor->GetIdx());
return FAILED;
}
}
}
return SUCCESS;
}
} // namespace ge

+ 37
- 0
ge/graph/passes/branch_logical_remove_pass.h View File

@@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GRAPH_PASSES_BRANCH_LOGICAL_REMOVE_PASS_H_
#define GE_GRAPH_PASSES_BRANCH_LOGICAL_REMOVE_PASS_H_

#include "inc/graph_pass.h"

namespace ge {

class BranchLogicalRemovePass : public GraphPass {
public:
Status Run(ge::ComputeGraphPtr graph) override;

private:
static Status RemoveRedundantSwitch(const std::map<NodePtr, LogicOperatorItem> &node_exec_cond,
const NodePtr &switch_node);

static Status RemoveDeadInputForMerge(const std::map<NodePtr, LogicOperatorItem> &node_exec_cond,
const NodePtr &merge_node);
};

} // namespace ge
#endif // GE_GRAPH_PASSES_BRANCH_LOGICAL_REMOVE_PASS_H_

+ 33
- 0
ge/graph/preprocess/graph_preprocess.cc View File

@@ -75,6 +75,7 @@
#include "graph/passes/var_is_initialized_op_pass.h" #include "graph/passes/var_is_initialized_op_pass.h"
#include "graph/passes/variable_prepare_op_pass.h" #include "graph/passes/variable_prepare_op_pass.h"
#include "graph/passes/mark_force_unknown_for_cond_pass.h" #include "graph/passes/mark_force_unknown_for_cond_pass.h"
#include "graph/passes/branch_logical_remove_pass.h"
#include "graph/preprocess/insert_op/util_insert_aipp_op.h" #include "graph/preprocess/insert_op/util_insert_aipp_op.h"
#include "graph/utils/type_utils.h" #include "graph/utils/type_utils.h"
#include "inc/pass_manager.h" #include "inc/pass_manager.h"
@@ -1725,6 +1726,7 @@ Status GraphPrepare::PrepareDynShape(const GraphNodePtr &graph_node, const std::
PP_RUN_AND_DUMP("CheckAndUpdateInput", CheckAndUpdateInput, user_input, graph_node->GetOptions()); PP_RUN_AND_DUMP("CheckAndUpdateInput", CheckAndUpdateInput, user_input, graph_node->GetOptions());
PP_RUN_AND_DUMP("GraphEquivalentTransformation", GraphEquivalentTransformation); PP_RUN_AND_DUMP("GraphEquivalentTransformation", GraphEquivalentTransformation);
PP_RUN_AND_DUMP("ProcessOutput", ProcessNetOutput); PP_RUN_AND_DUMP("ProcessOutput", ProcessNetOutput);
PP_RUN_AND_DUMP("RemoveLogicalDeadBranch", RemoveLogicalDeadBranch);
PP_RUN_AND_DUMP("ProcessMultiBatch", multibatch::ProcessMultiBatch, compute_graph_); PP_RUN_AND_DUMP("ProcessMultiBatch", multibatch::ProcessMultiBatch, compute_graph_);
PP_RUN_AND_DUMP("InsertAipp", TryDoAipp); PP_RUN_AND_DUMP("InsertAipp", TryDoAipp);
PP_RUN_AND_DUMP("ProcessBeforeInfershape", ProcessBeforeInfershape); PP_RUN_AND_DUMP("ProcessBeforeInfershape", ProcessBeforeInfershape);
@@ -2181,6 +2183,37 @@ Status GraphPrepare::GraphEquivalentTransformation() {
return GEPass(compute_graph_).Run(names_to_pass); return GEPass(compute_graph_).Run(names_to_pass);
} }


Status GraphPrepare::RemoveLogicalDeadBranch() {
GE_TIMESTAMP_START(BranchLogicalRemovePass);
PassManager logical_dead_branch_remove_pass;
GE_CHK_STATUS_RET(logical_dead_branch_remove_pass.AddPass("RemoveLogicalDeadBranch::BranchLogicalRemovePass",
new (std::nothrow) BranchLogicalRemovePass))
auto ret = logical_dead_branch_remove_pass.Run(compute_graph_);
if ((ret != SUCCESS) && (ret != NOT_CHANGED)) {
GELOGE(ret, "Run logical_dead_branch_remove_pass for RemoveLogicalDeadBranch failed, ret:%d.", ret);
return ret;
}

NamesToPass names_to_passes;
MergePass merge_pass;
names_to_passes.emplace_back("MergePass", &merge_pass);
ret = GEPass(compute_graph_).Run(names_to_passes);
if ((ret != SUCCESS) && (ret != NOT_CHANGED)) {
GELOGE(ret, "Run merge_pass for RemoveLogicalDeadBranch failed, ret:%d.", ret);
return ret;
}

PassManager prune_pass;
GE_CHK_STATUS_RET(prune_pass.AddPass("RemoveLogicalDeadBranch::PrunePass", new (std::nothrow) PrunePass))
if ((ret != SUCCESS) && (ret != NOT_CHANGED)) {
GELOGE(ret, "Run prune_pass for RemoveLogicalDeadBranch failed, ret:%d.", ret);
return ret;
}
GE_TIMESTAMP_END(BranchLogicalRemovePass, "GraphPrepare::RemoveLogicalDeadBranch");

return SUCCESS;
}

Status GraphPrepare::ProcessBeforeInfershape() { Status GraphPrepare::ProcessBeforeInfershape() {
NamesToPass names_to_passes; NamesToPass names_to_passes;
CondRemovePass condition_remove_pass; CondRemovePass condition_remove_pass;


+ 2
- 0
ge/graph/preprocess/graph_preprocess.h View File

@@ -88,6 +88,8 @@ class GraphPrepare {
void TypeConversionOfConstant(); void TypeConversionOfConstant();
bool IsDynamicDims(const NodePtr &input_node); bool IsDynamicDims(const NodePtr &input_node);


Status RemoveLogicalDeadBranch();

ge::ComputeGraphPtr compute_graph_; ge::ComputeGraphPtr compute_graph_;
GraphManagerOptions options_; GraphManagerOptions options_;
uint64_t session_id_ = 0; uint64_t session_id_ = 0;


+ 3
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -258,6 +258,7 @@ set(COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/common_subexpression_elimination_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/common_subexpression_elimination_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/transop_symmetry_elimination_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/transop_symmetry_elimination_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/save_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/save_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/branch_logical_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/switch_dead_branch_elimination.cc" "${GE_CODE_DIR}/ge/graph/passes/switch_dead_branch_elimination.cc"
"${GE_CODE_DIR}/ge/graph/passes/switch_logic_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/switch_logic_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/switch_data_edges_bypass.cc" "${GE_CODE_DIR}/ge/graph/passes/switch_data_edges_bypass.cc"
@@ -508,6 +509,7 @@ set(GRAPH_PASS_COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/addn_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/addn_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/save_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/save_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/merge_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/merge_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/branch_logical_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/switch_logic_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/switch_logic_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/assert_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/assert_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/dropout_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/dropout_pass.cc"
@@ -671,6 +673,7 @@ set(PASS_TEST_FILES
"graph/passes/addn_pass_unittest.cc" "graph/passes/addn_pass_unittest.cc"
"graph/passes/save_pass_unittest.cc" "graph/passes/save_pass_unittest.cc"
"graph/passes/merge_pass_unittest.cc" "graph/passes/merge_pass_unittest.cc"
"graph/passes/branch_logical_remove_pass_unittest.cc"
"graph/passes/switch_logic_remove_pass_unittest.cc" "graph/passes/switch_logic_remove_pass_unittest.cc"
"graph/passes/cond_branch_v1_unittest.cc" "graph/passes/cond_branch_v1_unittest.cc"
"graph/passes/loop_branch_v1_unittest.cc" "graph/passes/loop_branch_v1_unittest.cc"


Loading…
Cancel
Save