From 1ce576831c72350196866563a07ca9a0f65bd686 Mon Sep 17 00:00:00 2001 From: wxl Date: Thu, 25 Feb 2021 21:05:25 +0800 Subject: [PATCH 1/2] fix slice constant folding bug --- ge/host_kernels/slice_kernel.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ge/host_kernels/slice_kernel.cc b/ge/host_kernels/slice_kernel.cc index c3274465..6b91db1d 100644 --- a/ge/host_kernels/slice_kernel.cc +++ b/ge/host_kernels/slice_kernel.cc @@ -56,6 +56,8 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vectorGetTensorDesc().GetDataType(); + uint32_t type_size = 0; + (void)TypeUtils::GetDataTypeLength(data_type, type_size); // check data type of begin and size if (begin->GetTensorDesc().GetDataType() != DT_INT32 || size->GetTensorDesc().GetDataType() != DT_INT32) { GELOGW("Data type of begin and size for slice are not DT_INT32."); @@ -69,7 +71,7 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vectorGetData().size() / sizeof(int32_t); + size_t data_size = x_->GetData().size() / type_size; size_t begin_size = begin->GetData().size() / sizeof(int32_t); size_t size_size = size->GetData().size() / sizeof(int32_t); const ge::GeShape &x_shape = x_->GetTensorDesc().GetShape(); From db9a09e292c1d3841cb9f6e83033dd92ffad0b2a Mon Sep 17 00:00:00 2001 From: wxl Date: Fri, 26 Feb 2021 10:08:31 +0800 Subject: [PATCH 2/2] fix slice constant folding bug --- ge/host_kernels/slice_kernel.cc | 37 +++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/ge/host_kernels/slice_kernel.cc b/ge/host_kernels/slice_kernel.cc index 6b91db1d..0867ec2f 100644 --- a/ge/host_kernels/slice_kernel.cc +++ b/ge/host_kernels/slice_kernel.cc @@ -16,6 +16,8 @@ #include "host_kernels/slice_kernel.h" +#include + #include "common/ge_inner_error_codes.h" #include "common/op/ge_op_utils.h" #include "common/types.h" @@ -31,6 +33,30 @@ const size_t kSliceInputSize = 3; const size_t kSliceInputIndexX = 0; const size_t kSliceInputIndexBegin = 1; const size_t kSliceInputIndexSize = 2; +const std::set kSupportedDataTypeToLength = { + DT_BOOL, + DT_INT64, + DT_UINT64, + DT_FLOAT, + DT_INT32, + DT_UINT32, + DT_INT8, + DT_UINT8, + DT_INT16, + DT_UINT16, + DT_FLOAT16, + DT_DOUBLE, + DT_DUAL, + DT_DUAL_SUB_INT8, + DT_DUAL_SUB_UINT8, + DT_COMPLEX64, + DT_COMPLEX128, + DT_QINT8, + DT_QINT16, + DT_QINT32, + DT_QUINT8, + DT_QUINT16, +}; } // namespace Status SliceKernel::Compute(const OpDescPtr attr, const std::vector &input, @@ -53,11 +79,18 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vectorGetTensorDesc().GetDataType(); + // check supported + if (kSupportedDataTypeToLength.count(data_type) == 0) { + GELOGW("input_x data_type is [%s], does not supported!", TypeUtils::DataTypeToSerialString(data_type).c_str()); + return NOT_CHANGED; + } uint32_t type_size = 0; - (void)TypeUtils::GetDataTypeLength(data_type, type_size); + bool is_success = TypeUtils::GetDataTypeLength(data_type, type_size); + if (!is_success) { + return NOT_CHANGED; + } // check data type of begin and size if (begin->GetTensorDesc().GetDataType() != DT_INT32 || size->GetTensorDesc().GetDataType() != DT_INT32) { GELOGW("Data type of begin and size for slice are not DT_INT32.");