diff --git a/ge/host_kernels/dynamic_stitch_kernel.cc b/ge/host_kernels/dynamic_stitch_kernel.cc index d26237f4..32611b03 100644 --- a/ge/host_kernels/dynamic_stitch_kernel.cc +++ b/ge/host_kernels/dynamic_stitch_kernel.cc @@ -33,6 +33,8 @@ namespace { const int kDoubleAttrN = 2; const int kFirstOutputDescIdx = 0; const int kMergedShapeSecondDim = 1; +const size_t kNullTensorDimNum = 1; +const int64_t kNullTensorDimValue = 0; const std::set kSupportedTypeSet = {DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_INT32, DT_INT64, DT_BOOL, DT_FLOAT16, DT_FLOAT, DT_DOUBLE}; } // namespace @@ -177,7 +179,14 @@ Status DynamicStitchKernel::StitchDataFollowIndices(int64_t data_unit, const vec int64_t src_offset = 0; std::set indices_set; for (int i = 0; i < n_; i++) { - auto indices_shape_size = input[i]->GetTensorDesc().GetShape().GetShapeSize(); + GeShape indices_shape = input[i]->GetTensorDesc().GetShape(); + size_t indices_dim_num = indices_shape.GetDimNum(); + // skip null indices tensor + if (indices_dim_num == kNullTensorDimNum && indices_shape.GetDim(0) == kNullTensorDimValue) { + GELOGD("Input indices[%d] has null tensor, skip it.", i); + continue; + } + auto indices_shape_size = indices_shape.GetShapeSize(); // to normalize logic, assume scalar as vector with shape of [1]. indices_shape_size = (indices_shape_size == 0) ? 1 : indices_shape_size; // all index for input is less than size of input