Browse Source

!1426 netout pass fix for onnx parse subgraph

From: @chen-hua-baker
Reviewed-by: @selfws,@sheng-nan,@wqtshg
Signed-off-by: @wqtshg
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
0a7842cfa5
2 changed files with 14 additions and 12 deletions
  1. +9
    -1
      ge/graph/passes/net_output_pass.cc
  2. +5
    -11
      ge/graph/passes/parallel_group_pass.cc

+ 9
- 1
ge/graph/passes/net_output_pass.cc View File

@@ -110,7 +110,15 @@ Status NetOutputPass::GetOutputNode(const ge::ComputeGraphPtr &graph, std::vecto
if (op_desc->HasAttr(ATTR_ATC_USER_DEFINE_OUTPUT_NODES)) { if (op_desc->HasAttr(ATTR_ATC_USER_DEFINE_OUTPUT_NODES)) {
is_user_define_ouput_nodes = true; is_user_define_ouput_nodes = true;
} }
output_nodes_info.push_back({ele.first, ele.second, -1});
int parent_index = -1;
auto output_desc = op_desc->MutableOutputDesc(ele.second);
if (output_desc == nullptr) {
GELOGE(FAILED, "[Get][OutputDesc]Can not find output tensor desc from node:%s, index %d",
op_desc->GetName().c_str(), ele.second);
return FAILED;
}
(void)ge::AttrUtils::GetInt(output_desc, ge::ATTR_NAME_PARENT_NODE_INDEX, parent_index);
output_nodes_info.push_back({ele.first, ele.second, parent_index});
} }
GELOGI("Output node set by user or leaf node, size:%zu.", output_nodes_info.size()); GELOGI("Output node set by user or leaf node, size:%zu.", output_nodes_info.size());
for (auto &ele : out_nodes_tmp) { for (auto &ele : out_nodes_tmp) {


+ 5
- 11
ge/graph/passes/parallel_group_pass.cc View File

@@ -128,8 +128,7 @@ Status ParallelGroupPass::ProcessGraphGroupNodes(ComputeGraphPtr graph, int32_t
NodePtr cur_node = nullptr; NodePtr cur_node = nullptr;
for (std::size_t i = 1; i < nodes.size(); i++) { for (std::size_t i = 1; i < nodes.size(); i++) {
cur_node = nodes[i]; cur_node = nodes[i];
GELOGD("Original add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(),
cur_node->GetName().c_str());
GELOGD("Original add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), cur_node->GetName().c_str());
if (ReplaceWithSwitchAndMerge(pre_node, cur_node, node_2_switch_merge) != SUCCESS) { if (ReplaceWithSwitchAndMerge(pre_node, cur_node, node_2_switch_merge) != SUCCESS) {
GELOGE(FAILED, "[Replace][Node]Replace switch and merges for nodes: %s and %s failed.", GELOGE(FAILED, "[Replace][Node]Replace switch and merges for nodes: %s and %s failed.",
pre_node->GetName().c_str(), cur_node->GetName().c_str()); pre_node->GetName().c_str(), cur_node->GetName().c_str());
@@ -155,10 +154,8 @@ Status ParallelGroupPass::AddCtrlEdge(NodePtr pre_node, NodePtr cur_node) {
return SUCCESS; return SUCCESS;
} }
} }
GELOGD("Finally add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(),
cur_node->GetName().c_str());
return GraphUtils::AddEdge(pre_node->GetOutControlAnchor(),
cur_node->GetInControlAnchor());
GELOGD("Finally add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), cur_node->GetName().c_str());
return GraphUtils::AddEdge(pre_node->GetOutControlAnchor(), cur_node->GetInControlAnchor());
} }


Status ParallelGroupPass::ProcessGroupNodeInSwitch(ComputeGraphPtr graph, Status ParallelGroupPass::ProcessGroupNodeInSwitch(ComputeGraphPtr graph,
@@ -200,9 +197,7 @@ Status ParallelGroupPass::ProcessGroupNodeInSwitch(ComputeGraphPtr graph,


NodePtr cast_node = NodeUtils::GetInDataNodeByIndex(*node, 0); NodePtr cast_node = NodeUtils::GetInDataNodeByIndex(*node, 0);
GE_CHECK_NOTNULL(cast_node); GE_CHECK_NOTNULL(cast_node);
if (MappingNodeToSwitchAndMerge(group_nodes, merge_nodes,
cast_node, node,
node_2_switch_merge) != SUCCESS) {
if (MappingNodeToSwitchAndMerge(group_nodes, merge_nodes, cast_node, node, node_2_switch_merge) != SUCCESS) {
GELOGE(FAILED, "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.", graph->GetName().c_str()); GELOGE(FAILED, "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.", graph->GetName().c_str());
REPORT_CALL_ERROR("E19999", "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.", REPORT_CALL_ERROR("E19999", "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.",
graph->GetName().c_str()); graph->GetName().c_str());
@@ -247,8 +242,7 @@ void ParallelGroupPass::FindGroupNodeAndMerge(NodePtr stream_switch_node, std::s
} }


Status ParallelGroupPass::MappingNodeToSwitchAndMerge(const std::set<NodePtr> &group_nodes, Status ParallelGroupPass::MappingNodeToSwitchAndMerge(const std::set<NodePtr> &group_nodes,
const std::vector<NodePtr> &merge_nodes,
const NodePtr &cast_node, const NodePtr &switch_node,
const std::vector<NodePtr> &merge_nodes, const NodePtr &cast_node, const NodePtr &switch_node,
std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge) { std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge) {
for (const auto &group_node : group_nodes) { for (const auto &group_node : group_nodes) {
auto itr = node_2_switch_merge.find(group_node); auto itr = node_2_switch_merge.find(group_node);


Loading…
Cancel
Save