From: @wangwenhua1 Reviewed-by: @xchu42,@wqtshg Signed-off-by: @wqtshgtags/v1.2.0
| @@ -78,19 +78,6 @@ Status ModelHelper::SaveModelPartition(std::shared_ptr<OmFileSaveHelper> &om_fil | |||||
| Status ModelHelper::SaveSizeToModelDef(const GeModelPtr &ge_model) { | Status ModelHelper::SaveSizeToModelDef(const GeModelPtr &ge_model) { | ||||
| vector<int64_t> om_info; | vector<int64_t> om_info; | ||||
| ModelPtr model_tmp = ge::MakeShared<ge::Model>(ge_model->GetName(), ge_model->GetPlatformVersion()); | |||||
| if (model_tmp == nullptr) { | |||||
| GELOGE(FAILED, "Create Model %s Ptr failed", ge_model->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| model_tmp->SetGraph(ge_model->GetGraph()); | |||||
| model_tmp->SetVersion(ge_model->GetVersion()); | |||||
| model_tmp->SetAttr(ge_model->MutableAttrMap()); | |||||
| ge::Buffer model_buffer; | |||||
| (void)model_tmp->Save(model_buffer); | |||||
| GELOGD("SaveSizeToModelDef modeldef_size is %zu", model_buffer.GetSize()); | |||||
| om_info.push_back(model_buffer.GetSize()); | |||||
| auto ge_model_weight = ge_model->GetWeight(); | auto ge_model_weight = ge_model->GetWeight(); | ||||
| GELOGD("SaveSizeToModelDef weight_data_size is %zu, %p", ge_model_weight.GetSize(), ge_model_weight.GetData()); | GELOGD("SaveSizeToModelDef weight_data_size is %zu, %p", ge_model_weight.GetSize(), ge_model_weight.GetData()); | ||||
| om_info.push_back(ge_model_weight.GetSize()); | om_info.push_back(ge_model_weight.GetSize()); | ||||
| @@ -71,7 +71,7 @@ const char *const kOutputTypeError = "The multiple out nodes set in output_type | |||||
| const size_t kNodeNameIndex = 0; | const size_t kNodeNameIndex = 0; | ||||
| const size_t kIndexStrIndex = 1; | const size_t kIndexStrIndex = 1; | ||||
| const size_t kDTValueIndex = 2; | const size_t kDTValueIndex = 2; | ||||
| const size_t kOmInfoSize = 5; | |||||
| const size_t kOmInfoSize = 4; | |||||
| } // namespace | } // namespace | ||||
| // When the model is converted to a JSON file, the following operator attributes in the blacklist will be ignored | // When the model is converted to a JSON file, the following operator attributes in the blacklist will be ignored | ||||
| @@ -828,7 +828,7 @@ void GetGroupName(ge::proto::ModelDef &model_def) { | |||||
| }); | }); | ||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def) { | |||||
| FMK_FUNC_HOST_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def, uint32_t modeldef_size) { | |||||
| std::cout << "============ Display Model Info start ============" << std::endl; | std::cout << "============ Display Model Info start ============" << std::endl; | ||||
| auto model_attr_map = model_def->mutable_attr(); | auto model_attr_map = model_def->mutable_attr(); | ||||
| @@ -879,15 +879,15 @@ FMK_FUNC_HOST_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def) { | |||||
| if (list_size == kOmInfoSize) { | if (list_size == kOmInfoSize) { | ||||
| std::cout << "om info: " | std::cout << "om info: " | ||||
| << "modeldef_size" | << "modeldef_size" | ||||
| << "[" << iter->second.list().i(0) << " B], " | |||||
| << "[" << modeldef_size << " B], " | |||||
| << "weight_data_size" | << "weight_data_size" | ||||
| << "[" << iter->second.list().i(1) << " B], " | |||||
| << "[" << iter->second.list().i(0) << " B], " | |||||
| << "tbe_kernels_size" | << "tbe_kernels_size" | ||||
| << "[" << iter->second.list().i(2) << " B], " | |||||
| << "[" << iter->second.list().i(1) << " B], " | |||||
| << "cust_aicpu_kernel_store_size" | << "cust_aicpu_kernel_store_size" | ||||
| << "[" << iter->second.list().i(3) << " B], " | |||||
| << "[" << iter->second.list().i(2) << " B], " | |||||
| << "task_info_size" | << "task_info_size" | ||||
| << "[" << iter->second.list().i(4) << " B]." << std::endl; | |||||
| << "[" << iter->second.list().i(3) << " B]." << std::endl; | |||||
| } else { | } else { | ||||
| std::cout << "Display Model Info error, please check!" << std::endl; | std::cout << "Display Model Info error, please check!" << std::endl; | ||||
| }; | }; | ||||
| @@ -955,7 +955,7 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertOm(const char *model_file, const char *js | |||||
| ret = ModelSaver::SaveJsonToFile(json_file, j); | ret = ModelSaver::SaveJsonToFile(json_file, j); | ||||
| } else { | } else { | ||||
| PrintModelInfo(&model_def); | |||||
| PrintModelInfo(&model_def, ir_part.size); | |||||
| } | } | ||||
| } else { | } else { | ||||
| ret = INTERNAL_ERROR; | ret = INTERNAL_ERROR; | ||||
| @@ -43,8 +43,8 @@ namespace ge { | |||||
| * @brief init omg context | * @brief init omg context | ||||
| * @return void | * @return void | ||||
| */ | */ | ||||
| GE_FUNC_VISIBILITY Status InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format, | |||||
| bool is_dynamic_input); | |||||
| GE_FUNC_VISIBILITY Status InitDomiOmgContext(const string &input_shape, const string &input_format, | |||||
| const string &net_format, bool is_dynamic_input); | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -61,9 +61,10 @@ GE_FUNC_VISIBILITY Status InitDomiOmgContext(const string &input_shape, const st | |||||
| * @param [in] atc_params multiply atc params | * @param [in] atc_params multiply atc params | ||||
| * @return Status result code | * @return Status result code | ||||
| */ | */ | ||||
| GE_FUNC_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<string, string> &atc_params, const char *model_file, | |||||
| const char *weights_file, domi::FrameworkType type, const char *op_conf = nullptr, | |||||
| const char *target = nullptr, RunMode run_mode = GEN_OM_MODEL, bool is_dynamic_input = false); | |||||
| GE_FUNC_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<string, string> &atc_params, | |||||
| const char *model_file, const char *weights_file, domi::FrameworkType type, | |||||
| const char *op_conf = nullptr, const char *target = nullptr, | |||||
| RunMode run_mode = GEN_OM_MODEL, bool is_dynamic_input = false); | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -85,7 +86,8 @@ GE_FUNC_VISIBILITY Status ConvertPbtxtToJson(const char *model_file, const char | |||||
| * @param [key] encrypted key | * @param [key] encrypted key | ||||
| * @return Status result code | * @return Status result code | ||||
| */ | */ | ||||
| GE_FUNC_VISIBILITY Status ConvertFwkModelToJson(domi::FrameworkType framework, const char *model_file, const char *json_file); | |||||
| GE_FUNC_VISIBILITY Status ConvertFwkModelToJson(domi::FrameworkType framework, const char *model_file, | |||||
| const char *json_file); | |||||
| GE_FUNC_VISIBILITY void GetGroupName(ge::proto::ModelDef &model); | GE_FUNC_VISIBILITY void GetGroupName(ge::proto::ModelDef &model); | ||||
| @@ -93,18 +95,20 @@ GE_FUNC_VISIBILITY void FindParserSo(const string &path, vector<string> &fileLis | |||||
| GE_FUNC_VISIBILITY Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file); | GE_FUNC_VISIBILITY Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file); | ||||
| GE_FUNC_VISIBILITY Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output_format); | |||||
| GE_FUNC_VISIBILITY Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, | |||||
| const std::string &output_format); | |||||
| GE_FUNC_VISIBILITY Status GetOutputLeaf(ge::NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info); | |||||
| GE_FUNC_VISIBILITY Status GetOutputLeaf(ge::NodePtr node, | |||||
| std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info); | |||||
| GE_FUNC_VISIBILITY void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | GE_FUNC_VISIBILITY void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | ||||
| std::vector<std::string> &output_nodes_name); | |||||
| std::vector<std::string> &output_nodes_name); | |||||
| GE_FUNC_VISIBILITY void UpdateOmgCtxWithParserCtx(); | GE_FUNC_VISIBILITY void UpdateOmgCtxWithParserCtx(); | ||||
| GE_FUNC_VISIBILITY void UpdateParserCtxWithOmgCtx(); | GE_FUNC_VISIBILITY void UpdateParserCtxWithOmgCtx(); | ||||
| GE_FUNC_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def); | |||||
| GE_FUNC_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def, uint32_t modeldef_size); | |||||
| } // namespace ge | } // namespace ge | ||||
| namespace domi { | namespace domi { | ||||
| @@ -33,7 +33,7 @@ class UtestOmg : public testing::Test { | |||||
| TEST_F(UtestOmg, display_model_info_failed) { | TEST_F(UtestOmg, display_model_info_failed) { | ||||
| ge::proto::ModelDef model_def; | ge::proto::ModelDef model_def; | ||||
| PrintModelInfo(&model_def); | |||||
| PrintModelInfo(&model_def, 1); | |||||
| } | } | ||||
| TEST_F(UtestOmg, display_model_info_success) { | TEST_F(UtestOmg, display_model_info_success) { | ||||
| @@ -46,7 +46,6 @@ TEST_F(UtestOmg, display_model_info_success) { | |||||
| attr_def->mutable_list()->add_i(2); | attr_def->mutable_list()->add_i(2); | ||||
| attr_def->mutable_list()->add_i(3); | attr_def->mutable_list()->add_i(3); | ||||
| attr_def->mutable_list()->add_i(4); | attr_def->mutable_list()->add_i(4); | ||||
| attr_def->mutable_list()->add_i(5); | |||||
| PrintModelInfo(&model_def); | |||||
| PrintModelInfo(&model_def, 1); | |||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||