From f3db5fe415e0372c2731a92715ee4238225f1872 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=B6=9B?= Date: Sat, 27 Feb 2021 15:55:47 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9B=9E=E9=80=80=20'Pull=20Request=20!1156=20?= =?UTF-8?q?:=20fix=20slice=20kernel=20compute=20error=20question'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ge/host_kernels/slice_kernel.cc | 39 ++------------------------------- 1 file changed, 2 insertions(+), 37 deletions(-) diff --git a/ge/host_kernels/slice_kernel.cc b/ge/host_kernels/slice_kernel.cc index 0867ec2f..c3274465 100644 --- a/ge/host_kernels/slice_kernel.cc +++ b/ge/host_kernels/slice_kernel.cc @@ -16,8 +16,6 @@ #include "host_kernels/slice_kernel.h" -#include - #include "common/ge_inner_error_codes.h" #include "common/op/ge_op_utils.h" #include "common/types.h" @@ -33,30 +31,6 @@ const size_t kSliceInputSize = 3; const size_t kSliceInputIndexX = 0; const size_t kSliceInputIndexBegin = 1; const size_t kSliceInputIndexSize = 2; -const std::set kSupportedDataTypeToLength = { - DT_BOOL, - DT_INT64, - DT_UINT64, - DT_FLOAT, - DT_INT32, - DT_UINT32, - DT_INT8, - DT_UINT8, - DT_INT16, - DT_UINT16, - DT_FLOAT16, - DT_DOUBLE, - DT_DUAL, - DT_DUAL_SUB_INT8, - DT_DUAL_SUB_UINT8, - DT_COMPLEX64, - DT_COMPLEX128, - DT_QINT8, - DT_QINT16, - DT_QINT32, - DT_QUINT8, - DT_QUINT16, -}; } // namespace Status SliceKernel::Compute(const OpDescPtr attr, const std::vector &input, @@ -79,18 +53,9 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vectorGetTensorDesc().GetDataType(); - // check supported - if (kSupportedDataTypeToLength.count(data_type) == 0) { - GELOGW("input_x data_type is [%s], does not supported!", TypeUtils::DataTypeToSerialString(data_type).c_str()); - return NOT_CHANGED; - } - uint32_t type_size = 0; - bool is_success = TypeUtils::GetDataTypeLength(data_type, type_size); - if (!is_success) { - return NOT_CHANGED; - } // check data type of begin and size if (begin->GetTensorDesc().GetDataType() != DT_INT32 || size->GetTensorDesc().GetDataType() != DT_INT32) { GELOGW("Data type of begin and size for slice are not DT_INT32."); @@ -104,7 +69,7 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vectorGetData().size() / type_size; + size_t data_size = x_->GetData().size() / sizeof(int32_t); size_t begin_size = begin->GetData().size() / sizeof(int32_t); size_t size_size = size->GetData().size() / sizeof(int32_t); const ge::GeShape &x_shape = x_->GetTensorDesc().GetShape();