Browse Source

!1179 provide interface for atc

From: @jiming6
Reviewed-by: @xchu42,@wqtshg
Signed-off-by:
tags/v1.2.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
44eb9f9926
6 changed files with 206 additions and 1 deletions
  1. +94
    -0
      ge/common/helper/model_helper.cc
  2. +1
    -1
      ge/graph/optimize/common/params.h
  3. +21
    -0
      ge/init/gelib.cc
  4. +36
    -0
      inc/framework/omg/ge_init.h
  5. +35
    -0
      inc/framework/omg/model_tool.h
  6. +19
    -0
      tests/ut/ge/graph/load/model_helper_unittest.cc

+ 94
- 0
ge/common/helper/model_helper.cc View File

@@ -17,6 +17,7 @@
#include "framework/common/helper/model_helper.h" #include "framework/common/helper/model_helper.h"


#include "common/model_parser/model_parser.h" #include "common/model_parser/model_parser.h"
#include "framework/omg/model_tool.h"
#include "framework/omg/version.h" #include "framework/omg/version.h"
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h" #include "graph/utils/graph_utils.h"
@@ -873,4 +874,97 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::GetModelNam
GE_CHK_BOOL_EXEC_WARN(!model_name.empty(), return FAILED, "Get model_name failed, check params --output"); GE_CHK_BOOL_EXEC_WARN(!model_name.empty(), return FAILED, "Get model_name failed, check params --output");
return SUCCESS; return SUCCESS;
} }

Status ModelTool::GetModelInfoFromOm(const char *model_file, ge::proto::ModelDef &model_def, uint32_t &modeldef_size) {
GE_CHECK_NOTNULL(model_file);
ge::ModelData model;
int32_t priority = 0;

Status ret = ModelParserBase::LoadFromFile(model_file, "", priority, model);
if (ret != SUCCESS) {
GELOGE(ret, "LoadFromFile failed.");
return ret;
}
std::function<void()> callback = [&]() {
if (model.model_data != nullptr) {
delete[] reinterpret_cast<char *>(model.model_data);
model.model_data = nullptr;
}
};
GE_MAKE_GUARD(release, callback);

uint8_t *model_data = nullptr;
uint32_t model_len = 0;
ret = ModelParserBase::ParseModelContent(model, model_data, model_len);
if (ret != SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E10003",
{"parameter", "value", "reason"}, {"om", model_file, "invalid om file"});
GELOGE(ACL_ERROR_GE_PARAM_INVALID,
"ParseModelContent failed because of invalid om file. Please check --om param.");
return ret;
}

OmFileLoadHelper om_load_helper;
ret = om_load_helper.Init(model_data, model_len);
if (ret != SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {"Om file init failed"});
GELOGE(ge::FAILED, "Om file init failed.");
return ret;
}

ModelPartition ir_part;
ret = om_load_helper.GetModelPartition(MODEL_DEF, ir_part);
if (ret != SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {"Get model part failed"});
GELOGE(ge::FAILED, "Get model part failed.");
return ret;
}

bool flag = ReadProtoFromArray(ir_part.data, ir_part.size, &model_def);
if (!flag) {
ret = INTERNAL_ERROR;
ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {"ReadProtoFromArray failed"});
GELOGE(ret, "ReadProtoFromArray failed.");
return ret;
}
modeldef_size = ir_part.size;
return ret;
}

Status ModelTool::GetModelInfoFromPbtxt(const char *model_file, ge::proto::ModelDef &model_def) {
GE_CHECK_NOTNULL(model_file);
ge::ModelData model;
int32_t priority = 0;

Status ret = ModelParserBase::LoadFromFile(model_file, "", priority, model);
auto free_model_data = [](void **ptr) -> void {
if (ptr != nullptr && *ptr != nullptr) {
delete[] reinterpret_cast<char *>(*ptr);
*ptr = nullptr;
}
};
if (ret != SUCCESS) {
free_model_data(&model.model_data);
GELOGE(ret, "LoadFromFile failed.");
return ret;
}

try {
bool flag = google::protobuf::TextFormat::ParseFromString(reinterpret_cast<char *>(model.model_data), &model_def);
if (!flag) {
free_model_data(&model.model_data);
ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {"ParseFromString failed"});
GELOGE(FAILED, "ParseFromString failed.");
return FAILED;
}
free_model_data(&model.model_data);
return SUCCESS;
} catch (google::protobuf::FatalException &e) {
free_model_data(&model.model_data);
ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {"ParseFromString failed, exception message["
+ std::string(e.what()) + "]"});
GELOGE(FAILED, "ParseFromString failed. exception message : %s", e.what());
return FAILED;
}
}
} // namespace ge } // namespace ge

+ 1
- 1
ge/graph/optimize/common/params.h View File

@@ -55,7 +55,7 @@ class Params : public Singleton<Params> {
Params() : target_("MINI") {} Params() : target_("MINI") {}


string target_; string target_;
uint8_t target_8bit_ = 0;
uint8_t target_8bit_ = TARGET_TYPE_MINI_8BIT;
}; };
} // namespace ge } // namespace ge




+ 21
- 0
ge/init/gelib.cc View File

@@ -31,6 +31,7 @@
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h" #include "framework/common/debug/log.h"
#include "framework/common/util.h" #include "framework/common/util.h"
#include "framework/omg/ge_init.h"
#include "analyzer/analyzer.h" #include "analyzer/analyzer.h"
#include "ge/ge_api_types.h" #include "ge/ge_api_types.h"
#include "ge_local_engine/engine/host_cpu_engine.h" #include "ge_local_engine/engine/host_cpu_engine.h"
@@ -531,4 +532,24 @@ void GELib::RollbackInit() {
HostMemManager::Instance().Finalize(); HostMemManager::Instance().Finalize();
VarManagerPool::Instance().Destory(); VarManagerPool::Instance().Destory();
} }

Status GEInit::Initialize(const map<string, string> &options) {
Status ret = SUCCESS;
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
ret = GELib::Initialize(options);
}
return ret;
}

Status GEInit::Finalize() {
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr != nullptr) {
return instance_ptr->Finalize();
}
}

string GEInit::GetPath() {
return GELib::GetPath();
}
} // namespace ge } // namespace ge

+ 36
- 0
inc/framework/omg/ge_init.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_FRAMEWORK_OMG_GE_INIT_H_
#define INC_FRAMEWORK_OMG_GE_INIT_H_
#include <map>
#include <string>
#include "common/ge_inner_error_codes.h"

namespace ge {
class GE_FUNC_VISIBILITY GEInit {
public:
// GE Environment Initialize, return Status: SUCCESS,FAILED
static Status Initialize(const std::map<std::string, std::string> &options);

static std::string GetPath();

// GE Environment Finalize, return Status: SUCCESS,FAILED
static Status Finalize();
};
} // namespace ge

#endif // INC_FRAMEWORK_OMG_GE_INIT_H_

+ 35
- 0
inc/framework/omg/model_tool.h View File

@@ -0,0 +1,35 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_FRAMEWORK_OMG_MODEL_TOOL_H_
#define INC_FRAMEWORK_OMG_MODEL_TOOL_H_

#include <memory>
#include <string>

#include "framework/common/debug/ge_log.h"
#include "proto/ge_ir.pb.h"

namespace ge {
class GE_FUNC_VISIBILITY ModelTool {
public:
static Status GetModelInfoFromOm(const char *model_file, ge::proto::ModelDef &model_def, uint32_t &modeldef_size);

static Status GetModelInfoFromPbtxt(const char *model_file, ge::proto::ModelDef &model_def);
};
} // namespace ge

#endif // INC_FRAMEWORK_OMG_MODEL_TOOL_H_

+ 19
- 0
tests/ut/ge/graph/load/model_helper_unittest.cc View File

@@ -18,6 +18,8 @@
#define private public #define private public
#define protected public #define protected public
#include "framework/common/helper/model_helper.h" #include "framework/common/helper/model_helper.h"
#include "framework/omg/model_tool.h"
#include "framework/omg/ge_init.h"
#include "ge/model/ge_model.h" #include "ge/model/ge_model.h"
#undef private #undef private
#undef protected #undef protected
@@ -49,4 +51,21 @@ TEST_F(UtestModelHelper, save_size_to_modeldef)
ModelHelper model_helper; ModelHelper model_helper;
EXPECT_EQ(SUCCESS, model_helper.SaveSizeToModelDef(ge_model)); EXPECT_EQ(SUCCESS, model_helper.SaveSizeToModelDef(ge_model));
} }
TEST_F(UtestModelHelper, atc_test)
{
ge::proto::ModelDef model_def;
uint32_t modeldef_size = 0;
GEInit::Finalize();
char buffer[1024];
getcwd(buffer, 1024);
string path=buffer;
string file_path=path + "/Makefile";
ModelTool::GetModelInfoFromOm(file_path.c_str(), model_def, modeldef_size);
ModelTool::GetModelInfoFromOm("123.om", model_def, modeldef_size);
ModelTool::GetModelInfoFromPbtxt(file_path.c_str(), model_def);
ModelTool::GetModelInfoFromPbtxt("123.pbtxt", model_def);
}
} // namespace ge } // namespace ge

Loading…
Cancel
Save