diff --git a/ge/graph/passes/subgraph_const_migration_pass.cc b/ge/graph/passes/subgraph_const_migration_pass.cc index 7cf75661..d8ad41e1 100644 --- a/ge/graph/passes/subgraph_const_migration_pass.cc +++ b/ge/graph/passes/subgraph_const_migration_pass.cc @@ -164,29 +164,9 @@ Status SubgraphConstMigrationPass::ClassifyGraphNodes(const ComputeGraphPtr &gra data_nodes[parent_index] = node; GELOGD("%s, index: %u, Data: %s", subgraph->GetName().c_str(), parent_index, node->GetName().c_str()); - } else if (node->GetType() == CONSTANT) { + } else if (node->GetType() == CONSTANT && (node->GetOutDataAnchor(kZeroIndex) != nullptr)) { set peer_name_list; - const auto &out_anchor = node->GetOutDataAnchor(kZeroIndex); - GE_IF_BOOL_EXEC(out_anchor == nullptr, continue); - for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { - const auto &peer_node = in_anchor->GetOwnerNode(); - // Trim subgraph node name prefix. - string node_full_name = peer_node->GetName(); - size_t pos = node_full_name.find(kMbatchNodeNameMark); - if (pos == string::npos) { - GELOGI("Can not find: %s of multi-batch in node: %s", kMbatchNodeNameMark.c_str(), node_full_name.c_str()); - continue; - } - - string fixed_name = node_full_name.substr(0, pos); - pos = node_full_name.find("_", pos + kMbatchNodeNameMark.length()); - if (pos != string::npos) { - fixed_name += node_full_name.substr(pos); - } - - peer_name_list.insert(fixed_name + ":" + std::to_string(in_anchor->GetIdx())); - } - + GetPeerNameList(node, peer_name_list); if (peer_name_list.empty()) { GELOGI("%s, Const: %s, no data output", subgraph->GetName().c_str(), node->GetName().c_str()); const auto in_all_nodes = node->GetInAllNodes(); @@ -217,6 +197,28 @@ Status SubgraphConstMigrationPass::ClassifyGraphNodes(const ComputeGraphPtr &gra return SUCCESS; } +void SubgraphConstMigrationPass::GetPeerNameList(const NodePtr &node, set &peer_name_list) { + const auto &out_anchor = node->GetOutDataAnchor(kZeroIndex); + for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { + const auto &peer_node = in_anchor->GetOwnerNode(); + // Trim subgraph node name prefix. + string node_full_name = peer_node->GetName(); + size_t pos = node_full_name.find(kMbatchNodeNameMark); + if (pos == string::npos) { + GELOGI("Can not find: %s of multi-batch in node: %s", kMbatchNodeNameMark.c_str(), node_full_name.c_str()); + continue; + } + + string fixed_name = node_full_name.substr(0, pos); + pos = node_full_name.find("_", pos + kMbatchNodeNameMark.length()); + if (pos != string::npos) { + fixed_name += node_full_name.substr(pos); + } + + peer_name_list.insert(fixed_name + ":" + std::to_string(in_anchor->GetIdx())); + } +} + /// /// @ingroup ge /// @brief Get parent_index for Const node migration. diff --git a/ge/graph/passes/subgraph_const_migration_pass.h b/ge/graph/passes/subgraph_const_migration_pass.h index d93da839..2834fd66 100755 --- a/ge/graph/passes/subgraph_const_migration_pass.h +++ b/ge/graph/passes/subgraph_const_migration_pass.h @@ -133,6 +133,8 @@ class SubgraphConstMigrationPass : public GraphPass { /// Status AttachParallelNode(const ComputeGraphPtr &graph, const NodePtr &func_node, const NodePtr &const_node, uint32_t parent_index); + + void GetPeerNameList(const NodePtr &node, set &peer_name_list); }; } // namespace ge #endif // GE_COMMON_SUBGRAPH_CONST_MIGRATION_H_ \ No newline at end of file