@@ -565,13 +565,6 @@ NodePtr SwitchToStreamSwitchPass::CreateActiveNode(const ComputeGraphPtr &graph, | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); | op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); | ||||
return nullptr, "Create StreamActive node failed."); | return nullptr, "Create StreamActive node failed."); | ||||
GE_IF_BOOL_EXEC(GraphUtils::AddEdge(node->GetOutControlAnchor(), active_node->GetInControlAnchor()) != SUCCESS, | |||||
REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed", | |||||
node->GetName().c_str(), node->GetType().c_str(), | |||||
active_node->GetName().c_str(), active_node->GetType().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "add edge failed"); | |||||
return nullptr); | |||||
GE_IF_BOOL_EXEC(SetSwitchBranchNodeLabel(active_node, node_name) != SUCCESS, | GE_IF_BOOL_EXEC(SetSwitchBranchNodeLabel(active_node, node_name) != SUCCESS, | ||||
REPORT_CALL_ERROR("E19999", "Set switch branch node label:%s to node:%s(%s) failed", | REPORT_CALL_ERROR("E19999", "Set switch branch node label:%s to node:%s(%s) failed", | ||||
node_name.c_str(), active_node->GetName().c_str(), active_node->GetType().c_str()); | node_name.c_str(), active_node->GetName().c_str(), active_node->GetType().c_str()); | ||||
@@ -2394,6 +2394,10 @@ Status HybridModelBuilder::CreateMergeEnterGroup(const NodePtr &node, NodeItem * | |||||
// Enter --> StreamActive --> StreamMerge | // Enter --> StreamActive --> StreamMerge | ||||
for (const auto &dst_node : node->GetOutControlNodes()) { | for (const auto &dst_node : node->GetOutControlNodes()) { | ||||
GE_CHECK_NOTNULL(dst_node); | GE_CHECK_NOTNULL(dst_node); | ||||
if (dst_node->GetType() != STREAMMERGE) { | |||||
GELOGI("[%s] Skip Not StreamMerge node [%s]", node->GetName().c_str(), dst_node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
NodeItem *dst_node_item = nullptr; | NodeItem *dst_node_item = nullptr; | ||||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | ||||
"[%s] failed to get or create node item", dst_node->GetName().c_str()); | "[%s] failed to get or create node item", dst_node->GetName().c_str()); | ||||
@@ -2459,7 +2463,7 @@ Status HybridModelBuilder::CreateStreamActiveGroup(const NodePtr &node, NodeItem | |||||
if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsEnterNode)) { | if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsEnterNode)) { | ||||
// Enter --> StreamActive --> StreamMerge | // Enter --> StreamActive --> StreamMerge | ||||
return CreateMergeEnterGroup(node, node_item); | return CreateMergeEnterGroup(node, node_item); | ||||
} else if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsIterationNode)) { | |||||
} else if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsIterationNode)) { | |||||
// NextIteration --> StreamActive {-->} StreamMerge | // NextIteration --> StreamActive {-->} StreamMerge | ||||
return CreateMergeIterationGroup(node, node_item); | return CreateMergeIterationGroup(node, node_item); | ||||
} | } | ||||
@@ -121,8 +121,8 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||||
add1->GetOpDesc()->SetOpKernelLibName("AIcoreEngine"); | add1->GetOpDesc()->SetOpKernelLibName("AIcoreEngine"); | ||||
auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); | auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); | ||||
auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); | auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); | ||||
auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||||
auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||||
auto value0 = CreateNode(*graph, "const1", CONSTANT, 0, 1); | |||||
auto value1 = CreateNode(*graph, "const2", CONSTANT, 0, 1); | |||||
auto active1 = CreateNode(*graph, "active1", STREAMACTIVE, 0, 0); | auto active1 = CreateNode(*graph, "active1", STREAMACTIVE, 0, 0); | ||||
auto active2 = CreateNode(*graph, "active2", STREAMACTIVE, 0, 0); | auto active2 = CreateNode(*graph, "active2", STREAMACTIVE, 0, 0); | ||||
auto active3 = CreateNode(*graph, "active3", STREAMACTIVE, 0, 0); | auto active3 = CreateNode(*graph, "active3", STREAMACTIVE, 0, 0); | ||||
@@ -151,14 +151,17 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||||
GraphUtils::AddEdge(enter1->GetOutControlAnchor(), active1->GetInControlAnchor()); | GraphUtils::AddEdge(enter1->GetOutControlAnchor(), active1->GetInControlAnchor()); | ||||
GraphUtils::AddEdge(active1->GetOutControlAnchor(), merge1->GetInControlAnchor()); | GraphUtils::AddEdge(active1->GetOutControlAnchor(), merge1->GetInControlAnchor()); | ||||
GraphUtils::AddEdge(next1->GetOutControlAnchor(), active3->GetInControlAnchor()); | |||||
//GraphUtils::AddEdge(active3->GetOutControlAnchor(), merge1->GetInControlAnchor()); | |||||
SetNextIteration(merge1, next1); | |||||
GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_t->GetInControlAnchor()); // Test for not merge. | |||||
GraphUtils::AddEdge(loop1->GetOutControlAnchor(), active2->GetInControlAnchor()); | GraphUtils::AddEdge(loop1->GetOutControlAnchor(), active2->GetInControlAnchor()); | ||||
GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_f->GetInControlAnchor()); | GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_f->GetInControlAnchor()); | ||||
GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_t->GetInControlAnchor()); | GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_t->GetInControlAnchor()); | ||||
GraphUtils::AddEdge(next1->GetOutControlAnchor(), active3->GetInControlAnchor()); | |||||
GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | ||||
SetNextIteration(merge1, next1); | |||||
AttrUtils::SetBool(enter1->GetOpDesc(), ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true); | AttrUtils::SetBool(enter1->GetOpDesc(), ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true); | ||||
AttrUtils::SetBool(output1->GetOpDesc(), ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true); | AttrUtils::SetBool(output1->GetOpDesc(), ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true); | ||||