|
|
@@ -23,7 +23,7 @@ namespace ge { |
|
|
|
constexpr uint32_t kZeroIndex = 0; |
|
|
|
constexpr uint32_t kCaseInputBase = 1; |
|
|
|
constexpr uint32_t kInvalidParent = 0x7fffffffU; |
|
|
|
const char *const kMbatchNodeNameMark = "_ascend_mbatch_batch_"; |
|
|
|
const string kMbatchNodeNameMark = "_ascend_mbatch_batch_"; |
|
|
|
|
|
|
|
bool IsSameConstNode(const NodePtr &src_node, const NodePtr &dst_node) { |
|
|
|
if ((src_node == nullptr) && (dst_node == nullptr)) { |
|
|
@@ -164,11 +164,16 @@ Status SubgraphConstMigrationPass::ClassifyGraphNodes(const ComputeGraphPtr &gra |
|
|
|
string node_full_name = peer_node->GetName(); |
|
|
|
size_t pos = node_full_name.find(kMbatchNodeNameMark); |
|
|
|
if (pos == string::npos) { |
|
|
|
GELOGE(FAILED, "Cannot find: %s of multi-batch in node: %s", kMbatchNodeNameMark, node_full_name.c_str()); |
|
|
|
GELOGE(FAILED, "find: %s of multi-batch in node: %s", kMbatchNodeNameMark.c_str(), node_full_name.c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
string fixed_name = node_full_name.substr(0, pos); |
|
|
|
pos = node_full_name.find("_", pos + kMbatchNodeNameMark.length()); |
|
|
|
if (pos != string::npos) { |
|
|
|
fixed_name += node_full_name.substr(pos); |
|
|
|
} |
|
|
|
|
|
|
|
peer_name_list.insert(fixed_name + ":" + std::to_string(in_anchor->GetIdx())); |
|
|
|
} |
|
|
|
|
|
|
@@ -336,14 +341,19 @@ Status SubgraphConstMigrationPass::AppendParallelNode(const NodePtr &func_node, |
|
|
|
/// @param [in] outputs: Parent index of Node output. |
|
|
|
/// @return 0: SUCCESS / others: FAILED |
|
|
|
/// |
|
|
|
Status SubgraphConstMigrationPass::DetachParallelNode(const map<string, NodePtr> &const_nodes, |
|
|
|
Status SubgraphConstMigrationPass::DetachParallelNode(const ComputeGraphPtr &graph, |
|
|
|
const map<string, NodePtr> &const_nodes, |
|
|
|
const NodePtr &const_node, const NodePtr &data_node) { |
|
|
|
// Break Data and Move node. |
|
|
|
const auto &in_anchor = const_node->GetInControlAnchor(); |
|
|
|
const auto out_anchors = in_anchor->GetPeerOutControlAnchors(); |
|
|
|
for (const auto out_anchor : out_anchors) { |
|
|
|
GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, in_anchor), "Remove edge failed"); |
|
|
|
GELOGI("Remove Edge: %s %s", out_anchor->GetOwnerNode()->GetName().c_str(), const_node->GetName().c_str()); |
|
|
|
const auto owner_node = out_anchor->GetOwnerNode(); |
|
|
|
GELOGI("Remove Edge: %s %s", owner_node->GetName().c_str(), const_node->GetName().c_str()); |
|
|
|
if (owner_node->GetInAllNodes().empty() && owner_node->GetOutAllNodes().empty()) { |
|
|
|
graph->RemoveNode(owner_node); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
const auto &ctrl_anchor = const_node->GetOutControlAnchor(); |
|
|
@@ -454,7 +464,7 @@ Status SubgraphConstMigrationPass::MoveNodeToParent(const ComputeGraphPtr &graph |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (DetachParallelNode(item.second, move_node, it_data->second) != SUCCESS) { |
|
|
|
if (DetachParallelNode(subgraph, item.second, move_node, it_data->second) != SUCCESS) { |
|
|
|
GELOGE(FAILED, "Data: %s not found, index: %u", move_node->GetName().c_str(), parent_index); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|