diff --git a/ge/host_kernels/strided_slice_kernel.cc b/ge/host_kernels/strided_slice_kernel.cc index 7b9c0608..d59202d2 100644 --- a/ge/host_kernels/strided_slice_kernel.cc +++ b/ge/host_kernels/strided_slice_kernel.cc @@ -223,6 +223,8 @@ Status StridedSliceKernel::InitParamWithAttrs(const std::vector orig_begin_vec, orig_end_vec, orig_stride_vec; GetOriginStrideVec(input, orig_begin_vec, orig_end_vec, orig_stride_vec); + // calculate begin_mask & end_mask by ellipsis_mask + ExpandStrideWithEllipsisMask(x_dims_num, x_dims, orig_begin_vec, orig_end_vec, orig_stride_vec); auto begin_dim_num = orig_begin_vec.size(); auto min_dim = x_dims_num > begin_dim_num ? begin_dim_num : x_dims_num; for (size_t i = 0; i < x_dims.size(); ++i) { @@ -281,6 +283,38 @@ void StridedSliceKernel::ExpandDimsWithNewAxis(const ConstGeTensorPtr &begin_ten } } +void StridedSliceKernel::ExpandStrideWithEllipsisMask(const size_t x_dims_num, + const vector &x_dims, vector &orig_begin_vec, + vector &orig_end_vec, vector &orig_stride_vec) { + + if (attr_value_map_.at(STRIDE_SLICE_ATTR_ELLIPSIS_MASK) != 0) { + auto end_mask = attr_value_map_.at(STRIDE_SLICE_ATTR_END_MASK); + auto begin_mask = attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK); + if (begin_mask != 0 && x_dims_num != orig_begin_vec.size()) { + begin_mask *= begin_mask * (kMaskBitLeftUnit << (x_dims_num - orig_begin_vec.size() -1)); + attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK) = begin_mask; + } + if (end_mask != 0 && x_dims_num != orig_end_vec.size()) { + end_mask *= end_mask * (kMaskBitLeftUnit << (x_dims_num - orig_end_vec.size() -1)); + attr_value_map_.at(STRIDE_SLICE_ATTR_END_MASK) = end_mask; + } + for (auto i = 0; i < x_dims_num; ++i) { + bool ellipsis_mask_flag = attr_value_map_.at(STRIDE_SLICE_ATTR_ELLIPSIS_MASK) & (kMaskBitLeftUnit << i); + if (ellipsis_mask_flag) { + auto ellipsis_dim = i; + orig_begin_vec[i] = 0; + orig_end_vec[i] = x_dims.at(i); + orig_stride_vec[i] = 1; + if (auto j = 0; j < (x_dims_num - orig_begin_vec.size() + 1); ++j) { + orig_begin_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 0); + orig_end_vec.insert((orig_end_vec.begin() + ellipsis_dim + j), x_dims.at(ellipsis_dim +j)); + orig_stride_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 1); + } + } + } + } +} + Status StridedSliceKernel::MaskCal(const size_t i, int64_t &begin_i, int64_t &end_i, int64_t &dim_i) const { auto i_temp = static_cast(i); bool begin_mask_flag = (attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK) & (kMaskBitLeftUnit << i_temp)); @@ -302,10 +336,6 @@ Status StridedSliceKernel::MaskCal(const size_t i, int64_t &begin_i, int64_t &en } else { end_i = (end_i < 0 ? (dim_i + end_i) : end_i); } - if (ellipsis_mask_flag) { - begin_i = 0; - end_i = dim_i; - } } return SUCCESS; } @@ -316,8 +346,10 @@ Status StridedSliceKernel::StrideCal(const int64_t x_dims_i, int64_t &begin_i, i stride_i = kDefaultStrideSize; } else if (stride_i < 0) { stride_i = -stride_i; - begin_i = x_dims_i - begin_i - 1; - end_i = x_dims_i - end_i - 1; + if (begin_i < 0 && end_i < 0) { + begin_i = x_dims_i - begin_i - 1; + end_i = x_dims_i - end_i - 1; + } } if (end_i > x_dims_i) { diff --git a/ge/host_kernels/strided_slice_kernel.h b/ge/host_kernels/strided_slice_kernel.h index 315391fd..2daf53e7 100755 --- a/ge/host_kernels/strided_slice_kernel.h +++ b/ge/host_kernels/strided_slice_kernel.h @@ -36,6 +36,9 @@ class StridedSliceKernel : public Kernel { static Status StrideCal(const int64_t x_dims_i, int64_t &begin_i, int64_t &end_i, int64_t &stride_i, int64_t &dim_final) ; void ExpandDimsWithNewAxis(const ConstGeTensorPtr &begin_tensor, const size_t x_dims_num, vector &x_dims); + void ExpandStrideWithEllipsisMask(const size_t x_dims_num, + const vector &x_dims, vector &orig_begin_vec, + vector &orig_end_vec, vector &orig_stride_vec); void GetOutputDims(uint32_t dims_size, const std::vector &output_dims, vector &v_dims);