|
|
@@ -109,6 +109,7 @@ Status MultiBatchClonePass::Run(ComputeGraphPtr graph) { |
|
|
|
GE_CHK_STATUS_RET(CreateSubgraphs(graph, branch), "Construct subgraph failed."); |
|
|
|
|
|
|
|
GE_CHK_STATUS_RET(PruneDirectOutput(graph), "Prune direct output failed"); |
|
|
|
GE_CHK_STATUS_RET(UpdateSubgraphOutput(), "Update subgraph output failed"); |
|
|
|
GELOGD("MultiBatchClonePass Leave"); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
@@ -1057,8 +1058,6 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const |
|
|
|
subgraph->SetParentGraph(graph); |
|
|
|
graph->AddSubgraph(subgraph->GetName(), subgraph); |
|
|
|
all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT); |
|
|
|
GE_CHK_STATUS_RET(UpdateSubgraphOutput(all_branch_output_[subgraph]), |
|
|
|
"Update %s failed", all_branch_output_[subgraph]->GetName().c_str()); |
|
|
|
|
|
|
|
const string key_name = "branches" + std::to_string(i); |
|
|
|
op_desc->AddSubgraphName(key_name); |
|
|
@@ -1085,21 +1084,22 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const |
|
|
|
/// |
|
|
|
/// @ingroup ge |
|
|
|
/// @brief Update output_node in Subgraph. |
|
|
|
/// @param [in] const NodePtr &output_node: output_node in Subgraph. |
|
|
|
/// @return 0: SUCCESS / others: FAILED |
|
|
|
/// |
|
|
|
Status MultiBatchClonePass::UpdateSubgraphOutput(const NodePtr &output_node) { |
|
|
|
const auto &op_desc = output_node->GetOpDesc(); |
|
|
|
GE_CHECK_NOTNULL(op_desc); |
|
|
|
for (size_t index = 0; index < op_desc->GetInputsSize(); ++index) { |
|
|
|
GeTensorDescPtr tensor = op_desc->MutableInputDesc(index); |
|
|
|
GE_CHECK_NOTNULL(tensor); |
|
|
|
if (!AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) { |
|
|
|
GELOGE(FAILED, "Failed to set parent index for node %s", output_node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
Status MultiBatchClonePass::UpdateSubgraphOutput() { |
|
|
|
for (const auto &item : all_branch_output_) { |
|
|
|
const auto &output_node = item.second; |
|
|
|
const auto &op_desc = output_node->GetOpDesc(); |
|
|
|
GE_CHECK_NOTNULL(op_desc); |
|
|
|
for (size_t index = 0; index < op_desc->GetInputsSize(); ++index) { |
|
|
|
GeTensorDescPtr tensor = op_desc->MutableInputDesc(index); |
|
|
|
GE_CHECK_NOTNULL(tensor); |
|
|
|
if (!AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) { |
|
|
|
GELOGE(FAILED, "Failed to set parent index for node %s", output_node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|