From 2c7342bb3ae956e6f4884cc0e99068d215c178b0 Mon Sep 17 00:00:00 2001 From: wq160 Date: Tue, 29 Jun 2021 21:10:52 +0800 Subject: [PATCH] add scalar tensor value range process --- ge/graph/passes/infer_value_range_pass.cc | 30 ++++++++++--- .../passes/infer_value_range_pass_unittest.cc | 45 +++++++++++++++++++ 2 files changed, 70 insertions(+), 5 deletions(-) diff --git a/ge/graph/passes/infer_value_range_pass.cc b/ge/graph/passes/infer_value_range_pass.cc index e714e90a..03a18fdb 100644 --- a/ge/graph/passes/infer_value_range_pass.cc +++ b/ge/graph/passes/infer_value_range_pass.cc @@ -301,12 +301,26 @@ graphStatus InferValueRangePass::ConstructData(const GeTensorDesc &tensor_desc, GeTensorPtr &output_ptr) { std::vector> value_range; (void)tensor_desc.GetValueRange(value_range); - if (static_cast(value_range.size()) != tensor_desc.GetShape().GetShapeSize()) { - GELOGW("Value range of input %s is invalid.", tensor_desc.GetName().c_str()); + size_t value_range_data_num = value_range.size(); + auto tensor_shape = tensor_desc.GetShape(); + bool value_range_and_tensor_shape_matched = true; + if (tensor_shape.IsScalar()){ + // scalar tensor has only one value_range pair + if (value_range_data_num != 1) { + value_range_and_tensor_shape_matched = false; + } + } else { + // normal tensor, value_range size is equal to tensor shape size. + if (static_cast(value_range_data_num) != tensor_shape.GetShapeSize()) { + value_range_and_tensor_shape_matched = false; + } + } + if (!value_range_and_tensor_shape_matched) { + GELOGW("Input %s value range and tensor shape do not match. Value range size is %zu, tensor shape is %s.", + tensor_desc.GetName().c_str(), value_range_data_num, formats::ShapeToString(tensor_shape).c_str()); return GRAPH_PARAM_INVALID; } - size_t value_range_data_num = value_range.size(); unique_ptr buf(new (std::nothrow) T[value_range_data_num]()); if (buf == nullptr) { REPORT_INNER_ERROR("E19999", "New buf failed"); @@ -494,10 +508,16 @@ void InferValueRangePass::ConstructValueRange(const GeTensorPtr &left_tensor, co GELOGI("Output tensor of cpu kernel does not have data, no way to set value range."); return; } - for (auto j = 0; j < left_tensor->GetTensorDesc().GetShape().GetShapeSize(); ++j) { + auto left_tensor_shape = left_tensor->GetTensorDesc().GetShape(); + for (auto j = 0; j < left_tensor_shape.GetShapeSize(); ++j) { auto left = static_cast(*(x + j)); auto right = static_cast(*(y + j)); - value_range.emplace_back(std::make_pair(left, right)); + value_range.emplace_back(left, right); + } + + if (left_tensor_shape.IsScalar()) { + GELOGD("When inferring value range, output tensors of cpu kernel are scalar tensors."); + value_range.emplace_back(static_cast(*x), static_cast(*y)); } } } // namespace ge diff --git a/tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc b/tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc index 576d679c..c39755b3 100644 --- a/tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc @@ -293,6 +293,9 @@ class AddKernel : public Kernel { } else if (input[0]->GetTensorDesc().GetDataType() == DT_INT32 || input[0]->GetTensorDesc().GetDataType() == DT_UINT32) { vector data_vec; auto data_num = input[0]->GetTensorDesc().GetShape().GetShapeSize(); + if (input[0]->GetTensorDesc().GetShape().IsScalar()) { + data_num = 1; + } auto x1_data = reinterpret_cast(input[0]->GetData().data()); auto x2_data = reinterpret_cast(input[1]->GetData().data()); for (size_t i = 0; i < data_num; i++) { @@ -410,6 +413,48 @@ TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsHave EXPECT_EQ(unknown_target_value_range, output_value_range); } +TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsAreKnownValueRange_ScalarOutput) { + // shape --- add --- sqrt + // constant / + auto graph = std::make_shared("test_graph"); + vector data_vec = {2}; + GeTensorDesc const_td(ge::GeShape(), ge::FORMAT_NCHW, ge::DT_INT32); + GeTensorPtr const_tensor = std::make_shared(const_td, (uint8_t *)data_vec.data(), sizeof(int32_t)); + auto const_op_desc = std::make_shared("Constant", "Constant"); + const_op_desc->AddOutputDesc(const_td); + EXPECT_EQ(OpDescUtils::SetWeights(const_op_desc, const_tensor), GRAPH_SUCCESS); + auto const_node = graph->AddNode(const_op_desc); + + GeTensorDesc shape_td(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32); + std::vector> known_value_range = {make_pair(1, 100)}; + shape_td.SetValueRange(known_value_range); + auto shape_op_desc = std::make_shared("Shape", "Shape"); + shape_op_desc->AddOutputDesc(shape_td); + auto shape_node = graph->AddNode(shape_op_desc); + + GeTensorDesc add_td(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32); + auto add_op_desc = std::make_shared("Add", "Add"); + add_op_desc->AddInputDesc(shape_td); + add_op_desc->AddInputDesc(const_td); + add_op_desc->AddOutputDesc(add_td); + auto add_node = graph->AddNode(add_op_desc); + + ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(0)); + ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(1)); + + InferValueRangePass infer_pass; + EXPECT_EQ(infer_pass.Run(add_node), SUCCESS); + + auto output_0_desc = add_node->GetOpDesc()->GetOutputDesc(0); + std::vector> out_value_range; + output_0_desc.GetValueRange(out_value_range); + EXPECT_EQ(out_value_range.size(), 1); + + std::vector target_value_range = {3, 102}; + std::vector output_value_range = {out_value_range[0].first, out_value_range[0].second}; + EXPECT_EQ(output_value_range, target_value_range); +} + TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsAreKnownValueRange_Int64) { // shape --- add --- sqrt // constant /