diff --git a/ge/graph/passes/net_output_pass.cc b/ge/graph/passes/net_output_pass.cc index 3ac1100d..aca7058d 100644 --- a/ge/graph/passes/net_output_pass.cc +++ b/ge/graph/passes/net_output_pass.cc @@ -514,7 +514,7 @@ Status NetOutputPass::Run(ge::ComputeGraphPtr graph) { GELOGE(GE_GRAPH_PARAM_NULLPTR, "Compute graph is null."); return GE_GRAPH_PARAM_NULLPTR; } - GELOGI("NetOutputPass Run.graph is [%s]", graph->GetName().c_str()); + GELOGI("[NETOUTPUT PASS] Run.graph is [%s]", graph->GetName().c_str()); NodePtr output_node = graph->FindFirstNodeMatchType(NETOUTPUT); // save user targets node SaveAndRemoveTargets(graph); @@ -552,10 +552,17 @@ Status NetOutputPass::AddNetOutputNodeToGraph(const ge::ComputeGraphPtr &graph, // If user does not set out nodes and targets and no retval node, also add netoutput node if ((graph->GetGraphOutNodesInfo().empty()) && (graph->GetGraphTargetNodesInfo().empty()) && !is_include_special_node_) { - GELOGI("[NETOUTPUT PASS] output_nodes and target_nodes and special nodes is empty!Add netoutput!"); + GELOGI("[NETOUTPUT PASS] Both output, target and special nodes are empty! add net output node"); output_node = graph->AddNode(net_output_desc); GE_CHK_STATUS_RET(AddCtrlEdgesBetweenLeafAndNetOutput(graph, output_node), "add ctrl edge between leaf and netoutput failed"); + if (!ge::AttrUtils::SetInt(output_node->GetOpDesc(), ATTR_NAME_TRUE_BRANCH_STREAM, 0)) { + REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_TRUE_BRANCH_STREAM.c_str(), + output_node->GetName().c_str(), output_node->GetType().c_str()); + GELOGE(INTERNAL_ERROR, "set ATTR_NAME_TRUE_BRANCH_STREAM failed"); + return INTERNAL_ERROR; + } + GELOGI("[NETOUTPUT PASS] Add net output node succeed"); return SUCCESS; } GELOGI("[NETOUTPUT PASS] Output node size:%lu.", output_nodes_info.size()); diff --git a/tests/ut/ge/graph/passes/net_output_pass_unittest.cc b/tests/ut/ge/graph/passes/net_output_pass_unittest.cc index 031985f3..ac6cd63a 100644 --- a/tests/ut/ge/graph/passes/net_output_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/net_output_pass_unittest.cc @@ -631,6 +631,23 @@ TEST_F(UtestGraphPassesNetOutputPass, no_output_no_target_no_retval_success) { EXPECT_EQ(status, ge::SUCCESS); } +TEST_F(UtestGraphPassesNetOutputPass, no_output_no_target_no_retval_no_outnodes_success) { + ge::ComputeGraphPtr compute_graph = build_graph(); + + ge::PassManager pass_managers; + pass_managers.AddPass("", new (std::nothrow) NetOutputPass); + Status status = pass_managers.Run(compute_graph); + EXPECT_EQ(status, ge::SUCCESS); + + NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT); + EXPECT_NE(net_out_node, nullptr); + EXPECT_EQ(net_out_node->GetInControlNodes().size(), 2); + + int stream_label = -1; + EXPECT_TRUE(ge::AttrUtils::GetInt(net_out_node->GetOpDesc(), ATTR_NAME_TRUE_BRANCH_STREAM, stream_label)); + EXPECT_EQ(stream_label, 0); +} + TEST_F(UtestGraphPassesNetOutputPass, user_out_node_success) { ge::ComputeGraphPtr compute_graph = build_graph();