diff --git a/ge/graph/passes/data_pass.cc b/ge/graph/passes/data_pass.cc index 2a689cd5..5bbd2fb1 100644 --- a/ge/graph/passes/data_pass.cc +++ b/ge/graph/passes/data_pass.cc @@ -185,17 +185,13 @@ Status DataPass::Run(ComputeGraphPtr compute_graph) { const auto &parent_graph = compute_graph->GetParentGraph(); GE_CHECK_NOTNULL(parent_graph); - bool flag = false; - (void)AttrUtils::GetBool(compute_graph, "_no_reset_name", flag); - if (!flag) { - for (const NodePtr &node : compute_graph->GetDirectNode()) { - GE_CHECK_NOTNULL(node->GetOpDesc()); - if ((node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == NETOUTPUT)) { - continue; - } - - node->GetOpDesc()->SetName(parent_node->GetName() + "_" + compute_graph->GetName() + "/" + node->GetName()); + for (const NodePtr &node : compute_graph->GetDirectNode()) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + if ((node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == NETOUTPUT)) { + continue; } + + node->GetOpDesc()->SetName(parent_node->GetName() + "_" + compute_graph->GetName() + "/" + node->GetName()); } return PostParseSubgraph(compute_graph, subgraph_name, parent_node); diff --git a/ge/graph/passes/multi_batch_clone_pass.cc b/ge/graph/passes/multi_batch_clone_pass.cc index 496ad214..37f4b637 100755 --- a/ge/graph/passes/multi_batch_clone_pass.cc +++ b/ge/graph/passes/multi_batch_clone_pass.cc @@ -503,12 +503,24 @@ Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) { /// /// @ingroup ge -/// @brief Set shape to Data node in branch. -/// @param [in] const NodePtr &data: data in branch. +/// @brief Update Data node in Subgraph. +/// @param [in] const NodePtr &data: data in Subgraph. /// @param [in] size_t index: The batch index. /// @return 0: SUCCESS / others: FAILED /// -Status MultiBatchClonePass::UpdateShapeToData(const NodePtr &data, size_t index) { +Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t index) { + int node_index = -1; + if (!AttrUtils::GetInt(data->GetOpDesc(), ATTR_NAME_INDEX, node_index)) { + GELOGE(FAILED, "Failed to get index from data[%s]", data->GetName().c_str()); + return FAILED; + } + + int parent_index = node_index + 1; + if (!AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + GELOGE(FAILED, "Failed to set parent index for node %s", data->GetName().c_str()); + return FAILED; + } + auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); const auto &dims = data_shape.GetDims(); if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { @@ -581,13 +593,15 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const (void)AttrUtils::SetBool(subgraph, "_no_reset_name", true); graph->AddSubgraph(subgraph->GetName(), subgraph); all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT); + GE_CHK_STATUS_RET(UpdateSubgraphOutput(all_branch_output_[subgraph]), + "Update %s failed", all_branch_output_[subgraph]->GetName().c_str()); const string key_name = "branches" + std::to_string(i); op_desc->AddSubgraphName(key_name); op_desc->SetSubgraphInstanceName(i, subgraph->GetName()); for (const auto &data : input_nodes) { - GE_CHK_STATUS_RET(UpdateShapeToData(data, i), "Update %s failed", subgraph->GetName().c_str()); + GE_CHK_STATUS_RET(UpdateSubgraphData(data, i), "Update %s failed", subgraph->GetName().c_str()); } } @@ -596,7 +610,28 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const const auto &op_desc = n->GetOpDesc(); op_desc->SetName(n->GetName() + kMultiBatchNodePostfix + "0"); if (n->GetType() == DATA) { - GE_CHK_STATUS_RET(UpdateShapeToData(n, 0), "Update %s failed", branch->GetName().c_str()); + GE_CHK_STATUS_RET(UpdateSubgraphData(n, 0), "Update %s failed", branch->GetName().c_str()); + } + } + + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Update output_node in Subgraph. +/// @param [in] const NodePtr &data: output_node in Subgraph. +/// @return 0: SUCCESS / others: FAILED +/// +Status MultiBatchClonePass::UpdateSubgraphOutput(const NodePtr &output_node) { + const auto &op_desc = output_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + for (size_t index = 0; index < op_desc->GetInputsSize(); ++index) { + GeTensorDescPtr tensor = op_desc->MutableInputDesc(index); + GE_CHECK_NOTNULL(tensor); + if (!AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) { + GELOGE(FAILED, "Failed to set parent index for node %s", output_node->GetName().c_str()); + return FAILED; } } diff --git a/ge/graph/passes/multi_batch_clone_pass.h b/ge/graph/passes/multi_batch_clone_pass.h index 5921970a..3dbf91db 100755 --- a/ge/graph/passes/multi_batch_clone_pass.h +++ b/ge/graph/passes/multi_batch_clone_pass.h @@ -105,12 +105,20 @@ class MultiBatchClonePass : public GraphPass { /// /// @ingroup ge - /// @brief Set shape to Data node in branch. - /// @param [in] const NodePtr &data: data in branch. + /// @brief Update Data node in Subgraph. + /// @param [in] const NodePtr &data: data in Subgraph. /// @param [in] size_t index: The batch index. /// @return 0: SUCCESS / others: FAILED /// - Status UpdateShapeToData(const NodePtr &data, size_t index); + Status UpdateSubgraphData(const NodePtr &data, size_t index); + + /// + /// @ingroup ge + /// @brief Update output_node in Subgraph. + /// @param [in] const NodePtr &data: output_node in Subgraph. + /// @return 0: SUCCESS / others: FAILED + /// + Status UpdateSubgraphOutput(const NodePtr &output_node); /// /// @ingroup ge diff --git a/ge/graph/preprocess/multi_batch_copy_graph.cc b/ge/graph/preprocess/multi_batch_copy_graph.cc index eae97b04..c8880b2e 100644 --- a/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -29,7 +29,6 @@ #include "framework/omg/omg_inner_types.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_context.h" -#include "graph/passes/data_pass.h" #include "graph/passes/multi_batch_clone_pass.h" #include "graph/passes/prune_pass.h" #include "graph/preprocess/multi_batch_options.h" @@ -1698,7 +1697,6 @@ Status ProcessMultiBatch(ComputeGraphPtr &graph) { if (multi_batch_with_switchn == nullptr) { PassManager pass_manager; GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); - GE_CHK_STATUS_RET(pass_manager.AddPass("DataPass", new (std::nothrow) DataPass)); // set subgraph parent index return pass_manager.Run(graph); } }