Browse Source

fix dump

tags/v1.3.0
wjm 3 years ago
parent
commit
9df5c3fc80
3 changed files with 8 additions and 2 deletions
  1. +1
    -0
      ge/common/helper/model_helper.cc
  2. +2
    -2
      ge/hybrid/model/hybrid_model_builder.cc
  3. +5
    -0
      ge/model/ge_root_model.h

+ 1
- 0
ge/common/helper/model_helper.cc View File

@@ -599,6 +599,7 @@ Status ModelHelper::GenerateGeRootModel(OmFileLoadHelper &om_load_helper) {
is_first_model = false;
root_model_->SetRootGraph(GraphUtils::GetComputeGraph(cur_model->GetGraph()));
root_model_->SetModelId(cur_model->GetModelId());
root_model_->SetModelName(cur_model->GetName());
model_ = cur_model;
continue;
}


+ 2
- 2
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -134,7 +134,7 @@ HybridModelBuilder::HybridModelBuilder(HybridModel &hybrid_model)

Status HybridModelBuilder::Build() {
GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName());
hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName();
hybrid_model_.model_name_ = ge_root_model_->GetModelName();
GELOGI("[%s] Start to build hybrid model.", GetGraphName());
GE_CHK_STATUS_RET(InitRuntimeParams(), "[Invoke][InitRuntimeParams] failed, model_name_:[%s]", GetGraphName());
GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(),
@@ -160,7 +160,7 @@ Status HybridModelBuilder::Build() {

Status HybridModelBuilder::BuildForSingleOp() {
GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName());
hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName();
hybrid_model_.model_name_ = ge_root_model_->GetModelName();
GELOGI("[%s] Start to build hybrid model.", GetGraphName());
auto ret = ge_root_model_->GetSubgraphInstanceNameToModel();
const GeModelPtr ge_model = ret[ge_root_model_->GetRootGraph()->GetName()];


+ 5
- 0
ge/model/ge_root_model.h View File

@@ -42,6 +42,10 @@ class GeRootModel {

std::vector<uint32_t> GetAllModelId() const { return model_ids_; }

void SetModelName(const std::string &model_name) { model_name_ = model_name; }
const std::string &GetModelName() const { return model_name_; };
Status CheckIsUnknownShape(bool &is_dynamic_shape);

void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; }
@@ -57,6 +61,7 @@ class GeRootModel {
// In multithread online secenario, same graph can owns different davinci_model for for concurrency
std::vector<uint32_t> model_ids_;
bool train_flag_ = false;
std::string model_name_;
};
} // namespace ge
using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>;


Loading…
Cancel
Save