From ab65fa1a409928f3ad116d129185fd1297c61214 Mon Sep 17 00:00:00 2001 From: liudingyan Date: Sat, 19 Jun 2021 10:04:32 +0800 Subject: [PATCH] keep dtype add optype --- ge/ir_build/attr_options/attr_options.h | 1 + ge/ir_build/attr_options/keep_dtype_option.cc | 20 +++++++++----- ge/ir_build/attr_options/utils.cc | 27 ++++++++++++++++--- metadef | 2 +- parser | 2 +- 5 files changed, 41 insertions(+), 11 deletions(-) diff --git a/ge/ir_build/attr_options/attr_options.h b/ge/ir_build/attr_options/attr_options.h index 7c0f4f4f..6cddff94 100644 --- a/ge/ir_build/attr_options/attr_options.h +++ b/ge/ir_build/attr_options/attr_options.h @@ -23,6 +23,7 @@ namespace ge { bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name); +bool IsContainOpType(const std::string &cfg_line, std::string &op_type); graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path); graphStatus WeightCompressFunc(ComputeGraphPtr &graph, const std::string &cfg_path); } // namespace diff --git a/ge/ir_build/attr_options/keep_dtype_option.cc b/ge/ir_build/attr_options/keep_dtype_option.cc index dfdd0df3..de2d1ef7 100644 --- a/ge/ir_build/attr_options/keep_dtype_option.cc +++ b/ge/ir_build/attr_options/keep_dtype_option.cc @@ -32,17 +32,23 @@ void KeepDtypeReportError(const std::vector &invalid_list, const st size_t list_size = invalid_list.size(); err_msg << "config file contains " << list_size; if (list_size == 1) { - err_msg << " operator not in the graph, op name:"; + err_msg << " operator not in the graph, "; } else { - err_msg << " operators not in the graph, op names:"; + err_msg << " operators not in the graph, "; } - + std::string cft_type; for (size_t i = 0; i < list_size; i++) { if (i == kMaxOpsNum) { err_msg << ".."; break; } - err_msg << invalid_list[i]; + bool istype = IsContainOpType(invalid_list[i], cft_type); + if (!istype) { + err_msg << "op name:"; + } else { + err_msg << "op type:"; + } + err_msg << cft_type; if (i != list_size - 1) { err_msg << " "; } @@ -72,7 +78,7 @@ graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path) { return GRAPH_FAILED; } - std::string op_name; + std::string op_name, op_type; std::vector invalid_list; while (std::getline(ifs, op_name)) { if (op_name.empty()) { @@ -80,11 +86,13 @@ graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path) { } op_name = StringUtils::Trim(op_name); bool is_find = false; + bool is_type = IsContainOpType(op_name, op_type); for (auto &node_ptr : graph->GetDirectNode()) { auto op_desc = node_ptr->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); - if ((op_desc->GetName() == op_name) || IsOriginalOpFind(op_desc, op_name)) { + if ((!is_type && op_desc->GetName() == op_name) || (is_type && op_desc->GetType() == op_type) || + IsOriginalOpFind(op_desc, op_name)) { is_find = true; (void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1); } diff --git a/ge/ir_build/attr_options/utils.cc b/ge/ir_build/attr_options/utils.cc index f0b559ec..abefaf84 100644 --- a/ge/ir_build/attr_options/utils.cc +++ b/ge/ir_build/attr_options/utils.cc @@ -19,18 +19,39 @@ #include "common/util/error_manager/error_manager.h" namespace ge { + namespace { + const std::string CFG_PRE_OPTYPE = "OpType::"; +} bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) { + std::string attrOp = op_name; std::vector original_op_names; - if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) { - return false; + bool istype = IsContainOpType(op_name, attrOp); + if (!istype) { + if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) { + return false; + } + } else { + if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE, original_op_names)) { + return false; + } } for (auto &origin_name : original_op_names) { - if (origin_name == op_name) { + if (origin_name == attrOp) { return true; } } return false; } + +bool IsContainOpType(const std::string &cfg_line, std::string &op_type) { + op_type = cfg_line; + size_t pos = op_type.find(CFG_PRE_OPTYPE); + if (pos != std::string::npos) { + op_type = cfg_line.substr(pos+CFG_PRE_OPTYPE.length()); + return true; + } + return false; +} } // namespace ge \ No newline at end of file diff --git a/metadef b/metadef index e189fc7f..1427fa9c 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit e189fc7f4da9f7714f009d70da4db627de17955d +Subproject commit 1427fa9c6ad1849a941b946e52a469775bb21d87 diff --git a/parser b/parser index db5ce472..3073129b 160000 --- a/parser +++ b/parser @@ -1 +1 @@ -Subproject commit db5ce472de0086c3e2abdaab3b0685c1d2656c96 +Subproject commit 3073129b68c0fae12a8b7531d60782e39128a28c