/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "./mds_pass.h" namespace ge { Status ModelDeploySchedulerPass::Run(ComputeGraphPtr graph) { GE_CHECK_NOTNULL(graph); compute_graph_ = graph; if (!MdsUtils::IsMDSNeeded()) { return SUCCESS; } GELOGI("[MDS][%s] start to deploy.", GetGraphName()); MDS_REQUIRE_SUCCESS(SMDPProcess(), "[MDS][SMDPProcess] failed, graph_name:[%s]", GetGraphName()); MDS_REQUIRE_SUCCESS(CutProcess(), "[MDS][CutProcess] failed, graph_name:[%s]", GetGraphName()); MDS_REQUIRE_SUCCESS(SMDPProcess(false), "[MDS][SMDPProcess] failed, graph_name:[%s]", GetGraphName()); MDS_REQUIRE_SUCCESS(SwapProcess(), "[MDS][SwapProcess] failed, graph_name:[%s]", GetGraphName()); MDS_REQUIRE_SUCCESS(PiplineProcess(), "[MDS][PiplineProcess] failed, graph_name:[%s]", GetGraphName()); MDS_REQUIRE_SUCCESS(SetDeployInfo(), "[MDS][SetDeployInfo] failed, graph_name:[%s]", GetGraphName()); GELOGI("[MDS][%s] deploy successfully.", graph->GetName().c_str()); return SUCCESS; } Status ModelDeploySchedulerPass::CutProcess() { GE_CHECK_NOTNULL(compute_graph_); if (!compute_graph_->GetAllSubgraphs().empty() || compute_graph_->GetParentGraph() != nullptr) { GELOGW("[MDS][CutProcess] graph with subgraphs is not supported now. graph_name:[%s]", GetGraphName()); return SUCCESS; } auto type = MdsUtils::TryGetGraphCutType(compute_graph_); switch (type) { case kCutN:MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph_), "[MDS][CutNProcessImply] failed, graph_name:[%s]", GetGraphName()); break; case kCutH:MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph_), "[MDS][CutHProcessImply] failed, graph_name:[%s]", GetGraphName()); break; case kDynamicCutN:MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph_, true), "[MDS][CutNProcessImply] failed, graph_name:[%s]", GetGraphName()); break; case kDynamicCutH:MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph_, true), "[MDS][CutHProcessImply] failed, graph_name:[%s]", GetGraphName()); break; case kDynamicCutAll:MDS_REQUIRE_SUCCESS(DynamicCutAll(compute_graph_), "[MDS][DynamicCutAll] failed, graph_name:[%s]", GetGraphName()); break; default:GELOGI("[MDS][CutProcess] could not cut, just return. graph_name:[%s]", GetGraphName()); return SUCCESS; } } Status ModelDeploySchedulerPass::CutNProcessImply(const ComputeGraphPtr &compute_graph, bool is_dynamic) { GE_CHECK_NOTNULL(compute_graph); // step 0: Cut for (const auto &node : compute_graph->GetDirectNode()) { auto op_kernel = mds_cut_pass::GetKernelByType(node); if (op_kernel == nullptr) { op_kernel = DeploySchedulerKernel::Instance(); } if (is_dynamic) { MDS_REQUIRE_SUCCESS(op_kernel->DynamicCutN(node), "[MDS][DYNAMIC_CUTN] failed, node:[%s]", node->GetName().c_str()); } else { MDS_REQUIRE_SUCCESS(op_kernel->CutN(node), "[MDS][CUTN] failed, node:[%s]", node->GetName().c_str()); } bool is_grad_compute_node = false; if (ge::AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_GRADIENT_NODE, is_grad_compute_node) && is_grad_compute_node) { grad_compute_nodes_.push_back(node); } } //TODO:针对单输出多引用插入的allgather,allreduce节点做广度融合优化 MDS_REQUIRE_SUCCESS(HcomNodeFusionProcess(), "[MDS][CUTN][HcomNodeFusionProcess] failed, compute graph:[%s]", compute_graph->GetName().c_str()); return SUCCESS; } Status ModelDeploySchedulerPass::CutHProcessImply(const ComputeGraphPtr &compute_graph, bool is_dynamic) { GE_CHECK_NOTNULL(compute_graph); for (NodePtr &node : compute_graph->GetDirectNode()) { auto op_kernel = mds_cut_pass::GetKernelByType(node); if (op_kernel == nullptr) { op_kernel = DeploySchedulerKernel::Instance(); } if (is_dynamic) { MDS_REQUIRE_SUCCESS(op_kernel->DynamicCutH(node), "[MDS][DYNAMIC_CUTH] failed, node:[%s]", node->GetName().c_str()); } else { MDS_REQUIRE_SUCCESS(op_kernel->CutH(node), "[MDS][CUTH] failed, node:[%s]", node->GetName().c_str()); } } return SUCCESS; } Status ModelDeploySchedulerPass::DynamicCutAll(const ComputeGraphPtr &compute_graph) { std::vector input_nodes; std::vector output_nodes; auto compute_graph0 = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes); auto compute_graph1 = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes); MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph0, true), "[MDS][CutNProcessImply] failed, graph_name:[%s]", compute_graph0->GetName().c_str()); MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph1, true), "[MDS][CutHProcessImply] failed, graph_name:[%s]", compute_graph1->GetName().c_str()); //TODO:创建case节点,把两个图放在case的两个分支下,case节点添加到原来的compute_graph中,构造case节点的输入 return SUCCESS; } Status ModelDeploySchedulerPass::SMDPProcess(bool before_cut) { if (before_cut) { MDS_REQUIRE_SUCCESS(SMDPModelState(), "[SMDPProcess][SMDPModelState] failed, graph_name:[%s]", GetGraphName()); MDS_REQUIRE_SUCCESS(SMDPWeight(), "[SMDPProcess][SMDPWeight] failed, graph_name:[%s]", GetGraphName()); } else { MDS_REQUIRE_SUCCESS(SMDPGradient(), "[SMDPProcess][SMDPGradient] failed, graph_name:[%s]", GetGraphName()); } return SUCCESS; } Status ModelDeploySchedulerPass::SetDeployInfo() { vector deployInfo; REQUIRE (!ge::AttrUtils::GetListNamedAttrs(compute_graph_, ATTR_NAME_DEPLOY_INFO, deployInfo), "%s already has deployed before!", GetGraphName()); std::multimap deploys; for (int64_t j = 0; j < kDeployNumber; j++) { int64_t device_id = j; GraphInputs graph_inputs; // For now, only one input_node in input_nodes for (const auto &input_node : MdsUtils::GetInputNodes()) { GE_CHECK_NOTNULL(input_node); GeTensorPtr graph_input = MakeShared(input_node->GetOpDesc()->GetOutputDesc(0)); vector data{static_cast(device_id)}; graph_input->SetData(data); graph_inputs.push_back(graph_input); } deploys.emplace(j, graph_inputs); } return MdsUtils::SetDeployInfo(compute_graph_, deploys); } Status ModelDeploySchedulerPass::SwapProcess() { return SUCCESS; } Status ModelDeploySchedulerPass::PiplineProcess() { return SUCCESS; } Status ModelDeploySchedulerPass::HcomNodeFusionProcess() { return SUCCESS; } Status ModelDeploySchedulerPass::SMDPModelState() { return SUCCESS; } Status ModelDeploySchedulerPass::SMDPWeight() { return SUCCESS; } Status ModelDeploySchedulerPass::SMDPGradient() { //TDOD:标识buffer poolid return SUCCESS; } }