| @@ -484,20 +484,19 @@ Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) { | |||||
| std::vector<std::string> input_dims_str; | std::vector<std::string> input_dims_str; | ||||
| for (size_t i = 0; i < batch_shapes_.size(); ++i) { | for (size_t i = 0; i < batch_shapes_.size(); ++i) { | ||||
| auto shape = data_shape; | auto shape = data_shape; | ||||
| auto ret = CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape); | |||||
| auto ret = multibatch::CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "Failed to calculate the batched shape for data node %s, the shapes may not match", | GELOGE(ret, "Failed to calculate the batched shape for data node %s, the shapes may not match", | ||||
| data->GetName().c_str()); | |||||
| data->GetName().c_str()); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| tensor.SetShape(shape); | tensor.SetShape(shape); | ||||
| string input_str; | |||||
| int64_t tensor_size = 0; | int64_t tensor_size = 0; | ||||
| (void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size); | |||||
| input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" + | |||||
| TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + data->GetName() + ":" + | |||||
| std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" + | |||||
| formats::JoinToString(tensor.GetShape().GetDims()); | |||||
| (void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size); | |||||
| string input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" + | |||||
| TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + data->GetName() + ":" + | |||||
| std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" + | |||||
| formats::JoinToString(tensor.GetShape().GetDims()); | |||||
| input_dims_str.emplace_back(input_str); | input_dims_str.emplace_back(input_str); | ||||
| } | } | ||||
| (void)AttrUtils::SetListStr(data->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str); | (void)AttrUtils::SetListStr(data->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str); | ||||