Browse Source

!1481 修复无输出的训练图无法进行小循环的bug

From: @xuepenginnanjing
Reviewed-by: @xchu42
Signed-off-by:
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
d8e69d3e1e
2 changed files with 26 additions and 2 deletions
  1. +9
    -2
      ge/graph/passes/net_output_pass.cc
  2. +17
    -0
      tests/ut/ge/graph/passes/net_output_pass_unittest.cc

+ 9
- 2
ge/graph/passes/net_output_pass.cc View File

@@ -514,7 +514,7 @@ Status NetOutputPass::Run(ge::ComputeGraphPtr graph) {
GELOGE(GE_GRAPH_PARAM_NULLPTR, "Compute graph is null.");
return GE_GRAPH_PARAM_NULLPTR;
}
GELOGI("NetOutputPass Run.graph is [%s]", graph->GetName().c_str());
GELOGI("[NETOUTPUT PASS] Run.graph is [%s]", graph->GetName().c_str());
NodePtr output_node = graph->FindFirstNodeMatchType(NETOUTPUT);
// save user targets node
SaveAndRemoveTargets(graph);
@@ -552,10 +552,17 @@ Status NetOutputPass::AddNetOutputNodeToGraph(const ge::ComputeGraphPtr &graph,
// If user does not set out nodes and targets and no retval node, also add netoutput node
if ((graph->GetGraphOutNodesInfo().empty()) && (graph->GetGraphTargetNodesInfo().empty()) &&
!is_include_special_node_) {
GELOGI("[NETOUTPUT PASS] output_nodes and target_nodes and special nodes is empty!Add netoutput!");
GELOGI("[NETOUTPUT PASS] Both output, target and special nodes are empty! add net output node");
output_node = graph->AddNode(net_output_desc);
GE_CHK_STATUS_RET(AddCtrlEdgesBetweenLeafAndNetOutput(graph, output_node),
"add ctrl edge between leaf and netoutput failed");
if (!ge::AttrUtils::SetInt(output_node->GetOpDesc(), ATTR_NAME_TRUE_BRANCH_STREAM, 0)) {
REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_TRUE_BRANCH_STREAM.c_str(),
output_node->GetName().c_str(), output_node->GetType().c_str());
GELOGE(INTERNAL_ERROR, "set ATTR_NAME_TRUE_BRANCH_STREAM failed");
return INTERNAL_ERROR;
}
GELOGI("[NETOUTPUT PASS] Add net output node succeed");
return SUCCESS;
}
GELOGI("[NETOUTPUT PASS] Output node size:%lu.", output_nodes_info.size());


+ 17
- 0
tests/ut/ge/graph/passes/net_output_pass_unittest.cc View File

@@ -631,6 +631,23 @@ TEST_F(UtestGraphPassesNetOutputPass, no_output_no_target_no_retval_success) {
EXPECT_EQ(status, ge::SUCCESS);
}

TEST_F(UtestGraphPassesNetOutputPass, no_output_no_target_no_retval_no_outnodes_success) {
ge::ComputeGraphPtr compute_graph = build_graph();

ge::PassManager pass_managers;
pass_managers.AddPass("", new (std::nothrow) NetOutputPass);
Status status = pass_managers.Run(compute_graph);
EXPECT_EQ(status, ge::SUCCESS);

NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
EXPECT_NE(net_out_node, nullptr);
EXPECT_EQ(net_out_node->GetInControlNodes().size(), 2);

int stream_label = -1;
EXPECT_TRUE(ge::AttrUtils::GetInt(net_out_node->GetOpDesc(), ATTR_NAME_TRUE_BRANCH_STREAM, stream_label));
EXPECT_EQ(stream_label, 0);
}

TEST_F(UtestGraphPassesNetOutputPass, user_out_node_success) {
ge::ComputeGraphPtr compute_graph = build_graph();



Loading…
Cancel
Save