Browse Source

Pre Merge pull request !1779 from liudingyan/master

pull/1779/MERGE
liudingyan Gitee 4 years ago
parent
commit
ef8a15b4c0
5 changed files with 41 additions and 11 deletions
  1. +1
    -0
      ge/ir_build/attr_options/attr_options.h
  2. +14
    -6
      ge/ir_build/attr_options/keep_dtype_option.cc
  3. +24
    -3
      ge/ir_build/attr_options/utils.cc
  4. +1
    -1
      metadef
  5. +1
    -1
      parser

+ 1
- 0
ge/ir_build/attr_options/attr_options.h View File

@@ -23,6 +23,7 @@
namespace ge {
bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name);
bool IsContainOpType(const std::string &cfg_line, std::string &op_type);
graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path);
graphStatus WeightCompressFunc(ComputeGraphPtr &graph, const std::string &cfg_path);
} // namespace

+ 14
- 6
ge/ir_build/attr_options/keep_dtype_option.cc View File

@@ -32,17 +32,23 @@ void KeepDtypeReportError(const std::vector<std::string> &invalid_list, const st
size_t list_size = invalid_list.size();
err_msg << "config file contains " << list_size;
if (list_size == 1) {
err_msg << " operator not in the graph, op name:";
err_msg << " operator not in the graph, ";
} else {
err_msg << " operators not in the graph, op names:";
err_msg << " operators not in the graph, ";
}
std::string cft_type;
for (size_t i = 0; i < list_size; i++) {
if (i == kMaxOpsNum) {
err_msg << "..";
break;
}
err_msg << invalid_list[i];
bool istype = IsContainOpType(invalid_list[i], cft_type);
if (!istype) {
err_msg << "op name:";
} else {
err_msg << "op type:";
}
err_msg << cft_type;
if (i != list_size - 1) {
err_msg << " ";
}
@@ -72,7 +78,7 @@ graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path) {
return GRAPH_FAILED;
}
std::string op_name;
std::string op_name, op_type;
std::vector<std::string> invalid_list;
while (std::getline(ifs, op_name)) {
if (op_name.empty()) {
@@ -80,11 +86,13 @@ graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path) {
}
op_name = StringUtils::Trim(op_name);
bool is_find = false;
bool is_type = IsContainOpType(op_name, op_type);
for (auto &node_ptr : graph->GetDirectNode()) {
auto op_desc = node_ptr->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if ((op_desc->GetName() == op_name) || IsOriginalOpFind(op_desc, op_name)) {
if ((!is_type && op_desc->GetName() == op_name) || (is_type && op_desc->GetType() == op_type) ||
IsOriginalOpFind(op_desc, op_name)) {
is_find = true;
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1);
}


+ 24
- 3
ge/ir_build/attr_options/utils.cc View File

@@ -19,18 +19,39 @@
#include "common/util/error_manager/error_manager.h"
namespace ge {
namespace {
const std::string CFG_PRE_OPTYPE = "OpType::";
}
bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) {
std::string attrOp = op_name;
std::vector<std::string> original_op_names;
if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) {
return false;
bool istype = IsContainOpType(op_name, attrOp);
if (!istype) {
if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) {
return false;
}
} else {
if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE, original_op_names)) {
return false;
}
}
for (auto &origin_name : original_op_names) {
if (origin_name == op_name) {
if (origin_name == attrOp) {
return true;
}
}
return false;
}
bool IsContainOpType(const std::string &cfg_line, std::string &op_type) {
op_type = cfg_line;
size_t pos = op_type.find(CFG_PRE_OPTYPE);
if (pos != std::string::npos) {
op_type = cfg_line.substr(pos+CFG_PRE_OPTYPE.length());
return true;
}
return false;
}
} // namespace ge

+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit e189fc7f4da9f7714f009d70da4db627de17955d
Subproject commit 1427fa9c6ad1849a941b946e52a469775bb21d87

+ 1
- 1
parser

@@ -1 +1 @@
Subproject commit db5ce472de0086c3e2abdaab3b0685c1d2656c96
Subproject commit 3073129b68c0fae12a8b7531d60782e39128a28c

Loading…
Cancel
Save