@@ -36,7 +36,6 @@ struct DuringPassNodeSets { | |||
std::unordered_set<NodePtr> nodes_re_pass; | |||
std::unordered_set<NodePtr> nodes_re_pass_immediately; | |||
std::unordered_set<NodePtr> nodes_last; | |||
std::unordered_set<NodePtr> nodes_stopped; | |||
}; | |||
void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &input_edge_nodes, | |||
@@ -56,25 +55,8 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &i | |||
} | |||
} | |||
void AddNextIterNodes(const std::vector<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass, | |||
DuringPassNodeSets &during_pass_node_set) { | |||
for (auto &node : nodes) { | |||
if (node == nullptr) { | |||
continue; | |||
} | |||
if (during_pass_node_set.nodes_stopped.count(node) > 0) { | |||
GELOGD("The node %s was stopped by pass, skip it.", node->GetName().c_str()); | |||
continue; | |||
} | |||
nodes_to_pass.push_back(node); | |||
} | |||
} | |||
void GetNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::vector<NodePtr> &nodes_to_pass, | |||
DuringPassNodeSets &during_pass_node_set) { | |||
std::unordered_set<Node *> &nodes_seen = during_pass_node_set.nodes_seen; | |||
const std::unordered_set<NodePtr> &nodes_last = during_pass_node_set.nodes_last; | |||
void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass, | |||
std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_last) { | |||
for (auto &node : nodes) { | |||
if (node == nullptr) { | |||
continue; | |||
@@ -90,22 +72,8 @@ void GetNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::vector<NodePtr> & | |||
} | |||
} | |||
void PushToStoppedNodes(DuringPassNodeSets &during_pass_node_set, const std::string &pass_name, | |||
const std::unordered_set<NodePtr> &nodes_stopped, | |||
const std::unordered_set<NodePtr> &nodes_restored) { | |||
for (const auto &node : nodes_stopped) { | |||
GELOGD("The node %s was stopped by pass %s", node->GetName().c_str(), pass_name.c_str()); | |||
during_pass_node_set.nodes_stopped.emplace(node); | |||
} | |||
for (const auto &node : nodes_restored) { | |||
GELOGD("The node %s was restored by pass %s", node->GetName().c_str(), pass_name.c_str()); | |||
during_pass_node_set.nodes_stopped.erase(node); | |||
} | |||
} | |||
void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass, | |||
std::unordered_set<Node *> &nodes_seen, const std::unordered_set<NodePtr> &nodes_to_re_pass, | |||
std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_to_re_pass, | |||
std::unordered_set<NodePtr> &nodes_re_pass) { | |||
for (const auto &node_to_re_pass : nodes_to_re_pass) { | |||
if (node_to_re_pass == nullptr) { | |||
@@ -129,8 +97,6 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo | |||
} | |||
GELOGD("Begin to run pass for node %s", node->GetName().c_str()); | |||
for (const auto &name_to_pass : names_to_passes) { | |||
const std::string &pass_name = name_to_pass.first; | |||
BaseNodePass *pass_node = name_to_pass.second; | |||
if (name_to_pass.second == nullptr) { | |||
GELOGE(INTERNAL_ERROR, "There is null pointer in passes(%s), skip it", name_to_pass.first.c_str()); | |||
continue; | |||
@@ -147,17 +113,15 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo | |||
return result; | |||
} | |||
const auto &nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); | |||
auto nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); | |||
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass, | |||
during_pass_node_set.nodes_re_pass); | |||
const auto &nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); | |||
auto nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); | |||
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, | |||
during_pass_node_set.nodes_re_pass_immediately); | |||
PushToStoppedNodes(during_pass_node_set, pass_name, pass_node->GetNodesStopped(), pass_node->GetNodesRestored()); | |||
const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); | |||
auto nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); | |||
during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); | |||
if (nodes_deleted_by_pass.count(node) > 0) { | |||
GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), | |||
@@ -258,8 +222,7 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { | |||
continue; | |||
} | |||
std::vector<NodePtr> nodes_to_pass; | |||
GetNextIterNodes(node->GetOutNodes(), nodes_to_pass, during_pass_node_set); | |||
AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last); | |||
auto ret = RunPasses(node, names_to_passes, during_pass_node_set); | |||
if (ret != SUCCESS) { | |||
@@ -295,8 +258,6 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { | |||
nodes.push_front(node); | |||
} | |||
during_pass_node_set.nodes_re_pass_immediately.clear(); | |||
AddNextIterNodes(nodes_to_pass, nodes, during_pass_node_set); | |||
} | |||
for (auto &node : during_pass_node_set.nodes_last) { | |||
@@ -51,15 +51,11 @@ class BaseNodePass { | |||
virtual ~BaseNodePass() = default; | |||
const std::unordered_set<NodePtr> &GetNodesNeedRePass() { return nodes_need_re_pass_; } | |||
std::unordered_set<NodePtr> GetNodesNeedRePass() { return nodes_need_re_pass_; } | |||
const std::unordered_set<NodePtr> &GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; } | |||
std::unordered_set<NodePtr> GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; } | |||
const std::unordered_set<NodePtr> &GetNodesDeleted() { return nodes_deleted_; } | |||
const std::unordered_set<NodePtr> &GetNodesStopped() { return nodes_stopped_; } | |||
const std::unordered_set<NodePtr> &GetNodesRestored() { return nodes_restored_; } | |||
std::unordered_set<NodePtr> GetNodesDeleted() { return nodes_deleted_; } | |||
void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } | |||
@@ -69,8 +65,6 @@ class BaseNodePass { | |||
nodes_need_re_pass_.clear(); | |||
nodes_deleted_.clear(); | |||
nodes_need_re_pass_immediately_.clear(); | |||
nodes_stopped_.clear(); | |||
nodes_restored_.clear(); | |||
} | |||
protected: | |||
@@ -86,7 +80,7 @@ class BaseNodePass { | |||
/// optimized by other passes, call this function. | |||
/// @param node | |||
/// | |||
void AddRePassNode(const NodePtr &node) { nodes_need_re_pass_.insert(node); } | |||
void AddRePassNode(NodePtr &node) { nodes_need_re_pass_.insert(node); } | |||
/// | |||
/// Add a node to be optimized immediately again. If you add a new node to the graph, or | |||
@@ -94,13 +88,13 @@ class BaseNodePass { | |||
/// optimized by other passes, call this function. | |||
/// @param node | |||
/// | |||
void AddImmediateRePassNode(const NodePtr &node) { nodes_need_re_pass_immediately_.insert(node); } | |||
void AddImmediateRePassNode(NodePtr &node) { nodes_need_re_pass_immediately_.insert(node); } | |||
/// | |||
/// Add a node and it's input/output data nodes to be optimized again. | |||
/// @param node | |||
/// | |||
void AddRePassNodesWithInOut(const NodePtr &node) { | |||
void AddRePassNodesWithInOut(NodePtr &node) { | |||
AddRePassNode(node); | |||
auto out_nodes = node->GetOutNodes(); | |||
for (auto &out_node : out_nodes) { | |||
@@ -122,34 +116,12 @@ class BaseNodePass { | |||
/// | |||
void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } | |||
/// | |||
/// If you stop a node from the graph, especially following node. The remain | |||
/// iterate passes will stop process on the stopped node(if it can be | |||
/// reached by edge connections) till the last one. Obviously it is a waste of | |||
/// time. You can add the stopped nodes by calling this function, to stop the | |||
/// next iterations. | |||
/// @param node | |||
/// | |||
void AddNodeStopped(const NodePtr &node) { nodes_stopped_.insert(node); } | |||
/// | |||
/// If you restore a node from the graph, especially following node. The remain | |||
/// iterate passes will continue process on the stopped node(if it can be | |||
/// reached by edge connections) till the last one. | |||
/// You can add the restored nodes by calling this function, to restore the | |||
/// next iterations. | |||
/// @param node | |||
/// | |||
void AddNodeRestored(const NodePtr &node) { nodes_restored_.insert(node); } | |||
bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } | |||
private: | |||
std::unordered_set<NodePtr> nodes_need_re_pass_; | |||
std::unordered_set<NodePtr> nodes_need_re_pass_immediately_; | |||
std::unordered_set<NodePtr> nodes_deleted_; | |||
std::unordered_set<NodePtr> nodes_stopped_; | |||
std::unordered_set<NodePtr> nodes_restored_; | |||
std::map<NodePassOption, std::string> options_; | |||
}; | |||
@@ -126,19 +126,6 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) { | |||
return SUCCESS; | |||
}; | |||
const auto ExProcNode = [&](const std::set<std::string> &proc_types, | |||
const std::function<void(InferShapePass *, NodePtr)> &proc_func, | |||
const std::string &info) { | |||
for (auto &n : node->GetOutDataNodes()) { | |||
GE_CHECK_NOTNULL(n); | |||
if (proc_types.count(n->GetType()) > 0) { | |||
proc_func(this, n); | |||
GELOGD("Node %s %s after %s.", n->GetName().c_str(), info.c_str(), node->GetName().c_str()); | |||
} | |||
} | |||
return SUCCESS; | |||
}; | |||
if (node->GetType() == NEXTITERATION || node->GetType() == REFNEXTITERATION) { | |||
return RePassNode({MERGE, REFMERGE}); // Re-Pass Merge | |||
} | |||
@@ -146,20 +133,10 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) { | |||
if (node->GetType() == MERGE || node->GetType() == REFMERGE) { | |||
if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | |||
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | |||
return RePassNode({SWITCH, REFSWITCH}); // Re-Pass Switch | |||
} | |||
return SUCCESS; | |||
} | |||
if (node->GetType() == SWITCH || node->GetType() == REFSWITCH) { | |||
if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | |||
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | |||
return ExProcNode({EXIT, REFEXIT}, &InferShapePass::AddNodeRestored, "need restore"); // Restore Exit | |||
} else { | |||
return ExProcNode({EXIT, REFEXIT}, &InferShapePass::AddNodeStopped, "need stop"); // Stop Exit | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
} // namespace ge |