diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index f36c1c0d..b862a7d6 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -3132,10 +3132,10 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { } // Avoid repeatively prerun for graphs owns same graph_id in online inference concurrency if (count > 1 && graph_node->GetBuildFlag()) { + graph_node->Lock(); GELOGD("Avoid repeatively prerun, graph_id:%u.", args.graph_id); // In online inference concurrency senario, graph_node is allowed to be locked for 'count' times graph_node->SetSemSize(count); - graph_node->Lock(); graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context, args.input_tensor, graph_node->GetGeRootModel(), GetThreadLocalContext(), args.callback })); GELOGI("[PreRunThread] Loop end. Start to run with cached build model."); diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index c050875e..d3f00253 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -147,7 +147,6 @@ Status HybridModelBuilder::Build() { GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName()); hybrid_model_.model_name_ = ge_root_model_->GetModelName(); GELOGI("[%s] Start to build hybrid model.", GetGraphName()); - GE_CHK_STATUS_RET(CopyGraph(), "[Invoke][CopyGraph] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET(InitRuntimeParams(), "[Invoke][InitRuntimeParams] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), "[Invoke][RecoverGraphUnknownFlag] failed, model_name_:[%s]", GetGraphName()); @@ -175,8 +174,8 @@ Status HybridModelBuilder::BuildForSingleOp() { hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName(); GELOGI("[%s] Start to build hybrid model.", GetGraphName()); auto ret = ge_root_model_->GetSubgraphInstanceNameToModel(); - const GeModelPtr ge_model = ret[hybrid_model_.root_graph_->GetName()]; - GE_CHK_STATUS_RET(IndexTaskDefs(hybrid_model_.root_graph_, ge_model), + const GeModelPtr ge_model = ret[ge_root_model_->GetRootGraph()->GetName()]; + GE_CHK_STATUS_RET(IndexTaskDefs(ge_root_model_->GetRootGraph(), ge_model), "[Invoke][IndexTaskDefs] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET(LoadGraph(), "[Invoke][LoadGraph] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET(InitWeights(), "[Invoke][InitWeights] failed, model_name_:[%s]", GetGraphName()); @@ -191,29 +190,6 @@ Status HybridModelBuilder::ValidateParams() { return SUCCESS; } -Status HybridModelBuilder::CopyGraph() { - GELOGD("Copy compute graph begin."); - auto root_graph = ge_root_model_->GetRootGraph(); - - ge_root_model_->IncreaseBuildTimes(); - std::string new_graph_name = ge_root_model_->GetRootGraph()->GetName() + "_" + - std::to_string(ge_root_model_->GetBuildTimes()); - ComputeGraphPtr new_root_graph = MakeShared(new_graph_name); - GE_CHECK_NOTNULL(new_root_graph); - int32_t depth = 0; - std::map node_old_2_new; - std::map op_desc_old_2_new; - graphStatus ret = GraphUtils::CopyComputeGraph(root_graph, new_root_graph, node_old_2_new, op_desc_old_2_new, depth); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Copy compute graph failed."); - return GRAPH_FAILED; - } - hybrid_model_.root_graph_ = new_root_graph; - - GELOGD("Copy compute graph[%s] success.", new_graph_name.c_str()); - return SUCCESS; -} - Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_item) { auto op_desc = node->GetOpDesc(); GE_CHK_STATUS_RET(ParseForceInfershapeNodes(node, node_item), @@ -838,7 +814,7 @@ Status HybridModelBuilder::BuildOutputMapping(GraphItem &graph_item, } Status HybridModelBuilder::LoadGraph() { - auto root_graph = hybrid_model_.root_graph_; + auto root_graph = ge_root_model_->GetRootGraph(); if (!GetContext().GetHostExecFlag()) { std::shared_ptr merged_graph; GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", @@ -852,6 +828,7 @@ Status HybridModelBuilder::LoadGraph() { root_graph->GetAllNodesSize()); } + hybrid_model_.root_graph_ = root_graph; GE_CHK_STATUS_RET(RelinkNextIteration(), "[%s] Relink NextIteration failed", GetGraphName()); // Reset node id by topological order across all subgraphs int64_t index = 0; @@ -900,7 +877,6 @@ Status HybridModelBuilder::LoadGraph() { } for (auto &it : hybrid_model_.known_shape_sub_models_) { auto node_item = MutableNodeItem(it.first); - GE_CHECK_NOTNULL(node_item); AscendString graph_name; GE_CHK_GRAPH_STATUS_RET(it.second->GetGraph().GetName(graph_name), "Failed to get subgraph name"); auto subgraph = hybrid_model_.GetRootGraph()->GetSubgraph(graph_name.GetString()); @@ -1149,9 +1125,7 @@ Status HybridModelBuilder::InitWeights() { sub_weight_buffer->GetSize()); auto subgraph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph()); if (subgraph != ge_root_model_->GetRootGraph()) { - subgraph = hybrid_model_.root_graph_->GetSubgraph(subgraph_model.first); - } else { - subgraph = hybrid_model_.root_graph_; + subgraph = ge_root_model_->GetRootGraph()->GetSubgraph(subgraph_model.first); } GE_CHECK_NOTNULL(subgraph); hybrid_model_.weight_buffer_map_.emplace(subgraph->GetName(), std::move(sub_weight_buffer)); @@ -1308,7 +1282,7 @@ Status HybridModelBuilder::IndexTaskDefs(const ComputeGraphPtr &sub_graph, const } Status HybridModelBuilder::IndexTaskDefs() { - const auto &root_graph = hybrid_model_.root_graph_; + const auto root_graph = ge_root_model_->GetRootGraph(); const auto &root_graph_name = root_graph->GetName(); if (SetOutputNameAttr(*root_graph) != SUCCESS) { GELOGW("Set output name attr failed."); @@ -1342,7 +1316,7 @@ Status HybridModelBuilder::IndexTaskDefs() { Status HybridModelBuilder::IndexSpecialNodes() { GELOGD("Start to index special nodes"); - const auto &root_graph = hybrid_model_.root_graph_; + const auto &root_graph = ge_root_model_->GetRootGraph(); for (auto &node : root_graph->GetAllNodes()) { GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node->GetOpDesc()); @@ -1497,7 +1471,7 @@ Status HybridModelBuilder::InitRuntimeParams() { runtime_param_.session_id = ret ? static_cast(value) : 0; ret = ge::AttrUtils::GetInt(first_model, ATTR_MODEL_TASK_GEN_VAR_ADDR, value); runtime_param_.logic_var_base = ret ? static_cast(value) : 0; - runtime_param_.graph_id = hybrid_model_.root_graph_->GetGraphID(); + runtime_param_.graph_id = ge_root_model_->GetRootGraph()->GetGraphID(); value = 0; for (auto &it : ge_root_model_->GetSubgraphInstanceNameToModel()) { (void) ge::AttrUtils::GetInt(it.second, ATTR_MODEL_VAR_SIZE, value); @@ -1634,7 +1608,7 @@ Status HybridModelBuilder::TransAllVarData() { } Status HybridModelBuilder::CopyVarData() { - GE_CHK_STATUS_RET(TransVarDataUtils::CopyVarData(hybrid_model_.root_graph_, + GE_CHK_STATUS_RET(TransVarDataUtils::CopyVarData(ge_root_model_->GetRootGraph(), runtime_param_.session_id, hybrid_model_.device_id_), "[Invoke][CopyVarData] failed."); @@ -1717,7 +1691,7 @@ Status HybridModelBuilder::LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem } Status HybridModelBuilder::RecoverGraphUnknownFlag() { - const auto &root_graph = hybrid_model_.root_graph_; + const auto &root_graph = ge_root_model_->GetRootGraph(); for (auto &sub_graph : root_graph->GetAllSubgraphs()) { GE_CHECK_NOTNULL(sub_graph); for (const auto &node : sub_graph->GetDirectNode()) { diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index 3ab43b7f..92974441 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -56,7 +56,6 @@ class HybridModelBuilder { Status BuildOutputMapping(GraphItem &partitioned_call, const NodeItem &node_item, bool is_root_graph); Status ValidateParams(); Status LoadGraph(); - Status CopyGraph(); Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); Status LoadTask(NodeItem &node_item); Status LoadTasks(); diff --git a/ge/model/ge_root_model.h b/ge/model/ge_root_model.h index b6e3d081..9e8e116e 100755 --- a/ge/model/ge_root_model.h +++ b/ge/model/ge_root_model.h @@ -60,10 +60,6 @@ class GeRootModel { bool GetTrainFlag() const { return train_flag_; } - int32_t GetBuildTimes() const { return hybrid_build_times_; } - - void IncreaseBuildTimes() { hybrid_build_times_++; } - private: ComputeGraphPtr root_graph_ = nullptr; std::map subgraph_instance_name_to_model_; @@ -73,7 +69,6 @@ class GeRootModel { bool train_flag_ = false; std::string model_name_; bool is_specific_stream_ = false; - int32_t hybrid_build_times_ = 0; }; } // namespace ge using GeRootModelPtr = std::shared_ptr; diff --git a/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc b/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc index 827705ae..2dc3b639 100644 --- a/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc +++ b/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc @@ -249,9 +249,6 @@ TEST_F(UtestSubgraphExecutor, cond_graph_schedule_tasks) { graph_context.callback_manager = std::unique_ptr(new CallbackManager()); ASSERT_EQ(graph_context.callback_manager->Init(), SUCCESS); - auto root_graph = hybrid_model.root_graph_; - switch_t = root_graph->FindNode("switch_t"); - switch_f = root_graph->FindNode("switch_f"); const auto node_it_t = hybrid_model.node_items_.find(switch_t); const auto node_it_f = hybrid_model.node_items_.find(switch_f); ASSERT_NE(hybrid_model.node_items_.end(), node_it_t); diff --git a/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc b/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc index 95669b73..2ab82350 100644 --- a/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc +++ b/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc @@ -214,17 +214,11 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { ASSERT_EQ(it->second->frame_index_, index); ASSERT_EQ(it->second->parent_frame_, -1); }; - auto root_graph = hybrid_model.root_graph_; - auto enter1_node = root_graph->FindNode("enter"); - auto active1_node = root_graph->FindNode("active1"); - auto active2_node = root_graph->FindNode("active2"); - auto active3_node = root_graph->FindNode("active3"); - auto output1_node = root_graph->FindNode("net_output"); - TestFrameGroup(enter1_node, control_group_index); - TestFrameGroup(active1_node, control_group_index); - TestFrameGroup(active2_node, control_group_index); - TestFrameGroup(active3_node, control_group_index); - TestFrameGroup(output1_node, -1); + TestFrameGroup(enter1, control_group_index); + TestFrameGroup(active1, control_group_index); + TestFrameGroup(active2, control_group_index); + TestFrameGroup(active3, control_group_index); + TestFrameGroup(output1, -1); engine_mapping.clear(); task_executor.clear(); @@ -352,14 +346,4 @@ EXPECT_EQ(hybrid_model_builder.InitVariableTensors(), SUCCESS); EXPECT_EQ(hybrid_model_builder.hybrid_model_.variable_tensors_.size(), 1); HostMemManager::Instance().var_memory_base_map_.clear(); } - -TEST_F(UtestHybridModelBuilder, copy_graph_success) { -ComputeGraphPtr graph = std::make_shared("test"); -GeRootModelPtr ge_root_model = make_shared(graph); -HybridModel hybrid_model(ge_root_model); -HybridModelBuilder hybrid_model_builder(hybrid_model); - -Status st = hybrid_model_builder.CopyGraph(); -EXPECT_EQ(st, SUCCESS); -} } // namespace ge