Browse Source

keep dtype add optype

pull/1812/head
liudingyan 4 years ago
parent
commit
3438a28d95
4 changed files with 89 additions and 15 deletions
  1. +2
    -1
      ge/ir_build/attr_options/attr_options.h
  2. +24
    -11
      ge/ir_build/attr_options/keep_dtype_option.cc
  3. +48
    -2
      ge/ir_build/attr_options/utils.cc
  4. +15
    -1
      tests/ut/ge/graph_ir/ge_ir_build_unittest.cc

+ 2
- 1
ge/ir_build/attr_options/attr_options.h View File

@@ -22,7 +22,8 @@
namespace ge {
bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name);
bool IsOpTypeEqual(const ge::NodePtr &node, const std::string &op_type);
bool IsContainOpType(const std::string &cfg_line, std::string &op_type);
graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path);
graphStatus WeightCompressFunc(ComputeGraphPtr &graph, const std::string &cfg_path);
} // namespace

+ 24
- 11
ge/ir_build/attr_options/keep_dtype_option.cc View File

@@ -32,18 +32,24 @@ void KeepDtypeReportError(const std::vector<std::string> &invalid_list, const st
size_t list_size = invalid_list.size();
err_msg << "config file contains " << list_size;
if (list_size == 1) {
err_msg << " operator not in the graph, op name:";
err_msg << " operator not in the graph, ";
} else {
err_msg << " operators not in the graph, op names:";
err_msg << " operators not in the graph, ";
}
std::string cft_type;
for (size_t i = 0; i < list_size; i++) {
if (i == kMaxOpsNum) {
err_msg << "..";
break;
}
err_msg << invalid_list[i];
if (i != list_size - 1) {
bool istype = IsContainOpType(invalid_list[i], cft_type);
if (!istype) {
err_msg << "op name:";
} else {
err_msg << "op type:";
}
err_msg << cft_type;
if (i != (list_size - 1)) {
err_msg << " ";
}
}
@@ -72,7 +78,7 @@ graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path) {
return GRAPH_FAILED;
}
std::string op_name;
std::string op_name, op_type;
std::vector<std::string> invalid_list;
while (std::getline(ifs, op_name)) {
if (op_name.empty()) {
@@ -80,13 +86,20 @@ graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path) {
}
op_name = StringUtils::Trim(op_name);
bool is_find = false;
for (auto &node_ptr : graph->GetDirectNode()) {
bool is_type = IsContainOpType(op_name, op_type);
for (auto &node_ptr : graph->GetAllNodes()) {
auto op_desc = node_ptr->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if ((op_desc->GetName() == op_name) || IsOriginalOpFind(op_desc, op_name)) {
is_find = true;
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1);
if (is_type) {
if (IsOpTypeEqual(node_ptr, op_type)) {
is_find = true;
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1);
}
} else {
if (op_desc->GetName() == op_name || IsOriginalOpFind(op_desc, op_name)) {
is_find = true;
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1);
}
}
}
if (!is_find) {


+ 48
- 2
ge/ir_build/attr_options/utils.cc View File

@@ -16,9 +16,12 @@
#include "attr_options.h"
#include <vector>
#include "graph/debug/ge_attr_define.h"
#include "common/util/error_manager/error_manager.h"
#include "framework/common/debug/ge_log.h"
#include "graph/common/omg_util.h"
namespace ge {
namespace {
const std::string CFG_PRE_OPTYPE = "OpType::";
}
bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) {
std::vector<std::string> original_op_names;
if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) {
@@ -33,4 +36,47 @@ bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) {
return false;
}
bool IsOpTypeEqual(const ge::NodePtr &node, const std::string &op_type) {
if (node == nullptr) {
REPORT_INNER_ERROR("E19999", "param node is nullptr, check invalid");
GELOGE(FAILED, "[Check][Param] node is nullptr");
return false;
}
auto op_desc = node->GetOpDesc();
if (op_desc == nullptr) {
REPORT_INNER_ERROR("E19999", "param node's op desc is nullptr, check invalid");
GELOGE(FAILED, "[Check][Param] node op desc is nullptr");
return false;
}
if (op_type != op_desc->GetType()) {
return false;
}
std::string origin_type;
auto ret = GetOriginalType(node, origin_type);
if (ret != SUCCESS) {
GELOGW("[Get][OriginalType] from op:%s failed.", node->GetName().c_str());
return false;
}
if (op_type != origin_type) {
return false;
}
return true;
}
bool IsContainOpType(const std::string &cfg_line, std::string &op_type) {
op_type = cfg_line;
size_t pos = op_type.find(CFG_PRE_OPTYPE);
if (pos != std::string::npos) {
if (pos == 0) {
op_type = cfg_line.substr(CFG_PRE_OPTYPE.length());
return true;
} else {
GELOGW("[Check][Param] %s must be at zero pos of %s", CFG_PRE_OPTYPE.c_str(), cfg_line.c_str());
}
return false;
}
GELOGW("[Check][Param] %s not contain optype", cfg_line.c_str());
return false;
}
} // namespace ge

+ 15
- 1
tests/ut/ge/graph_ir/ge_ir_build_unittest.cc View File

@@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdio.h>
#include <gtest/gtest.h>
#include "ir_build/option_utils.h"
#include "graph/testcase/ge_graph/graph_builder_utils.h"
@@ -371,4 +371,18 @@ TEST(UtestIrBuild, check_modify_mixlist_param) {
auto ret = aclgrphBuildModel(graph, build_options, model);
EXPECT_EQ(ret, GRAPH_PARAM_INVALID);
}

TEST(UtestIrBuild, check_cfg_optype_param) {
Graph graph = BuildIrGraph1();
FILE *fp = fopen("./keep.txt", "w+");
if (fp) {
fprintf(fp, "Test\n");
fprintf(fp, "OpType::Mul\n");
fprintf(fp, "Optype::Sub\n");
fclose(fp);
}
auto ret = aclgrphSetOpAttr(graph, ATTR_TYPE_KEEP_DTYPE, "./keep.txt");
(void)remove("./keep.txt");
EXPECT_EQ(ret, GRAPH_PARAM_INVALID);
}

Loading…
Cancel
Save