| @@ -33,6 +33,8 @@ const int kCaseScopeWriteable = 2; | |||||
| const int kCaseWriteable = 3; | const int kCaseWriteable = 3; | ||||
| const int kCaseInvalidRWType = 5; | const int kCaseInvalidRWType = 5; | ||||
| const char *const kInputMutable = "_input_mutable"; | |||||
| // rw type of input. | // rw type of input. | ||||
| enum class InputRWType { | enum class InputRWType { | ||||
| kReadOnly, // Normal op input only read | kReadOnly, // Normal op input only read | ||||
| @@ -634,24 +636,29 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { | |||||
| } | } | ||||
| Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { | Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { | ||||
| for (const auto &node : compute_graph->GetDirectNode()) { | for (const auto &node : compute_graph->GetDirectNode()) { | ||||
| if (node->GetType() == HCOMALLREDUCE) { | |||||
| std::set<OutDataAnchorPtr> pre_out_anchor_set; | |||||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
| auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(pre_out_anchor); | |||||
| if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) { | |||||
| pre_out_anchor_set.emplace(pre_out_anchor); | |||||
| continue; | |||||
| } | |||||
| // need insert identity | |||||
| auto pre_node = pre_out_anchor->GetOwnerNode(); | |||||
| auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); | |||||
| GE_CHECK_NOTNULL(identity_node); | |||||
| auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); | |||||
| GE_CHK_STATUS_RET(ret, "Fail to insert identity."); | |||||
| GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||||
| pre_node->GetName().c_str(), node->GetName().c_str()); | |||||
| // op_desc of node should not be null | |||||
| const auto &op_desc = node->GetOpDesc(); | |||||
| bool mutable_input_flag = false; | |||||
| if (!AttrUtils::GetBool(op_desc, kInputMutable, mutable_input_flag) || !mutable_input_flag) { | |||||
| GELOGD("[Node:%s] Input is not mutable, ignore memory conflict handle", op_desc->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| std::set<OutDataAnchorPtr> pre_out_anchor_set; | |||||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
| auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(pre_out_anchor); | |||||
| if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) { | |||||
| pre_out_anchor_set.emplace(pre_out_anchor); | |||||
| continue; | |||||
| } | } | ||||
| // need insert identity | |||||
| auto pre_node = pre_out_anchor->GetOwnerNode(); | |||||
| auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); | |||||
| GE_CHECK_NOTNULL(identity_node); | |||||
| auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); | |||||
| GE_CHK_STATUS_RET(ret, "Fail to insert identity."); | |||||
| GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||||
| pre_node->GetName().c_str(), node->GetName().c_str()); | |||||
| } | } | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||