|
|
@@ -223,6 +223,8 @@ Status StridedSliceKernel::InitParamWithAttrs(const std::vector<ConstGeTensorPtr |
|
|
|
|
|
|
|
vector<int64_t> orig_begin_vec, orig_end_vec, orig_stride_vec; |
|
|
|
GetOriginStrideVec(input, orig_begin_vec, orig_end_vec, orig_stride_vec); |
|
|
|
// calculate begin_mask & end_mask by ellipsis_mask |
|
|
|
ExpandStrideWithEllipsisMask(x_dims_num, x_dims, orig_begin_vec, orig_end_vec, orig_stride_vec); |
|
|
|
auto begin_dim_num = orig_begin_vec.size(); |
|
|
|
auto min_dim = x_dims_num > begin_dim_num ? begin_dim_num : x_dims_num; |
|
|
|
for (size_t i = 0; i < x_dims.size(); ++i) { |
|
|
@@ -281,6 +283,40 @@ void StridedSliceKernel::ExpandDimsWithNewAxis(const ConstGeTensorPtr &begin_ten |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void StridedSliceKernel::ExpandStrideWithEllipsisMask(const size_t x_dims_num, |
|
|
|
const vector<int64_t> &x_dims, vector<int64_t> &orig_begin_vec, |
|
|
|
vector<int64_t> &orig_end_vec, vector<int64_t> &orig_stride_vec) { |
|
|
|
|
|
|
|
if (attr_value_map_.at(STRIDE_SLICE_ATTR_ELLIPSIS_MASK) != 0) { |
|
|
|
auto end_mask = attr_value_map_.at(STRIDE_SLICE_ATTR_END_MASK); |
|
|
|
auto begin_mask = attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK); |
|
|
|
if (begin_mask != 0 && x_dims_num != orig_begin_vec.size()) { |
|
|
|
begin_mask *= begin_mask * (kMaskBitLeftUnit << (x_dims_num - orig_begin_vec.size() - 1)); |
|
|
|
attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK) = begin_mask; |
|
|
|
} |
|
|
|
if (end_mask != 0 && x_dims_num != orig_end_vec.size()) { |
|
|
|
end_mask *= end_mask * (kMaskBitLeftUnit << (x_dims_num - orig_end_vec.size() - 1)); |
|
|
|
attr_value_map_.at(STRIDE_SLICE_ATTR_END_MASK) = end_mask; |
|
|
|
} |
|
|
|
for (auto i = 0; i < x_dims_num; ++i) { |
|
|
|
bool ellipsis_mask_flag = attr_value_map_.at(STRIDE_SLICE_ATTR_ELLIPSIS_MASK) & (kMaskBitLeftUnit << i); |
|
|
|
if (ellipsis_mask_flag) { |
|
|
|
auto ellipsis_dim = i; |
|
|
|
orig_begin_vec[i] = 0; |
|
|
|
orig_end_vec[i] = x_dims.at(i); |
|
|
|
orig_stride_vec[i] = 1; |
|
|
|
if (orig_begin_vec.size() < x_dims_num) { |
|
|
|
for (auto j = 1; j < (x_dims_num - orig_begin_vec.size() + 1); ++j) { |
|
|
|
orig_begin_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 0); |
|
|
|
orig_end_vec.insert((orig_end_vec.begin() + ellipsis_dim + j), x_dims.at(ellipsis_dim +j)); |
|
|
|
orig_stride_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 1); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
Status StridedSliceKernel::MaskCal(const size_t i, int64_t &begin_i, int64_t &end_i, int64_t &dim_i) const { |
|
|
|
auto i_temp = static_cast<uint32_t>(i); |
|
|
|
bool begin_mask_flag = (attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK) & (kMaskBitLeftUnit << i_temp)); |
|
|
@@ -302,10 +338,6 @@ Status StridedSliceKernel::MaskCal(const size_t i, int64_t &begin_i, int64_t &en |
|
|
|
} else { |
|
|
|
end_i = (end_i < 0 ? (dim_i + end_i) : end_i); |
|
|
|
} |
|
|
|
if (ellipsis_mask_flag) { |
|
|
|
begin_i = 0; |
|
|
|
end_i = dim_i; |
|
|
|
} |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
@@ -316,8 +348,10 @@ Status StridedSliceKernel::StrideCal(const int64_t x_dims_i, int64_t &begin_i, i |
|
|
|
stride_i = kDefaultStrideSize; |
|
|
|
} else if (stride_i < 0) { |
|
|
|
stride_i = -stride_i; |
|
|
|
begin_i = x_dims_i - begin_i - 1; |
|
|
|
end_i = x_dims_i - end_i - 1; |
|
|
|
if (begin_i < 0 && end_i < 0) { |
|
|
|
begin_i = x_dims_i - begin_i - 1; |
|
|
|
end_i = x_dims_i - end_i - 1; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (end_i > x_dims_i) { |
|
|
|