diff --git a/ge/graph/passes/base_pass.cc b/ge/graph/passes/base_pass.cc index f0ab392b..0868b729 100755 --- a/ge/graph/passes/base_pass.cc +++ b/ge/graph/passes/base_pass.cc @@ -36,7 +36,6 @@ struct DuringPassNodeSets { std::unordered_set nodes_re_pass; std::unordered_set nodes_re_pass_immediately; std::unordered_set nodes_last; - std::unordered_set nodes_stopped; }; void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque &input_edge_nodes, @@ -56,25 +55,8 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque &i } } -void AddNextIterNodes(const std::vector &nodes, std::deque &nodes_to_pass, - DuringPassNodeSets &during_pass_node_set) { - for (auto &node : nodes) { - if (node == nullptr) { - continue; - } - if (during_pass_node_set.nodes_stopped.count(node) > 0) { - GELOGD("The node %s was stopped by pass, skip it.", node->GetName().c_str()); - continue; - } - - nodes_to_pass.push_back(node); - } -} - -void GetNextIterNodes(const Node::Vistor &nodes, std::vector &nodes_to_pass, - DuringPassNodeSets &during_pass_node_set) { - std::unordered_set &nodes_seen = during_pass_node_set.nodes_seen; - const std::unordered_set &nodes_last = during_pass_node_set.nodes_last; +void AddNextIterNodes(const Node::Vistor &nodes, std::deque &nodes_to_pass, + std::unordered_set &nodes_seen, std::unordered_set &nodes_last) { for (auto &node : nodes) { if (node == nullptr) { continue; @@ -90,22 +72,8 @@ void GetNextIterNodes(const Node::Vistor &nodes, std::vector & } } -void PushToStoppedNodes(DuringPassNodeSets &during_pass_node_set, const std::string &pass_name, - const std::unordered_set &nodes_stopped, - const std::unordered_set &nodes_restored) { - for (const auto &node : nodes_stopped) { - GELOGD("The node %s was stopped by pass %s", node->GetName().c_str(), pass_name.c_str()); - during_pass_node_set.nodes_stopped.emplace(node); - } - - for (const auto &node : nodes_restored) { - GELOGD("The node %s was restored by pass %s", node->GetName().c_str(), pass_name.c_str()); - during_pass_node_set.nodes_stopped.erase(node); - } -} - void PushToRePassIfSeen(NodePtr &node, const std::pair &name_to_pass, - std::unordered_set &nodes_seen, const std::unordered_set &nodes_to_re_pass, + std::unordered_set &nodes_seen, std::unordered_set &nodes_to_re_pass, std::unordered_set &nodes_re_pass) { for (const auto &node_to_re_pass : nodes_to_re_pass) { if (node_to_re_pass == nullptr) { @@ -129,8 +97,6 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo } GELOGD("Begin to run pass for node %s", node->GetName().c_str()); for (const auto &name_to_pass : names_to_passes) { - const std::string &pass_name = name_to_pass.first; - BaseNodePass *pass_node = name_to_pass.second; if (name_to_pass.second == nullptr) { GELOGE(INTERNAL_ERROR, "There is null pointer in passes(%s), skip it", name_to_pass.first.c_str()); continue; @@ -147,17 +113,15 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo return result; } - const auto &nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); + auto nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass, during_pass_node_set.nodes_re_pass); - const auto &nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); + auto nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, during_pass_node_set.nodes_re_pass_immediately); - PushToStoppedNodes(during_pass_node_set, pass_name, pass_node->GetNodesStopped(), pass_node->GetNodesRestored()); - - const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); + auto nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); if (nodes_deleted_by_pass.count(node) > 0) { GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), @@ -258,8 +222,7 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { continue; } - std::vector nodes_to_pass; - GetNextIterNodes(node->GetOutNodes(), nodes_to_pass, during_pass_node_set); + AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last); auto ret = RunPasses(node, names_to_passes, during_pass_node_set); if (ret != SUCCESS) { @@ -295,8 +258,6 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { nodes.push_front(node); } during_pass_node_set.nodes_re_pass_immediately.clear(); - - AddNextIterNodes(nodes_to_pass, nodes, during_pass_node_set); } for (auto &node : during_pass_node_set.nodes_last) { diff --git a/ge/graph/passes/base_pass.h b/ge/graph/passes/base_pass.h index 4dde1495..a9f4f000 100644 --- a/ge/graph/passes/base_pass.h +++ b/ge/graph/passes/base_pass.h @@ -51,15 +51,11 @@ class BaseNodePass { virtual ~BaseNodePass() = default; - const std::unordered_set &GetNodesNeedRePass() { return nodes_need_re_pass_; } + std::unordered_set GetNodesNeedRePass() { return nodes_need_re_pass_; } - const std::unordered_set &GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; } + std::unordered_set GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; } - const std::unordered_set &GetNodesDeleted() { return nodes_deleted_; } - - const std::unordered_set &GetNodesStopped() { return nodes_stopped_; } - - const std::unordered_set &GetNodesRestored() { return nodes_restored_; } + std::unordered_set GetNodesDeleted() { return nodes_deleted_; } void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } @@ -69,8 +65,6 @@ class BaseNodePass { nodes_need_re_pass_.clear(); nodes_deleted_.clear(); nodes_need_re_pass_immediately_.clear(); - nodes_stopped_.clear(); - nodes_restored_.clear(); } protected: @@ -86,7 +80,7 @@ class BaseNodePass { /// optimized by other passes, call this function. /// @param node /// - void AddRePassNode(const NodePtr &node) { nodes_need_re_pass_.insert(node); } + void AddRePassNode(NodePtr &node) { nodes_need_re_pass_.insert(node); } /// /// Add a node to be optimized immediately again. If you add a new node to the graph, or @@ -94,13 +88,13 @@ class BaseNodePass { /// optimized by other passes, call this function. /// @param node /// - void AddImmediateRePassNode(const NodePtr &node) { nodes_need_re_pass_immediately_.insert(node); } + void AddImmediateRePassNode(NodePtr &node) { nodes_need_re_pass_immediately_.insert(node); } /// /// Add a node and it's input/output data nodes to be optimized again. /// @param node /// - void AddRePassNodesWithInOut(const NodePtr &node) { + void AddRePassNodesWithInOut(NodePtr &node) { AddRePassNode(node); auto out_nodes = node->GetOutNodes(); for (auto &out_node : out_nodes) { @@ -122,34 +116,12 @@ class BaseNodePass { /// void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } - /// - /// If you stop a node from the graph, especially following node. The remain - /// iterate passes will stop process on the stopped node(if it can be - /// reached by edge connections) till the last one. Obviously it is a waste of - /// time. You can add the stopped nodes by calling this function, to stop the - /// next iterations. - /// @param node - /// - void AddNodeStopped(const NodePtr &node) { nodes_stopped_.insert(node); } - - /// - /// If you restore a node from the graph, especially following node. The remain - /// iterate passes will continue process on the stopped node(if it can be - /// reached by edge connections) till the last one. - /// You can add the restored nodes by calling this function, to restore the - /// next iterations. - /// @param node - /// - void AddNodeRestored(const NodePtr &node) { nodes_restored_.insert(node); } - bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } private: std::unordered_set nodes_need_re_pass_; std::unordered_set nodes_need_re_pass_immediately_; std::unordered_set nodes_deleted_; - std::unordered_set nodes_stopped_; - std::unordered_set nodes_restored_; std::map options_; }; diff --git a/ge/graph/passes/infershape_pass.cc b/ge/graph/passes/infershape_pass.cc index acd240a5..46026023 100755 --- a/ge/graph/passes/infershape_pass.cc +++ b/ge/graph/passes/infershape_pass.cc @@ -126,19 +126,6 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) { return SUCCESS; }; - const auto ExProcNode = [&](const std::set &proc_types, - const std::function &proc_func, - const std::string &info) { - for (auto &n : node->GetOutDataNodes()) { - GE_CHECK_NOTNULL(n); - if (proc_types.count(n->GetType()) > 0) { - proc_func(this, n); - GELOGD("Node %s %s after %s.", n->GetName().c_str(), info.c_str(), node->GetName().c_str()); - } - } - return SUCCESS; - }; - if (node->GetType() == NEXTITERATION || node->GetType() == REFNEXTITERATION) { return RePassNode({MERGE, REFMERGE}); // Re-Pass Merge } @@ -146,20 +133,10 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) { if (node->GetType() == MERGE || node->GetType() == REFMERGE) { if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); - return RePassNode({SWITCH, REFSWITCH}); // Re-Pass Switch } return SUCCESS; } - if (node->GetType() == SWITCH || node->GetType() == REFSWITCH) { - if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { - node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); - return ExProcNode({EXIT, REFEXIT}, &InferShapePass::AddNodeRestored, "need restore"); // Restore Exit - } else { - return ExProcNode({EXIT, REFEXIT}, &InferShapePass::AddNodeStopped, "need stop"); // Stop Exit - } - } - return SUCCESS; } } // namespace ge