diff --git a/ge/common/helper/model_helper.cc b/ge/common/helper/model_helper.cc index 74238bc1..e95c3429 100644 --- a/ge/common/helper/model_helper.cc +++ b/ge/common/helper/model_helper.cc @@ -599,6 +599,7 @@ Status ModelHelper::GenerateGeRootModel(OmFileLoadHelper &om_load_helper) { is_first_model = false; root_model_->SetRootGraph(GraphUtils::GetComputeGraph(cur_model->GetGraph())); root_model_->SetModelId(cur_model->GetModelId()); + root_model_->SetModelName(cur_model->GetName()); model_ = cur_model; continue; } diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index edf9eb92..9b3cb692 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -134,7 +134,7 @@ HybridModelBuilder::HybridModelBuilder(HybridModel &hybrid_model) Status HybridModelBuilder::Build() { GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName()); - hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName(); + hybrid_model_.model_name_ = ge_root_model_->GetModelName(); GELOGI("[%s] Start to build hybrid model.", GetGraphName()); GE_CHK_STATUS_RET(InitRuntimeParams(), "[Invoke][InitRuntimeParams] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), @@ -277,7 +277,7 @@ Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeIt auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); // not care result, if no this attr, stand for the op does not need force infershape - (void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape); + (void) AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape); GELOGD("node [%s] is need do infershape, flag is %d", op_desc->GetName().c_str(), node_item.is_need_force_infershape); diff --git a/ge/model/ge_root_model.h b/ge/model/ge_root_model.h index 0747d77c..8c44272d 100755 --- a/ge/model/ge_root_model.h +++ b/ge/model/ge_root_model.h @@ -42,6 +42,10 @@ class GeRootModel { std::vector GetAllModelId() const { return model_ids_; } + void SetModelName(const std::string &model_name) { model_name_ = model_name; } + + const std::string &GetModelName() const { return model_name_; } + Status CheckIsUnknownShape(bool &is_dynamic_shape); void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; } @@ -57,6 +61,7 @@ class GeRootModel { // In multithread online secenario, same graph can owns different davinci_model for for concurrency std::vector model_ids_; bool train_flag_ = false; + std::string model_name_; }; } // namespace ge using GeRootModelPtr = std::shared_ptr; diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 25115340..b5aac527 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -154,9 +154,11 @@ TEST_F(UtestGeHybrid, index_taskdefs_failed) { ComputeGraphPtr graph = std::make_shared("test"); GeRootModelPtr ge_root_model = make_shared(graph); + ge_root_model->SetModelName("test_name"); HybridModel hybrid_model(ge_root_model); HybridModelBuilder hybrid_model_builder(hybrid_model); + ASSERT_EQ(hybrid_model_builder.Build(), INTERNAL_ERROR); ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR); } @@ -655,4 +657,4 @@ TEST_F(UtestGeHybrid, TestParseDependentInputNodesForHccl) { ASSERT_EQ(builder.ParseDependentInputNodes(*node_item_1, deps), SUCCESS); ASSERT_TRUE(model.GetNodeItem(node)->has_observer); ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1); -} \ No newline at end of file +}