/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "./mds_utils.h" namespace ge { namespace { // for count thread_local int64_t data_slice_count = 0; thread_local int64_t data_gather_count = 0; thread_local int64_t data_reduce_count = 0; const std::string kPrefix = "mds"; } // namespace int64_t MdsUtils::GetNLocation(Format fmt) { int64_t loc = kNInvalidLocation; switch (fmt) { case FORMAT_NCHW: case FORMAT_NHWC: loc = kNLocation0; break; case FORMAT_CHWN: case FORMAT_HWCN: loc = kNLocation3; break; default: GELOGE(FAILED, "[MDS]unsupported format:%d %s", fmt, TypeUtils::FormatToSerialString(fmt).c_str()); } return loc; } int64_t MdsUtils::GetHLocation(Format fmt) { int64_t loc = kHInvalidLocation; switch (fmt) { case FORMAT_HWCN: loc = kHLocation0; break; case FORMAT_NHWC: case FORMAT_CHWN: loc = kHLocation1; break; case FORMAT_NCHW: loc = kHLocation2; default: GELOGE(FAILED, "[MDS]unsupported format:%d %s", fmt, TypeUtils::FormatToSerialString(fmt).c_str()); } return loc; } int64_t MdsUtils::GetIndexByFormat(const GeTensorDescPtr &ge_tensor_desc, CutType type) { Format fmt = ge_tensor_desc->GetFormat(); switch (type) { case kCutN: return GetNLocation(fmt); case kCutH: return GetHLocation(fmt); default:; } GELOGE(FAILED, "[MDS]invalid CutType:%d", type); return kInvalidIndex; } bool MdsUtils::IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_desc, CutType type) { if (ge_tensor_desc == nullptr) { REPORT_INNER_ERROR("E19999", "invalid input param: tensor is null!"); GELOGE(FAILED, "[MDS]invalid input param: tensor is null!"); return false; } if (type != kCutN && type != kCutH) { REPORT_INNER_ERROR("E19999", "invalid CutType:%d", type); GELOGE(FAILED, "[MDS]invalid CutType:%d", type); return false; } int64_t cut_index = GetIndexByFormat(ge_tensor_desc, type); if (cut_index == kInvalidIndex) { REPORT_INNER_ERROR("E19999", "invalid index param:%ld", cut_index); GELOGE(FAILED, "[MDS]", "invalid index param:%ld", cut_index); return false; } auto dims = ge_tensor_desc->GetShape().GetDims(); if (cut_index < 0 || cut_index >= dims.size()) { REPORT_INNER_ERROR("E19999", "cut_index %ld for CutType %d is out of range of dims size %zu", cut_index, type, dims.size()); GELOGE(FAILED, "[MDS]", "cut_index %ld for CutType %d is out of range of dims size %zu", cut_index, type, dims.size()); return false; } if (dims[cut_index] % kDeployNumber != 0) { GELOGW("[MDS] cut_index %ld for CutType %d with dim %ld can not deploy", cut_index, type, dims[cut_index]); return false; } vector cut_support_info; if (!(AttrUtils::GetListInt(*ge_tensor_desc, ATTR_NAME_CUT_INFO, cut_support_info))) { REPORT_INNER_ERROR("E19999", "call GetlistInt failed"); GELOGE(FAILED, "[MDS]", "call GetlistInt failed"); return false; } if (cut_index < 0 || cut_index >= cut_support_info.size()) { REPORT_INNER_ERROR("E19999", "cut_index %ld for CutType %d is out of range of cut_support_info size %zu", cut_index, type, cut_support_info.size()); GELOGE(FAILED, "[MDS]", "cut_index %ld for CutType %d is out of range of cut_support_info size %zu", cut_index, type, cut_support_info.size()); return false; } if (cut_support_info[cut_index] < kNotSupport || cut_support_info[cut_index] > kAnyCutSupported) { REPORT_INNER_ERROR("E19999", "invalid cut info value:%ld", cut_support_info[cut_index]); GELOGE(FAILED, "[MDS]", "invalid cut info value:%ld", cut_support_info[cut_index]); return false; } return cut_support_info[cut_index] & kSplitCutSupported; } Status MdsUtils::DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc, CutType type, int64_t deploy_number) { GE_CHECK_NOTNULL(ge_tensor_desc); auto index = MdsUtils::GetIndexByFormat(ge_tensor_desc, type); auto dims = ge_tensor_desc->GetShape().GetDims(); REQUIRE(index < dims.size(), "[DistributedDeploy] failed, index %ld should less than %zu", index, dims.size()); auto dim_after_deploy = dims[index] / deploy_number; MDS_REQUIRE_SUCCESS(ge_tensor_desc->MutableShape().SetDim(index, dim_after_deploy), "[DistributedDeploy] update shape failed"); return SUCCESS; } Status MdsUtils::SetAttrForHcomNode(const OpDescPtr &hcom_op, int64_t fission_factor, const std::string &group_name) { GE_CHECK_NOTNULL(hcom_op); REQUIRE(fission_factor > kDefaultFissionFactor, "fission_factor %ld need be bigger than %ld", fission_factor, kDefaultFissionFactor); REQUIRE(ge::AttrUtils::SetInt(hcom_op, ATTR_NAME_FISSION_FACTOR, fission_factor), "Failed to set attr fission_factor %ld for op:%s(%s)", fission_factor, hcom_op->GetName().c_str(), hcom_op->GetType().c_str()); if (!group_name.empty()) { REQUIRE(ge::AttrUtils::SetStr(hcom_op, HCOM_ATTR_GROUP, group_name), "Failed to set attr group %s for op:%s(%s)", group_name.c_str(), hcom_op->GetName().c_str(), hcom_op->GetType().c_str()); } return SUCCESS; } bool MdsUtils::IsMDSNeeded() { std::string device_type; if (ge::GetContext().GetOption(ge::OPTION_DEVICE_TYPE, device_type) && device_type == kDefaultDeviceType) { GELOGI("[MDS]device type is %s, skip mds", device_type.c_str()); return false; } // TODO: Parse the configuration file of the system to get the sys_config_exe_unit std::string sys_config_exe_unit = "DIE"; return device_type != sys_config_exe_unit; } Status MdsUtils::SetDeployInfo(const ComputeGraphPtr &compute_graph, const NodePtr &input_node) { GE_CHECK_NOTNULL(compute_graph); GELOGD("[MDS]%s SetDeployInfo start", compute_graph->GetName().c_str()); // build deploy info vector deploy_info; GE_CHECK_NOTNULL(input_node); for (int64_t j = 0; j < kDeployNumber; j++) { int64_t device_id = j; GeAttrValue::LIST_TENSOR graph_inputs; GeTensorPtr graph_input = MakeShared(input_node->GetOpDesc()->GetOutputDesc(0)); vector data{static_cast(device_id)}; graph_input->SetData(data); // For now, only one graph_input graph_inputs.push_back(graph_input); GeAttrValue::NAMED_ATTRS thread_instance; thread_instance.SetName(std::to_string(device_id)); (void)thread_instance.SetAttr(kAttrDeviceId, GeAttrValue::CreateFrom(device_id)); // TODO:Change to enumeration from RTS header file (void)thread_instance.SetAttr(kAttrDeviceType, GeAttrValue::CreateFrom("MultiMode")); (void)thread_instance.SetAttr(kAttrGraphName, GeAttrValue::CreateFrom(compute_graph->GetName())); (void)thread_instance.SetAttr(kAttrGraphInputs, GeAttrValue::CreateFrom(graph_inputs)); deploy_info.emplace_back(thread_instance); GELOGD("[MDS]%s SetDeployInfo on device id: %d", compute_graph->GetName().c_str(), device_id); } // set deploy info REQUIRE(ge::AttrUtils::SetListNamedAttrs(*compute_graph, ATTR_NAME_DEPLOY_INFO, deploy_info), "Set attr failed for graph %s", compute_graph->GetName().c_str()); return SUCCESS; } CutType MdsUtils::TryGetGraphCutType(const ComputeGraphPtr &compute_graph) { bool is_unknown_graph = false; if (GraphUtils::IsUnknownShapeGraph(compute_graph)) { GELOGI("Graph %s is unknown shape graph", compute_graph->GetName().c_str()); is_unknown_graph = true; } CutType selected_cut_type = kNoCut; for (const auto &data : compute_graph->GetInputNodes()) { GELOGI("Get graph input %s %s", data->GetName().c_str(), data->GetType().c_str()); auto data_n_index = MdsUtils::GetIndexByFormat(data->GetOpDesc()->MutableOutputDesc(0), kCutN); auto data_n_dim = data->GetOpDesc()->GetOutputDesc(0).GetShape().GetDim(data_n_index); auto data_h_index = MdsUtils::GetIndexByFormat(data->GetOpDesc()->MutableOutputDesc(0), kCutH); auto data_h_dim = data->GetOpDesc()->GetOutputDesc(0).GetShape().GetDim(data_h_index); if (data_n_dim == -1 && data_h_dim == -1) { selected_cut_type = kDynamicCutAll; break; } if (data_n_dim % kDeployNumber == 0) { is_unknown_graph ? selected_cut_type = kDynamicCutN : selected_cut_type = kCutN; break; } if (data_h_dim % kDeployNumber == 0) { is_unknown_graph ? selected_cut_type = kDynamicCutH : selected_cut_type = kCutH; } } return selected_cut_type; } Status MdsUtils::SetDeployInfo(const ComputeGraphPtr &compute_graph, const std::multimap &deploys, const std::string &device_type) { GE_CHECK_NOTNULL(compute_graph); GELOGD("[MDS]%s SetDeployInfo start", compute_graph->GetName().c_str()); // build deploy info vector deploy_info; for (const auto &pair : deploys) { int64_t device_id = pair.first; GeAttrValue::NAMED_ATTRS thread_instance; thread_instance.SetName(std::to_string(device_id)); (void)thread_instance.SetAttr(kAttrNeedReturnResult, GeAttrValue::CreateFrom(deploy_info.empty() ? true : false)); (void)thread_instance.SetAttr(kAttrDeviceId, GeAttrValue::CreateFrom(device_id)); (void)thread_instance.SetAttr(kAttrDeviceType, GeAttrValue::CreateFrom(device_type)); (void)thread_instance.SetAttr(kAttrGraphName, GeAttrValue::CreateFrom(compute_graph->GetName())); (void)thread_instance.SetAttr(kAttrGraphInputs, GeAttrValue::CreateFrom(pair.second)); deploy_info.emplace_back(thread_instance); GELOGD("[MDS]%s SetDeployInfo on device id: %d", compute_graph->GetName().c_str(), device_id); } // set deploy info REQUIRE(ge::AttrUtils::SetListNamedAttrs(*compute_graph, ATTR_NAME_DEPLOY_INFO, deploy_info), "Set attr failed for graph %s", compute_graph->GetName().c_str()); return SUCCESS; } Status MdsUtils::DataGather(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst) { auto src_node = src->GetOwnerNode(); GE_CHECK_NOTNULL(src_node); auto dst_node = dst->GetOwnerNode(); GE_CHECK_NOTNULL(dst_node); auto src_graph = src_node->GetOwnerComputeGraph(); GE_CHECK_NOTNULL(src_graph); std::string node_name_suffix("_" + kPrefix + "_" + std::to_string(data_gather_count)); auto hcom_allgather_node = AddDynamicInputOutputNode(src_graph, HCOMALLGATHER, HCOMALLGATHER + node_name_suffix, 1, 1); GE_CHECK_NOTNULL(hcom_allgather_node); MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, hcom_allgather_node), "[DataGather] failed between %s and %s", src_node->GetName().c_str(), dst_node->GetName().c_str()); MDS_REQUIRE_SUCCESS(MdsUtils::SetAttrForHcomNode(hcom_allgather_node->GetOpDesc(), kDeployNumber, kDefaultGroup), "[DataGather]set attr for node for %s(%s) failed", hcom_allgather_node->GetName().c_str(), hcom_allgather_node->GetType().c_str()); REQUIRE(ge::AttrUtils::SetInt(hcom_allgather_node->GetOpDesc(), HCOM_ATTR_RANK_SIZE, kDefaultRankSize), "Failed to set attr reduction type %s for op:%s(%s)", kDefaultReduction.c_str(), hcom_allgather_node->GetName().c_str(), hcom_allgather_node->GetType().c_str()); MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(hcom_allgather_node, false), "[DataGather] %s call infershape failed", hcom_allgather_node->GetName().c_str()); data_gather_count++; return SUCCESS; } // gradients->ApplyMomentum // we want to reduce gradients on different device(die), so graph topo changed to // gradients->hcomallreducemean->ApplyMomentum; Because 'mean' is not currently supported by hcomallreduce, // topo will end up like gradients->hcomallreducesum->div->ApplyMomentum Status MdsUtils::DataReduce(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst) { auto src_node = src->GetOwnerNode(); GE_CHECK_NOTNULL(src_node); auto dst_node = dst->GetOwnerNode(); GE_CHECK_NOTNULL(dst_node); auto src_graph = src_node->GetOwnerComputeGraph(); GE_CHECK_NOTNULL(src_graph); NodePtr all_reduce_node = nullptr; if (NeedInsertHcomAllReduce(src_node, all_reduce_node)) { MDS_REQUIRE_SUCCESS(ConstructReduceNode(src_graph, src, dst, all_reduce_node), "[DataReduce] construct allreduce node for %s failed", all_reduce_node->GetName().c_str()); GE_CHECK_NOTNULL(all_reduce_node); } else { GE_CHECK_NOTNULL(all_reduce_node); MDS_REQUIRE_SUCCESS(MdsUtils::SetAttrForHcomNode(all_reduce_node->GetOpDesc(), kDeployNumber), "[DataReduce][Modify] set attr for allreduce node for %s failed", all_reduce_node->GetName().c_str()); } return SUCCESS; } // tensor t with shape like [n,c,h,w], we want get [0:2/n, c, h, w] and [2/n : n, c, h, w] on different // device; To achieve this goal, we use slice nodes. // slice(t, [i * n/2, 0, 0, 0], [n/2, c, h, w]) i=0,1 // slice three input like : t->slice; data(0,1)->mul(n/2)->pack[i*n/2,0,0,0]->slice; const(n,c,h,w)->slice Status MdsUtils::DataSlice(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, NodePtr &input_node) { auto src_node = src->GetOwnerNode(); GE_CHECK_NOTNULL(src_node); auto dst_node = dst->GetOwnerNode(); GE_CHECK_NOTNULL(dst_node); auto src_graph = src_node->GetOwnerComputeGraph(); GE_CHECK_NOTNULL(src_graph); if (input_node == nullptr) { std::string input_node_name = std::string(DATA) + "_" + kPrefix + "_" + std::to_string(0); input_node = AddSingleInputOutputNode(src_graph, input_node_name, DATA); AddInputNode(input_node); } GeTensorDesc tensor = src_node->GetOpDesc()->GetOutputDesc(src->GetIdx()); NodePtr slice_node = nullptr; MDS_REQUIRE_SUCCESS(ConstructSliceNode(src_graph, tensor, input_node.get(), slice_node), "[DataSlice] construct slice node for %s failed", src_node->GetName().c_str()); GE_CHECK_NOTNULL(slice_node); MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, slice_node), "[DataSlice] failed between %s and %s", src_node->GetName().c_str(), dst_node->GetName().c_str()); MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(slice_node, false), "[DataSlice] %s call infer shape failed", slice_node->GetName().c_str()); return SUCCESS; } Status MdsUtils::ConstructSliceNode(const ComputeGraphPtr &src_graph, const GeTensorDesc &tensor, Node *input_node, NodePtr &slice_node) { vector slice_sizes = tensor.GetShape().GetDims(); // TODO: Express with graph structure slice_sizes[0] /= kDeployNumber; vector ge_tensors; GeTensorDesc ge_tensor_desc; ge_tensor_desc.SetDataType(DT_INT64); MDS_REQUIRE_SUCCESS(PassUtils::ConstructTensorDescWithData(ge_tensor_desc, slice_sizes, ge_tensors), "[ConstructTensorDescWithData] failed"); GeTensorPtr slice_size_tensor = ge_tensors[0]; auto const_node_slice_size = AddConstNodeToGraph(slice_size_tensor, src_graph); vector slice_offset_other_dim{0}; ge_tensors.clear(); MDS_REQUIRE_SUCCESS(PassUtils::ConstructTensorDescWithData(ge_tensor_desc, slice_offset_other_dim, ge_tensors, true), "[ConstructTensorDescWithData] failed"); GeTensorPtr slice_offset_tensor = ge_tensors[0]; auto const_node_slice_offset = AddConstNodeToGraph(slice_offset_tensor, src_graph); vector slice_offset_first_dim{slice_sizes[0]}; ge_tensors.clear(); MDS_REQUIRE_SUCCESS(PassUtils::ConstructTensorDescWithData(ge_tensor_desc, slice_offset_first_dim, ge_tensors, true), "[ConstructTensorDescWithData] failed"); GeTensorPtr slice_offset_first_dim_tensor = ge_tensors[0]; auto const_node_slice_offset_first_dim = AddConstNodeToGraph(slice_offset_first_dim_tensor, src_graph); std::string node_name_suffix("_" + kPrefix + "_" + std::to_string(data_slice_count)); NodePtr mul_node = AddDynamicInputOutputNode(src_graph, MUL, MUL + node_name_suffix, 2, 1); GE_CHECK_NOTNULL(input_node); MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(input_node->GetOutDataAnchor(0), mul_node->GetInDataAnchor(0)), "[ConstructSliceNode] add edge failed"); MDS_REQUIRE_SUCCESS( GraphUtils::AddEdge(const_node_slice_offset_first_dim->GetOutDataAnchor(0), mul_node->GetInDataAnchor(1)), "[ConstructSliceNode] add edge failed"); MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(mul_node, false), "[DataSlice] %s call infer shape failed", mul_node->GetName().c_str()); NodePtr pack_node = AddDynamicInputOutputNode(src_graph, PACK, PACK + node_name_suffix, slice_sizes.size(), 1); bool is_first_input = true; for (const auto &in_anchor : pack_node->GetAllInDataAnchors()) { if (is_first_input) { MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(mul_node->GetOutDataAnchor(0), in_anchor), "[ConstructSliceNode] add edge failed"); is_first_input = false; } else { MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(const_node_slice_offset->GetOutDataAnchor(0), in_anchor), "[ConstructSliceNode] add edge failed"); } } MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(pack_node, false), "[DataSlice] %s call infer shape failed", pack_node->GetName().c_str()); slice_node = AddDynamicInputOutputNode(src_graph, SLICE, SLICE + node_name_suffix, 3, 1); MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(pack_node->GetOutDataAnchor(0), slice_node->GetInDataAnchor(1)), "[ConstructSliceNode] add edge failed"); MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(const_node_slice_size->GetOutDataAnchor(0), slice_node->GetInDataAnchor(2)), "[ConstructSliceNode] add edge failed"); ++data_slice_count; return SUCCESS; } NodePtr MdsUtils::AddSingleInputOutputNode(const ComputeGraphPtr &graph, const string &name, const string &type, const GeTensorDesc &tensor) { GELOGI("Begin to create op: %s", name.c_str()); OpDescBuilder op_desc_builder(name, type); OpDescPtr op_desc = op_desc_builder.AddInput("x", tensor).AddOutput("y", tensor).Build(); if (op_desc == nullptr) { REPORT_CALL_ERROR("E19999", "Create op_desc:%s(%s) failed", name.c_str(), type.c_str()); GELOGE(FAILED, "[Create][OpDesc] failed, name:%s(%s).", name.c_str(), type.c_str()); return nullptr; } NodePtr node = graph->AddNode(op_desc); if (node == nullptr) { REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed", op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); return nullptr; } return node; } NodePtr MdsUtils::AddDynamicInputOutputNode(const ComputeGraphPtr &graph, const std::string &type, const std::string &node_name, size_t input_num, size_t output_num) { GELOGI("Begin to create op: %s", node_name.c_str()); OpDescBuilder op_desc_builder(node_name, type); OpDescPtr op_desc = op_desc_builder.AddDynamicInput("x", input_num).AddDynamicOutput("y", output_num).Build(); if (op_desc == nullptr) { REPORT_CALL_ERROR("E19999", "Create op_desc:%s(%s) failed", node_name.c_str(), type.c_str()); GELOGE(FAILED, "[Create][OpDesc] failed, name:%s(%s).", node_name.c_str(), type.c_str()); return nullptr; } NodePtr node = graph->AddNode(op_desc); if (node == nullptr) { REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed", op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); return nullptr; } return node; } NodePtr MdsUtils::AddConstNodeToGraph(GeTensorPtr &tensor, const ComputeGraphPtr &graph) { auto const_desc = OpDescUtils::CreateConstOp(tensor); if (const_desc == nullptr) { REPORT_CALL_ERROR("E19999", "Create Const op failed"); GELOGE(OUT_OF_MEMORY, "[Create][ConstOp] failed"); return nullptr; } if (graph == nullptr) { GELOGW("input param graph is null"); return nullptr; } return graph->AddNodeFront(const_desc); } Status MdsUtils::ConstructReduceNode(const ComputeGraphPtr &src_graph, const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, NodePtr &reduce_node) { std::string node_name_suffix("_" + kPrefix + "_" + std::to_string(data_reduce_count)); reduce_node = AddDynamicInputOutputNode(src_graph, HCOMALLREDUCE, HCOMALLREDUCE + node_name_suffix, 1, 1); MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, reduce_node), "[DataReduce] failed insert %s between %s and %s", reduce_node->GetName().c_str(), src->GetOwnerNode()->GetName().c_str(), dst->GetOwnerNode()->GetName().c_str()); MDS_REQUIRE_SUCCESS(MdsUtils::SetAttrForHcomNode(reduce_node->GetOpDesc(), kDeployNumber, kDefaultGroup), "[DataReduce][Create] set attr for allreduce node for %s failed", reduce_node->GetName().c_str()); REQUIRE(ge::AttrUtils::SetStr(reduce_node->GetOpDesc(), HCOM_ATTR_REDUCE_TYPE, kDefaultReduction), "Failed to set attr reduction type %s for op:%s(%s)", kDefaultReduction.c_str(), reduce_node->GetName().c_str(), reduce_node->GetType().c_str()); MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(reduce_node, false), "[DataReduce] %s call infershape failed", reduce_node->GetName().c_str()); auto div_node = AddDynamicInputOutputNode(src_graph, REALDIV, REALDIV + node_name_suffix, 2, 1); vector slice_sizes{kDeployNumber}; vector ge_tensors; GeTensorDesc ge_tensor_desc; ge_tensor_desc.SetDataType(DT_INT64); MDS_REQUIRE_SUCCESS(PassUtils::ConstructTensorDescWithData(ge_tensor_desc, slice_sizes, ge_tensors), "[ConstructReduceNode] failed"); REQUIRE(!ge_tensors.empty(), "[ConstructReduceNode] failed"); auto const_node_div_input = AddConstNodeToGraph(ge_tensors[0], src_graph); MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(const_node_div_input->GetOutDataAnchor(0), div_node->GetInDataAnchor(1)), "[ConstructSliceNode] add edge failed"); MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(reduce_node->GetOutDataAnchor(0), {dst}, div_node), "[DataReduce] failed insert %s between %s and %s", div_node->GetName().c_str(), reduce_node->GetName().c_str(), dst->GetOwnerNode()->GetName().c_str()); MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(div_node, false), "[DataReduce] %s call infershape failed", div_node->GetName().c_str()); return SUCCESS; } bool MdsUtils::NeedInsertHcomAllReduce(const NodePtr &src_node, NodePtr &allreduce_node) { // TODO: recognize that the graph is originally a multi-p model, that is, there is already an allreduce node, // so there is no need to insert i return true; } } // namespace ge