|
|
@@ -503,12 +503,24 @@ Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) { |
|
|
|
|
|
|
|
/// |
|
|
|
/// @ingroup ge |
|
|
|
/// @brief Set shape to Data node in branch. |
|
|
|
/// @param [in] const NodePtr &data: data in branch. |
|
|
|
/// @brief Update Data node in Subgraph. |
|
|
|
/// @param [in] const NodePtr &data: data in Subgraph. |
|
|
|
/// @param [in] size_t index: The batch index. |
|
|
|
/// @return 0: SUCCESS / others: FAILED |
|
|
|
/// |
|
|
|
Status MultiBatchClonePass::UpdateShapeToData(const NodePtr &data, size_t index) { |
|
|
|
Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t index) { |
|
|
|
int node_index = -1; |
|
|
|
if (!AttrUtils::GetInt(data->GetOpDesc(), ATTR_NAME_INDEX, node_index)) { |
|
|
|
GELOGE(FAILED, "Failed to get index from data[%s]", data->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
int parent_index = node_index + 1; |
|
|
|
if (!AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { |
|
|
|
GELOGE(FAILED, "Failed to set parent index for node %s", data->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); |
|
|
|
const auto &dims = data_shape.GetDims(); |
|
|
|
if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { |
|
|
@@ -581,13 +593,15 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const |
|
|
|
(void)AttrUtils::SetBool(subgraph, "_no_reset_name", true); |
|
|
|
graph->AddSubgraph(subgraph->GetName(), subgraph); |
|
|
|
all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT); |
|
|
|
GE_CHK_STATUS_RET(UpdateSubgraphOutput(all_branch_output_[subgraph]), |
|
|
|
"Update %s failed", all_branch_output_[subgraph]->GetName().c_str()); |
|
|
|
|
|
|
|
const string key_name = "branches" + std::to_string(i); |
|
|
|
op_desc->AddSubgraphName(key_name); |
|
|
|
op_desc->SetSubgraphInstanceName(i, subgraph->GetName()); |
|
|
|
|
|
|
|
for (const auto &data : input_nodes) { |
|
|
|
GE_CHK_STATUS_RET(UpdateShapeToData(data, i), "Update %s failed", subgraph->GetName().c_str()); |
|
|
|
GE_CHK_STATUS_RET(UpdateSubgraphData(data, i), "Update %s failed", subgraph->GetName().c_str()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@@ -596,7 +610,28 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const |
|
|
|
const auto &op_desc = n->GetOpDesc(); |
|
|
|
op_desc->SetName(n->GetName() + kMultiBatchNodePostfix + "0"); |
|
|
|
if (n->GetType() == DATA) { |
|
|
|
GE_CHK_STATUS_RET(UpdateShapeToData(n, 0), "Update %s failed", branch->GetName().c_str()); |
|
|
|
GE_CHK_STATUS_RET(UpdateSubgraphData(n, 0), "Update %s failed", branch->GetName().c_str()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
/// |
|
|
|
/// @ingroup ge |
|
|
|
/// @brief Update output_node in Subgraph. |
|
|
|
/// @param [in] const NodePtr &data: output_node in Subgraph. |
|
|
|
/// @return 0: SUCCESS / others: FAILED |
|
|
|
/// |
|
|
|
Status MultiBatchClonePass::UpdateSubgraphOutput(const NodePtr &output_node) { |
|
|
|
const auto &op_desc = output_node->GetOpDesc(); |
|
|
|
GE_CHECK_NOTNULL(op_desc); |
|
|
|
for (size_t index = 0; index < op_desc->GetInputsSize(); ++index) { |
|
|
|
GeTensorDescPtr tensor = op_desc->MutableInputDesc(index); |
|
|
|
GE_CHECK_NOTNULL(tensor); |
|
|
|
if (!AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) { |
|
|
|
GELOGE(FAILED, "Failed to set parent index for node %s", output_node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|