Browse Source

Remove control edge for Pre-Cast to StreamActive

tags/v1.3.0
zhangxiaokun 3 years ago
parent
commit
b05613e3c5
3 changed files with 13 additions and 13 deletions
  1. +0
    -7
      ge/graph/passes/switch_to_stream_switch_pass.cc
  2. +5
    -1
      ge/hybrid/model/hybrid_model_builder.cc
  3. +8
    -5
      tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc

+ 0
- 7
ge/graph/passes/switch_to_stream_switch_pass.cc View File

@@ -565,13 +565,6 @@ NodePtr SwitchToStreamSwitchPass::CreateActiveNode(const ComputeGraphPtr &graph,
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
return nullptr, "Create StreamActive node failed.");

GE_IF_BOOL_EXEC(GraphUtils::AddEdge(node->GetOutControlAnchor(), active_node->GetInControlAnchor()) != SUCCESS,
REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
node->GetName().c_str(), node->GetType().c_str(),
active_node->GetName().c_str(), active_node->GetType().c_str());
GELOGE(INTERNAL_ERROR, "add edge failed");
return nullptr);

GE_IF_BOOL_EXEC(SetSwitchBranchNodeLabel(active_node, node_name) != SUCCESS,
REPORT_CALL_ERROR("E19999", "Set switch branch node label:%s to node:%s(%s) failed",
node_name.c_str(), active_node->GetName().c_str(), active_node->GetType().c_str());


+ 5
- 1
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -2394,6 +2394,10 @@ Status HybridModelBuilder::CreateMergeEnterGroup(const NodePtr &node, NodeItem *
// Enter --> StreamActive --> StreamMerge
for (const auto &dst_node : node->GetOutControlNodes()) {
GE_CHECK_NOTNULL(dst_node);
if (dst_node->GetType() != STREAMMERGE) {
GELOGI("[%s] Skip Not StreamMerge node [%s]", node->GetName().c_str(), dst_node->GetName().c_str());
continue;
}
NodeItem *dst_node_item = nullptr;
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item),
"[%s] failed to get or create node item", dst_node->GetName().c_str());
@@ -2459,7 +2463,7 @@ Status HybridModelBuilder::CreateStreamActiveGroup(const NodePtr &node, NodeItem
if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsEnterNode)) {
// Enter --> StreamActive --> StreamMerge
return CreateMergeEnterGroup(node, node_item);
} else if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsIterationNode)) {
} else if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsIterationNode)) {
// NextIteration --> StreamActive {-->} StreamMerge
return CreateMergeIterationGroup(node, node_item);
}


+ 8
- 5
tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc View File

@@ -121,8 +121,8 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) {
add1->GetOpDesc()->SetOpKernelLibName("AIcoreEngine");
auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1);
auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1);
auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1);
auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1);
auto value0 = CreateNode(*graph, "const1", CONSTANT, 0, 1);
auto value1 = CreateNode(*graph, "const2", CONSTANT, 0, 1);
auto active1 = CreateNode(*graph, "active1", STREAMACTIVE, 0, 0);
auto active2 = CreateNode(*graph, "active2", STREAMACTIVE, 0, 0);
auto active3 = CreateNode(*graph, "active3", STREAMACTIVE, 0, 0);
@@ -151,14 +151,17 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) {
GraphUtils::AddEdge(enter1->GetOutControlAnchor(), active1->GetInControlAnchor());
GraphUtils::AddEdge(active1->GetOutControlAnchor(), merge1->GetInControlAnchor());

GraphUtils::AddEdge(next1->GetOutControlAnchor(), active3->GetInControlAnchor());
//GraphUtils::AddEdge(active3->GetOutControlAnchor(), merge1->GetInControlAnchor());
SetNextIteration(merge1, next1);

GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_t->GetInControlAnchor()); // Test for not merge.

GraphUtils::AddEdge(loop1->GetOutControlAnchor(), active2->GetInControlAnchor());
GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_f->GetInControlAnchor());
GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_t->GetInControlAnchor());

GraphUtils::AddEdge(next1->GetOutControlAnchor(), active3->GetInControlAnchor());

GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0));
SetNextIteration(merge1, next1);

AttrUtils::SetBool(enter1->GetOpDesc(), ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true);
AttrUtils::SetBool(output1->GetOpDesc(), ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true);


Loading…
Cancel
Save