Browse Source

!857 Feature:display model info

From: @wangwenhua1
Reviewed-by: @xchu42,@ji_chen
Signed-off-by:
tags/v1.2.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
87914cc74b
4 changed files with 33 additions and 12 deletions
  1. +6
    -8
      ge/generator/ge_generator.cc
  2. +5
    -2
      ge/offline/main.cc
  3. +14
    -2
      ge/session/omg.cc
  4. +8
    -0
      inc/framework/common/ge_types.h

+ 6
- 8
ge/generator/ge_generator.cc View File

@@ -556,17 +556,15 @@ bool GeGenerator::Impl::SetOmSystemInfo(AttrHolder &obj) {
return false;
}

// 0(Caffe) 1(MindSpore) 3(TensorFlow) 5(Onnx)
std::map<string, string> framework_type_to_string = {
{"0", "Caffe"},
{"1", "MindSpore"},
{"3", "TensorFlow"},
{"5", "Onnx"}
};
std::string framework_type;
(void)ge::GetContext().GetOption(ge::FRAMEWORK_TYPE, framework_type);
GELOGI("SetOmSystemInfo framework_type: %s", framework_type.c_str());
if (!ge::AttrUtils::SetStr(obj, "framework_type", framework_type_to_string[framework_type.c_str()])) {
auto iter = ge::kFwkTypeToStr.find(framework_type);
if (iter == ge::kFwkTypeToStr.end()) {
GELOGW("Can not find framework_type in the map.");
return false;
}
if (!ge::AttrUtils::SetStr(obj, "framework_type", iter->second)) {
GELOGW("SetStr of framework_type failed.");
return false;
}


+ 5
- 2
ge/offline/main.cc View File

@@ -232,8 +232,7 @@ class GFlagUtils {
"[General]\n"
" --h/help Show this help message\n"
" --mode Run mode. 0(default): generate offline model; 1: convert model to JSON format; "
"6: display model info"
"3: only pre-check; 5: convert ge dump txt file to JSON format\n"
"3: only pre-check; 5: convert ge dump txt file to JSON format; 6: display model info\n"
"\n[Input]\n"
" --model Model file\n"
" --weight Weight file. Required when framework is Caffe\n"
@@ -463,6 +462,10 @@ class GFlagUtils {
ge::CheckEnableSingleStreamParamValid(std::string(FLAGS_enable_single_stream)) == ge::SUCCESS,
ret = ge::FAILED, "check enable single stream failed!");

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((FLAGS_display_model_info != "0") && (FLAGS_display_model_info != "1"),
ErrorManager::GetInstance().ATCReportErrMessage("E10006", {"parameter"}, {"display_model_info"});
ret = ge::FAILED, "Input parameter[--display_model_info]'s value must be 1 or 0.");

return ret;
}



+ 14
- 2
ge/session/omg.cc View File

@@ -963,6 +963,7 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertOm(const char *model_file, const char *js
OmFileLoadHelper omFileLoadHelper;
ge::graphStatus status = omFileLoadHelper.Init(model_data, model_len);
if (status != ge::GRAPH_SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {"Om file init failed"});
GELOGE(ge::FAILED, "Om file init failed.");
if (model.model_data != nullptr) {
delete[] reinterpret_cast<char *>(model.model_data);
@@ -974,6 +975,7 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertOm(const char *model_file, const char *js
ModelPartition ir_part;
status = omFileLoadHelper.GetModelPartition(MODEL_DEF, ir_part);
if (status != ge::GRAPH_SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {"Get model part failed"});
GELOGE(ge::FAILED, "Get model part failed.");
if (model.model_data != nullptr) {
delete[] reinterpret_cast<char *>(model.model_data);
@@ -999,9 +1001,12 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertOm(const char *model_file, const char *js
}
} else {
ret = INTERNAL_ERROR;
ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {"ReadProtoFromArray failed"});
GELOGE(ret, "ReadProtoFromArray failed.");
}
} else {
ErrorManager::GetInstance().ATCReportErrMessage("E10003",
{"parameter", "value", "reason"}, {"om", model_file, "invalid om file"});
GELOGE(PARAM_INVALID, "ParseModelContent failed because of invalid om file. Please check --om param.");
}

@@ -1011,6 +1016,8 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertOm(const char *model_file, const char *js
}
return ret;
} catch (const std::exception &e) {
ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"},
{"Convert om model to json failed, exception message[" + std::string(e.what()) + "]"});
GELOGE(FAILED, "Convert om model to json failed, exception message : %s.", e.what());
return FAILED;
}
@@ -1041,7 +1048,8 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertPbtxtToJson(const char *model_file, const

if (!flag) {
free_model_data(&model.model_data);
GELOGE(FAILED, "ParseFromString fail.");
ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {"ParseFromString failed"});
GELOGE(FAILED, "ParseFromString failed.");
return FAILED;
}
GetGroupName(model_def);
@@ -1057,9 +1065,13 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertPbtxtToJson(const char *model_file, const
return SUCCESS;
} catch (google::protobuf::FatalException &e) {
free_model_data(&model.model_data);
GELOGE(FAILED, "ParseFromString fail. exception message : %s", e.what());
ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {"ParseFromString failed, exception message["
+ std::string(e.what()) + "]"});
GELOGE(FAILED, "ParseFromString failed. exception message : %s", e.what());
return FAILED;
} catch (const std::exception &e) {
ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"},
{"Convert pbtxt to json failed, exception message[" + std::string(e.what()) + "]"});
GELOGE(FAILED, "Convert pbtxt to json failed, exception message : %s.", e.what());
return FAILED;
}


+ 8
- 0
inc/framework/common/ge_types.h View File

@@ -40,6 +40,14 @@ enum FrameworkType {
ONNX,
};

const std::map<std::string, std::string> kFwkTypeToStr = {
{"0", "Caffe"},
{"1", "MindSpore"},
{"3", "TensorFlow"},
{"4", "Android_NN"},
{"5", "Onnx"}
};

enum OpEngineType {
ENGINE_SYS = 0, // default engine
ENGINE_AICORE = 1,


Loading…
Cancel
Save