|
|
@@ -16,6 +16,8 @@ |
|
|
|
|
|
|
|
#include "host_kernels/slice_kernel.h" |
|
|
|
|
|
|
|
#include <set> |
|
|
|
|
|
|
|
#include "common/ge_inner_error_codes.h" |
|
|
|
#include "common/op/ge_op_utils.h" |
|
|
|
#include "common/types.h" |
|
|
@@ -31,6 +33,30 @@ const size_t kSliceInputSize = 3; |
|
|
|
const size_t kSliceInputIndexX = 0; |
|
|
|
const size_t kSliceInputIndexBegin = 1; |
|
|
|
const size_t kSliceInputIndexSize = 2; |
|
|
|
const std::set<ge::DataType> kSupportedDataTypeToLength = { |
|
|
|
DT_BOOL, |
|
|
|
DT_INT64, |
|
|
|
DT_UINT64, |
|
|
|
DT_FLOAT, |
|
|
|
DT_INT32, |
|
|
|
DT_UINT32, |
|
|
|
DT_INT8, |
|
|
|
DT_UINT8, |
|
|
|
DT_INT16, |
|
|
|
DT_UINT16, |
|
|
|
DT_FLOAT16, |
|
|
|
DT_DOUBLE, |
|
|
|
DT_DUAL, |
|
|
|
DT_DUAL_SUB_INT8, |
|
|
|
DT_DUAL_SUB_UINT8, |
|
|
|
DT_COMPLEX64, |
|
|
|
DT_COMPLEX128, |
|
|
|
DT_QINT8, |
|
|
|
DT_QINT16, |
|
|
|
DT_QINT32, |
|
|
|
DT_QUINT8, |
|
|
|
DT_QUINT16, |
|
|
|
}; |
|
|
|
} // namespace |
|
|
|
|
|
|
|
Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTensorPtr> &input, |
|
|
@@ -56,6 +82,16 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso |
|
|
|
|
|
|
|
// data type in input_x |
|
|
|
auto data_type = x_->GetTensorDesc().GetDataType(); |
|
|
|
// check supported |
|
|
|
if (kSupportedDataTypeToLength.count(data_type) == 0) { |
|
|
|
GELOGW("input_x data_type is [%s], does not supported!", TypeUtils::DataTypeToSerialString(data_type).c_str()); |
|
|
|
return NOT_CHANGED; |
|
|
|
} |
|
|
|
uint32_t type_size = 0; |
|
|
|
bool is_success = TypeUtils::GetDataTypeLength(data_type, type_size); |
|
|
|
if (!is_success) { |
|
|
|
return NOT_CHANGED; |
|
|
|
} |
|
|
|
// check data type of begin and size |
|
|
|
if (begin->GetTensorDesc().GetDataType() != DT_INT32 || size->GetTensorDesc().GetDataType() != DT_INT32) { |
|
|
|
GELOGW("Data type of begin and size for slice are not DT_INT32."); |
|
|
@@ -69,7 +105,7 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso |
|
|
|
GE_CHECK_NOTNULL(begin_data); |
|
|
|
GE_CHECK_NOTNULL(size_data); |
|
|
|
|
|
|
|
size_t data_size = x_->GetData().size() / sizeof(int32_t); |
|
|
|
size_t data_size = x_->GetData().size() / type_size; |
|
|
|
size_t begin_size = begin->GetData().size() / sizeof(int32_t); |
|
|
|
size_t size_size = size->GetData().size() / sizeof(int32_t); |
|
|
|
const ge::GeShape &x_shape = x_->GetTensorDesc().GetShape(); |
|
|
|