From b6d89663fad8cab0e3e0a23e23372431b948ebd1 Mon Sep 17 00:00:00 2001 From: wjm Date: Sun, 7 Feb 2021 14:41:54 +0800 Subject: [PATCH] fix DirectOutput error in mult batch --- ge/graph/passes/multi_batch_clone_pass.cc | 26 +++++++++++------------ ge/graph/passes/multi_batch_clone_pass.h | 3 +-- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/ge/graph/passes/multi_batch_clone_pass.cc b/ge/graph/passes/multi_batch_clone_pass.cc index b8fb6bde..a33e1f40 100755 --- a/ge/graph/passes/multi_batch_clone_pass.cc +++ b/ge/graph/passes/multi_batch_clone_pass.cc @@ -109,6 +109,7 @@ Status MultiBatchClonePass::Run(ComputeGraphPtr graph) { GE_CHK_STATUS_RET(CreateSubgraphs(graph, branch), "Construct subgraph failed."); GE_CHK_STATUS_RET(PruneDirectOutput(graph), "Prune direct output failed"); + GE_CHK_STATUS_RET(UpdateSubgraphOutput(), "Update subgraph output failed"); GELOGD("MultiBatchClonePass Leave"); return SUCCESS; } @@ -1057,8 +1058,6 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const subgraph->SetParentGraph(graph); graph->AddSubgraph(subgraph->GetName(), subgraph); all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT); - GE_CHK_STATUS_RET(UpdateSubgraphOutput(all_branch_output_[subgraph]), - "Update %s failed", all_branch_output_[subgraph]->GetName().c_str()); const string key_name = "branches" + std::to_string(i); op_desc->AddSubgraphName(key_name); @@ -1085,21 +1084,22 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const /// /// @ingroup ge /// @brief Update output_node in Subgraph. -/// @param [in] const NodePtr &output_node: output_node in Subgraph. /// @return 0: SUCCESS / others: FAILED /// -Status MultiBatchClonePass::UpdateSubgraphOutput(const NodePtr &output_node) { - const auto &op_desc = output_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - for (size_t index = 0; index < op_desc->GetInputsSize(); ++index) { - GeTensorDescPtr tensor = op_desc->MutableInputDesc(index); - GE_CHECK_NOTNULL(tensor); - if (!AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) { - GELOGE(FAILED, "Failed to set parent index for node %s", output_node->GetName().c_str()); - return FAILED; +Status MultiBatchClonePass::UpdateSubgraphOutput() { + for (const auto &item : all_branch_output_) { + const auto &output_node = item.second; + const auto &op_desc = output_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + for (size_t index = 0; index < op_desc->GetInputsSize(); ++index) { + GeTensorDescPtr tensor = op_desc->MutableInputDesc(index); + GE_CHECK_NOTNULL(tensor); + if (!AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) { + GELOGE(FAILED, "Failed to set parent index for node %s", output_node->GetName().c_str()); + return FAILED; + } } } - return SUCCESS; } diff --git a/ge/graph/passes/multi_batch_clone_pass.h b/ge/graph/passes/multi_batch_clone_pass.h index 0dae88ca..69f7ddf9 100755 --- a/ge/graph/passes/multi_batch_clone_pass.h +++ b/ge/graph/passes/multi_batch_clone_pass.h @@ -136,10 +136,9 @@ class MultiBatchClonePass : public GraphPass { /// /// @ingroup ge /// @brief Update output_node in Subgraph. - /// @param [in] const NodePtr &output_node: output_node in Subgraph. /// @return 0: SUCCESS / others: FAILED /// - Status UpdateSubgraphOutput(const NodePtr &output_node); + Status UpdateSubgraphOutput(); /// /// @ingroup ge