| @@ -16,8 +16,6 @@ | |||||
| #include "host_kernels/slice_kernel.h" | #include "host_kernels/slice_kernel.h" | ||||
| #include <set> | |||||
| #include "common/ge_inner_error_codes.h" | #include "common/ge_inner_error_codes.h" | ||||
| #include "common/op/ge_op_utils.h" | #include "common/op/ge_op_utils.h" | ||||
| #include "common/types.h" | #include "common/types.h" | ||||
| @@ -33,30 +31,6 @@ const size_t kSliceInputSize = 3; | |||||
| const size_t kSliceInputIndexX = 0; | const size_t kSliceInputIndexX = 0; | ||||
| const size_t kSliceInputIndexBegin = 1; | const size_t kSliceInputIndexBegin = 1; | ||||
| const size_t kSliceInputIndexSize = 2; | const size_t kSliceInputIndexSize = 2; | ||||
| const std::set<ge::DataType> kSupportedDataTypeToLength = { | |||||
| DT_BOOL, | |||||
| DT_INT64, | |||||
| DT_UINT64, | |||||
| DT_FLOAT, | |||||
| DT_INT32, | |||||
| DT_UINT32, | |||||
| DT_INT8, | |||||
| DT_UINT8, | |||||
| DT_INT16, | |||||
| DT_UINT16, | |||||
| DT_FLOAT16, | |||||
| DT_DOUBLE, | |||||
| DT_DUAL, | |||||
| DT_DUAL_SUB_INT8, | |||||
| DT_DUAL_SUB_UINT8, | |||||
| DT_COMPLEX64, | |||||
| DT_COMPLEX128, | |||||
| DT_QINT8, | |||||
| DT_QINT16, | |||||
| DT_QINT32, | |||||
| DT_QUINT8, | |||||
| DT_QUINT16, | |||||
| }; | |||||
| } // namespace | } // namespace | ||||
| Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTensorPtr> &input, | Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTensorPtr> &input, | ||||
| @@ -79,18 +53,9 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso | |||||
| GELOGW("input tensor is nullptr."); | GELOGW("input tensor is nullptr."); | ||||
| return NOT_CHANGED; | return NOT_CHANGED; | ||||
| } | } | ||||
| // data type in input_x | // data type in input_x | ||||
| auto data_type = x_->GetTensorDesc().GetDataType(); | auto data_type = x_->GetTensorDesc().GetDataType(); | ||||
| // check supported | |||||
| if (kSupportedDataTypeToLength.count(data_type) == 0) { | |||||
| GELOGW("input_x data_type is [%s], does not supported!", TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| return NOT_CHANGED; | |||||
| } | |||||
| uint32_t type_size = 0; | |||||
| bool is_success = TypeUtils::GetDataTypeLength(data_type, type_size); | |||||
| if (!is_success) { | |||||
| return NOT_CHANGED; | |||||
| } | |||||
| // check data type of begin and size | // check data type of begin and size | ||||
| if (begin->GetTensorDesc().GetDataType() != DT_INT32 || size->GetTensorDesc().GetDataType() != DT_INT32) { | if (begin->GetTensorDesc().GetDataType() != DT_INT32 || size->GetTensorDesc().GetDataType() != DT_INT32) { | ||||
| GELOGW("Data type of begin and size for slice are not DT_INT32."); | GELOGW("Data type of begin and size for slice are not DT_INT32."); | ||||
| @@ -104,7 +69,7 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso | |||||
| GE_CHECK_NOTNULL(begin_data); | GE_CHECK_NOTNULL(begin_data); | ||||
| GE_CHECK_NOTNULL(size_data); | GE_CHECK_NOTNULL(size_data); | ||||
| size_t data_size = x_->GetData().size() / type_size; | |||||
| size_t data_size = x_->GetData().size() / sizeof(int32_t); | |||||
| size_t begin_size = begin->GetData().size() / sizeof(int32_t); | size_t begin_size = begin->GetData().size() / sizeof(int32_t); | ||||
| size_t size_size = size->GetData().size() / sizeof(int32_t); | size_t size_size = size->GetData().size() / sizeof(int32_t); | ||||
| const ge::GeShape &x_shape = x_->GetTensorDesc().GetShape(); | const ge::GeShape &x_shape = x_->GetTensorDesc().GetShape(); | ||||