|
|
@@ -33,6 +33,8 @@ namespace { |
|
|
|
const int kDoubleAttrN = 2; |
|
|
|
const int kFirstOutputDescIdx = 0; |
|
|
|
const int kMergedShapeSecondDim = 1; |
|
|
|
const size_t kNullTensorDimNum = 1; |
|
|
|
const int64_t kNullTensorDimValue = 0; |
|
|
|
const std::set<DataType> kSupportedTypeSet = {DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_INT32, |
|
|
|
DT_INT64, DT_BOOL, DT_FLOAT16, DT_FLOAT, DT_DOUBLE}; |
|
|
|
} // namespace |
|
|
@@ -177,7 +179,14 @@ Status DynamicStitchKernel::StitchDataFollowIndices(int64_t data_unit, const vec |
|
|
|
int64_t src_offset = 0; |
|
|
|
std::set<int32_t> indices_set; |
|
|
|
for (int i = 0; i < n_; i++) { |
|
|
|
auto indices_shape_size = input[i]->GetTensorDesc().GetShape().GetShapeSize(); |
|
|
|
GeShape indices_shape = input[i]->GetTensorDesc().GetShape(); |
|
|
|
size_t indices_dim_num = indices_shape.GetDimNum(); |
|
|
|
// skip null indices tensor |
|
|
|
if (indices_dim_num == kNullTensorDimNum && indices_shape.GetDim(0) == kNullTensorDimValue) { |
|
|
|
GELOGD("Input indices[%d] has null tensor, skip it.", i); |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto indices_shape_size = indices_shape.GetShapeSize(); |
|
|
|
// to normalize logic, assume scalar as vector with shape of [1]. |
|
|
|
indices_shape_size = (indices_shape_size == 0) ? 1 : indices_shape_size; |
|
|
|
// all index for input is less than size of input |
|
|
|