diff --git a/ge/graph/passes/base_pass.cc b/ge/graph/passes/base_pass.cc index 0868b729..2f94c6ad 100755 --- a/ge/graph/passes/base_pass.cc +++ b/ge/graph/passes/base_pass.cc @@ -36,6 +36,8 @@ struct DuringPassNodeSets { std::unordered_set nodes_re_pass; std::unordered_set nodes_re_pass_immediately; std::unordered_set nodes_last; + std::unordered_set nodes_suspend; + std::unordered_set nodes_resume; }; void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque &input_edge_nodes, @@ -55,8 +57,15 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque &i } } +bool IsAllInNodesAlive(const Node::Vistor &nodes, const std::unordered_set &nodes_suspend) { + return !std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { return nodes_suspend.count(n) > 0; }); +} + void AddNextIterNodes(const Node::Vistor &nodes, std::deque &nodes_to_pass, - std::unordered_set &nodes_seen, std::unordered_set &nodes_last) { + DuringPassNodeSets &during_pass_node_set) { + auto &nodes_seen = during_pass_node_set.nodes_seen; + const auto &nodes_last = during_pass_node_set.nodes_last; + const auto &nodes_suspend = during_pass_node_set.nodes_suspend; for (auto &node : nodes) { if (node == nullptr) { continue; @@ -64,16 +73,57 @@ void AddNextIterNodes(const Node::Vistor &nodes, std::deque &n if (nodes_last.count(node) != 0) { continue; } + if (nodes_suspend.count(node) > 0) { + GELOGD("The node %s has suspend by pass, skip it.", node->GetName().c_str()); + continue; + } + bool all_in_nodes_alive = IsAllInNodesAlive(node->GetInAllNodes(), nodes_suspend); bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen); - if (all_in_nodes_seen && nodes_seen.insert(node.get()).second) { + if (all_in_nodes_seen && all_in_nodes_alive && nodes_seen.insert(node.get()).second) { nodes_to_pass.push_back(node); } } } +void AddRepassNodes(DuringPassNodeSets &during_pass_node_set, std::deque &nodes) { + for (const auto &node : during_pass_node_set.nodes_re_pass_immediately) { + GELOGD("The node %s will be re-pass immediately.", node->GetName().c_str()); + nodes.push_front(node); + } + during_pass_node_set.nodes_re_pass_immediately.clear(); +} + +void AddResumeNodes(DuringPassNodeSets &during_pass_node_set, std::deque &nodes) { + for (auto &node : during_pass_node_set.nodes_resume) { + const auto &it = during_pass_node_set.nodes_suspend.find(node); + if (it != during_pass_node_set.nodes_suspend.end()) { + during_pass_node_set.nodes_suspend.erase(node); + GELOGD("The node %s resumed by pass.", node->GetName().c_str()); + nodes.push_back(node); + } else { + GELOGW("The node %s not suspend, drop from resumed", node->GetName().c_str()); + } + } + during_pass_node_set.nodes_resume.clear(); +} + +void PushToSuspendNodes(DuringPassNodeSets &during_pass_node_set, const std::string &pass_name, + const std::unordered_set &nodes_suspend, + const std::unordered_set &nodes_resume) { + for (const auto &node : nodes_suspend) { + GELOGD("The iteration suspend of node %s has been set by pass %s", node->GetName().c_str(), pass_name.c_str()); + during_pass_node_set.nodes_suspend.emplace(node); + } + + for (const auto &node : nodes_resume) { + GELOGD("The iteration suspend of node %s has been resumed by pass %s", node->GetName().c_str(), pass_name.c_str()); + during_pass_node_set.nodes_resume.emplace(node); + } +} + void PushToRePassIfSeen(NodePtr &node, const std::pair &name_to_pass, - std::unordered_set &nodes_seen, std::unordered_set &nodes_to_re_pass, + std::unordered_set &nodes_seen, const std::unordered_set &nodes_to_re_pass, std::unordered_set &nodes_re_pass) { for (const auto &node_to_re_pass : nodes_to_re_pass) { if (node_to_re_pass == nullptr) { @@ -113,15 +163,18 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo return result; } - auto nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); + const auto &nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass, during_pass_node_set.nodes_re_pass); - auto nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); + const auto &nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, during_pass_node_set.nodes_re_pass_immediately); - auto nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); + PushToSuspendNodes(during_pass_node_set, name_to_pass.first, + name_to_pass.second->GetNodesSuspend(), name_to_pass.second->GetNodesResume()); + + const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); if (nodes_deleted_by_pass.count(node) > 0) { GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), @@ -221,8 +274,13 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); continue; } + if (during_pass_node_set.nodes_suspend.count(node) > 0) { + GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", + node->GetName().c_str()); + continue; + } - AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last); + AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set); auto ret = RunPasses(node, names_to_passes, during_pass_node_set); if (ret != SUCCESS) { @@ -253,11 +311,9 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { // should be called each time at the begin of the iteration ClearOption(names_to_passes); } - for (const auto &node : during_pass_node_set.nodes_re_pass_immediately) { - GELOGD("The node %s will be re-pass immediately.", node->GetName().c_str()); - nodes.push_front(node); - } - during_pass_node_set.nodes_re_pass_immediately.clear(); + + AddRepassNodes(during_pass_node_set, nodes); + AddResumeNodes(during_pass_node_set, nodes); } for (auto &node : during_pass_node_set.nodes_last) { diff --git a/ge/graph/passes/base_pass.h b/ge/graph/passes/base_pass.h index a9f4f000..d0f125b2 100644 --- a/ge/graph/passes/base_pass.h +++ b/ge/graph/passes/base_pass.h @@ -51,11 +51,15 @@ class BaseNodePass { virtual ~BaseNodePass() = default; - std::unordered_set GetNodesNeedRePass() { return nodes_need_re_pass_; } + const std::unordered_set &GetNodesNeedRePass() { return nodes_need_re_pass_; } - std::unordered_set GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; } + const std::unordered_set &GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; } - std::unordered_set GetNodesDeleted() { return nodes_deleted_; } + const std::unordered_set &GetNodesDeleted() { return nodes_deleted_; } + + const std::unordered_set &GetNodesSuspend() { return nodes_suspend_; } + + const std::unordered_set &GetNodesResume() { return nodes_resume_; } void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } @@ -65,6 +69,8 @@ class BaseNodePass { nodes_need_re_pass_.clear(); nodes_deleted_.clear(); nodes_need_re_pass_immediately_.clear(); + nodes_suspend_.clear(); + nodes_resume_.clear(); } protected: @@ -80,7 +86,7 @@ class BaseNodePass { /// optimized by other passes, call this function. /// @param node /// - void AddRePassNode(NodePtr &node) { nodes_need_re_pass_.insert(node); } + void AddRePassNode(const NodePtr &node) { nodes_need_re_pass_.insert(node); } /// /// Add a node to be optimized immediately again. If you add a new node to the graph, or @@ -88,13 +94,13 @@ class BaseNodePass { /// optimized by other passes, call this function. /// @param node /// - void AddImmediateRePassNode(NodePtr &node) { nodes_need_re_pass_immediately_.insert(node); } + void AddImmediateRePassNode(const NodePtr &node) { nodes_need_re_pass_immediately_.insert(node); } /// /// Add a node and it's input/output data nodes to be optimized again. /// @param node /// - void AddRePassNodesWithInOut(NodePtr &node) { + void AddRePassNodesWithInOut(const NodePtr &node) { AddRePassNode(node); auto out_nodes = node->GetOutNodes(); for (auto &out_node : out_nodes) { @@ -116,12 +122,34 @@ class BaseNodePass { /// void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } + /// + /// If you suspend a node from the graph, especially following node. The remain + /// iterate passes will stop process on the suspend node(if it can be + /// reached by edge connections) till the last one. Obviously it is a waste of + /// time. You can add the suspend nodes by calling this function, to stop the + /// next iterations. + /// @param node + /// + void AddNodeSuspend(const NodePtr &node) { nodes_suspend_.insert(node); } + + /// + /// If you resume a node from the graph, especially following node. The remain + /// iterate passes will continue process on the resume node(if it can be + /// reached by edge connections) till the last one. + /// You can add the resume nodes by calling this function, to resume the + /// next iterations. + /// @param node + /// + void AddNodeResume(const NodePtr &node) { nodes_resume_.insert(node); } + bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } private: std::unordered_set nodes_need_re_pass_; std::unordered_set nodes_need_re_pass_immediately_; std::unordered_set nodes_deleted_; + std::unordered_set nodes_suspend_; + std::unordered_set nodes_resume_; std::map options_; }; diff --git a/ge/graph/passes/infershape_pass.cc b/ge/graph/passes/infershape_pass.cc index 46026023..cb649240 100755 --- a/ge/graph/passes/infershape_pass.cc +++ b/ge/graph/passes/infershape_pass.cc @@ -21,6 +21,8 @@ #include "framework/common/util.h" #include "graph/shape_refiner.h" #include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" +#include "graph/common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "utils/tensor_utils.h" #include "utils/type_utils.h" @@ -117,7 +119,9 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) { const auto RePassNode = [&](const std::set &re_pass_types) { for (auto &n : node->GetOutDataNodes()) { GE_CHECK_NOTNULL(n); - if (re_pass_types.count(n->GetType()) > 0) { + std::string node_type; + GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "Get original node type failed."); + if (re_pass_types.count(node_type) > 0) { AddImmediateRePassNode(n); (void)AttrUtils::SetBool(n->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, false); GELOGD("Node %s need repass immediately after %s.", n->GetName().c_str(), node->GetName().c_str()); @@ -126,17 +130,44 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) { return SUCCESS; }; - if (node->GetType() == NEXTITERATION || node->GetType() == REFNEXTITERATION) { - return RePassNode({MERGE, REFMERGE}); // Re-Pass Merge + const auto ExProcNode = [&](const std::set &proc_types, + const std::function &proc_func, + const std::string &info) { + for (auto &n : node->GetOutDataNodes()) { + GE_CHECK_NOTNULL(n); + std::string node_type; + GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "Get original node type failed."); + if (proc_types.count(node_type) > 0) { + proc_func(this, n); + GELOGD("Node %s %s after %s.", n->GetName().c_str(), info.c_str(), node->GetName().c_str()); + } + } + return SUCCESS; + }; + + std::string node_type; + GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original node type failed."); + if (kNextIterationOpTypes.count(node_type) > 0) { + return RePassNode(kMergeOpTypes); // Re-Pass Merge } - if (node->GetType() == MERGE || node->GetType() == REFMERGE) { + if (kMergeOpTypes.count(node_type) > 0) { if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); + return RePassNode(kSwitchOpTypes); // Re-Pass Switch } return SUCCESS; } + if (kSwitchOpTypes.count(node_type) > 0) { + if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { + node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); + return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeResume, "need resume"); // Resume Exit + } else { + return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeSuspend, "need suspend"); // Suspend Exit + } + } + return SUCCESS; } } // namespace ge diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index e108dddf..4a1663fc 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -260,6 +260,10 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n return SUCCESS; } + if (node->GetType() == MEMCPYASYNC) { // Convert MemcpyAsync to Identity. + node->GetOpDesc()->SetType(IDENTITY); + } + std::unique_ptr new_node; GE_CHK_STATUS_RET(NodeItem::Create(node, new_node), "[Invoke][Create] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor));