Browse Source

!698 Add keep_dtype attribute on operators to keep precision unchanged

From: @li-lei0106
Reviewed-by: 
Signed-off-by:
tags/v1.2.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
9aec8b4f0f
8 changed files with 164 additions and 1 deletions
  1. +2
    -1
      ge/hybrid/node_executor/hccl/hccl_node_executor.cc
  2. +12
    -0
      ge/ir_build/atc_ir_common.cc
  3. +1
    -0
      ge/ir_build/atc_ir_common.h
  4. +1
    -0
      ge/offline/CMakeLists.txt
  5. +107
    -0
      ge/offline/keep_dtype_option.cc
  6. +26
    -0
      ge/offline/keep_dtype_option.h
  7. +12
    -0
      ge/offline/main.cc
  8. +3
    -0
      ge/offline/module.mk

+ 2
- 1
ge/hybrid/node_executor/hccl/hccl_node_executor.cc View File

@@ -96,7 +96,8 @@ Status HcclNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do
op_info.root = root_id; op_info.root = root_id;
auto callback = [this, op_desc](HcclResult status) { auto callback = [this, op_desc](HcclResult status) {
if (status != HCCL_SUCCESS) { if (status != HCCL_SUCCESS) {
GELOGE(HCCL_E_INTERNAL, "node %s call HcomExecEnqueueOperation failed, ret: 0x%X", op_desc->GetName().c_str(), status);
GELOGE(HCCL_E_INTERNAL, "node %s call HcomExecEnqueueOperation failed, ret: 0x%X",
op_desc->GetName().c_str(), status);
} }
std::lock_guard<std::mutex> lock(this->hccl_mutex_); std::lock_guard<std::mutex> lock(this->hccl_mutex_);
this->cond_.notify_all(); this->cond_.notify_all();


+ 12
- 0
ge/ir_build/atc_ir_common.cc View File

@@ -51,6 +51,7 @@ const char *const kDigitError = "is not digit";
const char *const kCompressWeightError = "it must be appointed when appoint parameter[--optypelist_for_implmode]"; const char *const kCompressWeightError = "it must be appointed when appoint parameter[--optypelist_for_implmode]";
const char *const kSelectImplmodeError = "only support high_performance, high_precision"; const char *const kSelectImplmodeError = "only support high_performance, high_precision";
const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \""; const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \"";
const char *const kKeepDtypeError = "file not found";


vector<string> SplitInputShape(const std::string &input_shape) { vector<string> SplitInputShape(const std::string &input_shape) {
vector<string> shape_pair_vec; vector<string> shape_pair_vec;
@@ -439,6 +440,17 @@ Status CheckCompressWeightParamValid(const std::string enable_compress_weight, c
return ge::SUCCESS; return ge::SUCCESS;
} }


Status CheckKeepTypeParamValid(const std::string &keep_dtype) {
if ((!keep_dtype.empty()) && (!CheckInputPathValid(keep_dtype, "--keep_dtype"))) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {"--keep_dtype", keep_dtype, kKeepDtypeError});
GELOGE(ge::PARAM_INVALID, "keep dtype config file not found, file_name:%s", keep_dtype.c_str());
return ge::PARAM_INVALID;
}

return ge::SUCCESS;
}

int CheckLogParamValidAndSetLogLevel(const std::string log) { int CheckLogParamValidAndSetLogLevel(const std::string log) {
int ret = -1; int ret = -1;
if (log == "default") { if (log == "default") {


+ 1
- 0
ge/ir_build/atc_ir_common.h View File

@@ -76,6 +76,7 @@ Status CheckDisableReuseMemoryParamValid(const std::string disable_reuse_memory)
Status CheckEnableSingleStreamParamValid(const std::string enable_single_stream); Status CheckEnableSingleStreamParamValid(const std::string enable_single_stream);
Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std::string &op_select_implmode); Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std::string &op_select_implmode);
Status CheckInputFormat(const string &input_format); Status CheckInputFormat(const string &input_format);
Status CheckKeepTypeParamValid(const std::string &keep_dtype);
void PrintOptionMap(std::map<std::string, std::string> &options, std::string tips); void PrintOptionMap(std::map<std::string, std::string> &options, std::string tips);
void EraseEndSemicolon(std::string &param); void EraseEndSemicolon(std::string &param);
} }


+ 1
- 0
ge/offline/CMakeLists.txt View File

@@ -10,6 +10,7 @@ protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST})
set(SRC_LIST set(SRC_LIST
"main.cc" "main.cc"
"single_op_parser.cc" "single_op_parser.cc"
"keep_dtype_option.cc"
"../session/omg.cc" "../session/omg.cc"
"../ir_build/atc_ir_common.cc" "../ir_build/atc_ir_common.cc"
) )


+ 107
- 0
ge/offline/keep_dtype_option.cc View File

@@ -0,0 +1,107 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "keep_dtype_option.h"
#include <fstream>
#include <iostream>
#include <sstream>
#include <vector>
#include "graph/debug/ge_attr_define.h"
#include "framework/common/util.h"
#include "common/util/error_manager/error_manager.h"
namespace ge {
namespace {
const size_t kMaxOpsNum = 10;
} // namespace
bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) {
std::vector<std::string> original_op_names;
if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) {
return false;
}
for (auto &origin_name : original_op_names) {
if (origin_name == op_name) {
return true;
}
}
return false;
}
void KeepDtypeReportError(const std::vector<std::string> &invalid_list) {
std::stringstream error_ops;
for (size_t i = 0; i < invalid_list.size(); i++) {
if (i == kMaxOpsNum) {
error_ops << "...";
break;
}
error_ops << invalid_list[i] << " ";
}
std::string err_msg = "config file contains ";
err_msg = err_msg.append(std::to_string(invalid_list.size()))
.append(" operators not in the graph, op names:")
.append(error_ops.str());
ErrorManager::GetInstance().ATCReportErrMessage(
"E10042", {"parameter", "reason"}, {"keep_dtype", err_msg.c_str()});
GELOGE(FAILED, "%s", err_msg.c_str());
}
Status DealKeepDtypeOption(const ComputeGraphPtr &graph, const std::string &keep_dtype) {
GE_CHECK_NOTNULL(graph);
if (keep_dtype.empty()) {
return SUCCESS;
}
std::string real_path = RealPath(keep_dtype.c_str());
if (real_path.empty()) {
GELOGE(PARAM_INVALID, "Can not get real path for %s.", keep_dtype.c_str());
return PARAM_INVALID;
}
std::ifstream ifs(real_path);
if (!ifs.is_open()) {
GELOGE(FAILED, "Open file %s failed", keep_dtype.c_str());
return FAILED;
}
std::string op_name;
std::vector<std::string> invalid_list;
while (std::getline(ifs, op_name)) {
if (op_name.empty()) {
continue;
}
op_name = StringUtils::Trim(op_name);
bool is_find = false;
for (auto &node_ptr : graph->GetDirectNode()) {
auto op_desc = node_ptr->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if ((op_desc->GetName() == op_name) || IsOriginalOpFind(op_desc, op_name)) {
is_find = true;
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1);
}
}
if (!is_find) {
invalid_list.push_back(op_name);
}
}
if (!invalid_list.empty()) {
KeepDtypeReportError(invalid_list);
return PARAM_INVALID;
}
return SUCCESS;
}
} // namespace ge

+ 26
- 0
ge/offline/keep_dtype_option.h View File

@@ -0,0 +1,26 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef KEEP_DTYPE_OPTION_H_
#define KEEP_DTYPE_OPTION_H_
#include <string>
#include "graph/compute_graph.h"
#include "framework/common/ge_inner_error_codes.h"
namespace ge {
Status DealKeepDtypeOption(const ComputeGraphPtr &graph, const std::string &keep_dtype);
} // namespace
#endif // KEEP_DTYPE_OPTION_H_

+ 12
- 0
ge/offline/main.cc View File

@@ -43,6 +43,7 @@
#include "parser/common/register_tbe.h" #include "parser/common/register_tbe.h"
#include "register/op_registry.h" #include "register/op_registry.h"
#include "single_op_parser.h" #include "single_op_parser.h"
#include "keep_dtype_option.h"


using domi::BuildMode; using domi::BuildMode;
using domi::OpRegistrationData; using domi::OpRegistrationData;
@@ -109,6 +110,9 @@ DEFINE_string(precision_mode, "force_fp16",
"Optional; precision mode." "Optional; precision mode."
"Support force_fp16, allow_mix_precision, allow_fp32_to_fp16, must_keep_origin_dtype."); "Support force_fp16, allow_mix_precision, allow_fp32_to_fp16, must_keep_origin_dtype.");


DEFINE_string(keep_dtype, "",
"Optional; config file to specify the precision used by the operator during compilation.");

DEFINE_string(input_format, "", DEFINE_string(input_format, "",
"Optional; input_format, format of input data, NCHW;NHWC." "Optional; input_format, format of input data, NCHW;NHWC."
"Format:\"NHWC\""); "Format:\"NHWC\"");
@@ -421,6 +425,9 @@ class GFlagUtils {
FLAGS_enable_compress_weight, FLAGS_compress_weight_conf) == ge::SUCCESS, FLAGS_enable_compress_weight, FLAGS_compress_weight_conf) == ge::SUCCESS,
ret = ge::FAILED, "check compress weight failed!"); ret = ge::FAILED, "check compress weight failed!");


GE_CHK_BOOL_EXEC(ge::CheckKeepTypeParamValid(FLAGS_keep_dtype) == ge::SUCCESS,
ret = ge::FAILED, "check keep dtype failed!");

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
!ge::CheckOutputPathValid(FLAGS_check_report, "--check_report"), ret = ge::FAILED, !ge::CheckOutputPathValid(FLAGS_check_report, "--check_report"), ret = ge::FAILED,
"check_report file %s not found!!", FLAGS_check_report.c_str()); "check_report file %s not found!!", FLAGS_check_report.c_str());
@@ -979,6 +986,11 @@ domi::Status GenerateModel(std::map<string, string> &options, std::string output
} }
} }


Status ret = ge::DealKeepDtypeOption(ge::GraphUtils::GetComputeGraph(graph), FLAGS_keep_dtype);
if (ret != SUCCESS) {
return ret;
}

geRet = ge_generator.GenerateOfflineModel(graph, output, inputs); geRet = ge_generator.GenerateOfflineModel(graph, output, inputs);
if (geRet != ge::SUCCESS) { if (geRet != ge::SUCCESS) {
DOMI_LOGE("GE GenerateOfflineModel execute failed"); DOMI_LOGE("GE GenerateOfflineModel execute failed");


+ 3
- 0
ge/offline/module.mk View File

@@ -10,6 +10,7 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg


LOCAL_SRC_FILES := \ LOCAL_SRC_FILES := \
main.cc \ main.cc \
keep_dtype_option.cc \
single_op_parser.cc \ single_op_parser.cc \
../session/omg.cc \ ../session/omg.cc \
../ir_build/atc_ir_common.cc \ ../ir_build/atc_ir_common.cc \
@@ -63,6 +64,7 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg


LOCAL_SRC_FILES := \ LOCAL_SRC_FILES := \
main.cc \ main.cc \
keep_dtype_option.cc \
single_op_parser.cc \ single_op_parser.cc \
../session/omg.cc \ ../session/omg.cc \
../ir_build/atc_ir_common.cc \ ../ir_build/atc_ir_common.cc \
@@ -116,6 +118,7 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg


LOCAL_SRC_FILES := \ LOCAL_SRC_FILES := \
main.cc \ main.cc \
keep_dtype_option.cc \
single_op_parser.cc \ single_op_parser.cc \
../session/omg.cc \ ../session/omg.cc \
../ir_build/atc_ir_common.cc \ ../ir_build/atc_ir_common.cc \


Loading…
Cancel
Save