Browse Source

fix DirectOutput error in mult batch

tags/v1.2.0
wjm 3 years ago
parent
commit
b6d89663fa
2 changed files with 14 additions and 15 deletions
  1. +13
    -13
      ge/graph/passes/multi_batch_clone_pass.cc
  2. +1
    -2
      ge/graph/passes/multi_batch_clone_pass.h

+ 13
- 13
ge/graph/passes/multi_batch_clone_pass.cc View File

@@ -109,6 +109,7 @@ Status MultiBatchClonePass::Run(ComputeGraphPtr graph) {
GE_CHK_STATUS_RET(CreateSubgraphs(graph, branch), "Construct subgraph failed."); GE_CHK_STATUS_RET(CreateSubgraphs(graph, branch), "Construct subgraph failed.");


GE_CHK_STATUS_RET(PruneDirectOutput(graph), "Prune direct output failed"); GE_CHK_STATUS_RET(PruneDirectOutput(graph), "Prune direct output failed");
GE_CHK_STATUS_RET(UpdateSubgraphOutput(), "Update subgraph output failed");
GELOGD("MultiBatchClonePass Leave"); GELOGD("MultiBatchClonePass Leave");
return SUCCESS; return SUCCESS;
} }
@@ -1057,8 +1058,6 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const
subgraph->SetParentGraph(graph); subgraph->SetParentGraph(graph);
graph->AddSubgraph(subgraph->GetName(), subgraph); graph->AddSubgraph(subgraph->GetName(), subgraph);
all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT); all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT);
GE_CHK_STATUS_RET(UpdateSubgraphOutput(all_branch_output_[subgraph]),
"Update %s failed", all_branch_output_[subgraph]->GetName().c_str());


const string key_name = "branches" + std::to_string(i); const string key_name = "branches" + std::to_string(i);
op_desc->AddSubgraphName(key_name); op_desc->AddSubgraphName(key_name);
@@ -1085,21 +1084,22 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const
/// ///
/// @ingroup ge /// @ingroup ge
/// @brief Update output_node in Subgraph. /// @brief Update output_node in Subgraph.
/// @param [in] const NodePtr &output_node: output_node in Subgraph.
/// @return 0: SUCCESS / others: FAILED /// @return 0: SUCCESS / others: FAILED
/// ///
Status MultiBatchClonePass::UpdateSubgraphOutput(const NodePtr &output_node) {
const auto &op_desc = output_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
for (size_t index = 0; index < op_desc->GetInputsSize(); ++index) {
GeTensorDescPtr tensor = op_desc->MutableInputDesc(index);
GE_CHECK_NOTNULL(tensor);
if (!AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) {
GELOGE(FAILED, "Failed to set parent index for node %s", output_node->GetName().c_str());
return FAILED;
Status MultiBatchClonePass::UpdateSubgraphOutput() {
for (const auto &item : all_branch_output_) {
const auto &output_node = item.second;
const auto &op_desc = output_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
for (size_t index = 0; index < op_desc->GetInputsSize(); ++index) {
GeTensorDescPtr tensor = op_desc->MutableInputDesc(index);
GE_CHECK_NOTNULL(tensor);
if (!AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) {
GELOGE(FAILED, "Failed to set parent index for node %s", output_node->GetName().c_str());
return FAILED;
}
} }
} }

return SUCCESS; return SUCCESS;
} }




+ 1
- 2
ge/graph/passes/multi_batch_clone_pass.h View File

@@ -136,10 +136,9 @@ class MultiBatchClonePass : public GraphPass {
/// ///
/// @ingroup ge /// @ingroup ge
/// @brief Update output_node in Subgraph. /// @brief Update output_node in Subgraph.
/// @param [in] const NodePtr &output_node: output_node in Subgraph.
/// @return 0: SUCCESS / others: FAILED /// @return 0: SUCCESS / others: FAILED
/// ///
Status UpdateSubgraphOutput(const NodePtr &output_node);
Status UpdateSubgraphOutput();


/// ///
/// @ingroup ge /// @ingroup ge


Loading…
Cancel
Save