Browse Source

!1120 Feature:display model info

From: @wangwenhua1
Reviewed-by: @xchu42,@wqtshg
Signed-off-by: @wqtshg
tags/v1.2.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
45d6a09b12
4 changed files with 24 additions and 34 deletions
  1. +0
    -13
      ge/common/helper/model_helper.cc
  2. +8
    -8
      ge/session/omg.cc
  3. +14
    -10
      inc/framework/omg/omg.h
  4. +2
    -3
      tests/ut/ge/session/omg_omg_unittest.cc

+ 0
- 13
ge/common/helper/model_helper.cc View File

@@ -78,19 +78,6 @@ Status ModelHelper::SaveModelPartition(std::shared_ptr<OmFileSaveHelper> &om_fil

Status ModelHelper::SaveSizeToModelDef(const GeModelPtr &ge_model) {
vector<int64_t> om_info;
ModelPtr model_tmp = ge::MakeShared<ge::Model>(ge_model->GetName(), ge_model->GetPlatformVersion());
if (model_tmp == nullptr) {
GELOGE(FAILED, "Create Model %s Ptr failed", ge_model->GetName().c_str());
return FAILED;
}
model_tmp->SetGraph(ge_model->GetGraph());
model_tmp->SetVersion(ge_model->GetVersion());
model_tmp->SetAttr(ge_model->MutableAttrMap());
ge::Buffer model_buffer;
(void)model_tmp->Save(model_buffer);
GELOGD("SaveSizeToModelDef modeldef_size is %zu", model_buffer.GetSize());
om_info.push_back(model_buffer.GetSize());

auto ge_model_weight = ge_model->GetWeight();
GELOGD("SaveSizeToModelDef weight_data_size is %zu, %p", ge_model_weight.GetSize(), ge_model_weight.GetData());
om_info.push_back(ge_model_weight.GetSize());


+ 8
- 8
ge/session/omg.cc View File

@@ -71,7 +71,7 @@ const char *const kOutputTypeError = "The multiple out nodes set in output_type
const size_t kNodeNameIndex = 0;
const size_t kIndexStrIndex = 1;
const size_t kDTValueIndex = 2;
const size_t kOmInfoSize = 5;
const size_t kOmInfoSize = 4;
} // namespace

// When the model is converted to a JSON file, the following operator attributes in the blacklist will be ignored
@@ -828,7 +828,7 @@ void GetGroupName(ge::proto::ModelDef &model_def) {
});
}

FMK_FUNC_HOST_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def) {
FMK_FUNC_HOST_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def, uint32_t modeldef_size) {
std::cout << "============ Display Model Info start ============" << std::endl;

auto model_attr_map = model_def->mutable_attr();
@@ -879,15 +879,15 @@ FMK_FUNC_HOST_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def) {
if (list_size == kOmInfoSize) {
std::cout << "om info: "
<< "modeldef_size"
<< "[" << iter->second.list().i(0) << " B], "
<< "[" << modeldef_size << " B], "
<< "weight_data_size"
<< "[" << iter->second.list().i(1) << " B], "
<< "[" << iter->second.list().i(0) << " B], "
<< "tbe_kernels_size"
<< "[" << iter->second.list().i(2) << " B], "
<< "[" << iter->second.list().i(1) << " B], "
<< "cust_aicpu_kernel_store_size"
<< "[" << iter->second.list().i(3) << " B], "
<< "[" << iter->second.list().i(2) << " B], "
<< "task_info_size"
<< "[" << iter->second.list().i(4) << " B]." << std::endl;
<< "[" << iter->second.list().i(3) << " B]." << std::endl;
} else {
std::cout << "Display Model Info error, please check!" << std::endl;
};
@@ -955,7 +955,7 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertOm(const char *model_file, const char *js

ret = ModelSaver::SaveJsonToFile(json_file, j);
} else {
PrintModelInfo(&model_def);
PrintModelInfo(&model_def, ir_part.size);
}
} else {
ret = INTERNAL_ERROR;


+ 14
- 10
inc/framework/omg/omg.h View File

@@ -43,8 +43,8 @@ namespace ge {
* @brief init omg context
* @return void
*/
GE_FUNC_VISIBILITY Status InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format,
bool is_dynamic_input);
GE_FUNC_VISIBILITY Status InitDomiOmgContext(const string &input_shape, const string &input_format,
const string &net_format, bool is_dynamic_input);

/**
* @ingroup domi_omg
@@ -61,9 +61,10 @@ GE_FUNC_VISIBILITY Status InitDomiOmgContext(const string &input_shape, const st
* @param [in] atc_params multiply atc params
* @return Status result code
*/
GE_FUNC_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<string, string> &atc_params, const char *model_file,
const char *weights_file, domi::FrameworkType type, const char *op_conf = nullptr,
const char *target = nullptr, RunMode run_mode = GEN_OM_MODEL, bool is_dynamic_input = false);
GE_FUNC_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<string, string> &atc_params,
const char *model_file, const char *weights_file, domi::FrameworkType type,
const char *op_conf = nullptr, const char *target = nullptr,
RunMode run_mode = GEN_OM_MODEL, bool is_dynamic_input = false);

/**
* @ingroup domi_omg
@@ -85,7 +86,8 @@ GE_FUNC_VISIBILITY Status ConvertPbtxtToJson(const char *model_file, const char
* @param [key] encrypted key
* @return Status result code
*/
GE_FUNC_VISIBILITY Status ConvertFwkModelToJson(domi::FrameworkType framework, const char *model_file, const char *json_file);
GE_FUNC_VISIBILITY Status ConvertFwkModelToJson(domi::FrameworkType framework, const char *model_file,
const char *json_file);

GE_FUNC_VISIBILITY void GetGroupName(ge::proto::ModelDef &model);

@@ -93,18 +95,20 @@ GE_FUNC_VISIBILITY void FindParserSo(const string &path, vector<string> &fileLis

GE_FUNC_VISIBILITY Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file);

GE_FUNC_VISIBILITY Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output_format);
GE_FUNC_VISIBILITY Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type,
const std::string &output_format);

GE_FUNC_VISIBILITY Status GetOutputLeaf(ge::NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info);
GE_FUNC_VISIBILITY Status GetOutputLeaf(ge::NodePtr node,
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info);

GE_FUNC_VISIBILITY void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name);
std::vector<std::string> &output_nodes_name);

GE_FUNC_VISIBILITY void UpdateOmgCtxWithParserCtx();

GE_FUNC_VISIBILITY void UpdateParserCtxWithOmgCtx();

GE_FUNC_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def);
GE_FUNC_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def, uint32_t modeldef_size);
} // namespace ge

namespace domi {


+ 2
- 3
tests/ut/ge/session/omg_omg_unittest.cc View File

@@ -33,7 +33,7 @@ class UtestOmg : public testing::Test {
TEST_F(UtestOmg, display_model_info_failed) {
ge::proto::ModelDef model_def;
PrintModelInfo(&model_def);
PrintModelInfo(&model_def, 1);
}
TEST_F(UtestOmg, display_model_info_success) {
@@ -46,7 +46,6 @@ TEST_F(UtestOmg, display_model_info_success) {
attr_def->mutable_list()->add_i(2);
attr_def->mutable_list()->add_i(3);
attr_def->mutable_list()->add_i(4);
attr_def->mutable_list()->add_i(5);
PrintModelInfo(&model_def);
PrintModelInfo(&model_def, 1);
}
} // namespace ge

Loading…
Cancel
Save