| @@ -23,6 +23,7 @@ | |||||
| namespace ge { | namespace ge { | ||||
| bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name); | bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name); | ||||
| bool IsContainOpType(const std::string &cfg_line, std::string &op_type); | |||||
| graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path); | graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path); | ||||
| graphStatus WeightCompressFunc(ComputeGraphPtr &graph, const std::string &cfg_path); | graphStatus WeightCompressFunc(ComputeGraphPtr &graph, const std::string &cfg_path); | ||||
| } // namespace | } // namespace | ||||
| @@ -32,17 +32,23 @@ void KeepDtypeReportError(const std::vector<std::string> &invalid_list, const st | |||||
| size_t list_size = invalid_list.size(); | size_t list_size = invalid_list.size(); | ||||
| err_msg << "config file contains " << list_size; | err_msg << "config file contains " << list_size; | ||||
| if (list_size == 1) { | if (list_size == 1) { | ||||
| err_msg << " operator not in the graph, op name:"; | |||||
| err_msg << " operator not in the graph, "; | |||||
| } else { | } else { | ||||
| err_msg << " operators not in the graph, op names:"; | |||||
| err_msg << " operators not in the graph, "; | |||||
| } | } | ||||
| std::string cft_type; | |||||
| for (size_t i = 0; i < list_size; i++) { | for (size_t i = 0; i < list_size; i++) { | ||||
| if (i == kMaxOpsNum) { | if (i == kMaxOpsNum) { | ||||
| err_msg << ".."; | err_msg << ".."; | ||||
| break; | break; | ||||
| } | } | ||||
| err_msg << invalid_list[i]; | |||||
| bool istype = IsContainOpType(invalid_list[i], cft_type); | |||||
| if (!istype) { | |||||
| err_msg << "op name:"; | |||||
| } else { | |||||
| err_msg << "op type:"; | |||||
| } | |||||
| err_msg << cft_type; | |||||
| if (i != list_size - 1) { | if (i != list_size - 1) { | ||||
| err_msg << " "; | err_msg << " "; | ||||
| } | } | ||||
| @@ -72,7 +78,7 @@ graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path) { | |||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| } | } | ||||
| std::string op_name; | |||||
| std::string op_name, op_type; | |||||
| std::vector<std::string> invalid_list; | std::vector<std::string> invalid_list; | ||||
| while (std::getline(ifs, op_name)) { | while (std::getline(ifs, op_name)) { | ||||
| if (op_name.empty()) { | if (op_name.empty()) { | ||||
| @@ -80,11 +86,13 @@ graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path) { | |||||
| } | } | ||||
| op_name = StringUtils::Trim(op_name); | op_name = StringUtils::Trim(op_name); | ||||
| bool is_find = false; | bool is_find = false; | ||||
| bool is_type = IsContainOpType(op_name, op_type); | |||||
| for (auto &node_ptr : graph->GetDirectNode()) { | for (auto &node_ptr : graph->GetDirectNode()) { | ||||
| auto op_desc = node_ptr->GetOpDesc(); | auto op_desc = node_ptr->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| if ((op_desc->GetName() == op_name) || IsOriginalOpFind(op_desc, op_name)) { | |||||
| if ((!is_type && op_desc->GetName() == op_name) || (is_type && op_desc->GetType() == op_type) || | |||||
| IsOriginalOpFind(op_desc, op_name)) { | |||||
| is_find = true; | is_find = true; | ||||
| (void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1); | (void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1); | ||||
| } | } | ||||
| @@ -19,18 +19,39 @@ | |||||
| #include "common/util/error_manager/error_manager.h" | #include "common/util/error_manager/error_manager.h" | ||||
| namespace ge { | namespace ge { | ||||
| namespace { | |||||
| const std::string CFG_PRE_OPTYPE = "OpType::"; | |||||
| } | |||||
| bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) { | bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) { | ||||
| std::string attrOp = op_name; | |||||
| std::vector<std::string> original_op_names; | std::vector<std::string> original_op_names; | ||||
| if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) { | |||||
| return false; | |||||
| bool istype = IsContainOpType(op_name, attrOp); | |||||
| if (!istype) { | |||||
| if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) { | |||||
| return false; | |||||
| } | |||||
| } else { | |||||
| if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE, original_op_names)) { | |||||
| return false; | |||||
| } | |||||
| } | } | ||||
| for (auto &origin_name : original_op_names) { | for (auto &origin_name : original_op_names) { | ||||
| if (origin_name == op_name) { | |||||
| if (origin_name == attrOp) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| bool IsContainOpType(const std::string &cfg_line, std::string &op_type) { | |||||
| op_type = cfg_line; | |||||
| size_t pos = op_type.find(CFG_PRE_OPTYPE); | |||||
| if (pos != std::string::npos) { | |||||
| op_type = cfg_line.substr(pos+CFG_PRE_OPTYPE.length()); | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -1 +1 @@ | |||||
| Subproject commit e189fc7f4da9f7714f009d70da4db627de17955d | |||||
| Subproject commit 1427fa9c6ad1849a941b946e52a469775bb21d87 | |||||
| @@ -1 +1 @@ | |||||
| Subproject commit db5ce472de0086c3e2abdaab3b0685c1d2656c96 | |||||
| Subproject commit 3073129b68c0fae12a8b7531d60782e39128a28c | |||||