|
|
@@ -128,8 +128,7 @@ Status ParallelGroupPass::ProcessGraphGroupNodes(ComputeGraphPtr graph, int32_t |
|
|
|
NodePtr cur_node = nullptr; |
|
|
|
for (std::size_t i = 1; i < nodes.size(); i++) { |
|
|
|
cur_node = nodes[i]; |
|
|
|
GELOGD("Original add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), |
|
|
|
cur_node->GetName().c_str()); |
|
|
|
GELOGD("Original add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), cur_node->GetName().c_str()); |
|
|
|
if (ReplaceWithSwitchAndMerge(pre_node, cur_node, node_2_switch_merge) != SUCCESS) { |
|
|
|
GELOGE(FAILED, "[Replace][Node]Replace switch and merges for nodes: %s and %s failed.", |
|
|
|
pre_node->GetName().c_str(), cur_node->GetName().c_str()); |
|
|
@@ -155,10 +154,8 @@ Status ParallelGroupPass::AddCtrlEdge(NodePtr pre_node, NodePtr cur_node) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
} |
|
|
|
GELOGD("Finally add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), |
|
|
|
cur_node->GetName().c_str()); |
|
|
|
return GraphUtils::AddEdge(pre_node->GetOutControlAnchor(), |
|
|
|
cur_node->GetInControlAnchor()); |
|
|
|
GELOGD("Finally add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), cur_node->GetName().c_str()); |
|
|
|
return GraphUtils::AddEdge(pre_node->GetOutControlAnchor(), cur_node->GetInControlAnchor()); |
|
|
|
} |
|
|
|
|
|
|
|
Status ParallelGroupPass::ProcessGroupNodeInSwitch(ComputeGraphPtr graph, |
|
|
@@ -200,9 +197,7 @@ Status ParallelGroupPass::ProcessGroupNodeInSwitch(ComputeGraphPtr graph, |
|
|
|
|
|
|
|
NodePtr cast_node = NodeUtils::GetInDataNodeByIndex(*node, 0); |
|
|
|
GE_CHECK_NOTNULL(cast_node); |
|
|
|
if (MappingNodeToSwitchAndMerge(group_nodes, merge_nodes, |
|
|
|
cast_node, node, |
|
|
|
node_2_switch_merge) != SUCCESS) { |
|
|
|
if (MappingNodeToSwitchAndMerge(group_nodes, merge_nodes, cast_node, node, node_2_switch_merge) != SUCCESS) { |
|
|
|
GELOGE(FAILED, "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.", graph->GetName().c_str()); |
|
|
|
REPORT_CALL_ERROR("E19999", "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.", |
|
|
|
graph->GetName().c_str()); |
|
|
@@ -247,8 +242,7 @@ void ParallelGroupPass::FindGroupNodeAndMerge(NodePtr stream_switch_node, std::s |
|
|
|
} |
|
|
|
|
|
|
|
Status ParallelGroupPass::MappingNodeToSwitchAndMerge(const std::set<NodePtr> &group_nodes, |
|
|
|
const std::vector<NodePtr> &merge_nodes, |
|
|
|
const NodePtr &cast_node, const NodePtr &switch_node, |
|
|
|
const std::vector<NodePtr> &merge_nodes, const NodePtr &cast_node, const NodePtr &switch_node, |
|
|
|
std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge) { |
|
|
|
for (const auto &group_node : group_nodes) { |
|
|
|
auto itr = node_2_switch_merge.find(group_node); |
|
|
|