diff --git a/ge/common/helper/model_helper.cc b/ge/common/helper/model_helper.cc index 52917abe..1d5a4a9b 100644 --- a/ge/common/helper/model_helper.cc +++ b/ge/common/helper/model_helper.cc @@ -76,6 +76,48 @@ Status ModelHelper::SaveModelPartition(std::shared_ptr &om_fil return SUCCESS; } +Status ModelHelper::SaveSizeToModelDef(const GeModelPtr &ge_model) { + vector om_info; + ModelPtr model_tmp = ge::MakeShared(ge_model->GetName(), ge_model->GetPlatformVersion()); + if (model_tmp == nullptr) { + GELOGE(FAILED, "Create Model %s Ptr failed", ge_model->GetName().c_str()); + return FAILED; + } + model_tmp->SetGraph(ge_model->GetGraph()); + model_tmp->SetVersion(ge_model->GetVersion()); + model_tmp->SetAttr(ge_model->MutableAttrMap()); + ge::Buffer model_buffer; + (void)model_tmp->Save(model_buffer); + GELOGD("SaveSizeToModelDef modeldef_size is %zu", model_buffer.GetSize()); + om_info.push_back(model_buffer.GetSize()); + + auto ge_model_weight = ge_model->GetWeight(); + GELOGD("SaveSizeToModelDef weight_data_size is %zu, %p", ge_model_weight.GetSize(), ge_model_weight.GetData()); + om_info.push_back(ge_model_weight.GetSize()); + + TBEKernelStore tbe_kernel_store = ge_model->GetTBEKernelStore(); + GELOGD("SaveSizeToModelDef tbe_kernels_size is %zu", tbe_kernel_store.DataSize()); + om_info.push_back(tbe_kernel_store.DataSize()); + + CustAICPUKernelStore cust_aicpu_kernel_store = ge_model->GetCustAICPUKernelStore(); + GELOGD("SaveSizeToModelDef cust aicpu kernels size is %zu", cust_aicpu_kernel_store.DataSize()); + om_info.push_back(cust_aicpu_kernel_store.DataSize()); + + std::shared_ptr model_task_def = ge_model->GetModelTaskDefPtr(); + if (model_task_def == nullptr) { + GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Create model task def ptr failed"); + return ACL_ERROR_GE_MEMORY_ALLOCATION; + } + size_t partition_task_size = model_task_def->ByteSizeLong(); + GELOGD("SaveSizeToModelDef task_info_size is %zu", partition_task_size); + om_info.push_back(partition_task_size); + + GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListInt(*(ge_model.get()), "om_info_list", om_info), + GELOGE(FAILED, "SetListInt of om_info_list failed."); + return FAILED); + + return SUCCESS; +} Status ModelHelper::SaveModelDef(std::shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, ge::Buffer &model_buffer, size_t model_index) { @@ -87,7 +129,11 @@ Status ModelHelper::SaveModelDef(std::shared_ptr &om_file_save model_tmp->SetGraph(ge_model->GetGraph()); model_tmp->SetVersion(ge_model->GetVersion()); model_tmp->SetAttr(ge_model->MutableAttrMap()); - + Status ret = SaveSizeToModelDef(ge_model); + if (ret != SUCCESS) { + GELOGE(ret, "SaveSizeToModelDef failed"); + return ret; + } (void)model_tmp->Save(model_buffer); GELOGD("MODEL_DEF size is %zu", model_buffer.GetSize()); diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index 016f9ef2..f8d4900a 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -336,6 +336,7 @@ class GeGenerator::Impl { bool GetVersionFromPath(const std::string &file_path, std::string &version); bool SetAtcVersionInfo(AttrHolder &obj); bool SetOppVersionInfo(AttrHolder &obj); + bool SetOmSystemInfo(AttrHolder &obj); }; Status GeGenerator::Initialize(const map &options) { @@ -546,6 +547,32 @@ bool GeGenerator::Impl::SetOppVersionInfo(AttrHolder &obj) { return true; } +bool GeGenerator::Impl::SetOmSystemInfo(AttrHolder &obj) { + std::string soc_version; + (void)ge::GetContext().GetOption(ge::SOC_VERSION, soc_version); + GELOGI("SetOmSystemInfo soc_version: %s", soc_version.c_str()); + if (!ge::AttrUtils::SetStr(obj, "soc_version", soc_version)) { + GELOGW("SetStr of soc_version failed."); + return false; + } + + // 0(Caffe) 1(MindSpore) 3(TensorFlow) 5(Onnx) + std::map framework_type_to_string = { + {"0", "Caffe"}, + {"1", "MindSpore"}, + {"3", "TensorFlow"}, + {"5", "Onnx"} + }; + std::string framework_type; + (void)ge::GetContext().GetOption(ge::FRAMEWORK_TYPE, framework_type); + GELOGI("SetOmSystemInfo framework_type: %s", framework_type.c_str()); + if (!ge::AttrUtils::SetStr(obj, "framework_type", framework_type_to_string[framework_type.c_str()])) { + GELOGW("SetStr of framework_type failed."); + return false; + } + return true; +} + Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector &inputs, ModelBufferData &model, bool is_offline) { rtContext_t ctx = nullptr; @@ -842,6 +869,9 @@ Status GeGenerator::Impl::SaveRootModel(const string &file_name_prefix, GeRootMo if (!SetOppVersionInfo(*(model_root.get()))) { GELOGW("SetPackageVersionInfo of ops failed!"); } + if (!SetOmSystemInfo(*(model_root.get()))) { + GELOGW("SetOmsystemInfo failed!"); + } ModelHelper model_helper; model_helper.SetSaveMode(is_offline_); ret = model_helper.SaveToOmRootModel(ge_root_model, save_param_, file_name_prefix, model_buff, is_unknown_shape); diff --git a/ge/offline/main.cc b/ge/offline/main.cc index 2b5bb41a..ed67b913 100755 --- a/ge/offline/main.cc +++ b/ge/offline/main.cc @@ -206,6 +206,8 @@ DEFINE_string(mdl_bank_path, "", "Optional; model bank path"); DEFINE_string(op_bank_path, "", "Optional; op bank path"); +DEFINE_string(display_model_info, "0", "Optional; display model info"); + class GFlagUtils { public: /** @@ -225,7 +227,8 @@ class GFlagUtils { "===== Basic Functionality =====\n" "[General]\n" " --h/help Show this help message\n" - " --mode Run mode. 0(default): generate offline model; 1: convert model to JSON format " + " --mode Run mode. 0(default): generate offline model; 1: convert model to JSON format; " + "6: display model info" "3: only pre-check; 5: convert ge dump txt file to JSON format\n" "\n[Input]\n" " --model Model file\n" @@ -313,7 +316,8 @@ class GFlagUtils { " --op_compiler_cache_dir Set the save path of operator compilation cache files.\n" "Default value: $HOME/atc_data\n" " --op_compiler_cache_mode Set the operator compilation cache mode." - "Options are disable(default), enable and force(force to refresh the cache)"); + "Options are disable(default), enable and force(force to refresh the cache)\n" + " --display_model_info enable for display model info; 0(default): close display, 1: open display"); gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true); // Using gflags to analyze input parameters @@ -862,7 +866,7 @@ domi::Status GenerateInfershapeJson() { static Status ConvertModelToJson(int fwk_type, const string &model_file, const string &json_file) { Status ret = ge::SUCCESS; if (fwk_type == -1) { - ret = ge::ConvertOmModelToJson(model_file.c_str(), json_file.c_str()); + ret = ge::ConvertOm(model_file.c_str(), json_file.c_str(), true); return ret; } @@ -1176,6 +1180,8 @@ domi::Status GenerateOmModel() { options.insert(std::pair(string(ge::MDL_BANK_PATH_FLAG), FLAGS_mdl_bank_path)); options.insert(std::pair(string(ge::OP_BANK_PATH_FLAG), FLAGS_op_bank_path)); + + options.insert(std::pair(string(ge::DISPLAY_MODEL_INFO), FLAGS_display_model_info)); // set enable scope fusion passes SetEnableScopeFusionPasses(FLAGS_enable_scope_fusion_passes); // print atc option map @@ -1188,6 +1194,11 @@ domi::Status GenerateOmModel() { return domi::FAILED; } + if (FLAGS_display_model_info == "1") { + GELOGI("need to display model info."); + return ge::ConvertOm(FLAGS_output.c_str(), "", false); + } + return domi::SUCCESS; } @@ -1201,6 +1212,26 @@ domi::Status ConvertModelToJson() { return domi::SUCCESS; } +domi::Status DisplayModelInfo() { + // No model path passed in + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_om == "", + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"om"}); + return ge::FAILED, + "Input parameter[--om]'s value is empty!!"); + + // Check if the model path is valid + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_om != "" && !ge::CheckInputPathValid(FLAGS_om, "--om"), + return ge::FAILED, + "model file path is invalid: %s.", FLAGS_om.c_str()); + + if (FLAGS_framework == -1) { + return ge::ConvertOm(FLAGS_om.c_str(), "", false); + } + + return ge::FAILED; +} + bool CheckRet(domi::Status ret) { if (ret != domi::SUCCESS) { if (FLAGS_mode == ONLY_PRE_CHECK) { @@ -1344,6 +1375,9 @@ int main(int argc, char* argv[]) { } else if (FLAGS_mode == ge::RunMode::PBTXT_TO_JSON) { GE_CHK_BOOL_EXEC(ConvertPbtxtToJson() == domi::SUCCESS, ret = domi::FAILED; break, "ATC convert pbtxt to json execute failed!!"); + } else if (FLAGS_mode == ge::RunMode::DISPLAY_OM_INFO) { + GE_CHK_BOOL_EXEC(DisplayModelInfo() == domi::SUCCESS, ret = domi::FAILED; + break, "ATC DisplayModelInfo failed!!"); } else { ErrorManager::GetInstance().ATCReportErrMessage( "E10001", {"parameter", "value", "reason"}, {"--mode", std::to_string(FLAGS_mode), kModeSupport}); diff --git a/ge/session/omg.cc b/ge/session/omg.cc index 7ff52e82..11384cfb 100755 --- a/ge/session/omg.cc +++ b/ge/session/omg.cc @@ -71,6 +71,7 @@ const char *const kOutputTypeError = "The multiple out nodes set in output_type const size_t kNodeNameIndex = 0; const size_t kIndexStrIndex = 1; const size_t kDTValueIndex = 2; +const size_t kOmInfoSize = 5; } // namespace // When the model is converted to a JSON file, the following operator attributes in the blacklist will be ignored @@ -869,9 +870,78 @@ void GetGroupName(ge::proto::ModelDef &model_def) { }); } -FMK_FUNC_HOST_VISIBILITY Status ConvertOmModelToJson(const char *model_file, const char *json_file) { +FMK_FUNC_HOST_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def) { + std::cout << "============ Display Model Info start ============" << std::endl; + + auto model_attr_map = model_def->mutable_attr(); + // system info + auto iter = model_attr_map->find(ATTR_MODEL_ATC_VERSION); + auto atc_version = (iter != model_attr_map->end()) ? iter->second.s() : ""; + iter = model_attr_map->find("soc_version"); + auto soc_version = (iter != model_attr_map->end()) ? iter->second.s() : ""; + iter = model_attr_map->find("framework_type"); + auto framework_type = (iter != model_attr_map->end()) ? iter->second.s() : ""; + std::cout << "system info: " + << ATTR_MODEL_ATC_VERSION + << "[" << atc_version << "], " + << "soc_version" + << "[" << soc_version << "], " + << "framework_type" + << "[" << framework_type << "]." << std::endl; + + // resource info + iter = model_attr_map->find(ATTR_MODEL_MEMORY_SIZE); + auto memory_size = (iter != model_attr_map->end()) ? iter->second.i() : -1; + iter = model_attr_map->find(ATTR_MODEL_WEIGHT_SIZE); + auto weight_size = (iter != model_attr_map->end()) ? iter->second.i() : -1; + iter = model_attr_map->find(ATTR_MODEL_STREAM_NUM); + auto stream_num = (iter != model_attr_map->end()) ? iter->second.i() : -1; + iter = model_attr_map->find(ATTR_MODEL_EVENT_NUM); + auto event_num = (iter != model_attr_map->end()) ? iter->second.i() : -1; + std::cout << "resource info: " + << ATTR_MODEL_MEMORY_SIZE + << "[" << memory_size << " B], " + << ATTR_MODEL_WEIGHT_SIZE + << "[" << weight_size << " B], " + << ATTR_MODEL_STREAM_NUM + << "[" << stream_num << "], " + << ATTR_MODEL_EVENT_NUM + << "[" << event_num << "]." + << std::endl; + + // om info + iter = model_attr_map->find("om_info_list"); + if (iter == model_attr_map->end()) { + std::cout << "Display Model Info failed, attr \"om_info_list\" is not found in om, check the version is matched." + << std::endl; + std::cout << "============ Display Model Info end ============" << std::endl; + return; + } + auto list_size = iter->second.list().i_size(); + if (list_size == kOmInfoSize) { + std::cout << "om info: " + << "modeldef_size" + << "[" << iter->second.list().i(0) << " B], " + << "weight_data_size" + << "[" << iter->second.list().i(1) << " B], " + << "tbe_kernels_size" + << "[" << iter->second.list().i(2) << " B], " + << "cust_aicpu_kernel_store_size" + << "[" << iter->second.list().i(3) << " B], " + << "task_info_size" + << "[" << iter->second.list().i(4) << " B]." << std::endl; + } else { + std::cout << "Display Model Info error, please check!" << std::endl; + }; + + std::cout << "============ Display Model Info end ============" << std::endl; +} + +FMK_FUNC_HOST_VISIBILITY Status ConvertOm(const char *model_file, const char *json_file, bool is_covert_to_json) { GE_CHECK_NOTNULL(model_file); - GE_CHECK_NOTNULL(json_file); + if (is_covert_to_json) { + GE_CHECK_NOTNULL(json_file); + } ge::ModelData model; // Mode 2 does not need to verify the priority, and a default value of 0 is passed @@ -917,12 +987,16 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertOmModelToJson(const char *model_file, con // De serialization bool flag = ReadProtoFromArray(ir_part.data, ir_part.size, &model_def); if (flag) { - GetGroupName(model_def); + if (is_covert_to_json) { + GetGroupName(model_def); - json j; - Pb2Json::Message2Json(model_def, kOmBlackFields, j, true); + json j; + Pb2Json::Message2Json(model_def, kOmBlackFields, j, true); - ret = ModelSaver::SaveJsonToFile(json_file, j); + ret = ModelSaver::SaveJsonToFile(json_file, j); + } else { + PrintModelInfo(&model_def); + } } else { ret = INTERNAL_ERROR; GELOGE(ret, "ReadProtoFromArray failed."); diff --git a/inc/external/ge/ge_api_types.h b/inc/external/ge/ge_api_types.h index 8a10a9b0..d0f2105f 100644 --- a/inc/external/ge/ge_api_types.h +++ b/inc/external/ge/ge_api_types.h @@ -291,6 +291,9 @@ const std::string OP_DEBUG_LEVEL = "ge.opDebugLevel"; // Configure model bank path const std::string MDL_BANK_PATH_FLAG = "ge.mdl_bank_path"; +// Configure display_model_info flag +const std::string DISPLAY_MODEL_INFO = "ge.display_model_info"; + // Configure op bank path const std::string OP_BANK_PATH_FLAG = "ge.op_bank_path"; const std::string OP_BANK_UPDATE_FLAG = "ge.op_bank_update"; diff --git a/inc/framework/common/helper/model_helper.h b/inc/framework/common/helper/model_helper.h index bc0444bc..4a169dda 100644 --- a/inc/framework/common/helper/model_helper.h +++ b/inc/framework/common/helper/model_helper.h @@ -84,6 +84,7 @@ class ModelHelper { const uint8_t *data, size_t size, size_t model_index); Status SaveModelDef(shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, Buffer &model_buffer, size_t model_index = 0); + Status SaveSizeToModelDef(const GeModelPtr &ge_model); Status SaveModelWeights(shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, size_t model_index = 0); Status SaveModelTbeKernel(shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, diff --git a/inc/framework/omg/omg.h b/inc/framework/omg/omg.h index e7ca05f7..62332b8d 100644 --- a/inc/framework/omg/omg.h +++ b/inc/framework/omg/omg.h @@ -73,7 +73,7 @@ Status ParseGraph(ge::Graph &graph, const std::map &atc_params, * @param [key] encrypted key * @return Status result code */ -Status ConvertOmModelToJson(const char *model_file, const char *json_file); +Status ConvertOm(const char *model_file, const char *json_file, bool is_covert_to_json); Status ConvertPbtxtToJson(const char *model_file, const char *json_file); /** @@ -103,6 +103,8 @@ void GetOutputNodesNameAndIndex(std::vector> &ou void UpdateOmgCtxWithParserCtx(); void UpdateParserCtxWithOmgCtx(); + +void PrintModelInfo(ge::proto::ModelDef *model_def); } // namespace ge namespace domi { diff --git a/inc/framework/omg/omg_inner_types.h b/inc/framework/omg/omg_inner_types.h index 454890aa..dab79053 100644 --- a/inc/framework/omg/omg_inner_types.h +++ b/inc/framework/omg/omg_inner_types.h @@ -46,7 +46,8 @@ enum RunMode { GEN_OM_MODEL = 0, // generate offline model file MODEL_TO_JSON = 1, // convert to JSON file ONLY_PRE_CHECK = 3, // only for pre-check - PBTXT_TO_JSON = 5 // pbtxt to json + PBTXT_TO_JSON = 5, // pbtxt to json + DISPLAY_OM_INFO = 6 // display model info }; /// diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index ebaf7708..6db99a45 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -49,6 +49,7 @@ include_directories(${GE_CODE_DIR}/metadef) include_directories(${GE_CODE_DIR}/metadef/graph) include_directories(${GE_CODE_DIR}/inc/external) include_directories(${GE_CODE_DIR}/metadef/inc/external) +include_directories(${GE_CODE_DIR}/parser) include_directories(${GE_CODE_DIR}/parser/parser) include_directories(${GE_CODE_DIR}/metadef/inc/external/graph) include_directories(${GE_CODE_DIR}/metadef/inc/graph) @@ -302,6 +303,7 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/manager/graph_caching_allocator.cc" "${GE_CODE_DIR}/ge/graph/manager/rdma_pool_allocator.cc" "${GE_CODE_DIR}/ge/common/dump/dump_op.cc" + "${GE_CODE_DIR}/ge/common/model_saver.cc" "${GE_CODE_DIR}/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc" "${GE_CODE_DIR}/ge/common/ge/datatype_util.cc" "${GE_CODE_DIR}/metadef/register/ops_kernel_builder_registry.cc" @@ -309,6 +311,13 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/metadef/graph/utils/tuning_utils.cc" "${GE_CODE_DIR}/metadef/register/op_tiling_registry.cpp" "${GE_CODE_DIR}/ge/ge_local_engine/engine/host_cpu_engine.cc" + "${GE_CODE_DIR}/parser/parser/common/pre_checker.cc" + "${GE_CODE_DIR}/parser/parser/common/convert/pb2json.cc" + "${GE_CODE_DIR}/parser/parser/common/parser_factory.cc" + "${GE_CODE_DIR}/parser/parser/common/model_saver.cc" + "${GE_CODE_DIR}/parser/parser/common/parser_types.cc" + "${GE_CODE_DIR}/parser/parser/common/parser_inner_ctx.cc" + "${GE_CODE_DIR}/ge/session/omg.cc" ) set(COMMON_FORMAT_SRC_FILES @@ -398,7 +407,6 @@ set(DISTINCT_GRAPH_LOAD_SRC_FILES "${GE_CODE_DIR}/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" "${GE_CODE_DIR}/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" "${GE_CODE_DIR}/ge/model/ge_model.cc" - "${GE_CODE_DIR}/ge/common/helper/model_helper.cc" "${GE_CODE_DIR}/ge/common/helper/om_file_helper.cc" "${GE_CODE_DIR}/ge/common/debug/memory_dumper.cc" "${GE_CODE_DIR}/ge/executor/ge_executor.cc" @@ -429,7 +437,6 @@ set(GRAPH_BUILD_COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/build/memory/hybrid_mem_assigner.cc" "${GE_CODE_DIR}/ge/graph/build/memory/max_block_mem_assigner.cc" "${GE_CODE_DIR}/ge/model/ge_model.cc" - "${GE_CODE_DIR}/ge/common/helper/model_helper.cc" "${GE_CODE_DIR}/ge/common/helper/om_file_helper.cc" "${GE_CODE_DIR}/ge/common/tbe_kernel_store.cc" "${GE_CODE_DIR}/ge/common/thread_pool.cc" @@ -574,6 +581,7 @@ set(DISTINCT_GRAPH_LOAD_TEST_FILES "graph/load/memcpy_async_task_info_unittest.cc" #"graph/graph_load_unittest.cc" "graph/ge_executor_unittest.cc" + "graph/load/model_helper_unittest.cc" ) set(PASS_TEST_FILES @@ -679,6 +687,7 @@ set(MULTI_PARTS_TEST_FILES "graph/variable_accelerate_ctrl_unittest.cc" "graph/build/logical_stream_allocator_unittest.cc" "graph/build/mem_assigner_unittest.cc" + "session/omg_omg_unittest.cc" ) set(SINGLE_OP_TEST_FILES diff --git a/tests/ut/ge/graph/load/model_helper_unittest.cc b/tests/ut/ge/graph/load/model_helper_unittest.cc new file mode 100644 index 00000000..455285bf --- /dev/null +++ b/tests/ut/ge/graph/load/model_helper_unittest.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#define private public +#define protected public +#include "framework/common/helper/model_helper.h" +#include "ge/model/ge_model.h" +#undef private +#undef protected + +#include "proto/task.pb.h" + +using namespace std; + +namespace ge { +class UtestModelHelper : public testing::Test { + protected: + void SetUp() override {} + + void TearDown() override {} +}; + +TEST_F(UtestModelHelper, save_size_to_modeldef_failed) +{ + GeModelPtr ge_model = ge::MakeShared(); + ModelHelper model_helper; + EXPECT_EQ(ACL_ERROR_GE_MEMORY_ALLOCATION, model_helper.SaveSizeToModelDef(ge_model)); +} + +TEST_F(UtestModelHelper, save_size_to_modeldef) +{ + GeModelPtr ge_model = ge::MakeShared(); + std::shared_ptr task = ge::MakeShared(); + ge_model->SetModelTaskDef(task); + ModelHelper model_helper; + EXPECT_EQ(SUCCESS, model_helper.SaveSizeToModelDef(ge_model)); +} +} // namespace ge diff --git a/tests/ut/ge/session/omg_omg_unittest.cc b/tests/ut/ge/session/omg_omg_unittest.cc new file mode 100644 index 00000000..b9c7f1ec --- /dev/null +++ b/tests/ut/ge/session/omg_omg_unittest.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "common/ge/ge_util.h" +#include "proto/ge_ir.pb.h" +#include "inc/framework/omg/omg.h" + + +using namespace std; + +namespace ge { +class UtestOmg : public testing::Test { + protected: + void SetUp() override {} + + void TearDown() override {} +}; + +TEST_F(UtestOmg, display_model_info_failed) { + ge::proto::ModelDef model_def; + PrintModelInfo(&model_def); +} + +TEST_F(UtestOmg, display_model_info_success) { + ge::proto::ModelDef model_def; + auto attrs = model_def.mutable_attr(); + ge::proto::AttrDef *attr_def_soc = &(*attrs)["soc_version"]; + attr_def_soc->set_s("Ascend310"); + ge::proto::AttrDef *attr_def = &(*attrs)["om_info_list"]; + attr_def->mutable_list()->add_i(1); + attr_def->mutable_list()->add_i(2); + attr_def->mutable_list()->add_i(3); + attr_def->mutable_list()->add_i(4); + attr_def->mutable_list()->add_i(5); + PrintModelInfo(&model_def); +} +} // namespace ge