From 55a5e8019df6c0e53c474c27e91a045d5492002c Mon Sep 17 00:00:00 2001 From: lianghao Date: Tue, 22 Jun 2021 21:49:01 +0800 Subject: [PATCH] FillKernel --- ge/host_kernels/fill_kernel.cc | 8 +++++ .../folding_kernel/fill_kernel_unittest.cc | 36 +++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/ge/host_kernels/fill_kernel.cc b/ge/host_kernels/fill_kernel.cc index e41c5bf3..ac46101b 100644 --- a/ge/host_kernels/fill_kernel.cc +++ b/ge/host_kernels/fill_kernel.cc @@ -45,6 +45,7 @@ Status FillKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vectorGetName().c_str()); GE_CHECK_NOTNULL(input.at(kFillDimsInputIndex)); GE_CHECK_NOTNULL(input.at(kFillDataInputIndex)); @@ -57,6 +58,13 @@ Status FillKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vectorGetOutputDescPtr(0); + GE_CHECK_NOTNULL(output_desc); + if (output_desc->GetShape().IsUnknownShape()) { + GELOGD("Output is unknown shape, [%s] skip FillKernel.", op_desc_ptr->GetName().c_str()); + return NOT_CHANGED; + } + GeTensorPtr output_ptr; output_ptr = MakeShared(op_desc_ptr->GetOutputDesc(0)); if (output_ptr == nullptr) { diff --git a/tests/ut/ge/graph/passes/folding_kernel/fill_kernel_unittest.cc b/tests/ut/ge/graph/passes/folding_kernel/fill_kernel_unittest.cc index f58d6d9b..c0cce260 100644 --- a/tests/ut/ge/graph/passes/folding_kernel/fill_kernel_unittest.cc +++ b/tests/ut/ge/graph/passes/folding_kernel/fill_kernel_unittest.cc @@ -64,6 +64,7 @@ class UtestGraphPassesFoldingKernelFillKernel : public testing::Test { op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc); + op_desc_ptr->AddOutputDesc(dims_tensor_desc); std::vector input = {dim_tensor, value_tensor}; std::vector outputs; @@ -124,6 +125,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillBoolShape2And3) { op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc); + op_desc_ptr->AddOutputDesc(dims_tensor_desc); std::vector input = {dim_tensor, value_tensor}; std::vector outputs; @@ -230,6 +232,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsHaveNegativeNumber) { op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc); + op_desc_ptr->AddOutputDesc(dims_tensor_desc); std::vector input = {dim_tensor, value_tensor}; std::vector outputs; @@ -284,6 +287,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsTypeNotSupport) { op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc); + op_desc_ptr->AddOutputDesc(dims_tensor_desc); std::vector input = {dim_tensor, value_tensor}; std::vector outputs; @@ -310,6 +314,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsOverflow) { op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc); + op_desc_ptr->AddOutputDesc(dims_tensor_desc); std::vector input = {dim_tensor, value_tensor}; std::vector outputs; @@ -336,6 +341,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsMulDataTypeOverflow) { op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc); + op_desc_ptr->AddOutputDesc(dims_tensor_desc); std::vector input = {dim_tensor, value_tensor}; std::vector outputs; @@ -343,3 +349,33 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsMulDataTypeOverflow) { EXPECT_EQ(PARAM_INVALID, status); } + +TEST_F(UtestGraphPassesFoldingKernelFillKernel, OutputdescUnknown) { + ge::OpDescPtr op_dims = std::make_shared(); + vector dims_vec = {2}; + vector dims_value_vec = {2, 3}; + GeTensorDesc dims_tensor_desc(GeShape(dims_vec), FORMAT_NCHW, DT_INT32); + GeTensorPtr dim_tensor = std::make_shared(dims_tensor_desc, (uint8_t *) dims_value_vec.data(), + dims_value_vec.size() * sizeof(int32_t)); + OpDescUtils::SetWeights(op_dims, dim_tensor); + + ge::OpDescPtr op_value = std::make_shared(); + vector data_vec = {1}; + GeTensorDesc value_tensor_desc(GeShape(), FORMAT_NCHW, DT_BOOL); + GeTensorPtr value_tensor = + std::make_shared(value_tensor_desc, (uint8_t *) data_vec.data(), data_vec.size() * sizeof(bool)); + OpDescUtils::SetWeights(op_value, value_tensor); + + op_desc_ptr->AddInputDesc(dims_tensor_desc); + op_desc_ptr->AddInputDesc(value_tensor_desc); + + vector out_vec = {-1, -1}; + GeTensorDesc out_tensor_desc(GeShape(out_vec), FORMAT_NCHW, DT_INT32); + op_desc_ptr->AddOutputDesc(out_tensor_desc); + + std::vector input = {dim_tensor, value_tensor}; + std::vector outputs; + Status status = kernel->Compute(op_desc_ptr, input, outputs); + + EXPECT_EQ(NOT_CHANGED, status); +} \ No newline at end of file