|
|
@@ -111,8 +111,9 @@ void DynamicStitchKernel::ComputeMergedShape(const vector<ConstGeTensorPtr> &inp |
|
|
|
int32_t merged_first_dim = 0; |
|
|
|
int64_t indices_shape_size = 0; |
|
|
|
for (int i = 0; i < n_; i++) { |
|
|
|
indices_shape_size = input[i]->GetTensorDesc().GetShape().GetShapeSize(); |
|
|
|
indices_shape_size = indices_shape_size == 0 ? 1 : indices_shape_size; |
|
|
|
// shape is [] means scalar |
|
|
|
indices_shape_size = |
|
|
|
input[i]->GetTensorDesc().GetShape().GetDims().empty() ? 1 : input[i]->GetTensorDesc().GetShape().GetShapeSize(); |
|
|
|
const int32_t *input_indices = reinterpret_cast<const int32_t *>(input[i]->GetData().data()); |
|
|
|
for (int64_t j = 0; j < indices_shape_size; j++) { |
|
|
|
merged_first_dim = std::max(merged_first_dim, input_indices[j]); |
|
|
|