|
@@ -56,19 +56,29 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &i |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass, |
|
|
|
|
|
|
|
|
void AddNextIterNodes(const std::vector<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass, |
|
|
DuringPassNodeSets &during_pass_node_set) { |
|
|
DuringPassNodeSets &during_pass_node_set) { |
|
|
std::unordered_set<Node *> &nodes_seen = during_pass_node_set.nodes_seen; |
|
|
|
|
|
const std::unordered_set<NodePtr> &nodes_last = during_pass_node_set.nodes_last; |
|
|
|
|
|
const std::unordered_set<NodePtr> &nodes_stopped = during_pass_node_set.nodes_stopped; |
|
|
|
|
|
for (auto &node : nodes) { |
|
|
for (auto &node : nodes) { |
|
|
if (node == nullptr) { |
|
|
if (node == nullptr) { |
|
|
continue; |
|
|
continue; |
|
|
} |
|
|
} |
|
|
if (nodes_stopped.count(node) > 0) { |
|
|
|
|
|
|
|
|
if (during_pass_node_set.nodes_stopped.count(node) > 0) { |
|
|
GELOGD("The node %s was stopped by pass, skip it.", node->GetName().c_str()); |
|
|
GELOGD("The node %s was stopped by pass, skip it.", node->GetName().c_str()); |
|
|
continue; |
|
|
continue; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
nodes_to_pass.push_back(node); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void GetNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::vector<NodePtr> &nodes_to_pass, |
|
|
|
|
|
DuringPassNodeSets &during_pass_node_set) { |
|
|
|
|
|
std::unordered_set<Node *> &nodes_seen = during_pass_node_set.nodes_seen; |
|
|
|
|
|
const std::unordered_set<NodePtr> &nodes_last = during_pass_node_set.nodes_last; |
|
|
|
|
|
for (auto &node : nodes) { |
|
|
|
|
|
if (node == nullptr) { |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
if (nodes_last.count(node) != 0) { |
|
|
if (nodes_last.count(node) != 0) { |
|
|
continue; |
|
|
continue; |
|
|
} |
|
|
} |
|
@@ -80,6 +90,20 @@ void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &n |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void PushToStoppedNodes(DuringPassNodeSets &during_pass_node_set, const std::string &pass_name, |
|
|
|
|
|
const std::unordered_set<NodePtr> &nodes_stopped, |
|
|
|
|
|
const std::unordered_set<NodePtr> &nodes_restored) { |
|
|
|
|
|
for (const auto &node : nodes_stopped) { |
|
|
|
|
|
GELOGD("The node %s was stopped by pass %s", node->GetName().c_str(), pass_name.c_str()); |
|
|
|
|
|
during_pass_node_set.nodes_stopped.emplace(node); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
for (const auto &node : nodes_restored) { |
|
|
|
|
|
GELOGD("The node %s was restored by pass %s", node->GetName().c_str(), pass_name.c_str()); |
|
|
|
|
|
during_pass_node_set.nodes_stopped.erase(node); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass, |
|
|
void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass, |
|
|
std::unordered_set<Node *> &nodes_seen, const std::unordered_set<NodePtr> &nodes_to_re_pass, |
|
|
std::unordered_set<Node *> &nodes_seen, const std::unordered_set<NodePtr> &nodes_to_re_pass, |
|
|
std::unordered_set<NodePtr> &nodes_re_pass) { |
|
|
std::unordered_set<NodePtr> &nodes_re_pass) { |
|
@@ -105,6 +129,8 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo |
|
|
} |
|
|
} |
|
|
GELOGD("Begin to run pass for node %s", node->GetName().c_str()); |
|
|
GELOGD("Begin to run pass for node %s", node->GetName().c_str()); |
|
|
for (const auto &name_to_pass : names_to_passes) { |
|
|
for (const auto &name_to_pass : names_to_passes) { |
|
|
|
|
|
const std::string &pass_name = name_to_pass.first; |
|
|
|
|
|
BaseNodePass *pass_node = name_to_pass.second; |
|
|
if (name_to_pass.second == nullptr) { |
|
|
if (name_to_pass.second == nullptr) { |
|
|
GELOGE(INTERNAL_ERROR, "There is null pointer in passes(%s), skip it", name_to_pass.first.c_str()); |
|
|
GELOGE(INTERNAL_ERROR, "There is null pointer in passes(%s), skip it", name_to_pass.first.c_str()); |
|
|
continue; |
|
|
continue; |
|
@@ -129,14 +155,7 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo |
|
|
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, |
|
|
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, |
|
|
during_pass_node_set.nodes_re_pass_immediately); |
|
|
during_pass_node_set.nodes_re_pass_immediately); |
|
|
|
|
|
|
|
|
for (const auto &node : name_to_pass.second->GetNodesStopped()) { |
|
|
|
|
|
GELOGD("The node %s was stopped by pass %s", node->GetName().c_str(), name_to_pass.first.c_str()); |
|
|
|
|
|
during_pass_node_set.nodes_stopped.emplace(node); |
|
|
|
|
|
} |
|
|
|
|
|
for (const auto &node : name_to_pass.second->GetNodesRestored()) { |
|
|
|
|
|
GELOGD("The node %s was restored by pass %s", node->GetName().c_str(), name_to_pass.first.c_str()); |
|
|
|
|
|
during_pass_node_set.nodes_stopped.erase(node); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
PushToStoppedNodes(during_pass_node_set, pass_name, pass_node->GetNodesStopped(), pass_node->GetNodesRestored()); |
|
|
|
|
|
|
|
|
const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); |
|
|
const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); |
|
|
during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); |
|
|
during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); |
|
@@ -239,7 +258,9 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { |
|
|
continue; |
|
|
continue; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
const auto all_out_nodes = node->GetOutNodes(); |
|
|
|
|
|
|
|
|
std::vector<NodePtr> nodes_to_pass; |
|
|
|
|
|
GetNextIterNodes(node->GetOutAllNodes(), nodes_to_pass, during_pass_node_set); |
|
|
|
|
|
|
|
|
auto ret = RunPasses(node, names_to_passes, during_pass_node_set); |
|
|
auto ret = RunPasses(node, names_to_passes, during_pass_node_set); |
|
|
if (ret != SUCCESS) { |
|
|
if (ret != SUCCESS) { |
|
|
GELOGE(ret, "Failed to process passes on node %s type %s, error code: %u", |
|
|
GELOGE(ret, "Failed to process passes on node %s type %s, error code: %u", |
|
@@ -275,7 +296,7 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { |
|
|
} |
|
|
} |
|
|
during_pass_node_set.nodes_re_pass_immediately.clear(); |
|
|
during_pass_node_set.nodes_re_pass_immediately.clear(); |
|
|
|
|
|
|
|
|
AddNextIterNodes(all_out_nodes, nodes, during_pass_node_set); |
|
|
|
|
|
|
|
|
AddNextIterNodes(nodes_to_pass, nodes, during_pass_node_set); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
for (auto &node : during_pass_node_set.nodes_last) { |
|
|
for (auto &node : during_pass_node_set.nodes_last) { |
|
|