/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/optimize/optimizer/allreduce_fusion_pass.h" #include #include "common/debug/log.h" #include "framework/common/debug/ge_log.h" #include "common/types.h" #include "common/util.h" #include "graph/anchor.h" #include "graph/node.h" #include "graph/op_desc.h" #include "graph/utils/attr_utils.h" #include "graph/utils/graph_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/debug/ge_attr_define.h" #include "hccl/base.h" #include "hccl/hcom.h" namespace ge { Status AllReducePass::Run(ge::ComputeGraphPtr graph) { GELOGI("FusionAllReducePass: start"); std::vector fusionOps; std::vector inputGradientSize; std::vector inputGradientTime; static const float inputGradientSizeTemp = 0.0; static const float inputGradientTimeTemp = 0.0; // Get all nodes for (auto nodePtr : graph->GetDirectNode()) { GE_IF_BOOL_EXEC(nullptr == nodePtr, GELOGW("FusionAllReducePass: null node exists"); continue;); ge::OpDescPtr opDescPtr = nodePtr->GetOpDesc(); GE_IF_BOOL_EXEC(nullptr == opDescPtr, GELOGW("FusionAllReducePass: desc of node %s is null", nodePtr->GetName().c_str()); continue;) GE_IF_BOOL_EXEC(HCOMALLREDUCE == opDescPtr->GetType(), // the op is allreduce and fusion > 0, then run fusion std::int64_t hcom_fusion = 1; GE_IF_BOOL_EXEC(!ge::AttrUtils::GetInt(opDescPtr, HCOM_ATTR_FUSION, hcom_fusion), GELOGW("FusionAllReducePass: not get hcom_fusion from opDescPtr " "by HCOM_ATTR_FUSION")); GELOGI("after GetInt, hcom_fusion is :%ld", hcom_fusion); GE_IF_BOOL_EXEC( hcom_fusion > 0, fusionOps.push_back(nodePtr); inputGradientSize.push_back(inputGradientSizeTemp); inputGradientTime.push_back(inputGradientTimeTemp);)) } // The number of allredecue operator must be more than 1 GE_IF_BOOL_EXEC(1 >= fusionOps.size(), GELOGW("FusionAllReducePass NOT_CHANGED: the graph has " "%lu allreduce operator", fusionOps.size()); return NOT_CHANGED;); string group = "group"; u32 gradientNum = fusionOps.size(); string model_name_str = graph->GetName(); const char *model_name = model_name_str.c_str(); model_feature modelFeature{model_name, gradientNum, inputGradientSize.data(), inputGradientTime.data()}; u32 segmentNum = 0; u32 segmentIndex[HCCL_MAX_SEGMENT_NUM] = {}; // Call HCCL function: hcom_gradient_segment GELOGI("FusionAllReducePass: invoking hcom_get_split_strategy"); GE_IF_BOOL_EXEC(HCCL_SUCCESS != hcom_get_split_strategy(group.c_str(), &modelFeature, HCCL_MAX_SEGMENT_NUM, &segmentNum, segmentIndex), GELOGE(FAILED, "FusionAllReducePass FAILED: the graph has %lu allreduce operator", fusionOps.size()); return FAILED;) GELOGI("FusionAllReducePass: invoke hcom_get_split_strategy successfully"); // check whether segmentNum is legal or not GE_IF_BOOL_EXEC((HCCL_MAX_SEGMENT_NUM < segmentNum || 1 > segmentNum || segmentNum > gradientNum), GELOGE(FAILED, "FusionAllReducePass FAILED: illegal segmentNum=%u, " "HCCL_MAX_SEGMENT_NUM=%u, gradientNum=%u", segmentNum, HCCL_MAX_SEGMENT_NUM, gradientNum); return FAILED;); // check whether segmentIndex is legal or not GE_IF_BOOL_EXEC((segmentIndex[segmentNum - 1] != gradientNum - 1), GELOGE(FAILED, "FusionAllReducePass FAILED: illegal segmentIndex[0]=%u, " "segmentIndex[segmentNum-1]=%u, gradientNum=%u", segmentIndex[0], segmentIndex[(segmentNum)-1], gradientNum); return FAILED;); for (uint32_t i = 0; i < segmentNum - 1; i++) { GE_IF_BOOL_EXEC(segmentIndex[i] >= segmentIndex[i + 1], GELOGE(FAILED, "FusionAllReducePass FAILED: illegal " "segmentIndex[%u]=%u, segmentIndex[%u]=%u", i, segmentIndex[i], i + 1, segmentIndex[i + 1]); return FAILED;); } // check whether fusion is needed or not GE_IF_BOOL_EXEC( segmentNum == gradientNum, GELOGE(NOT_CHANGED, "FusionAllReducePass NOT_CHANGED: segmentNum=%u, gradientNum=%u", segmentNum, gradientNum); return NOT_CHANGED;) std::unordered_set anchorPtrSet; std::vector fusionOpPeerOutDataAnchor; std::vector fusionOpPeerOutDataToInControl; std::vector fusionOpPeerOutControlAnchor; std::vector> fusionOpPeerInDataAnchor; std::vector> fusionOpPeerInControlFromOutData; std::vector fusionOpPeerInControlAnchor; ge::OutControlAnchorPtr previousNewAllreduceOutControlAnchor = nullptr; // Traversing the segmentNum uint32_t start = 0; uint32_t end = 0; for (uint32_t segmentIdx = 0; segmentIdx < segmentNum; segmentIdx++) { end = segmentIndex[segmentIdx]; GE_IF_BOOL_EXEC(end - start < 1, GELOGI("FusionAllReducePass: segmentIndex[%u]=%u", segmentIdx, segmentIndex[segmentIdx]); start = end + 1; continue;); ge::OpDescPtr originDescPtr = fusionOps[start]->GetOpDesc(); GE_CHECK_NOTNULL(originDescPtr); ge::OpDescPtr newAllreduceDesc = AttrUtils::CloneOpDesc(originDescPtr); GE_CHECK_NOTNULL(newAllreduceDesc); // Cleat buffer anchorPtrSet.clear(); fusionOpPeerOutDataAnchor.clear(); fusionOpPeerOutDataToInControl.clear(); fusionOpPeerOutControlAnchor.clear(); fusionOpPeerInDataAnchor.clear(); fusionOpPeerInControlFromOutData.clear(); fusionOpPeerInControlAnchor.clear(); // Traversing the Allreduce operators of each group int outDataAnchorIndex = 0; GE_CHK_STATUS_RET(GetPeerOutDataToInData(anchorPtrSet, fusionOpPeerOutDataAnchor, fusionOps[start]), "Get peer outDataAnchor to inDataAnchor failed"); GE_CHK_STATUS_RET(GetPeerInAnchorToOutData(anchorPtrSet, fusionOpPeerInDataAnchor, fusionOpPeerInControlFromOutData, fusionOps[start]), "Get peer inDataAnchor and inControlAnchor to outDataAnchor failed"); GE_CHK_STATUS_RET(GetPeerOutDataToInControl(anchorPtrSet, fusionOpPeerOutDataToInControl, fusionOps[start]), "Get peer outDataAnchor to inControlAnchor failed"); GE_CHK_STATUS_RET(GetPeerOutControlToInControl(anchorPtrSet, fusionOpPeerOutControlAnchor, fusionOps[start]), "Get peer outControlAnchor to inControlAnchor failed"); GE_CHK_STATUS_RET(GetPeerInControlFromOutControl(anchorPtrSet, fusionOpPeerInControlAnchor, fusionOps[start]), "Get peer outControlAnchor from inControlAnchor failed"); GE_CHK_STATUS_RET(graph->RemoveNode(fusionOps[start]), "FusionAllReducePass FAILED: remove node %s\n.", fusionOps[start]->GetName().c_str()); for (uint32_t idx = start + 1; idx <= end; idx++) { GE_CHK_STATUS_RET( GetPeerOutDataToInData(anchorPtrSet, fusionOpPeerOutDataAnchor, fusionOps[idx], newAllreduceDesc), "Get peer outDataAnchor to inDataAnchor failed"); GE_CHK_STATUS_RET(GetPeerOutDataToInControl(anchorPtrSet, fusionOpPeerOutDataToInControl, fusionOps[idx]), "Get peer outDataAnchor to inControlAnchor failed"); GE_CHK_STATUS_RET(GetPeerOutControlToInControl(anchorPtrSet, fusionOpPeerOutControlAnchor, fusionOps[idx]), "Get peer outControlAnchor to inControlAnchor failed"); GE_CHK_STATUS_RET( GetPeerAnchorFromOutData(anchorPtrSet, fusionOpPeerInDataAnchor, fusionOpPeerInControlFromOutData, fusionOps[idx], newAllreduceDesc, outDataAnchorIndex), "Get peerAnchor from outDataAnchor failed"); GE_CHK_STATUS_RET(GetPeerInControlFromOutControl(anchorPtrSet, fusionOpPeerInControlAnchor, fusionOps[idx]), "Get peer outControlAnchor from inControlAnchor failed"); // Delete the node GE_CHK_STATUS_RET(graph->RemoveNode(fusionOps[idx]), "FusionAllReducePass FAILED: remove node %s\n.", fusionOps[idx]->GetName().c_str()); } NodePtr newAllReducePtr = graph->AddNode(newAllreduceDesc); GE_CHECK_NOTNULL(newAllReducePtr); // Link the inputDataAnchor for (uint32_t i = 0; i < fusionOpPeerOutDataAnchor.size(); i++) { GE_CHK_STATUS_RET( GraphUtils::AddEdge(fusionOpPeerOutDataAnchor[i], newAllReducePtr->GetInDataAnchor(static_cast(i))), "FusionAllReducePass FAILED: add input data edge failed"); } // Link the inputControlAnchor for (uint32_t i = 0; i < fusionOpPeerOutControlAnchor.size(); i++) { GE_CHK_STATUS_RET(GraphUtils::AddEdge(fusionOpPeerOutControlAnchor[i], newAllReducePtr->GetInControlAnchor()), "FusionAllReducePass FAILED: add input control edge failed"); } for (uint32_t i = 0; i < fusionOpPeerOutDataToInControl.size(); i++) { GE_CHK_STATUS_RET(GraphUtils::AddEdge(fusionOpPeerOutDataToInControl[i], newAllReducePtr->GetInControlAnchor()), "FusionAllReducePass FAILED: add edge from out data to incontrol " "failed"); } // Link the outputDataAnchor for (uint32_t i = 0; i < fusionOpPeerInDataAnchor.size(); i++) { auto peerInDataAnchor = fusionOpPeerInDataAnchor[i].second; GE_CHK_STATUS_RET( GraphUtils::AddEdge(newAllReducePtr->GetOutDataAnchor(fusionOpPeerInDataAnchor[i].first), peerInDataAnchor), "FusionAllReducePass FAILED: add output data edge failed"); } for (uint32_t i = 0; i < fusionOpPeerInControlFromOutData.size(); i++) { auto peerInControlAnchor = fusionOpPeerInControlFromOutData[i].second; GE_CHK_STATUS_RET( GraphUtils::AddEdge(newAllReducePtr->GetOutDataAnchor(fusionOpPeerInControlFromOutData[i].first), peerInControlAnchor), "FusionAllReducePass FAILED: add edge from out data to in control " "failed"); } // Link the outputControlAnchor for (uint32_t i = 0; i < fusionOpPeerInControlAnchor.size(); i++) { GE_CHK_STATUS_RET(GraphUtils::AddEdge(newAllReducePtr->GetOutControlAnchor(), fusionOpPeerInControlAnchor[i]), "FusionAllReducePass FAILED: add output control edge failed"); } // Link the newAllreduce if (segmentIdx > 0 && previousNewAllreduceOutControlAnchor != nullptr) { GE_CHK_STATUS_RET( GraphUtils::AddEdge(previousNewAllreduceOutControlAnchor, newAllReducePtr->GetInControlAnchor()), "FusionAllReducePass FAILED: add input previous control edge failed"); } previousNewAllreduceOutControlAnchor = newAllReducePtr->GetOutControlAnchor(); start = end + 1; } return SUCCESS; } Status AllReducePass::GetPeerOutDataToInData(std::unordered_set &anchorSet, vector &peerOutDataAnchorVec, ge::NodePtr &srcNodePtr) { for (auto inDataAnchor : srcNodePtr->GetAllInDataAnchors()) { GE_IF_BOOL_EXEC(inDataAnchor == nullptr, continue;); OutDataAnchorPtr peerOutDataAnchor = inDataAnchor->GetPeerOutAnchor(); GE_IF_BOOL_EXEC(peerOutDataAnchor == nullptr, continue;); if (anchorSet.count(peerOutDataAnchor.get()) == 0) { peerOutDataAnchorVec.push_back(peerOutDataAnchor); anchorSet.insert(peerOutDataAnchor.get()); GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(peerOutDataAnchor, inDataAnchor)); } } return SUCCESS; } Status AllReducePass::GetPeerInAnchorToOutData( std::unordered_set &anchorSet, std::vector> &fusionOpPeerInDataAnchor, std::vector> &fusionOpPeerInControlFromOutData, ge::NodePtr &srcNodePtr) { for (auto outDataAnchor : srcNodePtr->GetAllOutDataAnchors()) { GE_IF_BOOL_EXEC(outDataAnchor == nullptr, continue;); for (auto peerInDataAnchor : outDataAnchor->GetPeerInDataAnchors()) { GE_IF_BOOL_EXEC(peerInDataAnchor == nullptr, continue;); if (anchorSet.count(peerInDataAnchor.get()) == 0) { std::pair pairPeerInDataAnchor; pairPeerInDataAnchor.first = 0; pairPeerInDataAnchor.second = peerInDataAnchor; fusionOpPeerInDataAnchor.push_back(pairPeerInDataAnchor); anchorSet.insert(peerInDataAnchor.get()); GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(outDataAnchor, peerInDataAnchor)); } } for (auto peerInControlAnchorFromData : outDataAnchor->GetPeerInControlAnchors()) { GE_IF_BOOL_EXEC(peerInControlAnchorFromData == nullptr, continue;); if (anchorSet.count(peerInControlAnchorFromData.get()) == 0) { std::pair pairPeerInControlAnchorFromData; pairPeerInControlAnchorFromData.first = 0; pairPeerInControlAnchorFromData.second = peerInControlAnchorFromData; fusionOpPeerInControlFromOutData.push_back(pairPeerInControlAnchorFromData); anchorSet.insert(peerInControlAnchorFromData.get()); GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(outDataAnchor, peerInControlAnchorFromData)); } } } return SUCCESS; } Status AllReducePass::GetPeerOutDataToInData(std::unordered_set &anchorSet, vector &peerOutDataAnchorVec, ge::NodePtr &srcNodePtr, ge::OpDescPtr &dstOpDescPtr) { for (auto inDataAnchor : srcNodePtr->GetAllInDataAnchors()) { GE_IF_BOOL_EXEC(inDataAnchor == nullptr, continue;); OutDataAnchorPtr peerOutDataAnchor = inDataAnchor->GetPeerOutAnchor(); GE_IF_BOOL_EXEC(peerOutDataAnchor == nullptr, continue;); if (anchorSet.count(peerOutDataAnchor.get()) == 0) { peerOutDataAnchorVec.push_back(peerOutDataAnchor); anchorSet.insert(peerOutDataAnchor.get()); if (dstOpDescPtr->AddInputDesc(inDataAnchor->GetOwnerNode()->GetOpDesc()->GetInputDesc(inDataAnchor->GetIdx())) != ge::GRAPH_SUCCESS) { GELOGW("GetPeerOutDataToInData: AddInputDesc failed"); } GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(peerOutDataAnchor, inDataAnchor)); } } return SUCCESS; } Status AllReducePass::GetPeerOutDataToInControl(std::unordered_set &anchorSet, vector &peerOutDataToInControlVec, ge::NodePtr &srcNodePtr) { InControlAnchorPtr inControlAnchor = srcNodePtr->GetInControlAnchor(); GE_CHECK_NOTNULL(inControlAnchor); for (auto peerOutDataToInControl : inControlAnchor->GetPeerOutDataAnchors()) { GE_IF_BOOL_EXEC(peerOutDataToInControl == nullptr, continue;); if (anchorSet.count(peerOutDataToInControl.get()) == 0) { peerOutDataToInControlVec.push_back(peerOutDataToInControl); anchorSet.insert(peerOutDataToInControl.get()); GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(peerOutDataToInControl, inControlAnchor)); } } return SUCCESS; } Status AllReducePass::GetPeerOutControlToInControl(std::unordered_set &anchorSet, vector &peerOutControlToInControlVec, ge::NodePtr &srcNodePtr) { InControlAnchorPtr inControlAnchor = srcNodePtr->GetInControlAnchor(); GE_CHECK_NOTNULL(inControlAnchor); for (auto peerOutControlAnchor : inControlAnchor->GetPeerOutControlAnchors()) { GE_IF_BOOL_EXEC(peerOutControlAnchor == nullptr, continue;); if (anchorSet.count(peerOutControlAnchor.get()) == 0) { peerOutControlToInControlVec.push_back(peerOutControlAnchor); anchorSet.insert(peerOutControlAnchor.get()); GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(peerOutControlAnchor, inControlAnchor)); } } return SUCCESS; } Status AllReducePass::GetPeerAnchorFromOutData( std::unordered_set &anchorSet, vector> &peerInDataFromOutDataVec, vector> &peerInControlFromOutDataVec, ge::NodePtr &srcNodePtr, ge::OpDescPtr &dstOpDescPtr, int &index) { for (auto outDataAnchor : srcNodePtr->GetAllOutDataAnchors()) { GE_IF_BOOL_EXEC(outDataAnchor == nullptr, continue;) if (outDataAnchor->GetPeerInDataAnchors().size() > 0 || outDataAnchor->GetPeerInControlAnchors().size() > 0) { if (dstOpDescPtr->AddOutputDesc( outDataAnchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(outDataAnchor->GetIdx())) != ge::GRAPH_SUCCESS) { GELOGW("GetPeerAnchorFromOutData: AddOutputDesc failed"); } index++; } for (auto peerInDataAnchor : outDataAnchor->GetPeerInDataAnchors()) { GE_IF_BOOL_EXEC(peerInDataAnchor == nullptr, continue;) if (anchorSet.count(peerInDataAnchor.get()) == 0) { std::pair pairPeerInDataAnchor; pairPeerInDataAnchor.first = index; pairPeerInDataAnchor.second = peerInDataAnchor; peerInDataFromOutDataVec.push_back(pairPeerInDataAnchor); anchorSet.insert(peerInDataAnchor.get()); GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(outDataAnchor, peerInDataAnchor)) } } for (auto peerInControlAnchorFromData : outDataAnchor->GetPeerInControlAnchors()) { GE_IF_BOOL_EXEC(peerInControlAnchorFromData == nullptr, continue;) if (anchorSet.count(peerInControlAnchorFromData.get()) == 0) { std::pair pairPeerInControlAnchorFromData; pairPeerInControlAnchorFromData.first = index; pairPeerInControlAnchorFromData.second = peerInControlAnchorFromData; peerInControlFromOutDataVec.push_back(pairPeerInControlAnchorFromData); anchorSet.insert(peerInControlAnchorFromData.get()); GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(outDataAnchor, peerInControlAnchorFromData)) } } } return SUCCESS; } Status AllReducePass::GetPeerInControlFromOutControl(std::unordered_set &anchorSet, vector &peerInControlFromOutControlVec, ge::NodePtr &srcNodePtr) { OutControlAnchorPtr outControlAnchor = srcNodePtr->GetOutControlAnchor(); GE_CHECK_NOTNULL(outControlAnchor); for (auto peerInControlAnchor : outControlAnchor->GetPeerInControlAnchors()) { GE_IF_BOOL_EXEC(peerInControlAnchor == nullptr, continue;) if (anchorSet.count(peerInControlAnchor.get()) == 0) { peerInControlFromOutControlVec.push_back(peerInControlAnchor); anchorSet.insert(peerInControlAnchor.get()); GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(outControlAnchor, peerInControlAnchor)) } } return SUCCESS; } } // namespace ge