diff --git a/ge/host_kernels/slice_kernel.cc b/ge/host_kernels/slice_kernel.cc index 0867ec2f..c3274465 100644 --- a/ge/host_kernels/slice_kernel.cc +++ b/ge/host_kernels/slice_kernel.cc @@ -16,8 +16,6 @@ #include "host_kernels/slice_kernel.h" -#include - #include "common/ge_inner_error_codes.h" #include "common/op/ge_op_utils.h" #include "common/types.h" @@ -33,30 +31,6 @@ const size_t kSliceInputSize = 3; const size_t kSliceInputIndexX = 0; const size_t kSliceInputIndexBegin = 1; const size_t kSliceInputIndexSize = 2; -const std::set kSupportedDataTypeToLength = { - DT_BOOL, - DT_INT64, - DT_UINT64, - DT_FLOAT, - DT_INT32, - DT_UINT32, - DT_INT8, - DT_UINT8, - DT_INT16, - DT_UINT16, - DT_FLOAT16, - DT_DOUBLE, - DT_DUAL, - DT_DUAL_SUB_INT8, - DT_DUAL_SUB_UINT8, - DT_COMPLEX64, - DT_COMPLEX128, - DT_QINT8, - DT_QINT16, - DT_QINT32, - DT_QUINT8, - DT_QUINT16, -}; } // namespace Status SliceKernel::Compute(const OpDescPtr attr, const std::vector &input, @@ -79,18 +53,9 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vectorGetTensorDesc().GetDataType(); - // check supported - if (kSupportedDataTypeToLength.count(data_type) == 0) { - GELOGW("input_x data_type is [%s], does not supported!", TypeUtils::DataTypeToSerialString(data_type).c_str()); - return NOT_CHANGED; - } - uint32_t type_size = 0; - bool is_success = TypeUtils::GetDataTypeLength(data_type, type_size); - if (!is_success) { - return NOT_CHANGED; - } // check data type of begin and size if (begin->GetTensorDesc().GetDataType() != DT_INT32 || size->GetTensorDesc().GetDataType() != DT_INT32) { GELOGW("Data type of begin and size for slice are not DT_INT32."); @@ -104,7 +69,7 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vectorGetData().size() / type_size; + size_t data_size = x_->GetData().size() / sizeof(int32_t); size_t begin_size = begin->GetData().size() / sizeof(int32_t); size_t size_size = size->GetData().size() / sizeof(int32_t); const ge::GeShape &x_shape = x_->GetTensorDesc().GetShape();