@@ -36,7 +36,6 @@ struct DuringPassNodeSets { | |||||
std::unordered_set<NodePtr> nodes_re_pass; | std::unordered_set<NodePtr> nodes_re_pass; | ||||
std::unordered_set<NodePtr> nodes_re_pass_immediately; | std::unordered_set<NodePtr> nodes_re_pass_immediately; | ||||
std::unordered_set<NodePtr> nodes_last; | std::unordered_set<NodePtr> nodes_last; | ||||
std::unordered_set<NodePtr> nodes_stopped; | |||||
}; | }; | ||||
void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &input_edge_nodes, | void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &input_edge_nodes, | ||||
@@ -56,25 +55,8 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &i | |||||
} | } | ||||
} | } | ||||
void AddNextIterNodes(const std::vector<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass, | |||||
DuringPassNodeSets &during_pass_node_set) { | |||||
for (auto &node : nodes) { | |||||
if (node == nullptr) { | |||||
continue; | |||||
} | |||||
if (during_pass_node_set.nodes_stopped.count(node) > 0) { | |||||
GELOGD("The node %s was stopped by pass, skip it.", node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
nodes_to_pass.push_back(node); | |||||
} | |||||
} | |||||
void GetNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::vector<NodePtr> &nodes_to_pass, | |||||
DuringPassNodeSets &during_pass_node_set) { | |||||
std::unordered_set<Node *> &nodes_seen = during_pass_node_set.nodes_seen; | |||||
const std::unordered_set<NodePtr> &nodes_last = during_pass_node_set.nodes_last; | |||||
void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass, | |||||
std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_last) { | |||||
for (auto &node : nodes) { | for (auto &node : nodes) { | ||||
if (node == nullptr) { | if (node == nullptr) { | ||||
continue; | continue; | ||||
@@ -90,22 +72,8 @@ void GetNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::vector<NodePtr> & | |||||
} | } | ||||
} | } | ||||
void PushToStoppedNodes(DuringPassNodeSets &during_pass_node_set, const std::string &pass_name, | |||||
const std::unordered_set<NodePtr> &nodes_stopped, | |||||
const std::unordered_set<NodePtr> &nodes_restored) { | |||||
for (const auto &node : nodes_stopped) { | |||||
GELOGD("The node %s was stopped by pass %s", node->GetName().c_str(), pass_name.c_str()); | |||||
during_pass_node_set.nodes_stopped.emplace(node); | |||||
} | |||||
for (const auto &node : nodes_restored) { | |||||
GELOGD("The node %s was restored by pass %s", node->GetName().c_str(), pass_name.c_str()); | |||||
during_pass_node_set.nodes_stopped.erase(node); | |||||
} | |||||
} | |||||
void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass, | void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass, | ||||
std::unordered_set<Node *> &nodes_seen, const std::unordered_set<NodePtr> &nodes_to_re_pass, | |||||
std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_to_re_pass, | |||||
std::unordered_set<NodePtr> &nodes_re_pass) { | std::unordered_set<NodePtr> &nodes_re_pass) { | ||||
for (const auto &node_to_re_pass : nodes_to_re_pass) { | for (const auto &node_to_re_pass : nodes_to_re_pass) { | ||||
if (node_to_re_pass == nullptr) { | if (node_to_re_pass == nullptr) { | ||||
@@ -129,8 +97,6 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo | |||||
} | } | ||||
GELOGD("Begin to run pass for node %s", node->GetName().c_str()); | GELOGD("Begin to run pass for node %s", node->GetName().c_str()); | ||||
for (const auto &name_to_pass : names_to_passes) { | for (const auto &name_to_pass : names_to_passes) { | ||||
const std::string &pass_name = name_to_pass.first; | |||||
BaseNodePass *pass_node = name_to_pass.second; | |||||
if (name_to_pass.second == nullptr) { | if (name_to_pass.second == nullptr) { | ||||
GELOGE(INTERNAL_ERROR, "There is null pointer in passes(%s), skip it", name_to_pass.first.c_str()); | GELOGE(INTERNAL_ERROR, "There is null pointer in passes(%s), skip it", name_to_pass.first.c_str()); | ||||
continue; | continue; | ||||
@@ -147,17 +113,15 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo | |||||
return result; | return result; | ||||
} | } | ||||
const auto &nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); | |||||
auto nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); | |||||
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass, | PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass, | ||||
during_pass_node_set.nodes_re_pass); | during_pass_node_set.nodes_re_pass); | ||||
const auto &nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); | |||||
auto nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); | |||||
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, | PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, | ||||
during_pass_node_set.nodes_re_pass_immediately); | during_pass_node_set.nodes_re_pass_immediately); | ||||
PushToStoppedNodes(during_pass_node_set, pass_name, pass_node->GetNodesStopped(), pass_node->GetNodesRestored()); | |||||
const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); | |||||
auto nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); | |||||
during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); | during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); | ||||
if (nodes_deleted_by_pass.count(node) > 0) { | if (nodes_deleted_by_pass.count(node) > 0) { | ||||
GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), | GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), | ||||
@@ -258,8 +222,7 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { | |||||
continue; | continue; | ||||
} | } | ||||
std::vector<NodePtr> nodes_to_pass; | |||||
GetNextIterNodes(node->GetOutNodes(), nodes_to_pass, during_pass_node_set); | |||||
AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last); | |||||
auto ret = RunPasses(node, names_to_passes, during_pass_node_set); | auto ret = RunPasses(node, names_to_passes, during_pass_node_set); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
@@ -295,8 +258,6 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { | |||||
nodes.push_front(node); | nodes.push_front(node); | ||||
} | } | ||||
during_pass_node_set.nodes_re_pass_immediately.clear(); | during_pass_node_set.nodes_re_pass_immediately.clear(); | ||||
AddNextIterNodes(nodes_to_pass, nodes, during_pass_node_set); | |||||
} | } | ||||
for (auto &node : during_pass_node_set.nodes_last) { | for (auto &node : during_pass_node_set.nodes_last) { | ||||
@@ -51,15 +51,11 @@ class BaseNodePass { | |||||
virtual ~BaseNodePass() = default; | virtual ~BaseNodePass() = default; | ||||
const std::unordered_set<NodePtr> &GetNodesNeedRePass() { return nodes_need_re_pass_; } | |||||
std::unordered_set<NodePtr> GetNodesNeedRePass() { return nodes_need_re_pass_; } | |||||
const std::unordered_set<NodePtr> &GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; } | |||||
std::unordered_set<NodePtr> GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; } | |||||
const std::unordered_set<NodePtr> &GetNodesDeleted() { return nodes_deleted_; } | |||||
const std::unordered_set<NodePtr> &GetNodesStopped() { return nodes_stopped_; } | |||||
const std::unordered_set<NodePtr> &GetNodesRestored() { return nodes_restored_; } | |||||
std::unordered_set<NodePtr> GetNodesDeleted() { return nodes_deleted_; } | |||||
void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } | void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } | ||||
@@ -69,8 +65,6 @@ class BaseNodePass { | |||||
nodes_need_re_pass_.clear(); | nodes_need_re_pass_.clear(); | ||||
nodes_deleted_.clear(); | nodes_deleted_.clear(); | ||||
nodes_need_re_pass_immediately_.clear(); | nodes_need_re_pass_immediately_.clear(); | ||||
nodes_stopped_.clear(); | |||||
nodes_restored_.clear(); | |||||
} | } | ||||
protected: | protected: | ||||
@@ -86,7 +80,7 @@ class BaseNodePass { | |||||
/// optimized by other passes, call this function. | /// optimized by other passes, call this function. | ||||
/// @param node | /// @param node | ||||
/// | /// | ||||
void AddRePassNode(const NodePtr &node) { nodes_need_re_pass_.insert(node); } | |||||
void AddRePassNode(NodePtr &node) { nodes_need_re_pass_.insert(node); } | |||||
/// | /// | ||||
/// Add a node to be optimized immediately again. If you add a new node to the graph, or | /// Add a node to be optimized immediately again. If you add a new node to the graph, or | ||||
@@ -94,13 +88,13 @@ class BaseNodePass { | |||||
/// optimized by other passes, call this function. | /// optimized by other passes, call this function. | ||||
/// @param node | /// @param node | ||||
/// | /// | ||||
void AddImmediateRePassNode(const NodePtr &node) { nodes_need_re_pass_immediately_.insert(node); } | |||||
void AddImmediateRePassNode(NodePtr &node) { nodes_need_re_pass_immediately_.insert(node); } | |||||
/// | /// | ||||
/// Add a node and it's input/output data nodes to be optimized again. | /// Add a node and it's input/output data nodes to be optimized again. | ||||
/// @param node | /// @param node | ||||
/// | /// | ||||
void AddRePassNodesWithInOut(const NodePtr &node) { | |||||
void AddRePassNodesWithInOut(NodePtr &node) { | |||||
AddRePassNode(node); | AddRePassNode(node); | ||||
auto out_nodes = node->GetOutNodes(); | auto out_nodes = node->GetOutNodes(); | ||||
for (auto &out_node : out_nodes) { | for (auto &out_node : out_nodes) { | ||||
@@ -122,34 +116,12 @@ class BaseNodePass { | |||||
/// | /// | ||||
void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } | void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } | ||||
/// | |||||
/// If you stop a node from the graph, especially following node. The remain | |||||
/// iterate passes will stop process on the stopped node(if it can be | |||||
/// reached by edge connections) till the last one. Obviously it is a waste of | |||||
/// time. You can add the stopped nodes by calling this function, to stop the | |||||
/// next iterations. | |||||
/// @param node | |||||
/// | |||||
void AddNodeStopped(const NodePtr &node) { nodes_stopped_.insert(node); } | |||||
/// | |||||
/// If you restore a node from the graph, especially following node. The remain | |||||
/// iterate passes will continue process on the stopped node(if it can be | |||||
/// reached by edge connections) till the last one. | |||||
/// You can add the restored nodes by calling this function, to restore the | |||||
/// next iterations. | |||||
/// @param node | |||||
/// | |||||
void AddNodeRestored(const NodePtr &node) { nodes_restored_.insert(node); } | |||||
bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } | bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } | ||||
private: | private: | ||||
std::unordered_set<NodePtr> nodes_need_re_pass_; | std::unordered_set<NodePtr> nodes_need_re_pass_; | ||||
std::unordered_set<NodePtr> nodes_need_re_pass_immediately_; | std::unordered_set<NodePtr> nodes_need_re_pass_immediately_; | ||||
std::unordered_set<NodePtr> nodes_deleted_; | std::unordered_set<NodePtr> nodes_deleted_; | ||||
std::unordered_set<NodePtr> nodes_stopped_; | |||||
std::unordered_set<NodePtr> nodes_restored_; | |||||
std::map<NodePassOption, std::string> options_; | std::map<NodePassOption, std::string> options_; | ||||
}; | }; | ||||
@@ -126,19 +126,6 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) { | |||||
return SUCCESS; | return SUCCESS; | ||||
}; | }; | ||||
const auto ExProcNode = [&](const std::set<std::string> &proc_types, | |||||
const std::function<void(InferShapePass *, NodePtr)> &proc_func, | |||||
const std::string &info) { | |||||
for (auto &n : node->GetOutDataNodes()) { | |||||
GE_CHECK_NOTNULL(n); | |||||
if (proc_types.count(n->GetType()) > 0) { | |||||
proc_func(this, n); | |||||
GELOGD("Node %s %s after %s.", n->GetName().c_str(), info.c_str(), node->GetName().c_str()); | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
}; | |||||
if (node->GetType() == NEXTITERATION || node->GetType() == REFNEXTITERATION) { | if (node->GetType() == NEXTITERATION || node->GetType() == REFNEXTITERATION) { | ||||
return RePassNode({MERGE, REFMERGE}); // Re-Pass Merge | return RePassNode({MERGE, REFMERGE}); // Re-Pass Merge | ||||
} | } | ||||
@@ -146,20 +133,10 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) { | |||||
if (node->GetType() == MERGE || node->GetType() == REFMERGE) { | if (node->GetType() == MERGE || node->GetType() == REFMERGE) { | ||||
if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | ||||
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | ||||
return RePassNode({SWITCH, REFSWITCH}); // Re-Pass Switch | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
if (node->GetType() == SWITCH || node->GetType() == REFSWITCH) { | |||||
if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | |||||
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | |||||
return ExProcNode({EXIT, REFEXIT}, &InferShapePass::AddNodeRestored, "need restore"); // Restore Exit | |||||
} else { | |||||
return ExProcNode({EXIT, REFEXIT}, &InferShapePass::AddNodeStopped, "need stop"); // Stop Exit | |||||
} | |||||
} | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
} // namespace ge | } // namespace ge |