| @@ -607,6 +607,30 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { | |||||
| for (const auto &node : compute_graph->GetDirectNode()) { | |||||
| if (node->GetType() == HCOMALLREDUCE) { | |||||
| std::set<OutDataAnchorPtr> pre_out_anchor_set; | |||||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
| auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(pre_out_anchor); | |||||
| if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) { | |||||
| pre_out_anchor_set.emplace(pre_out_anchor); | |||||
| continue; | |||||
| } | |||||
| // need insert identity | |||||
| auto pre_node = pre_out_anchor->GetOwnerNode(); | |||||
| auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); | |||||
| GE_CHECK_NOTNULL(identity_node); | |||||
| auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); | |||||
| GE_CHK_STATUS_RET(ret, "Fail to insert identity."); | |||||
| GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||||
| pre_node->GetName().c_str(), node->GetName().c_str()) | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| namespace ge { | namespace ge { | ||||
| @@ -665,13 +689,14 @@ Status GraphOptimize::CheckRWConflict(ComputeGraphPtr &compute_graph, bool &has_ | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GraphOptimize::HandleMemoryRWConflict(ComputeGraphPtr &compute_graph) { | Status GraphOptimize::HandleMemoryRWConflict(ComputeGraphPtr &compute_graph) { | ||||
| GE_DUMP(compute_graph, "BeforeHandleMemConflict"); | |||||
| node_rwtype_map_.clear(); | node_rwtype_map_.clear(); | ||||
| auto sub_graph_vec = compute_graph->GetAllSubgraphs(); | auto sub_graph_vec = compute_graph->GetAllSubgraphs(); | ||||
| if (sub_graph_vec.empty()) { | if (sub_graph_vec.empty()) { | ||||
| GELOGD("No sub graph here. Ignore memory conflict handle."); | |||||
| return SUCCESS; | |||||
| // only root graph, to handle allreduce servral input from one output anchor | |||||
| return HandleAllreduceDuplicateInput(compute_graph); | |||||
| } | } | ||||
| GE_DUMP(compute_graph, "BeforeHandleMemConflict"); | |||||
| // 1.loop all subgraph, mark rw type from inside to outside | // 1.loop all subgraph, mark rw type from inside to outside | ||||
| Status ret = MarkRWTypeForAllSubgraph(sub_graph_vec); | Status ret = MarkRWTypeForAllSubgraph(sub_graph_vec); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||