@@ -96,7 +96,8 @@ Status HcclNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do | |||||
op_info.root = root_id; | op_info.root = root_id; | ||||
auto callback = [this, op_desc](HcclResult status) { | auto callback = [this, op_desc](HcclResult status) { | ||||
if (status != HCCL_SUCCESS) { | if (status != HCCL_SUCCESS) { | ||||
GELOGE(HCCL_E_INTERNAL, "node %s call HcomExecEnqueueOperation failed, ret: 0x%X", op_desc->GetName().c_str(), status); | |||||
GELOGE(HCCL_E_INTERNAL, "node %s call HcomExecEnqueueOperation failed, ret: 0x%X", | |||||
op_desc->GetName().c_str(), status); | |||||
} | } | ||||
std::lock_guard<std::mutex> lock(this->hccl_mutex_); | std::lock_guard<std::mutex> lock(this->hccl_mutex_); | ||||
this->cond_.notify_all(); | this->cond_.notify_all(); | ||||
@@ -51,6 +51,7 @@ const char *const kDigitError = "is not digit"; | |||||
const char *const kCompressWeightError = "it must be appointed when appoint parameter[--optypelist_for_implmode]"; | const char *const kCompressWeightError = "it must be appointed when appoint parameter[--optypelist_for_implmode]"; | ||||
const char *const kSelectImplmodeError = "only support high_performance, high_precision"; | const char *const kSelectImplmodeError = "only support high_performance, high_precision"; | ||||
const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \""; | const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \""; | ||||
const char *const kKeepDtypeError = "file not found"; | |||||
vector<string> SplitInputShape(const std::string &input_shape) { | vector<string> SplitInputShape(const std::string &input_shape) { | ||||
vector<string> shape_pair_vec; | vector<string> shape_pair_vec; | ||||
@@ -439,6 +440,17 @@ Status CheckCompressWeightParamValid(const std::string enable_compress_weight, c | |||||
return ge::SUCCESS; | return ge::SUCCESS; | ||||
} | } | ||||
Status CheckKeepTypeParamValid(const std::string &keep_dtype) { | |||||
if ((!keep_dtype.empty()) && (!CheckInputPathValid(keep_dtype, "--keep_dtype"))) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage( | |||||
"E10001", {"parameter", "value", "reason"}, {"--keep_dtype", keep_dtype, kKeepDtypeError}); | |||||
GELOGE(ge::PARAM_INVALID, "keep dtype config file not found, file_name:%s", keep_dtype.c_str()); | |||||
return ge::PARAM_INVALID; | |||||
} | |||||
return ge::SUCCESS; | |||||
} | |||||
int CheckLogParamValidAndSetLogLevel(const std::string log) { | int CheckLogParamValidAndSetLogLevel(const std::string log) { | ||||
int ret = -1; | int ret = -1; | ||||
if (log == "default") { | if (log == "default") { | ||||
@@ -76,6 +76,7 @@ Status CheckDisableReuseMemoryParamValid(const std::string disable_reuse_memory) | |||||
Status CheckEnableSingleStreamParamValid(const std::string enable_single_stream); | Status CheckEnableSingleStreamParamValid(const std::string enable_single_stream); | ||||
Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std::string &op_select_implmode); | Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std::string &op_select_implmode); | ||||
Status CheckInputFormat(const string &input_format); | Status CheckInputFormat(const string &input_format); | ||||
Status CheckKeepTypeParamValid(const std::string &keep_dtype); | |||||
void PrintOptionMap(std::map<std::string, std::string> &options, std::string tips); | void PrintOptionMap(std::map<std::string, std::string> &options, std::string tips); | ||||
void EraseEndSemicolon(std::string ¶m); | void EraseEndSemicolon(std::string ¶m); | ||||
} | } | ||||
@@ -10,6 +10,7 @@ protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||||
set(SRC_LIST | set(SRC_LIST | ||||
"main.cc" | "main.cc" | ||||
"single_op_parser.cc" | "single_op_parser.cc" | ||||
"keep_dtype_option.cc" | |||||
"../session/omg.cc" | "../session/omg.cc" | ||||
"../ir_build/atc_ir_common.cc" | "../ir_build/atc_ir_common.cc" | ||||
) | ) | ||||
@@ -0,0 +1,107 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "keep_dtype_option.h" | |||||
#include <fstream> | |||||
#include <iostream> | |||||
#include <sstream> | |||||
#include <vector> | |||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "framework/common/util.h" | |||||
#include "common/util/error_manager/error_manager.h" | |||||
namespace ge { | |||||
namespace { | |||||
const size_t kMaxOpsNum = 10; | |||||
} // namespace | |||||
bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) { | |||||
std::vector<std::string> original_op_names; | |||||
if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) { | |||||
return false; | |||||
} | |||||
for (auto &origin_name : original_op_names) { | |||||
if (origin_name == op_name) { | |||||
return true; | |||||
} | |||||
} | |||||
return false; | |||||
} | |||||
void KeepDtypeReportError(const std::vector<std::string> &invalid_list) { | |||||
std::stringstream error_ops; | |||||
for (size_t i = 0; i < invalid_list.size(); i++) { | |||||
if (i == kMaxOpsNum) { | |||||
error_ops << "..."; | |||||
break; | |||||
} | |||||
error_ops << invalid_list[i] << " "; | |||||
} | |||||
std::string err_msg = "config file contains "; | |||||
err_msg = err_msg.append(std::to_string(invalid_list.size())) | |||||
.append(" operators not in the graph, op names:") | |||||
.append(error_ops.str()); | |||||
ErrorManager::GetInstance().ATCReportErrMessage( | |||||
"E10042", {"parameter", "reason"}, {"keep_dtype", err_msg.c_str()}); | |||||
GELOGE(FAILED, "%s", err_msg.c_str()); | |||||
} | |||||
Status DealKeepDtypeOption(const ComputeGraphPtr &graph, const std::string &keep_dtype) { | |||||
GE_CHECK_NOTNULL(graph); | |||||
if (keep_dtype.empty()) { | |||||
return SUCCESS; | |||||
} | |||||
std::string real_path = RealPath(keep_dtype.c_str()); | |||||
if (real_path.empty()) { | |||||
GELOGE(PARAM_INVALID, "Can not get real path for %s.", keep_dtype.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
std::ifstream ifs(real_path); | |||||
if (!ifs.is_open()) { | |||||
GELOGE(FAILED, "Open file %s failed", keep_dtype.c_str()); | |||||
return FAILED; | |||||
} | |||||
std::string op_name; | |||||
std::vector<std::string> invalid_list; | |||||
while (std::getline(ifs, op_name)) { | |||||
if (op_name.empty()) { | |||||
continue; | |||||
} | |||||
op_name = StringUtils::Trim(op_name); | |||||
bool is_find = false; | |||||
for (auto &node_ptr : graph->GetDirectNode()) { | |||||
auto op_desc = node_ptr->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
if ((op_desc->GetName() == op_name) || IsOriginalOpFind(op_desc, op_name)) { | |||||
is_find = true; | |||||
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1); | |||||
} | |||||
} | |||||
if (!is_find) { | |||||
invalid_list.push_back(op_name); | |||||
} | |||||
} | |||||
if (!invalid_list.empty()) { | |||||
KeepDtypeReportError(invalid_list); | |||||
return PARAM_INVALID; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
} // namespace ge |
@@ -0,0 +1,26 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef KEEP_DTYPE_OPTION_H_ | |||||
#define KEEP_DTYPE_OPTION_H_ | |||||
#include <string> | |||||
#include "graph/compute_graph.h" | |||||
#include "framework/common/ge_inner_error_codes.h" | |||||
namespace ge { | |||||
Status DealKeepDtypeOption(const ComputeGraphPtr &graph, const std::string &keep_dtype); | |||||
} // namespace | |||||
#endif // KEEP_DTYPE_OPTION_H_ |
@@ -43,6 +43,7 @@ | |||||
#include "parser/common/register_tbe.h" | #include "parser/common/register_tbe.h" | ||||
#include "register/op_registry.h" | #include "register/op_registry.h" | ||||
#include "single_op_parser.h" | #include "single_op_parser.h" | ||||
#include "keep_dtype_option.h" | |||||
using domi::BuildMode; | using domi::BuildMode; | ||||
using domi::OpRegistrationData; | using domi::OpRegistrationData; | ||||
@@ -109,6 +110,9 @@ DEFINE_string(precision_mode, "force_fp16", | |||||
"Optional; precision mode." | "Optional; precision mode." | ||||
"Support force_fp16, allow_mix_precision, allow_fp32_to_fp16, must_keep_origin_dtype."); | "Support force_fp16, allow_mix_precision, allow_fp32_to_fp16, must_keep_origin_dtype."); | ||||
DEFINE_string(keep_dtype, "", | |||||
"Optional; config file to specify the precision used by the operator during compilation."); | |||||
DEFINE_string(input_format, "", | DEFINE_string(input_format, "", | ||||
"Optional; input_format, format of input data, NCHW;NHWC." | "Optional; input_format, format of input data, NCHW;NHWC." | ||||
"Format:\"NHWC\""); | "Format:\"NHWC\""); | ||||
@@ -421,6 +425,9 @@ class GFlagUtils { | |||||
FLAGS_enable_compress_weight, FLAGS_compress_weight_conf) == ge::SUCCESS, | FLAGS_enable_compress_weight, FLAGS_compress_weight_conf) == ge::SUCCESS, | ||||
ret = ge::FAILED, "check compress weight failed!"); | ret = ge::FAILED, "check compress weight failed!"); | ||||
GE_CHK_BOOL_EXEC(ge::CheckKeepTypeParamValid(FLAGS_keep_dtype) == ge::SUCCESS, | |||||
ret = ge::FAILED, "check keep dtype failed!"); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | ||||
!ge::CheckOutputPathValid(FLAGS_check_report, "--check_report"), ret = ge::FAILED, | !ge::CheckOutputPathValid(FLAGS_check_report, "--check_report"), ret = ge::FAILED, | ||||
"check_report file %s not found!!", FLAGS_check_report.c_str()); | "check_report file %s not found!!", FLAGS_check_report.c_str()); | ||||
@@ -979,6 +986,11 @@ domi::Status GenerateModel(std::map<string, string> &options, std::string output | |||||
} | } | ||||
} | } | ||||
Status ret = ge::DealKeepDtypeOption(ge::GraphUtils::GetComputeGraph(graph), FLAGS_keep_dtype); | |||||
if (ret != SUCCESS) { | |||||
return ret; | |||||
} | |||||
geRet = ge_generator.GenerateOfflineModel(graph, output, inputs); | geRet = ge_generator.GenerateOfflineModel(graph, output, inputs); | ||||
if (geRet != ge::SUCCESS) { | if (geRet != ge::SUCCESS) { | ||||
DOMI_LOGE("GE GenerateOfflineModel execute failed"); | DOMI_LOGE("GE GenerateOfflineModel execute failed"); | ||||
@@ -10,6 +10,7 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg | |||||
LOCAL_SRC_FILES := \ | LOCAL_SRC_FILES := \ | ||||
main.cc \ | main.cc \ | ||||
keep_dtype_option.cc \ | |||||
single_op_parser.cc \ | single_op_parser.cc \ | ||||
../session/omg.cc \ | ../session/omg.cc \ | ||||
../ir_build/atc_ir_common.cc \ | ../ir_build/atc_ir_common.cc \ | ||||
@@ -63,6 +64,7 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg | |||||
LOCAL_SRC_FILES := \ | LOCAL_SRC_FILES := \ | ||||
main.cc \ | main.cc \ | ||||
keep_dtype_option.cc \ | |||||
single_op_parser.cc \ | single_op_parser.cc \ | ||||
../session/omg.cc \ | ../session/omg.cc \ | ||||
../ir_build/atc_ir_common.cc \ | ../ir_build/atc_ir_common.cc \ | ||||
@@ -116,6 +118,7 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg | |||||
LOCAL_SRC_FILES := \ | LOCAL_SRC_FILES := \ | ||||
main.cc \ | main.cc \ | ||||
keep_dtype_option.cc \ | |||||
single_op_parser.cc \ | single_op_parser.cc \ | ||||
../session/omg.cc \ | ../session/omg.cc \ | ||||
../ir_build/atc_ir_common.cc \ | ../ir_build/atc_ir_common.cc \ | ||||