| @@ -33,6 +33,8 @@ const int kCaseScopeWriteable = 2; | |||
| const int kCaseWriteable = 3; | |||
| const int kCaseInvalidRWType = 5; | |||
| const char *const kInputMutable = "_input_mutable"; | |||
| // rw type of input. | |||
| enum class InputRWType { | |||
| kReadOnly, // Normal op input only read | |||
| @@ -634,24 +636,29 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { | |||
| } | |||
| Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { | |||
| for (const auto &node : compute_graph->GetDirectNode()) { | |||
| if (node->GetType() == HCOMALLREDUCE) { | |||
| std::set<OutDataAnchorPtr> pre_out_anchor_set; | |||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
| auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| GE_CHECK_NOTNULL(pre_out_anchor); | |||
| if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) { | |||
| pre_out_anchor_set.emplace(pre_out_anchor); | |||
| continue; | |||
| } | |||
| // need insert identity | |||
| auto pre_node = pre_out_anchor->GetOwnerNode(); | |||
| auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); | |||
| GE_CHECK_NOTNULL(identity_node); | |||
| auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); | |||
| GE_CHK_STATUS_RET(ret, "Fail to insert identity."); | |||
| GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||
| pre_node->GetName().c_str(), node->GetName().c_str()); | |||
| // op_desc of node should not be null | |||
| const auto &op_desc = node->GetOpDesc(); | |||
| bool mutable_input_flag = false; | |||
| if (!AttrUtils::GetBool(op_desc, kInputMutable, mutable_input_flag) || !mutable_input_flag) { | |||
| GELOGD("[Node:%s] Input is not mutable, ignore memory conflict handle", op_desc->GetName().c_str()); | |||
| continue; | |||
| } | |||
| std::set<OutDataAnchorPtr> pre_out_anchor_set; | |||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
| auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| GE_CHECK_NOTNULL(pre_out_anchor); | |||
| if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) { | |||
| pre_out_anchor_set.emplace(pre_out_anchor); | |||
| continue; | |||
| } | |||
| // need insert identity | |||
| auto pre_node = pre_out_anchor->GetOwnerNode(); | |||
| auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); | |||
| GE_CHECK_NOTNULL(identity_node); | |||
| auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); | |||
| GE_CHK_STATUS_RET(ret, "Fail to insert identity."); | |||
| GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||
| pre_node->GetName().c_str(), node->GetName().c_str()); | |||
| } | |||
| } | |||
| return SUCCESS; | |||