Browse Source

!1499 fix dump

From: @jiming6
Reviewed-by: @xchu42
Signed-off-by:
tags/v1.3.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
c80cd136b7
4 changed files with 11 additions and 3 deletions
  1. +1
    -0
      ge/common/helper/model_helper.cc
  2. +2
    -2
      ge/hybrid/model/hybrid_model_builder.cc
  3. +5
    -0
      ge/model/ge_root_model.h
  4. +3
    -1
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc

+ 1
- 0
ge/common/helper/model_helper.cc View File

@@ -599,6 +599,7 @@ Status ModelHelper::GenerateGeRootModel(OmFileLoadHelper &om_load_helper) {
is_first_model = false;
root_model_->SetRootGraph(GraphUtils::GetComputeGraph(cur_model->GetGraph()));
root_model_->SetModelId(cur_model->GetModelId());
root_model_->SetModelName(cur_model->GetName());
model_ = cur_model;
continue;
}


+ 2
- 2
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -134,7 +134,7 @@ HybridModelBuilder::HybridModelBuilder(HybridModel &hybrid_model)

Status HybridModelBuilder::Build() {
GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName());
hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName();
hybrid_model_.model_name_ = ge_root_model_->GetModelName();
GELOGI("[%s] Start to build hybrid model.", GetGraphName());
GE_CHK_STATUS_RET(InitRuntimeParams(), "[Invoke][InitRuntimeParams] failed, model_name_:[%s]", GetGraphName());
GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(),
@@ -277,7 +277,7 @@ Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeIt
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
// not care result, if no this attr, stand for the op does not need force infershape
(void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape);
(void) AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape);
GELOGD("node [%s] is need do infershape, flag is %d",
op_desc->GetName().c_str(),
node_item.is_need_force_infershape);


+ 5
- 0
ge/model/ge_root_model.h View File

@@ -42,6 +42,10 @@ class GeRootModel {

std::vector<uint32_t> GetAllModelId() const { return model_ids_; }

void SetModelName(const std::string &model_name) { model_name_ = model_name; }
const std::string &GetModelName() const { return model_name_; }
Status CheckIsUnknownShape(bool &is_dynamic_shape);

void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; }
@@ -57,6 +61,7 @@ class GeRootModel {
// In multithread online secenario, same graph can owns different davinci_model for for concurrency
std::vector<uint32_t> model_ids_;
bool train_flag_ = false;
std::string model_name_;
};
} // namespace ge
using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>;


+ 3
- 1
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -154,9 +154,11 @@ TEST_F(UtestGeHybrid, index_taskdefs_failed) {

ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
ge_root_model->SetModelName("test_name");
HybridModel hybrid_model(ge_root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model);

ASSERT_EQ(hybrid_model_builder.Build(), INTERNAL_ERROR);
ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR);
}

@@ -655,4 +657,4 @@ TEST_F(UtestGeHybrid, TestParseDependentInputNodesForHccl) {
ASSERT_EQ(builder.ParseDependentInputNodes(*node_item_1, deps), SUCCESS);
ASSERT_TRUE(model.GetNodeItem(node)->has_observer);
ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1);
}
}

Loading…
Cancel
Save