| @@ -33,6 +33,8 @@ namespace { | |||||
| const int kDoubleAttrN = 2; | const int kDoubleAttrN = 2; | ||||
| const int kFirstOutputDescIdx = 0; | const int kFirstOutputDescIdx = 0; | ||||
| const int kMergedShapeSecondDim = 1; | const int kMergedShapeSecondDim = 1; | ||||
| const size_t kNullTensorDimNum = 1; | |||||
| const int64_t kNullTensorDimValue = 0; | |||||
| const std::set<DataType> kSupportedTypeSet = {DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_INT32, | const std::set<DataType> kSupportedTypeSet = {DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_INT32, | ||||
| DT_INT64, DT_BOOL, DT_FLOAT16, DT_FLOAT, DT_DOUBLE}; | DT_INT64, DT_BOOL, DT_FLOAT16, DT_FLOAT, DT_DOUBLE}; | ||||
| } // namespace | } // namespace | ||||
| @@ -177,7 +179,14 @@ Status DynamicStitchKernel::StitchDataFollowIndices(int64_t data_unit, const vec | |||||
| int64_t src_offset = 0; | int64_t src_offset = 0; | ||||
| std::set<int32_t> indices_set; | std::set<int32_t> indices_set; | ||||
| for (int i = 0; i < n_; i++) { | for (int i = 0; i < n_; i++) { | ||||
| auto indices_shape_size = input[i]->GetTensorDesc().GetShape().GetShapeSize(); | |||||
| GeShape indices_shape = input[i]->GetTensorDesc().GetShape(); | |||||
| size_t indices_dim_num = indices_shape.GetDimNum(); | |||||
| // skip null indices tensor | |||||
| if (indices_dim_num == kNullTensorDimNum && indices_shape.GetDim(0) == kNullTensorDimValue) { | |||||
| GELOGD("Input indices[%d] has null tensor, skip it.", i); | |||||
| continue; | |||||
| } | |||||
| auto indices_shape_size = indices_shape.GetShapeSize(); | |||||
| // to normalize logic, assume scalar as vector with shape of [1]. | // to normalize logic, assume scalar as vector with shape of [1]. | ||||
| indices_shape_size = (indices_shape_size == 0) ? 1 : indices_shape_size; | indices_shape_size = (indices_shape_size == 0) ? 1 : indices_shape_size; | ||||
| // all index for input is less than size of input | // all index for input is less than size of input | ||||