diff --git a/ge/ir_build/ge_ir_build.cc b/ge/ir_build/ge_ir_build.cc index c9dfac07..bd6a2d3a 100644 --- a/ge/ir_build/ge_ir_build.cc +++ b/ge/ir_build/ge_ir_build.cc @@ -251,17 +251,24 @@ class Impl { omg_context_.dynamic_batch_size.clear(); omg_context_.dynamic_image_size.clear(); omg_context_.dynamic_dims.clear(); + omg_context_.user_attr_index_valid = false; }; ~Impl() { (void)generator_.Finalize(); }; graphStatus CheckOptions(const std::map &options); graphStatus CreateInputsForIRBuild(const ge::Graph &graph, vector &inputs); graphStatus UpdateDataOpAttr(const Graph &graph); + graphStatus CheckDataOpAttrIndexValid(const Graph &graph, const std::string &input_shape_range); graphStatus Init(const Graph &graph, const std::map &options); graphStatus BuildModel(const Graph &graph, const std::map &options, ModelBufferData &ge_models); graphStatus InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format, bool is_dynamic_input); + graphStatus GetInputShapeRange(const string &input_shape_range, + std::map>> &name_shape_range_map, + std::vector>> &index_shape_range_map); static graphStatus InferShapePrepare(const ComputeGraphPtr &compute_graph); + bool GetUsrAttrIndexValidFlag(); + bool IsAttrIndexSetByUser(const ComputeGraphPtr &compute_graph, size_t &data_num, vector &attr_index); void SetRtSocVersion(); void UpdateThreadContext(); void LoadOpsProto(); @@ -288,11 +295,113 @@ graphStatus Impl::InferShapePrepare(const ComputeGraphPtr &compute_graph) { return GRAPH_SUCCESS; } +bool Impl::GetUsrAttrIndexValidFlag() { + return omg_context_.user_attr_index_valid; +} + +bool Impl::IsAttrIndexSetByUser(const ComputeGraphPtr &compute_graph, + size_t &data_num, + vector &attr_index) { + bool all_zero_flag = true; + for (ge::NodePtr &input_node : compute_graph->GetDirectNode()) { + GE_CHECK_NOTNULL(input_node); + ge::OpDescPtr op = input_node->GetOpDesc(); + GE_CHECK_NOTNULL(op); + if (op->GetType() == DATA) { + data_num++; + GeAttrValue::INT index = 0; + if (AttrUtils::GetInt(op, ATTR_NAME_INDEX, index)) { + if (index != 0) { + all_zero_flag = false; + } + attr_index.push_back(index); + } else { + GELOGW("[Get][AttrIndex] Get index[%ld] failed for op[%s].", index, op->GetName().c_str()); + } + } + } + if (data_num > 1 && attr_index.size() == data_num && all_zero_flag) { + GELOGI("Attr indexes are not set by user."); + return false; + } + return true; +} + +graphStatus Impl::GetInputShapeRange(const string &input_shape_range, + std::map>> &name_shape_range_map, + std::vector>> &index_shape_range_map) { + if (input_shape_range.empty()) { + GELOGI("Input shape range is empty."); + return GRAPH_SUCCESS; + } + Status ret = GRAPH_PARAM_INVALID; + if (input_shape_range.find(":") != string::npos) { + ret = ParseInputShapeRange(input_shape_range, name_shape_range_map); + } else { + ret = ParseInputShapeRange(input_shape_range, index_shape_range_map); + } + if (ret != SUCCESS) { + GELOGE(GRAPH_PARAM_INVALID, "[Parse][InputShapeRange] parse shape range[%s] failed.", input_shape_range.c_str()); + return GRAPH_PARAM_INVALID; + } + return GRAPH_SUCCESS; +} + +graphStatus Impl::CheckDataOpAttrIndexValid(const Graph &graph, const std::string &input_shape_range) { + auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); + + // when set input shape range by index, user must set data attr index, eg. "[1, 3, 3, -1],[1, 3~5, 6, -1]" + bool index_input_shape_range_flag = !input_shape_range.empty() && (input_shape_range.find(":") == string::npos); + size_t data_num = 0; + vector attr_index; + if (!IsAttrIndexSetByUser(compute_graph, data_num, attr_index)) { + if (index_input_shape_range_flag) { + std::string situation = "Data op index"; + std::string reason = "it must be set by user, total data op num[" + std::to_string(data_num) + "], " + "when set input shape range by index."; + REPORT_INPUT_ERROR("E19025", std::vector({"situation", "reason"}), + std::vector({situation, reason})); + GELOGE(GRAPH_FAILED, "[Check][AttrIndex] Data op index is not set, total data op num[%ld], " + "when set input shape range by index.", data_num); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; + } + + omg_context_.user_attr_index_valid = true; + for (size_t i = 0; i < data_num; ++i) { + if (std::find(attr_index.begin(), attr_index.end(), i) == attr_index.end()) { + omg_context_.user_attr_index_valid = false; + if (index_input_shape_range_flag) { + std::string situation = "Data op index[" + std::to_string(i) + "]"; + std::string reason = "it must be set by user, total data op num[" + std::to_string(data_num) + "], " + "when set input shape range by index"; + REPORT_INPUT_ERROR("E19025", std::vector({"situation", "reason"}), + std::vector({situation, reason})); + GELOGE(GRAPH_FAILED, "[Check][AttrIndex] Attr index [%ld] is not set, total data op num[%ld], " + "when set input shape range by index", i, data_num); + return GRAPH_FAILED; + } else { + GELOGW("[Check][AttrIndex] Attr index [%ld] is not set, total data op num[%ld].", i, data_num); + } + } + } + GELOGI("Data op attr indexes are set by user and valid."); + return GRAPH_SUCCESS; +} + graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { GELOGD("Enter Update Data Attr Process!"); std::string input_shape = (options_.find(kInputShape) == options_.end()) ? "" : options_[kInputShape]; std::string input_shape_range = (options_.find(kInputShapeRange) == options_.end()) ? "" : options_[kInputShapeRange]; + graphStatus ret = CheckDataOpAttrIndexValid(graph, input_shape_range); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "[Check][DataOpAttrIndex] fail, shape range[%s].", input_shape_range.c_str()); + return GRAPH_FAILED; + } + map> shape_map; vector>> user_shape_map; if (!input_shape.empty()) { @@ -301,20 +410,13 @@ graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { } std::map>> name_shape_range_map; std::vector>> index_shape_range_map; - if (!input_shape_range.empty()) { - Status ret = GRAPH_PARAM_INVALID; - if (input_shape_range.find(":") != string::npos) { - ret = ParseInputShapeRange(input_shape_range, name_shape_range_map); - } else { - ret = ParseInputShapeRange(input_shape_range, index_shape_range_map); - } - if (ret != SUCCESS) { - GELOGE(GRAPH_PARAM_INVALID, "[Parse][InputShapeRange] parse shape range[%s] failed.", input_shape_range.c_str()); - return GRAPH_PARAM_INVALID; - } - } auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); GE_CHECK_NOTNULL(compute_graph); + ret = GetInputShapeRange(input_shape_range, name_shape_range_map, index_shape_range_map); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "[Get][InputShapeRange] fail, shape range[%s].", input_shape_range.c_str()); + return GRAPH_FAILED; + } for (ge::NodePtr &input_node : compute_graph->GetDirectNode()) { GE_CHECK_NOTNULL(input_node); ge::OpDescPtr op = input_node->GetOpDesc(); @@ -495,7 +597,9 @@ graphStatus Impl::CreateInputsForIRBuild(const ge::Graph &graph, vectorGetOpDesc(); GE_CHECK_NOTNULL(op); if (op->GetType() == DATA) { - (void)AttrUtils::SetInt(op, ATTR_NAME_INDEX, index++); + if (!GetUsrAttrIndexValidFlag()) { + (void)AttrUtils::SetInt(op, ATTR_NAME_INDEX, index++); + } GELOGD("Data op inputDesc size: %zu", op->GetAllInputsDesc().size()); ge::GeTensorDesc tensor = op->GetInputDesc(0); string data_op_name = op->GetName(); diff --git a/inc/framework/omg/omg_inner_types.h b/inc/framework/omg/omg_inner_types.h index 1024f7e6..0b799bf2 100644 --- a/inc/framework/omg/omg_inner_types.h +++ b/inc/framework/omg/omg_inner_types.h @@ -125,6 +125,7 @@ struct OmgContext { std::vector getnext_nosink_nodes; bool fuzz_compile_flag = false; std::string atc_cmdline; + bool user_attr_index_valid = false; }; } // namespace ge diff --git a/tests/ut/ge/graph/build/model_builder_unittest.cc b/tests/ut/ge/graph/build/model_builder_unittest.cc index 628d0fda..d544e1a3 100644 --- a/tests/ut/ge/graph/build/model_builder_unittest.cc +++ b/tests/ut/ge/graph/build/model_builder_unittest.cc @@ -17,6 +17,7 @@ #include #include +#include "graph/common/local_context.h" #include "graph/anchor.h" #include "graph/attr_value.h" #include "graph/debug/ge_attr_define.h" @@ -165,7 +166,9 @@ void MakeSessionScopeReuseGraph(ge::ComputeGraphPtr graph) { } protected: - void SetUp() {} + void SetUp() { + SetLocalOmgContext(domi::GetContext()); + } void TearDown() { GetContext().out_nodes_map.clear(); } }; diff --git a/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc b/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc index ec7b9488..fb4a5a8d 100644 --- a/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc +++ b/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc @@ -18,6 +18,9 @@ #include "ir_build/option_utils.h" #include "graph/testcase/ge_graph/graph_builder_utils.h" #include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" +#include "ge/ge_ir_build.h" +#include "graph/ops_stub.h" #define protected public #define private public @@ -37,6 +40,13 @@ class UtestIrCommon : public testing::Test { void TearDown() {} }; +class UtestIrBuild : public testing::Test { + protected: + void SetUp() {} + + void TearDown() {} +}; + static ge::OpDescPtr CreateOpDesc(const std::string &name, const std::string &type) { OpDescPtr op_desc = std::make_shared(name, type); ge::GeTensorDesc ge_tensor_desc; @@ -60,6 +70,59 @@ static ComputeGraphPtr BuildComputeGraph() { return builder.GetGraph(); } +// data not set attr index; +// but becasue of op proto, register attr index. so all data index is zero; +static Graph BuildIrGraph() { + auto data1 = op::Data("data1"); + auto data2 = op::Data("data2"); + auto data3 = op::Data("data3"); + std::vector inputs {data1, data2, data3}; + std::vector outputs; + + Graph graph("test_graph"); + graph.SetInputs(inputs).SetOutputs(outputs); + return graph; +} + +// data set attr index, but is not valid +static Graph BuildIrGraph1() { + auto data1 = op::Data("data1").set_attr_index(0); + auto data2 = op::Data("data2").set_attr_index(1); + auto data3 = op::Data("data3"); + std::vector inputs {data1, data2, data3}; + std::vector outputs; + + Graph graph("test_graph"); + graph.SetInputs(inputs).SetOutputs(outputs); + return graph; +} + +// data set attr index, but is not valid +static Graph BuildIrGraph2() { + auto data1 = op::Data("data1").set_attr_index(0); + auto data2 = op::Data("data2"); + auto data3 = op::Data("data3").set_attr_index(2); + std::vector inputs {data1, data2, data3}; + std::vector outputs; + + Graph graph("test_graph"); + graph.SetInputs(inputs).SetOutputs(outputs); + return graph; +} + +// data set attr index +static Graph BuildIrGraph3() { + auto data1 = op::Data("data1").set_attr_index(0); + auto data2 = op::Data("data2").set_attr_index(1); + auto data3 = op::Data("data3").set_attr_index(2); + std::vector inputs {data1, data2, data3}; + std::vector outputs; + + Graph graph("test_graph"); + graph.SetInputs(inputs).SetOutputs(outputs); + return graph; +} + TEST(UtestIrCommon, update_data_op_shape) { ge::OpDescPtr op_desc = CreateOpDesc("Data", "Data"); map> shape_map; @@ -227,3 +290,63 @@ TEST(UtestIrCommon, check_param_failed) { ret = CheckLogParamValidAndSetLogLevel(param_invalid); } + +// Get attr index failed, when set input shape range +TEST(UtestIrBuild, check_data_op_attr_index_invalid_0) { + ComputeGraphPtr compute_graph = BuildComputeGraph(); + Graph graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph); + const map build_options = { + {"input_shape_range", "[1, 2~3, -1],[4~5, 3~5, 10],[1, 2~3, -1]"} + }; + ModelBufferData model; + graphStatus ret = aclgrphBuildModel(graph, build_options, model); + EXPECT_EQ(ret, GRAPH_FAILED); +} + +// not set attr index, when set input shape range +TEST(UtestIrBuild, check_data_op_attr_index_invalid_1) { + Graph graph = BuildIrGraph(); + const map build_options = { + {"input_shape_range", "[1, 2~3, -1],[4~5, 3~5, 10],[1, 2~3, -1]"} + }; + ModelBufferData model; + graphStatus ret = aclgrphBuildModel(graph, build_options, model); + EXPECT_EQ(ret, GRAPH_FAILED); +} + +// set attr index, but not valid, when set input shape range +TEST(UtestIrBuild, check_data_op_attr_index_invalid_2) { + Graph graph = BuildIrGraph1(); + const map build_options = { + {"input_shape_range", "[1, 2~3, -1],[4~5, 3~5, 10],[1, 2~3, -1]"} + }; + ModelBufferData model; + graphStatus ret = aclgrphBuildModel(graph, build_options, model); + EXPECT_EQ(ret, GRAPH_FAILED); + + Graph graph2 = BuildIrGraph2(); + ret = aclgrphBuildModel(graph2, build_options, model); + EXPECT_EQ(ret, GRAPH_FAILED); +} + +// set attr index valid, when set input shape range +// only check data op attr index valid func. +TEST(UtestIrBuild, check_data_op_attr_index_valid) { + Graph graph = BuildIrGraph3(); + const map build_options = { + {"input_shape_range", "[1, 2~3, -1],[4~5, 3~5, 10],[1, 2~3, -1]"} + }; + ModelBufferData model; + graphStatus ret = aclgrphBuildModel(graph, build_options, model); + EXPECT_EQ(ret, GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED); +} + +// set attr index invalid, when not set input shape range +// only check data op attr index valid func. +TEST(UtestIrBuild, check_data_attr_index_succ_no_input_range) { + Graph graph = BuildIrGraph1(); + const map build_options; + ModelBufferData model; + graphStatus ret = aclgrphBuildModel(graph, build_options, model); + EXPECT_EQ(ret, GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED); +} \ No newline at end of file