/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "./base_mds_kernel.h" namespace ge { namespace mds_cut_pass { shared_ptr GetKernelByType(const NodePtr &node) { if (node == nullptr) { REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); return nullptr; } KernelFactory &factory = KernelFactory::Instance(); string type = node->GetType(); if (type == FRAMEWORKOP) { if (!ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type)) { REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE.c_str(), node->GetName().c_str(), node->GetType().c_str()); return nullptr; } } return factory.Create(type); } } // namespace mds_cut_pass shared_ptr DeploySchedulerKernel::Instance() { static const std::shared_ptr instance_ptr = shared_ptr(new (std::nothrow) DeploySchedulerKernel()); return instance_ptr; } Status DeploySchedulerKernel::CutN(const ge::NodePtr &node) { GE_CHECK_NOTNULL(node); auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); for (auto &in_anchor : node->GetAllInDataAnchors()) { GE_CHECK_NOTNULL(in_anchor); auto src_anchor = in_anchor->GetPeerOutAnchor(); if (src_anchor == nullptr) { continue; } auto tensor_desc = op_desc->MutableInputDesc(in_anchor->GetIdx()); auto src_node = src_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(src_node); auto src_op_desc = src_node->GetOpDesc(); auto src_tensor_desc = src_op_desc->MutableOutputDesc(src_anchor->GetIdx()); GE_CHECK_NOTNULL(src_tensor_desc); // peer out shape is cutted already if (MdsUtils::IsDistributedDeploySupported(src_tensor_desc, kCutN)) { if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutN)) { tensor_desc->SetShape(src_tensor_desc->GetShape()); } else { MDS_REQUIRE_SUCCESS( MdsUtils::DataGather(src_anchor, in_anchor), "[CutN] failed to gather between node[%s][%d] to node[%s][%d]", src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx()); } } else { if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutN)) { MDS_REQUIRE_SUCCESS(MdsUtils::DataSlice(src_anchor, in_anchor, input_node_), "[CutN] failed to slice between node[%s][%d] to node[%s][%d]", src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx()); } else { tensor_desc->SetShape(src_tensor_desc->GetShape()); } } // insert hcomallreduce for cutn bool is_grad_compute_node = false; if (ge::AttrUtils::GetBool(src_node->GetOpDesc(), ATTR_NAME_GRADIENT_NODE, is_grad_compute_node) && is_grad_compute_node) { MDS_REQUIRE_SUCCESS( MdsUtils::DataReduce(src_anchor, in_anchor), "[CutN] failed to reduce between node[%s][%d] to node[%s][%d]", src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx()); } } // call infer shape, update output shape MDS_REQUIRE_SUCCESS(node->InferShapeAndType(), "[CutN] %s call infershape failed", node->GetName().c_str()); return SUCCESS; } Status DeploySchedulerKernel::CutH(const ge::NodePtr &node) { GE_CHECK_NOTNULL(node); auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); for (auto &in_anchor : node->GetAllInDataAnchors()) { GE_CHECK_NOTNULL(in_anchor); auto src_anchor = in_anchor->GetPeerOutAnchor(); if (src_anchor == nullptr) { continue; } auto tensor_desc = op_desc->MutableInputDesc(in_anchor->GetIdx()); auto src_node = src_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(src_node); auto src_op_desc = src_node->GetOpDesc(); auto src_tensor_desc = src_op_desc->MutableOutputDesc(src_anchor->GetIdx()); GE_CHECK_NOTNULL(src_tensor_desc); // peer out shape is cutted already if (MdsUtils::IsDistributedDeploySupported(src_tensor_desc, kCutH)) { if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutH)) { MDS_REQUIRE_SUCCESS(HaloExchangeProcess(node, in_anchor->GetIdx()), "[CutH] failed to do overlap between node[%s][%d] to node[%s][%d]", src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx()); } else { MDS_REQUIRE_SUCCESS( MdsUtils::DataGather(src_anchor, in_anchor), "[CutH] failed to gather between node[%s][%d] to node[%s][%d]", src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx()); } } else { if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutH)) { MDS_REQUIRE_SUCCESS(MdsUtils::DataSlice(src_anchor, in_anchor, input_node_), "[CutH] failed to slice between node[%s][%d] to node[%s][%d]", src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx()); } else { MDS_REQUIRE_SUCCESS(HaloExchangeProcess(node, in_anchor->GetIdx(), true), "[CutH] failed to do overlap between node[%s][%d] to node[%s][%d]", src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx()); } } } // call infer shape, update output shape MDS_REQUIRE_SUCCESS(node->InferShapeAndType(), "[CutH] call infer shape failed", node->GetName().c_str()); return SUCCESS; } } // namespace ge