From ef6fce94ad605383d9f28fbea34b37e33169ce76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=A3=8A?= Date: Tue, 13 Apr 2021 20:13:22 +0800 Subject: [PATCH] add atc param device_id --- ge/offline/main.cc | 14 +++++++++----- inc/external/ge/ge_api_types.h | 4 ++++ metadef | 2 +- parser | 2 +- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/ge/offline/main.cc b/ge/offline/main.cc index 54a1d8fb..6162fced 100755 --- a/ge/offline/main.cc +++ b/ge/offline/main.cc @@ -220,6 +220,8 @@ DEFINE_string(performance_mode, "", "Optional; express high compile performance "normal: no need to compile, used saved .o files directly;" "high: need to recompile, high execute performance mode."); +DEFINE_string(device_id, "0", "Optional; user device id"); + class GFlagUtils { public: /** @@ -579,7 +581,7 @@ class GFlagUtils { if (fileName.size() > static_cast(PATH_MAX)) { ErrorManager::GetInstance().ATCReportErrMessage( "E10021", {"parameter", "size"}, {"output", std::to_string(PATH_MAX)}); - GELOGE(ge::FAILED, + GELOGE(ge::FAILED, "[Check][Path]Input parameter[--output]'s path is too long, it must be less than %d", PATH_MAX); return false; } @@ -638,7 +640,7 @@ static bool CheckInputFormat() { // only support NCHW ND ErrorManager::GetInstance().ATCReportErrMessage( "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kCaffeFormatSupport}); - GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.", + GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kCaffeFormatSupport); return false; } else if ((FLAGS_framework == static_cast(domi::TENSORFLOW))) { // tf @@ -648,7 +650,7 @@ static bool CheckInputFormat() { // only support NCHW NHWC ND NCDHW NDHWC ErrorManager::GetInstance().ATCReportErrMessage( "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kTFFormatSupport}); - GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.", + GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kTFFormatSupport); return false; } else if (FLAGS_framework == static_cast(domi::ONNX)) { @@ -658,7 +660,7 @@ static bool CheckInputFormat() { // only support NCHW ND ErrorManager::GetInstance().ATCReportErrMessage( "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kONNXFormatSupport}); - GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.", + GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kONNXFormatSupport); return false; } @@ -903,7 +905,7 @@ static Status ConvertModelToJson(int fwk_type, const string &model_file, const s ErrorManager::GetInstance().ATCReportErrMessage( "E10001", {"parameter", "value", "reason"}, {"--framework", std::to_string(fwk_type), kModelToJsonSupport}); - GELOGE(ge::FAILED, "[Convert][ModelToJson]Invalid value for --framework[%d], %s.", + GELOGE(ge::FAILED, "[Convert][ModelToJson]Invalid value for --framework[%d], %s.", fwk_type, kModelToJsonSupport); ret = ge::FAILED; } @@ -1084,6 +1086,7 @@ static void SetEnvForSingleOp(std::map &options) { options.emplace(ge::MDL_BANK_PATH_FLAG, FLAGS_mdl_bank_path); options.emplace(ge::OP_BANK_PATH_FLAG, FLAGS_op_bank_path); options.emplace(ge::PERFORMANCE_MODE, FLAGS_performance_mode); + options.emplace(ge::TUNE_DEVICE_IDS, FLAGS_device_id); } domi::Status GenerateSingleOp(const std::string& json_file_path) { @@ -1176,6 +1179,7 @@ domi::Status GenerateOmModel() { options.insert(std::pair(string(ge::OUTPUT_NODE_NAME), FLAGS_out_nodes)); options.insert(std::pair(string(ge::INSERT_OP_FILE), FLAGS_insert_op_conf)); options.insert(std::pair(string(ge::PRECISION_MODE), FLAGS_precision_mode)); + options.insert(std::pair(string(ge::TUNE_DEVICE_IDS), FLAGS_device_id)); options.insert(std::pair(string(ge::RUN_FLAG), to_string(0))); options.insert(std::pair(string(ge::TRAIN_FLAG), to_string(0))); diff --git a/inc/external/ge/ge_api_types.h b/inc/external/ge/ge_api_types.h index 5ae5f036..12ee5e94 100644 --- a/inc/external/ge/ge_api_types.h +++ b/inc/external/ge/ge_api_types.h @@ -166,6 +166,8 @@ const std::string COMPRESS_FLAG = "ge.compressFlag"; const std::string PRECISION_MODE = "ge.exec.precision_mode"; +const std::string TUNE_DEVICE_IDS = "ge.exec.tuneDeviceIds"; + // Configure single op flag for FE // its value should be "0" or "1", default value is "0" const std::string SINGLE_OP_FLAG = "ge.exec.single_op"; @@ -407,6 +409,7 @@ const std::set ir_builder_suppported_options = {INPUT_FORMAT, DYNAMIC_DIMS, INSERT_OP_FILE, PRECISION_MODE, + TUNE_DEVICE_IDS, EXEC_DISABLE_REUSED_MEMORY, AUTO_TUNE_MODE, OUTPUT_TYPE, @@ -434,6 +437,7 @@ const std::set global_options = {CORE_TYPE, ENABLE_COMPRESS_WEIGHT, COMPRESS_WEIGHT_CONF, PRECISION_MODE, + TUNE_DEVICE_IDS, EXEC_DISABLE_REUSED_MEMORY, AUTO_TUNE_MODE, ENABLE_SINGLE_STREAM, diff --git a/metadef b/metadef index 4cf2633a..c1aea328 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit 4cf2633a8f2290dee165ab11f8d6b8a07cba1412 +Subproject commit c1aea328cc04340188e796e639cd55a907488365 diff --git a/parser b/parser index a41249dc..06e784fa 160000 --- a/parser +++ b/parser @@ -1 +1 @@ -Subproject commit a41249dc9b50e4c4988eb62a662b7df29ac24ee7 +Subproject commit 06e784fad01d7e9089cc7e8e0d00fce5b1901886