From: @zhangxiaokun9 Reviewed-by: @xchu42,@ji_chen Signed-off-by: @ji_chentags/v1.3.0
| @@ -565,13 +565,6 @@ NodePtr SwitchToStreamSwitchPass::CreateActiveNode(const ComputeGraphPtr &graph, | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); | op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); | ||||
| return nullptr, "Create StreamActive node failed."); | return nullptr, "Create StreamActive node failed."); | ||||
| GE_IF_BOOL_EXEC(GraphUtils::AddEdge(node->GetOutControlAnchor(), active_node->GetInControlAnchor()) != SUCCESS, | |||||
| REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed", | |||||
| node->GetName().c_str(), node->GetType().c_str(), | |||||
| active_node->GetName().c_str(), active_node->GetType().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "add edge failed"); | |||||
| return nullptr); | |||||
| GE_IF_BOOL_EXEC(SetSwitchBranchNodeLabel(active_node, node_name) != SUCCESS, | GE_IF_BOOL_EXEC(SetSwitchBranchNodeLabel(active_node, node_name) != SUCCESS, | ||||
| REPORT_CALL_ERROR("E19999", "Set switch branch node label:%s to node:%s(%s) failed", | REPORT_CALL_ERROR("E19999", "Set switch branch node label:%s to node:%s(%s) failed", | ||||
| node_name.c_str(), active_node->GetName().c_str(), active_node->GetType().c_str()); | node_name.c_str(), active_node->GetName().c_str(), active_node->GetType().c_str()); | ||||
| @@ -2394,6 +2394,10 @@ Status HybridModelBuilder::CreateMergeEnterGroup(const NodePtr &node, NodeItem * | |||||
| // Enter --> StreamActive --> StreamMerge | // Enter --> StreamActive --> StreamMerge | ||||
| for (const auto &dst_node : node->GetOutControlNodes()) { | for (const auto &dst_node : node->GetOutControlNodes()) { | ||||
| GE_CHECK_NOTNULL(dst_node); | GE_CHECK_NOTNULL(dst_node); | ||||
| if (dst_node->GetType() != STREAMMERGE) { | |||||
| GELOGI("[%s] Skip Not StreamMerge node [%s]", node->GetName().c_str(), dst_node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| NodeItem *dst_node_item = nullptr; | NodeItem *dst_node_item = nullptr; | ||||
| GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | ||||
| "[%s] failed to get or create node item", dst_node->GetName().c_str()); | "[%s] failed to get or create node item", dst_node->GetName().c_str()); | ||||
| @@ -2459,7 +2463,7 @@ Status HybridModelBuilder::CreateStreamActiveGroup(const NodePtr &node, NodeItem | |||||
| if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsEnterNode)) { | if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsEnterNode)) { | ||||
| // Enter --> StreamActive --> StreamMerge | // Enter --> StreamActive --> StreamMerge | ||||
| return CreateMergeEnterGroup(node, node_item); | return CreateMergeEnterGroup(node, node_item); | ||||
| } else if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsIterationNode)) { | |||||
| } else if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsIterationNode)) { | |||||
| // NextIteration --> StreamActive {-->} StreamMerge | // NextIteration --> StreamActive {-->} StreamMerge | ||||
| return CreateMergeIterationGroup(node, node_item); | return CreateMergeIterationGroup(node, node_item); | ||||
| } | } | ||||
| @@ -121,8 +121,8 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||||
| add1->GetOpDesc()->SetOpKernelLibName("AIcoreEngine"); | add1->GetOpDesc()->SetOpKernelLibName("AIcoreEngine"); | ||||
| auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); | auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); | ||||
| auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); | auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); | ||||
| auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||||
| auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||||
| auto value0 = CreateNode(*graph, "const1", CONSTANT, 0, 1); | |||||
| auto value1 = CreateNode(*graph, "const2", CONSTANT, 0, 1); | |||||
| auto active1 = CreateNode(*graph, "active1", STREAMACTIVE, 0, 0); | auto active1 = CreateNode(*graph, "active1", STREAMACTIVE, 0, 0); | ||||
| auto active2 = CreateNode(*graph, "active2", STREAMACTIVE, 0, 0); | auto active2 = CreateNode(*graph, "active2", STREAMACTIVE, 0, 0); | ||||
| auto active3 = CreateNode(*graph, "active3", STREAMACTIVE, 0, 0); | auto active3 = CreateNode(*graph, "active3", STREAMACTIVE, 0, 0); | ||||
| @@ -151,14 +151,17 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||||
| GraphUtils::AddEdge(enter1->GetOutControlAnchor(), active1->GetInControlAnchor()); | GraphUtils::AddEdge(enter1->GetOutControlAnchor(), active1->GetInControlAnchor()); | ||||
| GraphUtils::AddEdge(active1->GetOutControlAnchor(), merge1->GetInControlAnchor()); | GraphUtils::AddEdge(active1->GetOutControlAnchor(), merge1->GetInControlAnchor()); | ||||
| GraphUtils::AddEdge(next1->GetOutControlAnchor(), active3->GetInControlAnchor()); | |||||
| //GraphUtils::AddEdge(active3->GetOutControlAnchor(), merge1->GetInControlAnchor()); | |||||
| SetNextIteration(merge1, next1); | |||||
| GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_t->GetInControlAnchor()); // Test for not merge. | |||||
| GraphUtils::AddEdge(loop1->GetOutControlAnchor(), active2->GetInControlAnchor()); | GraphUtils::AddEdge(loop1->GetOutControlAnchor(), active2->GetInControlAnchor()); | ||||
| GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_f->GetInControlAnchor()); | GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_f->GetInControlAnchor()); | ||||
| GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_t->GetInControlAnchor()); | GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_t->GetInControlAnchor()); | ||||
| GraphUtils::AddEdge(next1->GetOutControlAnchor(), active3->GetInControlAnchor()); | |||||
| GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | ||||
| SetNextIteration(merge1, next1); | |||||
| AttrUtils::SetBool(enter1->GetOpDesc(), ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true); | AttrUtils::SetBool(enter1->GetOpDesc(), ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true); | ||||
| AttrUtils::SetBool(output1->GetOpDesc(), ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true); | AttrUtils::SetBool(output1->GetOpDesc(), ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true); | ||||