diff --git a/ge/graph/passes/infer_value_range_pass.cc b/ge/graph/passes/infer_value_range_pass.cc index 03a18fdb..c183a599 100644 --- a/ge/graph/passes/infer_value_range_pass.cc +++ b/ge/graph/passes/infer_value_range_pass.cc @@ -85,8 +85,16 @@ graphStatus InferValueRangePass::Infer(NodePtr &node) { return GRAPH_SUCCESS; } - // if input value range has -1, cpu kernel cannot calculate correctly, so set {1:-1} - if (InputHasUnknownValueRange(node)) { + // Deal with scenes with unknown value range + bool has_unknown_value_range = false; + bool has_zero_in_value_range = false; + CheckInputValueRange(node, has_unknown_value_range, has_zero_in_value_range); + if (has_unknown_value_range) { + if (has_zero_in_value_range) { + // When there is zero in input value range, it is unreasonable to always set output value range {1:-1}. + GELOGW("Node %s has -1 and 0 in value range, skip setting value range.", node->GetName().c_str()); + return GRAPH_NOT_CHANGED; + } GELOGI("Node %s has unknown value range in input tensors, set value range {1:-1}, and skip cpu kernel.", node->GetName().c_str()); return GenerateWorstValueRange(node); @@ -188,14 +196,21 @@ bool InferValueRangePass::InputIsConstOrHasValueRange(const NodePtr &node) const return input_is_const_or_has_value_range; } -bool InferValueRangePass::InputHasUnknownValueRange(const NodePtr &node) const { - bool has_unknown_value_range = false; +void InferValueRangePass::CheckInputValueRange(const NodePtr &node, bool &has_unknown_value_range, + bool &has_zero_in_value_range) const { + has_unknown_value_range = false; + has_zero_in_value_range = false; auto cur_op_desc = node->GetOpDesc(); for (const auto &input_desc : cur_op_desc->GetAllInputsDescPtr()) { std::vector> input_desc_value_range; input_desc->GetValueRange(input_desc_value_range); if (!input_desc_value_range.empty()) { for (const auto &range : input_desc_value_range) { + if (range.first == 0 || range.second == 0) { + GELOGD("Node %s input tensors have zero in value range %s.", node->GetName().c_str(), + formats::RangeToString(input_desc_value_range).c_str()); + has_zero_in_value_range = true; + } if (range.first == -1 || range.second == -1) { GELOGD("Node %s input tensors have unknown value range, value range is %s.", node->GetName().c_str(), formats::RangeToString(input_desc_value_range).c_str()); @@ -204,7 +219,6 @@ bool InferValueRangePass::InputHasUnknownValueRange(const NodePtr &node) const { } } } - return has_unknown_value_range; } graphStatus InferValueRangePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { diff --git a/ge/graph/passes/infer_value_range_pass.h b/ge/graph/passes/infer_value_range_pass.h index eb485c87..503b5a9f 100644 --- a/ge/graph/passes/infer_value_range_pass.h +++ b/ge/graph/passes/infer_value_range_pass.h @@ -34,7 +34,7 @@ class InferValueRangePass : public InferBasePass { bool InputIsDynamic(const NodePtr &node) const; bool InputIsConstOrHasValueRange(const NodePtr &node) const; - bool InputHasUnknownValueRange(const NodePtr &node) const; + void CheckInputValueRange(const NodePtr &node, bool &has_unknown_value_range, bool &has_zero_in_value_range) const; graphStatus GenerateWorstValueRange(NodePtr &node); template graphStatus ConstructData(const GeTensorDesc &tensor_desc, bool use_floor_value, GeTensorPtr &output_ptr); diff --git a/tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc b/tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc index c39755b3..014d87dc 100644 --- a/tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc @@ -365,6 +365,35 @@ TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsHave EXPECT_EQ(unknown_target_value_range, output_value_range); } +TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsHaveZeroInValueRange) { + // shape --- add --- sqrt + auto graph = std::make_shared("test_graph"); + GeTensorDesc shape_tensor_desc(GeShape({2}), ge::FORMAT_NCHW, ge::DT_INT64); + std::vector> unknown_value_range = {make_pair(1, -1), make_pair(0, 240)}; + shape_tensor_desc.SetValueRange(unknown_value_range); + auto shape_op_desc = std::make_shared("Shape", "Shape"); + shape_op_desc->AddOutputDesc(shape_tensor_desc); + auto shape_node = graph->AddNode(shape_op_desc); + + GeTensorDesc add_tensor_desc(GeShape({2}), ge::FORMAT_NCHW, ge::DT_INT64); + auto add_op_desc = std::make_shared("Add", "Add"); + add_op_desc->AddInputDesc(shape_tensor_desc); + add_op_desc->AddInputDesc(shape_tensor_desc); + add_op_desc->AddOutputDesc(add_tensor_desc); + auto add_node = graph->AddNode(add_op_desc); + + ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(0)); + ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(1)); + + // test unknown value range + InferValueRangePass infer_pass; + EXPECT_EQ(infer_pass.Run(add_node), SUCCESS); + auto output_0_desc = add_node->GetOpDesc()->GetOutputDesc(0); + std::vector> out_value_range; + output_0_desc.GetValueRange(out_value_range); + EXPECT_EQ(out_value_range.size(), 0); +} + TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsHaveUnKnownValueRange_ScalarOutput) { // shape --- add --- sqrt // constant /