| @@ -22,6 +22,8 @@ | |||||
| #include "graph/preprocess/multi_batch_options.h" | #include "graph/preprocess/multi_batch_options.h" | ||||
| #include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
| #include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| #include "register/op_registry.h" | #include "register/op_registry.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -478,8 +480,28 @@ Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) { | |||||
| if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { | if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| (void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims()); | (void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims()); | ||||
| GeTensorDesc tensor(NodeUtils::GetOutputDesc(*data, kDataOutIndex)); | |||||
| std::vector<std::string> input_dims_str; | |||||
| for (size_t i = 0; i < batch_shapes_.size(); ++i) { | |||||
| auto shape = data_shape; | |||||
| auto ret = multibatch::CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Failed to calculate the shape for data node %s, the shape may not match", data->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| tensor.SetShape(shape); | |||||
| int64_t tensor_size = 0; | |||||
| (void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size); | |||||
| string input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" + | |||||
| TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + data->GetName() + ":" + | |||||
| std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" + | |||||
| formats::JoinToString(tensor.GetShape().GetDims()); | |||||
| input_dims_str.emplace_back(input_str); | |||||
| } | |||||
| (void)AttrUtils::SetListStr(data->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str); | |||||
| size_t max_shape_index = 0; | size_t max_shape_index = 0; | ||||
| int64_t max_size = 0; | int64_t max_size = 0; | ||||
| for (size_t i = 0; i < batch_shapes_.size(); ++i) { | for (size_t i = 0; i < batch_shapes_.size(); ++i) { | ||||
| @@ -593,7 +615,7 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const | |||||
| graph->AddSubgraph(subgraph->GetName(), subgraph); | graph->AddSubgraph(subgraph->GetName(), subgraph); | ||||
| all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT); | all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT); | ||||
| GE_CHK_STATUS_RET(UpdateSubgraphOutput(all_branch_output_[subgraph]), | GE_CHK_STATUS_RET(UpdateSubgraphOutput(all_branch_output_[subgraph]), | ||||
| "Update %s failed", all_branch_output_[subgraph]->GetName().c_str()); | |||||
| "Update %s failed", all_branch_output_[subgraph]->GetName().c_str()); | |||||
| const string key_name = "branches" + std::to_string(i); | const string key_name = "branches" + std::to_string(i); | ||||
| op_desc->AddSubgraphName(key_name); | op_desc->AddSubgraphName(key_name); | ||||