| @@ -599,6 +599,7 @@ Status ModelHelper::GenerateGeRootModel(OmFileLoadHelper &om_load_helper) { | |||||
| is_first_model = false; | is_first_model = false; | ||||
| root_model_->SetRootGraph(GraphUtils::GetComputeGraph(cur_model->GetGraph())); | root_model_->SetRootGraph(GraphUtils::GetComputeGraph(cur_model->GetGraph())); | ||||
| root_model_->SetModelId(cur_model->GetModelId()); | root_model_->SetModelId(cur_model->GetModelId()); | ||||
| root_model_->SetModelName(cur_model->GetName()); | |||||
| model_ = cur_model; | model_ = cur_model; | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -134,7 +134,7 @@ HybridModelBuilder::HybridModelBuilder(HybridModel &hybrid_model) | |||||
| Status HybridModelBuilder::Build() { | Status HybridModelBuilder::Build() { | ||||
| GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName()); | GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName()); | ||||
| hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName(); | |||||
| hybrid_model_.model_name_ = ge_root_model_->GetModelName(); | |||||
| GELOGI("[%s] Start to build hybrid model.", GetGraphName()); | GELOGI("[%s] Start to build hybrid model.", GetGraphName()); | ||||
| GE_CHK_STATUS_RET(InitRuntimeParams(), "[Invoke][InitRuntimeParams] failed, model_name_:[%s]", GetGraphName()); | GE_CHK_STATUS_RET(InitRuntimeParams(), "[Invoke][InitRuntimeParams] failed, model_name_:[%s]", GetGraphName()); | ||||
| GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), | GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), | ||||
| @@ -277,7 +277,7 @@ Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeIt | |||||
| auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| // not care result, if no this attr, stand for the op does not need force infershape | // not care result, if no this attr, stand for the op does not need force infershape | ||||
| (void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape); | |||||
| (void) AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape); | |||||
| GELOGD("node [%s] is need do infershape, flag is %d", | GELOGD("node [%s] is need do infershape, flag is %d", | ||||
| op_desc->GetName().c_str(), | op_desc->GetName().c_str(), | ||||
| node_item.is_need_force_infershape); | node_item.is_need_force_infershape); | ||||
| @@ -42,6 +42,10 @@ class GeRootModel { | |||||
| std::vector<uint32_t> GetAllModelId() const { return model_ids_; } | std::vector<uint32_t> GetAllModelId() const { return model_ids_; } | ||||
| void SetModelName(const std::string &model_name) { model_name_ = model_name; } | |||||
| const std::string &GetModelName() const { return model_name_; } | |||||
| Status CheckIsUnknownShape(bool &is_dynamic_shape); | Status CheckIsUnknownShape(bool &is_dynamic_shape); | ||||
| void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; } | void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; } | ||||
| @@ -57,6 +61,7 @@ class GeRootModel { | |||||
| // In multithread online secenario, same graph can owns different davinci_model for for concurrency | // In multithread online secenario, same graph can owns different davinci_model for for concurrency | ||||
| std::vector<uint32_t> model_ids_; | std::vector<uint32_t> model_ids_; | ||||
| bool train_flag_ = false; | bool train_flag_ = false; | ||||
| std::string model_name_; | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>; | using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>; | ||||
| @@ -154,9 +154,11 @@ TEST_F(UtestGeHybrid, index_taskdefs_failed) { | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | ||||
| GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | ||||
| ge_root_model->SetModelName("test_name"); | |||||
| HybridModel hybrid_model(ge_root_model); | HybridModel hybrid_model(ge_root_model); | ||||
| HybridModelBuilder hybrid_model_builder(hybrid_model); | HybridModelBuilder hybrid_model_builder(hybrid_model); | ||||
| ASSERT_EQ(hybrid_model_builder.Build(), INTERNAL_ERROR); | |||||
| ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR); | ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR); | ||||
| } | } | ||||
| @@ -655,4 +657,4 @@ TEST_F(UtestGeHybrid, TestParseDependentInputNodesForHccl) { | |||||
| ASSERT_EQ(builder.ParseDependentInputNodes(*node_item_1, deps), SUCCESS); | ASSERT_EQ(builder.ParseDependentInputNodes(*node_item_1, deps), SUCCESS); | ||||
| ASSERT_TRUE(model.GetNodeItem(node)->has_observer); | ASSERT_TRUE(model.GetNodeItem(node)->has_observer); | ||||
| ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1); | ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1); | ||||
| } | |||||
| } | |||||