From f1a6967e04128952072c6c65f776d574b1620333 Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Sat, 21 Nov 2020 10:29:30 +0800 Subject: [PATCH 1/2] modified: ge/graph/optimize/mem_rw_conflict_optimize.cc --- ge/graph/optimize/mem_rw_conflict_optimize.cc | 31 +++++++++++++++++-- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/ge/graph/optimize/mem_rw_conflict_optimize.cc b/ge/graph/optimize/mem_rw_conflict_optimize.cc index b9533588..828ed20f 100644 --- a/ge/graph/optimize/mem_rw_conflict_optimize.cc +++ b/ge/graph/optimize/mem_rw_conflict_optimize.cc @@ -607,6 +607,30 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { } return SUCCESS; } +Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { + for (const auto &node : compute_graph->GetDirectNode()) { + if (node->GetType() == HCOMALLREDUCE) { + std::set pre_out_anchor_set; + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(pre_out_anchor); + if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) { + pre_out_anchor_set.emplace(pre_out_anchor); + continue; + } + // need insert identity + auto pre_node = pre_out_anchor->GetOwnerNode(); + auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); + GE_CHECK_NOTNULL(identity_node); + auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); + GE_CHK_STATUS_RET(ret, "Fail to insert identity."); + GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), + pre_node->GetName().c_str(), node->GetName().c_str()) + } + } + } + return SUCCESS; +} } // namespace namespace ge { @@ -665,13 +689,14 @@ Status GraphOptimize::CheckRWConflict(ComputeGraphPtr &compute_graph, bool &has_ return SUCCESS; } Status GraphOptimize::HandleMemoryRWConflict(ComputeGraphPtr &compute_graph) { + GE_DUMP(compute_graph, "BeforeHandleMemConflict"); node_rwtype_map_.clear(); auto sub_graph_vec = compute_graph->GetAllSubgraphs(); if (sub_graph_vec.empty()) { - GELOGD("No sub graph here. Ignore memory conflict handle."); - return SUCCESS; + // only root graph, to handle allreduce servral input from one output anchor + return HandleAllreduceDuplicateInput(compute_graph); } - GE_DUMP(compute_graph, "BeforeHandleMemConflict"); + // 1.loop all subgraph, mark rw type from inside to outside Status ret = MarkRWTypeForAllSubgraph(sub_graph_vec); if (ret != SUCCESS) { From 9da00765c74f3416769277958e5e2bb4a31c969b Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Sat, 21 Nov 2020 10:33:52 +0800 Subject: [PATCH 2/2] modified: ge/graph/optimize/mem_rw_conflict_optimize.cc --- ge/graph/optimize/mem_rw_conflict_optimize.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ge/graph/optimize/mem_rw_conflict_optimize.cc b/ge/graph/optimize/mem_rw_conflict_optimize.cc index 828ed20f..fdb825a9 100644 --- a/ge/graph/optimize/mem_rw_conflict_optimize.cc +++ b/ge/graph/optimize/mem_rw_conflict_optimize.cc @@ -625,7 +625,7 @@ Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); GE_CHK_STATUS_RET(ret, "Fail to insert identity."); GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), - pre_node->GetName().c_str(), node->GetName().c_str()) + pre_node->GetName().c_str(), node->GetName().c_str()); } } }