diff --git a/ge/graph/passes/infer_value_range_pass.cc b/ge/graph/passes/infer_value_range_pass.cc index b9cb88bc..e714e90a 100644 --- a/ge/graph/passes/infer_value_range_pass.cc +++ b/ge/graph/passes/infer_value_range_pass.cc @@ -286,6 +286,9 @@ graphStatus InferValueRangePass::GenerateWorstValueRange(NodePtr &node) { } std::vector> output_i_value_range(output_i_shape_size, {1, -1}); + if (output_i_shape.IsScalar()) { + output_i_value_range.emplace_back(1, -1); + } output_desc->SetValueRange(output_i_value_range); GELOGD("Node %s output %zu shape is %s, the generated worst value range is %s.", node->GetName().c_str(), i, formats::ShapeToString(output_i_shape).c_str(), formats::RangeToString(output_i_value_range).c_str()); diff --git a/ge/graph/passes/replace_with_empty_const_pass.cc b/ge/graph/passes/replace_with_empty_const_pass.cc index 3176d1ee..6cb31627 100644 --- a/ge/graph/passes/replace_with_empty_const_pass.cc +++ b/ge/graph/passes/replace_with_empty_const_pass.cc @@ -71,7 +71,7 @@ Status ReplaceWithEmptyConstPass::Run(NodePtr &node) { GELOGI("Node %s Got empty output_desc_ptr, ignore current pass.", node->GetName().c_str()); return SUCCESS; } - if (!IsEmptyTenor(output_desc_ptr->GetShape())) { + if (!IsKnownEmptyTenor(output_desc_ptr->GetShape())) { is_all_output_empty = false; break; } @@ -107,12 +107,16 @@ Status ReplaceWithEmptyConstPass::GetOutputsOfCurrNode(const NodePtr &node_to_re return SUCCESS; } -bool ReplaceWithEmptyConstPass::IsEmptyTenor(const GeShape &shape) const { +bool ReplaceWithEmptyConstPass::IsKnownEmptyTenor(const GeShape &shape) const { + bool is_known_empty_tensor = false; for (auto dim : shape.GetDims()) { - if (dim == 0) { - return true; + if (dim < 0) { + // current dim is unknown dim, skip replace + return false; + } else if (dim == 0) { + is_known_empty_tensor = true; } } - return false; + return is_known_empty_tensor; } } // namespace ge diff --git a/ge/graph/passes/replace_with_empty_const_pass.h b/ge/graph/passes/replace_with_empty_const_pass.h index fde75358..90103432 100644 --- a/ge/graph/passes/replace_with_empty_const_pass.h +++ b/ge/graph/passes/replace_with_empty_const_pass.h @@ -26,7 +26,7 @@ class ReplaceWithEmptyConstPass : public FoldingPass { private: Status GetOutputsOfCurrNode(const NodePtr &node_to_replace, vector &outputs); - bool IsEmptyTenor(const GeShape &shape) const; + bool IsKnownEmptyTenor(const GeShape &shape) const; }; } // namespace ge #endif // GE_GRAPH_PASSES_REPLACE_WITH_EMPTY_CONST_PASS_H_ diff --git a/tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc b/tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc index fea1b27d..576d679c 100644 --- a/tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc @@ -362,6 +362,54 @@ TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsHave EXPECT_EQ(unknown_target_value_range, output_value_range); } +TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsHaveUnKnownValueRange_ScalarOutput) { + // shape --- add --- sqrt + // constant / + auto graph = std::make_shared("test_graph"); + vector data_vec = {1}; + GeTensorDesc const_tensor_desc(ge::GeShape(), ge::FORMAT_NCHW, ge::DT_INT64); + GeTensorPtr const_tensor = + std::make_shared(const_tensor_desc, (uint8_t *)data_vec.data(), data_vec.size() * sizeof(int64_t)); + + auto const_op_desc = std::make_shared("Constant", "Constant"); + const_op_desc->AddOutputDesc(const_tensor_desc); + EXPECT_EQ(OpDescUtils::SetWeights(const_op_desc, const_tensor), GRAPH_SUCCESS); + auto const_node = graph->AddNode(const_op_desc); + + GeTensorDesc shape_tensor_desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64); + std::vector> unknown_value_range = {make_pair(1, -1)}; + shape_tensor_desc.SetValueRange(unknown_value_range); + auto shape_op_desc = std::make_shared("Shape", "Shape"); + shape_op_desc->AddOutputDesc(shape_tensor_desc); + auto shape_node = graph->AddNode(shape_op_desc); + + GeTensorDesc add_tensor_desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64); + auto add_op_desc = std::make_shared("Add", "Add"); + add_op_desc->AddInputDesc(shape_tensor_desc); + add_op_desc->AddInputDesc(const_tensor_desc); + add_op_desc->AddOutputDesc(add_tensor_desc); + auto add_node = graph->AddNode(add_op_desc); + + ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(0)); + ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(1)); + + // test unknown value range + InferValueRangePass infer_pass; + EXPECT_EQ(infer_pass.Run(add_node), SUCCESS); + auto output_0_desc = add_node->GetOpDesc()->GetOutputDesc(0); + std::vector> out_value_range; + output_0_desc.GetValueRange(out_value_range); + EXPECT_EQ(out_value_range.size(), 1); + + std::vector unknown_target_value_range = {1, -1}; + std::vector output_value_range; + for (auto pair : out_value_range) { + output_value_range.push_back(pair.first); + output_value_range.push_back(pair.second); + } + EXPECT_EQ(unknown_target_value_range, output_value_range); +} + TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsAreKnownValueRange_Int64) { // shape --- add --- sqrt // constant /