| @@ -51,6 +51,7 @@ const char *const kDigitError = "is not digit"; | |||
| const char *const kCompressWeightError = "it must be appointed when appoint parameter[--optypelist_for_implmode]"; | |||
| const char *const kSelectImplmodeError = "only support high_performance, high_precision"; | |||
| const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \""; | |||
| const char *const kKeepDtypeError = "file not found"; | |||
| vector<string> SplitInputShape(const std::string &input_shape) { | |||
| vector<string> shape_pair_vec; | |||
| @@ -438,6 +439,17 @@ Status CheckCompressWeightParamValid(const std::string enable_compress_weight, c | |||
| return ge::SUCCESS; | |||
| } | |||
| Status CheckKeepTypeParamValid(const std::string &keep_dtype) { | |||
| if ((!keep_dtype.empty()) && (!CheckInputPathValid(keep_dtype, "--keep_dtype"))) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||
| "E10001", {"parameter", "value", "reason"}, {"--keep_dtype", keep_dtype, kKeepDtypeError}); | |||
| GELOGE(ge::PARAM_INVALID, "keep dtype config file not found, file_name:%s", keep_dtype.c_str()); | |||
| return ge::PARAM_INVALID; | |||
| } | |||
| return ge::SUCCESS; | |||
| } | |||
| int CheckLogParamValidAndSetLogLevel(const std::string log) { | |||
| int ret = -1; | |||
| if (log == "default") { | |||
| @@ -76,6 +76,7 @@ Status CheckDisableReuseMemoryParamValid(const std::string disable_reuse_memory) | |||
| Status CheckEnableSingleStreamParamValid(const std::string enable_single_stream); | |||
| Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std::string &op_select_implmode); | |||
| Status CheckInputFormat(const std::string &input_format); | |||
| Status CheckKeepTypeParamValid(const std::string &keep_dtype); | |||
| void PrintOptionMap(std::map<std::string, std::string> &options, std::string tips); | |||
| void EraseEndSemicolon(std::string ¶m); | |||
| } | |||
| @@ -10,6 +10,7 @@ protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||
| set(SRC_LIST | |||
| "main.cc" | |||
| "single_op_parser.cc" | |||
| "keep_dtype_option.cc" | |||
| "../session/omg.cc" | |||
| "../ir_build/atc_ir_common.cc" | |||
| ) | |||
| @@ -0,0 +1,116 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "keep_dtype_option.h" | |||
| #include <fstream> | |||
| #include <iostream> | |||
| #include <sstream> | |||
| #include <vector> | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "framework/common/util.h" | |||
| #include "common/util/error_manager/error_manager.h" | |||
| namespace ge { | |||
| namespace { | |||
| const size_t kMaxOpsNum = 10; | |||
| } // namespace | |||
| bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) { | |||
| std::vector<std::string> original_op_names; | |||
| if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) { | |||
| return false; | |||
| } | |||
| for (auto &origin_name : original_op_names) { | |||
| if (origin_name == op_name) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| void KeepDtypeReportError(const std::vector<std::string> &invalid_list) { | |||
| std::stringstream err_msg; | |||
| size_t list_size = invalid_list.size(); | |||
| err_msg << "config file contains " << list_size; | |||
| if (list_size == 1) { | |||
| err_msg << " operator not in the graph, op name:"; | |||
| } else { | |||
| err_msg << " operators not in the graph, op names:"; | |||
| } | |||
| for (size_t i = 0; i < list_size; i++) { | |||
| if (i == kMaxOpsNum) { | |||
| err_msg << ".."; | |||
| break; | |||
| } | |||
| err_msg << invalid_list[i]; | |||
| if (i != list_size - 1) { | |||
| err_msg << " "; | |||
| } | |||
| } | |||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||
| "E10042", {"parameter", "reason"}, {"keep_dtype", err_msg.str().c_str()}); | |||
| GELOGE(FAILED, "%s", err_msg.str().c_str()); | |||
| } | |||
| Status DealKeepDtypeOption(const ComputeGraphPtr &graph, const std::string &keep_dtype) { | |||
| GE_CHECK_NOTNULL(graph); | |||
| if (keep_dtype.empty()) { | |||
| return SUCCESS; | |||
| } | |||
| std::string real_path = RealPath(keep_dtype.c_str()); | |||
| if (real_path.empty()) { | |||
| GELOGE(PARAM_INVALID, "Can not get real path for %s.", keep_dtype.c_str()); | |||
| return PARAM_INVALID; | |||
| } | |||
| std::ifstream ifs(real_path); | |||
| if (!ifs.is_open()) { | |||
| GELOGE(FAILED, "Open file %s failed", keep_dtype.c_str()); | |||
| return FAILED; | |||
| } | |||
| std::string op_name; | |||
| std::vector<std::string> invalid_list; | |||
| while (std::getline(ifs, op_name)) { | |||
| if (op_name.empty()) { | |||
| continue; | |||
| } | |||
| op_name = StringUtils::Trim(op_name); | |||
| bool is_find = false; | |||
| for (auto &node_ptr : graph->GetDirectNode()) { | |||
| auto op_desc = node_ptr->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| if ((op_desc->GetName() == op_name) || IsOriginalOpFind(op_desc, op_name)) { | |||
| is_find = true; | |||
| (void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1); | |||
| } | |||
| } | |||
| if (!is_find) { | |||
| invalid_list.push_back(op_name); | |||
| } | |||
| } | |||
| ifs.close(); | |||
| if (!invalid_list.empty()) { | |||
| KeepDtypeReportError(invalid_list); | |||
| return PARAM_INVALID; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace ge | |||
| @@ -0,0 +1,26 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef KEEP_DTYPE_OPTION_H_ | |||
| #define KEEP_DTYPE_OPTION_H_ | |||
| #include <string> | |||
| #include "graph/compute_graph.h" | |||
| #include "framework/common/ge_inner_error_codes.h" | |||
| namespace ge { | |||
| Status DealKeepDtypeOption(const ComputeGraphPtr &graph, const std::string &keep_dtype); | |||
| } // namespace | |||
| #endif // KEEP_DTYPE_OPTION_H_ | |||
| @@ -43,6 +43,7 @@ | |||
| #include "parser/common/register_tbe.h" | |||
| #include "register/op_registry.h" | |||
| #include "single_op_parser.h" | |||
| #include "keep_dtype_option.h" | |||
| using domi::BuildMode; | |||
| using domi::OpRegistrationData; | |||
| @@ -109,6 +110,9 @@ DEFINE_string(precision_mode, "force_fp16", | |||
| "Optional; precision mode." | |||
| "Support force_fp16, allow_mix_precision, allow_fp32_to_fp16, must_keep_origin_dtype."); | |||
| DEFINE_string(keep_dtype, "", | |||
| "Optional; config file to specify the precision used by the operator during compilation."); | |||
| DEFINE_string(input_format, "", | |||
| "Optional; input_format, format of input data, NCHW;NHWC." | |||
| "Format:\"NHWC\""); | |||
| @@ -285,6 +289,8 @@ class GFlagUtils { | |||
| "\n[Operator Tuning]\n" | |||
| " --precision_mode precision mode, support force_fp16(default), allow_mix_precision, " | |||
| "allow_fp32_to_fp16, must_keep_origin_dtype.\n" | |||
| " --keep_dtype Retains the precision of certain operators in inference " | |||
| "scenarios by using a configuration file.\n" | |||
| " --auto_tune_mode Set tune mode. E.g.: \"GA,RL\", support configure multiple, spit by ,\n" | |||
| " --op_select_implmode Set op select implmode. Support high_precision, high_performance. " | |||
| "default: high_performance\n" | |||
| @@ -421,6 +427,9 @@ class GFlagUtils { | |||
| FLAGS_enable_compress_weight, FLAGS_compress_weight_conf) == ge::SUCCESS, | |||
| ret = ge::FAILED, "check compress weight failed!"); | |||
| GE_CHK_BOOL_EXEC(ge::CheckKeepTypeParamValid(FLAGS_keep_dtype) == ge::SUCCESS, | |||
| ret = ge::FAILED, "check keep dtype failed!"); | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
| !ge::CheckOutputPathValid(FLAGS_check_report, "--check_report"), ret = ge::FAILED, | |||
| "check_report file %s not found!!", FLAGS_check_report.c_str()); | |||
| @@ -979,6 +988,13 @@ domi::Status GenerateModel(std::map<string, string> &options, std::string output | |||
| } | |||
| } | |||
| Status ret = ge::DealKeepDtypeOption(ge::GraphUtils::GetComputeGraph(graph), FLAGS_keep_dtype); | |||
| if (ret != SUCCESS) { | |||
| (void)ge_generator.Finalize(); | |||
| (void)ge::GELib::GetInstance()->Finalize(); | |||
| return ret; | |||
| } | |||
| geRet = ge_generator.GenerateOfflineModel(graph, output, inputs); | |||
| if (geRet != ge::SUCCESS) { | |||
| DOMI_LOGE("GE GenerateOfflineModel execute failed"); | |||
| @@ -11,6 +11,7 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg | |||
| LOCAL_SRC_FILES := \ | |||
| main.cc \ | |||
| single_op_parser.cc \ | |||
| keep_dtype_option.cc \ | |||
| ../session/omg.cc \ | |||
| ../ir_build/atc_ir_common.cc \ | |||
| @@ -64,6 +65,7 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg | |||
| LOCAL_SRC_FILES := \ | |||
| main.cc \ | |||
| single_op_parser.cc \ | |||
| keep_dtype_option.cc \ | |||
| ../session/omg.cc \ | |||
| ../ir_build/atc_ir_common.cc \ | |||
| @@ -117,6 +119,7 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg | |||
| LOCAL_SRC_FILES := \ | |||
| main.cc \ | |||
| single_op_parser.cc \ | |||
| keep_dtype_option.cc \ | |||
| ../session/omg.cc \ | |||
| ../ir_build/atc_ir_common.cc \ | |||
| @@ -1 +1 @@ | |||
| Subproject commit af156f825aa53a24bd30ae4065e3ea356cf555ef | |||
| Subproject commit 98a7ac86170097104a94d72b64bd1a8644c5b3c5 | |||