From d40a346167e4bf23ac84e869b24f80d89035c54e Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Thu, 19 Nov 2020 21:41:48 +0800 Subject: [PATCH 1/3] modified: ge/host_kernels/strided_slice_kernel.cc modified: ge/host_kernels/strided_slice_kernel.h --- ge/host_kernels/strided_slice_kernel.cc | 44 +++++++++++++++++++++---- ge/host_kernels/strided_slice_kernel.h | 3 ++ 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/ge/host_kernels/strided_slice_kernel.cc b/ge/host_kernels/strided_slice_kernel.cc index 7b9c0608..d59202d2 100644 --- a/ge/host_kernels/strided_slice_kernel.cc +++ b/ge/host_kernels/strided_slice_kernel.cc @@ -223,6 +223,8 @@ Status StridedSliceKernel::InitParamWithAttrs(const std::vector orig_begin_vec, orig_end_vec, orig_stride_vec; GetOriginStrideVec(input, orig_begin_vec, orig_end_vec, orig_stride_vec); + // calculate begin_mask & end_mask by ellipsis_mask + ExpandStrideWithEllipsisMask(x_dims_num, x_dims, orig_begin_vec, orig_end_vec, orig_stride_vec); auto begin_dim_num = orig_begin_vec.size(); auto min_dim = x_dims_num > begin_dim_num ? begin_dim_num : x_dims_num; for (size_t i = 0; i < x_dims.size(); ++i) { @@ -281,6 +283,38 @@ void StridedSliceKernel::ExpandDimsWithNewAxis(const ConstGeTensorPtr &begin_ten } } +void StridedSliceKernel::ExpandStrideWithEllipsisMask(const size_t x_dims_num, + const vector &x_dims, vector &orig_begin_vec, + vector &orig_end_vec, vector &orig_stride_vec) { + + if (attr_value_map_.at(STRIDE_SLICE_ATTR_ELLIPSIS_MASK) != 0) { + auto end_mask = attr_value_map_.at(STRIDE_SLICE_ATTR_END_MASK); + auto begin_mask = attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK); + if (begin_mask != 0 && x_dims_num != orig_begin_vec.size()) { + begin_mask *= begin_mask * (kMaskBitLeftUnit << (x_dims_num - orig_begin_vec.size() -1)); + attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK) = begin_mask; + } + if (end_mask != 0 && x_dims_num != orig_end_vec.size()) { + end_mask *= end_mask * (kMaskBitLeftUnit << (x_dims_num - orig_end_vec.size() -1)); + attr_value_map_.at(STRIDE_SLICE_ATTR_END_MASK) = end_mask; + } + for (auto i = 0; i < x_dims_num; ++i) { + bool ellipsis_mask_flag = attr_value_map_.at(STRIDE_SLICE_ATTR_ELLIPSIS_MASK) & (kMaskBitLeftUnit << i); + if (ellipsis_mask_flag) { + auto ellipsis_dim = i; + orig_begin_vec[i] = 0; + orig_end_vec[i] = x_dims.at(i); + orig_stride_vec[i] = 1; + if (auto j = 0; j < (x_dims_num - orig_begin_vec.size() + 1); ++j) { + orig_begin_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 0); + orig_end_vec.insert((orig_end_vec.begin() + ellipsis_dim + j), x_dims.at(ellipsis_dim +j)); + orig_stride_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 1); + } + } + } + } +} + Status StridedSliceKernel::MaskCal(const size_t i, int64_t &begin_i, int64_t &end_i, int64_t &dim_i) const { auto i_temp = static_cast(i); bool begin_mask_flag = (attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK) & (kMaskBitLeftUnit << i_temp)); @@ -302,10 +336,6 @@ Status StridedSliceKernel::MaskCal(const size_t i, int64_t &begin_i, int64_t &en } else { end_i = (end_i < 0 ? (dim_i + end_i) : end_i); } - if (ellipsis_mask_flag) { - begin_i = 0; - end_i = dim_i; - } } return SUCCESS; } @@ -316,8 +346,10 @@ Status StridedSliceKernel::StrideCal(const int64_t x_dims_i, int64_t &begin_i, i stride_i = kDefaultStrideSize; } else if (stride_i < 0) { stride_i = -stride_i; - begin_i = x_dims_i - begin_i - 1; - end_i = x_dims_i - end_i - 1; + if (begin_i < 0 && end_i < 0) { + begin_i = x_dims_i - begin_i - 1; + end_i = x_dims_i - end_i - 1; + } } if (end_i > x_dims_i) { diff --git a/ge/host_kernels/strided_slice_kernel.h b/ge/host_kernels/strided_slice_kernel.h index 315391fd..2daf53e7 100755 --- a/ge/host_kernels/strided_slice_kernel.h +++ b/ge/host_kernels/strided_slice_kernel.h @@ -36,6 +36,9 @@ class StridedSliceKernel : public Kernel { static Status StrideCal(const int64_t x_dims_i, int64_t &begin_i, int64_t &end_i, int64_t &stride_i, int64_t &dim_final) ; void ExpandDimsWithNewAxis(const ConstGeTensorPtr &begin_tensor, const size_t x_dims_num, vector &x_dims); + void ExpandStrideWithEllipsisMask(const size_t x_dims_num, + const vector &x_dims, vector &orig_begin_vec, + vector &orig_end_vec, vector &orig_stride_vec); void GetOutputDims(uint32_t dims_size, const std::vector &output_dims, vector &v_dims); From c7ee494caff75b9df7a9021f5e7abf5afefebd1b Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Thu, 19 Nov 2020 21:50:51 +0800 Subject: [PATCH 2/3] modified: ge/host_kernels/strided_slice_kernel.cc --- ge/host_kernels/strided_slice_kernel.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ge/host_kernels/strided_slice_kernel.cc b/ge/host_kernels/strided_slice_kernel.cc index d59202d2..c73d7c7f 100644 --- a/ge/host_kernels/strided_slice_kernel.cc +++ b/ge/host_kernels/strided_slice_kernel.cc @@ -305,10 +305,12 @@ void StridedSliceKernel::ExpandStrideWithEllipsisMask(const size_t x_dims_num, orig_begin_vec[i] = 0; orig_end_vec[i] = x_dims.at(i); orig_stride_vec[i] = 1; - if (auto j = 0; j < (x_dims_num - orig_begin_vec.size() + 1); ++j) { - orig_begin_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 0); - orig_end_vec.insert((orig_end_vec.begin() + ellipsis_dim + j), x_dims.at(ellipsis_dim +j)); - orig_stride_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 1); + if (orig_begin_vec.size() < x_dims_num) { + for (auto j = 0; j < (x_dims_num - orig_begin_vec.size() + 1); ++j) { + orig_begin_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 0); + orig_end_vec.insert((orig_end_vec.begin() + ellipsis_dim + j), x_dims.at(ellipsis_dim +j)); + orig_stride_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 1); + } } } } From 2a5548b1924d9a4ae1c947dcd155f033314e565f Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Thu, 19 Nov 2020 21:54:20 +0800 Subject: [PATCH 3/3] modified: ge/host_kernels/strided_slice_kernel.cc --- ge/host_kernels/strided_slice_kernel.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ge/host_kernels/strided_slice_kernel.cc b/ge/host_kernels/strided_slice_kernel.cc index c73d7c7f..b76e5c6d 100644 --- a/ge/host_kernels/strided_slice_kernel.cc +++ b/ge/host_kernels/strided_slice_kernel.cc @@ -291,11 +291,11 @@ void StridedSliceKernel::ExpandStrideWithEllipsisMask(const size_t x_dims_num, auto end_mask = attr_value_map_.at(STRIDE_SLICE_ATTR_END_MASK); auto begin_mask = attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK); if (begin_mask != 0 && x_dims_num != orig_begin_vec.size()) { - begin_mask *= begin_mask * (kMaskBitLeftUnit << (x_dims_num - orig_begin_vec.size() -1)); + begin_mask *= begin_mask * (kMaskBitLeftUnit << (x_dims_num - orig_begin_vec.size() - 1)); attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK) = begin_mask; } if (end_mask != 0 && x_dims_num != orig_end_vec.size()) { - end_mask *= end_mask * (kMaskBitLeftUnit << (x_dims_num - orig_end_vec.size() -1)); + end_mask *= end_mask * (kMaskBitLeftUnit << (x_dims_num - orig_end_vec.size() - 1)); attr_value_map_.at(STRIDE_SLICE_ATTR_END_MASK) = end_mask; } for (auto i = 0; i < x_dims_num; ++i) { @@ -306,7 +306,7 @@ void StridedSliceKernel::ExpandStrideWithEllipsisMask(const size_t x_dims_num, orig_end_vec[i] = x_dims.at(i); orig_stride_vec[i] = 1; if (orig_begin_vec.size() < x_dims_num) { - for (auto j = 0; j < (x_dims_num - orig_begin_vec.size() + 1); ++j) { + for (auto j = 1; j < (x_dims_num - orig_begin_vec.size() + 1); ++j) { orig_begin_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 0); orig_end_vec.insert((orig_end_vec.begin() + ellipsis_dim + j), x_dims.at(ellipsis_dim +j)); orig_stride_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 1);