|
- /**
- * Copyright 2020 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #include "host_kernels/strided_slice_kernel.h"
- #include "common/fp16_t.h"
- #include "common/math/math_util.h"
- #include "framework/common/types.h"
- #include "graph/utils/type_utils.h"
- #include "host_kernels/kernel_utils.h"
- #include "inc/kernel_factory.h"
-
- namespace ge {
- namespace {
- const int32_t kNumOne = 1;
- const size_t kStridedSliceInputSize = 4;
- const size_t kStridedSliceInputIndex = 0;
- const size_t kStridedSliceBeginIndex = 1;
- const size_t kStridedSliceEndIndex = 2;
- const size_t kStridedSliceStrideIndex = 3;
- const int32_t kDefaultStrideSize = 1;
- const uint32_t kMaskBitLeftUnit = 1;
- const std::set<DataType> kIndexNumberType = {DT_INT32, DT_INT64};
-
- bool IsEllipsisMaskValid(const GeTensorDescPtr &input_desc, const uint32_t ellipsis_mask) {
- if (ellipsis_mask != 0) {
- auto ellipsis_num = 0;
- auto input_shape = input_desc->GetShape();
- for (size_t i = 0; i < input_shape.GetDimNum(); ++i) {
- auto i_temp = static_cast<uint32_t>(i);
- bool ellipsis_mask_flag = (ellipsis_mask) & (kMaskBitLeftUnit << i_temp);
- if (ellipsis_mask_flag) {
- ++ellipsis_num;
- }
- if (ellipsis_num > 1) {
- GELOGW("Only one non-zero bit is allowed in ellipsis_mask.");
- return false;
- }
- }
- }
- return true;
- }
-
- void GetOriginStrideVec(const std::vector<ge::ConstGeTensorPtr> &input, vector<int64_t> &orig_begin_vec,
- vector<int64_t> &orig_end_vec, vector<int64_t> &orig_stride_vec) {
- ConstGeTensorPtr begin_tensor = input[kStridedSliceBeginIndex];
- ConstGeTensorPtr end_tensor = input[kStridedSliceEndIndex];
- ConstGeTensorPtr stride_tensor = input[kStridedSliceStrideIndex];
-
- auto data_type = begin_tensor->GetTensorDesc().GetDataType();
- size_t vec_size = begin_tensor->GetData().size() / GetSizeByDataType(data_type);
- if (data_type == DT_INT32) {
- const int32_t *begin = reinterpret_cast<const int32_t *>(begin_tensor->GetData().data());
- const int32_t *end = reinterpret_cast<const int32_t *>(end_tensor->GetData().data());
- const int32_t *stride = reinterpret_cast<const int32_t *>(stride_tensor->GetData().data());
- for (size_t i = 0; i < vec_size; ++i) {
- orig_begin_vec.emplace_back(begin[i]);
- orig_end_vec.emplace_back(end[i]);
- orig_stride_vec.emplace_back(stride[i]);
- }
- } else {
- const int64_t *begin = reinterpret_cast<const int64_t *>(begin_tensor->GetData().data());
- const int64_t *end = reinterpret_cast<const int64_t *>(end_tensor->GetData().data());
- const int64_t *stride = reinterpret_cast<const int64_t *>(stride_tensor->GetData().data());
- for (size_t i = 0; i < vec_size; ++i) {
- orig_begin_vec.emplace_back(begin[i]);
- orig_end_vec.emplace_back(end[i]);
- orig_stride_vec.emplace_back(stride[i]);
- }
- }
- }
- } // namespace
- Status StridedSliceKernel::Compute(const ge::OpDescPtr attr, const std::vector<ge::ConstGeTensorPtr> &input,
- vector<ge::GeTensorPtr> &v_output) {
- GELOGD("StridedSliceKernel in");
- // 1.Check input and attrs
- if (CheckAndGetAttr(attr) != SUCCESS) {
- GELOGW("Check and get attrs failed.Ignore kernel");
- return NOT_CHANGED;
- }
- if (CheckInputParam(input) != SUCCESS) {
- GELOGW("Check input params failed.Ignore kernel");
- return NOT_CHANGED;
- }
- // 2.Init param with mask attrs.
- std::vector<int64_t> input_dims;
- std::vector<int64_t> begin_vec;
- std::vector<int64_t> output_dims;
- std::vector<int64_t> stride_vec;
- if (InitParamWithAttrs(input, input_dims, begin_vec, output_dims, stride_vec) != SUCCESS) {
- GELOGW("Init param with mask attrs failed.Ignore kernel.");
- return NOT_CHANGED;
- }
-
- // 3.Set sliced data to output_ptr
- ConstGeTensorPtr weight0 = input[kStridedSliceInputIndex];
- auto data_type = weight0->GetTensorDesc().GetDataType();
- size_t data_size = weight0->GetData().size() / GetSizeByDataType(data_type);
- void *data = reinterpret_cast<void *>(const_cast<uint8_t *>(weight0->GetData().data()));
- GE_CHECK_NOTNULL(data);
- // Index 0 can always gets a GeTensorDesc object from any OpDescPtr.
- auto output_tensor_desc = attr->GetOutputDesc(0);
- GeTensorPtr output_ptr = MakeShared<GeTensor>(output_tensor_desc);
- if (output_ptr == nullptr) {
- GELOGE(MEMALLOC_FAILED, "MakeShared GeTensor failed, node name %s.", attr->GetName().c_str());
- return NOT_CHANGED;
- }
- auto ret = OpUtils::SetOutputSliceData(data, static_cast<int64_t>(data_size), data_type, input_dims, begin_vec,
- output_dims, output_ptr.get(), stride_vec);
- if (ret != SUCCESS) {
- GELOGE(INTERNAL_ERROR, "SetOutputSliceData failed");
- return NOT_CHANGED;
- }
-
- // 4.Set output data_type and shape
- GeTensorDesc &t_d = output_ptr->MutableTensorDesc();
- t_d.SetDataType(static_cast<DataType>(data_type));
-
- auto final_dim_size = static_cast<uint32_t>(output_dims.size());
- vector<int64_t> v_dims;
- GetOutputDims(final_dim_size, output_dims, v_dims);
- t_d.SetShape(GeShape(v_dims));
- v_output.push_back(output_ptr);
- GELOGI("StridedSliceKernel success");
- return SUCCESS;
- }
- Status StridedSliceKernel::CheckAndGetAttr(const OpDescPtr &attr) {
- if (attr == nullptr) {
- GELOGE(PARAM_INVALID, "input opdescptr is nullptr.");
- return PARAM_INVALID;
- }
- // Get all op attr value of strided_slice
- for (auto &attr_2_value : attr_value_map_) {
- if (!AttrUtils::GetInt(attr, attr_2_value.first, attr_2_value.second)) {
- GELOGE(PARAM_INVALID, "Get %s attr failed", attr_2_value.first.c_str());
- return PARAM_INVALID;
- }
- }
- // Check ellipsis_mask is valid
- const auto &input_desc = attr->MutableInputDesc(kStridedSliceInputIndex);
- GE_CHECK_NOTNULL(input_desc);
- auto ellipsis_mask = attr_value_map_.at(STRIDE_SLICE_ATTR_ELLIPSIS_MASK);
- if (!IsEllipsisMaskValid(input_desc, ellipsis_mask)) {
- return PARAM_INVALID;
- }
- return SUCCESS;
- }
- Status StridedSliceKernel::CheckInputParam(const std::vector<ConstGeTensorPtr> &input) {
- if (input.size() != kStridedSliceInputSize) {
- GELOGE(PARAM_INVALID, "The number of input for strided slice must be %zu.", kStridedSliceInputSize);
- return PARAM_INVALID;
- }
-
- ConstGeTensorPtr weight0 = input[kStridedSliceInputIndex];
- ConstGeTensorPtr begin_tensor = input[kStridedSliceBeginIndex];
- ConstGeTensorPtr end_tensor = input[kStridedSliceEndIndex];
- ConstGeTensorPtr stride_tensor = input[kStridedSliceStrideIndex];
- GE_CHECK_NOTNULL(weight0);
- GE_CHECK_NOTNULL(begin_tensor);
- GE_CHECK_NOTNULL(end_tensor);
- GE_CHECK_NOTNULL(stride_tensor);
-
- // check if begin,end,strides data type is supported
- auto begin_tensor_desc = begin_tensor->GetTensorDesc();
- auto end_tensor_desc = begin_tensor->GetTensorDesc();
- auto stride_tensor_desc = begin_tensor->GetTensorDesc();
- if (begin_tensor_desc.GetDataType() != end_tensor_desc.GetDataType() ||
- end_tensor_desc.GetDataType() != stride_tensor_desc.GetDataType()) {
- GELOGW("Data type of StridedSlice OP(begin,end,strides) must be same.");
- return PARAM_INVALID;
- }
- if (kIndexNumberType.find(begin_tensor_desc.GetDataType()) == kIndexNumberType.end()) {
- GELOGW("Data type of StridedSlice OP(begin,end,strides) must be int32 or int64");
- return PARAM_INVALID;
- }
-
- // check data
- auto x_data_type = weight0->GetTensorDesc().GetDataType();
- auto x_data_size = GetSizeByDataType(x_data_type);
- if (x_data_size < 0) {
- GELOGW("Data type of x input %s is not supported.", TypeUtils::DataTypeToSerialString(x_data_type).c_str());
- return PARAM_INVALID;
- }
- size_t weight0_size = weight0->GetData().size() / x_data_size;
- size_t begin_data_size = begin_tensor->GetData().size();
- size_t end_data_size = end_tensor->GetData().size();
- size_t stride_data_size = stride_tensor->GetData().size();
- if ((weight0_size == 0) || (begin_data_size == 0) || (end_data_size == 0) || (stride_data_size == 0)) {
- GELOGW("Data size of inputs is 0.");
- return PARAM_INVALID;
- }
- // check dim size
- if (!((begin_data_size == end_data_size) && (end_data_size == stride_data_size))) {
- GELOGW("The sizes of begin, end and stride is not supported.");
- return PARAM_INVALID;
- }
- return SUCCESS;
- }
-
- Status StridedSliceKernel::InitParamWithAttrs(const std::vector<ConstGeTensorPtr> &input,
- std::vector<int64_t> &input_dims, std::vector<int64_t> &begin_vec,
- std::vector<int64_t> &output_dims, std::vector<int64_t> &stride_vec) {
- ConstGeTensorPtr weight0 = input[kStridedSliceInputIndex];
- ConstGeTensorPtr begin_tensor = input[kStridedSliceBeginIndex];
-
- const GeShape x_shape = weight0->GetTensorDesc().GetShape();
- auto x_dims = x_shape.GetDims();
- auto x_dims_num = x_shape.GetDimNum();
- // handle new_axis_mask
- ExpandDimsWithNewAxis(begin_tensor, x_dims_num, x_dims);
-
- vector<int64_t> orig_begin_vec, orig_end_vec, orig_stride_vec;
- GetOriginStrideVec(input, orig_begin_vec, orig_end_vec, orig_stride_vec);
- // calculate begin_mask & end_mask by ellipsis_mask
- ExpandStrideWithEllipsisMask(x_dims_num, x_dims, orig_begin_vec, orig_end_vec, orig_stride_vec);
- auto begin_dim_num = orig_begin_vec.size();
- auto min_dim = x_dims_num > begin_dim_num ? begin_dim_num : x_dims_num;
- for (size_t i = 0; i < x_dims.size(); ++i) {
- auto i_temp = static_cast<uint32_t>(i);
- bool new_axis_mask_flag = (attr_value_map_.at(STRIDE_SLICE_ATTR_NEW_AXIS_MASK) & (kMaskBitLeftUnit << i_temp));
- if (new_axis_mask_flag) {
- output_dims.push_back(1);
- input_dims.push_back(1);
- begin_vec.push_back(0);
- stride_vec.push_back(1);
- continue;
- }
-
- int64_t begin_i = 0;
- int64_t end_i = 0;
- int64_t stride_i = 1;
- if (i < min_dim) {
- begin_i = orig_begin_vec[i];
- end_i = orig_end_vec[i];
- stride_i = orig_stride_vec[i];
- } else {
- begin_i = 0;
- end_i = x_dims.at(i);
- stride_i = 1;
- }
- GELOGD("Before mask calculate. Begin is : %ld\t,end is : %ld\t stride is : %ld\t x_dim_i is : %ld",
- begin_i, end_i, stride_i, x_dims.at(i));
- auto ret = MaskCal(i, begin_i, end_i, x_dims.at(i));
- if (ret != SUCCESS) {
- GELOGW("MaskCal failed, because of data overflow.");
- return NOT_CHANGED;
- }
- int64_t dim_final;
- GELOGD("Before stride calculate. Begin is : %ld\t,end is : %ld\t stride is : %ld\t x_dim_i is : %ld",
- begin_i, end_i, stride_i, x_dims.at(i));
- (void) StrideCal(x_dims.at(i), begin_i, end_i, stride_i, dim_final);
- output_dims.push_back(dim_final);
- input_dims.push_back(x_dims.at(i));
- begin_vec.push_back(begin_i);
- stride_vec.push_back(stride_i);
- }
- return SUCCESS;
- }
-
- void StridedSliceKernel::ExpandDimsWithNewAxis(const ConstGeTensorPtr &begin_tensor, const size_t x_dims_num,
- vector<int64_t> &x_dims) {
- auto begin_data_type_size = GetSizeByDataType(begin_tensor->GetTensorDesc().GetDataType());
- if (begin_data_type_size == 0) {
- GELOGW("Param begin_data_type_size should not be zero.");
- return;
- }
- size_t begin_vec_size = begin_tensor->GetData().size() / begin_data_type_size;
- auto final_dim_num = x_dims_num < begin_vec_size ? begin_vec_size : x_dims_num;
- for (size_t i = 0; i < final_dim_num; i++) {
- auto i_temp = static_cast<uint32_t>(i);
- bool new_axis_mask_flag = (attr_value_map_.at(STRIDE_SLICE_ATTR_NEW_AXIS_MASK) & (kMaskBitLeftUnit << i_temp));
- if (new_axis_mask_flag) {
- x_dims.insert(x_dims.begin() + i, 1);
- }
- }
- }
-
- void StridedSliceKernel::ExpandStrideWithEllipsisMask(const size_t x_dims_num,
- const vector<int64_t> &x_dims,
- vector<int64_t> &orig_begin_vec,
- vector<int64_t> &orig_end_vec,
- vector<int64_t> &orig_stride_vec) {
-
- if (attr_value_map_.at(STRIDE_SLICE_ATTR_ELLIPSIS_MASK) != 0) {
- auto end_mask = attr_value_map_.at(STRIDE_SLICE_ATTR_END_MASK);
- auto begin_mask = attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK);
- if (begin_mask != 0 && x_dims_num != orig_begin_vec.size()) {
- begin_mask *= begin_mask * (kMaskBitLeftUnit << (x_dims_num - orig_begin_vec.size() - 1));
- attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK) = begin_mask;
- }
- if (end_mask != 0 && x_dims_num != orig_end_vec.size()) {
- end_mask *= end_mask * (kMaskBitLeftUnit << (x_dims_num - orig_end_vec.size() - 1));
- attr_value_map_.at(STRIDE_SLICE_ATTR_END_MASK) = end_mask;
- }
- for (size_t i = 0; i < x_dims_num; ++i) {
- bool ellipsis_mask_flag = attr_value_map_.at(STRIDE_SLICE_ATTR_ELLIPSIS_MASK) & (kMaskBitLeftUnit << i);
- if (ellipsis_mask_flag) {
- auto ellipsis_dim = i;
- orig_begin_vec[i] = 0;
- orig_end_vec[i] = x_dims.at(i);
- orig_stride_vec[i] = 1;
- if (orig_begin_vec.size() < x_dims_num) {
- for (size_t j = 1; j < (x_dims_num - orig_begin_vec.size() + 1); ++j) {
- orig_begin_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 0);
- orig_end_vec.insert((orig_end_vec.begin() + ellipsis_dim + j), x_dims.at(ellipsis_dim + j));
- orig_stride_vec.insert((orig_stride_vec.begin() + ellipsis_dim + j), 1);
- }
- }
- }
- }
- }
- }
-
- Status StridedSliceKernel::MaskCal(const size_t i, int64_t &begin_i, int64_t &end_i, int64_t &dim_i) const {
- auto i_temp = static_cast<uint32_t>(i);
- bool begin_mask_flag = (attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK) & (kMaskBitLeftUnit << i_temp));
- bool end_mask_flag = (attr_value_map_.at(STRIDE_SLICE_ATTR_END_MASK) & (kMaskBitLeftUnit << i_temp));
- bool shrink_mask_flag = (attr_value_map_.at(STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK) & (kMaskBitLeftUnit << i_temp));
- if (shrink_mask_flag) {
- begin_i = (begin_i < 0 ? (dim_i + begin_i) : begin_i);
- FMK_INT32_ADDCHECK(begin_i, kNumOne)
- end_i = begin_i + kNumOne;
- } else {
- if (begin_mask_flag) {
- begin_i = 0;
- } else {
- begin_i = (begin_i < 0 ? (dim_i + begin_i) : begin_i);
- }
- if (end_mask_flag) {
- end_i = dim_i;
- } else {
- end_i = (end_i < 0 ? (dim_i + end_i) : end_i);
- }
- }
- return SUCCESS;
- }
-
- Status StridedSliceKernel::StrideCal(const int64_t x_dims_i, int64_t &begin_i, int64_t &end_i, int64_t &stride_i,
- int64_t &dim_final) {
- if (stride_i == 0) {
- stride_i = kDefaultStrideSize;
- } else if (stride_i < 0) {
- stride_i = -stride_i;
- if (begin_i < 0 && end_i < 0) {
- begin_i = x_dims_i - begin_i - 1;
- end_i = x_dims_i - end_i - 1;
- }
- }
-
- if (end_i > x_dims_i) {
- end_i = x_dims_i;
- }
-
- if ((begin_i == 0) && (end_i == 0)) {
- dim_final = x_dims_i;
- } else {
- dim_final = abs(end_i - begin_i) / stride_i;
- }
- return SUCCESS;
- }
-
- void StridedSliceKernel::GetOutputDims(uint32_t dims_size, const std::vector<int64_t> &output_dims,
- vector<int64_t> &v_dims) {
- for (uint32_t k = 0; k < dims_size; k++) {
- bool shrink_mask_i = (attr_value_map_.at(STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK) & (kMaskBitLeftUnit << k));
- if (shrink_mask_i) {
- continue;
- }
- v_dims.push_back(output_dims[k]);
- }
- }
-
- REGISTER_KERNEL(STRIDEDSLICE, StridedSliceKernel);
- } // namespace ge
|