Browse Source

回退 'Pull Request !1156 : fix slice kernel compute error question'

tags/v1.2.0
王涛 Gitee 3 years ago
parent
commit
f3db5fe415
1 changed files with 2 additions and 37 deletions
  1. +2
    -37
      ge/host_kernels/slice_kernel.cc

+ 2
- 37
ge/host_kernels/slice_kernel.cc View File

@@ -16,8 +16,6 @@


#include "host_kernels/slice_kernel.h" #include "host_kernels/slice_kernel.h"


#include <set>

#include "common/ge_inner_error_codes.h" #include "common/ge_inner_error_codes.h"
#include "common/op/ge_op_utils.h" #include "common/op/ge_op_utils.h"
#include "common/types.h" #include "common/types.h"
@@ -33,30 +31,6 @@ const size_t kSliceInputSize = 3;
const size_t kSliceInputIndexX = 0; const size_t kSliceInputIndexX = 0;
const size_t kSliceInputIndexBegin = 1; const size_t kSliceInputIndexBegin = 1;
const size_t kSliceInputIndexSize = 2; const size_t kSliceInputIndexSize = 2;
const std::set<ge::DataType> kSupportedDataTypeToLength = {
DT_BOOL,
DT_INT64,
DT_UINT64,
DT_FLOAT,
DT_INT32,
DT_UINT32,
DT_INT8,
DT_UINT8,
DT_INT16,
DT_UINT16,
DT_FLOAT16,
DT_DOUBLE,
DT_DUAL,
DT_DUAL_SUB_INT8,
DT_DUAL_SUB_UINT8,
DT_COMPLEX64,
DT_COMPLEX128,
DT_QINT8,
DT_QINT16,
DT_QINT32,
DT_QUINT8,
DT_QUINT16,
};
} // namespace } // namespace


Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTensorPtr> &input, Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTensorPtr> &input,
@@ -79,18 +53,9 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso
GELOGW("input tensor is nullptr."); GELOGW("input tensor is nullptr.");
return NOT_CHANGED; return NOT_CHANGED;
} }

// data type in input_x // data type in input_x
auto data_type = x_->GetTensorDesc().GetDataType(); auto data_type = x_->GetTensorDesc().GetDataType();
// check supported
if (kSupportedDataTypeToLength.count(data_type) == 0) {
GELOGW("input_x data_type is [%s], does not supported!", TypeUtils::DataTypeToSerialString(data_type).c_str());
return NOT_CHANGED;
}
uint32_t type_size = 0;
bool is_success = TypeUtils::GetDataTypeLength(data_type, type_size);
if (!is_success) {
return NOT_CHANGED;
}
// check data type of begin and size // check data type of begin and size
if (begin->GetTensorDesc().GetDataType() != DT_INT32 || size->GetTensorDesc().GetDataType() != DT_INT32) { if (begin->GetTensorDesc().GetDataType() != DT_INT32 || size->GetTensorDesc().GetDataType() != DT_INT32) {
GELOGW("Data type of begin and size for slice are not DT_INT32."); GELOGW("Data type of begin and size for slice are not DT_INT32.");
@@ -104,7 +69,7 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso
GE_CHECK_NOTNULL(begin_data); GE_CHECK_NOTNULL(begin_data);
GE_CHECK_NOTNULL(size_data); GE_CHECK_NOTNULL(size_data);


size_t data_size = x_->GetData().size() / type_size;
size_t data_size = x_->GetData().size() / sizeof(int32_t);
size_t begin_size = begin->GetData().size() / sizeof(int32_t); size_t begin_size = begin->GetData().size() / sizeof(int32_t);
size_t size_size = size->GetData().size() / sizeof(int32_t); size_t size_size = size->GetData().size() / sizeof(int32_t);
const ge::GeShape &x_shape = x_->GetTensorDesc().GetShape(); const ge::GeShape &x_shape = x_->GetTensorDesc().GetShape();


Loading…
Cancel
Save