Browse Source

atc test keepdtype

tags/v1.5.1
liudingyan 3 years ago
parent
commit
beea153eb9
4 changed files with 97 additions and 23 deletions
  1. +3
    -2
      ge/ir_build/attr_options/attr_options.h
  2. +24
    -11
      ge/ir_build/attr_options/keep_dtype_option.cc
  3. +37
    -2
      ge/ir_build/attr_options/utils.cc
  4. +33
    -8
      tests/ut/ge/graph_ir/ge_ir_build_unittest.cc

+ 3
- 2
ge/ir_build/attr_options/attr_options.h View File

@@ -18,11 +18,12 @@

#include <string>
#include "graph/compute_graph.h"
#include "external/graph/ge_error_codes.h"
#include "graph/ge_error_codes.h"

namespace ge {
bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name);

bool IsOpTypeEqual(const ge::NodePtr &node, const std::string &op_type);
bool IsContainOpType(const std::string &cfg_line, std::string &op_type);
graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path);
graphStatus WeightCompressFunc(ComputeGraphPtr &graph, const std::string &cfg_path);
} // namespace

+ 24
- 11
ge/ir_build/attr_options/keep_dtype_option.cc View File

@@ -32,18 +32,24 @@ void KeepDtypeReportError(const std::vector<std::string> &invalid_list, const st
size_t list_size = invalid_list.size();
err_msg << "config file contains " << list_size;
if (list_size == 1) {
err_msg << " operator not in the graph, op name:";
err_msg << " operator not in the graph, ";
} else {
err_msg << " operators not in the graph, op names:";
err_msg << " operators not in the graph, ";
}
std::string cft_type;
for (size_t i = 0; i < list_size; i++) {
if (i == kMaxOpsNum) {
err_msg << "..";
break;
}
err_msg << invalid_list[i];
if (i != list_size - 1) {
bool istype = IsContainOpType(invalid_list[i], cft_type);
if (!istype) {
err_msg << "op name:";
} else {
err_msg << "op type:";
}
err_msg << cft_type;
if (i != (list_size - 1)) {
err_msg << " ";
}
}
@@ -72,7 +78,7 @@ graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path) {
return GRAPH_FAILED;
}
std::string op_name;
std::string op_name, op_type;
std::vector<std::string> invalid_list;
while (std::getline(ifs, op_name)) {
if (op_name.empty()) {
@@ -80,13 +86,20 @@ graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path) {
}
op_name = StringUtils::Trim(op_name);
bool is_find = false;
for (auto &node_ptr : graph->GetDirectNode()) {
bool is_type = IsContainOpType(op_name, op_type);
for (auto &node_ptr : graph->GetAllNodes()) {
auto op_desc = node_ptr->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if ((op_desc->GetName() == op_name) || IsOriginalOpFind(op_desc, op_name)) {
is_find = true;
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1);
if (is_type) {
if (IsOpTypeEqual(node_ptr, op_type)) {
is_find = true;
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1);
}
} else {
if (op_desc->GetName() == op_name || IsOriginalOpFind(op_desc, op_name)) {
is_find = true;
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1);
}
}
}
if (!is_find) {


+ 37
- 2
ge/ir_build/attr_options/utils.cc View File

@@ -16,9 +16,12 @@
#include "ir_build/attr_options/attr_options.h"
#include <vector>
#include "graph/debug/ge_attr_define.h"
#include "common/util/error_manager/error_manager.h"
#include "framework/common/debug/ge_log.h"
#include "graph/common/omg_util.h"
namespace ge {
namespace {
const std::string CFG_PRE_OPTYPE = "OpType::";
}
bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) {
std::vector<std::string> original_op_names;
if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) {
@@ -33,4 +36,36 @@ bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) {
return false;
}
bool IsOpTypeEqual(const ge::NodePtr &node, const std::string &op_type) {
if (op_type != node->GetOpDesc()->GetType()) {
return false;
}
std::string origin_type;
auto ret = GetOriginalType(node, origin_type);
if (ret != SUCCESS) {
GELOGW("[Get][OriginalType] from op:%s failed.", node->GetName().c_str());
return false;
}
if (op_type != origin_type) {
return false;
}
return true;
}
bool IsContainOpType(const std::string &cfg_line, std::string &op_type) {
op_type = cfg_line;
size_t pos = op_type.find(CFG_PRE_OPTYPE);
if (pos != std::string::npos) {
if (pos == 0) {
op_type = cfg_line.substr(CFG_PRE_OPTYPE.length());
return true;
} else {
GELOGW("[Check][Param] %s must be at zero pos of %s", CFG_PRE_OPTYPE.c_str(), cfg_line.c_str());
}
return false;
}
GELOGW("[Check][Param] %s not contain optype", cfg_line.c_str());
return false;
}
} // namespace ge

+ 33
- 8
tests/ut/ge/graph_ir/ge_ir_build_unittest.cc View File

@@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdio.h>
#include <gtest/gtest.h>
#include "ir_build/option_utils.h"
#include "graph/testcase/ge_graph/graph_builder_utils.h"
@@ -21,7 +21,7 @@
#include "graph/utils/graph_utils.h"
#include "ge/ge_ir_build.h"
#include "graph/ops_stub.h"
#include "ge/ir_build/attr_options/attr_options.h"
#define protected public
#define private public

@@ -70,6 +70,22 @@ static ComputeGraphPtr BuildComputeGraph() {
return builder.GetGraph();
}

static ComputeGraphPtr BuildComputeGraph1() {
auto builder = ut::GraphBuilder("test");
auto data1 = builder.AddNode("input1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 2, 3});
auto data2 = builder.AddNode("input2", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {4, 10});
auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1);
auto node1 = builder.AddNode("addd", "Mul", 2, 1);
auto node2 = builder.AddNode("ffm", "FrameworkOp", 2, 1);
auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0);

builder.AddDataEdge(data1, 0, addn1, 0);
builder.AddDataEdge(data2, 0, addn1, 1);
builder.AddDataEdge(addn1, 0,netoutput, 0);

return builder.GetGraph();
}

// data not set attr index;
// but becasue of op proto, register attr index. so all data index is zero;
static Graph BuildIrGraph() {
@@ -89,10 +105,12 @@ static Graph BuildIrGraph1() {
auto data1 = op::Data("data1").set_attr_index(0);
auto data2 = op::Data("data2").set_attr_index(1);
auto data3 = op::Data("data3");
std::vector<Operator> inputs {data1, data2, data3};
auto data4 = op::Data("Test");
std::vector<Operator> inputs {data1, data2, data3, data4};
std::vector<Operator> outputs;

Graph graph("test_graph");
graph.AddNodeByOp(Operator("gg", "Mul"));
graph.SetInputs(inputs).SetOutputs(outputs);
return graph;
}
@@ -373,9 +391,16 @@ TEST(UtestIrBuild, check_modify_mixlist_param) {
EXPECT_EQ(ret, GRAPH_PARAM_INVALID);
}

TEST(UtestIrCommon, check_dynamic_imagesize_input_shape_valid_format_empty) {
std::map<std::string, std::vector<int64_t>> shape_map;
std::string dynamic_image_size = "";
bool ret = CheckDynamicImagesizeInputShapeValid(shape_map, "123", dynamic_image_size);
EXPECT_EQ(ret, false);
TEST(UtestIrBuild, atc_cfg_optype_param) {
ComputeGraphPtr graph = BuildComputeGraph1();
FILE *fp = fopen("./keep.txt", "w+");
if (fp) {
fprintf(fp, "Test\n");
fprintf(fp, "OpType::Mul\n");
fprintf(fp, "Optype::Sub\n");
fclose(fp);
}
auto ret = KeepDtypeFunc(graph, "./keep.txt");
(void)remove("./keep.txt");
EXPECT_EQ(ret, GRAPH_PARAM_INVALID);
}

Loading…
Cancel
Save