From 01dfd8974956525c7977f926a94aee25335c7b8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=B6=9B?= Date: Sat, 26 Jun 2021 16:16:34 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9B=9E=E9=80=80=20'Pull=20Request=20!1849=20?= =?UTF-8?q?:=20add=20copy=20graph'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ge/graph/manager/graph_manager.cc | 2 +- ge/hybrid/model/hybrid_model.h | 1 - ge/hybrid/model/hybrid_model_builder.cc | 47 ++++--------------- ge/hybrid/model/hybrid_model_builder.h | 1 - ge/model/ge_root_model.h | 5 -- .../executor/subgraph_executor_unittest.cc | 3 -- .../model/hybrid_model_builder_unittest.cc | 26 ++-------- 7 files changed, 15 insertions(+), 70 deletions(-) diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index 01a2e502..66026f8d 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -3131,10 +3131,10 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { } // Avoid repeatively prerun for graphs owns same graph_id in online inference concurrency if (count > 1 && graph_node->GetBuildFlag()) { + graph_node->Lock(); GELOGD("Avoid repeatively prerun, graph_id:%u.", args.graph_id); // In online inference concurrency senario, graph_node is allowed to be locked for 'count' times graph_node->SetSemSize(count); - graph_node->Lock(); graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context, args.input_tensor, graph_node->GetGeRootModel(), GetThreadLocalContext(), args.callback })); GELOGI("[PreRunThread] Loop end. Start to run with cached build model."); diff --git a/ge/hybrid/model/hybrid_model.h b/ge/hybrid/model/hybrid_model.h index 77246e20..9821242a 100644 --- a/ge/hybrid/model/hybrid_model.h +++ b/ge/hybrid/model/hybrid_model.h @@ -147,7 +147,6 @@ class HybridModel { GeRootModelPtr ge_root_model_; std::map input_nodes_; ComputeGraphPtr root_graph_; - ComputeGraphPtr orig_root_graph_; std::map device_variable_nodes_; //lint !e148 std::map host_variable_nodes_; //lint !e148 std::map> variable_tensors_; diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index 8f96ea9d..1f68f374 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -147,7 +147,6 @@ Status HybridModelBuilder::Build() { GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName()); hybrid_model_.model_name_ = ge_root_model_->GetModelName(); GELOGI("[%s] Start to build hybrid model.", GetGraphName()); - GE_CHK_STATUS_RET(CopyGraph(), "[Invoke][CopyGraph] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET(InitRuntimeParams(), "[Invoke][InitRuntimeParams] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), "[Invoke][RecoverGraphUnknownFlag] failed, model_name_:[%s]", GetGraphName()); @@ -172,12 +171,11 @@ Status HybridModelBuilder::Build() { Status HybridModelBuilder::BuildForSingleOp() { GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName()); - hybrid_model_.root_graph_ = ge_root_model_->GetRootGraph(); hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName(); GELOGI("[%s] Start to build hybrid model.", GetGraphName()); auto ret = ge_root_model_->GetSubgraphInstanceNameToModel(); - const GeModelPtr ge_model = ret[hybrid_model_.root_graph_->GetName()]; - GE_CHK_STATUS_RET(IndexTaskDefs(hybrid_model_.root_graph_, ge_model), + const GeModelPtr ge_model = ret[ge_root_model_->GetRootGraph()->GetName()]; + GE_CHK_STATUS_RET(IndexTaskDefs(ge_root_model_->GetRootGraph(), ge_model), "[Invoke][IndexTaskDefs] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET(LoadGraph(), "[Invoke][LoadGraph] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET(InitWeights(), "[Invoke][InitWeights] failed, model_name_:[%s]", GetGraphName()); @@ -192,29 +190,6 @@ Status HybridModelBuilder::ValidateParams() { return SUCCESS; } -Status HybridModelBuilder::CopyGraph() { - GELOGD("Copy compute graph begin."); - auto root_graph = ge_root_model_->GetRootGraph(); - - ge_root_model_->IncreaseBuildTimes(); - std::string new_graph_name = ge_root_model_->GetRootGraph()->GetName() + "_" + - std::to_string(ge_root_model_->GetBuildTimes()); - ComputeGraphPtr new_root_graph = MakeShared(new_graph_name); - GE_CHECK_NOTNULL(new_root_graph); - int32_t depth = 0; - std::map node_old_2_new; - std::map op_desc_old_2_new; - graphStatus ret = GraphUtils::CopyComputeGraph(root_graph, new_root_graph, node_old_2_new, op_desc_old_2_new, depth); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Copy compute graph failed."); - return GRAPH_FAILED; - } - hybrid_model_.root_graph_ = new_root_graph; - - GELOGD("Copy compute graph[%s] success.", new_graph_name.c_str()); - return SUCCESS; -} - Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_item) { auto op_desc = node->GetOpDesc(); GE_CHK_STATUS_RET(ParseForceInfershapeNodes(node, node_item), @@ -835,13 +810,12 @@ Status HybridModelBuilder::BuildOutputMapping(GraphItem &graph_item, } Status HybridModelBuilder::LoadGraph() { - auto root_graph = hybrid_model_.root_graph_; + auto root_graph = ge_root_model_->GetRootGraph(); if (!GetContext().GetHostExecFlag()) { std::shared_ptr merged_graph; GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", root_graph->GetDirectNodesSize(), root_graph->GetAllNodesSize()); - hybrid_model_.orig_root_graph_ = root_graph; GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(root_graph, merged_graph), "[Invoke][UnfoldSubgraphs]Failed to unfold subgraphs, model_name_:%s.", GetGraphName()); root_graph = std::move(merged_graph); @@ -899,7 +873,6 @@ Status HybridModelBuilder::LoadGraph() { } for (auto &it : hybrid_model_.known_shape_sub_models_) { auto node_item = MutableNodeItem(it.first); - GE_CHECK_NOTNULL(node_item); AscendString graph_name; GE_CHK_GRAPH_STATUS_RET(it.second->GetGraph().GetName(graph_name), "Failed to get subgraph name"); auto subgraph = hybrid_model_.GetRootGraph()->GetSubgraph(graph_name.GetString()); @@ -1148,9 +1121,7 @@ Status HybridModelBuilder::InitWeights() { sub_weight_buffer->GetSize()); auto subgraph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph()); if (subgraph != ge_root_model_->GetRootGraph()) { - subgraph = hybrid_model_.root_graph_->GetSubgraph(subgraph_model.first); - } else { - subgraph = hybrid_model_.root_graph_; + subgraph = ge_root_model_->GetRootGraph()->GetSubgraph(subgraph_model.first); } GE_CHECK_NOTNULL(subgraph); hybrid_model_.weight_buffer_map_.emplace(subgraph->GetName(), std::move(sub_weight_buffer)); @@ -1329,7 +1300,7 @@ Status HybridModelBuilder::IndexTaskDefs(const ComputeGraphPtr &sub_graph, const } Status HybridModelBuilder::IndexTaskDefs() { - const auto &root_graph = hybrid_model_.root_graph_; + const auto root_graph = ge_root_model_->GetRootGraph(); const auto &root_graph_name = root_graph->GetName(); if (SetOutputNameAttr(*root_graph) != SUCCESS) { GELOGW("Set output name attr failed."); @@ -1363,7 +1334,7 @@ Status HybridModelBuilder::IndexTaskDefs() { Status HybridModelBuilder::IndexSpecialNodes() { GELOGD("Start to index special nodes"); - const auto &root_graph = hybrid_model_.root_graph_; + const auto &root_graph = ge_root_model_->GetRootGraph(); for (auto &node : root_graph->GetAllNodes()) { GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node->GetOpDesc()); @@ -1518,7 +1489,7 @@ Status HybridModelBuilder::InitRuntimeParams() { runtime_param_.session_id = ret ? static_cast(value) : 0; ret = ge::AttrUtils::GetInt(first_model, ATTR_MODEL_TASK_GEN_VAR_ADDR, value); runtime_param_.logic_var_base = ret ? static_cast(value) : 0; - runtime_param_.graph_id = hybrid_model_.root_graph_->GetGraphID(); + runtime_param_.graph_id = ge_root_model_->GetRootGraph()->GetGraphID(); value = 0; for (auto &it : ge_root_model_->GetSubgraphInstanceNameToModel()) { (void) ge::AttrUtils::GetInt(it.second, ATTR_MODEL_VAR_SIZE, value); @@ -1655,7 +1626,7 @@ Status HybridModelBuilder::TransAllVarData() { } Status HybridModelBuilder::CopyVarData() { - GE_CHK_STATUS_RET(TransVarDataUtils::CopyVarData(hybrid_model_.root_graph_, + GE_CHK_STATUS_RET(TransVarDataUtils::CopyVarData(ge_root_model_->GetRootGraph(), runtime_param_.session_id, hybrid_model_.device_id_), "[Invoke][CopyVarData] failed."); @@ -1738,7 +1709,7 @@ Status HybridModelBuilder::LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem } Status HybridModelBuilder::RecoverGraphUnknownFlag() { - const auto &root_graph = hybrid_model_.root_graph_; + const auto &root_graph = ge_root_model_->GetRootGraph(); for (auto &sub_graph : root_graph->GetAllSubgraphs()) { GE_CHECK_NOTNULL(sub_graph); for (const auto &node : sub_graph->GetDirectNode()) { diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index 05830e82..9c1eb187 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -56,7 +56,6 @@ class HybridModelBuilder { Status BuildOutputMapping(GraphItem &partitioned_call, const NodeItem &node_item, bool is_root_graph); Status ValidateParams(); Status LoadGraph(); - Status CopyGraph(); Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); static Status InitHcclExecutorOnDemand(const GeModelPtr &ge_model); Status LoadTask(NodeItem &node_item); diff --git a/ge/model/ge_root_model.h b/ge/model/ge_root_model.h index b6e3d081..9e8e116e 100755 --- a/ge/model/ge_root_model.h +++ b/ge/model/ge_root_model.h @@ -60,10 +60,6 @@ class GeRootModel { bool GetTrainFlag() const { return train_flag_; } - int32_t GetBuildTimes() const { return hybrid_build_times_; } - - void IncreaseBuildTimes() { hybrid_build_times_++; } - private: ComputeGraphPtr root_graph_ = nullptr; std::map subgraph_instance_name_to_model_; @@ -73,7 +69,6 @@ class GeRootModel { bool train_flag_ = false; std::string model_name_; bool is_specific_stream_ = false; - int32_t hybrid_build_times_ = 0; }; } // namespace ge using GeRootModelPtr = std::shared_ptr; diff --git a/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc b/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc index 827705ae..2dc3b639 100644 --- a/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc +++ b/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc @@ -249,9 +249,6 @@ TEST_F(UtestSubgraphExecutor, cond_graph_schedule_tasks) { graph_context.callback_manager = std::unique_ptr(new CallbackManager()); ASSERT_EQ(graph_context.callback_manager->Init(), SUCCESS); - auto root_graph = hybrid_model.root_graph_; - switch_t = root_graph->FindNode("switch_t"); - switch_f = root_graph->FindNode("switch_f"); const auto node_it_t = hybrid_model.node_items_.find(switch_t); const auto node_it_f = hybrid_model.node_items_.find(switch_f); ASSERT_NE(hybrid_model.node_items_.end(), node_it_t); diff --git a/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc b/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc index 10f7c0fe..5567aca2 100644 --- a/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc +++ b/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc @@ -214,17 +214,11 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { ASSERT_EQ(it->second->frame_index_, index); ASSERT_EQ(it->second->parent_frame_, -1); }; - auto root_graph = hybrid_model.root_graph_; - auto enter1_node = root_graph->FindNode("enter"); - auto active1_node = root_graph->FindNode("active1"); - auto active2_node = root_graph->FindNode("active2"); - auto active3_node = root_graph->FindNode("active3"); - auto output1_node = root_graph->FindNode("net_output"); - TestFrameGroup(enter1_node, control_group_index); - TestFrameGroup(active1_node, control_group_index); - TestFrameGroup(active2_node, control_group_index); - TestFrameGroup(active3_node, control_group_index); - TestFrameGroup(output1_node, -1); + TestFrameGroup(enter1, control_group_index); + TestFrameGroup(active1, control_group_index); + TestFrameGroup(active2, control_group_index); + TestFrameGroup(active3, control_group_index); + TestFrameGroup(output1, -1); engine_mapping.clear(); task_executor.clear(); @@ -379,14 +373,4 @@ TEST_F(UtestHybridModelBuilder, TestInitHcclExecutorOnDemand) { NodeExecutorManager::GetInstance().builders_.erase(NodeExecutorManager::ExecutorType::HCCL); ASSERT_EQ(HybridModelBuilder::InitHcclExecutorOnDemand(ge_model), SUCCESS); } - -TEST_F(UtestHybridModelBuilder, copy_graph_success) { -ComputeGraphPtr graph = std::make_shared("test"); -GeRootModelPtr ge_root_model = make_shared(graph); -HybridModel hybrid_model(ge_root_model); -HybridModelBuilder hybrid_model_builder(hybrid_model); - -Status st = hybrid_model_builder.CopyGraph(); -EXPECT_EQ(st, SUCCESS); -} } // namespace ge