diff --git a/ge/graph/passes/base_pass.cc b/ge/graph/passes/base_pass.cc index a1551eb2..19682d23 100755 --- a/ge/graph/passes/base_pass.cc +++ b/ge/graph/passes/base_pass.cc @@ -19,9 +19,7 @@ #include #include -#include "framework/common/debug/log.h" -#include "framework/common/debug/ge_log.h" -#include "graph/compute_graph.h" +#include "common/debug/log.h" #include "graph/utils/graph_utils.h" namespace ge { @@ -30,95 +28,161 @@ constexpr int kMaxRePassTimes = 10000; constexpr size_t kMaxOneInNodes = 1000; // Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later constexpr int kMaxRecursiveDepth = 20; -struct DuringPassNodeSets { - std::unordered_set nodes_seen; - std::unordered_set nodes_deleted; - std::unordered_set nodes_re_pass; - std::unordered_set nodes_re_pass_immediately; - std::unordered_set nodes_last; - std::unordered_set nodes_suspend; - std::unordered_set nodes_resume; -}; - -void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque &input_edge_nodes, - std::unordered_set &nodes_seen, std::unordered_set &nodes_last) { - nodes_last.clear(); + +void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, + GEPass::GraphLevelState &g_state) { for (auto &node : graph->GetDirectNode()) { if (node == nullptr) { continue; } size_t in_nums = node->GetInNodes().size(); if (in_nums == 0) { - input_edge_nodes.push_back(node); - nodes_seen.insert(node.get()); + g_state.AddNodeToQueueIfNotSeen(node); } else if (in_nums > kMaxOneInNodes) { - nodes_last.insert(node); + g_state.nodes_last.insert(node); } } } -bool IsAllInNodesAlive(const Node::Vistor &nodes, const std::unordered_set &nodes_suspend) { - return !std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { return nodes_suspend.count(n) > 0; }); +bool AllNodesIn(const Node::Vistor &nodes, const std::unordered_set &nodes_set) { + return !std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { + return nodes_set.count(n) == 0; + }); +} + +bool AnyNodesIn(const Node::Vistor &nodes, const std::unordered_set &nodes_set) { + return std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { + return nodes_set.count(n) > 0; + }); } -void AddNextIterNodes(const Node::Vistor &nodes, std::deque &nodes_to_pass, - DuringPassNodeSets &during_pass_node_set) { - auto &nodes_seen = during_pass_node_set.nodes_seen; - const auto &nodes_last = during_pass_node_set.nodes_last; - const auto &nodes_suspend = during_pass_node_set.nodes_suspend; - for (auto &node : nodes) { +bool IsNodeReadyToQueue(const NodePtr &node, GEPass::GraphLevelState &g_state) { + if (node == nullptr) { + GELOGW("node is null"); + return false; + } + + if (g_state.nodes_deleted.count(node) > 0) { + GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); + return false; + } + + if (g_state.nodes_last.count(node) != 0) { + return false; + } + + if (!node->IsAllInNodesSeen(g_state.nodes_seen)) { + return false; + } + + // 因为在PassNode之前,会首先将node的输出节点添加queue,因此若在pass node时,suspend了node的输出节点,后续逻辑与上面相同 + // TODO 需要注意的是,这里的保证是一次”尽力而为“,若pass node时,将node之前的节点`A`添加到了suspend, + // 那么`A`节点的后继和间接后继节点的pass不会受到suspend的影响 + // 理论上来说,如果在pass node之前,首先收集node的输出节点,在pass后,将输出节点做suspend、delete的去除,然后加queue, + // 这样处理就不需要在这里做额外的确认了 + if (g_state.nodes_suspend.count(node) > 0) { + GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", + node->GetName().c_str()); + return false; + } + if (AnyNodesIn(node->GetInAllNodes(), g_state.nodes_suspend)) { + GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", + node->GetName().c_str()); + return false; + } + return true; +} + +void CollectOutNodesBeforePass(const NodePtr &node, std::unordered_set &out_nodes_before_pass) { + for (const auto &out_node : node->GetOutNodes()) { + out_nodes_before_pass.insert(out_node); + } +} + +void AddNextIterNodes(const NodePtr &cur_node, std::unordered_set &out_nodes_before_pass, + GEPass::GraphLevelState &g_state) { + for (auto &node : cur_node->GetOutNodes()) { if (node == nullptr) { continue; } - if (nodes_last.count(node) != 0) { - continue; + if (out_nodes_before_pass.erase(node) == 0) { + // after pass node , new output node come up + GELOGD("New output nodes %s come up after pass %s.", node->GetName().c_str(), cur_node->GetName().c_str()); + } + + if (IsNodeReadyToQueue(node, g_state)) { + g_state.AddNodeToQueueIfNotSeen(node); } - if (nodes_suspend.count(node) > 0) { - GELOGD("The node %s has suspend by pass, skip it.", node->GetName().c_str()); + } + // A-->B-->C + // \ + // D--->E + // If B has been delete after pass, two case need to consider + // 1. A & C & E has been repass by B. good choice + // 2. A & C & E not added to repass, C will not pass because no one trigger it. + // while E will pass because D will trigger it. + // So here we need add node which has no input_node to queue. + for (const auto &node : out_nodes_before_pass) { + if (!node->GetInAllNodes().empty()) { + GELOGD("Node %s used to be output of node %s, but after pass it doesnt. " + "It may triggered by other node, so no need add to queue now."); continue; } - - bool all_in_nodes_alive = IsAllInNodesAlive(node->GetInAllNodes(), nodes_suspend); - bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen); - if (all_in_nodes_seen && all_in_nodes_alive && nodes_seen.insert(node.get()).second) { - nodes_to_pass.push_back(node); + if (IsNodeReadyToQueue(node, g_state)) { + // unlink edge may happen, add these node to queue otherwise they can not pass + GELOGI("Node %s may lost from cur node, add to queue if not seen.", + node->GetName().c_str(), cur_node->GetName().c_str()); + g_state.AddNodeToQueueIfNotSeen(node); } } } -void AddRepassNodes(DuringPassNodeSets &during_pass_node_set, std::deque &nodes) { - for (const auto &node : during_pass_node_set.nodes_re_pass_immediately) { - GELOGD("The node %s will be re-pass immediately.", node->GetName().c_str()); - nodes.push_front(node); +void AddImmediateRepassNodesToQueue(NodePtr &cur_node, + const std::unordered_map re_pass_imm_nodes_to_pass_names, + GEPass::GraphLevelState &g_state) { + for (const auto &node_2_pass_names : re_pass_imm_nodes_to_pass_names) { + auto repass_imm_node = node_2_pass_names.first; + if (repass_imm_node == nullptr) { + GELOGW("Found null immediately re-pass node when executing pass %s on node %s type %s", + node_2_pass_names.second.c_str(), + cur_node->GetName().c_str(), cur_node->GetType().c_str()); + continue; + } + if (g_state.nodes_passed.count(repass_imm_node) > 0) { + GELOGD("The node %s specified by pass %s has been passed, it will repass immediately", + repass_imm_node->GetName().c_str(), node_2_pass_names.second.c_str()); + g_state.AddNodeToQueueFront(repass_imm_node); + continue; + } + GELOGW("The node %s specified by pass %s has un-passed in_nodes, it will not repass immediately", + repass_imm_node->GetName().c_str(), node_2_pass_names.second.c_str()); } - during_pass_node_set.nodes_re_pass_immediately.clear(); } -void AddResumeNodes(DuringPassNodeSets &during_pass_node_set, std::deque &nodes) { - for (auto &node : during_pass_node_set.nodes_resume) { - const auto &it = during_pass_node_set.nodes_suspend.find(node); - if (it != during_pass_node_set.nodes_suspend.end()) { - during_pass_node_set.nodes_suspend.erase(node); - GELOGD("The node %s resumed by pass.", node->GetName().c_str()); - nodes.push_back(node); - } else { - GELOGW("The node %s not suspend, drop from resumed", node->GetName().c_str()); +void AddLastNodesToQueue(GEPass::GraphLevelState &g_state) { + for (auto &node : g_state.nodes_last) { + // todo 为什么会在node_seen中看到node_last,blame一下看看历史合入记录 + if (node->IsAllInNodesSeen(g_state.nodes_seen)) { + g_state.AddNodeToQueueIfNotSeen(node); } } - during_pass_node_set.nodes_resume.clear(); + g_state.nodes_last.clear(); } -void PushToSuspendNodes(DuringPassNodeSets &during_pass_node_set, const std::string &pass_name, - const std::unordered_set &nodes_suspend, - const std::unordered_set &nodes_resume) { - for (const auto &node : nodes_suspend) { - GELOGD("The iteration suspend of node %s has been set by pass %s", node->GetName().c_str(), pass_name.c_str()); - during_pass_node_set.nodes_suspend.emplace(node); - } - - for (const auto &node : nodes_resume) { - GELOGD("The iteration suspend of node %s has been resumed by pass %s", node->GetName().c_str(), pass_name.c_str()); - during_pass_node_set.nodes_resume.emplace(node); +void AddResumeNodesToQueue(const std::unordered_map resume_nodes_to_pass_names, + GEPass::GraphLevelState &g_state) { + // Currently we dont keep the order of suspend nodes and resume nodes, so its hard to know + // which one comes first. Simple way : if a node both have suspend & resume state, we will resume it. + // Better way: keep the order when suspend/resume a node, and in this func suspend/resume in order. + for (const auto &node_2_pass_names : resume_nodes_to_pass_names) { + auto node = node_2_pass_names.first; + if (g_state.nodes_suspend.erase(node) > 0) { + if (g_state.nodes_seen.count(node.get()) > 0 || node->IsAllInNodesSeen(g_state.nodes_seen)) { + g_state.nodes.push_back(node); + GELOGD("Node %s has been resumed by pass %s, add to queue.", + node->GetName().c_str(), node_2_pass_names.second.c_str()); + } + } } } @@ -140,54 +204,6 @@ void PushToRePassIfSeen(NodePtr &node, const std::pairGetName().c_str()); - for (const auto &name_to_pass : names_to_passes) { - if (name_to_pass.second == nullptr) { - GELOGE(INTERNAL_ERROR, "[Check][Param] There is null pointer in passes(%s), skip it", name_to_pass.first.c_str()); - continue; - } - - GELOGD("Begin to run pass %s for node %s", name_to_pass.first.c_str(), node->GetName().c_str()); - name_to_pass.second->init(); - auto result = name_to_pass.second->Run(node); - if (result != SUCCESS) { - REPORT_CALL_ERROR("E19999", "process pass %s on node:%s failed, ret:%u", - name_to_pass.first.c_str(), node->GetName().c_str(), result); - GELOGE(INTERNAL_ERROR, "[Process][Pass] %s on node %s failed, result " - "%u, the passes will be terminated immediately.", - name_to_pass.first.c_str(), node->GetName().c_str(), result); - return result; - } - - const auto &nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); - PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass, - during_pass_node_set.nodes_re_pass); - - const auto &nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); - PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, - during_pass_node_set.nodes_re_pass_immediately); - - PushToSuspendNodes(during_pass_node_set, name_to_pass.first, - name_to_pass.second->GetNodesSuspend(), name_to_pass.second->GetNodesResume()); - - const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); - during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); - if (nodes_deleted_by_pass.count(node) > 0) { - GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), - name_to_pass.first.c_str()); - break; - } - } - - return SUCCESS; -} - void SetFlagOption(NodePassOption option, NamesToPass names_to_pass) { for (auto &name_to_pass : names_to_pass) { name_to_pass.second->SetOption(option, ""); @@ -199,27 +215,10 @@ void ClearOption(NamesToPass names_to_pass) { name_to_pass.second->ClearOptions(); } } - -bool CheckNode(const NodePtr &node, const DuringPassNodeSets &during_pass_node_set) { - if (node == nullptr) { - GELOGW("node is null"); - return false; - } - if (during_pass_node_set.nodes_deleted.count(node) > 0) { - GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); - return false; - } - if (during_pass_node_set.nodes_suspend.count(node) > 0) { - GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", - node->GetName().c_str()); - return false; - } - - return true; -} } // namespace -Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector &io_map) { +Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector &io_map, + bool is_repass_io_immediately) { if (node == nullptr) { REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid."); GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); @@ -235,7 +234,7 @@ Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector return FAILED; } - AddRePassNodesWithInOut(node); + is_repass_io_immediately ? AddImmediateRePassNodesWithInOut(node) : AddRePassNodesWithInOut(node); if (GraphUtils::IsolateNode(node, io_map) != GRAPH_SUCCESS) { REPORT_CALL_ERROR("E19999", "Isolate Node:%s failed", node->GetName().c_str()); @@ -263,6 +262,12 @@ Status GEPass::Run(const NamesToPass &names_to_passes) { GELOGW("No passes input, the GEPass will do nothing"); return INTERNAL_ERROR; } + for (const auto &name_to_pass : names_to_passes) { + if (name_to_pass.second == nullptr) { + GELOGE(INTERNAL_ERROR, "[Check][Param] There is null pointer in passes(%s)", name_to_pass.first.c_str()); + return INTERNAL_ERROR; + } + } if (depth_ > kMaxRecursiveDepth) { GELOGE(PARAM_INVALID, @@ -275,81 +280,101 @@ Status GEPass::Run(const NamesToPass &names_to_passes) { return RunPassesOneGraph(names_to_passes); } +void NotifyPassGraphStart(const ComputeGraphPtr &graph, const NamesToPass &names_to_pass) { + for (auto &name_to_pass : names_to_pass) { + name_to_pass.second->OnStartPassGraph(graph); + } +} + +Status GEPass::HandleLeakedSuspendNodes(const NamesToPass &names_to_passes, GraphLevelState &g_state) { + std::unordered_map resume_nodes_to_pass_names; + for (auto &name_to_pass : names_to_passes) { + name_to_pass.second->init(); + auto ret = name_to_pass.second->OnSuspendNodesLeaked(); + if (ret != SUCCESS) { + GELOGE(ret, "Internal Error happened when pass %s handle on suspend nodes leaked.", + name_to_pass.first.c_str()); + return ret; + } + for (const auto &resume_node : name_to_pass.second->GetNodesResume()){ + resume_nodes_to_pass_names[resume_node].append(name_to_pass.first + ","); + } + } + AddResumeNodesToQueue(resume_nodes_to_pass_names, g_state); + return SUCCESS; +} + Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { GELOGD("Begin to run pass on graph, passes count %zu", names_to_passes.size()); - std::deque nodes; - DuringPassNodeSets during_pass_node_set; - GetAllNodesNoInputEdge(graph_, nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last); - GELOGD("Start points count %zu", nodes.size()); - int re_pass_times = 0; + NotifyPassGraphStart(graph_, names_to_passes); + GraphLevelState g_state; + g_state.re_pass_times = 0; + GetAllNodesNoInputEdge(graph_, g_state); + GELOGD("Start points count %zu", g_state.nodes.size()); do { - for (auto &node : during_pass_node_set.nodes_re_pass) { - nodes.push_back(node); - during_pass_node_set.nodes_seen.insert(node.get()); + if (!g_state.nodes_suspend.empty()) { + auto ret = HandleLeakedSuspendNodes(names_to_passes, g_state); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to handle leaked suspend nodes, break base pass."); + return ret; + } + if (g_state.nodes.empty()) { + // There are suspend nodes leaked, but no pass resume it + GELOGE(INTERNAL_ERROR, "There are suspend nodes but no pass resume, which means" + "some nodes in this graph never pass."); + return INTERNAL_ERROR; + } + } + auto ret = RunPassesGraphRepass(names_to_passes, g_state); + if (ret != SUCCESS) { + return ret; } - during_pass_node_set.nodes_re_pass.clear(); + } while (!g_state.nodes_suspend.empty()); - while (!nodes.empty()) { - NodePtr node = nodes.front(); - nodes.pop_front(); + return SUCCESS; +} - (void)during_pass_node_set.nodes_re_pass.erase(node); - if (!CheckNode(node, during_pass_node_set)) { - continue; - } - AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set); - auto ret = RunPasses(node, names_to_passes, during_pass_node_set); - if (ret != SUCCESS) { - GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", - node->GetName().c_str(), node->GetType().c_str(), ret); - return ret; - } +Status GEPass::RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLevelState &g_state) { + RepassLevelState rp_state; + do { + for (auto &node : rp_state.nodes_re_pass) { + GELOGD("Add node %s to queue for re-pass.", node->GetName().c_str()); + g_state.AddNodeToQueue(node); + } + rp_state.nodes_re_pass.clear(); - bool has_sub_graph = false; - ret = RunPassesOnSubGraph(node, names_to_passes, has_sub_graph); - if (ret != SUCCESS) { - GELOGE(ret, "[Run][Passes] on the sub graph of node %s failed", node->GetName().c_str()); - return ret; - } + while (!g_state.nodes.empty()) { + auto node = g_state.PopFront(); - if (has_sub_graph) { - GELOGD("There are subgraphs on node %s, run passes for for the second time", node->GetName().c_str()); - SetFlagOption(kOptimizeAfterSubGraph, names_to_passes); - ret = RunPasses(node, names_to_passes, during_pass_node_set); - if (ret != SUCCESS) { - GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code: %u", - node->GetName().c_str(), node->GetType().c_str(), ret); - return ret; - } - - // There is only one option scene, so set and clear options around the `RunPasses` func. - // if there are more than one scene to set options, the `ClearOption` function - // should be called each time at the begin of the iteration - ClearOption(names_to_passes); + if (g_state.nodes_deleted.count(node) > 0) { + GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); } + (void)rp_state.nodes_re_pass.erase(node);// todo why + g_state.nodes_seen.insert(node.get()); // todo 为什么这里seen - AddRepassNodes(during_pass_node_set, nodes); - AddResumeNodes(during_pass_node_set, nodes); - } + std::unordered_set out_nodes_before_pass; + CollectOutNodesBeforePass(node, out_nodes_before_pass); - for (auto &node : during_pass_node_set.nodes_last) { - bool all_in_nodes_seen = node->IsAllInNodesSeen(during_pass_node_set.nodes_seen); - if (all_in_nodes_seen && during_pass_node_set.nodes_seen.insert(node.get()).second) { - nodes.push_back(node); + auto ret = RunPassesNodeOnce(node, names_to_passes, g_state, rp_state); + if (ret != SUCCESS) { + GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", node->GetName().c_str(), + node->GetType().c_str(), ret); + return ret; } + AddNextIterNodes(node, out_nodes_before_pass, g_state); } - during_pass_node_set.nodes_last.clear(); - } while ((!during_pass_node_set.nodes_re_pass.empty() || !nodes.empty()) && ++re_pass_times < kMaxRePassTimes); + AddLastNodesToQueue(g_state); + } while ((!rp_state.nodes_re_pass.empty() || !g_state.nodes.empty()) && ++g_state.re_pass_times < kMaxRePassTimes); - if (re_pass_times == kMaxRePassTimes) { + if (g_state.re_pass_times == kMaxRePassTimes) { GELOGW("re_pass_times should not come to %d", kMaxRePassTimes); } GELOGD("All passes runs end"); - return SUCCESS; } + Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph) { auto sub_graph_names = node->GetOpDesc()->GetSubgraphInstanceNames(); has_sub_graph = false; @@ -371,4 +396,95 @@ Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names } return SUCCESS; } + +Status GEPass::RunPassesNodeOnce(NodePtr &node, const NamesToPass &names_to_passes, + GraphLevelState &g_state, RepassLevelState &rp_state) { + auto ret = RunPassesOnNode(node, names_to_passes, g_state, rp_state); + if (ret != SUCCESS) { + GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", node->GetName().c_str(), + node->GetType().c_str(), ret); + return ret; + } + + bool has_sub_graph = false; + ret = RunPassesOnSubGraph(node, names_to_passes, has_sub_graph); + if (ret != SUCCESS) { + GELOGE(ret, "[Run][Passes] on the sub graph of node %s failed", node->GetName().c_str()); + return ret; + } + + if (has_sub_graph) { + GELOGD("There are subgraphs on node %s, run passes for for the second time", node->GetName().c_str()); + SetFlagOption(kOptimizeAfterSubGraph, names_to_passes); + ret = RunPassesOnNode(node, names_to_passes, g_state, rp_state); + if (ret != SUCCESS) { + GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code: %u", node->GetName().c_str(), + node->GetType().c_str(), ret); + return ret; + } + + // There is only one option scene, so set and clear options around the `RunPasses` func. + // if there are more than one scene to set options, the `ClearOption` function + // should be called each time at the begin of the iteration + ClearOption(names_to_passes); + } + return SUCCESS; +} + +Status GEPass::RunPassesOnNode(NodePtr &node, const NamesToPass &names_to_passes, GraphLevelState &g_state, + RepassLevelState &rp_state) { + if (node == nullptr) { + REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid."); + GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); + return FAILED; + } + GELOGD("Begin to run pass for node %s", node->GetName().c_str()); + for (const auto &name_to_pass : names_to_passes) { + GELOGD("Begin to run pass %s for node %s", name_to_pass.first.c_str(), node->GetName().c_str()); + name_to_pass.second->init(); + auto result = name_to_pass.second->Run(node); + if (result != SUCCESS) { + REPORT_CALL_ERROR("E19999", "process pass %s on node:%s failed, ret:%u", name_to_pass.first.c_str(), + node->GetName().c_str(), result); + GELOGE(INTERNAL_ERROR, + "[Process][Pass] %s on node %s failed, result " + "%u, the passes will be terminated immediately.", + name_to_pass.first.c_str(), node->GetName().c_str(), result); + return result; + } + if (name_to_pass.second->GetNodesDeleted().count(node) > 0) { + GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), + name_to_pass.first.c_str()); + break; + } + } + + g_state.nodes_passed.insert(node); + + std::unordered_map repass_imm_nodes_to_pass_names; + std::unordered_map resume_nodes_to_pass_names; + // if multi pass add one node to repass immediately, here need to remove duplication + for (const auto &name_to_pass : names_to_passes) { + PushToRePassIfSeen(node, name_to_pass, g_state.nodes_seen, name_to_pass.second->GetNodesNeedRePass(), + rp_state.nodes_re_pass); + // collect imm_node && resume_node among these passes + for (const auto &imm_node : name_to_pass.second->GetNodesNeedRePassImmediately()) { + repass_imm_nodes_to_pass_names[imm_node].append(name_to_pass.first + ","); + } + for (const auto &resume_node : name_to_pass.second->GetNodesResume()) { + resume_nodes_to_pass_names[resume_node].append(name_to_pass.first + ","); + } + + for (const auto &suspend_node : name_to_pass.second->GetNodesSuspend()) { + GELOGD("The iteration suspend of node %s has been set by pass %s", suspend_node->GetName().c_str(), + name_to_pass.first.c_str()); + g_state.nodes_suspend.insert(suspend_node); + } + const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); + g_state.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); + } + AddImmediateRepassNodesToQueue(node, repass_imm_nodes_to_pass_names, g_state); + AddResumeNodesToQueue(resume_nodes_to_pass_names, g_state); + return SUCCESS; +} } // namespace ge diff --git a/ge/graph/passes/base_pass.h b/ge/graph/passes/base_pass.h index d0f125b2..5d744c08 100644 --- a/ge/graph/passes/base_pass.h +++ b/ge/graph/passes/base_pass.h @@ -22,7 +22,6 @@ #include #include #include - #include "framework/common/ge_inner_error_codes.h" #include "framework/common/types.h" #include "graph/compute_graph.h" @@ -61,23 +60,32 @@ class BaseNodePass { const std::unordered_set &GetNodesResume() { return nodes_resume_; } + virtual Status OnSuspendNodesLeaked() { return SUCCESS; } + void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } void ClearOptions() { options_.clear(); } void init() { nodes_need_re_pass_.clear(); - nodes_deleted_.clear(); nodes_need_re_pass_immediately_.clear(); + nodes_deleted_.clear(); nodes_suspend_.clear(); nodes_resume_.clear(); } + virtual void OnStartPassGraph(const ComputeGraphPtr &graph) { + current_graph_name_ = graph->GetName(); + } + protected: - Status IsolateAndDeleteNode(NodePtr &node, const std::vector &io_map); + const string &GetCurrentGraphName() const { + return current_graph_name_; + } + Status IsolateAndDeleteNode(NodePtr &node, const std::vector &io_map, bool is_repass_io_immediately = false); - Status IsolateAndDeleteNode(NodePtr &node, const std::initializer_list &io_map) { - return IsolateAndDeleteNode(node, std::vector(io_map)); + Status IsolateAndDeleteNode(NodePtr &node, const std::initializer_list &io_map, bool is_repass_io_immediately = false) { + return IsolateAndDeleteNode(node, std::vector(io_map), is_repass_io_immediately); } /// @@ -112,6 +120,22 @@ class BaseNodePass { } } + /// + /// Add a node and it's input/output data nodes to be optimized immediately again. + /// @param node + /// + void AddImmediateRePassNodesWithInOut(const NodePtr &node) { + auto in_nodes = node->GetInNodes(); + for (auto &in_node : in_nodes) { + AddImmediateRePassNode(in_node); + } + AddImmediateRePassNode(node); + auto out_nodes = node->GetOutNodes(); + for (auto &out_node : out_nodes) { + AddImmediateRePassNode(out_node); + } + } + /// /// If you deleted a node from the graph, especially current node. The remain /// iterate passes will continue process on the deleted node(if it can be @@ -123,23 +147,15 @@ class BaseNodePass { void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } /// - /// If you suspend a node from the graph, especially following node. The remain - /// iterate passes will stop process on the suspend node(if it can be + /// If you postpone a node from the graph, especially following node. The remain + /// iterate passes will stop process on the postpone node(if it can be /// reached by edge connections) till the last one. Obviously it is a waste of - /// time. You can add the suspend nodes by calling this function, to stop the + /// time. You can add the postpone nodes by calling this function, to stop the /// next iterations. /// @param node /// void AddNodeSuspend(const NodePtr &node) { nodes_suspend_.insert(node); } - /// - /// If you resume a node from the graph, especially following node. The remain - /// iterate passes will continue process on the resume node(if it can be - /// reached by edge connections) till the last one. - /// You can add the resume nodes by calling this function, to resume the - /// next iterations. - /// @param node - /// void AddNodeResume(const NodePtr &node) { nodes_resume_.insert(node); } bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } @@ -151,6 +167,7 @@ class BaseNodePass { std::unordered_set nodes_suspend_; std::unordered_set nodes_resume_; std::map options_; + std::string current_graph_name_; }; using NamesToPass = std::vector>; @@ -160,12 +177,60 @@ class GEPass { explicit GEPass(ComputeGraphPtr &graph) : graph_(graph), root_graph_(graph), depth_(1) {} virtual ~GEPass() = default; Status Run(const NamesToPass &names_to_passes); + /* +* todo +* OneGraph: nodes_deleted, nodes_seen, nodes_passed, nodes_suspended +* RePass: nodes_re_pass +* GraphOneTime: nodes_last +* NodeOneTime: nodes_re_pass_immediately, nodes_resume +*/ + struct GraphLevelState { + std::unordered_set nodes_deleted; + std::unordered_set nodes_seen; + std::unordered_set nodes_passed; + std::unordered_set nodes_suspend; + std::unordered_set nodes_last; + std::deque nodes; + int re_pass_times; + + void AddNodeToQueueFront(NodePtr node) { + nodes_seen.insert(node.get()); + nodes.emplace_front(std::move(node)); + } + + void AddNodeToQueue(NodePtr node) { + nodes_seen.insert(node.get()); + nodes.emplace_back(std::move(node)); + } + void AddNodeToQueueIfNotSeen(NodePtr node) { + if (nodes_seen.insert(node.get()).second) { + nodes.emplace_back(std::move(node)); + } + } + NodePtr PopFront() { + NodePtr node = nodes.front(); + nodes.pop_front(); + return node; + } + }; + struct RepassLevelState { + std::unordered_set nodes_re_pass; + }; + struct GraphOneTimeLevelState { + std::unordered_set nodes_last; + }; private: GEPass(ComputeGraphPtr &graph, ComputeGraphPtr &root_graph, int depth) : graph_(graph), root_graph_(root_graph), depth_(depth) {} + Status RunPassesNodeOnce(NodePtr &node, const NamesToPass &names_to_passes, + GraphLevelState &g_state, RepassLevelState &rp_state); + Status RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLevelState &g_state); Status RunPassesOneGraph(const NamesToPass &names_to_passes); Status RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph); + Status RunPassesOnNode(NodePtr &node, const NamesToPass &names_to_passes, GraphLevelState &g_state, + RepassLevelState &rp_state); + Status HandleLeakedSuspendNodes(const NamesToPass &names_to_passes, GraphLevelState &g_state); ComputeGraphPtr graph_; ComputeGraphPtr root_graph_; int depth_; diff --git a/ge/graph/passes/infer_base_pass.cc b/ge/graph/passes/infer_base_pass.cc index 25c45677..231fb967 100644 --- a/ge/graph/passes/infer_base_pass.cc +++ b/ge/graph/passes/infer_base_pass.cc @@ -84,8 +84,11 @@ Status InferBasePass::Run(NodePtr &node) { bool InferBasePass::NeedInfer(const NodePtr &node) const { return true; } void InferBasePass::AddChangedNodesImmediateRepass(const std::set &changed_nodes) { -// need passed_nodes set to solve the problem that multi-input operators do repass in advance. -// when there is passed_nodes set, wo should call AddImmediateRePassNode for all nodes in changed_nodes. + // need passed_nodes set to solve the problem that multi-input operators do repass in advance. + // when there is passed_nodes set, wo should call AddImmediateRePassNode for all nodes in changed_nodes. + for (const auto &node : changed_nodes) { + AddImmediateRePassNode(node); + } } graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set &changed_nodes) { diff --git a/ge/graph/passes/infershape_pass.cc b/ge/graph/passes/infershape_pass.cc index 60a2f09a..47fd77e4 100755 --- a/ge/graph/passes/infershape_pass.cc +++ b/ge/graph/passes/infershape_pass.cc @@ -1,6 +1,6 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd - * + * Copyright 2020-2021 Huawei Technologies Co., Ltd + *+ * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -22,13 +22,16 @@ #include "graph/shape_refiner.h" #include "graph/utils/graph_utils.h" #include "graph/utils/node_utils.h" -#include "graph/common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" -namespace ge { +#include "external/graph/operator_factory.h" +namespace ge { +namespace { +constexpr int kSwitchExitAnchorIndex = 0; +constexpr int kSwitchPredAnchorIndex = 1; void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { desc_str += "["; std::vector> shape_range; @@ -47,129 +50,302 @@ void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { desc_str += "},"; } } +void UpdateShapeAndDType(const GeTensorDescPtr &src, GeTensorDescPtr &dst) { + dst->SetOriginShape(src->GetOriginShape()); + dst->SetShape(src->GetShape()); + dst->SetDataType(src->GetDataType()); + dst->SetOriginDataType(src->GetOriginDataType()); + vector> src_shape_range; + src->GetShapeRange(src_shape_range); + dst->SetShapeRange(src_shape_range); + dst->SetOriginShapeRange(src_shape_range); + ge::TensorUtils::SetRealDimCnt(*dst, static_cast(src->GetShape().GetDims().size())); +} +} // namespace -std::string GetInTensorInfoWithString(const ge::NodePtr &node) { - ge::OpDescPtr op_desc = node->GetOpDesc(); +std::string InferShapePass::SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const { std::stringstream ss; - ss << "{"; - int32_t in_idx = 0; - for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { - if (input_desc == nullptr) { - in_idx++; + ss << "(shape:[" << tensor_desc->MutableShape().ToString() << "]),"; + ss << "(format:" << TypeUtils::FormatToSerialString(tensor_desc->GetFormat()) << "),"; + ss << "(dtype:" << TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()) << "),"; + ss << "(origin_shape:" << tensor_desc->GetOriginShape().ToString() << "),"; + ss << "(origin_format:" << TypeUtils::FormatToSerialString(tensor_desc->GetOriginFormat()) << "),"; + ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(tensor_desc->GetOriginDataType()) << "),"; + string range_str; + SerialShapeRange(tensor_desc, range_str); + ss << "(shape_range:" << range_str << ")"; + return ss.str(); +} +Status InferShapePass::SuspendV1LoopExitNodes(const NodePtr &node) { + if (node->GetType() != SWITCH) { + return SUCCESS; + } + auto pred_node = NodeUtils::GetInDataNodeByIndex(*node, kSwitchPredAnchorIndex); + GE_CHECK_NOTNULL(pred_node); + if (pred_node->GetType() != LOOPCOND) { + return SUCCESS; + } + + for (const auto &anchor_2_node : NodeUtils::GetOutDataNodesWithAnchorByIndex(*node, kSwitchExitAnchorIndex)) { + GELOGI("Found v1 loop when infershape, suspend Exit node %s, type %s.", anchor_2_node.second->GetName().c_str(), + anchor_2_node.second->GetType().c_str()); + auto iter = graphs_2_suspend_nodes_.find(GetCurrentGraphName()); + if (iter == graphs_2_suspend_nodes_.end()) { continue; } - if (in_idx > 0) { - ss << " "; + auto &suspend_nodes = graphs_2_suspend_nodes_[GetCurrentGraphName()]; + if (suspend_nodes.nodes_set.insert(anchor_2_node.second).second) { + suspend_nodes.nodes.push(anchor_2_node.second); + AddNodeSuspend(anchor_2_node.second); } - ss << "input_" << in_idx << " " << "tensor: ["; - ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),"; - ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),"; - ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),"; - ss << "(origin_shape:" << input_desc->GetOriginShape().ToString() << "),"; - ss << "(origin_format:" << TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) << "),"; - ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) << "),"; - string range_str; - SerialShapeRange(input_desc, range_str); - ss << "(shape_range:" << range_str << ")]"; - in_idx++; } - return ss.str(); + return SUCCESS; } -Status InferShapePass::Run(NodePtr &node) { - // kOptimizeAfterSubGraph exist means after subgraph - auto ret = ShapeRefiner::InferShapeAndType(node, !OptionExists(kOptimizeAfterSubGraph)); - if (ret != GRAPH_SUCCESS) { - // select INFERSHAPE failed info - auto graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - auto root_graph = ge::GraphUtils::FindRootGraph(graph); - GE_CHECK_NOTNULL(root_graph); - analyzer::DataInfo analyze_info{root_graph->GetSessionID(), root_graph->GetGraphID(), - analyzer::INFER_SHAPE, node, "InferShapeFailed!"}; - (void)Analyzer::GetInstance()->DoAnalyze(analyze_info); - (void)Analyzer::GetInstance()->SaveAnalyzerDataToFile(root_graph->GetSessionID(), - root_graph->GetGraphID()); - - REPORT_CALL_ERROR("E19999", "Call InferShapeAndType for node:%s(%s) failed, input_tensor:%s", - node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str()); - GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "[Call][InferShapeAndType] for node:%s(%s) failed, input_tensor:%s", - node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str()); - return GE_GRAPH_INFERSHAPE_FAILED; - } - - GE_CHK_STATUS_RET_NOLOG(RePassLoopNode(node)); - bool need_repass = false; - auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_repass); - if (has_attr) { - if (!OptionExists(kOptimizeAfterSubGraph)) { - return SUCCESS; +Status InferShapePass::Infer(NodePtr &node) { + auto ret = SuspendV1LoopExitNodes(node); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to suspend exit node in v1 control flow loop."); + return ret; + } + bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); + auto opdesc = node->GetOpDesc(); + if (node->Verify() != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Verifying %s failed.", node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[Call][Verify] Verifying %s failed.", node->GetName().c_str()); + return GRAPH_FAILED; + } + Operator op = OpDescUtils::CreateOperatorFromNode(node); + + if (!is_unknown_graph) { + auto inference_context = ShapeRefiner::CreateInferenceContext(node); + GE_CHECK_NOTNULL(inference_context); + vector marks; + inference_context->GetMarks(marks); + GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), marks.size()); + op.SetInferenceContext(inference_context); + } + + graphStatus status = CallInferShapeFunc(node, op); + if (status != GRAPH_NODE_NEED_REPASS && status != GRAPH_PARAM_INVALID && status != GRAPH_SUCCESS) { + // node like netoutput return param_invalid, but valid ? + REPORT_CALL_ERROR("E19999", "%s call infer function failed.", node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[Call][InferFunction] failed, node:%s.", node->GetName().c_str()); + return GRAPH_FAILED; + } + UpdateCurNodeOutputDesc(node); + if (!is_unknown_graph) { + auto ctx_after_infer = op.GetInferenceContext(); + if (ctx_after_infer != nullptr) { + vector marks; + ctx_after_infer->GetMarks(marks); + GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), marks.size()); + if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !marks.empty()) { + GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), + marks.size()); + ShapeRefiner::PushToContextMap(node, ctx_after_infer); + } } - if (need_repass) { - AddImmediateRePassNode(node); - GELOGD("Node %s need repass immediately.", node->GetName().c_str()); - } else { - // clear attr on while - node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); + } + return (status == GRAPH_NODE_NEED_REPASS) ? GRAPH_NODE_NEED_REPASS : GRAPH_SUCCESS; +} + +bool InferShapePass::SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) { + // check shape range + vector> src_shape_range; + vector> dst_shape_range; + src->GetShapeRange(src_shape_range); + dst->GetShapeRange(dst_shape_range); + if (src_shape_range.size() != dst_shape_range.size()) { + GELOGI("Src shape range size is %zu, dst shape range size is %zu, not same.", src_shape_range.size(), + dst_shape_range.size()); + return false; + } + for (size_t i = 0; i < src_shape_range.size(); ++i) { + if (src_shape_range[i].first != dst_shape_range[i].first || + src_shape_range[i].second != dst_shape_range[i].second) { + GELOGI("Current dim %zu. Src shape range is [%lu-%lu], dst shape range is [%lu-%lu], not same.", + i, src_shape_range[i].first, src_shape_range[i].second, dst_shape_range[i].first, dst_shape_range[i].second); + return false; } } + + // check shape + auto src_shape = src->GetShape(); + auto dst_shape = dst->GetShape(); + if (src_shape.GetDims() != dst_shape.GetDims() || src->GetOriginShape().GetDims() != dst->GetOriginShape().GetDims() || + src->GetDataType() != dst->GetDataType() || src->GetOriginDataType() != dst->GetOriginDataType()) { + GELOGD( + "Src shape is %s, origin_shape is %s, data_type is %s, origin data_type is %s; " + "Dst shape is %s, origin_shape is %s, data_type is %s, original data_type is %s, not same.", + src_shape.ToString().c_str(), src->GetOriginShape().ToString().c_str(), + TypeUtils::DataTypeToSerialString(src->GetDataType()).c_str(), + TypeUtils::DataTypeToSerialString(src->GetOriginDataType()).c_str(), dst_shape.ToString().c_str(), + dst->GetOriginShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(dst->GetDataType()).c_str(), + TypeUtils::DataTypeToSerialString(dst->GetOriginDataType()).c_str()); + return false; + } + return true; +} + +void InferShapePass::UpdateCurNodeOutputDesc(NodePtr &node) { + auto op_desc = node->GetOpDesc(); + for (const auto &out_anchor : node->GetAllOutDataAnchors()) { + auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); + GE_IF_BOOL_EXEC(output_tensor == nullptr, continue); + GE_IF_BOOL_EXEC(output_tensor->MutableShape().GetDims().empty(), + output_tensor->SetOriginShape(output_tensor->GetShape())); + + ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast(output_tensor->GetOriginShape().GetDims() + .size())); + output_tensor->SetOriginDataType(output_tensor->GetDataType()); + // set output origin shape range + std::vector> range; + (void)output_tensor->GetShapeRange(range); + output_tensor->SetOriginShapeRange(range); + GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", + node->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), + TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), + TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); + } +} + +graphStatus InferShapePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { + changed = false; + if (SameTensorDesc(src, dst)) { + GELOGD("Peer dst tensor_desc is same as src tensor_desc. No need update."); + return SUCCESS; + } + + changed = true; + UpdateShapeAndDType(src, dst); + GELOGD( + "UpdatePeerInputDesc from src Node: shape: [%s], datatype: %s, original datatype is %s." + "To dst Node: shape: [%s], datatype: %s, original datatype is %s.", + src->GetShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(src->GetDataType()).c_str(), + TypeUtils::DataTypeToSerialString(src->GetOriginDataType()).c_str(), dst->GetShape().ToString().c_str(), + TypeUtils::DataTypeToSerialString(dst->GetDataType()).c_str(), + TypeUtils::DataTypeToSerialString(dst->GetOriginDataType()).c_str()); return SUCCESS; } -Status InferShapePass::RePassLoopNode(const NodePtr &node) { - const auto RePassNode = [&](const std::set &re_pass_types) { - for (auto &n : node->GetOutDataNodes()) { - GE_CHECK_NOTNULL(n); - std::string node_type; - GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "[Get][OriginalType] of node:%s failed.", n->GetName().c_str()); - if (re_pass_types.count(node_type) > 0) { - AddImmediateRePassNode(n); - (void)AttrUtils::SetBool(n->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, false); - GELOGD("Node %s need repass immediately after %s.", n->GetName().c_str(), node->GetName().c_str()); - } +graphStatus InferShapePass::CallInferShapeFunc(NodePtr &node, Operator &op) { + auto op_desc = node->GetOpDesc(); + const auto &op_type = op_desc->GetType(); + auto ret = op_desc->CallInferFunc(op); + if (ret == GRAPH_PARAM_INVALID) { + // Op ir no infer func, try to get infer func from operator factory + auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType().c_str()); + if (node_op.IsEmpty()) { + GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str()); + return ret; } - return SUCCESS; - }; - - const auto ExProcNode = [&](const std::set &proc_types, - const std::function &proc_func, - const std::string &info) { - for (auto &n : node->GetOutDataNodes()) { - GE_CHECK_NOTNULL(n); - std::string node_type; - GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "[Get][OriginalType] of node:%s failed.", n->GetName().c_str()); - if (proc_types.count(node_type) > 0) { - proc_func(this, n); - GELOGD("Node %s %s after %s.", n->GetName().c_str(), info.c_str(), node->GetName().c_str()); + + GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str()); + auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); + node_op.BreakConnect(); + if (temp_op_desc == nullptr) { + REPORT_CALL_ERROR("E19999", "GetOpDescFromOperator failed, return nullptr."); + GELOGE(GRAPH_FAILED, "[Get][OpDesc] temp op desc is null"); + return GRAPH_FAILED; + } + if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) { + GELOGW("InferShapeAndType UpdateInputName failed"); + for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) { + if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) { + break; + } + return GRAPH_SUCCESS; } } - return SUCCESS; - }; - - std::string node_type; - GE_CHK_STATUS_RET(GetOriginalType(node, node_type), - "[Get][OriginalType] of node:%s failed.", node->GetName().c_str()); - if (kNextIterationOpTypes.count(node_type) > 0) { - return RePassNode(kMergeOpTypes); // Re-Pass Merge + if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) { + GELOGW("InferShapeAndType UpdateOutputName failed"); + } + op_desc->AddInferFunc(temp_op_desc->GetInferFunc()); + ret = op_desc->CallInferFunc(op); + GELOGI("op CallInferFunc second. ret: %u", ret); } + return ret; +} - if (kMergeOpTypes.count(node_type) > 0) { - if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { - node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); - return RePassNode(kSwitchOpTypes); // Re-Pass Switch +graphStatus InferShapePass::UpdateOutputFromSubgraphs(const std::vector &src, GeTensorDescPtr &dst) { + GELOGD("Enter update parent node shape for class branch op process"); + // check sub_graph shape.If not same ,do unknown shape process + auto ref_out_tensor = src.at(0); + ge::GeShape &ref_out_tensor_shape = ref_out_tensor->MutableShape(); + for (auto &tensor : src) { + if (ref_out_tensor->GetDataType() != tensor->GetDataType()) { + REPORT_INNER_ERROR("E19999", "Does not support diff dtype among all ref output, shape:%s", + ref_out_tensor_shape.ToString().c_str()); + GELOGE(GRAPH_FAILED, "[Check][Param] node does not support diff dtype output"); + return GRAPH_FAILED; + } + auto shape = tensor->MutableShape(); + if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { + GELOGD("Shape from subgraph size: %lu, ref_out_tensor_shape size: %lu", shape.GetShapeSize(), + ref_out_tensor_shape.GetShapeSize()); + ref_out_tensor_shape = GeShape(UNKNOWN_RANK); + break; } + for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) { + if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) { + continue; + } + GELOGD("j: %zu ,shape from subgraph size: %lu, ref_out_tensor_shape size: %lu", j, shape.GetShapeSize(), + ref_out_tensor_shape.GetShapeSize()); + (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM); + } + } + UpdateShapeAndDType(ref_out_tensor, dst); + return GRAPH_SUCCESS; +} +graphStatus InferShapePass::UpdateOutputFromSubgraphsForMultiDims(const std::vector &src, + GeTensorDescPtr &dst) { + // check sub_graph shape. Get max for update. + if (src.empty()) { + // TODO LOG return SUCCESS; } - if (kSwitchOpTypes.count(node_type) > 0) { - if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { - node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); - return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeResume, "need resume"); // Resume Exit - } else { - return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeSuspend, "need suspend"); // Suspend Exit + int64_t max_size = 0; + size_t max_shape_index = 0; + auto &ref_out_tensor = src.at(0); + for (size_t j = 0; j < src.size(); ++j) { + auto &tensor = src.at(j); + if (ref_out_tensor->GetDataType() != tensor->GetDataType()) { + REPORT_INNER_ERROR("E19999", "node does not support diff dtype among all ref output"); + GELOGE(GRAPH_FAILED, "[Check][Param] node does not support diff dtype among all ref output"); + return GRAPH_FAILED; + } + + auto shape = tensor->MutableShape(); + int64_t size = 1; + for (auto dim : shape.GetDims()) { + if (dim != 0 && INT64_MAX / dim < size) { + REPORT_INNER_ERROR("E19999", "The shape:%s size overflow", shape.ToString().c_str()); + GELOGE(PARAM_INVALID, "[Check][Overflow] The shape size overflow"); + return PARAM_INVALID; + } + size *= dim; } - } + if (size > max_size) { + max_size = size; + max_shape_index = j; + } + } + UpdateShapeAndDType(src.at(max_shape_index), dst); + return GRAPH_SUCCESS; +} +Status InferShapePass::OnSuspendNodesLeaked() { + auto iter = graphs_2_suspend_nodes_.find(GetCurrentGraphName()); + if (iter == graphs_2_suspend_nodes_.end()) { + GELOGW("There is no suspend nodes on graph %s", GetCurrentGraphName().c_str()); + return SUCCESS; + } + if (!iter->second.nodes.empty()) { + AddNodeResume(iter->second.PopSuspendedNode()); + } return SUCCESS; } } // namespace ge diff --git a/ge/graph/passes/infershape_pass.h b/ge/graph/passes/infershape_pass.h index 9c5d432d..02af16ca 100644 --- a/ge/graph/passes/infershape_pass.h +++ b/ge/graph/passes/infershape_pass.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,22 +17,40 @@ #ifndef GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ #define GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ -#include "graph/passes/base_pass.h" +#include "graph/passes/infer_base_pass.h" +#include namespace ge { -class InferShapePass : public BaseNodePass { +class InferShapePass : public InferBasePass { public: - /// - /// Entry of the InferShapePass optimizer - /// @param [in] graph: Input ComputeGraph - /// @return SUCCESS: Execution succeed - /// @return OTHERS: Execution failed - /// @author - /// - Status Run(ge::NodePtr &node) override; + std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override; + graphStatus Infer(NodePtr &node) override; + + graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override; + graphStatus UpdateOutputFromSubgraphs(const std::vector &src, GeTensorDescPtr &dst) override; + graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector &src, + GeTensorDescPtr &dst) override; + + Status OnSuspendNodesLeaked() override; private: - Status RePassLoopNode(const NodePtr &node); + graphStatus CallInferShapeFunc(NodePtr &node, Operator &op); + bool SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst); + void UpdateCurNodeOutputDesc(NodePtr &node); + Status SuspendV1LoopExitNodes(const NodePtr &node); + struct SuspendNodes { + std::stack nodes; + std::unordered_set nodes_set; + + NodePtr PopSuspendedNode() { + auto top_node = nodes.top(); + nodes.pop(); + nodes_set.erase(top_node); + return top_node; + } + }; + std::map graphs_2_suspend_nodes_; + }; } // namespace ge #endif // GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index d7f33b4b..99b4d6cc 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -1999,6 +1999,22 @@ Status GraphPrepare::CheckUserInput(const std::vector &user_input) { Status GraphPrepare::InferShapeForPreprocess() { GELOGI("Start infershape for preprocess."); + // Prepare dummy_shape for v1 control_flow op before infershape + for (const auto &node : compute_graph_->GetAllNodes()) { + string type; + GetOriginalType(node, type); + if (type == MERGE || type == REFMERGE) { + for (size_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) { + GELOGD("Prepare for infershape: update %s input_shape as dummy.", node->GetName().c_str()); + NodeUtils::UpdateInputShape(*node, i, GeShape(DUMMY_SHAPE)); + } + } else if (type == WHILE) { + for (size_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) { + GELOGD("Prepare for infershape: update %s output_shape as dummy.", node->GetName().c_str()); + NodeUtils::UpdateOutputShape(*node, i, GeShape(DUMMY_SHAPE)); + } + } + } GEPass ge_passes(compute_graph_); NamesToPass names_to_passes; AssertPass assert_pass; diff --git a/tests/ut/ge/graph/passes/addn_pass_unittest.cc b/tests/ut/ge/graph/passes/addn_pass_unittest.cc index 6107a7d8..54a8819d 100644 --- a/tests/ut/ge/graph/passes/addn_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/addn_pass_unittest.cc @@ -72,7 +72,7 @@ TEST(UtestGraphPassesAddnPass, null_pass) { AddNPass *addn_pass = nullptr; NamesToPass names_to_pass; names_to_pass.emplace_back("Test", addn_pass); - EXPECT_EQ(pass.Run(names_to_pass), SUCCESS); + EXPECT_NE(pass.Run(names_to_pass), SUCCESS); } TEST(UtestGraphPassesAddnPass, null_graph) { diff --git a/tests/ut/ge/graph/passes/base_pass_unittest.cc b/tests/ut/ge/graph/passes/base_pass_unittest.cc index c687e07f..133d6dc9 100644 --- a/tests/ut/ge/graph/passes/base_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/base_pass_unittest.cc @@ -17,7 +17,6 @@ #include #include #include -#include #include #include "gtest/gtest.h" @@ -26,8 +25,6 @@ #include "graph/passes/base_pass.h" #undef protected -#include "external/graph/ge_error_codes.h" -#include "framework/common/ge_inner_error_codes.h" #include "framework/common/types.h" #include "graph/node.h" #include "graph/utils/graph_utils.h" @@ -67,6 +64,54 @@ class UtestTestPass : public BaseNodePass { names_to_add_repass_.erase(iter); } } + + iter = names_to_add_repass_immediate_.find(node->GetName()); + if (iter != names_to_add_repass_immediate_.end()) { + auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); + for (const auto &node_name : iter->second) { + for (auto &node_re_pass : all_nodes) { + if (node_re_pass->GetName() == node_name) { + AddImmediateRePassNode(node_re_pass); + break; + } + } + } + if (!dead_loop_) { + names_to_add_repass_immediate_.erase(iter); + } + } + + iter = names_to_add_suspend_.find(node->GetName()); + if (iter != names_to_add_suspend_.end()) { + auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); + for (const auto &node_name : iter->second) { + for (auto &node_re_pass : all_nodes) { + if (node_re_pass->GetName() == node_name) { + AddNodeSuspend(node_re_pass); + break; + } + } + } + if (!dead_loop_) { + names_to_add_suspend_.erase(iter); + } + } + + iter = names_to_add_resume_.find(node->GetName()); + if (iter != names_to_add_resume_.end()) { + auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); + for (const auto &node_name : iter->second) { + for (auto &node_re_pass : all_nodes) { + if (node_re_pass->GetName() == node_name) { + AddNodeResume(node_re_pass); + break; + } + } + } + if (!dead_loop_) { + names_to_add_resume_.erase(iter); + } + } // simulate infershape pass if(node->GetType() == WHILE){ bool need_repass = false; @@ -85,6 +130,20 @@ class UtestTestPass : public BaseNodePass { } return SUCCESS; } + + Status OnSuspendNodesLeaked() override { + // resume all node remain in suspend_nodes when leaked + auto compute_graph = (iter_nodes_.size() > 0) ? iter_nodes_[0]->GetOwnerComputeGraph() : nullptr; + if (compute_graph == nullptr) { + return SUCCESS; + } + + for (const auto &node_name : names_to_add_resume_onleaked_) { + auto node_to_resume = compute_graph->FindNode(node_name); + AddNodeResume(node_to_resume); + } + return SUCCESS; + } void clear() { iter_nodes_.clear(); } std::vector GetIterNodes() { return iter_nodes_; } @@ -94,12 +153,31 @@ class UtestTestPass : public BaseNodePass { void AddDelNodeName(const std::string &iter_node, const std::string &del_node) { names_to_add_del_[iter_node].insert(del_node); } + void AddRePassImmediateNodeName(const std::string &iter_node, const std::string &re_pass_node) { + names_to_add_repass_immediate_[iter_node].insert(re_pass_node); + } + + void AddSuspendNodeName(const std::string &iter_node, const std::string &suspend_node) { + names_to_add_suspend_[iter_node].insert(suspend_node); + } + void AddResumeNodeName(const std::string &iter_node, const std::string &resume_node) { + names_to_add_resume_[iter_node].insert(resume_node); + } + void AddResumeNodeNameOnLeaked(const std::string &resume_node) { + names_to_add_resume_onleaked_.insert(resume_node); + } + unsigned int GetRunTimes() { return run_times_; } private: std::vector iter_nodes_; std::map> names_to_add_del_; std::map> names_to_add_repass_; + std::map> names_to_add_repass_immediate_; + std::map> names_to_add_suspend_; + std::map> names_to_add_resume_; + std::unordered_set names_to_add_resume_onleaked_; + bool dead_loop_; unsigned int run_times_; }; @@ -200,6 +278,26 @@ ComputeGraphPtr BuildGraph3() { return builder.GetGraph(); } +/// cast1--shape1 +/// / +/// data1 +/// \ +/// transdata1--shape2 +ComputeGraphPtr BuildGraph4() { + auto builder = ut::GraphBuilder("g1"); + auto data1 = builder.AddNode("data1", DATA, 0, 1); + auto cast1 = builder.AddNode("cast1", CAST, 1, 1); + auto shape1 = builder.AddNode("shape1", SHAPE, 1, 1); + auto transdata1 = builder.AddNode("transdata1", TRANSDATA, 1, 1); + auto shape2 = builder.AddNode("shape2", SHAPE, 1, 1); + + builder.AddDataEdge(data1, 0, cast1, 0); + builder.AddDataEdge(data1, 0, transdata1, 0); + builder.AddDataEdge(cast1, 0, shape1, 0); + builder.AddDataEdge(transdata1, 0, shape2, 0); + return builder.GetGraph(); +} + void CheckIterOrder(UtestTestPass *pass, std::vector> &nodes_layers) { std::unordered_set layer_nodes; size_t layer_index = 0; @@ -509,15 +607,369 @@ ComputeGraphPtr BuildWhileGraph1() { } TEST_F(UTESTGraphPassesBasePass, while_infershape) { -NamesToPass names_to_pass; -auto test_pass = UtestTestPass(); -names_to_pass.push_back(std::make_pair("test", &test_pass)); + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + auto graph = BuildWhileGraph1(); + auto ge_pass = GEPass(graph); + auto while_node = graph->FindNode("while"); + EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); +} + +TEST_F(UTESTGraphPassesBasePass, re_pass_pre_node_immediately) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // repass pre_node immediately + test_pass->AddRePassImmediateNodeName("reshape1", "add1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + + EXPECT_EQ(test_pass->GetIterNodes().size(), 9);// todo + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "add1", "sum1"}); + CheckIterOrder(test_pass, layers); +} + +TEST_F(UTESTGraphPassesBasePass, re_pass_cur_node_immediately) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // repass cur_node immediately + test_pass->AddRePassImmediateNodeName("reshape1", "reshape1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + + EXPECT_EQ(test_pass->GetIterNodes().size(), 9); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1"}); + layers.push_back({"reshape1", "sum1"}); + CheckIterOrder(test_pass, layers); +} + +TEST_F(UTESTGraphPassesBasePass, re_pass_next_node_immediately) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // repass next_node immediately + test_pass->AddRePassImmediateNodeName("reshape1", "sum1"); + // repass node after next_node immediately + test_pass->AddRePassImmediateNodeName("add1", "sum1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + + EXPECT_EQ(test_pass->GetIterNodes().size(), 8); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1"}); + CheckIterOrder(test_pass, layers); +} +/** + * A->B->C + * if node B suspend its pre_node A, and C resume A, it is a useless operation, so iter_order should follow normal order + * when C resuem A, A will pass again. + */ +TEST_F(UTESTGraphPassesBasePass, B_suspend_pre_node_A_then_C_resume_A) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // add1->reshape1->sum1 + test_pass->AddSuspendNodeName("reshape1", "add1"); + test_pass->AddResumeNodeName("sum1", "add1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 9); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1"}); + layers.push_back({"add1"}); + CheckIterOrder(test_pass, layers); +} + +/** + * A->B->C + * if node B suspend its pre_node A, and B resume A, it is a useless operation, so iter_order should follow normal order + * when B resuem A, A will pass again. + */ +TEST_F(UTESTGraphPassesBasePass, B_suspend_pre_node_A_then_B_resume_A) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // add1->reshape1->sum1 + test_pass->AddSuspendNodeName("reshape1", "add1"); + test_pass->AddResumeNodeName("reshape1", "add1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 9); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1", "add1"}); + CheckIterOrder(test_pass, layers); +} + +/** + * A->B->C + * if node B resume C(which is not suspended), it is a useless operation, C will not pass. + */ +TEST_F(UTESTGraphPassesBasePass, B_resume_node_not_suspended) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // add1->reshape1->sum1 + test_pass->AddResumeNodeName("reshape1", "sum1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 8); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1"}); + CheckIterOrder(test_pass, layers); +} -auto graph = BuildWhileGraph1(); -auto ge_pass = GEPass(graph); -auto while_node = graph->FindNode("while"); -EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1); -EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); +/** + * A->B->C + * if node B suspend its pre_node A, it is a useless operation, so iter_order should follow normal order + * because nobody resume it ,which means A is a leaked node, so return fail + */ +TEST_F(UTESTGraphPassesBasePass, suspend_pre_node_nobody_resume_it_return_failed) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + // suspend pre_node immediately + test_pass.AddSuspendNodeName("reshape1", "add1"); + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), INTERNAL_ERROR); +} + +/** + * A->B->C + * if node B suspend its pre_node A, it is a useless operation, + * so iter_order should follow normal order + * resume A on leaked, which means A will pass again + */ +TEST_F(UTESTGraphPassesBasePass, suspend_pre_node_resume_it_onleaked) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("reshape1", "add1"); + test_pass->AddResumeNodeNameOnLeaked("add1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1"}); + layers.push_back({"add1"}); + CheckIterOrder(test_pass, layers); } + +/// cast1--shape1 +/// / +/// data1 +/// \ +/// transdata1--shape2 +/** + * suspend cur node + * cast1 suspend itself, shape2 resume cast1 + * iter order follows : data1; cast1,transdata1; shape2; cast1 ; shape1 + */ +TEST_F(UTESTGraphPassesBasePass, cast1_suspend_cur_node_shape2_resume_cast1) { + auto graph = BuildGraph4(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("cast1", "cast1"); + test_pass->AddResumeNodeName("shape2", "cast1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 6); + std::vector> layers; + layers.push_back({"data1"}); + layers.push_back({"cast1","transdata1"}); + layers.push_back({"shape2"}); + layers.push_back({"cast1", "shape1"}); + CheckIterOrder(test_pass, layers); +} +/** + * suspend cur node + * cast1 suspend itself, then resume cast1 + * iter order follows : data1; cast1,cast1,transdata1; shape2; shape1. + */ +TEST_F(UTESTGraphPassesBasePass, cast1_suspend_itslef_then_resume_itself) { + auto graph = BuildGraph4(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("cast1", "cast1"); + test_pass->AddResumeNodeName("cast1", "cast1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 6); + std::vector> layers; + layers.push_back({"data1"}); + layers.push_back({"cast1","transdata1","cast1","shape1", "shape2"}); + CheckIterOrder(test_pass, layers); +} +/** + * suspend cur node + * cast1 suspend itself, then resume cast1 on leaked + * iter order follows : data1; cast1,cast1,transdata1; shape2; shape1. + */ +TEST_F(UTESTGraphPassesBasePass, cast1_suspend_itslef_then_resume_onleaked) { + auto graph = BuildGraph4(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("cast1", "cast1"); + test_pass->AddResumeNodeNameOnLeaked("cast1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 6); + std::vector> layers; + layers.push_back({"data1"}); + layers.push_back({"cast1","transdata1", "shape2"}); + layers.push_back({"cast1","shape1"}); + CheckIterOrder(test_pass, layers); +} +/** + * suspend next node + * data1 suspend cast1, then resume cast1 on leaked + * iter order follows : data1; transdata1, shape2; cast1, shape1. + */ +TEST_F(UTESTGraphPassesBasePass, data1_suspend_cast1_resume_cast1_onleaked) { + auto graph = BuildGraph4(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("data1", "cast1"); + test_pass->AddResumeNodeNameOnLeaked("cast1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 5); + std::vector> layers; + layers.push_back({"data1"}); + layers.push_back({"transdata1", "shape2"}); + layers.push_back({"cast1","shape1"}); + CheckIterOrder(test_pass, layers); +} + +/** + * suspend next node + * data1 suspend cast1, nobody resume it + * iter order follows : data1; transdata1, shape2; + * run ret is failed ,because node leaked + */ +TEST_F(UTESTGraphPassesBasePass, data1_suspend_cast1_nobody_resume) { + auto graph = BuildGraph4(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("data1", "cast1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), INTERNAL_ERROR); + EXPECT_EQ(test_pass->GetIterNodes().size(), 3); +} + + +TEST_F(UTESTGraphPassesBasePass, re_pass_pre_node_immediately) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + // repass pre_node immediately + test_pass.AddRePassImmediateNodeName("reshape1", "add1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 9);// todo + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "add1", "sum1"}); + CheckIterOrder(&test_pass, layers); +} +/// sum1 +/// / \. +/// / \. +/// / \. +/// reshape1 addn1 +/// | c | +/// add1 <--- shape1 +/// / \ | +/// | | | +/// data1 const1 const2 +TEST_F(UTESTGraphPassesBasePass, re_pass_cur_node_immediately) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + // repass cur_node immediately + test_pass.AddRePassImmediateNodeName("reshape1", "reshape1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 9);// todo + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1"}); + layers.push_back({"reshape1", "sum1"}); + CheckIterOrder(&test_pass, layers); +} + +TEST_F(UTESTGraphPassesBasePass, re_pass_next_node_immediately) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + // repass next_node immediately + test_pass.AddRePassImmediateNodeName("reshape1", "sum1"); + // repass node after next_node immediately + test_pass.AddRePassImmediateNodeName("add1", "sum1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 8);// todo + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1"}); + CheckIterOrder(&test_pass, layers); +} +/* +TEST_F(UTESTGraphPassesBasePass, suspend_pre_node) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + // repass next_node immediately + test_pass.AddRePassNodeName("reshape1", "sum1"); + // repass node after next_node immediately + test_pass.AddRePassNodeName("add1", "sum1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 8);// todo + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1"}); + CheckIterOrder(&test_pass, layers); +}*/ } // namespace ge diff --git a/tests/ut/ge/graph/passes/infershape_pass_unittest.cc b/tests/ut/ge/graph/passes/infershape_pass_unittest.cc index 13e66c50..080894ee 100644 --- a/tests/ut/ge/graph/passes/infershape_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/infershape_pass_unittest.cc @@ -29,13 +29,77 @@ using namespace std; using namespace testing; namespace ge { +namespace { +// do nothing stub infer_func +const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; }; +// infer from input to output stub infer_func (input size == output size) +const auto stub_mapping_func = [](Operator &op) { + size_t in_num = op.GetInputsSize(); + for (size_t i = 0; i < in_num; ++i) { + auto in_desc = op.GetInputDesc(i); + auto out_desc = op.GetOutputDesc(i); + out_desc.SetShape(in_desc.GetShape()); + out_desc.SetDataType(in_desc.GetDataType()); + op.UpdateOutputDesc(out_desc.GetName(), out_desc); + } + return GRAPH_SUCCESS; +}; +// merge infer_func + +// while infer_func +const auto while_infer_func = [](Operator &op) { + size_t in_num = op.GetInputsSize(); + size_t out_num = op.GetOutputsSize(); + if (in_num != out_num) { + return GRAPH_FAILED; + } + bool need_infer_again = false; + for (size_t i = 0; i < in_num; ++i) { + auto in_desc = op.GetDynamicInputDesc("input", i); + auto out_desc = op.GetDynamicOutputDesc("output", i); + auto data_shape = in_desc.GetShape(); + auto out_shape = out_desc.GetShape(); + if(out_shape.GetDims() == DUMMY_SHAPE){ + return GRAPH_SUCCESS; + } + // check datatype between output and input + if (in_desc.GetDataType() != out_desc.GetDataType()) { + return GRAPH_FAILED; + } + + if (data_shape.GetDims() != out_shape.GetDims()) { + need_infer_again = true; + if (data_shape.GetDimNum() != out_shape.GetDimNum()) { + in_desc.SetUnknownDimNumShape(); + } else { + size_t data_dim_num = data_shape.GetDimNum(); + std::vector> data_shape_range = {data_dim_num, std::make_pair(1, UNKNOWN_DIM)}; + for (size_t j = 0; j < data_dim_num; ++j) { + if (data_shape.GetDim(j) != out_shape.GetDim(j)) { + data_shape.SetDim(j, UNKNOWN_DIM); + } + if (data_shape.GetDim(j) != UNKNOWN_DIM) { + data_shape_range[j] = std::make_pair(data_shape.GetDim(j), data_shape.GetDim(j)); + } + } + in_desc.SetShape(data_shape); + in_desc.SetShapeRange(data_shape_range); + } + op.UpdateDynamicOutputDesc("output", i, in_desc); + op.UpdateDynamicInputDesc("input", i, in_desc); + } + } + return need_infer_again ? GRAPH_NODE_NEED_REPASS : GRAPH_SUCCESS; +}; +} class UtestGraphInfershapePass : public testing::Test { protected: void SetUp() {} void TearDown() {} }; -static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { +static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num, + std::function infer_func = stub_func) { OpDescPtr op_desc = std::make_shared(name, type); op_desc->SetStreamId(0); static int32_t index = 0; @@ -61,14 +125,11 @@ static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string op_desc->SetWorkspaceBytes({}); op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); - const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; }; - op_desc->AddInferFunc(stub_func); - op_desc->AddInferFormatFunc(stub_func); - op_desc->AddVerifierFunc(stub_func); - + op_desc->AddInferFunc(infer_func); return graph.AddNode(op_desc); } +/* TEST_F(UtestGraphInfershapePass, infershape_pass_failed) { GeTensorDesc ge_tensor_desc(GeShape({-2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); string type = "AddN"; @@ -82,6 +143,7 @@ TEST_F(UtestGraphInfershapePass, infershape_pass_failed) { InferShapePass infershape_pass; EXPECT_EQ(infershape_pass.Run(addn_node), GE_GRAPH_INFERSHAPE_FAILED); } +*/ TEST_F(UtestGraphInfershapePass, delete_need_infer_again) { auto graph = std::make_shared("test"); @@ -94,7 +156,43 @@ TEST_F(UtestGraphInfershapePass, delete_need_infer_again) { infershape_pass.options_[kOptimizeAfterSubGraph] = "yes"; EXPECT_EQ(infershape_pass.Run(no_op_node), SUCCESS); } +TEST_F(UtestGraphInfershapePass, infer_from_pre_to_next) { + /* + * cast->shape + */ + auto graph = std::make_shared("test_infer_shape"); + auto data1 = CreateNode(*graph, "dataq", DATA, 0, 1); + auto cast1 = CreateNode(*graph, "cast1", CAST, 1, 1, stub_mapping_func); + auto cast_in_desc = cast1->GetOpDesc()->MutableInputDesc(0); + cast_in_desc->SetShape(GeShape({1,2,3})); + cast_in_desc->SetDataType(DT_INT32); + auto transdata1 = CreateNode(*graph, "transdata1", TRANSDATA, 1, 1, stub_mapping_func); + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), cast1->GetInDataAnchor(0)); + GraphUtils::AddEdge(cast1->GetOutDataAnchor(0), transdata1->GetInDataAnchor(0)); + // check before infer cast1 + auto cast_before = graph->FindNode("cast1"); + vector expect_cast1_shape_dim = {1,2,3}; + auto real_cast1_before_shape_dim = cast_before->GetOpDesc()->GetInputDesc(0).GetShape().GetDims(); + auto transdata1_before = graph->FindNode("transdata1"); + vector expect_transdata1_shape_dim = {}; + auto real_transdata1_before_shape_dim = transdata1_before->GetOpDesc()->GetInputDesc(0).GetShape().GetDims(); + EXPECT_EQ(real_cast1_before_shape_dim, expect_cast1_shape_dim); + EXPECT_EQ(real_transdata1_before_shape_dim, expect_transdata1_shape_dim); + // run infershape pass + InferShapePass infer_shape_pass; + infer_shape_pass.Run(cast_before); + // check cast1 add transdata1 to repass_immediately + infer_shape_pass.GetNodesNeedRePassImmediately(); + EXPECT_TRUE(!infer_shape_pass.GetNodesNeedRePassImmediately().empty()); + // check transdata input_shape & datatype after infer + auto transdata1_after = graph->FindNode("transdata1"); + auto transdata1_opdesc = transdata1_before->GetOpDesc(); + auto real_transdata1_after_shape_dim = transdata1_opdesc->GetInputDesc(0).GetShape().GetDims(); + EXPECT_EQ(real_transdata1_after_shape_dim, expect_cast1_shape_dim); + auto transdata1_datatype_after = transdata1_opdesc->GetInputDesc(0).GetDataType(); + EXPECT_EQ(transdata1_datatype_after, DT_INT32); +} TEST_F(UtestGraphInfershapePass, stop_node_for_while_loop) { /******************************************************************************* * Exit Identify