| @@ -36,6 +36,8 @@ struct DuringPassNodeSets { | |||||
| std::unordered_set<NodePtr> nodes_re_pass; | std::unordered_set<NodePtr> nodes_re_pass; | ||||
| std::unordered_set<NodePtr> nodes_re_pass_immediately; | std::unordered_set<NodePtr> nodes_re_pass_immediately; | ||||
| std::unordered_set<NodePtr> nodes_last; | std::unordered_set<NodePtr> nodes_last; | ||||
| std::unordered_set<NodePtr> nodes_suspend; | |||||
| std::unordered_set<NodePtr> nodes_resume; | |||||
| }; | }; | ||||
| void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &input_edge_nodes, | void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &input_edge_nodes, | ||||
| @@ -55,8 +57,15 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &i | |||||
| } | } | ||||
| } | } | ||||
| bool IsAllInNodesAlive(const Node::Vistor<NodePtr> &nodes, const std::unordered_set<NodePtr> &nodes_suspend) { | |||||
| return !std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { return nodes_suspend.count(n) > 0; }); | |||||
| } | |||||
| void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass, | void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass, | ||||
| std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_last) { | |||||
| DuringPassNodeSets &during_pass_node_set) { | |||||
| auto &nodes_seen = during_pass_node_set.nodes_seen; | |||||
| const auto &nodes_last = during_pass_node_set.nodes_last; | |||||
| const auto &nodes_suspend = during_pass_node_set.nodes_suspend; | |||||
| for (auto &node : nodes) { | for (auto &node : nodes) { | ||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| continue; | continue; | ||||
| @@ -64,16 +73,57 @@ void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &n | |||||
| if (nodes_last.count(node) != 0) { | if (nodes_last.count(node) != 0) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (nodes_suspend.count(node) > 0) { | |||||
| GELOGD("The node %s has suspend by pass, skip it.", node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| bool all_in_nodes_alive = IsAllInNodesAlive(node->GetInAllNodes(), nodes_suspend); | |||||
| bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen); | bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen); | ||||
| if (all_in_nodes_seen && nodes_seen.insert(node.get()).second) { | |||||
| if (all_in_nodes_seen && all_in_nodes_alive && nodes_seen.insert(node.get()).second) { | |||||
| nodes_to_pass.push_back(node); | nodes_to_pass.push_back(node); | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| void AddRepassNodes(DuringPassNodeSets &during_pass_node_set, std::deque<NodePtr> &nodes) { | |||||
| for (const auto &node : during_pass_node_set.nodes_re_pass_immediately) { | |||||
| GELOGD("The node %s will be re-pass immediately.", node->GetName().c_str()); | |||||
| nodes.push_front(node); | |||||
| } | |||||
| during_pass_node_set.nodes_re_pass_immediately.clear(); | |||||
| } | |||||
| void AddResumeNodes(DuringPassNodeSets &during_pass_node_set, std::deque<NodePtr> &nodes) { | |||||
| for (auto &node : during_pass_node_set.nodes_resume) { | |||||
| const auto &it = during_pass_node_set.nodes_suspend.find(node); | |||||
| if (it != during_pass_node_set.nodes_suspend.end()) { | |||||
| during_pass_node_set.nodes_suspend.erase(node); | |||||
| GELOGD("The node %s resumed by pass.", node->GetName().c_str()); | |||||
| nodes.push_back(node); | |||||
| } else { | |||||
| GELOGW("The node %s not suspend, drop from resumed", node->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| during_pass_node_set.nodes_resume.clear(); | |||||
| } | |||||
| void PushToSuspendNodes(DuringPassNodeSets &during_pass_node_set, const std::string &pass_name, | |||||
| const std::unordered_set<NodePtr> &nodes_suspend, | |||||
| const std::unordered_set<NodePtr> &nodes_resume) { | |||||
| for (const auto &node : nodes_suspend) { | |||||
| GELOGD("The iteration suspend of node %s has been set by pass %s", node->GetName().c_str(), pass_name.c_str()); | |||||
| during_pass_node_set.nodes_suspend.emplace(node); | |||||
| } | |||||
| for (const auto &node : nodes_resume) { | |||||
| GELOGD("The iteration suspend of node %s has been resumed by pass %s", node->GetName().c_str(), pass_name.c_str()); | |||||
| during_pass_node_set.nodes_resume.emplace(node); | |||||
| } | |||||
| } | |||||
| void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass, | void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass, | ||||
| std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_to_re_pass, | |||||
| std::unordered_set<Node *> &nodes_seen, const std::unordered_set<NodePtr> &nodes_to_re_pass, | |||||
| std::unordered_set<NodePtr> &nodes_re_pass) { | std::unordered_set<NodePtr> &nodes_re_pass) { | ||||
| for (const auto &node_to_re_pass : nodes_to_re_pass) { | for (const auto &node_to_re_pass : nodes_to_re_pass) { | ||||
| if (node_to_re_pass == nullptr) { | if (node_to_re_pass == nullptr) { | ||||
| @@ -113,15 +163,18 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo | |||||
| return result; | return result; | ||||
| } | } | ||||
| auto nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); | |||||
| const auto &nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); | |||||
| PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass, | PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass, | ||||
| during_pass_node_set.nodes_re_pass); | during_pass_node_set.nodes_re_pass); | ||||
| auto nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); | |||||
| const auto &nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); | |||||
| PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, | PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, | ||||
| during_pass_node_set.nodes_re_pass_immediately); | during_pass_node_set.nodes_re_pass_immediately); | ||||
| auto nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); | |||||
| PushToSuspendNodes(during_pass_node_set, name_to_pass.first, | |||||
| name_to_pass.second->GetNodesSuspend(), name_to_pass.second->GetNodesResume()); | |||||
| const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); | |||||
| during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); | during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); | ||||
| if (nodes_deleted_by_pass.count(node) > 0) { | if (nodes_deleted_by_pass.count(node) > 0) { | ||||
| GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), | GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), | ||||
| @@ -221,8 +274,13 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { | |||||
| GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); | GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (during_pass_node_set.nodes_suspend.count(node) > 0) { | |||||
| GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", | |||||
| node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last); | |||||
| AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set); | |||||
| auto ret = RunPasses(node, names_to_passes, during_pass_node_set); | auto ret = RunPasses(node, names_to_passes, during_pass_node_set); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| @@ -253,11 +311,9 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { | |||||
| // should be called each time at the begin of the iteration | // should be called each time at the begin of the iteration | ||||
| ClearOption(names_to_passes); | ClearOption(names_to_passes); | ||||
| } | } | ||||
| for (const auto &node : during_pass_node_set.nodes_re_pass_immediately) { | |||||
| GELOGD("The node %s will be re-pass immediately.", node->GetName().c_str()); | |||||
| nodes.push_front(node); | |||||
| } | |||||
| during_pass_node_set.nodes_re_pass_immediately.clear(); | |||||
| AddRepassNodes(during_pass_node_set, nodes); | |||||
| AddResumeNodes(during_pass_node_set, nodes); | |||||
| } | } | ||||
| for (auto &node : during_pass_node_set.nodes_last) { | for (auto &node : during_pass_node_set.nodes_last) { | ||||
| @@ -51,11 +51,15 @@ class BaseNodePass { | |||||
| virtual ~BaseNodePass() = default; | virtual ~BaseNodePass() = default; | ||||
| std::unordered_set<NodePtr> GetNodesNeedRePass() { return nodes_need_re_pass_; } | |||||
| const std::unordered_set<NodePtr> &GetNodesNeedRePass() { return nodes_need_re_pass_; } | |||||
| std::unordered_set<NodePtr> GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; } | |||||
| const std::unordered_set<NodePtr> &GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; } | |||||
| std::unordered_set<NodePtr> GetNodesDeleted() { return nodes_deleted_; } | |||||
| const std::unordered_set<NodePtr> &GetNodesDeleted() { return nodes_deleted_; } | |||||
| const std::unordered_set<NodePtr> &GetNodesSuspend() { return nodes_suspend_; } | |||||
| const std::unordered_set<NodePtr> &GetNodesResume() { return nodes_resume_; } | |||||
| void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } | void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } | ||||
| @@ -65,6 +69,8 @@ class BaseNodePass { | |||||
| nodes_need_re_pass_.clear(); | nodes_need_re_pass_.clear(); | ||||
| nodes_deleted_.clear(); | nodes_deleted_.clear(); | ||||
| nodes_need_re_pass_immediately_.clear(); | nodes_need_re_pass_immediately_.clear(); | ||||
| nodes_suspend_.clear(); | |||||
| nodes_resume_.clear(); | |||||
| } | } | ||||
| protected: | protected: | ||||
| @@ -80,7 +86,7 @@ class BaseNodePass { | |||||
| /// optimized by other passes, call this function. | /// optimized by other passes, call this function. | ||||
| /// @param node | /// @param node | ||||
| /// | /// | ||||
| void AddRePassNode(NodePtr &node) { nodes_need_re_pass_.insert(node); } | |||||
| void AddRePassNode(const NodePtr &node) { nodes_need_re_pass_.insert(node); } | |||||
| /// | /// | ||||
| /// Add a node to be optimized immediately again. If you add a new node to the graph, or | /// Add a node to be optimized immediately again. If you add a new node to the graph, or | ||||
| @@ -88,13 +94,13 @@ class BaseNodePass { | |||||
| /// optimized by other passes, call this function. | /// optimized by other passes, call this function. | ||||
| /// @param node | /// @param node | ||||
| /// | /// | ||||
| void AddImmediateRePassNode(NodePtr &node) { nodes_need_re_pass_immediately_.insert(node); } | |||||
| void AddImmediateRePassNode(const NodePtr &node) { nodes_need_re_pass_immediately_.insert(node); } | |||||
| /// | /// | ||||
| /// Add a node and it's input/output data nodes to be optimized again. | /// Add a node and it's input/output data nodes to be optimized again. | ||||
| /// @param node | /// @param node | ||||
| /// | /// | ||||
| void AddRePassNodesWithInOut(NodePtr &node) { | |||||
| void AddRePassNodesWithInOut(const NodePtr &node) { | |||||
| AddRePassNode(node); | AddRePassNode(node); | ||||
| auto out_nodes = node->GetOutNodes(); | auto out_nodes = node->GetOutNodes(); | ||||
| for (auto &out_node : out_nodes) { | for (auto &out_node : out_nodes) { | ||||
| @@ -116,12 +122,34 @@ class BaseNodePass { | |||||
| /// | /// | ||||
| void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } | void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } | ||||
| /// | |||||
| /// If you suspend a node from the graph, especially following node. The remain | |||||
| /// iterate passes will stop process on the suspend node(if it can be | |||||
| /// reached by edge connections) till the last one. Obviously it is a waste of | |||||
| /// time. You can add the suspend nodes by calling this function, to stop the | |||||
| /// next iterations. | |||||
| /// @param node | |||||
| /// | |||||
| void AddNodeSuspend(const NodePtr &node) { nodes_suspend_.insert(node); } | |||||
| /// | |||||
| /// If you resume a node from the graph, especially following node. The remain | |||||
| /// iterate passes will continue process on the resume node(if it can be | |||||
| /// reached by edge connections) till the last one. | |||||
| /// You can add the resume nodes by calling this function, to resume the | |||||
| /// next iterations. | |||||
| /// @param node | |||||
| /// | |||||
| void AddNodeResume(const NodePtr &node) { nodes_resume_.insert(node); } | |||||
| bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } | bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } | ||||
| private: | private: | ||||
| std::unordered_set<NodePtr> nodes_need_re_pass_; | std::unordered_set<NodePtr> nodes_need_re_pass_; | ||||
| std::unordered_set<NodePtr> nodes_need_re_pass_immediately_; | std::unordered_set<NodePtr> nodes_need_re_pass_immediately_; | ||||
| std::unordered_set<NodePtr> nodes_deleted_; | std::unordered_set<NodePtr> nodes_deleted_; | ||||
| std::unordered_set<NodePtr> nodes_suspend_; | |||||
| std::unordered_set<NodePtr> nodes_resume_; | |||||
| std::map<NodePassOption, std::string> options_; | std::map<NodePassOption, std::string> options_; | ||||
| }; | }; | ||||
| @@ -21,6 +21,8 @@ | |||||
| #include "framework/common/util.h" | #include "framework/common/util.h" | ||||
| #include "graph/shape_refiner.h" | #include "graph/shape_refiner.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "graph/utils/node_utils.h" | |||||
| #include "graph/common/omg_util.h" | |||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "utils/tensor_utils.h" | #include "utils/tensor_utils.h" | ||||
| #include "utils/type_utils.h" | #include "utils/type_utils.h" | ||||
| @@ -117,7 +119,9 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) { | |||||
| const auto RePassNode = [&](const std::set<std::string> &re_pass_types) { | const auto RePassNode = [&](const std::set<std::string> &re_pass_types) { | ||||
| for (auto &n : node->GetOutDataNodes()) { | for (auto &n : node->GetOutDataNodes()) { | ||||
| GE_CHECK_NOTNULL(n); | GE_CHECK_NOTNULL(n); | ||||
| if (re_pass_types.count(n->GetType()) > 0) { | |||||
| std::string node_type; | |||||
| GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "Get original node type failed."); | |||||
| if (re_pass_types.count(node_type) > 0) { | |||||
| AddImmediateRePassNode(n); | AddImmediateRePassNode(n); | ||||
| (void)AttrUtils::SetBool(n->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, false); | (void)AttrUtils::SetBool(n->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, false); | ||||
| GELOGD("Node %s need repass immediately after %s.", n->GetName().c_str(), node->GetName().c_str()); | GELOGD("Node %s need repass immediately after %s.", n->GetName().c_str(), node->GetName().c_str()); | ||||
| @@ -126,17 +130,44 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| }; | }; | ||||
| if (node->GetType() == NEXTITERATION || node->GetType() == REFNEXTITERATION) { | |||||
| return RePassNode({MERGE, REFMERGE}); // Re-Pass Merge | |||||
| const auto ExProcNode = [&](const std::set<std::string> &proc_types, | |||||
| const std::function<void(InferShapePass *, NodePtr)> &proc_func, | |||||
| const std::string &info) { | |||||
| for (auto &n : node->GetOutDataNodes()) { | |||||
| GE_CHECK_NOTNULL(n); | |||||
| std::string node_type; | |||||
| GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "Get original node type failed."); | |||||
| if (proc_types.count(node_type) > 0) { | |||||
| proc_func(this, n); | |||||
| GELOGD("Node %s %s after %s.", n->GetName().c_str(), info.c_str(), node->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| }; | |||||
| std::string node_type; | |||||
| GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original node type failed."); | |||||
| if (kNextIterationOpTypes.count(node_type) > 0) { | |||||
| return RePassNode(kMergeOpTypes); // Re-Pass Merge | |||||
| } | } | ||||
| if (node->GetType() == MERGE || node->GetType() == REFMERGE) { | |||||
| if (kMergeOpTypes.count(node_type) > 0) { | |||||
| if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | ||||
| node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | ||||
| return RePassNode(kSwitchOpTypes); // Re-Pass Switch | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| if (kSwitchOpTypes.count(node_type) > 0) { | |||||
| if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | |||||
| node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | |||||
| return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeResume, "need resume"); // Resume Exit | |||||
| } else { | |||||
| return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeSuspend, "need suspend"); // Suspend Exit | |||||
| } | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -260,6 +260,10 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| if (node->GetType() == MEMCPYASYNC) { // Convert MemcpyAsync to Identity. | |||||
| node->GetOpDesc()->SetType(IDENTITY); | |||||
| } | |||||
| std::unique_ptr<NodeItem> new_node; | std::unique_ptr<NodeItem> new_node; | ||||
| GE_CHK_STATUS_RET(NodeItem::Create(node, new_node), "[Invoke][Create] failed, model_name_:[%s]", GetGraphName()); | GE_CHK_STATUS_RET(NodeItem::Create(node, new_node), "[Invoke][Create] failed, model_name_:[%s]", GetGraphName()); | ||||
| GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor)); | GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor)); | ||||