Browse Source

add scalar tensor value range process

tags/v1.5.1
wq160 3 years ago
parent
commit
2c7342bb3a
2 changed files with 70 additions and 5 deletions
  1. +25
    -5
      ge/graph/passes/infer_value_range_pass.cc
  2. +45
    -0
      tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc

+ 25
- 5
ge/graph/passes/infer_value_range_pass.cc View File

@@ -301,12 +301,26 @@ graphStatus InferValueRangePass::ConstructData(const GeTensorDesc &tensor_desc,
GeTensorPtr &output_ptr) { GeTensorPtr &output_ptr) {
std::vector<std::pair<int64_t, int64_t>> value_range; std::vector<std::pair<int64_t, int64_t>> value_range;
(void)tensor_desc.GetValueRange(value_range); (void)tensor_desc.GetValueRange(value_range);
if (static_cast<int64_t>(value_range.size()) != tensor_desc.GetShape().GetShapeSize()) {
GELOGW("Value range of input %s is invalid.", tensor_desc.GetName().c_str());
size_t value_range_data_num = value_range.size();
auto tensor_shape = tensor_desc.GetShape();
bool value_range_and_tensor_shape_matched = true;
if (tensor_shape.IsScalar()){
// scalar tensor has only one value_range pair
if (value_range_data_num != 1) {
value_range_and_tensor_shape_matched = false;
}
} else {
// normal tensor, value_range size is equal to tensor shape size.
if (static_cast<int64_t>(value_range_data_num) != tensor_shape.GetShapeSize()) {
value_range_and_tensor_shape_matched = false;
}
}
if (!value_range_and_tensor_shape_matched) {
GELOGW("Input %s value range and tensor shape do not match. Value range size is %zu, tensor shape is %s.",
tensor_desc.GetName().c_str(), value_range_data_num, formats::ShapeToString(tensor_shape).c_str());
return GRAPH_PARAM_INVALID; return GRAPH_PARAM_INVALID;
} }


size_t value_range_data_num = value_range.size();
unique_ptr<T[]> buf(new (std::nothrow) T[value_range_data_num]()); unique_ptr<T[]> buf(new (std::nothrow) T[value_range_data_num]());
if (buf == nullptr) { if (buf == nullptr) {
REPORT_INNER_ERROR("E19999", "New buf failed"); REPORT_INNER_ERROR("E19999", "New buf failed");
@@ -494,10 +508,16 @@ void InferValueRangePass::ConstructValueRange(const GeTensorPtr &left_tensor, co
GELOGI("Output tensor of cpu kernel does not have data, no way to set value range."); GELOGI("Output tensor of cpu kernel does not have data, no way to set value range.");
return; return;
} }
for (auto j = 0; j < left_tensor->GetTensorDesc().GetShape().GetShapeSize(); ++j) {
auto left_tensor_shape = left_tensor->GetTensorDesc().GetShape();
for (auto j = 0; j < left_tensor_shape.GetShapeSize(); ++j) {
auto left = static_cast<int64_t>(*(x + j)); auto left = static_cast<int64_t>(*(x + j));
auto right = static_cast<int64_t>(*(y + j)); auto right = static_cast<int64_t>(*(y + j));
value_range.emplace_back(std::make_pair(left, right));
value_range.emplace_back(left, right);
}

if (left_tensor_shape.IsScalar()) {
GELOGD("When inferring value range, output tensors of cpu kernel are scalar tensors.");
value_range.emplace_back(static_cast<int64_t>(*x), static_cast<int64_t>(*y));
} }
} }
} // namespace ge } // namespace ge

+ 45
- 0
tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc View File

@@ -293,6 +293,9 @@ class AddKernel : public Kernel {
} else if (input[0]->GetTensorDesc().GetDataType() == DT_INT32 || input[0]->GetTensorDesc().GetDataType() == DT_UINT32) { } else if (input[0]->GetTensorDesc().GetDataType() == DT_INT32 || input[0]->GetTensorDesc().GetDataType() == DT_UINT32) {
vector<int32_t> data_vec; vector<int32_t> data_vec;
auto data_num = input[0]->GetTensorDesc().GetShape().GetShapeSize(); auto data_num = input[0]->GetTensorDesc().GetShape().GetShapeSize();
if (input[0]->GetTensorDesc().GetShape().IsScalar()) {
data_num = 1;
}
auto x1_data = reinterpret_cast<const int32_t *>(input[0]->GetData().data()); auto x1_data = reinterpret_cast<const int32_t *>(input[0]->GetData().data());
auto x2_data = reinterpret_cast<const int32_t *>(input[1]->GetData().data()); auto x2_data = reinterpret_cast<const int32_t *>(input[1]->GetData().data());
for (size_t i = 0; i < data_num; i++) { for (size_t i = 0; i < data_num; i++) {
@@ -410,6 +413,48 @@ TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsHave
EXPECT_EQ(unknown_target_value_range, output_value_range); EXPECT_EQ(unknown_target_value_range, output_value_range);
} }


TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsAreKnownValueRange_ScalarOutput) {
// shape --- add --- sqrt
// constant /
auto graph = std::make_shared<ComputeGraph>("test_graph");
vector<int32_t> data_vec = {2};
GeTensorDesc const_td(ge::GeShape(), ge::FORMAT_NCHW, ge::DT_INT32);
GeTensorPtr const_tensor = std::make_shared<ge::GeTensor>(const_td, (uint8_t *)data_vec.data(), sizeof(int32_t));
auto const_op_desc = std::make_shared<OpDesc>("Constant", "Constant");
const_op_desc->AddOutputDesc(const_td);
EXPECT_EQ(OpDescUtils::SetWeights(const_op_desc, const_tensor), GRAPH_SUCCESS);
auto const_node = graph->AddNode(const_op_desc);

GeTensorDesc shape_td(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32);
std::vector<std::pair<int64_t, int64_t>> known_value_range = {make_pair(1, 100)};
shape_td.SetValueRange(known_value_range);
auto shape_op_desc = std::make_shared<OpDesc>("Shape", "Shape");
shape_op_desc->AddOutputDesc(shape_td);
auto shape_node = graph->AddNode(shape_op_desc);

GeTensorDesc add_td(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32);
auto add_op_desc = std::make_shared<OpDesc>("Add", "Add");
add_op_desc->AddInputDesc(shape_td);
add_op_desc->AddInputDesc(const_td);
add_op_desc->AddOutputDesc(add_td);
auto add_node = graph->AddNode(add_op_desc);

ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(1));

InferValueRangePass infer_pass;
EXPECT_EQ(infer_pass.Run(add_node), SUCCESS);

auto output_0_desc = add_node->GetOpDesc()->GetOutputDesc(0);
std::vector<std::pair<int64_t, int64_t>> out_value_range;
output_0_desc.GetValueRange(out_value_range);
EXPECT_EQ(out_value_range.size(), 1);

std::vector<int64_t> target_value_range = {3, 102};
std::vector<int64_t> output_value_range = {out_value_range[0].first, out_value_range[0].second};
EXPECT_EQ(output_value_range, target_value_range);
}

TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsAreKnownValueRange_Int64) { TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsAreKnownValueRange_Int64) {
// shape --- add --- sqrt // shape --- add --- sqrt
// constant / // constant /


Loading…
Cancel
Save