Browse Source

fix

tags/v1.2.0
wjm 3 years ago
parent
commit
b83f971b7a
3 changed files with 27 additions and 6 deletions
  1. +6
    -5
      ge/common/helper/model_helper.cc
  2. +4
    -1
      ge/init/gelib.cc
  3. +17
    -0
      tests/ut/ge/graph/load/model_helper_unittest.cc

+ 6
- 5
ge/common/helper/model_helper.cc View File

@@ -891,6 +891,7 @@ Status ModelTool::GetModelInfoFromOm(const char *model_file, ge::proto::ModelDef
model.model_data = nullptr;
}
};
GE_MAKE_GUARD(release, callback);

uint8_t *model_data = nullptr;
uint32_t model_len = 0;
@@ -903,17 +904,17 @@ Status ModelTool::GetModelInfoFromOm(const char *model_file, ge::proto::ModelDef
return ret;
}

OmFileLoadHelper omFileLoadHelper;
ret = omFileLoadHelper.Init(model_data, model_len);
if (ret != ge::GRAPH_SUCCESS) {
OmFileLoadHelper om_load_helper;
ret = om_load_helper.Init(model_data, model_len);
if (ret != SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {"Om file init failed"});
GELOGE(ge::FAILED, "Om file init failed.");
return ret;
}

ModelPartition ir_part;
ret = omFileLoadHelper.GetModelPartition(MODEL_DEF, ir_part);
if (ret != ge::GRAPH_SUCCESS) {
ret = om_load_helper.GetModelPartition(MODEL_DEF, ir_part);
if (ret != SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {"Get model part failed"});
GELOGE(ge::FAILED, "Get model part failed.");
return ret;


+ 4
- 1
ge/init/gelib.cc View File

@@ -543,7 +543,10 @@ Status GEInit::Initialize(const map<string, string> &options) {
}

Status GEInit::Finalize() {
return GELib::GetInstance()->Finalize();
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr != nullptr) {
return instance_ptr->Finalize();
}
}

string GEInit::GetPath() {


+ 17
- 0
tests/ut/ge/graph/load/model_helper_unittest.cc View File

@@ -49,4 +49,21 @@ TEST_F(UtestModelHelper, save_size_to_modeldef)
ModelHelper model_helper;
EXPECT_EQ(SUCCESS, model_helper.SaveSizeToModelDef(ge_model));
}
TEST_F(UtestModelHelper, atc_test)
{
ge::proto::ModelDef model_def;
uint32_t modeldef_size = 0;
GEInit::Finalize();
char buffer[1024];
getcwd(buffer, 1024);
string path=buffer;
string file_path=path + "/Makefile";
ModelTool::GetModelInfoFromOm(file_path.c_str(), model_def, modeldef_size);
ModelTool::GetModelInfoFromOm("123.om", model_def, modeldef_size);
ModelTool::GetModelInfoFromPbtxt(file_path.c_str(), model_def);
ModelTool::GetModelInfoFromPbtxt("123.pbtxt", model_def);
}
} // namespace ge

Loading…
Cancel
Save