Browse Source

!1649 dts: ir check data attr index valid

From: @zhengyuanhua
Reviewed-by: 
Signed-off-by:
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
dfe484b986
4 changed files with 245 additions and 14 deletions
  1. +117
    -13
      ge/ir_build/ge_ir_build.cc
  2. +1
    -0
      inc/framework/omg/omg_inner_types.h
  3. +4
    -1
      tests/ut/ge/graph/build/model_builder_unittest.cc
  4. +123
    -0
      tests/ut/ge/graph_ir/ge_ir_build_unittest.cc

+ 117
- 13
ge/ir_build/ge_ir_build.cc View File

@@ -251,17 +251,24 @@ class Impl {
omg_context_.dynamic_batch_size.clear();
omg_context_.dynamic_image_size.clear();
omg_context_.dynamic_dims.clear();
omg_context_.user_attr_index_valid = false;
};
~Impl() { (void)generator_.Finalize(); };
graphStatus CheckOptions(const std::map<std::string, std::string> &options);
graphStatus CreateInputsForIRBuild(const ge::Graph &graph, vector<ge::GeTensor> &inputs);
graphStatus UpdateDataOpAttr(const Graph &graph);
graphStatus CheckDataOpAttrIndexValid(const Graph &graph, const std::string &input_shape_range);
graphStatus Init(const Graph &graph, const std::map<std::string, std::string> &options);
graphStatus BuildModel(const Graph &graph, const std::map<std::string, std::string> &options,
ModelBufferData &ge_models);
graphStatus InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format,
bool is_dynamic_input);
graphStatus GetInputShapeRange(const string &input_shape_range,
std::map<string, std::vector<std::pair<int64_t, int64_t>>> &name_shape_range_map,
std::vector<std::vector<std::pair<int64_t, int64_t>>> &index_shape_range_map);
static graphStatus InferShapePrepare(const ComputeGraphPtr &compute_graph);
bool GetUsrAttrIndexValidFlag();
bool IsAttrIndexSetByUser(const ComputeGraphPtr &compute_graph, size_t &data_num, vector<int64_t> &attr_index);
void SetRtSocVersion();
void UpdateThreadContext();
void LoadOpsProto();
@@ -288,11 +295,113 @@ graphStatus Impl::InferShapePrepare(const ComputeGraphPtr &compute_graph) {
return GRAPH_SUCCESS;
}

bool Impl::GetUsrAttrIndexValidFlag() {
return omg_context_.user_attr_index_valid;
}

bool Impl::IsAttrIndexSetByUser(const ComputeGraphPtr &compute_graph,
size_t &data_num,
vector<int64_t> &attr_index) {
bool all_zero_flag = true;
for (ge::NodePtr &input_node : compute_graph->GetDirectNode()) {
GE_CHECK_NOTNULL(input_node);
ge::OpDescPtr op = input_node->GetOpDesc();
GE_CHECK_NOTNULL(op);
if (op->GetType() == DATA) {
data_num++;
GeAttrValue::INT index = 0;
if (AttrUtils::GetInt(op, ATTR_NAME_INDEX, index)) {
if (index != 0) {
all_zero_flag = false;
}
attr_index.push_back(index);
} else {
GELOGW("[Get][AttrIndex] Get index[%ld] failed for op[%s].", index, op->GetName().c_str());
}
}
}
if (data_num > 1 && attr_index.size() == data_num && all_zero_flag) {
GELOGI("Attr indexes are not set by user.");
return false;
}
return true;
}

graphStatus Impl::GetInputShapeRange(const string &input_shape_range,
std::map<string, std::vector<std::pair<int64_t, int64_t>>> &name_shape_range_map,
std::vector<std::vector<std::pair<int64_t, int64_t>>> &index_shape_range_map) {
if (input_shape_range.empty()) {
GELOGI("Input shape range is empty.");
return GRAPH_SUCCESS;
}
Status ret = GRAPH_PARAM_INVALID;
if (input_shape_range.find(":") != string::npos) {
ret = ParseInputShapeRange(input_shape_range, name_shape_range_map);
} else {
ret = ParseInputShapeRange(input_shape_range, index_shape_range_map);
}
if (ret != SUCCESS) {
GELOGE(GRAPH_PARAM_INVALID, "[Parse][InputShapeRange] parse shape range[%s] failed.", input_shape_range.c_str());
return GRAPH_PARAM_INVALID;
}
return GRAPH_SUCCESS;
}

graphStatus Impl::CheckDataOpAttrIndexValid(const Graph &graph, const std::string &input_shape_range) {
auto compute_graph = ge::GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph);

// when set input shape range by index, user must set data attr index, eg. "[1, 3, 3, -1],[1, 3~5, 6, -1]"
bool index_input_shape_range_flag = !input_shape_range.empty() && (input_shape_range.find(":") == string::npos);
size_t data_num = 0;
vector<int64_t> attr_index;
if (!IsAttrIndexSetByUser(compute_graph, data_num, attr_index)) {
if (index_input_shape_range_flag) {
std::string situation = "Data op index";
std::string reason = "it must be set by user, total data op num[" + std::to_string(data_num) + "], "
"when set input shape range by index.";
REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"situation", "reason"}),
std::vector<std::string>({situation, reason}));
GELOGE(GRAPH_FAILED, "[Check][AttrIndex] Data op index is not set, total data op num[%ld], "
"when set input shape range by index.", data_num);
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}

omg_context_.user_attr_index_valid = true;
for (size_t i = 0; i < data_num; ++i) {
if (std::find(attr_index.begin(), attr_index.end(), i) == attr_index.end()) {
omg_context_.user_attr_index_valid = false;
if (index_input_shape_range_flag) {
std::string situation = "Data op index[" + std::to_string(i) + "]";
std::string reason = "it must be set by user, total data op num[" + std::to_string(data_num) + "], "
"when set input shape range by index";
REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"situation", "reason"}),
std::vector<std::string>({situation, reason}));
GELOGE(GRAPH_FAILED, "[Check][AttrIndex] Attr index [%ld] is not set, total data op num[%ld], "
"when set input shape range by index", i, data_num);
return GRAPH_FAILED;
} else {
GELOGW("[Check][AttrIndex] Attr index [%ld] is not set, total data op num[%ld].", i, data_num);
}
}
}
GELOGI("Data op attr indexes are set by user and valid.");
return GRAPH_SUCCESS;
}

graphStatus Impl::UpdateDataOpAttr(const Graph &graph) {
GELOGD("Enter Update Data Attr Process!");
std::string input_shape = (options_.find(kInputShape) == options_.end()) ? "" : options_[kInputShape];
std::string input_shape_range = (options_.find(kInputShapeRange) == options_.end()) ? "" : options_[kInputShapeRange];

graphStatus ret = CheckDataOpAttrIndexValid(graph, input_shape_range);
if (ret != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "[Check][DataOpAttrIndex] fail, shape range[%s].", input_shape_range.c_str());
return GRAPH_FAILED;
}

map<string, vector<int64_t>> shape_map;
vector<pair<string, vector<int64_t>>> user_shape_map;
if (!input_shape.empty()) {
@@ -301,20 +410,13 @@ graphStatus Impl::UpdateDataOpAttr(const Graph &graph) {
}
std::map<string, std::vector<std::pair<int64_t, int64_t>>> name_shape_range_map;
std::vector<std::vector<std::pair<int64_t, int64_t>>> index_shape_range_map;
if (!input_shape_range.empty()) {
Status ret = GRAPH_PARAM_INVALID;
if (input_shape_range.find(":") != string::npos) {
ret = ParseInputShapeRange(input_shape_range, name_shape_range_map);
} else {
ret = ParseInputShapeRange(input_shape_range, index_shape_range_map);
}
if (ret != SUCCESS) {
GELOGE(GRAPH_PARAM_INVALID, "[Parse][InputShapeRange] parse shape range[%s] failed.", input_shape_range.c_str());
return GRAPH_PARAM_INVALID;
}
}
auto compute_graph = ge::GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph);
ret = GetInputShapeRange(input_shape_range, name_shape_range_map, index_shape_range_map);
if (ret != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "[Get][InputShapeRange] fail, shape range[%s].", input_shape_range.c_str());
return GRAPH_FAILED;
}
for (ge::NodePtr &input_node : compute_graph->GetDirectNode()) {
GE_CHECK_NOTNULL(input_node);
ge::OpDescPtr op = input_node->GetOpDesc();
@@ -495,7 +597,9 @@ graphStatus Impl::CreateInputsForIRBuild(const ge::Graph &graph, vector<ge::GeTe
ge::OpDescPtr op = input_node->GetOpDesc();
GE_CHECK_NOTNULL(op);
if (op->GetType() == DATA) {
(void)AttrUtils::SetInt(op, ATTR_NAME_INDEX, index++);
if (!GetUsrAttrIndexValidFlag()) {
(void)AttrUtils::SetInt(op, ATTR_NAME_INDEX, index++);
}
GELOGD("Data op inputDesc size: %zu", op->GetAllInputsDesc().size());
ge::GeTensorDesc tensor = op->GetInputDesc(0);
string data_op_name = op->GetName();


+ 1
- 0
inc/framework/omg/omg_inner_types.h View File

@@ -125,6 +125,7 @@ struct OmgContext {
std::vector<NodePtr> getnext_nosink_nodes;
bool fuzz_compile_flag = false;
std::string atc_cmdline;
bool user_attr_index_valid = false;
};
} // namespace ge



+ 4
- 1
tests/ut/ge/graph/build/model_builder_unittest.cc View File

@@ -17,6 +17,7 @@
#include <gtest/gtest.h>
#include <memory>

#include "graph/common/local_context.h"
#include "graph/anchor.h"
#include "graph/attr_value.h"
#include "graph/debug/ge_attr_define.h"
@@ -165,7 +166,9 @@ void MakeSessionScopeReuseGraph(ge::ComputeGraphPtr graph) {
}

protected:
void SetUp() {}
void SetUp() {
SetLocalOmgContext(domi::GetContext());
}

void TearDown() { GetContext().out_nodes_map.clear(); }
};


+ 123
- 0
tests/ut/ge/graph_ir/ge_ir_build_unittest.cc View File

@@ -18,6 +18,9 @@
#include "ir_build/option_utils.h"
#include "graph/testcase/ge_graph/graph_builder_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "ge/ge_ir_build.h"
#include "graph/ops_stub.h"

#define protected public
#define private public
@@ -37,6 +40,13 @@ class UtestIrCommon : public testing::Test {
void TearDown() {}
};

class UtestIrBuild : public testing::Test {
protected:
void SetUp() {}

void TearDown() {}
};

static ge::OpDescPtr CreateOpDesc(const std::string &name, const std::string &type) {
OpDescPtr op_desc = std::make_shared<ge::OpDesc>(name, type);
ge::GeTensorDesc ge_tensor_desc;
@@ -60,6 +70,59 @@ static ComputeGraphPtr BuildComputeGraph() {
return builder.GetGraph();
}

// data not set attr index;
// but becasue of op proto, register attr index. so all data index is zero;
static Graph BuildIrGraph() {
auto data1 = op::Data("data1");
auto data2 = op::Data("data2");
auto data3 = op::Data("data3");
std::vector<Operator> inputs {data1, data2, data3};
std::vector<Operator> outputs;

Graph graph("test_graph");
graph.SetInputs(inputs).SetOutputs(outputs);
return graph;
}

// data set attr index, but is not valid
static Graph BuildIrGraph1() {
auto data1 = op::Data("data1").set_attr_index(0);
auto data2 = op::Data("data2").set_attr_index(1);
auto data3 = op::Data("data3");
std::vector<Operator> inputs {data1, data2, data3};
std::vector<Operator> outputs;

Graph graph("test_graph");
graph.SetInputs(inputs).SetOutputs(outputs);
return graph;
}

// data set attr index, but is not valid
static Graph BuildIrGraph2() {
auto data1 = op::Data("data1").set_attr_index(0);
auto data2 = op::Data("data2");
auto data3 = op::Data("data3").set_attr_index(2);
std::vector<Operator> inputs {data1, data2, data3};
std::vector<Operator> outputs;

Graph graph("test_graph");
graph.SetInputs(inputs).SetOutputs(outputs);
return graph;
}

// data set attr index
static Graph BuildIrGraph3() {
auto data1 = op::Data("data1").set_attr_index(0);
auto data2 = op::Data("data2").set_attr_index(1);
auto data3 = op::Data("data3").set_attr_index(2);
std::vector<Operator> inputs {data1, data2, data3};
std::vector<Operator> outputs;

Graph graph("test_graph");
graph.SetInputs(inputs).SetOutputs(outputs);
return graph;
}

TEST(UtestIrCommon, update_data_op_shape) {
ge::OpDescPtr op_desc = CreateOpDesc("Data", "Data");
map<string, vector<int64_t>> shape_map;
@@ -227,3 +290,63 @@ TEST(UtestIrCommon, check_param_failed) {

ret = CheckLogParamValidAndSetLogLevel(param_invalid);
}

// Get attr index failed, when set input shape range
TEST(UtestIrBuild, check_data_op_attr_index_invalid_0) {
ComputeGraphPtr compute_graph = BuildComputeGraph();
Graph graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
const map<string, string> build_options = {
{"input_shape_range", "[1, 2~3, -1],[4~5, 3~5, 10],[1, 2~3, -1]"}
};
ModelBufferData model;
graphStatus ret = aclgrphBuildModel(graph, build_options, model);
EXPECT_EQ(ret, GRAPH_FAILED);
}

// not set attr index, when set input shape range
TEST(UtestIrBuild, check_data_op_attr_index_invalid_1) {
Graph graph = BuildIrGraph();
const map<string, string> build_options = {
{"input_shape_range", "[1, 2~3, -1],[4~5, 3~5, 10],[1, 2~3, -1]"}
};
ModelBufferData model;
graphStatus ret = aclgrphBuildModel(graph, build_options, model);
EXPECT_EQ(ret, GRAPH_FAILED);
}

// set attr index, but not valid, when set input shape range
TEST(UtestIrBuild, check_data_op_attr_index_invalid_2) {
Graph graph = BuildIrGraph1();
const map<string, string> build_options = {
{"input_shape_range", "[1, 2~3, -1],[4~5, 3~5, 10],[1, 2~3, -1]"}
};
ModelBufferData model;
graphStatus ret = aclgrphBuildModel(graph, build_options, model);
EXPECT_EQ(ret, GRAPH_FAILED);

Graph graph2 = BuildIrGraph2();
ret = aclgrphBuildModel(graph2, build_options, model);
EXPECT_EQ(ret, GRAPH_FAILED);
}

// set attr index valid, when set input shape range
// only check data op attr index valid func.
TEST(UtestIrBuild, check_data_op_attr_index_valid) {
Graph graph = BuildIrGraph3();
const map<string, string> build_options = {
{"input_shape_range", "[1, 2~3, -1],[4~5, 3~5, 10],[1, 2~3, -1]"}
};
ModelBufferData model;
graphStatus ret = aclgrphBuildModel(graph, build_options, model);
EXPECT_EQ(ret, GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED);
}

// set attr index invalid, when not set input shape range
// only check data op attr index valid func.
TEST(UtestIrBuild, check_data_attr_index_succ_no_input_range) {
Graph graph = BuildIrGraph1();
const map<string, string> build_options;
ModelBufferData model;
graphStatus ret = aclgrphBuildModel(graph, build_options, model);
EXPECT_EQ(ret, GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED);
}

Loading…
Cancel
Save