Browse Source

!339 Bugfix: add strided slice kernel ellipsis_mask support

From: @hugo1
Reviewed-by: @xchu42,@ji_chen
Signed-off-by: @ji_chen
tags/v1.1.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
2bb86763a2
2 changed files with 43 additions and 6 deletions
  1. +40
    -6
      ge/host_kernels/strided_slice_kernel.cc
  2. +3
    -0
      ge/host_kernels/strided_slice_kernel.h

+ 40
- 6
ge/host_kernels/strided_slice_kernel.cc View File

@@ -223,6 +223,8 @@ Status StridedSliceKernel::InitParamWithAttrs(const std::vector<ConstGeTensorPtr

vector<int64_t> orig_begin_vec, orig_end_vec, orig_stride_vec;
GetOriginStrideVec(input, orig_begin_vec, orig_end_vec, orig_stride_vec);
// calculate begin_mask & end_mask by ellipsis_mask
ExpandStrideWithEllipsisMask(x_dims_num, x_dims, orig_begin_vec, orig_end_vec, orig_stride_vec);
auto begin_dim_num = orig_begin_vec.size();
auto min_dim = x_dims_num > begin_dim_num ? begin_dim_num : x_dims_num;
for (size_t i = 0; i < x_dims.size(); ++i) {
@@ -281,6 +283,40 @@ void StridedSliceKernel::ExpandDimsWithNewAxis(const ConstGeTensorPtr &begin_ten
}
}

void StridedSliceKernel::ExpandStrideWithEllipsisMask(const size_t x_dims_num,
const vector<int64_t> &x_dims, vector<int64_t> &orig_begin_vec,
vector<int64_t> &orig_end_vec, vector<int64_t> &orig_stride_vec) {
if (attr_value_map_.at(STRIDE_SLICE_ATTR_ELLIPSIS_MASK) != 0) {
auto end_mask = attr_value_map_.at(STRIDE_SLICE_ATTR_END_MASK);
auto begin_mask = attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK);
if (begin_mask != 0 && x_dims_num != orig_begin_vec.size()) {
begin_mask *= begin_mask * (kMaskBitLeftUnit << (x_dims_num - orig_begin_vec.size() - 1));
attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK) = begin_mask;
}
if (end_mask != 0 && x_dims_num != orig_end_vec.size()) {
end_mask *= end_mask * (kMaskBitLeftUnit << (x_dims_num - orig_end_vec.size() - 1));
attr_value_map_.at(STRIDE_SLICE_ATTR_END_MASK) = end_mask;
}
for (auto i = 0; i < x_dims_num; ++i) {
bool ellipsis_mask_flag = attr_value_map_.at(STRIDE_SLICE_ATTR_ELLIPSIS_MASK) & (kMaskBitLeftUnit << i);
if (ellipsis_mask_flag) {
auto ellipsis_dim = i;
orig_begin_vec[i] = 0;
orig_end_vec[i] = x_dims.at(i);
orig_stride_vec[i] = 1;
if (orig_begin_vec.size() < x_dims_num) {
for (auto j = 1; j < (x_dims_num - orig_begin_vec.size() + 1); ++j) {
orig_begin_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 0);
orig_end_vec.insert((orig_end_vec.begin() + ellipsis_dim + j), x_dims.at(ellipsis_dim +j));
orig_stride_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 1);
}
}
}
}
}
}

Status StridedSliceKernel::MaskCal(const size_t i, int64_t &begin_i, int64_t &end_i, int64_t &dim_i) const {
auto i_temp = static_cast<uint32_t>(i);
bool begin_mask_flag = (attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK) & (kMaskBitLeftUnit << i_temp));
@@ -302,10 +338,6 @@ Status StridedSliceKernel::MaskCal(const size_t i, int64_t &begin_i, int64_t &en
} else {
end_i = (end_i < 0 ? (dim_i + end_i) : end_i);
}
if (ellipsis_mask_flag) {
begin_i = 0;
end_i = dim_i;
}
}
return SUCCESS;
}
@@ -316,8 +348,10 @@ Status StridedSliceKernel::StrideCal(const int64_t x_dims_i, int64_t &begin_i, i
stride_i = kDefaultStrideSize;
} else if (stride_i < 0) {
stride_i = -stride_i;
begin_i = x_dims_i - begin_i - 1;
end_i = x_dims_i - end_i - 1;
if (begin_i < 0 && end_i < 0) {
begin_i = x_dims_i - begin_i - 1;
end_i = x_dims_i - end_i - 1;
}
}

if (end_i > x_dims_i) {


+ 3
- 0
ge/host_kernels/strided_slice_kernel.h View File

@@ -36,6 +36,9 @@ class StridedSliceKernel : public Kernel {
static Status StrideCal(const int64_t x_dims_i, int64_t &begin_i, int64_t &end_i, int64_t &stride_i,
int64_t &dim_final) ;
void ExpandDimsWithNewAxis(const ConstGeTensorPtr &begin_tensor, const size_t x_dims_num, vector<int64_t> &x_dims);
void ExpandStrideWithEllipsisMask(const size_t x_dims_num,
const vector<int64_t> &x_dims, vector<int64_t> &orig_begin_vec,
vector<int64_t> &orig_end_vec, vector<int64_t> &orig_stride_vec);

void GetOutputDims(uint32_t dims_size, const std::vector<int64_t> &output_dims, vector<int64_t> &v_dims);



Loading…
Cancel
Save