diff --git a/ge/graph/passes/multi_batch_clone_pass.cc b/ge/graph/passes/multi_batch_clone_pass.cc index 872f94fb..ed654d4f 100755 --- a/ge/graph/passes/multi_batch_clone_pass.cc +++ b/ge/graph/passes/multi_batch_clone_pass.cc @@ -478,8 +478,30 @@ Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) { if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { return SUCCESS; } - (void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims()); + + GeTensorDesc tensor(NodeUtils::GetOutputDesc(*data, kDataOutIndex)); + std::vector input_dims_str; + for (size_t i = 0; i < batch_shapes_.size(); ++i) { + auto shape = data_shape; + auto ret = CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to calculate the batched shape for data node %s, the shapes may not match", + data->GetName().c_str()); + return ret; + } + tensor.SetShape(shape); + string input_str; + int64_t tensor_size = 0; + (void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size); + input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" + + TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + data->GetName() + ":" + + std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" + + formats::JoinToString(tensor.GetShape().GetDims()); + input_dims_str.emplace_back(input_str); + } + (void)AttrUtils::SetListStr(data->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str); + size_t max_shape_index = 0; int64_t max_size = 0; for (size_t i = 0; i < batch_shapes_.size(); ++i) {