Browse Source

Feature: repair dynamic_stitch_kernel folding bug

tags/v1.2.0
l00444296 3 years ago
parent
commit
cdbef14012
1 changed files with 10 additions and 1 deletions
  1. +10
    -1
      ge/host_kernels/dynamic_stitch_kernel.cc

+ 10
- 1
ge/host_kernels/dynamic_stitch_kernel.cc View File

@@ -33,6 +33,8 @@ namespace {
const int kDoubleAttrN = 2; const int kDoubleAttrN = 2;
const int kFirstOutputDescIdx = 0; const int kFirstOutputDescIdx = 0;
const int kMergedShapeSecondDim = 1; const int kMergedShapeSecondDim = 1;
const size_t kNullTensorDimNum = 1;
const int64_t kNullTensorDimValue = 0;
const std::set<DataType> kSupportedTypeSet = {DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_INT32, const std::set<DataType> kSupportedTypeSet = {DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_INT32,
DT_INT64, DT_BOOL, DT_FLOAT16, DT_FLOAT, DT_DOUBLE}; DT_INT64, DT_BOOL, DT_FLOAT16, DT_FLOAT, DT_DOUBLE};
} // namespace } // namespace
@@ -177,7 +179,14 @@ Status DynamicStitchKernel::StitchDataFollowIndices(int64_t data_unit, const vec
int64_t src_offset = 0; int64_t src_offset = 0;
std::set<int32_t> indices_set; std::set<int32_t> indices_set;
for (int i = 0; i < n_; i++) { for (int i = 0; i < n_; i++) {
auto indices_shape_size = input[i]->GetTensorDesc().GetShape().GetShapeSize();
GeShape indices_shape = input[i]->GetTensorDesc().GetShape();
size_t indices_dim_num = indices_shape.GetDimNum();
// skip null indices tensor
if (indices_dim_num == kNullTensorDimNum && indices_shape.GetDim(0) == kNullTensorDimValue) {
GELOGD("Input indices[%d] has null tensor, skip it.", i);
continue;
}
auto indices_shape_size = indices_shape.GetShapeSize();
// to normalize logic, assume scalar as vector with shape of [1]. // to normalize logic, assume scalar as vector with shape of [1].
indices_shape_size = (indices_shape_size == 0) ? 1 : indices_shape_size; indices_shape_size = (indices_shape_size == 0) ? 1 : indices_shape_size;
// all index for input is less than size of input // all index for input is less than size of input


Loading…
Cancel
Save