diff --git a/ge/graph/passes/base_pass.cc b/ge/graph/passes/base_pass.cc index a1551eb2..b74eb5cd 100755 --- a/ge/graph/passes/base_pass.cc +++ b/ge/graph/passes/base_pass.cc @@ -19,9 +19,7 @@ #include #include -#include "framework/common/debug/log.h" -#include "framework/common/debug/ge_log.h" -#include "graph/compute_graph.h" +#include "common/debug/log.h" #include "graph/utils/graph_utils.h" namespace ge { @@ -30,47 +28,35 @@ constexpr int kMaxRePassTimes = 10000; constexpr size_t kMaxOneInNodes = 1000; // Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later constexpr int kMaxRecursiveDepth = 20; -struct DuringPassNodeSets { - std::unordered_set nodes_seen; - std::unordered_set nodes_deleted; - std::unordered_set nodes_re_pass; - std::unordered_set nodes_re_pass_immediately; - std::unordered_set nodes_last; - std::unordered_set nodes_suspend; - std::unordered_set nodes_resume; -}; - -void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque &input_edge_nodes, - std::unordered_set &nodes_seen, std::unordered_set &nodes_last) { - nodes_last.clear(); + +void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, + GEPass::GraphLevelState &g_state) { for (auto &node : graph->GetDirectNode()) { if (node == nullptr) { continue; } size_t in_nums = node->GetInNodes().size(); if (in_nums == 0) { - input_edge_nodes.push_back(node); - nodes_seen.insert(node.get()); + g_state.AddNodeToQueueIfNotSeen(node); } else if (in_nums > kMaxOneInNodes) { - nodes_last.insert(node); + g_state.nodes_last.insert(node); } } } -bool IsAllInNodesAlive(const Node::Vistor &nodes, const std::unordered_set &nodes_suspend) { - return !std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { return nodes_suspend.count(n) > 0; }); +bool AllNodesIn(const Node::Vistor &nodes, const std::unordered_set &nodes_set) { + return !std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { + return nodes_set.count(n) > 0; + }); } -void AddNextIterNodes(const Node::Vistor &nodes, std::deque &nodes_to_pass, - DuringPassNodeSets &during_pass_node_set) { - auto &nodes_seen = during_pass_node_set.nodes_seen; - const auto &nodes_last = during_pass_node_set.nodes_last; - const auto &nodes_suspend = during_pass_node_set.nodes_suspend; - for (auto &node : nodes) { +void AddNextIterNodes(const NodePtr &cur_node, GEPass::GraphLevelState &g_state) { + const auto &nodes_suspend = g_state.nodes_suspend; + for (auto &node : cur_node->GetOutNodes()) { if (node == nullptr) { continue; } - if (nodes_last.count(node) != 0) { + if (g_state.nodes_last.count(node) != 0) { continue; } if (nodes_suspend.count(node) > 0) { @@ -78,47 +64,64 @@ void AddNextIterNodes(const Node::Vistor &nodes, std::deque &n continue; } - bool all_in_nodes_alive = IsAllInNodesAlive(node->GetInAllNodes(), nodes_suspend); - bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen); - if (all_in_nodes_seen && all_in_nodes_alive && nodes_seen.insert(node.get()).second) { - nodes_to_pass.push_back(node); + if (node->IsAllInNodesSeen(g_state.nodes_seen) && AllNodesIn(node->GetInAllNodes(), nodes_suspend)) { + g_state.AddNodeToQueueIfNotSeen(node); } } } -void AddRepassNodes(DuringPassNodeSets &during_pass_node_set, std::deque &nodes) { - for (const auto &node : during_pass_node_set.nodes_re_pass_immediately) { - GELOGD("The node %s will be re-pass immediately.", node->GetName().c_str()); - nodes.push_front(node); +void AddImmediateRepassNodesToQueue(NodePtr &cur_node, const std::pair &name_to_pass, + const std::unordered_set &nodes_im_re_pass, + GEPass::GraphLevelState &g_state) { + for (const auto &node : nodes_im_re_pass) { + if (node == nullptr) { + GELOGW("Found null immediately re-pass node when executing pass %s on node %s type %s", name_to_pass.first.c_str(), + cur_node->GetName().c_str(), cur_node->GetType().c_str()); + continue; + } + if (g_state.nodes_passed.count(node) > 0) { + g_state.AddNodeToQueueFront(node); + continue; + } + // exp: constant folding add new const need repass immediate + if (AllNodesIn(node->GetInAllNodes(), g_state.nodes_passed)) { + g_state.AddNodeToQueueFront(node); + continue; + } + GELOGW("The node %s specified by pass %s has un-passed in_nodes, it will not repass immediately", + node->GetName().c_str(), name_to_pass.first.c_str()); } - during_pass_node_set.nodes_re_pass_immediately.clear(); } -void AddResumeNodes(DuringPassNodeSets &during_pass_node_set, std::deque &nodes) { - for (auto &node : during_pass_node_set.nodes_resume) { - const auto &it = during_pass_node_set.nodes_suspend.find(node); - if (it != during_pass_node_set.nodes_suspend.end()) { - during_pass_node_set.nodes_suspend.erase(node); - GELOGD("The node %s resumed by pass.", node->GetName().c_str()); - nodes.push_back(node); - } else { - GELOGW("The node %s not suspend, drop from resumed", node->GetName().c_str()); +void AddLastNodesToQueue(GEPass::GraphLevelState &g_state) { + for (auto &node : g_state.nodes_last) { + // todo 为什么会在node_seen中看到node_last,blame一下看看历史合入记录 + if (node->IsAllInNodesSeen(g_state.nodes_seen)) { + g_state.AddNodeToQueueIfNotSeen(node); } } - during_pass_node_set.nodes_resume.clear(); + g_state.nodes_last.clear(); } -void PushToSuspendNodes(DuringPassNodeSets &during_pass_node_set, const std::string &pass_name, - const std::unordered_set &nodes_suspend, - const std::unordered_set &nodes_resume) { +void SuspendAndResume(const std::string &pass_name, + const std::unordered_set &nodes_suspend, + const std::unordered_set &nodes_resume, + GEPass::GraphLevelState &g_state) { + // TODO 当前没有记录NodePass中suspend和resume的顺序,因此无法辨别NodePass中是先做Suspend还是Resume。 + // 因此此处的简单处理是如果在NodePass的过程中,触发了suspend/resume,那么框架以resume为准 + // 更好的处理方式是,在NodePass做suspend/resume时,做顺序的记录,在此函数中按序做回放 for (const auto &node : nodes_suspend) { GELOGD("The iteration suspend of node %s has been set by pass %s", node->GetName().c_str(), pass_name.c_str()); - during_pass_node_set.nodes_suspend.emplace(node); + g_state.nodes_suspend.insert(node); } for (const auto &node : nodes_resume) { - GELOGD("The iteration suspend of node %s has been resumed by pass %s", node->GetName().c_str(), pass_name.c_str()); - during_pass_node_set.nodes_resume.emplace(node); + if (g_state.nodes_suspend.erase(node) > 0) { + if (g_state.nodes_seen.count(node.get()) > 0 || node->IsAllInNodesSeen(g_state.nodes_seen)) { + g_state.nodes.push_back(node); + GELOGD("Node %s has been resumed by pass %s", node->GetName().c_str(), pass_name.c_str()); + } + } } } @@ -140,54 +143,6 @@ void PushToRePassIfSeen(NodePtr &node, const std::pairGetName().c_str()); - for (const auto &name_to_pass : names_to_passes) { - if (name_to_pass.second == nullptr) { - GELOGE(INTERNAL_ERROR, "[Check][Param] There is null pointer in passes(%s), skip it", name_to_pass.first.c_str()); - continue; - } - - GELOGD("Begin to run pass %s for node %s", name_to_pass.first.c_str(), node->GetName().c_str()); - name_to_pass.second->init(); - auto result = name_to_pass.second->Run(node); - if (result != SUCCESS) { - REPORT_CALL_ERROR("E19999", "process pass %s on node:%s failed, ret:%u", - name_to_pass.first.c_str(), node->GetName().c_str(), result); - GELOGE(INTERNAL_ERROR, "[Process][Pass] %s on node %s failed, result " - "%u, the passes will be terminated immediately.", - name_to_pass.first.c_str(), node->GetName().c_str(), result); - return result; - } - - const auto &nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); - PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass, - during_pass_node_set.nodes_re_pass); - - const auto &nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); - PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, - during_pass_node_set.nodes_re_pass_immediately); - - PushToSuspendNodes(during_pass_node_set, name_to_pass.first, - name_to_pass.second->GetNodesSuspend(), name_to_pass.second->GetNodesResume()); - - const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); - during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); - if (nodes_deleted_by_pass.count(node) > 0) { - GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), - name_to_pass.first.c_str()); - break; - } - } - - return SUCCESS; -} - void SetFlagOption(NodePassOption option, NamesToPass names_to_pass) { for (auto &name_to_pass : names_to_pass) { name_to_pass.second->SetOption(option, ""); @@ -200,26 +155,39 @@ void ClearOption(NamesToPass names_to_pass) { } } -bool CheckNode(const NodePtr &node, const DuringPassNodeSets &during_pass_node_set) { +bool ShouldNodePassActually(const NodePtr &node, const GEPass::GraphLevelState &g_state) { if (node == nullptr) { GELOGW("node is null"); return false; } - if (during_pass_node_set.nodes_deleted.count(node) > 0) { + // 因为在PassNode之前,会首先将node的输出节点添加queue,因此若在pass node时,删除了node的输出节点, + // 那么会出现:已经删除的节点出现在queue中,并且被pop出来,因此这里做确认,如果node已经被删除过了,就跳过pass + if (g_state.nodes_deleted.count(node) > 0) { GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); return false; } - if (during_pass_node_set.nodes_suspend.count(node) > 0) { + + // 因为在PassNode之前,会首先将node的输出节点添加queue,因此若在pass node时,suspend了node的输出节点,后续逻辑与上面相同 + // TODO 需要注意的是,这里的保证是一次”尽力而为“,若pass node时,将node之前的节点`A`添加到了suspend, + // 那么`A`节点的后继和间接后继节点的pass不会受到suspend的影响 + // 理论上来说,如果在pass node之前,首先收集node的输出节点,在pass后,将输出节点做suspend、delete的去除,然后加queue, + // 这样处理就不需要在这里做额外的确认了 + if (g_state.nodes_suspend.count(node) > 0) { + GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", + node->GetName().c_str()); + return false; + } + if (!AllNodesIn(node->GetInAllNodes(), g_state.nodes_suspend)) { GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", node->GetName().c_str()); return false; } - return true; } } // namespace -Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector &io_map) { +Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector &io_map, + bool is_repass_io_immediately) { if (node == nullptr) { REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid."); GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); @@ -235,7 +203,7 @@ Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector return FAILED; } - AddRePassNodesWithInOut(node); + is_repass_io_immediately ? AddImmediateRePassNodesWithInOut(node) : AddRePassNodesWithInOut(node); if (GraphUtils::IsolateNode(node, io_map) != GRAPH_SUCCESS) { REPORT_CALL_ERROR("E19999", "Isolate Node:%s failed", node->GetName().c_str()); @@ -263,6 +231,12 @@ Status GEPass::Run(const NamesToPass &names_to_passes) { GELOGW("No passes input, the GEPass will do nothing"); return INTERNAL_ERROR; } + for (const auto &name_to_pass : names_to_passes) { + if (name_to_pass.second == nullptr) { + GELOGE(INTERNAL_ERROR, "[Check][Param] There is null pointer in passes(%s)", name_to_pass.first.c_str()); + return INTERNAL_ERROR; + } + } if (depth_ > kMaxRecursiveDepth) { GELOGE(PARAM_INVALID, @@ -275,81 +249,93 @@ Status GEPass::Run(const NamesToPass &names_to_passes) { return RunPassesOneGraph(names_to_passes); } -Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { - GELOGD("Begin to run pass on graph, passes count %zu", names_to_passes.size()); - std::deque nodes; - DuringPassNodeSets during_pass_node_set; - GetAllNodesNoInputEdge(graph_, nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last); - GELOGD("Start points count %zu", nodes.size()); - int re_pass_times = 0; +void NotifyPassGraphStart(const ComputeGraphPtr &graph, const NamesToPass &names_to_pass) { + for (auto &name_to_pass : names_to_pass) { + name_to_pass.second->OnStartPassGraph(graph); + } +} - do { - for (auto &node : during_pass_node_set.nodes_re_pass) { - nodes.push_back(node); - during_pass_node_set.nodes_seen.insert(node.get()); +Status GEPass::HandleLeakedSuspendNodes(const NamesToPass &names_to_passes, GraphLevelState &g_state) { + for (auto &name_to_pass : names_to_passes) { + name_to_pass.second->init(); + auto ret = name_to_pass.second->OnSuspendNodesLeaked(); + if (ret != SUCCESS) { + // todo error + return ret; } - during_pass_node_set.nodes_re_pass.clear(); - - while (!nodes.empty()) { - NodePtr node = nodes.front(); - nodes.pop_front(); + SuspendAndResume(name_to_pass.first, + name_to_pass.second->GetNodesSuspend(), + name_to_pass.second->GetNodesResume(), + g_state); + } + return SUCCESS; +} - (void)during_pass_node_set.nodes_re_pass.erase(node); - if (!CheckNode(node, during_pass_node_set)) { - continue; - } - AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set); +Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { + GELOGD("Begin to run pass on graph, passes count %zu", names_to_passes.size()); + NotifyPassGraphStart(graph_, names_to_passes); + GraphLevelState g_state; + g_state.re_pass_times = 0; + GetAllNodesNoInputEdge(graph_, g_state); + GELOGD("Start points count %zu", g_state.nodes.size()); - auto ret = RunPasses(node, names_to_passes, during_pass_node_set); + do { + if (!g_state.nodes_suspend.empty()) { + auto ret = HandleLeakedSuspendNodes(names_to_passes, g_state); if (ret != SUCCESS) { - GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", - node->GetName().c_str(), node->GetType().c_str(), ret); + // todo log return ret; } - - bool has_sub_graph = false; - ret = RunPassesOnSubGraph(node, names_to_passes, has_sub_graph); - if (ret != SUCCESS) { - GELOGE(ret, "[Run][Passes] on the sub graph of node %s failed", node->GetName().c_str()); - return ret; + if (g_state.nodes.empty()) { + // todo 报错,因为suspend泄露场景,没有子类做进一步的resume,此处可能已经彻底泄露,需要报错 + return INTERNAL_ERROR; } + } + auto ret = RunPassesGraphRepass(names_to_passes, g_state); + if (ret != SUCCESS) { + return ret; + } + } while (!g_state.nodes_suspend.empty()); + + return SUCCESS; +} - if (has_sub_graph) { - GELOGD("There are subgraphs on node %s, run passes for for the second time", node->GetName().c_str()); - SetFlagOption(kOptimizeAfterSubGraph, names_to_passes); - ret = RunPasses(node, names_to_passes, during_pass_node_set); - if (ret != SUCCESS) { - GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code: %u", - node->GetName().c_str(), node->GetType().c_str(), ret); - return ret; - } - - // There is only one option scene, so set and clear options around the `RunPasses` func. - // if there are more than one scene to set options, the `ClearOption` function - // should be called each time at the begin of the iteration - ClearOption(names_to_passes); - } - AddRepassNodes(during_pass_node_set, nodes); - AddResumeNodes(during_pass_node_set, nodes); +Status GEPass::RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLevelState &g_state) { + RepassLevelState rp_state; + do { + for (auto &node : rp_state.nodes_re_pass) { + g_state.AddNodeToQueue(node); } + rp_state.nodes_re_pass.clear(); + + while (!g_state.nodes.empty()) { + auto node = g_state.PopFront(); - for (auto &node : during_pass_node_set.nodes_last) { - bool all_in_nodes_seen = node->IsAllInNodesSeen(during_pass_node_set.nodes_seen); - if (all_in_nodes_seen && during_pass_node_set.nodes_seen.insert(node.get()).second) { - nodes.push_back(node); + (void)rp_state.nodes_re_pass.erase(node); // todo 回忆一下为什么 + if (!ShouldNodePassActually(node, g_state)) { + continue; + } + g_state.nodes_seen.insert(node.get()); // todo 为什么这里seen + AddNextIterNodes(node, g_state); + + auto ret = RunPassesNodeOnce(node, names_to_passes, g_state, rp_state); + if (ret != SUCCESS) { + GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", node->GetName().c_str(), + node->GetType().c_str(), ret); + return ret; } } - during_pass_node_set.nodes_last.clear(); - } while ((!during_pass_node_set.nodes_re_pass.empty() || !nodes.empty()) && ++re_pass_times < kMaxRePassTimes); + AddLastNodesToQueue(g_state); + } while ((!rp_state.nodes_re_pass.empty() || !g_state.nodes.empty()) && ++g_state.re_pass_times < kMaxRePassTimes); - if (re_pass_times == kMaxRePassTimes) { + if (g_state.re_pass_times == kMaxRePassTimes) { GELOGW("re_pass_times should not come to %d", kMaxRePassTimes); } GELOGD("All passes runs end"); - return SUCCESS; } + Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph) { auto sub_graph_names = node->GetOpDesc()->GetSubgraphInstanceNames(); has_sub_graph = false; @@ -371,4 +357,87 @@ Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names } return SUCCESS; } + +Status GEPass::RunPassesNodeOnce(NodePtr &node, const NamesToPass &names_to_passes, + GraphLevelState &g_state, RepassLevelState &rp_state) { + auto ret = RunPassesOnNode(node, names_to_passes, g_state, rp_state); + if (ret != SUCCESS) { + GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", node->GetName().c_str(), + node->GetType().c_str(), ret); + return ret; + } + + bool has_sub_graph = false; + ret = RunPassesOnSubGraph(node, names_to_passes, has_sub_graph); + if (ret != SUCCESS) { + GELOGE(ret, "[Run][Passes] on the sub graph of node %s failed", node->GetName().c_str()); + return ret; + } + + if (has_sub_graph) { + GELOGD("There are subgraphs on node %s, run passes for for the second time", node->GetName().c_str()); + SetFlagOption(kOptimizeAfterSubGraph, names_to_passes); + ret = RunPassesOnNode(node, names_to_passes, g_state, rp_state); + if (ret != SUCCESS) { + GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code: %u", node->GetName().c_str(), + node->GetType().c_str(), ret); + return ret; + } + + // There is only one option scene, so set and clear options around the `RunPasses` func. + // if there are more than one scene to set options, the `ClearOption` function + // should be called each time at the begin of the iteration + ClearOption(names_to_passes); + } + return SUCCESS; +} + +Status GEPass::RunPassesOnNode(NodePtr &node, const NamesToPass &names_to_passes, GraphLevelState &g_state, + RepassLevelState &rp_state) { + if (node == nullptr) { + REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid."); + GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); + return FAILED; + } + GELOGD("Begin to run pass for node %s", node->GetName().c_str()); + for (const auto &name_to_pass : names_to_passes) { + GELOGD("Begin to run pass %s for node %s", name_to_pass.first.c_str(), node->GetName().c_str()); + name_to_pass.second->init(); + auto result = name_to_pass.second->Run(node); + if (result != SUCCESS) { + REPORT_CALL_ERROR("E19999", "process pass %s on node:%s failed, ret:%u", + name_to_pass.first.c_str(), node->GetName().c_str(), result); + GELOGE(INTERNAL_ERROR, "[Process][Pass] %s on node %s failed, result " + "%u, the passes will be terminated immediately.", + name_to_pass.first.c_str(), node->GetName().c_str(), result); + return result; + } + if (name_to_pass.second->GetNodesDeleted().count(node) > 0) { + GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), + name_to_pass.first.c_str()); + break; + } + } + + g_state.nodes_passed.insert(node); + + for (const auto &name_to_pass : names_to_passes) { + PushToRePassIfSeen(node, name_to_pass, g_state.nodes_seen, + name_to_pass.second->GetNodesNeedRePass(), + rp_state.nodes_re_pass); + + AddImmediateRepassNodesToQueue(node, name_to_pass, + name_to_pass.second->GetNodesNeedRePassImmediately(), + g_state); + SuspendAndResume(name_to_pass.first, + name_to_pass.second->GetNodesSuspend(), + name_to_pass.second->GetNodesResume(), + g_state); + + const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); + g_state.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); + } + + return SUCCESS; +} } // namespace ge diff --git a/ge/graph/passes/base_pass.h b/ge/graph/passes/base_pass.h index d0f125b2..5d744c08 100644 --- a/ge/graph/passes/base_pass.h +++ b/ge/graph/passes/base_pass.h @@ -22,7 +22,6 @@ #include #include #include - #include "framework/common/ge_inner_error_codes.h" #include "framework/common/types.h" #include "graph/compute_graph.h" @@ -61,23 +60,32 @@ class BaseNodePass { const std::unordered_set &GetNodesResume() { return nodes_resume_; } + virtual Status OnSuspendNodesLeaked() { return SUCCESS; } + void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } void ClearOptions() { options_.clear(); } void init() { nodes_need_re_pass_.clear(); - nodes_deleted_.clear(); nodes_need_re_pass_immediately_.clear(); + nodes_deleted_.clear(); nodes_suspend_.clear(); nodes_resume_.clear(); } + virtual void OnStartPassGraph(const ComputeGraphPtr &graph) { + current_graph_name_ = graph->GetName(); + } + protected: - Status IsolateAndDeleteNode(NodePtr &node, const std::vector &io_map); + const string &GetCurrentGraphName() const { + return current_graph_name_; + } + Status IsolateAndDeleteNode(NodePtr &node, const std::vector &io_map, bool is_repass_io_immediately = false); - Status IsolateAndDeleteNode(NodePtr &node, const std::initializer_list &io_map) { - return IsolateAndDeleteNode(node, std::vector(io_map)); + Status IsolateAndDeleteNode(NodePtr &node, const std::initializer_list &io_map, bool is_repass_io_immediately = false) { + return IsolateAndDeleteNode(node, std::vector(io_map), is_repass_io_immediately); } /// @@ -112,6 +120,22 @@ class BaseNodePass { } } + /// + /// Add a node and it's input/output data nodes to be optimized immediately again. + /// @param node + /// + void AddImmediateRePassNodesWithInOut(const NodePtr &node) { + auto in_nodes = node->GetInNodes(); + for (auto &in_node : in_nodes) { + AddImmediateRePassNode(in_node); + } + AddImmediateRePassNode(node); + auto out_nodes = node->GetOutNodes(); + for (auto &out_node : out_nodes) { + AddImmediateRePassNode(out_node); + } + } + /// /// If you deleted a node from the graph, especially current node. The remain /// iterate passes will continue process on the deleted node(if it can be @@ -123,23 +147,15 @@ class BaseNodePass { void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } /// - /// If you suspend a node from the graph, especially following node. The remain - /// iterate passes will stop process on the suspend node(if it can be + /// If you postpone a node from the graph, especially following node. The remain + /// iterate passes will stop process on the postpone node(if it can be /// reached by edge connections) till the last one. Obviously it is a waste of - /// time. You can add the suspend nodes by calling this function, to stop the + /// time. You can add the postpone nodes by calling this function, to stop the /// next iterations. /// @param node /// void AddNodeSuspend(const NodePtr &node) { nodes_suspend_.insert(node); } - /// - /// If you resume a node from the graph, especially following node. The remain - /// iterate passes will continue process on the resume node(if it can be - /// reached by edge connections) till the last one. - /// You can add the resume nodes by calling this function, to resume the - /// next iterations. - /// @param node - /// void AddNodeResume(const NodePtr &node) { nodes_resume_.insert(node); } bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } @@ -151,6 +167,7 @@ class BaseNodePass { std::unordered_set nodes_suspend_; std::unordered_set nodes_resume_; std::map options_; + std::string current_graph_name_; }; using NamesToPass = std::vector>; @@ -160,12 +177,60 @@ class GEPass { explicit GEPass(ComputeGraphPtr &graph) : graph_(graph), root_graph_(graph), depth_(1) {} virtual ~GEPass() = default; Status Run(const NamesToPass &names_to_passes); + /* +* todo +* OneGraph: nodes_deleted, nodes_seen, nodes_passed, nodes_suspended +* RePass: nodes_re_pass +* GraphOneTime: nodes_last +* NodeOneTime: nodes_re_pass_immediately, nodes_resume +*/ + struct GraphLevelState { + std::unordered_set nodes_deleted; + std::unordered_set nodes_seen; + std::unordered_set nodes_passed; + std::unordered_set nodes_suspend; + std::unordered_set nodes_last; + std::deque nodes; + int re_pass_times; + + void AddNodeToQueueFront(NodePtr node) { + nodes_seen.insert(node.get()); + nodes.emplace_front(std::move(node)); + } + + void AddNodeToQueue(NodePtr node) { + nodes_seen.insert(node.get()); + nodes.emplace_back(std::move(node)); + } + void AddNodeToQueueIfNotSeen(NodePtr node) { + if (nodes_seen.insert(node.get()).second) { + nodes.emplace_back(std::move(node)); + } + } + NodePtr PopFront() { + NodePtr node = nodes.front(); + nodes.pop_front(); + return node; + } + }; + struct RepassLevelState { + std::unordered_set nodes_re_pass; + }; + struct GraphOneTimeLevelState { + std::unordered_set nodes_last; + }; private: GEPass(ComputeGraphPtr &graph, ComputeGraphPtr &root_graph, int depth) : graph_(graph), root_graph_(root_graph), depth_(depth) {} + Status RunPassesNodeOnce(NodePtr &node, const NamesToPass &names_to_passes, + GraphLevelState &g_state, RepassLevelState &rp_state); + Status RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLevelState &g_state); Status RunPassesOneGraph(const NamesToPass &names_to_passes); Status RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph); + Status RunPassesOnNode(NodePtr &node, const NamesToPass &names_to_passes, GraphLevelState &g_state, + RepassLevelState &rp_state); + Status HandleLeakedSuspendNodes(const NamesToPass &names_to_passes, GraphLevelState &g_state); ComputeGraphPtr graph_; ComputeGraphPtr root_graph_; int depth_; diff --git a/ge/graph/passes/folding_pass.cc b/ge/graph/passes/folding_pass.cc index 819c3b40..d586a57b 100755 --- a/ge/graph/passes/folding_pass.cc +++ b/ge/graph/passes/folding_pass.cc @@ -363,7 +363,7 @@ Status FoldingPass::ConnectNodeToInAnchor(InDataAnchorPtr &in_anchor, NodePtr &n in_anchor->GetIdx()); return INTERNAL_ERROR; } - AddRePassNodesWithInOut(node); + AddImmediateRePassNode(node); return SUCCESS; } } // namespace ge diff --git a/ge/graph/passes/infer_base_pass.h b/ge/graph/passes/infer_base_pass.h index 3900b5db..911d7fc7 100644 --- a/ge/graph/passes/infer_base_pass.h +++ b/ge/graph/passes/infer_base_pass.h @@ -36,7 +36,7 @@ class InferBasePass : public BaseNodePass { * @param dst, output TensorDesc to be updated * @return */ - virtual graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) = 0; + virtual graphStatus UpdateTensorDesc(GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) = 0; /** * Update the output TensorDesc for nodes which contain subgraphs. diff --git a/ge/graph/passes/infer_value_range_pass.cc b/ge/graph/passes/infer_value_range_pass.cc index e714e90a..1d305fbb 100644 --- a/ge/graph/passes/infer_value_range_pass.cc +++ b/ge/graph/passes/infer_value_range_pass.cc @@ -207,7 +207,7 @@ bool InferValueRangePass::InputHasUnknownValueRange(const NodePtr &node) const { return has_unknown_value_range; } -graphStatus InferValueRangePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { +graphStatus InferValueRangePass::UpdateTensorDesc(GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { if (src == nullptr || dst == nullptr) { REPORT_CALL_ERROR("E19999", "While updating tensor desc, input desc is null."); GELOGE(GRAPH_FAILED, "[Param][check] While updating tensor desc, input desc is null."); diff --git a/ge/graph/passes/infer_value_range_pass.h b/ge/graph/passes/infer_value_range_pass.h index eb485c87..cc971d17 100644 --- a/ge/graph/passes/infer_value_range_pass.h +++ b/ge/graph/passes/infer_value_range_pass.h @@ -26,7 +26,7 @@ class InferValueRangePass : public InferBasePass { private: std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override; - graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override; + graphStatus UpdateTensorDesc(GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override; graphStatus UpdateOutputFromSubgraphs(const std::vector &src, GeTensorDescPtr &dst) override; graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector &src, GeTensorDescPtr &dst) override; diff --git a/ge/graph/passes/infershape_pass.cc b/ge/graph/passes/infershape_pass.cc index 60a2f09a..58ee0348 100755 --- a/ge/graph/passes/infershape_pass.cc +++ b/ge/graph/passes/infershape_pass.cc @@ -1,6 +1,6 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd - * + * Copyright 2020-2021 Huawei Technologies Co., Ltd + *+ * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -22,13 +22,16 @@ #include "graph/shape_refiner.h" #include "graph/utils/graph_utils.h" #include "graph/utils/node_utils.h" -#include "graph/common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" -namespace ge { +#include "external/graph/operator_factory.h" +namespace ge { +namespace { +constexpr int kSwitchExitAnchorIndex = 0; +constexpr int kSwitchPredAnchorIndex = 1; void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { desc_str += "["; std::vector> shape_range; @@ -47,129 +50,280 @@ void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { desc_str += "},"; } } +void UpdateShapeAndDType(const GeTensorDescPtr &src, GeTensorDescPtr &dst) { + dst->SetOriginShape(src->GetOriginShape()); + dst->SetShape(src->GetShape()); + dst->SetDataType(src->GetDataType()); + dst->SetOriginDataType(src->GetOriginDataType()); + vector> src_shape_range; + src->GetShapeRange(src_shape_range); + dst->SetShapeRange(src_shape_range); + dst->SetOriginShapeRange(src_shape_range); + ge::TensorUtils::SetRealDimCnt(*dst, static_cast(src->GetShape().GetDims().size())); +} +} // namespace -std::string GetInTensorInfoWithString(const ge::NodePtr &node) { - ge::OpDescPtr op_desc = node->GetOpDesc(); +std::string InferShapePass::SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const { std::stringstream ss; - ss << "{"; - int32_t in_idx = 0; - for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { - if (input_desc == nullptr) { - in_idx++; - continue; - } - if (in_idx > 0) { - ss << " "; - } - ss << "input_" << in_idx << " " << "tensor: ["; - ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),"; - ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),"; - ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),"; - ss << "(origin_shape:" << input_desc->GetOriginShape().ToString() << "),"; - ss << "(origin_format:" << TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) << "),"; - ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) << "),"; - string range_str; - SerialShapeRange(input_desc, range_str); - ss << "(shape_range:" << range_str << ")]"; - in_idx++; - } + ss << "(shape:[" << tensor_desc->MutableShape().ToString() << "]),"; + ss << "(format:" << TypeUtils::FormatToSerialString(tensor_desc->GetFormat()) << "),"; + ss << "(dtype:" << TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()) << "),"; + ss << "(origin_shape:" << tensor_desc->GetOriginShape().ToString() << "),"; + ss << "(origin_format:" << TypeUtils::FormatToSerialString(tensor_desc->GetOriginFormat()) << "),"; + ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(tensor_desc->GetOriginDataType()) << "),"; + string range_str; + SerialShapeRange(tensor_desc, range_str); + ss << "(shape_range:" << range_str << ")"; return ss.str(); } +Status InferShapePass::SuspendV1LoopExitNodes(const NodePtr &node) { + if (node->GetType() != SWITCH) { + return SUCCESS; + } + auto pred_node = NodeUtils::GetInDataNodeByIndex(*node, kSwitchPredAnchorIndex); + GE_CHECK_NOTNULL(pred_node); + if (pred_node->GetType() != LOOPCOND) { + return SUCCESS; + } -Status InferShapePass::Run(NodePtr &node) { - // kOptimizeAfterSubGraph exist means after subgraph - auto ret = ShapeRefiner::InferShapeAndType(node, !OptionExists(kOptimizeAfterSubGraph)); - if (ret != GRAPH_SUCCESS) { - // select INFERSHAPE failed info - auto graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - auto root_graph = ge::GraphUtils::FindRootGraph(graph); - GE_CHECK_NOTNULL(root_graph); - analyzer::DataInfo analyze_info{root_graph->GetSessionID(), root_graph->GetGraphID(), - analyzer::INFER_SHAPE, node, "InferShapeFailed!"}; - (void)Analyzer::GetInstance()->DoAnalyze(analyze_info); - (void)Analyzer::GetInstance()->SaveAnalyzerDataToFile(root_graph->GetSessionID(), - root_graph->GetGraphID()); - - REPORT_CALL_ERROR("E19999", "Call InferShapeAndType for node:%s(%s) failed, input_tensor:%s", - node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str()); - GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "[Call][InferShapeAndType] for node:%s(%s) failed, input_tensor:%s", - node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str()); - return GE_GRAPH_INFERSHAPE_FAILED; - } - - GE_CHK_STATUS_RET_NOLOG(RePassLoopNode(node)); - bool need_repass = false; - auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_repass); - if (has_attr) { - if (!OptionExists(kOptimizeAfterSubGraph)) { - return SUCCESS; - } - if (need_repass) { - AddImmediateRePassNode(node); - GELOGD("Node %s need repass immediately.", node->GetName().c_str()); - } else { - // clear attr on while - node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); + for (const auto &anchor_2_node : NodeUtils::GetOutDataNodesWithAnchorByIndex(*node, kSwitchExitAnchorIndex)) { + GELOGI("Found v1 loop when infershape, suspend Exit node %s, type %s.", anchor_2_node.second->GetName().c_str(), + anchor_2_node.second->GetType().c_str()); + auto iter = graphs_2_suspend_nodes_.find(GetCurrentGraphName()); + auto &suspend_nodes = graphs_2_suspend_nodes_[GetCurrentGraphName()]; + if (suspend_nodes.nodes_set.insert(anchor_2_node.second).second) { + suspend_nodes.nodes.push(anchor_2_node.second); + AddNodeSuspend(anchor_2_node.second); } } return SUCCESS; } -Status InferShapePass::RePassLoopNode(const NodePtr &node) { - const auto RePassNode = [&](const std::set &re_pass_types) { - for (auto &n : node->GetOutDataNodes()) { - GE_CHECK_NOTNULL(n); - std::string node_type; - GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "[Get][OriginalType] of node:%s failed.", n->GetName().c_str()); - if (re_pass_types.count(node_type) > 0) { - AddImmediateRePassNode(n); - (void)AttrUtils::SetBool(n->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, false); - GELOGD("Node %s need repass immediately after %s.", n->GetName().c_str(), node->GetName().c_str()); +Status InferShapePass::Infer(NodePtr &node) { + auto ret = SuspendV1LoopExitNodes(node); + if (ret != SUCCESS) { + //todo LOG + return ret; + } + bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); + auto opdesc = node->GetOpDesc(); + if (node->Verify() != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Verifying %s failed.", node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[Call][Verify] Verifying %s failed.", node->GetName().c_str()); + return GRAPH_FAILED; + } + Operator op = OpDescUtils::CreateOperatorFromNode(node); + + if (!is_unknown_graph) { + auto inference_context = ShapeRefiner::CreateInferenceContext(node); + GE_CHECK_NOTNULL(inference_context); + GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size()); + op.SetInferenceContext(inference_context); + } + + graphStatus status = CallInferShapeFunc(node, op); + if (status != GRAPH_NODE_NEED_REPASS && status != GRAPH_PARAM_INVALID && status != GRAPH_SUCCESS) { + // node like netoutput return param_invalid, but valid ? + REPORT_CALL_ERROR("E19999", "%s call infer function failed.", node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[Call][InferFunction] failed, node:%s.", node->GetName().c_str()); + return GRAPH_FAILED; + } + if (!is_unknown_graph) { + auto ctx_after_infer = op.GetInferenceContext(); + if (ctx_after_infer != nullptr) { + GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size()); + if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) { + GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), + ctx_after_infer->GetMarks().size()); + ShapeRefiner::PushToContextMap(node, ctx_after_infer); } } - return SUCCESS; - }; - - const auto ExProcNode = [&](const std::set &proc_types, - const std::function &proc_func, - const std::string &info) { - for (auto &n : node->GetOutDataNodes()) { - GE_CHECK_NOTNULL(n); - std::string node_type; - GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "[Get][OriginalType] of node:%s failed.", n->GetName().c_str()); - if (proc_types.count(node_type) > 0) { - proc_func(this, n); - GELOGD("Node %s %s after %s.", n->GetName().c_str(), info.c_str(), node->GetName().c_str()); - } + } + return (status == GRAPH_NODE_NEED_REPASS) ? GRAPH_NODE_NEED_REPASS : GRAPH_SUCCESS; +} + +bool InferShapePass::SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) { + // check shape range + vector> src_shape_range; + vector> dst_shape_range; + src->GetShapeRange(src_shape_range); + dst->GetShapeRange(dst_shape_range); + if (src_shape_range.size() != dst_shape_range.size()) { + GELOGI("Src shape range size is %zu, dst shape range size is %zu, not same.", src_shape_range.size(), + dst_shape_range.size()); + return false; + } + for (size_t i = 0; i < src_shape_range.size(); ++i) { + if (src_shape_range[i].first != dst_shape_range[i].first || + src_shape_range[i].second != dst_shape_range[i].second) { + GELOGI("Current dim %zu. Src shape range is [%lu-%lu], dst shape range is [%lu-%lu], not same.", + i, src_shape_range[i].first, src_shape_range[i].second, dst_shape_range[i].first, dst_shape_range[i].second); + return false; } + } + + // check shape + auto src_shape = src->GetShape(); + auto dst_shape = dst->GetShape(); + if (src_shape.GetDims() != dst_shape.GetDims() || src->GetOriginShape().GetDims() != dst->GetOriginShape().GetDims() || + src->GetDataType() != dst->GetDataType() || src->GetOriginDataType() != dst->GetOriginDataType()) { + GELOGD( + "Src shape is %s, origin_shape is %s, data_type is %s, origin data_type is %s; " + "Dst shape is %s, origin_shape is %s, data_type is %s, original data_type is %s, not same.", + src_shape.ToString().c_str(), src->GetOriginShape().ToString().c_str(), + TypeUtils::DataTypeToSerialString(src->GetDataType()).c_str(), + TypeUtils::DataTypeToSerialString(src->GetOriginDataType()).c_str(), dst_shape.ToString().c_str(), + dst->GetOriginShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(dst->GetDataType()).c_str(), + TypeUtils::DataTypeToSerialString(dst->GetOriginDataType()).c_str()); + return false; + } + return true; +} + +graphStatus InferShapePass::UpdateTensorDesc(GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { + // refresh src itself + src->SetOriginShape(src->GetShape()); + src->SetOriginDataType(src->GetDataType()); + TensorUtils::SetRealDimCnt(*src, static_cast(src->GetOriginShape().GetDims().size())); + vector> src_shape_range; + src->GetShapeRange(src_shape_range); + src->SetOriginShapeRange(src_shape_range); + + changed = false; + if (SameTensorDesc(src, dst)) { + GELOGD("Peer dst tensor_desc is same as src tensor_desc. No need update."); return SUCCESS; - }; + } + + changed = true; + UpdateShapeAndDType(src, dst); + GELOGD( + "UpdatePeerInputDesc from src Node: shape: [%s], datatype: %s, original datatype is %s." + "To dst Node: shape: [%s], datatype: %s, original datatype is %s.", + src->GetShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(src->GetDataType()).c_str(), + TypeUtils::DataTypeToSerialString(src->GetOriginDataType()).c_str(), dst->GetShape().ToString().c_str(), + TypeUtils::DataTypeToSerialString(dst->GetDataType()).c_str(), + TypeUtils::DataTypeToSerialString(dst->GetOriginDataType()).c_str()); + return SUCCESS; +} + +graphStatus InferShapePass::CallInferShapeFunc(NodePtr &node, Operator &op) { + auto op_desc = node->GetOpDesc(); + const auto &op_type = op_desc->GetType(); + auto ret = op_desc->CallInferFunc(op); + if (ret == GRAPH_PARAM_INVALID) { + // Op ir no infer func, try to get infer func from operator factory + auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType()); + if (node_op.IsEmpty()) { + GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str()); + return ret; + } - std::string node_type; - GE_CHK_STATUS_RET(GetOriginalType(node, node_type), - "[Get][OriginalType] of node:%s failed.", node->GetName().c_str()); - if (kNextIterationOpTypes.count(node_type) > 0) { - return RePassNode(kMergeOpTypes); // Re-Pass Merge + GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str()); + auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); + node_op.BreakConnect(); + if (temp_op_desc == nullptr) { + REPORT_CALL_ERROR("E19999", "GetOpDescFromOperator failed, return nullptr."); + GELOGE(GRAPH_FAILED, "[Get][OpDesc] temp op desc is null"); + return GRAPH_FAILED; + } + if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) { + GELOGW("InferShapeAndType UpdateInputName failed"); + for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) { + if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) { + break; + } + return GRAPH_SUCCESS; + } + } + if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) { + GELOGW("InferShapeAndType UpdateOutputName failed"); + } + op_desc->AddInferFunc(temp_op_desc->GetInferFunc()); + ret = op_desc->CallInferFunc(op); + GELOGI("op CallInferFunc second. ret: %u", ret); } + return ret; +} - if (kMergeOpTypes.count(node_type) > 0) { - if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { - node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); - return RePassNode(kSwitchOpTypes); // Re-Pass Switch +graphStatus InferShapePass::UpdateOutputFromSubgraphs(const std::vector &src, GeTensorDescPtr &dst) { + GELOGD("Enter update parent node shape for class branch op process"); + // check sub_graph shape.If not same ,do unknown shape process + auto ref_out_tensor = src.at(0); + ge::GeShape &ref_out_tensor_shape = ref_out_tensor->MutableShape(); + for (auto &tensor : src) { + if (ref_out_tensor->GetDataType() != tensor->GetDataType()) { + REPORT_INNER_ERROR("E19999", "Does not support diff dtype among all ref output, shape:%s", + ref_out_tensor_shape.ToString().c_str()); + GELOGE(GRAPH_FAILED, "[Check][Param] node does not support diff dtype output"); + return GRAPH_FAILED; + } + auto shape = tensor->MutableShape(); + if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { + GELOGD("Shape from subgraph size: %lu, ref_out_tensor_shape size: %lu", shape.GetShapeSize(), + ref_out_tensor_shape.GetShapeSize()); + ref_out_tensor_shape = GeShape(UNKNOWN_RANK); + break; } + for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) { + if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) { + continue; + } + GELOGD("j: %zu ,shape from subgraph size: %lu, ref_out_tensor_shape size: %lu", j, shape.GetShapeSize(), + ref_out_tensor_shape.GetShapeSize()); + (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM); + } + } + UpdateShapeAndDType(ref_out_tensor, dst); + return GRAPH_SUCCESS; +} +graphStatus InferShapePass::UpdateOutputFromSubgraphsForMultiDims(const std::vector &src, + GeTensorDescPtr &dst) { + // check sub_graph shape. Get max for update. + if (src.empty()) { + // TODO LOG return SUCCESS; } - if (kSwitchOpTypes.count(node_type) > 0) { - if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { - node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); - return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeResume, "need resume"); // Resume Exit - } else { - return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeSuspend, "need suspend"); // Suspend Exit + int64_t max_size = 0; + size_t max_shape_index = 0; + auto &ref_out_tensor = src.at(0); + for (size_t j = 0; j < src.size(); ++j) { + auto &tensor = src.at(j); + if (ref_out_tensor->GetDataType() != tensor->GetDataType()) { + REPORT_INNER_ERROR("E19999", "node does not support diff dtype among all ref output"); + GELOGE(GRAPH_FAILED, "[Check][Param] node does not support diff dtype among all ref output"); + return GRAPH_FAILED; + } + + auto shape = tensor->MutableShape(); + int64_t size = 1; + for (auto dim : shape.GetDims()) { + if (dim != 0 && INT64_MAX / dim < size) { + REPORT_INNER_ERROR("E19999", "The shape:%s size overflow", shape.ToString().c_str()); + GELOGE(PARAM_INVALID, "[Check][Overflow] The shape size overflow"); + return PARAM_INVALID; + } + size *= dim; } - } + if (size > max_size) { + max_size = size; + max_shape_index = j; + } + } + UpdateShapeAndDType(src.at(max_shape_index), dst); + return GRAPH_SUCCESS; +} +Status InferShapePass::OnSuspendNodesLeaked() { + auto iter = graphs_2_suspend_nodes_.find(GetCurrentGraphName()); + if (iter == graphs_2_suspend_nodes_.end()) { + // todo log warn + return SUCCESS; + } + if (!iter->second.nodes.empty()) { + AddNodeResume(iter->second.PopSuspendedNode()); + } return SUCCESS; } } // namespace ge diff --git a/ge/graph/passes/infershape_pass.h b/ge/graph/passes/infershape_pass.h index 9c5d432d..2b7fdd3b 100644 --- a/ge/graph/passes/infershape_pass.h +++ b/ge/graph/passes/infershape_pass.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,22 +17,39 @@ #ifndef GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ #define GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ -#include "graph/passes/base_pass.h" +#include "graph/passes/infer_base_pass.h" +#include namespace ge { -class InferShapePass : public BaseNodePass { +class InferShapePass : public InferBasePass { public: - /// - /// Entry of the InferShapePass optimizer - /// @param [in] graph: Input ComputeGraph - /// @return SUCCESS: Execution succeed - /// @return OTHERS: Execution failed - /// @author - /// - Status Run(ge::NodePtr &node) override; + std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override; + graphStatus Infer(NodePtr &node) override; + + graphStatus UpdateTensorDesc(GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override; + graphStatus UpdateOutputFromSubgraphs(const std::vector &src, GeTensorDescPtr &dst) override; + graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector &src, + GeTensorDescPtr &dst) override; + + Status OnSuspendNodesLeaked() override; private: - Status RePassLoopNode(const NodePtr &node); + graphStatus CallInferShapeFunc(NodePtr &node, Operator &op); + bool SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst); + Status SuspendV1LoopExitNodes(const NodePtr &node); + struct SuspendNodes { + std::stack nodes; + std::unordered_set nodes_set; + + NodePtr PopSuspendedNode() { + auto top_node = nodes.top(); + nodes.pop(); + nodes_set.erase(top_node); + return top_node; + } + }; + std::map graphs_2_suspend_nodes_; + }; } // namespace ge #endif // GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ diff --git a/ge/graph/passes/merge_pass.cc b/ge/graph/passes/merge_pass.cc index fec9c6d0..27ea85fd 100644 --- a/ge/graph/passes/merge_pass.cc +++ b/ge/graph/passes/merge_pass.cc @@ -31,6 +31,7 @@ namespace ge { const int kValueIndexOutputIndex = 1; const size_t kCaseNoInput = 0; const size_t kCaseOneInput = 1; +const bool kWillRepassImmediately = true; Status MergePass::Run(NodePtr &node) { GELOGD("MergePass running"); @@ -82,14 +83,14 @@ Status MergePass::Run(NodePtr &node) { } auto in_node = in_data_nodes.at(0); if (IsMergeInputNeedOptimized(in_node)) { - if (IsolateAndDeleteNode(in_node, {0}) != SUCCESS) { + if (IsolateAndDeleteNode(in_node, {0}, kWillRepassImmediately) != SUCCESS) { REPORT_CALL_ERROR("E19999", "Isolate and delete node:%s(%s) failed", in_node->GetName().c_str(), in_node->GetType().c_str()); GELOGE(FAILED, "[Remove][Node] %s failed.", in_node->GetName().c_str()); return FAILED; } } - return IsolateAndDeleteNode(node, merge_io_map); + return IsolateAndDeleteNode(node, merge_io_map, kWillRepassImmediately); } default: { // Case C: input_count > 1, the merge node can not be optimized diff --git a/ge/graph/passes/switch_dead_branch_elimination.cc b/ge/graph/passes/switch_dead_branch_elimination.cc index 3c6c57d0..d67a8d01 100644 --- a/ge/graph/passes/switch_dead_branch_elimination.cc +++ b/ge/graph/passes/switch_dead_branch_elimination.cc @@ -28,6 +28,7 @@ namespace { const std::vector::size_type kDataInputIndex = 0; const std::vector::size_type kPredInputIndex = 1; const int kDefaultInputIndex = -1; +const bool kWillRepassImmediately = true; bool ParsePred(const ConstGeTensorPtr &tensor) { if (tensor == nullptr) { @@ -134,7 +135,7 @@ Status SwitchDeadBranchElimination::DeleteSwitchNode(NodePtr &node, NodePtr &pre return FAILED; } switch_io_map[out_index] = kDataInputIndex; - return IsolateAndDeleteNode(node, switch_io_map); + return IsolateAndDeleteNode(node, switch_io_map, kWillRepassImmediately); } Status SwitchDeadBranchElimination::Run(NodePtr &node) { diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index bc8646e7..ca5f90ea 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -1999,6 +1999,22 @@ Status GraphPrepare::CheckUserInput(const std::vector &user_input) { Status GraphPrepare::InferShapeForPreprocess() { GELOGI("Start infershape for preprocess."); + // Prepare dummy_shape for v1 control_flow op before infershape + for (const auto &node : compute_graph_->GetAllNodes()) { + string type; + GetOriginalType(node, type); + if (type == MERGE || type == REFMERGE) { + for (size_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) { + GELOGD("Prepare for infershape: update %s input_shape as dummy.", node->GetName().c_str()); + NodeUtils::UpdateInputShape(*node, i, GeShape(DUMMY_SHAPE)); + } + } else if (type == WHILE) { + for (size_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) { + GELOGD("Prepare for infershape: update %s output_shape as dummy.", node->GetName().c_str()); + NodeUtils::UpdateOutputShape(*node, i, GeShape(DUMMY_SHAPE)); + } + } + } GEPass ge_passes(compute_graph_); NamesToPass names_to_passes; AssertPass assert_pass;