diff --git a/ge/common/helper/model_helper.cc b/ge/common/helper/model_helper.cc index 74238bc1..e95c3429 100644 --- a/ge/common/helper/model_helper.cc +++ b/ge/common/helper/model_helper.cc @@ -599,6 +599,7 @@ Status ModelHelper::GenerateGeRootModel(OmFileLoadHelper &om_load_helper) { is_first_model = false; root_model_->SetRootGraph(GraphUtils::GetComputeGraph(cur_model->GetGraph())); root_model_->SetModelId(cur_model->GetModelId()); + root_model_->SetModelName(cur_model->GetName()); model_ = cur_model; continue; } diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index edf9eb92..f9ffbaca 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -134,7 +134,7 @@ HybridModelBuilder::HybridModelBuilder(HybridModel &hybrid_model) Status HybridModelBuilder::Build() { GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName()); - hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName(); + hybrid_model_.model_name_ = ge_root_model_->GetModelName(); GELOGI("[%s] Start to build hybrid model.", GetGraphName()); GE_CHK_STATUS_RET(InitRuntimeParams(), "[Invoke][InitRuntimeParams] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), @@ -160,7 +160,7 @@ Status HybridModelBuilder::Build() { Status HybridModelBuilder::BuildForSingleOp() { GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName()); - hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName(); + hybrid_model_.model_name_ = ge_root_model_->GetModelName(); GELOGI("[%s] Start to build hybrid model.", GetGraphName()); auto ret = ge_root_model_->GetSubgraphInstanceNameToModel(); const GeModelPtr ge_model = ret[ge_root_model_->GetRootGraph()->GetName()]; diff --git a/ge/model/ge_root_model.h b/ge/model/ge_root_model.h index 0747d77c..899be5d6 100755 --- a/ge/model/ge_root_model.h +++ b/ge/model/ge_root_model.h @@ -42,6 +42,10 @@ class GeRootModel { std::vector GetAllModelId() const { return model_ids_; } + void SetModelName(const std::string &model_name) { model_name_ = model_name; } + + const std::string &GetModelName() const { return model_name_; }; + Status CheckIsUnknownShape(bool &is_dynamic_shape); void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; } @@ -57,6 +61,7 @@ class GeRootModel { // In multithread online secenario, same graph can owns different davinci_model for for concurrency std::vector model_ids_; bool train_flag_ = false; + std::string model_name_; }; } // namespace ge using GeRootModelPtr = std::shared_ptr;