|
|
|
@@ -36,7 +36,6 @@ struct DuringPassNodeSets { |
|
|
|
std::unordered_set<NodePtr> nodes_re_pass; |
|
|
|
std::unordered_set<NodePtr> nodes_re_pass_immediately; |
|
|
|
std::unordered_set<NodePtr> nodes_last; |
|
|
|
std::unordered_set<NodePtr> nodes_stopped; |
|
|
|
}; |
|
|
|
|
|
|
|
void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &input_edge_nodes, |
|
|
|
@@ -56,25 +55,8 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &i |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void AddNextIterNodes(const std::vector<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass, |
|
|
|
DuringPassNodeSets &during_pass_node_set) { |
|
|
|
for (auto &node : nodes) { |
|
|
|
if (node == nullptr) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (during_pass_node_set.nodes_stopped.count(node) > 0) { |
|
|
|
GELOGD("The node %s was stopped by pass, skip it.", node->GetName().c_str()); |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
nodes_to_pass.push_back(node); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void GetNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::vector<NodePtr> &nodes_to_pass, |
|
|
|
DuringPassNodeSets &during_pass_node_set) { |
|
|
|
std::unordered_set<Node *> &nodes_seen = during_pass_node_set.nodes_seen; |
|
|
|
const std::unordered_set<NodePtr> &nodes_last = during_pass_node_set.nodes_last; |
|
|
|
void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass, |
|
|
|
std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_last) { |
|
|
|
for (auto &node : nodes) { |
|
|
|
if (node == nullptr) { |
|
|
|
continue; |
|
|
|
@@ -90,22 +72,8 @@ void GetNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::vector<NodePtr> & |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void PushToStoppedNodes(DuringPassNodeSets &during_pass_node_set, const std::string &pass_name, |
|
|
|
const std::unordered_set<NodePtr> &nodes_stopped, |
|
|
|
const std::unordered_set<NodePtr> &nodes_restored) { |
|
|
|
for (const auto &node : nodes_stopped) { |
|
|
|
GELOGD("The node %s was stopped by pass %s", node->GetName().c_str(), pass_name.c_str()); |
|
|
|
during_pass_node_set.nodes_stopped.emplace(node); |
|
|
|
} |
|
|
|
|
|
|
|
for (const auto &node : nodes_restored) { |
|
|
|
GELOGD("The node %s was restored by pass %s", node->GetName().c_str(), pass_name.c_str()); |
|
|
|
during_pass_node_set.nodes_stopped.erase(node); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass, |
|
|
|
std::unordered_set<Node *> &nodes_seen, const std::unordered_set<NodePtr> &nodes_to_re_pass, |
|
|
|
std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_to_re_pass, |
|
|
|
std::unordered_set<NodePtr> &nodes_re_pass) { |
|
|
|
for (const auto &node_to_re_pass : nodes_to_re_pass) { |
|
|
|
if (node_to_re_pass == nullptr) { |
|
|
|
@@ -129,8 +97,6 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo |
|
|
|
} |
|
|
|
GELOGD("Begin to run pass for node %s", node->GetName().c_str()); |
|
|
|
for (const auto &name_to_pass : names_to_passes) { |
|
|
|
const std::string &pass_name = name_to_pass.first; |
|
|
|
BaseNodePass *pass_node = name_to_pass.second; |
|
|
|
if (name_to_pass.second == nullptr) { |
|
|
|
GELOGE(INTERNAL_ERROR, "There is null pointer in passes(%s), skip it", name_to_pass.first.c_str()); |
|
|
|
continue; |
|
|
|
@@ -147,17 +113,15 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo |
|
|
|
return result; |
|
|
|
} |
|
|
|
|
|
|
|
const auto &nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); |
|
|
|
auto nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); |
|
|
|
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass, |
|
|
|
during_pass_node_set.nodes_re_pass); |
|
|
|
|
|
|
|
const auto &nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); |
|
|
|
auto nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); |
|
|
|
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, |
|
|
|
during_pass_node_set.nodes_re_pass_immediately); |
|
|
|
|
|
|
|
PushToStoppedNodes(during_pass_node_set, pass_name, pass_node->GetNodesStopped(), pass_node->GetNodesRestored()); |
|
|
|
|
|
|
|
const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); |
|
|
|
auto nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); |
|
|
|
during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); |
|
|
|
if (nodes_deleted_by_pass.count(node) > 0) { |
|
|
|
GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), |
|
|
|
@@ -258,8 +222,7 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<NodePtr> nodes_to_pass; |
|
|
|
GetNextIterNodes(node->GetOutNodes(), nodes_to_pass, during_pass_node_set); |
|
|
|
AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last); |
|
|
|
|
|
|
|
auto ret = RunPasses(node, names_to_passes, during_pass_node_set); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
@@ -295,8 +258,6 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { |
|
|
|
nodes.push_front(node); |
|
|
|
} |
|
|
|
during_pass_node_set.nodes_re_pass_immediately.clear(); |
|
|
|
|
|
|
|
AddNextIterNodes(nodes_to_pass, nodes, during_pass_node_set); |
|
|
|
} |
|
|
|
|
|
|
|
for (auto &node : during_pass_node_set.nodes_last) { |
|
|
|
|