From 2874ec935f700ce196c28a5b70c7d8070f7c9ff0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=8D=8E?= Date: Fri, 25 Jun 2021 14:08:54 +0800 Subject: [PATCH] fix parallel group pass --- ge/graph/passes/parallel_group_pass.cc | 58 +++++++++----- ge/graph/passes/parallel_group_pass.h | 1 + .../passes/parallel_group_pass_unittest.cc | 80 ++++++++++++++++++- 3 files changed, 119 insertions(+), 20 deletions(-) diff --git a/ge/graph/passes/parallel_group_pass.cc b/ge/graph/passes/parallel_group_pass.cc index 9c93f6cf..795002f1 100644 --- a/ge/graph/passes/parallel_group_pass.cc +++ b/ge/graph/passes/parallel_group_pass.cc @@ -15,7 +15,7 @@ */ #include "graph/passes/parallel_group_pass.h" - +#include #include "framework/common/debug/ge_log.h" #include "common/ge/ge_util.h" #include "framework/common/ge_inner_error_codes.h" @@ -299,24 +299,19 @@ Status ParallelGroupPass::ReplaceWithSwitchAndMerge(NodePtr pre_node, NodePtr cu for (const auto &switch_node : cur_itr->second.first) { int64_t pre_id = pre_node->GetOpDesc()->GetId(); int64_t switch_id = switch_node->GetOpDesc()->GetId(); - // avoid ring - if (pre_id > switch_id) { - auto merge_node = cur_itr->second.second; - if (AddCtrlEdge(merge_node, pre_node) != SUCCESS) { - GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", - pre_node->GetName().c_str(), switch_node->GetName().c_str()); - REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", - pre_node->GetName().c_str(), switch_node->GetName().c_str()); - return FAILED; - } - } else { - if (AddCtrlEdge(pre_node, switch_node) != SUCCESS) { - GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", - pre_node->GetName().c_str(), switch_node->GetName().c_str()); - REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", - pre_node->GetName().c_str(), switch_node->GetName().c_str()); - return FAILED; - } + NodePtr first_node = pre_node; + NodePtr second_node = switch_node; + if (pre_id > switch_id && IsIndirectConnect(switch_node, pre_node)) { + // avoid ring, merge->pre_node + first_node = cur_itr->second.second; + second_node = pre_node; + } + if (AddCtrlEdge(first_node, second_node) != SUCCESS) { + GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", + first_node->GetName().c_str(), second_node->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", + first_node->GetName().c_str(), second_node->GetName().c_str()); + return FAILED; } } } else { @@ -345,4 +340,29 @@ bool ParallelGroupPass::IsWhileStreamSwitch(OpDescPtr switch_op_desc) { return (AttrUtils::GetInt(switch_op_desc, ATTR_NAME_STREAM_SWITCH_TYPE, stream_switch_type) && stream_switch_type == kLoopType); } + +bool ParallelGroupPass::IsIndirectConnect(const NodePtr &node_a, const NodePtr &node_b) { + if (node_a == nullptr || node_b == nullptr) { + GELOGW("node_a or node_b is nullptr."); + return false; + } + int64_t end_id = node_b->GetOpDesc()->GetId(); + std::queue nodes; + nodes.push(node_a); + while (!nodes.empty()) { + NodePtr tmp_node = nodes.front(); + nodes.pop(); + if (tmp_node == nullptr || tmp_node->GetOpDesc() == nullptr || + tmp_node->GetOpDesc()->GetId() > end_id) { + continue; + } + if (tmp_node == node_b) { + return true; + } + for (const auto &out_node : tmp_node->GetOutAllNodes()) { + nodes.push(out_node); + } + } + return false; +} } // namespace ge diff --git a/ge/graph/passes/parallel_group_pass.h b/ge/graph/passes/parallel_group_pass.h index cdcdabab..93b0b158 100644 --- a/ge/graph/passes/parallel_group_pass.h +++ b/ge/graph/passes/parallel_group_pass.h @@ -48,6 +48,7 @@ class ParallelGroupPass : public GraphPass { bool IsBigSmallLoopStreamSwitch(OpDescPtr switch_op_desc); bool IsWhileStreamSwitch(OpDescPtr switch_op_desc); + bool IsIndirectConnect(const NodePtr &node_a, const NodePtr &node_b); }; } // namespace ge #endif // GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H diff --git a/tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc b/tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc index 374fe837..a6c3ff6a 100644 --- a/tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc @@ -19,7 +19,8 @@ #include #define private public - +#include "inc/graph/ge_local_context.h" +#include "inc/external/ge/ge_api_types.h" #include "common/ge_inner_error_codes.h" #include "inc/pass_manager.h" #include "utils/graph_utils.h" @@ -225,6 +226,70 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { output_true_node_->GetOpDesc()->SetIsInputConst({false}); } + void BuildDefaultGraph3() { + /// input + /// \ + /// sqrt pred + /// \ / + /// Switch + /// | | + /// F T ------ + /// / \_/_ \ + /// / / \ \ + /// Merge sqrt2 sqrt3 + /// / \ \ + /// sqrt1 \ relu + /// \ \ + /// \ sqrt4 + /// \ / + /// Merge1 + input_node_ = NewNode("input", RELU, 0, 1); + AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + pred_node_ = NewNode("pred", GREATER, 2, 1); + sqrt_node_ = NewNode("sqrt", SQRT, 1, 1); + cast_node_ = NewNode("cast", CAST, 2, 2); + + switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1); + AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true); + switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1); + AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false); + output_false_node_ = NewNode("false_output", RELU, 1, 2); + AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + output_true_node_ = NewNode("true_output", RELU, 1, 2); + AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + merge_node_ = NewNode("merge", STREAMMERGE, 2, 1); + sqrt_node1_ = NewNode("sqrt1", SQRT, 1, 1); + AttrUtils::SetStr(sqrt_node1_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + sqrt_node2_ = NewNode("sqrt2", SQRT, 1, 1); + AttrUtils::SetStr(sqrt_node2_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + sqrt_node3_ = NewNode("sqrt3", SQRT, 1, 1); + relu_node_ = NewNode("relu", RELU, 1, 1); + sqrt_node4_ = NewNode("sqrt4", SQRT, 1, 1); + AttrUtils::SetStr(sqrt_node4_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + merge_node1_ = NewNode("merge1", STREAMMERGE, 2, 1); + + GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1)); + GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0)); + GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0)); + + GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1)); + GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), sqrt_node2_->GetInDataAnchor(0)); + GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), sqrt_node3_->GetInDataAnchor(0)); + + GraphUtils::AddEdge(merge_node_->GetOutDataAnchor(0), sqrt_node1_->GetInDataAnchor(0)); + GraphUtils::AddEdge(sqrt_node3_->GetOutDataAnchor(0), relu_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(relu_node_->GetOutDataAnchor(0), sqrt_node4_->GetInDataAnchor(0)); + GraphUtils::AddEdge(sqrt_node2_->GetOutDataAnchor(0), merge_node1_->GetInDataAnchor(0)); + GraphUtils::AddEdge(sqrt_node4_->GetOutDataAnchor(0), merge_node1_->GetInDataAnchor(1)); + output_false_node_->GetOpDesc()->SetIsInputConst({false}); + output_true_node_->GetOpDesc()->SetIsInputConst({false}); + } + ComputeGraphPtr graph_; ComputeGraphPtr sub_graph_; GeTensorDescPtr default_tensor_desc_; @@ -235,6 +300,9 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { NodePtr cast_node1_; NodePtr sqrt_node_; NodePtr sqrt_node1_; + NodePtr sqrt_node2_; + NodePtr sqrt_node3_; + NodePtr sqrt_node4_; NodePtr input_node_; NodePtr input_node1_; NodePtr switch_node_t; @@ -278,6 +346,16 @@ TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph2) { EXPECT_EQ(true, input_node1_->GetOutControlAnchor()->IsLinkedWith(cast_node1_->GetInControlAnchor())); } +TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph3) { + std::map options; + options.emplace(OPTION_GRAPH_RUN_MODE, "1"); + GetThreadLocalContext().SetGraphOption(options); + BuildDefaultGraph3(); + auto ret = pass_.Run(graph_); + EXPECT_EQ(ret, GRAPH_SUCCESS); + EXPECT_EQ(true, merge_node1_->GetOutControlAnchor()->IsLinkedWith(sqrt_node1_->GetInControlAnchor())); +} + TEST_F(UtestGraphPassesParallelGgroupPass, normal_subgraph) { BuildDefaultGraph1(); NodePtr input_node1 = NewNode("input1", RELU, 0, 1, true);