diff --git a/ge/graph/passes/switch_to_stream_switch_pass.cc b/ge/graph/passes/switch_to_stream_switch_pass.cc index 66a60ab9..af87dafa 100644 --- a/ge/graph/passes/switch_to_stream_switch_pass.cc +++ b/ge/graph/passes/switch_to_stream_switch_pass.cc @@ -565,13 +565,6 @@ NodePtr SwitchToStreamSwitchPass::CreateActiveNode(const ComputeGraphPtr &graph, op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); return nullptr, "Create StreamActive node failed."); - GE_IF_BOOL_EXEC(GraphUtils::AddEdge(node->GetOutControlAnchor(), active_node->GetInControlAnchor()) != SUCCESS, - REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed", - node->GetName().c_str(), node->GetType().c_str(), - active_node->GetName().c_str(), active_node->GetType().c_str()); - GELOGE(INTERNAL_ERROR, "add edge failed"); - return nullptr); - GE_IF_BOOL_EXEC(SetSwitchBranchNodeLabel(active_node, node_name) != SUCCESS, REPORT_CALL_ERROR("E19999", "Set switch branch node label:%s to node:%s(%s) failed", node_name.c_str(), active_node->GetName().c_str(), active_node->GetType().c_str()); diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index bb3c8dc8..b00b8ec8 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -2394,6 +2394,10 @@ Status HybridModelBuilder::CreateMergeEnterGroup(const NodePtr &node, NodeItem * // Enter --> StreamActive --> StreamMerge for (const auto &dst_node : node->GetOutControlNodes()) { GE_CHECK_NOTNULL(dst_node); + if (dst_node->GetType() != STREAMMERGE) { + GELOGI("[%s] Skip Not StreamMerge node [%s]", node->GetName().c_str(), dst_node->GetName().c_str()); + continue; + } NodeItem *dst_node_item = nullptr; GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), "[%s] failed to get or create node item", dst_node->GetName().c_str()); @@ -2459,7 +2463,7 @@ Status HybridModelBuilder::CreateStreamActiveGroup(const NodePtr &node, NodeItem if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsEnterNode)) { // Enter --> StreamActive --> StreamMerge return CreateMergeEnterGroup(node, node_item); - } else if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsIterationNode)) { + } else if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsIterationNode)) { // NextIteration --> StreamActive {-->} StreamMerge return CreateMergeIterationGroup(node, node_item); } diff --git a/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc b/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc index f8f37698..1037c764 100644 --- a/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc +++ b/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc @@ -121,8 +121,8 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { add1->GetOpDesc()->SetOpKernelLibName("AIcoreEngine"); auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); - auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1); - auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1); + auto value0 = CreateNode(*graph, "const1", CONSTANT, 0, 1); + auto value1 = CreateNode(*graph, "const2", CONSTANT, 0, 1); auto active1 = CreateNode(*graph, "active1", STREAMACTIVE, 0, 0); auto active2 = CreateNode(*graph, "active2", STREAMACTIVE, 0, 0); auto active3 = CreateNode(*graph, "active3", STREAMACTIVE, 0, 0); @@ -151,14 +151,17 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { GraphUtils::AddEdge(enter1->GetOutControlAnchor(), active1->GetInControlAnchor()); GraphUtils::AddEdge(active1->GetOutControlAnchor(), merge1->GetInControlAnchor()); + GraphUtils::AddEdge(next1->GetOutControlAnchor(), active3->GetInControlAnchor()); + //GraphUtils::AddEdge(active3->GetOutControlAnchor(), merge1->GetInControlAnchor()); + SetNextIteration(merge1, next1); + + GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_t->GetInControlAnchor()); // Test for not merge. + GraphUtils::AddEdge(loop1->GetOutControlAnchor(), active2->GetInControlAnchor()); GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_f->GetInControlAnchor()); GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_t->GetInControlAnchor()); - GraphUtils::AddEdge(next1->GetOutControlAnchor(), active3->GetInControlAnchor()); - GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); - SetNextIteration(merge1, next1); AttrUtils::SetBool(enter1->GetOpDesc(), ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true); AttrUtils::SetBool(output1->GetOpDesc(), ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true);