| @@ -74,6 +74,49 @@ Status ModelHelper::SaveModelPartition(std::shared_ptr<OmFileSaveHelper> &om_fil | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ModelHelper::SaveSizeToModelDef(const GeModelPtr &ge_model) { | |||||
| vector<int64_t> om_info; | |||||
| ModelPtr model_tmp = ge::MakeShared<ge::Model>(ge_model->GetName(), ge_model->GetPlatformVersion()); | |||||
| if (model_tmp == nullptr) { | |||||
| GELOGE(FAILED, "Create Model %s Ptr failed", ge_model->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| model_tmp->SetGraph(ge_model->GetGraph()); | |||||
| model_tmp->SetVersion(ge_model->GetVersion()); | |||||
| model_tmp->SetAttr(ge_model->MutableAttrMap()); | |||||
| ge::Buffer model_buffer; | |||||
| (void)model_tmp->Save(model_buffer); | |||||
| GELOGD("SaveSizeToModelDef modeldef_size is %zu", model_buffer.GetSize()); | |||||
| om_info.push_back(model_buffer.GetSize()); | |||||
| auto ge_model_weight = ge_model->GetWeight(); | |||||
| GELOGD("SaveSizeToModelDef weight_data_size is %zu, %p", ge_model_weight.GetSize(), ge_model_weight.GetData()); | |||||
| om_info.push_back(ge_model_weight.GetSize()); | |||||
| TBEKernelStore tbe_kernel_store = ge_model->GetTBEKernelStore(); | |||||
| GELOGD("SaveSizeToModelDef tbe_kernels_size is %zu", tbe_kernel_store.DataSize()); | |||||
| om_info.push_back(tbe_kernel_store.DataSize()); | |||||
| CustAICPUKernelStore cust_aicpu_kernel_store = ge_model->GetCustAICPUKernelStore(); | |||||
| GELOGD("SaveSizeToModelDef cust aicpu kernels size is %zu", cust_aicpu_kernel_store.DataSize()); | |||||
| om_info.push_back(cust_aicpu_kernel_store.DataSize()); | |||||
| std::shared_ptr<ModelTaskDef> model_task_def = ge_model->GetModelTaskDefPtr(); | |||||
| if (model_task_def == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Create model task def ptr failed"); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| size_t partition_task_size = model_task_def->ByteSizeLong(); | |||||
| GELOGD("SaveSizeToModelDef task_info_size is %zu", partition_task_size); | |||||
| om_info.push_back(partition_task_size); | |||||
| GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListInt(*(ge_model.get()), "om_info_list", om_info), | |||||
| GELOGE(FAILED, "SetListInt of om_info_list failed."); | |||||
| return FAILED); | |||||
| return SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmModel(const GeModelPtr &ge_model, | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmModel(const GeModelPtr &ge_model, | ||||
| const SaveParam &save_param, | const SaveParam &save_param, | ||||
| const std::string &output_file, | const std::string &output_file, | ||||
| @@ -94,6 +137,11 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmMod | |||||
| model_tmp->SetGraph(ge_model->GetGraph()); | model_tmp->SetGraph(ge_model->GetGraph()); | ||||
| model_tmp->SetVersion(ge_model->GetVersion()); | model_tmp->SetVersion(ge_model->GetVersion()); | ||||
| model_tmp->SetAttr(ge_model->MutableAttrMap()); | model_tmp->SetAttr(ge_model->MutableAttrMap()); | ||||
| Status ret = SaveSizeToModelDef(ge_model); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "SaveSizeToModelDef failed"); | |||||
| return ret; | |||||
| } | |||||
| ge::Buffer model_buffer; | ge::Buffer model_buffer; | ||||
| (void)model_tmp->Save(model_buffer); | (void)model_tmp->Save(model_buffer); | ||||
| @@ -249,6 +249,7 @@ class GeGenerator::Impl { | |||||
| bool GetVersionFromPath(const std::string &file_path, std::string &version); | bool GetVersionFromPath(const std::string &file_path, std::string &version); | ||||
| bool SetAtcVersionInfo(AttrHolder &obj); | bool SetAtcVersionInfo(AttrHolder &obj); | ||||
| bool SetOppVersionInfo(AttrHolder &obj); | bool SetOppVersionInfo(AttrHolder &obj); | ||||
| bool SetOmSystemInfo(AttrHolder &obj); | |||||
| }; | }; | ||||
| Status GeGenerator::Initialize(const map<string, string> &options) { return Initialize(options, domi::GetContext()); } | Status GeGenerator::Initialize(const map<string, string> &options) { return Initialize(options, domi::GetContext()); } | ||||
| @@ -462,6 +463,32 @@ bool GeGenerator::Impl::SetOppVersionInfo(AttrHolder &obj) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool GeGenerator::Impl::SetOmSystemInfo(AttrHolder &obj) { | |||||
| std::string soc_version; | |||||
| (void)ge::GetContext().GetOption(ge::SOC_VERSION, soc_version); | |||||
| GELOGI("SetOmSystemInfo soc_version: %s", soc_version.c_str()); | |||||
| if (!ge::AttrUtils::SetStr(obj, "soc_version", soc_version)) { | |||||
| GELOGW("SetStr of soc_version failed."); | |||||
| return false; | |||||
| } | |||||
| // 0(Caffe) 1(MindSpore) 3(TensorFlow) 5(Onnx) | |||||
| std::map<string, string> framework_type_to_string = { | |||||
| {"0", "Caffe"}, | |||||
| {"1", "MindSpore"}, | |||||
| {"3", "TensorFlow"}, | |||||
| {"5", "Onnx"} | |||||
| }; | |||||
| std::string framework_type; | |||||
| (void)ge::GetContext().GetOption(ge::FRAMEWORK_TYPE, framework_type); | |||||
| GELOGI("SetOmSystemInfo framework_type: %s", framework_type.c_str()); | |||||
| if (!ge::AttrUtils::SetStr(obj, "framework_type", framework_type_to_string[framework_type.c_str()])) { | |||||
| GELOGW("SetStr of framework_type failed."); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs, | Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs, | ||||
| ModelBufferData &model, bool is_offline) { | ModelBufferData &model, bool is_offline) { | ||||
| rtContext_t ctx = nullptr; | rtContext_t ctx = nullptr; | ||||
| @@ -664,6 +691,9 @@ Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, GeModelPtr & | |||||
| if (!SetOppVersionInfo(*(model.get()))) { | if (!SetOppVersionInfo(*(model.get()))) { | ||||
| GELOGW("SetPackageVersionInfo of ops failed!"); | GELOGW("SetPackageVersionInfo of ops failed!"); | ||||
| } | } | ||||
| if (!SetOmSystemInfo(*(model_root.get()))) { | |||||
| GELOGW("SetOmsystemInfo failed!"); | |||||
| } | |||||
| ModelHelper model_helper; | ModelHelper model_helper; | ||||
| model_helper.SetSaveMode(is_offline_); | model_helper.SetSaveMode(is_offline_); | ||||
| Status ret = model_helper.SaveToOmModel(model, save_param_, file_name_prefix, model_buff); | Status ret = model_helper.SaveToOmModel(model, save_param_, file_name_prefix, model_buff); | ||||
| @@ -194,6 +194,7 @@ DEFINE_int32(op_debug_level, 0, "Optional; configure debug level of compiler. 0( | |||||
| "1: open TBE compiler, export ccec file and TBE instruction mapping file; 2: open ccec compiler"); | "1: open TBE compiler, export ccec file and TBE instruction mapping file; 2: open ccec compiler"); | ||||
| DEFINE_string(enable_scope_fusion_passes, "", "Optional; validate the non-general scope fusion pass," | DEFINE_string(enable_scope_fusion_passes, "", "Optional; validate the non-general scope fusion pass," | ||||
| "multiple names can be set and separated by ','."); | "multiple names can be set and separated by ','."); | ||||
| DEFINE_string(display_model_info, "0", "Optional; display model info"); | |||||
| class GFlagUtils { | class GFlagUtils { | ||||
| public: | public: | ||||
| @@ -215,7 +216,7 @@ class GFlagUtils { | |||||
| "[General]\n" | "[General]\n" | ||||
| " --h/help Show this help message\n" | " --h/help Show this help message\n" | ||||
| " --mode Run mode. 0(default): generate offline model; 1: convert model to JSON format " | " --mode Run mode. 0(default): generate offline model; 1: convert model to JSON format " | ||||
| "3: only pre-check; 5: convert pbtxt file to JSON format\n" | |||||
| "3: only pre-check; 5: convert pbtxt file to JSON format; 6: display model info\n" | |||||
| "\n[Input]\n" | "\n[Input]\n" | ||||
| " --model Model file\n" | " --model Model file\n" | ||||
| " --weight Weight file. Required when framework is Caffe\n" | " --weight Weight file. Required when framework is Caffe\n" | ||||
| @@ -296,7 +297,8 @@ class GFlagUtils { | |||||
| " --save_original_model Control whether to output original model. E.g.: true: output original model\n" | " --save_original_model Control whether to output original model. E.g.: true: output original model\n" | ||||
| " --log Generate log with level. Support debug, info, warning, error, null\n" | " --log Generate log with level. Support debug, info, warning, error, null\n" | ||||
| " --dump_mode The switch of dump json with shape, to be used with mode 1." | " --dump_mode The switch of dump json with shape, to be used with mode 1." | ||||
| "0(default): disable; 1: enable."); | |||||
| "0(default): disable; 1: enable.\n" | |||||
| " --display_model_info enable for display model info; 0(default): close display, 1: open display"); | |||||
| gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true); | gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true); | ||||
| // Using gflags to analyze input parameters | // Using gflags to analyze input parameters | ||||
| @@ -1133,6 +1135,8 @@ domi::Status GenerateOmModel() { | |||||
| options.insert(std::pair<string, string>(string(ge::ENABLE_SINGLE_STREAM), FLAGS_enable_single_stream)); | options.insert(std::pair<string, string>(string(ge::ENABLE_SINGLE_STREAM), FLAGS_enable_single_stream)); | ||||
| options.insert(std::pair<string, string>(string(ge::DISPLAY_MODEL_INFO), FLAGS_display_model_info)); | |||||
| SetDynamicInputSizeOptions(); | SetDynamicInputSizeOptions(); | ||||
| if (!FLAGS_save_original_model.empty()) { | if (!FLAGS_save_original_model.empty()) { | ||||
| @@ -1152,10 +1156,34 @@ domi::Status GenerateOmModel() { | |||||
| if (ret != domi::SUCCESS) { | if (ret != domi::SUCCESS) { | ||||
| return domi::FAILED; | return domi::FAILED; | ||||
| } | } | ||||
| if (FLAGS_display_model_info == "1") { | |||||
| GELOGI("need to display model info."); | |||||
| return ge::ConvertOm(FLAGS_output.c_str(), "", false); | |||||
| } | |||||
| return domi::SUCCESS; | return domi::SUCCESS; | ||||
| } | } | ||||
| domi::Status DisplayModelInfo() { | |||||
| // No model path passed in | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_om == "", | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"om"}); | |||||
| return ge::FAILED, | |||||
| "Input parameter[--om]'s value is empty!!"); | |||||
| // Check if the model path is valid | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| FLAGS_om != "" && !ge::CheckInputPathValid(FLAGS_om, "--om"), | |||||
| return ge::FAILED, | |||||
| "model file path is invalid: %s.", FLAGS_om.c_str()); | |||||
| if (FLAGS_framework == -1) { | |||||
| return ge::ConvertOm(FLAGS_om.c_str(), "", false); | |||||
| } | |||||
| return ge::FAILED; | |||||
| } | |||||
| domi::Status ConvertModelToJson() { | domi::Status ConvertModelToJson() { | ||||
| Status ret = GFlagUtils::CheckConverJsonParamFlags(); | Status ret = GFlagUtils::CheckConverJsonParamFlags(); | ||||
| GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "Check convert json params flags failed!"); | GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "Check convert json params flags failed!"); | ||||
| @@ -1176,6 +1204,8 @@ bool CheckRet(domi::Status ret) { | |||||
| GELOGW("ATC convert model to json file failed."); | GELOGW("ATC convert model to json file failed."); | ||||
| } else if (FLAGS_mode == PBTXT_TO_JSON) { | } else if (FLAGS_mode == PBTXT_TO_JSON) { | ||||
| GELOGW("ATC convert pbtxt to json file failed."); | GELOGW("ATC convert pbtxt to json file failed."); | ||||
| } else if (FLAGS_mode == ge::RunMode::DISPLAY_OM_INFO) { | |||||
| GELOGW("ATC display om info failed."); | |||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -1190,6 +1220,8 @@ bool CheckRet(domi::Status ret) { | |||||
| GELOGI("ATC convert model to json file success."); | GELOGI("ATC convert model to json file success."); | ||||
| } else if (FLAGS_mode == PBTXT_TO_JSON) { | } else if (FLAGS_mode == PBTXT_TO_JSON) { | ||||
| GELOGI("ATC convert pbtxt to json file success."); | GELOGI("ATC convert pbtxt to json file success."); | ||||
| } else if (FLAGS_mode == ge::RunMode::DISPLAY_OM_INFO) { | |||||
| GELOGW("ATC display om info success."); | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -1309,6 +1341,9 @@ int main(int argc, char* argv[]) { | |||||
| } else if (FLAGS_mode == ge::RunMode::PBTXT_TO_JSON) { | } else if (FLAGS_mode == ge::RunMode::PBTXT_TO_JSON) { | ||||
| GE_CHK_BOOL_EXEC(ConvertPbtxtToJson() == domi::SUCCESS, ret = domi::FAILED; | GE_CHK_BOOL_EXEC(ConvertPbtxtToJson() == domi::SUCCESS, ret = domi::FAILED; | ||||
| break, "ATC convert pbtxt to json execute failed!!"); | break, "ATC convert pbtxt to json execute failed!!"); | ||||
| } else if (FLAGS_mode == ge::RunMode::DISPLAY_OM_INFO) { | |||||
| GE_CHK_BOOL_EXEC(DisplayModelInfo() == domi::SUCCESS, ret = domi::FAILED; | |||||
| break, "ATC DisplayModelInfo failed!!"); | |||||
| } else { | } else { | ||||
| ErrorManager::GetInstance().ATCReportErrMessage( | ErrorManager::GetInstance().ATCReportErrMessage( | ||||
| "E10001", {"parameter", "value", "reason"}, {"--mode", std::to_string(FLAGS_mode), kModeSupport}); | "E10001", {"parameter", "value", "reason"}, {"--mode", std::to_string(FLAGS_mode), kModeSupport}); | ||||
| @@ -68,6 +68,7 @@ const std::string kScopeIdAttr = "fusion_scope"; | |||||
| const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\""; | const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\""; | ||||
| const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8"; | const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8"; | ||||
| const char *const kOutputTypeError = "The multiple out nodes set in output_type must be found in out_nodes."; | const char *const kOutputTypeError = "The multiple out nodes set in output_type must be found in out_nodes."; | ||||
| const size_t kOmInfoSize = 5; | |||||
| } // namespace | } // namespace | ||||
| // When the model is converted to a JSON file, the following operator attributes in the blacklist will be ignored | // When the model is converted to a JSON file, the following operator attributes in the blacklist will be ignored | ||||
| @@ -865,9 +866,78 @@ void GetGroupName(ge::proto::ModelDef &model_def) { | |||||
| }); | }); | ||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY Status ConvertOmModelToJson(const char *model_file, const char *json_file) { | |||||
| FMK_FUNC_HOST_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def) { | |||||
| std::cout << "============ Display Model Info start ============" << std::endl; | |||||
| auto model_attr_map = model_def->mutable_attr(); | |||||
| // system info | |||||
| auto iter = model_attr_map->find(ATTR_MODEL_ATC_VERSION); | |||||
| auto atc_version = (iter != model_attr_map->end()) ? iter->second.s() : ""; | |||||
| iter = model_attr_map->find("soc_version"); | |||||
| auto soc_version = (iter != model_attr_map->end()) ? iter->second.s() : ""; | |||||
| iter = model_attr_map->find("framework_type"); | |||||
| auto framework_type = (iter != model_attr_map->end()) ? iter->second.s() : ""; | |||||
| std::cout << "system info: " | |||||
| << ATTR_MODEL_ATC_VERSION | |||||
| << "[" << atc_version << "], " | |||||
| << "soc_version" | |||||
| << "[" << soc_version << "], " | |||||
| << "framework_type" | |||||
| << "[" << framework_type << "]." << std::endl; | |||||
| // resource info | |||||
| iter = model_attr_map->find(ATTR_MODEL_MEMORY_SIZE); | |||||
| auto memory_size = (iter != model_attr_map->end()) ? iter->second.i() : -1; | |||||
| iter = model_attr_map->find(ATTR_MODEL_WEIGHT_SIZE); | |||||
| auto weight_size = (iter != model_attr_map->end()) ? iter->second.i() : -1; | |||||
| iter = model_attr_map->find(ATTR_MODEL_STREAM_NUM); | |||||
| auto stream_num = (iter != model_attr_map->end()) ? iter->second.i() : -1; | |||||
| iter = model_attr_map->find(ATTR_MODEL_EVENT_NUM); | |||||
| auto event_num = (iter != model_attr_map->end()) ? iter->second.i() : -1; | |||||
| std::cout << "resource info: " | |||||
| << ATTR_MODEL_MEMORY_SIZE | |||||
| << "[" << memory_size << " B], " | |||||
| << ATTR_MODEL_WEIGHT_SIZE | |||||
| << "[" << weight_size << " B], " | |||||
| << ATTR_MODEL_STREAM_NUM | |||||
| << "[" << stream_num << "], " | |||||
| << ATTR_MODEL_EVENT_NUM | |||||
| << "[" << event_num << "]." | |||||
| << std::endl; | |||||
| // om info | |||||
| iter = model_attr_map->find("om_info_list"); | |||||
| if (iter == model_attr_map->end()) { | |||||
| std::cout << "Display Model Info failed, attr \"om_info_list\" is not found in om, check the version is matched." | |||||
| << std::endl; | |||||
| std::cout << "============ Display Model Info end ============" << std::endl; | |||||
| return; | |||||
| } | |||||
| auto list_size = iter->second.list().i_size(); | |||||
| if (list_size == kOmInfoSize) { | |||||
| std::cout << "om info: " | |||||
| << "modeldef_size" | |||||
| << "[" << iter->second.list().i(0) << " B], " | |||||
| << "weight_data_size" | |||||
| << "[" << iter->second.list().i(1) << " B], " | |||||
| << "tbe_kernels_size" | |||||
| << "[" << iter->second.list().i(2) << " B], " | |||||
| << "cust_aicpu_kernel_store_size" | |||||
| << "[" << iter->second.list().i(3) << " B], " | |||||
| << "task_info_size" | |||||
| << "[" << iter->second.list().i(4) << " B]." << std::endl; | |||||
| } else { | |||||
| std::cout << "Display Model Info error, please check!" << std::endl; | |||||
| }; | |||||
| std::cout << "============ Display Model Info end ============" << std::endl; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY Status ConvertOm(const char *model_file, const char *json_file, bool is_covert_to_json) { | |||||
| GE_CHECK_NOTNULL(model_file); | GE_CHECK_NOTNULL(model_file); | ||||
| GE_CHECK_NOTNULL(json_file); | |||||
| if (is_covert_to_json) { | |||||
| GE_CHECK_NOTNULL(json_file); | |||||
| } | |||||
| ge::ModelData model; | ge::ModelData model; | ||||
| // Mode 2 does not need to verify the priority, and a default value of 0 is passed | // Mode 2 does not need to verify the priority, and a default value of 0 is passed | ||||
| @@ -913,12 +983,16 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertOmModelToJson(const char *model_file, con | |||||
| // De serialization | // De serialization | ||||
| bool flag = ReadProtoFromArray(ir_part.data, ir_part.size, &model_def); | bool flag = ReadProtoFromArray(ir_part.data, ir_part.size, &model_def); | ||||
| if (flag) { | if (flag) { | ||||
| GetGroupName(model_def); | |||||
| if (is_covert_to_json) { | |||||
| GetGroupName(model_def); | |||||
| json j; | |||||
| Pb2Json::Message2Json(model_def, kOmBlackFields, j, true); | |||||
| json j; | |||||
| Pb2Json::Message2Json(model_def, kOmBlackFields, j, true); | |||||
| ret = ModelSaver::SaveJsonToFile(json_file, j); | |||||
| ret = ModelSaver::SaveJsonToFile(json_file, j); | |||||
| } else { | |||||
| PrintModelInfo(&model_def); | |||||
| } | |||||
| } else { | } else { | ||||
| ret = INTERNAL_ERROR; | ret = INTERNAL_ERROR; | ||||
| GELOGE(ret, "ReadProtoFromArray failed."); | GELOGE(ret, "ReadProtoFromArray failed."); | ||||
| @@ -233,6 +233,9 @@ const std::string INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16"; | |||||
| // 0: close debug; 1: open TBE compiler; 2: open ccec compiler | // 0: close debug; 1: open TBE compiler; 2: open ccec compiler | ||||
| const std::string OP_DEBUG_LEVEL = "ge.opDebugLevel"; | const std::string OP_DEBUG_LEVEL = "ge.opDebugLevel"; | ||||
| // Configure display_model_info flag | |||||
| const std::string DISPLAY_MODEL_INFO = "ge.display_model_info"; | |||||
| // Graph run mode | // Graph run mode | ||||
| enum GraphRunMode { PREDICTION = 0, TRAIN }; | enum GraphRunMode { PREDICTION = 0, TRAIN }; | ||||
| @@ -61,6 +61,7 @@ class ModelHelper { | |||||
| Status GenerateGeModel(OmFileLoadHelper& om_load_helper); | Status GenerateGeModel(OmFileLoadHelper& om_load_helper); | ||||
| Status LoadModelData(OmFileLoadHelper& om_load_helper); | Status LoadModelData(OmFileLoadHelper& om_load_helper); | ||||
| void SetModelToGeModel(ge::Model& model); | void SetModelToGeModel(ge::Model& model); | ||||
| Status SaveSizeToModelDef(const GeModelPtr &ge_model); | |||||
| Status LoadWeights(OmFileLoadHelper& om_load_helper); | Status LoadWeights(OmFileLoadHelper& om_load_helper); | ||||
| Status LoadTask(OmFileLoadHelper& om_load_helper); | Status LoadTask(OmFileLoadHelper& om_load_helper); | ||||
| Status LoadTBEKernelStore(OmFileLoadHelper& om_load_helper); | Status LoadTBEKernelStore(OmFileLoadHelper& om_load_helper); | ||||
| @@ -74,7 +74,7 @@ Status ParseGraph(ge::Graph &graph, const std::map<string, string> &atc_params, | |||||
| * @param [key] encrypted key | * @param [key] encrypted key | ||||
| * @return Status result code | * @return Status result code | ||||
| */ | */ | ||||
| Status ConvertOmModelToJson(const char *model_file, const char *json_file); | |||||
| Status ConvertOm(const char *model_file, const char *json_file, bool is_covert_to_json); | |||||
| Status ConvertPbtxtToJson(const char *model_file, const char *json_file); | Status ConvertPbtxtToJson(const char *model_file, const char *json_file); | ||||
| /** | /** | ||||
| @@ -106,6 +106,8 @@ void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &ou | |||||
| void UpdateOmgCtxWithParserCtx(); | void UpdateOmgCtxWithParserCtx(); | ||||
| void UpdateParserCtxWithOmgCtx(); | void UpdateParserCtxWithOmgCtx(); | ||||
| void PrintModelInfo(ge::proto::ModelDef *model_def); | |||||
| } // namespace ge | } // namespace ge | ||||
| namespace domi { | namespace domi { | ||||
| @@ -47,7 +47,8 @@ enum RunMode { | |||||
| GEN_OM_MODEL = 0, // generate offline model file | GEN_OM_MODEL = 0, // generate offline model file | ||||
| MODEL_TO_JSON = 1, // convert to JSON file | MODEL_TO_JSON = 1, // convert to JSON file | ||||
| ONLY_PRE_CHECK = 3, // only for pre-check | ONLY_PRE_CHECK = 3, // only for pre-check | ||||
| PBTXT_TO_JSON = 5 // pbtxt to json | |||||
| PBTXT_TO_JSON = 5, // pbtxt to json | |||||
| DISPLAY_OM_INFO = 6 // display model info | |||||
| }; | }; | ||||
| /// | /// | ||||