|
@@ -220,6 +220,8 @@ DEFINE_string(performance_mode, "", "Optional; express high compile performance |
|
|
"normal: no need to compile, used saved .o files directly;" |
|
|
"normal: no need to compile, used saved .o files directly;" |
|
|
"high: need to recompile, high execute performance mode."); |
|
|
"high: need to recompile, high execute performance mode."); |
|
|
|
|
|
|
|
|
|
|
|
DEFINE_string(device_id, "0", "Optional; user device id"); |
|
|
|
|
|
|
|
|
class GFlagUtils { |
|
|
class GFlagUtils { |
|
|
public: |
|
|
public: |
|
|
/** |
|
|
/** |
|
@@ -579,7 +581,7 @@ class GFlagUtils { |
|
|
if (fileName.size() > static_cast<int>(PATH_MAX)) { |
|
|
if (fileName.size() > static_cast<int>(PATH_MAX)) { |
|
|
ErrorManager::GetInstance().ATCReportErrMessage( |
|
|
ErrorManager::GetInstance().ATCReportErrMessage( |
|
|
"E10021", {"parameter", "size"}, {"output", std::to_string(PATH_MAX)}); |
|
|
"E10021", {"parameter", "size"}, {"output", std::to_string(PATH_MAX)}); |
|
|
GELOGE(ge::FAILED, |
|
|
|
|
|
|
|
|
GELOGE(ge::FAILED, |
|
|
"[Check][Path]Input parameter[--output]'s path is too long, it must be less than %d", PATH_MAX); |
|
|
"[Check][Path]Input parameter[--output]'s path is too long, it must be less than %d", PATH_MAX); |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
@@ -638,7 +640,7 @@ static bool CheckInputFormat() { |
|
|
// only support NCHW ND |
|
|
// only support NCHW ND |
|
|
ErrorManager::GetInstance().ATCReportErrMessage( |
|
|
ErrorManager::GetInstance().ATCReportErrMessage( |
|
|
"E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kCaffeFormatSupport}); |
|
|
"E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kCaffeFormatSupport}); |
|
|
GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.", |
|
|
|
|
|
|
|
|
GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.", |
|
|
FLAGS_input_format.c_str(), kCaffeFormatSupport); |
|
|
FLAGS_input_format.c_str(), kCaffeFormatSupport); |
|
|
return false; |
|
|
return false; |
|
|
} else if ((FLAGS_framework == static_cast<int32_t>(domi::TENSORFLOW))) { // tf |
|
|
} else if ((FLAGS_framework == static_cast<int32_t>(domi::TENSORFLOW))) { // tf |
|
@@ -648,7 +650,7 @@ static bool CheckInputFormat() { |
|
|
// only support NCHW NHWC ND NCDHW NDHWC |
|
|
// only support NCHW NHWC ND NCDHW NDHWC |
|
|
ErrorManager::GetInstance().ATCReportErrMessage( |
|
|
ErrorManager::GetInstance().ATCReportErrMessage( |
|
|
"E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kTFFormatSupport}); |
|
|
"E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kTFFormatSupport}); |
|
|
GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.", |
|
|
|
|
|
|
|
|
GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.", |
|
|
FLAGS_input_format.c_str(), kTFFormatSupport); |
|
|
FLAGS_input_format.c_str(), kTFFormatSupport); |
|
|
return false; |
|
|
return false; |
|
|
} else if (FLAGS_framework == static_cast<int32_t>(domi::ONNX)) { |
|
|
} else if (FLAGS_framework == static_cast<int32_t>(domi::ONNX)) { |
|
@@ -658,7 +660,7 @@ static bool CheckInputFormat() { |
|
|
// only support NCHW ND |
|
|
// only support NCHW ND |
|
|
ErrorManager::GetInstance().ATCReportErrMessage( |
|
|
ErrorManager::GetInstance().ATCReportErrMessage( |
|
|
"E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kONNXFormatSupport}); |
|
|
"E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kONNXFormatSupport}); |
|
|
GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.", |
|
|
|
|
|
|
|
|
GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.", |
|
|
FLAGS_input_format.c_str(), kONNXFormatSupport); |
|
|
FLAGS_input_format.c_str(), kONNXFormatSupport); |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
@@ -903,7 +905,7 @@ static Status ConvertModelToJson(int fwk_type, const string &model_file, const s |
|
|
ErrorManager::GetInstance().ATCReportErrMessage( |
|
|
ErrorManager::GetInstance().ATCReportErrMessage( |
|
|
"E10001", {"parameter", "value", "reason"}, |
|
|
"E10001", {"parameter", "value", "reason"}, |
|
|
{"--framework", std::to_string(fwk_type), kModelToJsonSupport}); |
|
|
{"--framework", std::to_string(fwk_type), kModelToJsonSupport}); |
|
|
GELOGE(ge::FAILED, "[Convert][ModelToJson]Invalid value for --framework[%d], %s.", |
|
|
|
|
|
|
|
|
GELOGE(ge::FAILED, "[Convert][ModelToJson]Invalid value for --framework[%d], %s.", |
|
|
fwk_type, kModelToJsonSupport); |
|
|
fwk_type, kModelToJsonSupport); |
|
|
ret = ge::FAILED; |
|
|
ret = ge::FAILED; |
|
|
} |
|
|
} |
|
@@ -1084,6 +1086,7 @@ static void SetEnvForSingleOp(std::map<string, string> &options) { |
|
|
options.emplace(ge::MDL_BANK_PATH_FLAG, FLAGS_mdl_bank_path); |
|
|
options.emplace(ge::MDL_BANK_PATH_FLAG, FLAGS_mdl_bank_path); |
|
|
options.emplace(ge::OP_BANK_PATH_FLAG, FLAGS_op_bank_path); |
|
|
options.emplace(ge::OP_BANK_PATH_FLAG, FLAGS_op_bank_path); |
|
|
options.emplace(ge::PERFORMANCE_MODE, FLAGS_performance_mode); |
|
|
options.emplace(ge::PERFORMANCE_MODE, FLAGS_performance_mode); |
|
|
|
|
|
options.emplace(ge::TUNE_DEVICE_IDS, FLAGS_device_id); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
domi::Status GenerateSingleOp(const std::string& json_file_path) { |
|
|
domi::Status GenerateSingleOp(const std::string& json_file_path) { |
|
@@ -1176,6 +1179,7 @@ domi::Status GenerateOmModel() { |
|
|
options.insert(std::pair<string, string>(string(ge::OUTPUT_NODE_NAME), FLAGS_out_nodes)); |
|
|
options.insert(std::pair<string, string>(string(ge::OUTPUT_NODE_NAME), FLAGS_out_nodes)); |
|
|
options.insert(std::pair<string, string>(string(ge::INSERT_OP_FILE), FLAGS_insert_op_conf)); |
|
|
options.insert(std::pair<string, string>(string(ge::INSERT_OP_FILE), FLAGS_insert_op_conf)); |
|
|
options.insert(std::pair<string, string>(string(ge::PRECISION_MODE), FLAGS_precision_mode)); |
|
|
options.insert(std::pair<string, string>(string(ge::PRECISION_MODE), FLAGS_precision_mode)); |
|
|
|
|
|
options.insert(std::pair<string, string>(string(ge::TUNE_DEVICE_IDS), FLAGS_device_id)); |
|
|
|
|
|
|
|
|
options.insert(std::pair<string, string>(string(ge::RUN_FLAG), to_string(0))); |
|
|
options.insert(std::pair<string, string>(string(ge::RUN_FLAG), to_string(0))); |
|
|
options.insert(std::pair<string, string>(string(ge::TRAIN_FLAG), to_string(0))); |
|
|
options.insert(std::pair<string, string>(string(ge::TRAIN_FLAG), to_string(0))); |
|
|