/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "./mds_pass.h" namespace ge { Status ModelDeploySchedulerPass::Run(ComputeGraphPtr graph) { GE_CHECK_NOTNULL(graph); compute_graph_ = graph; if (!MdsUtils::IsMDSNeeded()) { return SUCCESS; } GELOGI("[MDS][%s] start to deploy.", GetGraphName()); MDS_REQUIRE_SUCCESS(SMDPProcess(), "[MDS][SMDPProcess] failed, graph_name:[%s]", GetGraphName()); MDS_REQUIRE_SUCCESS(CutProcess(), "[MDS][CutProcess] failed, graph_name:[%s]", GetGraphName()); MDS_REQUIRE_SUCCESS(SMDPProcess(false), "[MDS][SMDPProcess] failed, graph_name:[%s]", GetGraphName()); MDS_REQUIRE_SUCCESS(SwapProcess(), "[MDS][SwapProcess] failed, graph_name:[%s]", GetGraphName()); MDS_REQUIRE_SUCCESS(PiplineProcess(), "[MDS][PiplineProcess] failed, graph_name:[%s]", GetGraphName()); MDS_REQUIRE_SUCCESS(SetDeployInfo(), "[MDS][SetDeployInfo] failed, graph_name:[%s]", GetGraphName()); GELOGI("[MDS][%s] deploy successfully.", graph->GetName().c_str()); return SUCCESS; } Status ModelDeploySchedulerPass::CutProcess() { GE_CHECK_NOTNULL(compute_graph_); if (!compute_graph_->GetAllSubgraphs().empty() || compute_graph_->GetParentGraph() != nullptr) { GELOGW("[MDS][CutProcess] graph with subgraphs is not supported now. graph_name:[%s]", GetGraphName()); return SUCCESS; } auto type = MdsUtils::TryGetGraphCutType(compute_graph_); switch (type) { case kCutN: MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph_), "[MDS][CutNProcessImply] failed, graph_name:[%s]", GetGraphName()); break; case kCutH: MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph_), "[MDS][CutHProcessImply] failed, graph_name:[%s]", GetGraphName()); break; case kDynamicCutN: MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph_, true), "[MDS][CutNProcessImply] failed, graph_name:[%s]", GetGraphName()); break; case kDynamicCutH: MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph_, true), "[MDS][CutHProcessImply] failed, graph_name:[%s]", GetGraphName()); break; case kDynamicCutAll: MDS_REQUIRE_SUCCESS(DynamicCutAll(compute_graph_), "[MDS][DynamicCutAll] failed, graph_name:[%s]", GetGraphName()); break; default: GELOGI("[MDS][CutProcess] could not cut, just return. graph_name:[%s]", GetGraphName()); return SUCCESS; } } Status ModelDeploySchedulerPass::CutNProcessImply(const ComputeGraphPtr &compute_graph, bool is_dynamic) { GE_CHECK_NOTNULL(compute_graph); // step 0: Cut for (const auto &node : compute_graph->GetDirectNode()) { auto op_kernel = mds_cut_pass::GetKernelByType(node); if (op_kernel == nullptr) { op_kernel = DeploySchedulerKernel::Instance(); } if (is_dynamic) { MDS_REQUIRE_SUCCESS(op_kernel->DynamicCutN(node), "[MDS][DYNAMIC_CUTN] failed, node:[%s]", node->GetName().c_str()); } else { MDS_REQUIRE_SUCCESS(op_kernel->CutN(node), "[MDS][CUTN] failed, node:[%s]", node->GetName().c_str()); } bool is_grad_compute_node = false; if (ge::AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_GRADIENT_NODE, is_grad_compute_node) && is_grad_compute_node) { grad_compute_nodes_.push_back(node); } } // TODO:for single output multi reference insertion allgather, allreduce nodes, do breadth fusion optimization MDS_REQUIRE_SUCCESS(HcomNodeFusionProcess(), "[MDS][CUTN][HcomNodeFusionProcess] failed, compute graph:[%s]", compute_graph->GetName().c_str()); return SUCCESS; } Status ModelDeploySchedulerPass::CutHProcessImply(const ComputeGraphPtr &compute_graph, bool is_dynamic) { GE_CHECK_NOTNULL(compute_graph); for (NodePtr &node : compute_graph->GetDirectNode()) { auto op_kernel = mds_cut_pass::GetKernelByType(node); if (op_kernel == nullptr) { op_kernel = DeploySchedulerKernel::Instance(); } if (is_dynamic) { MDS_REQUIRE_SUCCESS(op_kernel->DynamicCutH(node), "[MDS][DYNAMIC_CUTH] failed, node:[%s]", node->GetName().c_str()); } else { MDS_REQUIRE_SUCCESS(op_kernel->CutH(node), "[MDS][CUTH] failed, node:[%s]", node->GetName().c_str()); } } return SUCCESS; } Status ModelDeploySchedulerPass::DynamicCutAll(const ComputeGraphPtr &compute_graph) { std::vector input_nodes; std::vector output_nodes; auto compute_graph0 = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes); auto compute_graph1 = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes); MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph0, true), "[MDS][CutNProcessImply] failed, graph_name:[%s]", compute_graph0->GetName().c_str()); MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph1, true), "[MDS][CutHProcessImply] failed, graph_name:[%s]", compute_graph1->GetName().c_str()); // TODO:Create a case node, put the two graphs under the two branches of case return SUCCESS; } Status ModelDeploySchedulerPass::SMDPProcess(bool before_cut) { if (before_cut) { MDS_REQUIRE_SUCCESS(SMDPModelState(), "[SMDPProcess][SMDPModelState] failed, graph_name:[%s]", GetGraphName()); MDS_REQUIRE_SUCCESS(SMDPWeight(), "[SMDPProcess][SMDPWeight] failed, graph_name:[%s]", GetGraphName()); } else { MDS_REQUIRE_SUCCESS(SMDPGradient(), "[SMDPProcess][SMDPGradient] failed, graph_name:[%s]", GetGraphName()); } return SUCCESS; } Status ModelDeploySchedulerPass::SetDeployInfo() { vector deployInfo; REQUIRE(!ge::AttrUtils::GetListNamedAttrs(compute_graph_, ATTR_NAME_DEPLOY_INFO, deployInfo), "%s already has deployed before!", GetGraphName()); std::multimap deploys; for (int64_t j = 0; j < kDeployNumber; j++) { int64_t device_id = j; GraphInputs graph_inputs; // For now, only one input_node in input_nodes for (const auto &input_node : MdsUtils::GetInputNodes()) { GE_CHECK_NOTNULL(input_node); GeTensorPtr graph_input = MakeShared(input_node->GetOpDesc()->GetOutputDesc(0)); vector data{static_cast(device_id)}; graph_input->SetData(data); graph_inputs.push_back(graph_input); } deploys.emplace(j, graph_inputs); } return MdsUtils::SetDeployInfo(compute_graph_, deploys); } Status ModelDeploySchedulerPass::SwapProcess() { return SUCCESS; } Status ModelDeploySchedulerPass::PiplineProcess() { return SUCCESS; } Status ModelDeploySchedulerPass::HcomNodeFusionProcess() { return SUCCESS; } Status ModelDeploySchedulerPass::SMDPModelState() { return SUCCESS; } Status ModelDeploySchedulerPass::SMDPWeight() { return SUCCESS; } Status ModelDeploySchedulerPass::SMDPGradient() { return SUCCESS; } } // namespace ge