|
@@ -607,6 +607,30 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { |
|
|
} |
|
|
} |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { |
|
|
|
|
|
for (const auto &node : compute_graph->GetDirectNode()) { |
|
|
|
|
|
if (node->GetType() == HCOMALLREDUCE) { |
|
|
|
|
|
std::set<OutDataAnchorPtr> pre_out_anchor_set; |
|
|
|
|
|
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { |
|
|
|
|
|
auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); |
|
|
|
|
|
GE_CHECK_NOTNULL(pre_out_anchor); |
|
|
|
|
|
if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) { |
|
|
|
|
|
pre_out_anchor_set.emplace(pre_out_anchor); |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
// need insert identity |
|
|
|
|
|
auto pre_node = pre_out_anchor->GetOwnerNode(); |
|
|
|
|
|
auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); |
|
|
|
|
|
GE_CHECK_NOTNULL(identity_node); |
|
|
|
|
|
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); |
|
|
|
|
|
GE_CHK_STATUS_RET(ret, "Fail to insert identity."); |
|
|
|
|
|
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), |
|
|
|
|
|
pre_node->GetName().c_str(), node->GetName().c_str()); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
} // namespace |
|
|
} // namespace |
|
|
|
|
|
|
|
|
namespace ge { |
|
|
namespace ge { |
|
@@ -665,13 +689,14 @@ Status GraphOptimize::CheckRWConflict(ComputeGraphPtr &compute_graph, bool &has_ |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
Status GraphOptimize::HandleMemoryRWConflict(ComputeGraphPtr &compute_graph) { |
|
|
Status GraphOptimize::HandleMemoryRWConflict(ComputeGraphPtr &compute_graph) { |
|
|
|
|
|
GE_DUMP(compute_graph, "BeforeHandleMemConflict"); |
|
|
node_rwtype_map_.clear(); |
|
|
node_rwtype_map_.clear(); |
|
|
auto sub_graph_vec = compute_graph->GetAllSubgraphs(); |
|
|
auto sub_graph_vec = compute_graph->GetAllSubgraphs(); |
|
|
if (sub_graph_vec.empty()) { |
|
|
if (sub_graph_vec.empty()) { |
|
|
GELOGD("No sub graph here. Ignore memory conflict handle."); |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
|
|
|
// only root graph, to handle allreduce servral input from one output anchor |
|
|
|
|
|
return HandleAllreduceDuplicateInput(compute_graph); |
|
|
} |
|
|
} |
|
|
GE_DUMP(compute_graph, "BeforeHandleMemConflict"); |
|
|
|
|
|
|
|
|
|
|
|
// 1.loop all subgraph, mark rw type from inside to outside |
|
|
// 1.loop all subgraph, mark rw type from inside to outside |
|
|
Status ret = MarkRWTypeForAllSubgraph(sub_graph_vec); |
|
|
Status ret = MarkRWTypeForAllSubgraph(sub_graph_vec); |
|
|
if (ret != SUCCESS) { |
|
|
if (ret != SUCCESS) { |
|
|