@@ -76,6 +76,48 @@ Status ModelHelper::SaveModelPartition(std::shared_ptr<OmFileSaveHelper> &om_fil | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status ModelHelper::SaveSizeToModelDef(const GeModelPtr &ge_model) { | |||||
vector<int64_t> om_info; | |||||
ModelPtr model_tmp = ge::MakeShared<ge::Model>(ge_model->GetName(), ge_model->GetPlatformVersion()); | |||||
if (model_tmp == nullptr) { | |||||
GELOGE(FAILED, "Create Model %s Ptr failed", ge_model->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
model_tmp->SetGraph(ge_model->GetGraph()); | |||||
model_tmp->SetVersion(ge_model->GetVersion()); | |||||
model_tmp->SetAttr(ge_model->MutableAttrMap()); | |||||
ge::Buffer model_buffer; | |||||
(void)model_tmp->Save(model_buffer); | |||||
GELOGD("SaveSizeToModelDef modeldef_size is %zu", model_buffer.GetSize()); | |||||
om_info.push_back(model_buffer.GetSize()); | |||||
auto ge_model_weight = ge_model->GetWeight(); | |||||
GELOGD("SaveSizeToModelDef weight_data_size is %zu, %p", ge_model_weight.GetSize(), ge_model_weight.GetData()); | |||||
om_info.push_back(ge_model_weight.GetSize()); | |||||
TBEKernelStore tbe_kernel_store = ge_model->GetTBEKernelStore(); | |||||
GELOGD("SaveSizeToModelDef tbe_kernels_size is %zu", tbe_kernel_store.DataSize()); | |||||
om_info.push_back(tbe_kernel_store.DataSize()); | |||||
CustAICPUKernelStore cust_aicpu_kernel_store = ge_model->GetCustAICPUKernelStore(); | |||||
GELOGD("SaveSizeToModelDef cust aicpu kernels size is %zu", cust_aicpu_kernel_store.DataSize()); | |||||
om_info.push_back(cust_aicpu_kernel_store.DataSize()); | |||||
std::shared_ptr<ModelTaskDef> model_task_def = ge_model->GetModelTaskDefPtr(); | |||||
if (model_task_def == nullptr) { | |||||
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Create model task def ptr failed"); | |||||
return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
} | |||||
size_t partition_task_size = model_task_def->ByteSizeLong(); | |||||
GELOGD("SaveSizeToModelDef task_info_size is %zu", partition_task_size); | |||||
om_info.push_back(partition_task_size); | |||||
GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListInt(*(ge_model.get()), "om_info_list", om_info), | |||||
GELOGE(FAILED, "SetListInt of om_info_list failed."); | |||||
return FAILED); | |||||
return SUCCESS; | |||||
} | |||||
Status ModelHelper::SaveModelDef(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper, | Status ModelHelper::SaveModelDef(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper, | ||||
const GeModelPtr &ge_model, ge::Buffer &model_buffer, size_t model_index) { | const GeModelPtr &ge_model, ge::Buffer &model_buffer, size_t model_index) { | ||||
@@ -87,7 +129,11 @@ Status ModelHelper::SaveModelDef(std::shared_ptr<OmFileSaveHelper> &om_file_save | |||||
model_tmp->SetGraph(ge_model->GetGraph()); | model_tmp->SetGraph(ge_model->GetGraph()); | ||||
model_tmp->SetVersion(ge_model->GetVersion()); | model_tmp->SetVersion(ge_model->GetVersion()); | ||||
model_tmp->SetAttr(ge_model->MutableAttrMap()); | model_tmp->SetAttr(ge_model->MutableAttrMap()); | ||||
Status ret = SaveSizeToModelDef(ge_model); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "SaveSizeToModelDef failed"); | |||||
return ret; | |||||
} | |||||
(void)model_tmp->Save(model_buffer); | (void)model_tmp->Save(model_buffer); | ||||
GELOGD("MODEL_DEF size is %zu", model_buffer.GetSize()); | GELOGD("MODEL_DEF size is %zu", model_buffer.GetSize()); | ||||
@@ -336,6 +336,7 @@ class GeGenerator::Impl { | |||||
bool GetVersionFromPath(const std::string &file_path, std::string &version); | bool GetVersionFromPath(const std::string &file_path, std::string &version); | ||||
bool SetAtcVersionInfo(AttrHolder &obj); | bool SetAtcVersionInfo(AttrHolder &obj); | ||||
bool SetOppVersionInfo(AttrHolder &obj); | bool SetOppVersionInfo(AttrHolder &obj); | ||||
bool SetOmSystemInfo(AttrHolder &obj); | |||||
}; | }; | ||||
Status GeGenerator::Initialize(const map<string, string> &options) { | Status GeGenerator::Initialize(const map<string, string> &options) { | ||||
@@ -546,6 +547,32 @@ bool GeGenerator::Impl::SetOppVersionInfo(AttrHolder &obj) { | |||||
return true; | return true; | ||||
} | } | ||||
bool GeGenerator::Impl::SetOmSystemInfo(AttrHolder &obj) { | |||||
std::string soc_version; | |||||
(void)ge::GetContext().GetOption(ge::SOC_VERSION, soc_version); | |||||
GELOGI("SetOmSystemInfo soc_version: %s", soc_version.c_str()); | |||||
if (!ge::AttrUtils::SetStr(obj, "soc_version", soc_version)) { | |||||
GELOGW("SetStr of soc_version failed."); | |||||
return false; | |||||
} | |||||
// 0(Caffe) 1(MindSpore) 3(TensorFlow) 5(Onnx) | |||||
std::map<string, string> framework_type_to_string = { | |||||
{"0", "Caffe"}, | |||||
{"1", "MindSpore"}, | |||||
{"3", "TensorFlow"}, | |||||
{"5", "Onnx"} | |||||
}; | |||||
std::string framework_type; | |||||
(void)ge::GetContext().GetOption(ge::FRAMEWORK_TYPE, framework_type); | |||||
GELOGI("SetOmSystemInfo framework_type: %s", framework_type.c_str()); | |||||
if (!ge::AttrUtils::SetStr(obj, "framework_type", framework_type_to_string[framework_type.c_str()])) { | |||||
GELOGW("SetStr of framework_type failed."); | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs, | Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs, | ||||
ModelBufferData &model, bool is_offline) { | ModelBufferData &model, bool is_offline) { | ||||
rtContext_t ctx = nullptr; | rtContext_t ctx = nullptr; | ||||
@@ -842,6 +869,9 @@ Status GeGenerator::Impl::SaveRootModel(const string &file_name_prefix, GeRootMo | |||||
if (!SetOppVersionInfo(*(model_root.get()))) { | if (!SetOppVersionInfo(*(model_root.get()))) { | ||||
GELOGW("SetPackageVersionInfo of ops failed!"); | GELOGW("SetPackageVersionInfo of ops failed!"); | ||||
} | } | ||||
if (!SetOmSystemInfo(*(model_root.get()))) { | |||||
GELOGW("SetOmsystemInfo failed!"); | |||||
} | |||||
ModelHelper model_helper; | ModelHelper model_helper; | ||||
model_helper.SetSaveMode(is_offline_); | model_helper.SetSaveMode(is_offline_); | ||||
ret = model_helper.SaveToOmRootModel(ge_root_model, save_param_, file_name_prefix, model_buff, is_unknown_shape); | ret = model_helper.SaveToOmRootModel(ge_root_model, save_param_, file_name_prefix, model_buff, is_unknown_shape); | ||||
@@ -206,6 +206,8 @@ DEFINE_string(mdl_bank_path, "", "Optional; model bank path"); | |||||
DEFINE_string(op_bank_path, "", "Optional; op bank path"); | DEFINE_string(op_bank_path, "", "Optional; op bank path"); | ||||
DEFINE_string(display_model_info, "0", "Optional; display model info"); | |||||
class GFlagUtils { | class GFlagUtils { | ||||
public: | public: | ||||
/** | /** | ||||
@@ -225,7 +227,8 @@ class GFlagUtils { | |||||
"===== Basic Functionality =====\n" | "===== Basic Functionality =====\n" | ||||
"[General]\n" | "[General]\n" | ||||
" --h/help Show this help message\n" | " --h/help Show this help message\n" | ||||
" --mode Run mode. 0(default): generate offline model; 1: convert model to JSON format " | |||||
" --mode Run mode. 0(default): generate offline model; 1: convert model to JSON format; " | |||||
"6: display model info" | |||||
"3: only pre-check; 5: convert ge dump txt file to JSON format\n" | "3: only pre-check; 5: convert ge dump txt file to JSON format\n" | ||||
"\n[Input]\n" | "\n[Input]\n" | ||||
" --model Model file\n" | " --model Model file\n" | ||||
@@ -313,7 +316,8 @@ class GFlagUtils { | |||||
" --op_compiler_cache_dir Set the save path of operator compilation cache files.\n" | " --op_compiler_cache_dir Set the save path of operator compilation cache files.\n" | ||||
"Default value: $HOME/atc_data\n" | "Default value: $HOME/atc_data\n" | ||||
" --op_compiler_cache_mode Set the operator compilation cache mode." | " --op_compiler_cache_mode Set the operator compilation cache mode." | ||||
"Options are disable(default), enable and force(force to refresh the cache)"); | |||||
"Options are disable(default), enable and force(force to refresh the cache)\n" | |||||
" --display_model_info enable for display model info; 0(default): close display, 1: open display"); | |||||
gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true); | gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true); | ||||
// Using gflags to analyze input parameters | // Using gflags to analyze input parameters | ||||
@@ -862,7 +866,7 @@ domi::Status GenerateInfershapeJson() { | |||||
static Status ConvertModelToJson(int fwk_type, const string &model_file, const string &json_file) { | static Status ConvertModelToJson(int fwk_type, const string &model_file, const string &json_file) { | ||||
Status ret = ge::SUCCESS; | Status ret = ge::SUCCESS; | ||||
if (fwk_type == -1) { | if (fwk_type == -1) { | ||||
ret = ge::ConvertOmModelToJson(model_file.c_str(), json_file.c_str()); | |||||
ret = ge::ConvertOm(model_file.c_str(), json_file.c_str(), true); | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -1176,6 +1180,8 @@ domi::Status GenerateOmModel() { | |||||
options.insert(std::pair<string, string>(string(ge::MDL_BANK_PATH_FLAG), FLAGS_mdl_bank_path)); | options.insert(std::pair<string, string>(string(ge::MDL_BANK_PATH_FLAG), FLAGS_mdl_bank_path)); | ||||
options.insert(std::pair<string, string>(string(ge::OP_BANK_PATH_FLAG), FLAGS_op_bank_path)); | options.insert(std::pair<string, string>(string(ge::OP_BANK_PATH_FLAG), FLAGS_op_bank_path)); | ||||
options.insert(std::pair<string, string>(string(ge::DISPLAY_MODEL_INFO), FLAGS_display_model_info)); | |||||
// set enable scope fusion passes | // set enable scope fusion passes | ||||
SetEnableScopeFusionPasses(FLAGS_enable_scope_fusion_passes); | SetEnableScopeFusionPasses(FLAGS_enable_scope_fusion_passes); | ||||
// print atc option map | // print atc option map | ||||
@@ -1188,6 +1194,11 @@ domi::Status GenerateOmModel() { | |||||
return domi::FAILED; | return domi::FAILED; | ||||
} | } | ||||
if (FLAGS_display_model_info == "1") { | |||||
GELOGI("need to display model info."); | |||||
return ge::ConvertOm(FLAGS_output.c_str(), "", false); | |||||
} | |||||
return domi::SUCCESS; | return domi::SUCCESS; | ||||
} | } | ||||
@@ -1201,6 +1212,26 @@ domi::Status ConvertModelToJson() { | |||||
return domi::SUCCESS; | return domi::SUCCESS; | ||||
} | } | ||||
domi::Status DisplayModelInfo() { | |||||
// No model path passed in | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_om == "", | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"om"}); | |||||
return ge::FAILED, | |||||
"Input parameter[--om]'s value is empty!!"); | |||||
// Check if the model path is valid | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
FLAGS_om != "" && !ge::CheckInputPathValid(FLAGS_om, "--om"), | |||||
return ge::FAILED, | |||||
"model file path is invalid: %s.", FLAGS_om.c_str()); | |||||
if (FLAGS_framework == -1) { | |||||
return ge::ConvertOm(FLAGS_om.c_str(), "", false); | |||||
} | |||||
return ge::FAILED; | |||||
} | |||||
bool CheckRet(domi::Status ret) { | bool CheckRet(domi::Status ret) { | ||||
if (ret != domi::SUCCESS) { | if (ret != domi::SUCCESS) { | ||||
if (FLAGS_mode == ONLY_PRE_CHECK) { | if (FLAGS_mode == ONLY_PRE_CHECK) { | ||||
@@ -1344,6 +1375,9 @@ int main(int argc, char* argv[]) { | |||||
} else if (FLAGS_mode == ge::RunMode::PBTXT_TO_JSON) { | } else if (FLAGS_mode == ge::RunMode::PBTXT_TO_JSON) { | ||||
GE_CHK_BOOL_EXEC(ConvertPbtxtToJson() == domi::SUCCESS, ret = domi::FAILED; | GE_CHK_BOOL_EXEC(ConvertPbtxtToJson() == domi::SUCCESS, ret = domi::FAILED; | ||||
break, "ATC convert pbtxt to json execute failed!!"); | break, "ATC convert pbtxt to json execute failed!!"); | ||||
} else if (FLAGS_mode == ge::RunMode::DISPLAY_OM_INFO) { | |||||
GE_CHK_BOOL_EXEC(DisplayModelInfo() == domi::SUCCESS, ret = domi::FAILED; | |||||
break, "ATC DisplayModelInfo failed!!"); | |||||
} else { | } else { | ||||
ErrorManager::GetInstance().ATCReportErrMessage( | ErrorManager::GetInstance().ATCReportErrMessage( | ||||
"E10001", {"parameter", "value", "reason"}, {"--mode", std::to_string(FLAGS_mode), kModeSupport}); | "E10001", {"parameter", "value", "reason"}, {"--mode", std::to_string(FLAGS_mode), kModeSupport}); | ||||
@@ -71,6 +71,7 @@ const char *const kOutputTypeError = "The multiple out nodes set in output_type | |||||
const size_t kNodeNameIndex = 0; | const size_t kNodeNameIndex = 0; | ||||
const size_t kIndexStrIndex = 1; | const size_t kIndexStrIndex = 1; | ||||
const size_t kDTValueIndex = 2; | const size_t kDTValueIndex = 2; | ||||
const size_t kOmInfoSize = 5; | |||||
} // namespace | } // namespace | ||||
// When the model is converted to a JSON file, the following operator attributes in the blacklist will be ignored | // When the model is converted to a JSON file, the following operator attributes in the blacklist will be ignored | ||||
@@ -869,9 +870,78 @@ void GetGroupName(ge::proto::ModelDef &model_def) { | |||||
}); | }); | ||||
} | } | ||||
FMK_FUNC_HOST_VISIBILITY Status ConvertOmModelToJson(const char *model_file, const char *json_file) { | |||||
FMK_FUNC_HOST_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def) { | |||||
std::cout << "============ Display Model Info start ============" << std::endl; | |||||
auto model_attr_map = model_def->mutable_attr(); | |||||
// system info | |||||
auto iter = model_attr_map->find(ATTR_MODEL_ATC_VERSION); | |||||
auto atc_version = (iter != model_attr_map->end()) ? iter->second.s() : ""; | |||||
iter = model_attr_map->find("soc_version"); | |||||
auto soc_version = (iter != model_attr_map->end()) ? iter->second.s() : ""; | |||||
iter = model_attr_map->find("framework_type"); | |||||
auto framework_type = (iter != model_attr_map->end()) ? iter->second.s() : ""; | |||||
std::cout << "system info: " | |||||
<< ATTR_MODEL_ATC_VERSION | |||||
<< "[" << atc_version << "], " | |||||
<< "soc_version" | |||||
<< "[" << soc_version << "], " | |||||
<< "framework_type" | |||||
<< "[" << framework_type << "]." << std::endl; | |||||
// resource info | |||||
iter = model_attr_map->find(ATTR_MODEL_MEMORY_SIZE); | |||||
auto memory_size = (iter != model_attr_map->end()) ? iter->second.i() : -1; | |||||
iter = model_attr_map->find(ATTR_MODEL_WEIGHT_SIZE); | |||||
auto weight_size = (iter != model_attr_map->end()) ? iter->second.i() : -1; | |||||
iter = model_attr_map->find(ATTR_MODEL_STREAM_NUM); | |||||
auto stream_num = (iter != model_attr_map->end()) ? iter->second.i() : -1; | |||||
iter = model_attr_map->find(ATTR_MODEL_EVENT_NUM); | |||||
auto event_num = (iter != model_attr_map->end()) ? iter->second.i() : -1; | |||||
std::cout << "resource info: " | |||||
<< ATTR_MODEL_MEMORY_SIZE | |||||
<< "[" << memory_size << " B], " | |||||
<< ATTR_MODEL_WEIGHT_SIZE | |||||
<< "[" << weight_size << " B], " | |||||
<< ATTR_MODEL_STREAM_NUM | |||||
<< "[" << stream_num << "], " | |||||
<< ATTR_MODEL_EVENT_NUM | |||||
<< "[" << event_num << "]." | |||||
<< std::endl; | |||||
// om info | |||||
iter = model_attr_map->find("om_info_list"); | |||||
if (iter == model_attr_map->end()) { | |||||
std::cout << "Display Model Info failed, attr \"om_info_list\" is not found in om, check the version is matched." | |||||
<< std::endl; | |||||
std::cout << "============ Display Model Info end ============" << std::endl; | |||||
return; | |||||
} | |||||
auto list_size = iter->second.list().i_size(); | |||||
if (list_size == kOmInfoSize) { | |||||
std::cout << "om info: " | |||||
<< "modeldef_size" | |||||
<< "[" << iter->second.list().i(0) << " B], " | |||||
<< "weight_data_size" | |||||
<< "[" << iter->second.list().i(1) << " B], " | |||||
<< "tbe_kernels_size" | |||||
<< "[" << iter->second.list().i(2) << " B], " | |||||
<< "cust_aicpu_kernel_store_size" | |||||
<< "[" << iter->second.list().i(3) << " B], " | |||||
<< "task_info_size" | |||||
<< "[" << iter->second.list().i(4) << " B]." << std::endl; | |||||
} else { | |||||
std::cout << "Display Model Info error, please check!" << std::endl; | |||||
}; | |||||
std::cout << "============ Display Model Info end ============" << std::endl; | |||||
} | |||||
FMK_FUNC_HOST_VISIBILITY Status ConvertOm(const char *model_file, const char *json_file, bool is_covert_to_json) { | |||||
GE_CHECK_NOTNULL(model_file); | GE_CHECK_NOTNULL(model_file); | ||||
GE_CHECK_NOTNULL(json_file); | |||||
if (is_covert_to_json) { | |||||
GE_CHECK_NOTNULL(json_file); | |||||
} | |||||
ge::ModelData model; | ge::ModelData model; | ||||
// Mode 2 does not need to verify the priority, and a default value of 0 is passed | // Mode 2 does not need to verify the priority, and a default value of 0 is passed | ||||
@@ -917,12 +987,16 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertOmModelToJson(const char *model_file, con | |||||
// De serialization | // De serialization | ||||
bool flag = ReadProtoFromArray(ir_part.data, ir_part.size, &model_def); | bool flag = ReadProtoFromArray(ir_part.data, ir_part.size, &model_def); | ||||
if (flag) { | if (flag) { | ||||
GetGroupName(model_def); | |||||
if (is_covert_to_json) { | |||||
GetGroupName(model_def); | |||||
json j; | |||||
Pb2Json::Message2Json(model_def, kOmBlackFields, j, true); | |||||
json j; | |||||
Pb2Json::Message2Json(model_def, kOmBlackFields, j, true); | |||||
ret = ModelSaver::SaveJsonToFile(json_file, j); | |||||
ret = ModelSaver::SaveJsonToFile(json_file, j); | |||||
} else { | |||||
PrintModelInfo(&model_def); | |||||
} | |||||
} else { | } else { | ||||
ret = INTERNAL_ERROR; | ret = INTERNAL_ERROR; | ||||
GELOGE(ret, "ReadProtoFromArray failed."); | GELOGE(ret, "ReadProtoFromArray failed."); | ||||
@@ -291,6 +291,9 @@ const std::string OP_DEBUG_LEVEL = "ge.opDebugLevel"; | |||||
// Configure model bank path | // Configure model bank path | ||||
const std::string MDL_BANK_PATH_FLAG = "ge.mdl_bank_path"; | const std::string MDL_BANK_PATH_FLAG = "ge.mdl_bank_path"; | ||||
// Configure display_model_info flag | |||||
const std::string DISPLAY_MODEL_INFO = "ge.display_model_info"; | |||||
// Configure op bank path | // Configure op bank path | ||||
const std::string OP_BANK_PATH_FLAG = "ge.op_bank_path"; | const std::string OP_BANK_PATH_FLAG = "ge.op_bank_path"; | ||||
const std::string OP_BANK_UPDATE_FLAG = "ge.op_bank_update"; | const std::string OP_BANK_UPDATE_FLAG = "ge.op_bank_update"; | ||||
@@ -84,6 +84,7 @@ class ModelHelper { | |||||
const uint8_t *data, size_t size, size_t model_index); | const uint8_t *data, size_t size, size_t model_index); | ||||
Status SaveModelDef(shared_ptr<OmFileSaveHelper> &om_file_save_helper, const GeModelPtr &ge_model, | Status SaveModelDef(shared_ptr<OmFileSaveHelper> &om_file_save_helper, const GeModelPtr &ge_model, | ||||
Buffer &model_buffer, size_t model_index = 0); | Buffer &model_buffer, size_t model_index = 0); | ||||
Status SaveSizeToModelDef(const GeModelPtr &ge_model); | |||||
Status SaveModelWeights(shared_ptr<OmFileSaveHelper> &om_file_save_helper, const GeModelPtr &ge_model, | Status SaveModelWeights(shared_ptr<OmFileSaveHelper> &om_file_save_helper, const GeModelPtr &ge_model, | ||||
size_t model_index = 0); | size_t model_index = 0); | ||||
Status SaveModelTbeKernel(shared_ptr<OmFileSaveHelper> &om_file_save_helper, const GeModelPtr &ge_model, | Status SaveModelTbeKernel(shared_ptr<OmFileSaveHelper> &om_file_save_helper, const GeModelPtr &ge_model, | ||||
@@ -73,7 +73,7 @@ Status ParseGraph(ge::Graph &graph, const std::map<string, string> &atc_params, | |||||
* @param [key] encrypted key | * @param [key] encrypted key | ||||
* @return Status result code | * @return Status result code | ||||
*/ | */ | ||||
Status ConvertOmModelToJson(const char *model_file, const char *json_file); | |||||
Status ConvertOm(const char *model_file, const char *json_file, bool is_covert_to_json); | |||||
Status ConvertPbtxtToJson(const char *model_file, const char *json_file); | Status ConvertPbtxtToJson(const char *model_file, const char *json_file); | ||||
/** | /** | ||||
@@ -103,6 +103,8 @@ void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &ou | |||||
void UpdateOmgCtxWithParserCtx(); | void UpdateOmgCtxWithParserCtx(); | ||||
void UpdateParserCtxWithOmgCtx(); | void UpdateParserCtxWithOmgCtx(); | ||||
void PrintModelInfo(ge::proto::ModelDef *model_def); | |||||
} // namespace ge | } // namespace ge | ||||
namespace domi { | namespace domi { | ||||
@@ -46,7 +46,8 @@ enum RunMode { | |||||
GEN_OM_MODEL = 0, // generate offline model file | GEN_OM_MODEL = 0, // generate offline model file | ||||
MODEL_TO_JSON = 1, // convert to JSON file | MODEL_TO_JSON = 1, // convert to JSON file | ||||
ONLY_PRE_CHECK = 3, // only for pre-check | ONLY_PRE_CHECK = 3, // only for pre-check | ||||
PBTXT_TO_JSON = 5 // pbtxt to json | |||||
PBTXT_TO_JSON = 5, // pbtxt to json | |||||
DISPLAY_OM_INFO = 6 // display model info | |||||
}; | }; | ||||
/// | /// | ||||
@@ -49,6 +49,7 @@ include_directories(${GE_CODE_DIR}/metadef) | |||||
include_directories(${GE_CODE_DIR}/metadef/graph) | include_directories(${GE_CODE_DIR}/metadef/graph) | ||||
include_directories(${GE_CODE_DIR}/inc/external) | include_directories(${GE_CODE_DIR}/inc/external) | ||||
include_directories(${GE_CODE_DIR}/metadef/inc/external) | include_directories(${GE_CODE_DIR}/metadef/inc/external) | ||||
include_directories(${GE_CODE_DIR}/parser) | |||||
include_directories(${GE_CODE_DIR}/parser/parser) | include_directories(${GE_CODE_DIR}/parser/parser) | ||||
include_directories(${GE_CODE_DIR}/metadef/inc/external/graph) | include_directories(${GE_CODE_DIR}/metadef/inc/external/graph) | ||||
include_directories(${GE_CODE_DIR}/metadef/inc/graph) | include_directories(${GE_CODE_DIR}/metadef/inc/graph) | ||||
@@ -302,6 +303,7 @@ set(COMMON_SRC_FILES | |||||
"${GE_CODE_DIR}/ge/graph/manager/graph_caching_allocator.cc" | "${GE_CODE_DIR}/ge/graph/manager/graph_caching_allocator.cc" | ||||
"${GE_CODE_DIR}/ge/graph/manager/rdma_pool_allocator.cc" | "${GE_CODE_DIR}/ge/graph/manager/rdma_pool_allocator.cc" | ||||
"${GE_CODE_DIR}/ge/common/dump/dump_op.cc" | "${GE_CODE_DIR}/ge/common/dump/dump_op.cc" | ||||
"${GE_CODE_DIR}/ge/common/model_saver.cc" | |||||
"${GE_CODE_DIR}/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc" | "${GE_CODE_DIR}/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc" | ||||
"${GE_CODE_DIR}/ge/common/ge/datatype_util.cc" | "${GE_CODE_DIR}/ge/common/ge/datatype_util.cc" | ||||
"${GE_CODE_DIR}/metadef/register/ops_kernel_builder_registry.cc" | "${GE_CODE_DIR}/metadef/register/ops_kernel_builder_registry.cc" | ||||
@@ -309,6 +311,13 @@ set(COMMON_SRC_FILES | |||||
"${GE_CODE_DIR}/metadef/graph/utils/tuning_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/tuning_utils.cc" | ||||
"${GE_CODE_DIR}/metadef/register/op_tiling_registry.cpp" | "${GE_CODE_DIR}/metadef/register/op_tiling_registry.cpp" | ||||
"${GE_CODE_DIR}/ge/ge_local_engine/engine/host_cpu_engine.cc" | "${GE_CODE_DIR}/ge/ge_local_engine/engine/host_cpu_engine.cc" | ||||
"${GE_CODE_DIR}/parser/parser/common/pre_checker.cc" | |||||
"${GE_CODE_DIR}/parser/parser/common/convert/pb2json.cc" | |||||
"${GE_CODE_DIR}/parser/parser/common/parser_factory.cc" | |||||
"${GE_CODE_DIR}/parser/parser/common/model_saver.cc" | |||||
"${GE_CODE_DIR}/parser/parser/common/parser_types.cc" | |||||
"${GE_CODE_DIR}/parser/parser/common/parser_inner_ctx.cc" | |||||
"${GE_CODE_DIR}/ge/session/omg.cc" | |||||
) | ) | ||||
set(COMMON_FORMAT_SRC_FILES | set(COMMON_FORMAT_SRC_FILES | ||||
@@ -672,6 +681,7 @@ set(MULTI_PARTS_TEST_FILES | |||||
"graph/variable_accelerate_ctrl_unittest.cc" | "graph/variable_accelerate_ctrl_unittest.cc" | ||||
"graph/build/logical_stream_allocator_unittest.cc" | "graph/build/logical_stream_allocator_unittest.cc" | ||||
"graph/build/mem_assigner_unittest.cc" | "graph/build/mem_assigner_unittest.cc" | ||||
"session/omg_omg_unittest.cc" | |||||
) | ) | ||||
set(SINGLE_OP_TEST_FILES | set(SINGLE_OP_TEST_FILES | ||||
@@ -0,0 +1,38 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include "common/ge/ge_util.h" | |||||
#include "proto/ge_ir.pb.h" | |||||
#include "inc/framework/omg/omg.h" | |||||
using namespace ge; | |||||
using namespace std; | |||||
class UTEST_omg_omg : public testing::Test { | |||||
protected: | |||||
void SetUp() override {} | |||||
void TearDown() override {} | |||||
}; | |||||
TEST_F(UTEST_omg_omg, display_model_info_success) | |||||
{ | |||||
ge::proto::ModelDef model_def; | |||||
PrintModelInfo(&model_def); | |||||
} |