|
@@ -484,20 +484,19 @@ Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) { |
|
|
std::vector<std::string> input_dims_str; |
|
|
std::vector<std::string> input_dims_str; |
|
|
for (size_t i = 0; i < batch_shapes_.size(); ++i) { |
|
|
for (size_t i = 0; i < batch_shapes_.size(); ++i) { |
|
|
auto shape = data_shape; |
|
|
auto shape = data_shape; |
|
|
auto ret = CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape); |
|
|
|
|
|
|
|
|
auto ret = multibatch::CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape); |
|
|
if (ret != SUCCESS) { |
|
|
if (ret != SUCCESS) { |
|
|
GELOGE(ret, "Failed to calculate the batched shape for data node %s, the shapes may not match", |
|
|
GELOGE(ret, "Failed to calculate the batched shape for data node %s, the shapes may not match", |
|
|
data->GetName().c_str()); |
|
|
|
|
|
|
|
|
data->GetName().c_str()); |
|
|
return ret; |
|
|
return ret; |
|
|
} |
|
|
} |
|
|
tensor.SetShape(shape); |
|
|
tensor.SetShape(shape); |
|
|
string input_str; |
|
|
|
|
|
int64_t tensor_size = 0; |
|
|
int64_t tensor_size = 0; |
|
|
(void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size); |
|
|
|
|
|
input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" + |
|
|
|
|
|
TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + data->GetName() + ":" + |
|
|
|
|
|
std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" + |
|
|
|
|
|
formats::JoinToString(tensor.GetShape().GetDims()); |
|
|
|
|
|
|
|
|
(void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size); |
|
|
|
|
|
string input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" + |
|
|
|
|
|
TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + data->GetName() + ":" + |
|
|
|
|
|
std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" + |
|
|
|
|
|
formats::JoinToString(tensor.GetShape().GetDims()); |
|
|
input_dims_str.emplace_back(input_str); |
|
|
input_dims_str.emplace_back(input_str); |
|
|
} |
|
|
} |
|
|
(void)AttrUtils::SetListStr(data->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str); |
|
|
(void)AttrUtils::SetListStr(data->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str); |
|
|