diff --git a/ge/graph/passes/subgraph_const_migration_pass.cc b/ge/graph/passes/subgraph_const_migration_pass.cc index f3a6b998..71463f4c 100644 --- a/ge/graph/passes/subgraph_const_migration_pass.cc +++ b/ge/graph/passes/subgraph_const_migration_pass.cc @@ -23,7 +23,7 @@ namespace ge { constexpr uint32_t kZeroIndex = 0; constexpr uint32_t kCaseInputBase = 1; constexpr uint32_t kInvalidParent = 0x7fffffffU; -const char *const kMbatchNodeNameMark = "_ascend_mbatch_batch_"; +const string kMbatchNodeNameMark = "_ascend_mbatch_batch_"; bool IsSameConstNode(const NodePtr &src_node, const NodePtr &dst_node) { if ((src_node == nullptr) && (dst_node == nullptr)) { @@ -164,11 +164,16 @@ Status SubgraphConstMigrationPass::ClassifyGraphNodes(const ComputeGraphPtr &gra string node_full_name = peer_node->GetName(); size_t pos = node_full_name.find(kMbatchNodeNameMark); if (pos == string::npos) { - GELOGE(FAILED, "Cannot find: %s of multi-batch in node: %s", kMbatchNodeNameMark, node_full_name.c_str()); + GELOGE(FAILED, "find: %s of multi-batch in node: %s", kMbatchNodeNameMark.c_str(), node_full_name.c_str()); return FAILED; } string fixed_name = node_full_name.substr(0, pos); + pos = node_full_name.find("_", pos + kMbatchNodeNameMark.length()); + if (pos != string::npos) { + fixed_name += node_full_name.substr(pos); + } + peer_name_list.insert(fixed_name + ":" + std::to_string(in_anchor->GetIdx())); } @@ -336,14 +341,19 @@ Status SubgraphConstMigrationPass::AppendParallelNode(const NodePtr &func_node, /// @param [in] outputs: Parent index of Node output. /// @return 0: SUCCESS / others: FAILED /// -Status SubgraphConstMigrationPass::DetachParallelNode(const map &const_nodes, +Status SubgraphConstMigrationPass::DetachParallelNode(const ComputeGraphPtr &graph, + const map &const_nodes, const NodePtr &const_node, const NodePtr &data_node) { // Break Data and Move node. const auto &in_anchor = const_node->GetInControlAnchor(); const auto out_anchors = in_anchor->GetPeerOutControlAnchors(); for (const auto out_anchor : out_anchors) { GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, in_anchor), "Remove edge failed"); - GELOGI("Remove Edge: %s %s", out_anchor->GetOwnerNode()->GetName().c_str(), const_node->GetName().c_str()); + const auto owner_node = out_anchor->GetOwnerNode(); + GELOGI("Remove Edge: %s %s", owner_node->GetName().c_str(), const_node->GetName().c_str()); + if (owner_node->GetInAllNodes().empty() && owner_node->GetOutAllNodes().empty()) { + graph->RemoveNode(owner_node); + } } const auto &ctrl_anchor = const_node->GetOutControlAnchor(); @@ -454,7 +464,7 @@ Status SubgraphConstMigrationPass::MoveNodeToParent(const ComputeGraphPtr &graph return FAILED; } - if (DetachParallelNode(item.second, move_node, it_data->second) != SUCCESS) { + if (DetachParallelNode(subgraph, item.second, move_node, it_data->second) != SUCCESS) { GELOGE(FAILED, "Data: %s not found, index: %u", move_node->GetName().c_str(), parent_index); return FAILED; } diff --git a/ge/graph/passes/subgraph_const_migration_pass.h b/ge/graph/passes/subgraph_const_migration_pass.h index 66c0011c..323be0ff 100755 --- a/ge/graph/passes/subgraph_const_migration_pass.h +++ b/ge/graph/passes/subgraph_const_migration_pass.h @@ -119,8 +119,8 @@ class SubgraphConstMigrationPass : public GraphPass { /// @param [in] outputs: Parent index of Node output. /// @return 0: SUCCESS / others: FAILED /// - Status DetachParallelNode(const map &const_nodes, const NodePtr &const_node, - const NodePtr &data_node); + Status DetachParallelNode(const ComputeGraphPtr &graph, const map &const_nodes, + const NodePtr &const_node, const NodePtr &data_node); /// /// @ingroup ge