@@ -599,6 +599,7 @@ Status ModelHelper::GenerateGeRootModel(OmFileLoadHelper &om_load_helper) { | |||||
is_first_model = false; | is_first_model = false; | ||||
root_model_->SetRootGraph(GraphUtils::GetComputeGraph(cur_model->GetGraph())); | root_model_->SetRootGraph(GraphUtils::GetComputeGraph(cur_model->GetGraph())); | ||||
root_model_->SetModelId(cur_model->GetModelId()); | root_model_->SetModelId(cur_model->GetModelId()); | ||||
root_model_->SetModelName(cur_model->GetName()); | |||||
model_ = cur_model; | model_ = cur_model; | ||||
continue; | continue; | ||||
} | } | ||||
@@ -134,7 +134,7 @@ HybridModelBuilder::HybridModelBuilder(HybridModel &hybrid_model) | |||||
Status HybridModelBuilder::Build() { | Status HybridModelBuilder::Build() { | ||||
GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName()); | GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName()); | ||||
hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName(); | |||||
hybrid_model_.model_name_ = ge_root_model_->GetModelName(); | |||||
GELOGI("[%s] Start to build hybrid model.", GetGraphName()); | GELOGI("[%s] Start to build hybrid model.", GetGraphName()); | ||||
GE_CHK_STATUS_RET(InitRuntimeParams(), "[Invoke][InitRuntimeParams] failed, model_name_:[%s]", GetGraphName()); | GE_CHK_STATUS_RET(InitRuntimeParams(), "[Invoke][InitRuntimeParams] failed, model_name_:[%s]", GetGraphName()); | ||||
GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), | GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), | ||||
@@ -277,7 +277,7 @@ Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeIt | |||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
// not care result, if no this attr, stand for the op does not need force infershape | // not care result, if no this attr, stand for the op does not need force infershape | ||||
(void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape); | |||||
(void) AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape); | |||||
GELOGD("node [%s] is need do infershape, flag is %d", | GELOGD("node [%s] is need do infershape, flag is %d", | ||||
op_desc->GetName().c_str(), | op_desc->GetName().c_str(), | ||||
node_item.is_need_force_infershape); | node_item.is_need_force_infershape); | ||||
@@ -42,6 +42,10 @@ class GeRootModel { | |||||
std::vector<uint32_t> GetAllModelId() const { return model_ids_; } | std::vector<uint32_t> GetAllModelId() const { return model_ids_; } | ||||
void SetModelName(const std::string &model_name) { model_name_ = model_name; } | |||||
const std::string &GetModelName() const { return model_name_; } | |||||
Status CheckIsUnknownShape(bool &is_dynamic_shape); | Status CheckIsUnknownShape(bool &is_dynamic_shape); | ||||
void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; } | void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; } | ||||
@@ -57,6 +61,7 @@ class GeRootModel { | |||||
// In multithread online secenario, same graph can owns different davinci_model for for concurrency | // In multithread online secenario, same graph can owns different davinci_model for for concurrency | ||||
std::vector<uint32_t> model_ids_; | std::vector<uint32_t> model_ids_; | ||||
bool train_flag_ = false; | bool train_flag_ = false; | ||||
std::string model_name_; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>; | using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>; | ||||
@@ -154,9 +154,11 @@ TEST_F(UtestGeHybrid, index_taskdefs_failed) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | ||||
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | ||||
ge_root_model->SetModelName("test_name"); | |||||
HybridModel hybrid_model(ge_root_model); | HybridModel hybrid_model(ge_root_model); | ||||
HybridModelBuilder hybrid_model_builder(hybrid_model); | HybridModelBuilder hybrid_model_builder(hybrid_model); | ||||
ASSERT_EQ(hybrid_model_builder.Build(), INTERNAL_ERROR); | |||||
ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR); | ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR); | ||||
} | } | ||||
@@ -655,4 +657,4 @@ TEST_F(UtestGeHybrid, TestParseDependentInputNodesForHccl) { | |||||
ASSERT_EQ(builder.ParseDependentInputNodes(*node_item_1, deps), SUCCESS); | ASSERT_EQ(builder.ParseDependentInputNodes(*node_item_1, deps), SUCCESS); | ||||
ASSERT_TRUE(model.GetNodeItem(node)->has_observer); | ASSERT_TRUE(model.GetNodeItem(node)->has_observer); | ||||
ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1); | ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1); | ||||
} | |||||
} |