Browse Source

!356 Bugfix: fix allreduce has duplicate input from one output anchor

From: @hugo1
Reviewed-by: @xchu42,@ji_chen
Signed-off-by: @ji_chen
tags/v1.1.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
e5a5d8b69a
1 changed files with 28 additions and 3 deletions
  1. +28
    -3
      ge/graph/optimize/mem_rw_conflict_optimize.cc

+ 28
- 3
ge/graph/optimize/mem_rw_conflict_optimize.cc View File

@@ -607,6 +607,30 @@ Status InsertIdentityAsNeeded(const NodePtr &node) {
}
return SUCCESS;
}
Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) {
for (const auto &node : compute_graph->GetDirectNode()) {
if (node->GetType() == HCOMALLREDUCE) {
std::set<OutDataAnchorPtr> pre_out_anchor_set;
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(pre_out_anchor);
if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) {
pre_out_anchor_set.emplace(pre_out_anchor);
continue;
}
// need insert identity
auto pre_node = pre_out_anchor->GetOwnerNode();
auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx());
GE_CHECK_NOTNULL(identity_node);
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node);
GE_CHK_STATUS_RET(ret, "Fail to insert identity.");
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(),
pre_node->GetName().c_str(), node->GetName().c_str());
}
}
}
return SUCCESS;
}
} // namespace

namespace ge {
@@ -665,13 +689,14 @@ Status GraphOptimize::CheckRWConflict(ComputeGraphPtr &compute_graph, bool &has_
return SUCCESS;
}
Status GraphOptimize::HandleMemoryRWConflict(ComputeGraphPtr &compute_graph) {
GE_DUMP(compute_graph, "BeforeHandleMemConflict");
node_rwtype_map_.clear();
auto sub_graph_vec = compute_graph->GetAllSubgraphs();
if (sub_graph_vec.empty()) {
GELOGD("No sub graph here. Ignore memory conflict handle.");
return SUCCESS;
// only root graph, to handle allreduce servral input from one output anchor
return HandleAllreduceDuplicateInput(compute_graph);
}
GE_DUMP(compute_graph, "BeforeHandleMemConflict");
// 1.loop all subgraph, mark rw type from inside to outside
Status ret = MarkRWTypeForAllSubgraph(sub_graph_vec);
if (ret != SUCCESS) {


Loading…
Cancel
Save