diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index e1921f29..2eae6023 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -1148,6 +1148,10 @@ Status UpdateDynamicInputShapeRange(const ge::GeAttrValue::INT index, GeTensorDesc &desc) { auto origin_shape = desc.GetShape(); auto current_shape_range_vec = range_vec.at(index); + if (origin_shape.IsScalar()) { + GELOGI("Cur input %ld is scalar, no need set shape range.", index); + return SUCCESS; + } if (current_shape_range_vec.size() != origin_shape.GetDimNum()) { REPORT_INNER_ERROR("E19999", "Given shape_range dim num is %zu, current dim:%s num is %zu, not match, " "check invalid", current_shape_range_vec.size(), origin_shape.ToString().c_str(), diff --git a/ge/ir_build/option_utils.cc b/ge/ir_build/option_utils.cc index d7fccee3..3722301a 100755 --- a/ge/ir_build/option_utils.cc +++ b/ge/ir_build/option_utils.cc @@ -39,7 +39,7 @@ const size_t kSquareBracketsSize = 2; const size_t kRangePairSize = 2; const size_t kShapeRangeSize = 2; const size_t kShapeRangeStrIndex = 2; -const size_t kShapeRangeStrSize = 3; +const size_t kShapeRangeStrSize = 1; // datatype/formats from user to GE, Unified to util interface file later const std::map kOutputTypeSupportDatatype = { {"FP32", ge::DT_FLOAT}, {"FP16", ge::DT_FLOAT16}, {"UINT8", ge::DT_UINT8}}; @@ -456,8 +456,9 @@ Status ParseInputShapeRange(const std::string &shape_range, for (auto &shape_range_str : shape_range_set) { if (shape_range_str.size() < kShapeRangeStrSize) { // shape_range_str should be "[2~3,1" - // or ",[2~3,1". because we should trim '[' or ',[' - // so shape_range_str.size() < 3 is invalid + // or ",[2~3,1". because we should trim '[' or ',['. + // For scaler input, shape range should be "[]" + // so shape_range_str.size() < 1 is invalid continue; } // trim start bytes, after that, single input should be "1~20,3,3~6,-1" @@ -472,6 +473,11 @@ Status ParseInputShapeRange(const std::string &shape_range, std::vector> range_of_single_input; vector dim_range_set = ge::StringUtils::Split(shape_range_str, ','); for (const auto &range_pair_str : dim_range_set) { + if (range_pair_str.empty()) { + // for scaler input ,range is empty. use [0,0] as scaler range. + range_of_single_input.emplace_back(std::make_pair(0, 0)); + continue; + } vector range_pair_set = ge::StringUtils::Split(range_pair_str, '~'); pair range_pair; if (!ParseShapeRangePair(shape_range_str, range_pair_set, range_pair)) { diff --git a/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc b/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc index 8ece7564..ebd0ab25 100644 --- a/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc +++ b/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc @@ -120,29 +120,63 @@ ComputeGraphPtr BuildGraph4_Subgraph(string graph_name) { return builder.GetGraph(); } +ComputeGraphPtr BuildGraph6() { + auto builder = ut::GraphBuilder("g6"); + auto data1 = builder.AddNode("input1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {3, -1, -1, 5}); + auto data2 = builder.AddNode("input2", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {}); + AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 0); + AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 1); + auto add = builder.AddNode("add", ADD, 2, 1); + auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); + + builder.AddDataEdge(data1, 0, add, 0); + builder.AddDataEdge(data2, 0, add, 1); + builder.AddDataEdge(add, 0,netoutput, 0); + return builder.GetGraph(); +} + TEST_F(UtestGraphPreproces, test_dynamic_input_shape_parse) { ge::GraphPrepare graph_prepare; - graph_prepare.compute_graph_ = BuildGraph1(); + graph_prepare.compute_graph_ = BuildGraph6(); // prepare user_input & graph option ge::GeTensorDesc tensor1; tensor1.SetFormat(ge::FORMAT_NCHW); tensor1.SetShape(ge::GeShape({3, 12, 5, 5})); tensor1.SetDataType(ge::DT_FLOAT); GeTensor input1(tensor1); - std::vector user_input = {input1}; + ge::GeTensorDesc tensor2; + tensor2.SetFormat(ge::FORMAT_NCHW); + tensor2.SetShape(ge::GeShape()); + tensor2.SetDataType(ge::DT_FLOAT); + GeTensor input2(tensor2); + std::vector user_input = {input1, input2}; std::map graph_option = {{"ge.exec.dynamicGraphExecuteMode","dynamic_execute"}, - {"ge.exec.dataInputsShapeRange","[3,1~20,2~10,5]"}}; + {"ge.exec.dataInputsShapeRange","[3,1~20,2~10,5],[]"}}; auto ret = graph_prepare.UpdateInput(user_input, graph_option); EXPECT_EQ(ret, ge::SUCCESS); - // check data node output shape_range and shape - auto data_node = graph_prepare.compute_graph_->FindNode("data1"); + // check data1 node output shape_range and shape + auto data_node = graph_prepare.compute_graph_->FindNode("input1"); auto data_output_desc = data_node->GetOpDesc()->GetOutputDescPtr(0); - vector expect_shape = {3,-1,-1,5}; - auto result_shape = data_output_desc->GetShape(); - EXPECT_EQ(result_shape.GetDimNum(), expect_shape.size()); - for(size_t i =0; i< expect_shape.size(); ++i){ - EXPECT_EQ(result_shape.GetDim(i), expect_shape.at(i)); + vector input1_expect_shape = {3,-1,-1,5}; + vector> intpu1_expect_shape_range = {{3,3},{1,20},{2,10},{5,5}}; + auto input1_result_shape = data_output_desc->GetShape(); + vector> input1_result_shape_range; + data_output_desc->GetShapeRange(input1_result_shape_range); + EXPECT_EQ(input1_result_shape.GetDimNum(), input1_expect_shape.size()); + EXPECT_EQ(input1_result_shape_range.size(), input1_expect_shape.size()); + for(size_t i =0; i< input1_expect_shape.size(); ++i){ + EXPECT_EQ(input1_result_shape.GetDim(i), input1_expect_shape.at(i)); + } + for(size_t i =0; i< intpu1_expect_shape_range.size(); ++i){ + EXPECT_EQ(input1_result_shape_range.at(i).first, intpu1_expect_shape_range.at(i).first); + EXPECT_EQ(input1_result_shape_range.at(i).second, intpu1_expect_shape_range.at(i).second); } + // check data2 node output shape_range and shape + auto data_node_2 = graph_prepare.compute_graph_->FindNode("input2"); + auto data_output_desc_2 = data_node_2->GetOpDesc()->GetOutputDescPtr(0); + vector> intput2_result_shape_range; + data_output_desc_2->GetShapeRange(intput2_result_shape_range); + EXPECT_EQ(intput2_result_shape_range.size(), 0); } TEST_F(UtestGraphPreproces, test_check_user_input) {