Browse Source

add atc param device_id

tags/v1.3.0
李磊 3 years ago
parent
commit
ef6fce94ad
4 changed files with 15 additions and 7 deletions
  1. +9
    -5
      ge/offline/main.cc
  2. +4
    -0
      inc/external/ge/ge_api_types.h
  3. +1
    -1
      metadef
  4. +1
    -1
      parser

+ 9
- 5
ge/offline/main.cc View File

@@ -220,6 +220,8 @@ DEFINE_string(performance_mode, "", "Optional; express high compile performance
"normal: no need to compile, used saved .o files directly;" "normal: no need to compile, used saved .o files directly;"
"high: need to recompile, high execute performance mode."); "high: need to recompile, high execute performance mode.");


DEFINE_string(device_id, "0", "Optional; user device id");

class GFlagUtils { class GFlagUtils {
public: public:
/** /**
@@ -579,7 +581,7 @@ class GFlagUtils {
if (fileName.size() > static_cast<int>(PATH_MAX)) { if (fileName.size() > static_cast<int>(PATH_MAX)) {
ErrorManager::GetInstance().ATCReportErrMessage( ErrorManager::GetInstance().ATCReportErrMessage(
"E10021", {"parameter", "size"}, {"output", std::to_string(PATH_MAX)}); "E10021", {"parameter", "size"}, {"output", std::to_string(PATH_MAX)});
GELOGE(ge::FAILED,
GELOGE(ge::FAILED,
"[Check][Path]Input parameter[--output]'s path is too long, it must be less than %d", PATH_MAX); "[Check][Path]Input parameter[--output]'s path is too long, it must be less than %d", PATH_MAX);
return false; return false;
} }
@@ -638,7 +640,7 @@ static bool CheckInputFormat() {
// only support NCHW ND // only support NCHW ND
ErrorManager::GetInstance().ATCReportErrMessage( ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kCaffeFormatSupport}); "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kCaffeFormatSupport});
GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.",
GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.",
FLAGS_input_format.c_str(), kCaffeFormatSupport); FLAGS_input_format.c_str(), kCaffeFormatSupport);
return false; return false;
} else if ((FLAGS_framework == static_cast<int32_t>(domi::TENSORFLOW))) { // tf } else if ((FLAGS_framework == static_cast<int32_t>(domi::TENSORFLOW))) { // tf
@@ -648,7 +650,7 @@ static bool CheckInputFormat() {
// only support NCHW NHWC ND NCDHW NDHWC // only support NCHW NHWC ND NCDHW NDHWC
ErrorManager::GetInstance().ATCReportErrMessage( ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kTFFormatSupport}); "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kTFFormatSupport});
GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.",
GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.",
FLAGS_input_format.c_str(), kTFFormatSupport); FLAGS_input_format.c_str(), kTFFormatSupport);
return false; return false;
} else if (FLAGS_framework == static_cast<int32_t>(domi::ONNX)) { } else if (FLAGS_framework == static_cast<int32_t>(domi::ONNX)) {
@@ -658,7 +660,7 @@ static bool CheckInputFormat() {
// only support NCHW ND // only support NCHW ND
ErrorManager::GetInstance().ATCReportErrMessage( ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kONNXFormatSupport}); "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kONNXFormatSupport});
GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.",
GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.",
FLAGS_input_format.c_str(), kONNXFormatSupport); FLAGS_input_format.c_str(), kONNXFormatSupport);
return false; return false;
} }
@@ -903,7 +905,7 @@ static Status ConvertModelToJson(int fwk_type, const string &model_file, const s
ErrorManager::GetInstance().ATCReportErrMessage( ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, "E10001", {"parameter", "value", "reason"},
{"--framework", std::to_string(fwk_type), kModelToJsonSupport}); {"--framework", std::to_string(fwk_type), kModelToJsonSupport});
GELOGE(ge::FAILED, "[Convert][ModelToJson]Invalid value for --framework[%d], %s.",
GELOGE(ge::FAILED, "[Convert][ModelToJson]Invalid value for --framework[%d], %s.",
fwk_type, kModelToJsonSupport); fwk_type, kModelToJsonSupport);
ret = ge::FAILED; ret = ge::FAILED;
} }
@@ -1084,6 +1086,7 @@ static void SetEnvForSingleOp(std::map<string, string> &options) {
options.emplace(ge::MDL_BANK_PATH_FLAG, FLAGS_mdl_bank_path); options.emplace(ge::MDL_BANK_PATH_FLAG, FLAGS_mdl_bank_path);
options.emplace(ge::OP_BANK_PATH_FLAG, FLAGS_op_bank_path); options.emplace(ge::OP_BANK_PATH_FLAG, FLAGS_op_bank_path);
options.emplace(ge::PERFORMANCE_MODE, FLAGS_performance_mode); options.emplace(ge::PERFORMANCE_MODE, FLAGS_performance_mode);
options.emplace(ge::TUNE_DEVICE_IDS, FLAGS_device_id);
} }


domi::Status GenerateSingleOp(const std::string& json_file_path) { domi::Status GenerateSingleOp(const std::string& json_file_path) {
@@ -1176,6 +1179,7 @@ domi::Status GenerateOmModel() {
options.insert(std::pair<string, string>(string(ge::OUTPUT_NODE_NAME), FLAGS_out_nodes)); options.insert(std::pair<string, string>(string(ge::OUTPUT_NODE_NAME), FLAGS_out_nodes));
options.insert(std::pair<string, string>(string(ge::INSERT_OP_FILE), FLAGS_insert_op_conf)); options.insert(std::pair<string, string>(string(ge::INSERT_OP_FILE), FLAGS_insert_op_conf));
options.insert(std::pair<string, string>(string(ge::PRECISION_MODE), FLAGS_precision_mode)); options.insert(std::pair<string, string>(string(ge::PRECISION_MODE), FLAGS_precision_mode));
options.insert(std::pair<string, string>(string(ge::TUNE_DEVICE_IDS), FLAGS_device_id));


options.insert(std::pair<string, string>(string(ge::RUN_FLAG), to_string(0))); options.insert(std::pair<string, string>(string(ge::RUN_FLAG), to_string(0)));
options.insert(std::pair<string, string>(string(ge::TRAIN_FLAG), to_string(0))); options.insert(std::pair<string, string>(string(ge::TRAIN_FLAG), to_string(0)));


+ 4
- 0
inc/external/ge/ge_api_types.h View File

@@ -166,6 +166,8 @@ const std::string COMPRESS_FLAG = "ge.compressFlag";


const std::string PRECISION_MODE = "ge.exec.precision_mode"; const std::string PRECISION_MODE = "ge.exec.precision_mode";


const std::string TUNE_DEVICE_IDS = "ge.exec.tuneDeviceIds";

// Configure single op flag for FE // Configure single op flag for FE
// its value should be "0" or "1", default value is "0" // its value should be "0" or "1", default value is "0"
const std::string SINGLE_OP_FLAG = "ge.exec.single_op"; const std::string SINGLE_OP_FLAG = "ge.exec.single_op";
@@ -407,6 +409,7 @@ const std::set<std::string> ir_builder_suppported_options = {INPUT_FORMAT,
DYNAMIC_DIMS, DYNAMIC_DIMS,
INSERT_OP_FILE, INSERT_OP_FILE,
PRECISION_MODE, PRECISION_MODE,
TUNE_DEVICE_IDS,
EXEC_DISABLE_REUSED_MEMORY, EXEC_DISABLE_REUSED_MEMORY,
AUTO_TUNE_MODE, AUTO_TUNE_MODE,
OUTPUT_TYPE, OUTPUT_TYPE,
@@ -434,6 +437,7 @@ const std::set<std::string> global_options = {CORE_TYPE,
ENABLE_COMPRESS_WEIGHT, ENABLE_COMPRESS_WEIGHT,
COMPRESS_WEIGHT_CONF, COMPRESS_WEIGHT_CONF,
PRECISION_MODE, PRECISION_MODE,
TUNE_DEVICE_IDS,
EXEC_DISABLE_REUSED_MEMORY, EXEC_DISABLE_REUSED_MEMORY,
AUTO_TUNE_MODE, AUTO_TUNE_MODE,
ENABLE_SINGLE_STREAM, ENABLE_SINGLE_STREAM,


+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit 4cf2633a8f2290dee165ab11f8d6b8a07cba1412
Subproject commit c1aea328cc04340188e796e639cd55a907488365

+ 1
- 1
parser

@@ -1 +1 @@
Subproject commit a41249dc9b50e4c4988eb62a662b7df29ac24ee7
Subproject commit 06e784fad01d7e9089cc7e8e0d00fce5b1901886

Loading…
Cancel
Save