diff --git a/ge/graph/passes/net_output_pass.cc b/ge/graph/passes/net_output_pass.cc index 100e73cd..28916706 100644 --- a/ge/graph/passes/net_output_pass.cc +++ b/ge/graph/passes/net_output_pass.cc @@ -110,7 +110,15 @@ Status NetOutputPass::GetOutputNode(const ge::ComputeGraphPtr &graph, std::vecto if (op_desc->HasAttr(ATTR_ATC_USER_DEFINE_OUTPUT_NODES)) { is_user_define_ouput_nodes = true; } - output_nodes_info.push_back({ele.first, ele.second, -1}); + int parent_index = -1; + auto output_desc = op_desc->MutableOutputDesc(ele.second); + if (output_desc == nullptr) { + GELOGE(FAILED, "[Get][OutputDesc]Can not find output tensor desc from node:%s, index %d", + op_desc->GetName().c_str(), ele.second); + return FAILED; + } + (void)ge::AttrUtils::GetInt(output_desc, ge::ATTR_NAME_PARENT_NODE_INDEX, parent_index); + output_nodes_info.push_back({ele.first, ele.second, parent_index}); } GELOGI("Output node set by user or leaf node, size:%zu.", output_nodes_info.size()); for (auto &ele : out_nodes_tmp) { diff --git a/ge/graph/passes/parallel_group_pass.cc b/ge/graph/passes/parallel_group_pass.cc index 0d033fbf..9c93f6cf 100644 --- a/ge/graph/passes/parallel_group_pass.cc +++ b/ge/graph/passes/parallel_group_pass.cc @@ -128,8 +128,7 @@ Status ParallelGroupPass::ProcessGraphGroupNodes(ComputeGraphPtr graph, int32_t NodePtr cur_node = nullptr; for (std::size_t i = 1; i < nodes.size(); i++) { cur_node = nodes[i]; - GELOGD("Original add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), - cur_node->GetName().c_str()); + GELOGD("Original add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), cur_node->GetName().c_str()); if (ReplaceWithSwitchAndMerge(pre_node, cur_node, node_2_switch_merge) != SUCCESS) { GELOGE(FAILED, "[Replace][Node]Replace switch and merges for nodes: %s and %s failed.", pre_node->GetName().c_str(), cur_node->GetName().c_str()); @@ -155,10 +154,8 @@ Status ParallelGroupPass::AddCtrlEdge(NodePtr pre_node, NodePtr cur_node) { return SUCCESS; } } - GELOGD("Finally add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), - cur_node->GetName().c_str()); - return GraphUtils::AddEdge(pre_node->GetOutControlAnchor(), - cur_node->GetInControlAnchor()); + GELOGD("Finally add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), cur_node->GetName().c_str()); + return GraphUtils::AddEdge(pre_node->GetOutControlAnchor(), cur_node->GetInControlAnchor()); } Status ParallelGroupPass::ProcessGroupNodeInSwitch(ComputeGraphPtr graph, @@ -200,9 +197,7 @@ Status ParallelGroupPass::ProcessGroupNodeInSwitch(ComputeGraphPtr graph, NodePtr cast_node = NodeUtils::GetInDataNodeByIndex(*node, 0); GE_CHECK_NOTNULL(cast_node); - if (MappingNodeToSwitchAndMerge(group_nodes, merge_nodes, - cast_node, node, - node_2_switch_merge) != SUCCESS) { + if (MappingNodeToSwitchAndMerge(group_nodes, merge_nodes, cast_node, node, node_2_switch_merge) != SUCCESS) { GELOGE(FAILED, "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.", graph->GetName().c_str()); REPORT_CALL_ERROR("E19999", "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.", graph->GetName().c_str()); @@ -247,8 +242,7 @@ void ParallelGroupPass::FindGroupNodeAndMerge(NodePtr stream_switch_node, std::s } Status ParallelGroupPass::MappingNodeToSwitchAndMerge(const std::set &group_nodes, - const std::vector &merge_nodes, - const NodePtr &cast_node, const NodePtr &switch_node, + const std::vector &merge_nodes, const NodePtr &cast_node, const NodePtr &switch_node, std::map, NodePtr>> &node_2_switch_merge) { for (const auto &group_node : group_nodes) { auto itr = node_2_switch_merge.find(group_node);