diff --git a/ge/graph/optimize/mem_rw_conflict_optimize.cc b/ge/graph/optimize/mem_rw_conflict_optimize.cc index b9533588..fdb825a9 100644 --- a/ge/graph/optimize/mem_rw_conflict_optimize.cc +++ b/ge/graph/optimize/mem_rw_conflict_optimize.cc @@ -607,6 +607,30 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { } return SUCCESS; } +Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { + for (const auto &node : compute_graph->GetDirectNode()) { + if (node->GetType() == HCOMALLREDUCE) { + std::set pre_out_anchor_set; + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(pre_out_anchor); + if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) { + pre_out_anchor_set.emplace(pre_out_anchor); + continue; + } + // need insert identity + auto pre_node = pre_out_anchor->GetOwnerNode(); + auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); + GE_CHECK_NOTNULL(identity_node); + auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); + GE_CHK_STATUS_RET(ret, "Fail to insert identity."); + GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), + pre_node->GetName().c_str(), node->GetName().c_str()); + } + } + } + return SUCCESS; +} } // namespace namespace ge { @@ -665,13 +689,14 @@ Status GraphOptimize::CheckRWConflict(ComputeGraphPtr &compute_graph, bool &has_ return SUCCESS; } Status GraphOptimize::HandleMemoryRWConflict(ComputeGraphPtr &compute_graph) { + GE_DUMP(compute_graph, "BeforeHandleMemConflict"); node_rwtype_map_.clear(); auto sub_graph_vec = compute_graph->GetAllSubgraphs(); if (sub_graph_vec.empty()) { - GELOGD("No sub graph here. Ignore memory conflict handle."); - return SUCCESS; + // only root graph, to handle allreduce servral input from one output anchor + return HandleAllreduceDuplicateInput(compute_graph); } - GE_DUMP(compute_graph, "BeforeHandleMemConflict"); + // 1.loop all subgraph, mark rw type from inside to outside Status ret = MarkRWTypeForAllSubgraph(sub_graph_vec); if (ret != SUCCESS) {