|
@@ -1,397 +0,0 @@ |
|
|
/** |
|
|
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
|
|
* |
|
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
|
|
* You may obtain a copy of the License at |
|
|
|
|
|
* |
|
|
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
|
|
* |
|
|
|
|
|
* Unless required by applicable law or agreed to in writing, software |
|
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
|
|
* See the License for the specific language governing permissions and |
|
|
|
|
|
* limitations under the License. |
|
|
|
|
|
*/ |
|
|
|
|
|
|
|
|
|
|
|
#include "graph/optimize/optimizer/allreduce_fusion_pass.h" |
|
|
|
|
|
#include <string> |
|
|
|
|
|
#include "common/debug/log.h" |
|
|
|
|
|
#include "framework/common/debug/ge_log.h" |
|
|
|
|
|
#include "common/types.h" |
|
|
|
|
|
#include "common/util.h" |
|
|
|
|
|
#include "graph/anchor.h" |
|
|
|
|
|
#include "graph/node.h" |
|
|
|
|
|
#include "graph/op_desc.h" |
|
|
|
|
|
#include "graph/utils/attr_utils.h" |
|
|
|
|
|
#include "graph/utils/graph_utils.h" |
|
|
|
|
|
#include "graph/utils/tensor_utils.h" |
|
|
|
|
|
#include "graph/debug/ge_attr_define.h" |
|
|
|
|
|
#include "hccl/base.h" |
|
|
|
|
|
#include "hccl/hcom.h" |
|
|
|
|
|
|
|
|
|
|
|
namespace ge { |
|
|
|
|
|
Status AllReducePass::Run(ge::ComputeGraphPtr graph) { |
|
|
|
|
|
GELOGI("FusionAllReducePass: start"); |
|
|
|
|
|
std::vector<NodePtr> fusionOps; |
|
|
|
|
|
std::vector<float> inputGradientSize; |
|
|
|
|
|
std::vector<float> inputGradientTime; |
|
|
|
|
|
|
|
|
|
|
|
static const float inputGradientSizeTemp = 0.0; |
|
|
|
|
|
static const float inputGradientTimeTemp = 0.0; |
|
|
|
|
|
|
|
|
|
|
|
// Get all nodes |
|
|
|
|
|
for (auto nodePtr : graph->GetDirectNode()) { |
|
|
|
|
|
GE_IF_BOOL_EXEC(nullptr == nodePtr, GELOGW("FusionAllReducePass: null node exists"); continue;); |
|
|
|
|
|
|
|
|
|
|
|
ge::OpDescPtr opDescPtr = nodePtr->GetOpDesc(); |
|
|
|
|
|
GE_IF_BOOL_EXEC(nullptr == opDescPtr, |
|
|
|
|
|
GELOGW("FusionAllReducePass: desc of node %s is null", nodePtr->GetName().c_str()); |
|
|
|
|
|
continue;) |
|
|
|
|
|
GE_IF_BOOL_EXEC(HCOMALLREDUCE == opDescPtr->GetType(), |
|
|
|
|
|
// the op is allreduce and fusion > 0, then run fusion |
|
|
|
|
|
std::int64_t hcom_fusion = 1; |
|
|
|
|
|
GE_IF_BOOL_EXEC(!ge::AttrUtils::GetInt(opDescPtr, HCOM_ATTR_FUSION, hcom_fusion), |
|
|
|
|
|
GELOGW("FusionAllReducePass: not get hcom_fusion from opDescPtr " |
|
|
|
|
|
"by HCOM_ATTR_FUSION")); |
|
|
|
|
|
GELOGI("after GetInt, hcom_fusion is :%ld", hcom_fusion); GE_IF_BOOL_EXEC( |
|
|
|
|
|
hcom_fusion > 0, fusionOps.push_back(nodePtr); inputGradientSize.push_back(inputGradientSizeTemp); |
|
|
|
|
|
inputGradientTime.push_back(inputGradientTimeTemp);)) |
|
|
|
|
|
} |
|
|
|
|
|
// The number of allredecue operator must be more than 1 |
|
|
|
|
|
GE_IF_BOOL_EXEC(1 >= fusionOps.size(), GELOGW("FusionAllReducePass NOT_CHANGED: the graph has " |
|
|
|
|
|
"%lu allreduce operator", |
|
|
|
|
|
fusionOps.size()); |
|
|
|
|
|
return NOT_CHANGED;); |
|
|
|
|
|
|
|
|
|
|
|
string group = "group"; |
|
|
|
|
|
u32 gradientNum = fusionOps.size(); |
|
|
|
|
|
string model_name_str = graph->GetName(); |
|
|
|
|
|
const char *model_name = model_name_str.c_str(); |
|
|
|
|
|
model_feature modelFeature{model_name, gradientNum, inputGradientSize.data(), inputGradientTime.data()}; |
|
|
|
|
|
|
|
|
|
|
|
u32 segmentNum = 0; |
|
|
|
|
|
u32 segmentIndex[HCCL_MAX_SEGMENT_NUM] = {}; |
|
|
|
|
|
|
|
|
|
|
|
// Call HCCL function: hcom_gradient_segment |
|
|
|
|
|
GELOGI("FusionAllReducePass: invoking hcom_get_split_strategy"); |
|
|
|
|
|
GE_IF_BOOL_EXEC(HCCL_SUCCESS != hcom_get_split_strategy(group.c_str(), &modelFeature, HCCL_MAX_SEGMENT_NUM, |
|
|
|
|
|
&segmentNum, segmentIndex), |
|
|
|
|
|
GELOGE(FAILED, "FusionAllReducePass FAILED: the graph has %lu allreduce operator", fusionOps.size()); |
|
|
|
|
|
return FAILED;) |
|
|
|
|
|
GELOGI("FusionAllReducePass: invoke hcom_get_split_strategy successfully"); |
|
|
|
|
|
|
|
|
|
|
|
// check whether segmentNum is legal or not |
|
|
|
|
|
GE_IF_BOOL_EXEC((HCCL_MAX_SEGMENT_NUM < segmentNum || 1 > segmentNum || segmentNum > gradientNum), |
|
|
|
|
|
GELOGE(FAILED, |
|
|
|
|
|
"FusionAllReducePass FAILED: illegal segmentNum=%u, " |
|
|
|
|
|
"HCCL_MAX_SEGMENT_NUM=%u, gradientNum=%u", |
|
|
|
|
|
segmentNum, HCCL_MAX_SEGMENT_NUM, gradientNum); |
|
|
|
|
|
return FAILED;); |
|
|
|
|
|
|
|
|
|
|
|
// check whether segmentIndex is legal or not |
|
|
|
|
|
GE_IF_BOOL_EXEC((segmentIndex[segmentNum - 1] != gradientNum - 1), |
|
|
|
|
|
GELOGE(FAILED, |
|
|
|
|
|
"FusionAllReducePass FAILED: illegal segmentIndex[0]=%u, " |
|
|
|
|
|
"segmentIndex[segmentNum-1]=%u, gradientNum=%u", |
|
|
|
|
|
segmentIndex[0], segmentIndex[(segmentNum)-1], gradientNum); |
|
|
|
|
|
return FAILED;); |
|
|
|
|
|
|
|
|
|
|
|
for (uint32_t i = 0; i < segmentNum - 1; i++) { |
|
|
|
|
|
GE_IF_BOOL_EXEC(segmentIndex[i] >= segmentIndex[i + 1], GELOGE(FAILED, |
|
|
|
|
|
"FusionAllReducePass FAILED: illegal " |
|
|
|
|
|
"segmentIndex[%u]=%u, segmentIndex[%u]=%u", |
|
|
|
|
|
i, segmentIndex[i], i + 1, segmentIndex[i + 1]); |
|
|
|
|
|
return FAILED;); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// check whether fusion is needed or not |
|
|
|
|
|
GE_IF_BOOL_EXEC( |
|
|
|
|
|
segmentNum == gradientNum, |
|
|
|
|
|
GELOGE(NOT_CHANGED, "FusionAllReducePass NOT_CHANGED: segmentNum=%u, gradientNum=%u", segmentNum, gradientNum); |
|
|
|
|
|
return NOT_CHANGED;) |
|
|
|
|
|
|
|
|
|
|
|
std::unordered_set<void *> anchorPtrSet; |
|
|
|
|
|
std::vector<ge::OutDataAnchorPtr> fusionOpPeerOutDataAnchor; |
|
|
|
|
|
std::vector<ge::OutDataAnchorPtr> fusionOpPeerOutDataToInControl; |
|
|
|
|
|
std::vector<ge::OutControlAnchorPtr> fusionOpPeerOutControlAnchor; |
|
|
|
|
|
std::vector<std::pair<int, ge::InDataAnchorPtr>> fusionOpPeerInDataAnchor; |
|
|
|
|
|
std::vector<std::pair<int, ge::InControlAnchorPtr>> fusionOpPeerInControlFromOutData; |
|
|
|
|
|
std::vector<ge::InControlAnchorPtr> fusionOpPeerInControlAnchor; |
|
|
|
|
|
ge::OutControlAnchorPtr previousNewAllreduceOutControlAnchor = nullptr; |
|
|
|
|
|
|
|
|
|
|
|
// Traversing the segmentNum |
|
|
|
|
|
uint32_t start = 0; |
|
|
|
|
|
uint32_t end = 0; |
|
|
|
|
|
for (uint32_t segmentIdx = 0; segmentIdx < segmentNum; segmentIdx++) { |
|
|
|
|
|
end = segmentIndex[segmentIdx]; |
|
|
|
|
|
GE_IF_BOOL_EXEC(end - start < 1, |
|
|
|
|
|
GELOGI("FusionAllReducePass: segmentIndex[%u]=%u", segmentIdx, segmentIndex[segmentIdx]); |
|
|
|
|
|
start = end + 1; continue;); |
|
|
|
|
|
|
|
|
|
|
|
ge::OpDescPtr originDescPtr = fusionOps[start]->GetOpDesc(); |
|
|
|
|
|
GE_CHECK_NOTNULL(originDescPtr); |
|
|
|
|
|
ge::OpDescPtr newAllreduceDesc = AttrUtils::CloneOpDesc(originDescPtr); |
|
|
|
|
|
GE_CHECK_NOTNULL(newAllreduceDesc); |
|
|
|
|
|
|
|
|
|
|
|
// Cleat buffer |
|
|
|
|
|
anchorPtrSet.clear(); |
|
|
|
|
|
fusionOpPeerOutDataAnchor.clear(); |
|
|
|
|
|
fusionOpPeerOutDataToInControl.clear(); |
|
|
|
|
|
fusionOpPeerOutControlAnchor.clear(); |
|
|
|
|
|
fusionOpPeerInDataAnchor.clear(); |
|
|
|
|
|
fusionOpPeerInControlFromOutData.clear(); |
|
|
|
|
|
fusionOpPeerInControlAnchor.clear(); |
|
|
|
|
|
|
|
|
|
|
|
// Traversing the Allreduce operators of each group |
|
|
|
|
|
int outDataAnchorIndex = 0; |
|
|
|
|
|
GE_CHK_STATUS_RET(GetPeerOutDataToInData(anchorPtrSet, fusionOpPeerOutDataAnchor, fusionOps[start]), |
|
|
|
|
|
"Get peer outDataAnchor to inDataAnchor failed"); |
|
|
|
|
|
|
|
|
|
|
|
GE_CHK_STATUS_RET(GetPeerInAnchorToOutData(anchorPtrSet, fusionOpPeerInDataAnchor, fusionOpPeerInControlFromOutData, |
|
|
|
|
|
fusionOps[start]), |
|
|
|
|
|
"Get peer inDataAnchor and inControlAnchor to outDataAnchor failed"); |
|
|
|
|
|
|
|
|
|
|
|
GE_CHK_STATUS_RET(GetPeerOutDataToInControl(anchorPtrSet, fusionOpPeerOutDataToInControl, fusionOps[start]), |
|
|
|
|
|
"Get peer outDataAnchor to inControlAnchor failed"); |
|
|
|
|
|
GE_CHK_STATUS_RET(GetPeerOutControlToInControl(anchorPtrSet, fusionOpPeerOutControlAnchor, fusionOps[start]), |
|
|
|
|
|
"Get peer outControlAnchor to inControlAnchor failed"); |
|
|
|
|
|
GE_CHK_STATUS_RET(GetPeerInControlFromOutControl(anchorPtrSet, fusionOpPeerInControlAnchor, fusionOps[start]), |
|
|
|
|
|
"Get peer outControlAnchor from inControlAnchor failed"); |
|
|
|
|
|
GE_CHK_STATUS_RET(graph->RemoveNode(fusionOps[start]), "FusionAllReducePass FAILED: remove node %s\n.", |
|
|
|
|
|
fusionOps[start]->GetName().c_str()); |
|
|
|
|
|
|
|
|
|
|
|
for (uint32_t idx = start + 1; idx <= end; idx++) { |
|
|
|
|
|
GE_CHK_STATUS_RET( |
|
|
|
|
|
GetPeerOutDataToInData(anchorPtrSet, fusionOpPeerOutDataAnchor, fusionOps[idx], newAllreduceDesc), |
|
|
|
|
|
"Get peer outDataAnchor to inDataAnchor failed"); |
|
|
|
|
|
GE_CHK_STATUS_RET(GetPeerOutDataToInControl(anchorPtrSet, fusionOpPeerOutDataToInControl, fusionOps[idx]), |
|
|
|
|
|
"Get peer outDataAnchor to inControlAnchor failed"); |
|
|
|
|
|
GE_CHK_STATUS_RET(GetPeerOutControlToInControl(anchorPtrSet, fusionOpPeerOutControlAnchor, fusionOps[idx]), |
|
|
|
|
|
"Get peer outControlAnchor to inControlAnchor failed"); |
|
|
|
|
|
GE_CHK_STATUS_RET( |
|
|
|
|
|
GetPeerAnchorFromOutData(anchorPtrSet, fusionOpPeerInDataAnchor, fusionOpPeerInControlFromOutData, |
|
|
|
|
|
fusionOps[idx], newAllreduceDesc, outDataAnchorIndex), |
|
|
|
|
|
"Get peerAnchor from outDataAnchor failed"); |
|
|
|
|
|
GE_CHK_STATUS_RET(GetPeerInControlFromOutControl(anchorPtrSet, fusionOpPeerInControlAnchor, fusionOps[idx]), |
|
|
|
|
|
"Get peer outControlAnchor from inControlAnchor failed"); |
|
|
|
|
|
|
|
|
|
|
|
// Delete the node |
|
|
|
|
|
GE_CHK_STATUS_RET(graph->RemoveNode(fusionOps[idx]), "FusionAllReducePass FAILED: remove node %s\n.", |
|
|
|
|
|
fusionOps[idx]->GetName().c_str()); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
NodePtr newAllReducePtr = graph->AddNode(newAllreduceDesc); |
|
|
|
|
|
GE_CHECK_NOTNULL(newAllReducePtr); |
|
|
|
|
|
// Link the inputDataAnchor |
|
|
|
|
|
for (uint32_t i = 0; i < fusionOpPeerOutDataAnchor.size(); i++) { |
|
|
|
|
|
GE_CHK_STATUS_RET( |
|
|
|
|
|
GraphUtils::AddEdge(fusionOpPeerOutDataAnchor[i], newAllReducePtr->GetInDataAnchor(static_cast<int>(i))), |
|
|
|
|
|
"FusionAllReducePass FAILED: add input data edge failed"); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Link the inputControlAnchor |
|
|
|
|
|
for (uint32_t i = 0; i < fusionOpPeerOutControlAnchor.size(); i++) { |
|
|
|
|
|
GE_CHK_STATUS_RET(GraphUtils::AddEdge(fusionOpPeerOutControlAnchor[i], newAllReducePtr->GetInControlAnchor()), |
|
|
|
|
|
"FusionAllReducePass FAILED: add input control edge failed"); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
for (uint32_t i = 0; i < fusionOpPeerOutDataToInControl.size(); i++) { |
|
|
|
|
|
GE_CHK_STATUS_RET(GraphUtils::AddEdge(fusionOpPeerOutDataToInControl[i], newAllReducePtr->GetInControlAnchor()), |
|
|
|
|
|
"FusionAllReducePass FAILED: add edge from out data to incontrol " |
|
|
|
|
|
"failed"); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Link the outputDataAnchor |
|
|
|
|
|
for (uint32_t i = 0; i < fusionOpPeerInDataAnchor.size(); i++) { |
|
|
|
|
|
auto peerInDataAnchor = fusionOpPeerInDataAnchor[i].second; |
|
|
|
|
|
GE_CHK_STATUS_RET( |
|
|
|
|
|
GraphUtils::AddEdge(newAllReducePtr->GetOutDataAnchor(fusionOpPeerInDataAnchor[i].first), peerInDataAnchor), |
|
|
|
|
|
"FusionAllReducePass FAILED: add output data edge failed"); |
|
|
|
|
|
} |
|
|
|
|
|
for (uint32_t i = 0; i < fusionOpPeerInControlFromOutData.size(); i++) { |
|
|
|
|
|
auto peerInControlAnchor = fusionOpPeerInControlFromOutData[i].second; |
|
|
|
|
|
GE_CHK_STATUS_RET( |
|
|
|
|
|
GraphUtils::AddEdge(newAllReducePtr->GetOutDataAnchor(fusionOpPeerInControlFromOutData[i].first), |
|
|
|
|
|
peerInControlAnchor), |
|
|
|
|
|
"FusionAllReducePass FAILED: add edge from out data to in control " |
|
|
|
|
|
"failed"); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Link the outputControlAnchor |
|
|
|
|
|
for (uint32_t i = 0; i < fusionOpPeerInControlAnchor.size(); i++) { |
|
|
|
|
|
GE_CHK_STATUS_RET(GraphUtils::AddEdge(newAllReducePtr->GetOutControlAnchor(), fusionOpPeerInControlAnchor[i]), |
|
|
|
|
|
"FusionAllReducePass FAILED: add output control edge failed"); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Link the newAllreduce |
|
|
|
|
|
if (segmentIdx > 0 && previousNewAllreduceOutControlAnchor != nullptr) { |
|
|
|
|
|
GE_CHK_STATUS_RET( |
|
|
|
|
|
GraphUtils::AddEdge(previousNewAllreduceOutControlAnchor, newAllReducePtr->GetInControlAnchor()), |
|
|
|
|
|
"FusionAllReducePass FAILED: add input previous control edge failed"); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
previousNewAllreduceOutControlAnchor = newAllReducePtr->GetOutControlAnchor(); |
|
|
|
|
|
start = end + 1; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status AllReducePass::GetPeerOutDataToInData(std::unordered_set<void *> &anchorSet, |
|
|
|
|
|
vector<ge::OutDataAnchorPtr> &peerOutDataAnchorVec, |
|
|
|
|
|
ge::NodePtr &srcNodePtr) { |
|
|
|
|
|
for (auto inDataAnchor : srcNodePtr->GetAllInDataAnchors()) { |
|
|
|
|
|
GE_IF_BOOL_EXEC(inDataAnchor == nullptr, continue;); |
|
|
|
|
|
OutDataAnchorPtr peerOutDataAnchor = inDataAnchor->GetPeerOutAnchor(); |
|
|
|
|
|
GE_IF_BOOL_EXEC(peerOutDataAnchor == nullptr, continue;); |
|
|
|
|
|
if (anchorSet.count(peerOutDataAnchor.get()) == 0) { |
|
|
|
|
|
peerOutDataAnchorVec.push_back(peerOutDataAnchor); |
|
|
|
|
|
anchorSet.insert(peerOutDataAnchor.get()); |
|
|
|
|
|
GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(peerOutDataAnchor, inDataAnchor)); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status AllReducePass::GetPeerInAnchorToOutData( |
|
|
|
|
|
std::unordered_set<void *> &anchorSet, std::vector<std::pair<int, ge::InDataAnchorPtr>> &fusionOpPeerInDataAnchor, |
|
|
|
|
|
std::vector<std::pair<int, ge::InControlAnchorPtr>> &fusionOpPeerInControlFromOutData, ge::NodePtr &srcNodePtr) { |
|
|
|
|
|
for (auto outDataAnchor : srcNodePtr->GetAllOutDataAnchors()) { |
|
|
|
|
|
GE_IF_BOOL_EXEC(outDataAnchor == nullptr, continue;); |
|
|
|
|
|
for (auto peerInDataAnchor : outDataAnchor->GetPeerInDataAnchors()) { |
|
|
|
|
|
GE_IF_BOOL_EXEC(peerInDataAnchor == nullptr, continue;); |
|
|
|
|
|
if (anchorSet.count(peerInDataAnchor.get()) == 0) { |
|
|
|
|
|
std::pair<int, ge::InDataAnchorPtr> pairPeerInDataAnchor; |
|
|
|
|
|
pairPeerInDataAnchor.first = 0; |
|
|
|
|
|
pairPeerInDataAnchor.second = peerInDataAnchor; |
|
|
|
|
|
fusionOpPeerInDataAnchor.push_back(pairPeerInDataAnchor); |
|
|
|
|
|
anchorSet.insert(peerInDataAnchor.get()); |
|
|
|
|
|
GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(outDataAnchor, peerInDataAnchor)); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
for (auto peerInControlAnchorFromData : outDataAnchor->GetPeerInControlAnchors()) { |
|
|
|
|
|
GE_IF_BOOL_EXEC(peerInControlAnchorFromData == nullptr, continue;); |
|
|
|
|
|
if (anchorSet.count(peerInControlAnchorFromData.get()) == 0) { |
|
|
|
|
|
std::pair<uint32_t, ge::InControlAnchorPtr> pairPeerInControlAnchorFromData; |
|
|
|
|
|
pairPeerInControlAnchorFromData.first = 0; |
|
|
|
|
|
pairPeerInControlAnchorFromData.second = peerInControlAnchorFromData; |
|
|
|
|
|
fusionOpPeerInControlFromOutData.push_back(pairPeerInControlAnchorFromData); |
|
|
|
|
|
anchorSet.insert(peerInControlAnchorFromData.get()); |
|
|
|
|
|
GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(outDataAnchor, peerInControlAnchorFromData)); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status AllReducePass::GetPeerOutDataToInData(std::unordered_set<void *> &anchorSet, |
|
|
|
|
|
vector<ge::OutDataAnchorPtr> &peerOutDataAnchorVec, |
|
|
|
|
|
ge::NodePtr &srcNodePtr, ge::OpDescPtr &dstOpDescPtr) { |
|
|
|
|
|
for (auto inDataAnchor : srcNodePtr->GetAllInDataAnchors()) { |
|
|
|
|
|
GE_IF_BOOL_EXEC(inDataAnchor == nullptr, continue;); |
|
|
|
|
|
OutDataAnchorPtr peerOutDataAnchor = inDataAnchor->GetPeerOutAnchor(); |
|
|
|
|
|
GE_IF_BOOL_EXEC(peerOutDataAnchor == nullptr, continue;); |
|
|
|
|
|
if (anchorSet.count(peerOutDataAnchor.get()) == 0) { |
|
|
|
|
|
peerOutDataAnchorVec.push_back(peerOutDataAnchor); |
|
|
|
|
|
anchorSet.insert(peerOutDataAnchor.get()); |
|
|
|
|
|
if (dstOpDescPtr->AddInputDesc(inDataAnchor->GetOwnerNode()->GetOpDesc()->GetInputDesc(inDataAnchor->GetIdx())) != |
|
|
|
|
|
ge::GRAPH_SUCCESS) { |
|
|
|
|
|
GELOGW("GetPeerOutDataToInData: AddInputDesc failed"); |
|
|
|
|
|
} |
|
|
|
|
|
GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(peerOutDataAnchor, inDataAnchor)); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status AllReducePass::GetPeerOutDataToInControl(std::unordered_set<void *> &anchorSet, |
|
|
|
|
|
vector<ge::OutDataAnchorPtr> &peerOutDataToInControlVec, |
|
|
|
|
|
ge::NodePtr &srcNodePtr) { |
|
|
|
|
|
InControlAnchorPtr inControlAnchor = srcNodePtr->GetInControlAnchor(); |
|
|
|
|
|
GE_CHECK_NOTNULL(inControlAnchor); |
|
|
|
|
|
for (auto peerOutDataToInControl : inControlAnchor->GetPeerOutDataAnchors()) { |
|
|
|
|
|
GE_IF_BOOL_EXEC(peerOutDataToInControl == nullptr, continue;); |
|
|
|
|
|
if (anchorSet.count(peerOutDataToInControl.get()) == 0) { |
|
|
|
|
|
peerOutDataToInControlVec.push_back(peerOutDataToInControl); |
|
|
|
|
|
anchorSet.insert(peerOutDataToInControl.get()); |
|
|
|
|
|
GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(peerOutDataToInControl, inControlAnchor)); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status AllReducePass::GetPeerOutControlToInControl(std::unordered_set<void *> &anchorSet, |
|
|
|
|
|
vector<ge::OutControlAnchorPtr> &peerOutControlToInControlVec, |
|
|
|
|
|
ge::NodePtr &srcNodePtr) { |
|
|
|
|
|
InControlAnchorPtr inControlAnchor = srcNodePtr->GetInControlAnchor(); |
|
|
|
|
|
GE_CHECK_NOTNULL(inControlAnchor); |
|
|
|
|
|
for (auto peerOutControlAnchor : inControlAnchor->GetPeerOutControlAnchors()) { |
|
|
|
|
|
GE_IF_BOOL_EXEC(peerOutControlAnchor == nullptr, continue;); |
|
|
|
|
|
if (anchorSet.count(peerOutControlAnchor.get()) == 0) { |
|
|
|
|
|
peerOutControlToInControlVec.push_back(peerOutControlAnchor); |
|
|
|
|
|
anchorSet.insert(peerOutControlAnchor.get()); |
|
|
|
|
|
GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(peerOutControlAnchor, inControlAnchor)); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status AllReducePass::GetPeerAnchorFromOutData( |
|
|
|
|
|
std::unordered_set<void *> &anchorSet, vector<std::pair<int, ge::InDataAnchorPtr>> &peerInDataFromOutDataVec, |
|
|
|
|
|
vector<std::pair<int, ge::InControlAnchorPtr>> &peerInControlFromOutDataVec, ge::NodePtr &srcNodePtr, |
|
|
|
|
|
ge::OpDescPtr &dstOpDescPtr, int &index) { |
|
|
|
|
|
for (auto outDataAnchor : srcNodePtr->GetAllOutDataAnchors()) { |
|
|
|
|
|
GE_IF_BOOL_EXEC(outDataAnchor == nullptr, continue;) |
|
|
|
|
|
if (outDataAnchor->GetPeerInDataAnchors().size() > 0 || outDataAnchor->GetPeerInControlAnchors().size() > 0) { |
|
|
|
|
|
if (dstOpDescPtr->AddOutputDesc( |
|
|
|
|
|
outDataAnchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(outDataAnchor->GetIdx())) != ge::GRAPH_SUCCESS) { |
|
|
|
|
|
GELOGW("GetPeerAnchorFromOutData: AddOutputDesc failed"); |
|
|
|
|
|
} |
|
|
|
|
|
index++; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
for (auto peerInDataAnchor : outDataAnchor->GetPeerInDataAnchors()) { |
|
|
|
|
|
GE_IF_BOOL_EXEC(peerInDataAnchor == nullptr, continue;) |
|
|
|
|
|
if (anchorSet.count(peerInDataAnchor.get()) == 0) { |
|
|
|
|
|
std::pair<int, ge::InDataAnchorPtr> pairPeerInDataAnchor; |
|
|
|
|
|
pairPeerInDataAnchor.first = index; |
|
|
|
|
|
pairPeerInDataAnchor.second = peerInDataAnchor; |
|
|
|
|
|
peerInDataFromOutDataVec.push_back(pairPeerInDataAnchor); |
|
|
|
|
|
anchorSet.insert(peerInDataAnchor.get()); |
|
|
|
|
|
GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(outDataAnchor, peerInDataAnchor)) |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
for (auto peerInControlAnchorFromData : outDataAnchor->GetPeerInControlAnchors()) { |
|
|
|
|
|
GE_IF_BOOL_EXEC(peerInControlAnchorFromData == nullptr, continue;) |
|
|
|
|
|
if (anchorSet.count(peerInControlAnchorFromData.get()) == 0) { |
|
|
|
|
|
std::pair<int, ge::InControlAnchorPtr> pairPeerInControlAnchorFromData; |
|
|
|
|
|
pairPeerInControlAnchorFromData.first = index; |
|
|
|
|
|
pairPeerInControlAnchorFromData.second = peerInControlAnchorFromData; |
|
|
|
|
|
peerInControlFromOutDataVec.push_back(pairPeerInControlAnchorFromData); |
|
|
|
|
|
anchorSet.insert(peerInControlAnchorFromData.get()); |
|
|
|
|
|
GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(outDataAnchor, peerInControlAnchorFromData)) |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status AllReducePass::GetPeerInControlFromOutControl(std::unordered_set<void *> &anchorSet, |
|
|
|
|
|
vector<ge::InControlAnchorPtr> &peerInControlFromOutControlVec, |
|
|
|
|
|
ge::NodePtr &srcNodePtr) { |
|
|
|
|
|
OutControlAnchorPtr outControlAnchor = srcNodePtr->GetOutControlAnchor(); |
|
|
|
|
|
GE_CHECK_NOTNULL(outControlAnchor); |
|
|
|
|
|
for (auto peerInControlAnchor : outControlAnchor->GetPeerInControlAnchors()) { |
|
|
|
|
|
GE_IF_BOOL_EXEC(peerInControlAnchor == nullptr, continue;) |
|
|
|
|
|
if (anchorSet.count(peerInControlAnchor.get()) == 0) { |
|
|
|
|
|
peerInControlFromOutControlVec.push_back(peerInControlAnchor); |
|
|
|
|
|
anchorSet.insert(peerInControlAnchor.get()); |
|
|
|
|
|
GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(outControlAnchor, peerInControlAnchor)) |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
} // namespace ge |
|
|
|