Browse Source

add pass

tags/v1.3.0
wjm 3 years ago
parent
commit
e6253a449b
2 changed files with 26 additions and 22 deletions
  1. +24
    -22
      ge/graph/passes/subgraph_const_migration_pass.cc
  2. +2
    -0
      ge/graph/passes/subgraph_const_migration_pass.h

+ 24
- 22
ge/graph/passes/subgraph_const_migration_pass.cc View File

@@ -164,29 +164,9 @@ Status SubgraphConstMigrationPass::ClassifyGraphNodes(const ComputeGraphPtr &gra

data_nodes[parent_index] = node;
GELOGD("%s, index: %u, Data: %s", subgraph->GetName().c_str(), parent_index, node->GetName().c_str());
} else if (node->GetType() == CONSTANT) {
} else if (node->GetType() == CONSTANT && (node->GetOutDataAnchor(kZeroIndex) != nullptr)) {
set<string> peer_name_list;
const auto &out_anchor = node->GetOutDataAnchor(kZeroIndex);
GE_IF_BOOL_EXEC(out_anchor == nullptr, continue);
for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) {
const auto &peer_node = in_anchor->GetOwnerNode();
// Trim subgraph node name prefix.
string node_full_name = peer_node->GetName();
size_t pos = node_full_name.find(kMbatchNodeNameMark);
if (pos == string::npos) {
GELOGI("Can not find: %s of multi-batch in node: %s", kMbatchNodeNameMark.c_str(), node_full_name.c_str());
continue;
}

string fixed_name = node_full_name.substr(0, pos);
pos = node_full_name.find("_", pos + kMbatchNodeNameMark.length());
if (pos != string::npos) {
fixed_name += node_full_name.substr(pos);
}

peer_name_list.insert(fixed_name + ":" + std::to_string(in_anchor->GetIdx()));
}

GetPeerNameList(node, peer_name_list);
if (peer_name_list.empty()) {
GELOGI("%s, Const: %s, no data output", subgraph->GetName().c_str(), node->GetName().c_str());
const auto in_all_nodes = node->GetInAllNodes();
@@ -217,6 +197,28 @@ Status SubgraphConstMigrationPass::ClassifyGraphNodes(const ComputeGraphPtr &gra
return SUCCESS;
}

void SubgraphConstMigrationPass::GetPeerNameList(const NodePtr &node, set<string> &peer_name_list) {
const auto &out_anchor = node->GetOutDataAnchor(kZeroIndex);
for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) {
const auto &peer_node = in_anchor->GetOwnerNode();
// Trim subgraph node name prefix.
string node_full_name = peer_node->GetName();
size_t pos = node_full_name.find(kMbatchNodeNameMark);
if (pos == string::npos) {
GELOGI("Can not find: %s of multi-batch in node: %s", kMbatchNodeNameMark.c_str(), node_full_name.c_str());
continue;
}

string fixed_name = node_full_name.substr(0, pos);
pos = node_full_name.find("_", pos + kMbatchNodeNameMark.length());
if (pos != string::npos) {
fixed_name += node_full_name.substr(pos);
}

peer_name_list.insert(fixed_name + ":" + std::to_string(in_anchor->GetIdx()));
}
}

///
/// @ingroup ge
/// @brief Get parent_index for Const node migration.


+ 2
- 0
ge/graph/passes/subgraph_const_migration_pass.h View File

@@ -133,6 +133,8 @@ class SubgraphConstMigrationPass : public GraphPass {
///
Status AttachParallelNode(const ComputeGraphPtr &graph, const NodePtr &func_node,
const NodePtr &const_node, uint32_t parent_index);

void GetPeerNameList(const NodePtr &node, set<string> &peer_name_list);
};
} // namespace ge
#endif // GE_COMMON_SUBGRAPH_CONST_MIGRATION_H_

Loading…
Cancel
Save