Browse Source

!1952 deal with unknown value range

Merge pull request !1952 from 王强/master
tags/v1.5.1
i-robot Gitee 3 years ago
parent
commit
bd759ac31e
3 changed files with 49 additions and 6 deletions
  1. +19
    -5
      ge/graph/passes/infer_value_range_pass.cc
  2. +1
    -1
      ge/graph/passes/infer_value_range_pass.h
  3. +29
    -0
      tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc

+ 19
- 5
ge/graph/passes/infer_value_range_pass.cc View File

@@ -85,8 +85,16 @@ graphStatus InferValueRangePass::Infer(NodePtr &node) {
return GRAPH_SUCCESS;
}

// if input value range has -1, cpu kernel cannot calculate correctly, so set {1:-1}
if (InputHasUnknownValueRange(node)) {
// Deal with scenes with unknown value range
bool has_unknown_value_range = false;
bool has_zero_in_value_range = false;
CheckInputValueRange(node, has_unknown_value_range, has_zero_in_value_range);
if (has_unknown_value_range) {
if (has_zero_in_value_range) {
// When there is zero in input value range, it is unreasonable to always set output value range {1:-1}.
GELOGW("Node %s has -1 and 0 in value range, skip setting value range.", node->GetName().c_str());
return GRAPH_NOT_CHANGED;
}
GELOGI("Node %s has unknown value range in input tensors, set value range {1:-1}, and skip cpu kernel.",
node->GetName().c_str());
return GenerateWorstValueRange(node);
@@ -188,14 +196,21 @@ bool InferValueRangePass::InputIsConstOrHasValueRange(const NodePtr &node) const
return input_is_const_or_has_value_range;
}

bool InferValueRangePass::InputHasUnknownValueRange(const NodePtr &node) const {
bool has_unknown_value_range = false;
void InferValueRangePass::CheckInputValueRange(const NodePtr &node, bool &has_unknown_value_range,
bool &has_zero_in_value_range) const {
has_unknown_value_range = false;
has_zero_in_value_range = false;
auto cur_op_desc = node->GetOpDesc();
for (const auto &input_desc : cur_op_desc->GetAllInputsDescPtr()) {
std::vector<std::pair<int64_t, int64_t>> input_desc_value_range;
input_desc->GetValueRange(input_desc_value_range);
if (!input_desc_value_range.empty()) {
for (const auto &range : input_desc_value_range) {
if (range.first == 0 || range.second == 0) {
GELOGD("Node %s input tensors have zero in value range %s.", node->GetName().c_str(),
formats::RangeToString(input_desc_value_range).c_str());
has_zero_in_value_range = true;
}
if (range.first == -1 || range.second == -1) {
GELOGD("Node %s input tensors have unknown value range, value range is %s.", node->GetName().c_str(),
formats::RangeToString(input_desc_value_range).c_str());
@@ -204,7 +219,6 @@ bool InferValueRangePass::InputHasUnknownValueRange(const NodePtr &node) const {
}
}
}
return has_unknown_value_range;
}

graphStatus InferValueRangePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) {


+ 1
- 1
ge/graph/passes/infer_value_range_pass.h View File

@@ -34,7 +34,7 @@ class InferValueRangePass : public InferBasePass {

bool InputIsDynamic(const NodePtr &node) const;
bool InputIsConstOrHasValueRange(const NodePtr &node) const;
bool InputHasUnknownValueRange(const NodePtr &node) const;
void CheckInputValueRange(const NodePtr &node, bool &has_unknown_value_range, bool &has_zero_in_value_range) const;
graphStatus GenerateWorstValueRange(NodePtr &node);
template <typename T>
graphStatus ConstructData(const GeTensorDesc &tensor_desc, bool use_floor_value, GeTensorPtr &output_ptr);


+ 29
- 0
tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc View File

@@ -365,6 +365,35 @@ TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsHave
EXPECT_EQ(unknown_target_value_range, output_value_range);
}

TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsHaveZeroInValueRange) {
// shape --- add --- sqrt
auto graph = std::make_shared<ComputeGraph>("test_graph");
GeTensorDesc shape_tensor_desc(GeShape({2}), ge::FORMAT_NCHW, ge::DT_INT64);
std::vector<std::pair<int64_t, int64_t>> unknown_value_range = {make_pair(1, -1), make_pair(0, 240)};
shape_tensor_desc.SetValueRange(unknown_value_range);
auto shape_op_desc = std::make_shared<OpDesc>("Shape", "Shape");
shape_op_desc->AddOutputDesc(shape_tensor_desc);
auto shape_node = graph->AddNode(shape_op_desc);

GeTensorDesc add_tensor_desc(GeShape({2}), ge::FORMAT_NCHW, ge::DT_INT64);
auto add_op_desc = std::make_shared<OpDesc>("Add", "Add");
add_op_desc->AddInputDesc(shape_tensor_desc);
add_op_desc->AddInputDesc(shape_tensor_desc);
add_op_desc->AddOutputDesc(add_tensor_desc);
auto add_node = graph->AddNode(add_op_desc);

ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(1));

// test unknown value range
InferValueRangePass infer_pass;
EXPECT_EQ(infer_pass.Run(add_node), SUCCESS);
auto output_0_desc = add_node->GetOpDesc()->GetOutputDesc(0);
std::vector<std::pair<int64_t, int64_t>> out_value_range;
output_0_desc.GetValueRange(out_value_range);
EXPECT_EQ(out_value_range.size(), 0);
}

TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsHaveUnKnownValueRange_ScalarOutput) {
// shape --- add --- sqrt
// constant /


Loading…
Cancel
Save