| @@ -19,9 +19,7 @@ | |||||
| #include <queue> | #include <queue> | ||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include "framework/common/debug/log.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/compute_graph.h" | |||||
| #include "common/debug/log.h" | |||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -30,95 +28,161 @@ constexpr int kMaxRePassTimes = 10000; | |||||
| constexpr size_t kMaxOneInNodes = 1000; | constexpr size_t kMaxOneInNodes = 1000; | ||||
| // Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later | // Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later | ||||
| constexpr int kMaxRecursiveDepth = 20; | constexpr int kMaxRecursiveDepth = 20; | ||||
| struct DuringPassNodeSets { | |||||
| std::unordered_set<Node *> nodes_seen; | |||||
| std::unordered_set<NodePtr> nodes_deleted; | |||||
| std::unordered_set<NodePtr> nodes_re_pass; | |||||
| std::unordered_set<NodePtr> nodes_re_pass_immediately; | |||||
| std::unordered_set<NodePtr> nodes_last; | |||||
| std::unordered_set<NodePtr> nodes_suspend; | |||||
| std::unordered_set<NodePtr> nodes_resume; | |||||
| }; | |||||
| void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &input_edge_nodes, | |||||
| std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_last) { | |||||
| nodes_last.clear(); | |||||
| void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, | |||||
| GEPass::GraphLevelState &g_state) { | |||||
| for (auto &node : graph->GetDirectNode()) { | for (auto &node : graph->GetDirectNode()) { | ||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| size_t in_nums = node->GetInNodes().size(); | size_t in_nums = node->GetInNodes().size(); | ||||
| if (in_nums == 0) { | if (in_nums == 0) { | ||||
| input_edge_nodes.push_back(node); | |||||
| nodes_seen.insert(node.get()); | |||||
| g_state.AddNodeToQueueIfNotSeen(node); | |||||
| } else if (in_nums > kMaxOneInNodes) { | } else if (in_nums > kMaxOneInNodes) { | ||||
| nodes_last.insert(node); | |||||
| g_state.nodes_last.insert(node); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| bool IsAllInNodesAlive(const Node::Vistor<NodePtr> &nodes, const std::unordered_set<NodePtr> &nodes_suspend) { | |||||
| return !std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { return nodes_suspend.count(n) > 0; }); | |||||
| bool AllNodesIn(const Node::Vistor<NodePtr> &nodes, const std::unordered_set<NodePtr> &nodes_set) { | |||||
| return !std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { | |||||
| return nodes_set.count(n) == 0; | |||||
| }); | |||||
| } | |||||
| bool AnyNodesIn(const Node::Vistor<NodePtr> &nodes, const std::unordered_set<NodePtr> &nodes_set) { | |||||
| return std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { | |||||
| return nodes_set.count(n) > 0; | |||||
| }); | |||||
| } | } | ||||
| void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass, | |||||
| DuringPassNodeSets &during_pass_node_set) { | |||||
| auto &nodes_seen = during_pass_node_set.nodes_seen; | |||||
| const auto &nodes_last = during_pass_node_set.nodes_last; | |||||
| const auto &nodes_suspend = during_pass_node_set.nodes_suspend; | |||||
| for (auto &node : nodes) { | |||||
| bool IsNodeReadyToQueue(const NodePtr &node, GEPass::GraphLevelState &g_state) { | |||||
| if (node == nullptr) { | |||||
| GELOGW("node is null"); | |||||
| return false; | |||||
| } | |||||
| if (g_state.nodes_deleted.count(node) > 0) { | |||||
| GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| if (g_state.nodes_last.count(node) != 0) { | |||||
| return false; | |||||
| } | |||||
| if (!node->IsAllInNodesSeen(g_state.nodes_seen)) { | |||||
| return false; | |||||
| } | |||||
| // 因为在PassNode之前,会首先将node的输出节点添加queue,因此若在pass node时,suspend了node的输出节点,后续逻辑与上面相同 | |||||
| // TODO 需要注意的是,这里的保证是一次”尽力而为“,若pass node时,将node之前的节点`A`添加到了suspend, | |||||
| // 那么`A`节点的后继和间接后继节点的pass不会受到suspend的影响 | |||||
| // 理论上来说,如果在pass node之前,首先收集node的输出节点,在pass后,将输出节点做suspend、delete的去除,然后加queue, | |||||
| // 这样处理就不需要在这里做额外的确认了 | |||||
| if (g_state.nodes_suspend.count(node) > 0) { | |||||
| GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", | |||||
| node->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| if (AnyNodesIn(node->GetInAllNodes(), g_state.nodes_suspend)) { | |||||
| GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", | |||||
| node->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void CollectOutNodesBeforePass(const NodePtr &node, std::unordered_set<NodePtr> &out_nodes_before_pass) { | |||||
| for (const auto &out_node : node->GetOutNodes()) { | |||||
| out_nodes_before_pass.insert(out_node); | |||||
| } | |||||
| } | |||||
| void AddNextIterNodes(const NodePtr &cur_node, std::unordered_set<NodePtr> &out_nodes_before_pass, | |||||
| GEPass::GraphLevelState &g_state) { | |||||
| for (auto &node : cur_node->GetOutNodes()) { | |||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (nodes_last.count(node) != 0) { | |||||
| continue; | |||||
| if (out_nodes_before_pass.erase(node) == 0) { | |||||
| // after pass node , new output node come up | |||||
| GELOGD("New output nodes %s come up after pass %s.", node->GetName().c_str(), cur_node->GetName().c_str()); | |||||
| } | |||||
| if (IsNodeReadyToQueue(node, g_state)) { | |||||
| g_state.AddNodeToQueueIfNotSeen(node); | |||||
| } | } | ||||
| if (nodes_suspend.count(node) > 0) { | |||||
| GELOGD("The node %s has suspend by pass, skip it.", node->GetName().c_str()); | |||||
| } | |||||
| // A-->B-->C | |||||
| // \ | |||||
| // D--->E | |||||
| // If B has been delete after pass, two case need to consider | |||||
| // 1. A & C & E has been repass by B. good choice | |||||
| // 2. A & C & E not added to repass, C will not pass because no one trigger it. | |||||
| // while E will pass because D will trigger it. | |||||
| // So here we need add node which has no input_node to queue. | |||||
| for (const auto &node : out_nodes_before_pass) { | |||||
| if (!node->GetInAllNodes().empty()) { | |||||
| GELOGD("Node %s used to be output of node %s, but after pass it doesnt. " | |||||
| "It may triggered by other node, so no need add to queue now."); | |||||
| continue; | continue; | ||||
| } | } | ||||
| bool all_in_nodes_alive = IsAllInNodesAlive(node->GetInAllNodes(), nodes_suspend); | |||||
| bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen); | |||||
| if (all_in_nodes_seen && all_in_nodes_alive && nodes_seen.insert(node.get()).second) { | |||||
| nodes_to_pass.push_back(node); | |||||
| if (IsNodeReadyToQueue(node, g_state)) { | |||||
| // unlink edge may happen, add these node to queue otherwise they can not pass | |||||
| GELOGI("Node %s may lost from cur node, add to queue if not seen.", | |||||
| node->GetName().c_str(), cur_node->GetName().c_str()); | |||||
| g_state.AddNodeToQueueIfNotSeen(node); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| void AddRepassNodes(DuringPassNodeSets &during_pass_node_set, std::deque<NodePtr> &nodes) { | |||||
| for (const auto &node : during_pass_node_set.nodes_re_pass_immediately) { | |||||
| GELOGD("The node %s will be re-pass immediately.", node->GetName().c_str()); | |||||
| nodes.push_front(node); | |||||
| void AddImmediateRepassNodesToQueue(NodePtr &cur_node, | |||||
| const std::unordered_map<NodePtr, std::string> re_pass_imm_nodes_to_pass_names, | |||||
| GEPass::GraphLevelState &g_state) { | |||||
| for (const auto &node_2_pass_names : re_pass_imm_nodes_to_pass_names) { | |||||
| auto repass_imm_node = node_2_pass_names.first; | |||||
| if (repass_imm_node == nullptr) { | |||||
| GELOGW("Found null immediately re-pass node when executing pass %s on node %s type %s", | |||||
| node_2_pass_names.second.c_str(), | |||||
| cur_node->GetName().c_str(), cur_node->GetType().c_str()); | |||||
| continue; | |||||
| } | |||||
| if (g_state.nodes_passed.count(repass_imm_node) > 0) { | |||||
| GELOGD("The node %s specified by pass %s has been passed, it will repass immediately", | |||||
| repass_imm_node->GetName().c_str(), node_2_pass_names.second.c_str()); | |||||
| g_state.AddNodeToQueueFront(repass_imm_node); | |||||
| continue; | |||||
| } | |||||
| GELOGW("The node %s specified by pass %s has un-passed in_nodes, it will not repass immediately", | |||||
| repass_imm_node->GetName().c_str(), node_2_pass_names.second.c_str()); | |||||
| } | } | ||||
| during_pass_node_set.nodes_re_pass_immediately.clear(); | |||||
| } | } | ||||
| void AddResumeNodes(DuringPassNodeSets &during_pass_node_set, std::deque<NodePtr> &nodes) { | |||||
| for (auto &node : during_pass_node_set.nodes_resume) { | |||||
| const auto &it = during_pass_node_set.nodes_suspend.find(node); | |||||
| if (it != during_pass_node_set.nodes_suspend.end()) { | |||||
| during_pass_node_set.nodes_suspend.erase(node); | |||||
| GELOGD("The node %s resumed by pass.", node->GetName().c_str()); | |||||
| nodes.push_back(node); | |||||
| } else { | |||||
| GELOGW("The node %s not suspend, drop from resumed", node->GetName().c_str()); | |||||
| void AddLastNodesToQueue(GEPass::GraphLevelState &g_state) { | |||||
| for (auto &node : g_state.nodes_last) { | |||||
| // todo 为什么会在node_seen中看到node_last,blame一下看看历史合入记录 | |||||
| if (node->IsAllInNodesSeen(g_state.nodes_seen)) { | |||||
| g_state.AddNodeToQueueIfNotSeen(node); | |||||
| } | } | ||||
| } | } | ||||
| during_pass_node_set.nodes_resume.clear(); | |||||
| g_state.nodes_last.clear(); | |||||
| } | } | ||||
| void PushToSuspendNodes(DuringPassNodeSets &during_pass_node_set, const std::string &pass_name, | |||||
| const std::unordered_set<NodePtr> &nodes_suspend, | |||||
| const std::unordered_set<NodePtr> &nodes_resume) { | |||||
| for (const auto &node : nodes_suspend) { | |||||
| GELOGD("The iteration suspend of node %s has been set by pass %s", node->GetName().c_str(), pass_name.c_str()); | |||||
| during_pass_node_set.nodes_suspend.emplace(node); | |||||
| } | |||||
| for (const auto &node : nodes_resume) { | |||||
| GELOGD("The iteration suspend of node %s has been resumed by pass %s", node->GetName().c_str(), pass_name.c_str()); | |||||
| during_pass_node_set.nodes_resume.emplace(node); | |||||
| void AddResumeNodesToQueue(const std::unordered_map<NodePtr, std::string> resume_nodes_to_pass_names, | |||||
| GEPass::GraphLevelState &g_state) { | |||||
| // Currently we dont keep the order of suspend nodes and resume nodes, so its hard to know | |||||
| // which one comes first. Simple way : if a node both have suspend & resume state, we will resume it. | |||||
| // Better way: keep the order when suspend/resume a node, and in this func suspend/resume in order. | |||||
| for (const auto &node_2_pass_names : resume_nodes_to_pass_names) { | |||||
| auto node = node_2_pass_names.first; | |||||
| if (g_state.nodes_suspend.erase(node) > 0) { | |||||
| if (g_state.nodes_seen.count(node.get()) > 0 || node->IsAllInNodesSeen(g_state.nodes_seen)) { | |||||
| g_state.nodes.push_back(node); | |||||
| GELOGD("Node %s has been resumed by pass %s, add to queue.", | |||||
| node->GetName().c_str(), node_2_pass_names.second.c_str()); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -140,54 +204,6 @@ void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass | |||||
| } | } | ||||
| } | } | ||||
| Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNodeSets &during_pass_node_set) { | |||||
| if (node == nullptr) { | |||||
| REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid."); | |||||
| GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGD("Begin to run pass for node %s", node->GetName().c_str()); | |||||
| for (const auto &name_to_pass : names_to_passes) { | |||||
| if (name_to_pass.second == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "[Check][Param] There is null pointer in passes(%s), skip it", name_to_pass.first.c_str()); | |||||
| continue; | |||||
| } | |||||
| GELOGD("Begin to run pass %s for node %s", name_to_pass.first.c_str(), node->GetName().c_str()); | |||||
| name_to_pass.second->init(); | |||||
| auto result = name_to_pass.second->Run(node); | |||||
| if (result != SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "process pass %s on node:%s failed, ret:%u", | |||||
| name_to_pass.first.c_str(), node->GetName().c_str(), result); | |||||
| GELOGE(INTERNAL_ERROR, "[Process][Pass] %s on node %s failed, result " | |||||
| "%u, the passes will be terminated immediately.", | |||||
| name_to_pass.first.c_str(), node->GetName().c_str(), result); | |||||
| return result; | |||||
| } | |||||
| const auto &nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); | |||||
| PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass, | |||||
| during_pass_node_set.nodes_re_pass); | |||||
| const auto &nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); | |||||
| PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, | |||||
| during_pass_node_set.nodes_re_pass_immediately); | |||||
| PushToSuspendNodes(during_pass_node_set, name_to_pass.first, | |||||
| name_to_pass.second->GetNodesSuspend(), name_to_pass.second->GetNodesResume()); | |||||
| const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); | |||||
| during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); | |||||
| if (nodes_deleted_by_pass.count(node) > 0) { | |||||
| GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), | |||||
| name_to_pass.first.c_str()); | |||||
| break; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| void SetFlagOption(NodePassOption option, NamesToPass names_to_pass) { | void SetFlagOption(NodePassOption option, NamesToPass names_to_pass) { | ||||
| for (auto &name_to_pass : names_to_pass) { | for (auto &name_to_pass : names_to_pass) { | ||||
| name_to_pass.second->SetOption(option, ""); | name_to_pass.second->SetOption(option, ""); | ||||
| @@ -199,27 +215,10 @@ void ClearOption(NamesToPass names_to_pass) { | |||||
| name_to_pass.second->ClearOptions(); | name_to_pass.second->ClearOptions(); | ||||
| } | } | ||||
| } | } | ||||
| bool CheckNode(const NodePtr &node, const DuringPassNodeSets &during_pass_node_set) { | |||||
| if (node == nullptr) { | |||||
| GELOGW("node is null"); | |||||
| return false; | |||||
| } | |||||
| if (during_pass_node_set.nodes_deleted.count(node) > 0) { | |||||
| GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| if (during_pass_node_set.nodes_suspend.count(node) > 0) { | |||||
| GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", | |||||
| node->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map) { | |||||
| Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map, | |||||
| bool is_repass_io_immediately) { | |||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid."); | REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid."); | ||||
| GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); | GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); | ||||
| @@ -235,7 +234,7 @@ Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector<int> | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| AddRePassNodesWithInOut(node); | |||||
| is_repass_io_immediately ? AddImmediateRePassNodesWithInOut(node) : AddRePassNodesWithInOut(node); | |||||
| if (GraphUtils::IsolateNode(node, io_map) != GRAPH_SUCCESS) { | if (GraphUtils::IsolateNode(node, io_map) != GRAPH_SUCCESS) { | ||||
| REPORT_CALL_ERROR("E19999", "Isolate Node:%s failed", node->GetName().c_str()); | REPORT_CALL_ERROR("E19999", "Isolate Node:%s failed", node->GetName().c_str()); | ||||
| @@ -263,6 +262,12 @@ Status GEPass::Run(const NamesToPass &names_to_passes) { | |||||
| GELOGW("No passes input, the GEPass will do nothing"); | GELOGW("No passes input, the GEPass will do nothing"); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| for (const auto &name_to_pass : names_to_passes) { | |||||
| if (name_to_pass.second == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "[Check][Param] There is null pointer in passes(%s)", name_to_pass.first.c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| } | |||||
| if (depth_ > kMaxRecursiveDepth) { | if (depth_ > kMaxRecursiveDepth) { | ||||
| GELOGE(PARAM_INVALID, | GELOGE(PARAM_INVALID, | ||||
| @@ -275,81 +280,101 @@ Status GEPass::Run(const NamesToPass &names_to_passes) { | |||||
| return RunPassesOneGraph(names_to_passes); | return RunPassesOneGraph(names_to_passes); | ||||
| } | } | ||||
| void NotifyPassGraphStart(const ComputeGraphPtr &graph, const NamesToPass &names_to_pass) { | |||||
| for (auto &name_to_pass : names_to_pass) { | |||||
| name_to_pass.second->OnStartPassGraph(graph); | |||||
| } | |||||
| } | |||||
| Status GEPass::HandleLeakedSuspendNodes(const NamesToPass &names_to_passes, GraphLevelState &g_state) { | |||||
| std::unordered_map<NodePtr, std::string> resume_nodes_to_pass_names; | |||||
| for (auto &name_to_pass : names_to_passes) { | |||||
| name_to_pass.second->init(); | |||||
| auto ret = name_to_pass.second->OnSuspendNodesLeaked(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Internal Error happened when pass %s handle on suspend nodes leaked.", | |||||
| name_to_pass.first.c_str()); | |||||
| return ret; | |||||
| } | |||||
| for (const auto &resume_node : name_to_pass.second->GetNodesResume()){ | |||||
| resume_nodes_to_pass_names[resume_node].append(name_to_pass.first + ","); | |||||
| } | |||||
| } | |||||
| AddResumeNodesToQueue(resume_nodes_to_pass_names, g_state); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { | Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { | ||||
| GELOGD("Begin to run pass on graph, passes count %zu", names_to_passes.size()); | GELOGD("Begin to run pass on graph, passes count %zu", names_to_passes.size()); | ||||
| std::deque<NodePtr> nodes; | |||||
| DuringPassNodeSets during_pass_node_set; | |||||
| GetAllNodesNoInputEdge(graph_, nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last); | |||||
| GELOGD("Start points count %zu", nodes.size()); | |||||
| int re_pass_times = 0; | |||||
| NotifyPassGraphStart(graph_, names_to_passes); | |||||
| GraphLevelState g_state; | |||||
| g_state.re_pass_times = 0; | |||||
| GetAllNodesNoInputEdge(graph_, g_state); | |||||
| GELOGD("Start points count %zu", g_state.nodes.size()); | |||||
| do { | do { | ||||
| for (auto &node : during_pass_node_set.nodes_re_pass) { | |||||
| nodes.push_back(node); | |||||
| during_pass_node_set.nodes_seen.insert(node.get()); | |||||
| if (!g_state.nodes_suspend.empty()) { | |||||
| auto ret = HandleLeakedSuspendNodes(names_to_passes, g_state); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Failed to handle leaked suspend nodes, break base pass."); | |||||
| return ret; | |||||
| } | |||||
| if (g_state.nodes.empty()) { | |||||
| // There are suspend nodes leaked, but no pass resume it | |||||
| GELOGE(INTERNAL_ERROR, "There are suspend nodes but no pass resume, which means" | |||||
| "some nodes in this graph never pass."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| } | |||||
| auto ret = RunPassesGraphRepass(names_to_passes, g_state); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | } | ||||
| during_pass_node_set.nodes_re_pass.clear(); | |||||
| } while (!g_state.nodes_suspend.empty()); | |||||
| while (!nodes.empty()) { | |||||
| NodePtr node = nodes.front(); | |||||
| nodes.pop_front(); | |||||
| return SUCCESS; | |||||
| } | |||||
| (void)during_pass_node_set.nodes_re_pass.erase(node); | |||||
| if (!CheckNode(node, during_pass_node_set)) { | |||||
| continue; | |||||
| } | |||||
| AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set); | |||||
| auto ret = RunPasses(node, names_to_passes, during_pass_node_set); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", | |||||
| node->GetName().c_str(), node->GetType().c_str(), ret); | |||||
| return ret; | |||||
| } | |||||
| Status GEPass::RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLevelState &g_state) { | |||||
| RepassLevelState rp_state; | |||||
| do { | |||||
| for (auto &node : rp_state.nodes_re_pass) { | |||||
| GELOGD("Add node %s to queue for re-pass.", node->GetName().c_str()); | |||||
| g_state.AddNodeToQueue(node); | |||||
| } | |||||
| rp_state.nodes_re_pass.clear(); | |||||
| bool has_sub_graph = false; | |||||
| ret = RunPassesOnSubGraph(node, names_to_passes, has_sub_graph); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Run][Passes] on the sub graph of node %s failed", node->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| while (!g_state.nodes.empty()) { | |||||
| auto node = g_state.PopFront(); | |||||
| if (has_sub_graph) { | |||||
| GELOGD("There are subgraphs on node %s, run passes for for the second time", node->GetName().c_str()); | |||||
| SetFlagOption(kOptimizeAfterSubGraph, names_to_passes); | |||||
| ret = RunPasses(node, names_to_passes, during_pass_node_set); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code: %u", | |||||
| node->GetName().c_str(), node->GetType().c_str(), ret); | |||||
| return ret; | |||||
| } | |||||
| // There is only one option scene, so set and clear options around the `RunPasses` func. | |||||
| // if there are more than one scene to set options, the `ClearOption` function | |||||
| // should be called each time at the begin of the iteration | |||||
| ClearOption(names_to_passes); | |||||
| if (g_state.nodes_deleted.count(node) > 0) { | |||||
| GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); | |||||
| } | } | ||||
| (void)rp_state.nodes_re_pass.erase(node);// todo why | |||||
| g_state.nodes_seen.insert(node.get()); // todo 为什么这里seen | |||||
| AddRepassNodes(during_pass_node_set, nodes); | |||||
| AddResumeNodes(during_pass_node_set, nodes); | |||||
| } | |||||
| std::unordered_set<NodePtr> out_nodes_before_pass; | |||||
| CollectOutNodesBeforePass(node, out_nodes_before_pass); | |||||
| for (auto &node : during_pass_node_set.nodes_last) { | |||||
| bool all_in_nodes_seen = node->IsAllInNodesSeen(during_pass_node_set.nodes_seen); | |||||
| if (all_in_nodes_seen && during_pass_node_set.nodes_seen.insert(node.get()).second) { | |||||
| nodes.push_back(node); | |||||
| auto ret = RunPassesNodeOnce(node, names_to_passes, g_state, rp_state); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", node->GetName().c_str(), | |||||
| node->GetType().c_str(), ret); | |||||
| return ret; | |||||
| } | } | ||||
| AddNextIterNodes(node, out_nodes_before_pass, g_state); | |||||
| } | } | ||||
| during_pass_node_set.nodes_last.clear(); | |||||
| } while ((!during_pass_node_set.nodes_re_pass.empty() || !nodes.empty()) && ++re_pass_times < kMaxRePassTimes); | |||||
| AddLastNodesToQueue(g_state); | |||||
| } while ((!rp_state.nodes_re_pass.empty() || !g_state.nodes.empty()) && ++g_state.re_pass_times < kMaxRePassTimes); | |||||
| if (re_pass_times == kMaxRePassTimes) { | |||||
| if (g_state.re_pass_times == kMaxRePassTimes) { | |||||
| GELOGW("re_pass_times should not come to %d", kMaxRePassTimes); | GELOGW("re_pass_times should not come to %d", kMaxRePassTimes); | ||||
| } | } | ||||
| GELOGD("All passes runs end"); | GELOGD("All passes runs end"); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph) { | Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph) { | ||||
| auto sub_graph_names = node->GetOpDesc()->GetSubgraphInstanceNames(); | auto sub_graph_names = node->GetOpDesc()->GetSubgraphInstanceNames(); | ||||
| has_sub_graph = false; | has_sub_graph = false; | ||||
| @@ -371,4 +396,95 @@ Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GEPass::RunPassesNodeOnce(NodePtr &node, const NamesToPass &names_to_passes, | |||||
| GraphLevelState &g_state, RepassLevelState &rp_state) { | |||||
| auto ret = RunPassesOnNode(node, names_to_passes, g_state, rp_state); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", node->GetName().c_str(), | |||||
| node->GetType().c_str(), ret); | |||||
| return ret; | |||||
| } | |||||
| bool has_sub_graph = false; | |||||
| ret = RunPassesOnSubGraph(node, names_to_passes, has_sub_graph); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Run][Passes] on the sub graph of node %s failed", node->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| if (has_sub_graph) { | |||||
| GELOGD("There are subgraphs on node %s, run passes for for the second time", node->GetName().c_str()); | |||||
| SetFlagOption(kOptimizeAfterSubGraph, names_to_passes); | |||||
| ret = RunPassesOnNode(node, names_to_passes, g_state, rp_state); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code: %u", node->GetName().c_str(), | |||||
| node->GetType().c_str(), ret); | |||||
| return ret; | |||||
| } | |||||
| // There is only one option scene, so set and clear options around the `RunPasses` func. | |||||
| // if there are more than one scene to set options, the `ClearOption` function | |||||
| // should be called each time at the begin of the iteration | |||||
| ClearOption(names_to_passes); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GEPass::RunPassesOnNode(NodePtr &node, const NamesToPass &names_to_passes, GraphLevelState &g_state, | |||||
| RepassLevelState &rp_state) { | |||||
| if (node == nullptr) { | |||||
| REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid."); | |||||
| GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGD("Begin to run pass for node %s", node->GetName().c_str()); | |||||
| for (const auto &name_to_pass : names_to_passes) { | |||||
| GELOGD("Begin to run pass %s for node %s", name_to_pass.first.c_str(), node->GetName().c_str()); | |||||
| name_to_pass.second->init(); | |||||
| auto result = name_to_pass.second->Run(node); | |||||
| if (result != SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "process pass %s on node:%s failed, ret:%u", name_to_pass.first.c_str(), | |||||
| node->GetName().c_str(), result); | |||||
| GELOGE(INTERNAL_ERROR, | |||||
| "[Process][Pass] %s on node %s failed, result " | |||||
| "%u, the passes will be terminated immediately.", | |||||
| name_to_pass.first.c_str(), node->GetName().c_str(), result); | |||||
| return result; | |||||
| } | |||||
| if (name_to_pass.second->GetNodesDeleted().count(node) > 0) { | |||||
| GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), | |||||
| name_to_pass.first.c_str()); | |||||
| break; | |||||
| } | |||||
| } | |||||
| g_state.nodes_passed.insert(node); | |||||
| std::unordered_map<NodePtr, std::string> repass_imm_nodes_to_pass_names; | |||||
| std::unordered_map<NodePtr, std::string> resume_nodes_to_pass_names; | |||||
| // if multi pass add one node to repass immediately, here need to remove duplication | |||||
| for (const auto &name_to_pass : names_to_passes) { | |||||
| PushToRePassIfSeen(node, name_to_pass, g_state.nodes_seen, name_to_pass.second->GetNodesNeedRePass(), | |||||
| rp_state.nodes_re_pass); | |||||
| // collect imm_node && resume_node among these passes | |||||
| for (const auto &imm_node : name_to_pass.second->GetNodesNeedRePassImmediately()) { | |||||
| repass_imm_nodes_to_pass_names[imm_node].append(name_to_pass.first + ","); | |||||
| } | |||||
| for (const auto &resume_node : name_to_pass.second->GetNodesResume()) { | |||||
| resume_nodes_to_pass_names[resume_node].append(name_to_pass.first + ","); | |||||
| } | |||||
| for (const auto &suspend_node : name_to_pass.second->GetNodesSuspend()) { | |||||
| GELOGD("The iteration suspend of node %s has been set by pass %s", suspend_node->GetName().c_str(), | |||||
| name_to_pass.first.c_str()); | |||||
| g_state.nodes_suspend.insert(suspend_node); | |||||
| } | |||||
| const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); | |||||
| g_state.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); | |||||
| } | |||||
| AddImmediateRepassNodesToQueue(node, repass_imm_nodes_to_pass_names, g_state); | |||||
| AddResumeNodesToQueue(resume_nodes_to_pass_names, g_state); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -22,7 +22,6 @@ | |||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
| #include "framework/common/types.h" | #include "framework/common/types.h" | ||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| @@ -61,23 +60,32 @@ class BaseNodePass { | |||||
| const std::unordered_set<NodePtr> &GetNodesResume() { return nodes_resume_; } | const std::unordered_set<NodePtr> &GetNodesResume() { return nodes_resume_; } | ||||
| virtual Status OnSuspendNodesLeaked() { return SUCCESS; } | |||||
| void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } | void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } | ||||
| void ClearOptions() { options_.clear(); } | void ClearOptions() { options_.clear(); } | ||||
| void init() { | void init() { | ||||
| nodes_need_re_pass_.clear(); | nodes_need_re_pass_.clear(); | ||||
| nodes_deleted_.clear(); | |||||
| nodes_need_re_pass_immediately_.clear(); | nodes_need_re_pass_immediately_.clear(); | ||||
| nodes_deleted_.clear(); | |||||
| nodes_suspend_.clear(); | nodes_suspend_.clear(); | ||||
| nodes_resume_.clear(); | nodes_resume_.clear(); | ||||
| } | } | ||||
| virtual void OnStartPassGraph(const ComputeGraphPtr &graph) { | |||||
| current_graph_name_ = graph->GetName(); | |||||
| } | |||||
| protected: | protected: | ||||
| Status IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map); | |||||
| const string &GetCurrentGraphName() const { | |||||
| return current_graph_name_; | |||||
| } | |||||
| Status IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map, bool is_repass_io_immediately = false); | |||||
| Status IsolateAndDeleteNode(NodePtr &node, const std::initializer_list<int> &io_map) { | |||||
| return IsolateAndDeleteNode(node, std::vector<int>(io_map)); | |||||
| Status IsolateAndDeleteNode(NodePtr &node, const std::initializer_list<int> &io_map, bool is_repass_io_immediately = false) { | |||||
| return IsolateAndDeleteNode(node, std::vector<int>(io_map), is_repass_io_immediately); | |||||
| } | } | ||||
| /// | /// | ||||
| @@ -112,6 +120,22 @@ class BaseNodePass { | |||||
| } | } | ||||
| } | } | ||||
| /// | |||||
| /// Add a node and it's input/output data nodes to be optimized immediately again. | |||||
| /// @param node | |||||
| /// | |||||
| void AddImmediateRePassNodesWithInOut(const NodePtr &node) { | |||||
| auto in_nodes = node->GetInNodes(); | |||||
| for (auto &in_node : in_nodes) { | |||||
| AddImmediateRePassNode(in_node); | |||||
| } | |||||
| AddImmediateRePassNode(node); | |||||
| auto out_nodes = node->GetOutNodes(); | |||||
| for (auto &out_node : out_nodes) { | |||||
| AddImmediateRePassNode(out_node); | |||||
| } | |||||
| } | |||||
| /// | /// | ||||
| /// If you deleted a node from the graph, especially current node. The remain | /// If you deleted a node from the graph, especially current node. The remain | ||||
| /// iterate passes will continue process on the deleted node(if it can be | /// iterate passes will continue process on the deleted node(if it can be | ||||
| @@ -123,23 +147,15 @@ class BaseNodePass { | |||||
| void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } | void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } | ||||
| /// | /// | ||||
| /// If you suspend a node from the graph, especially following node. The remain | |||||
| /// iterate passes will stop process on the suspend node(if it can be | |||||
| /// If you postpone a node from the graph, especially following node. The remain | |||||
| /// iterate passes will stop process on the postpone node(if it can be | |||||
| /// reached by edge connections) till the last one. Obviously it is a waste of | /// reached by edge connections) till the last one. Obviously it is a waste of | ||||
| /// time. You can add the suspend nodes by calling this function, to stop the | |||||
| /// time. You can add the postpone nodes by calling this function, to stop the | |||||
| /// next iterations. | /// next iterations. | ||||
| /// @param node | /// @param node | ||||
| /// | /// | ||||
| void AddNodeSuspend(const NodePtr &node) { nodes_suspend_.insert(node); } | void AddNodeSuspend(const NodePtr &node) { nodes_suspend_.insert(node); } | ||||
| /// | |||||
| /// If you resume a node from the graph, especially following node. The remain | |||||
| /// iterate passes will continue process on the resume node(if it can be | |||||
| /// reached by edge connections) till the last one. | |||||
| /// You can add the resume nodes by calling this function, to resume the | |||||
| /// next iterations. | |||||
| /// @param node | |||||
| /// | |||||
| void AddNodeResume(const NodePtr &node) { nodes_resume_.insert(node); } | void AddNodeResume(const NodePtr &node) { nodes_resume_.insert(node); } | ||||
| bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } | bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } | ||||
| @@ -151,6 +167,7 @@ class BaseNodePass { | |||||
| std::unordered_set<NodePtr> nodes_suspend_; | std::unordered_set<NodePtr> nodes_suspend_; | ||||
| std::unordered_set<NodePtr> nodes_resume_; | std::unordered_set<NodePtr> nodes_resume_; | ||||
| std::map<NodePassOption, std::string> options_; | std::map<NodePassOption, std::string> options_; | ||||
| std::string current_graph_name_; | |||||
| }; | }; | ||||
| using NamesToPass = std::vector<std::pair<std::string, BaseNodePass *>>; | using NamesToPass = std::vector<std::pair<std::string, BaseNodePass *>>; | ||||
| @@ -160,12 +177,60 @@ class GEPass { | |||||
| explicit GEPass(ComputeGraphPtr &graph) : graph_(graph), root_graph_(graph), depth_(1) {} | explicit GEPass(ComputeGraphPtr &graph) : graph_(graph), root_graph_(graph), depth_(1) {} | ||||
| virtual ~GEPass() = default; | virtual ~GEPass() = default; | ||||
| Status Run(const NamesToPass &names_to_passes); | Status Run(const NamesToPass &names_to_passes); | ||||
| /* | |||||
| * todo | |||||
| * OneGraph: nodes_deleted, nodes_seen, nodes_passed, nodes_suspended | |||||
| * RePass: nodes_re_pass | |||||
| * GraphOneTime: nodes_last | |||||
| * NodeOneTime: nodes_re_pass_immediately, nodes_resume | |||||
| */ | |||||
| struct GraphLevelState { | |||||
| std::unordered_set<NodePtr> nodes_deleted; | |||||
| std::unordered_set<Node *> nodes_seen; | |||||
| std::unordered_set<NodePtr> nodes_passed; | |||||
| std::unordered_set<NodePtr> nodes_suspend; | |||||
| std::unordered_set<NodePtr> nodes_last; | |||||
| std::deque<NodePtr> nodes; | |||||
| int re_pass_times; | |||||
| void AddNodeToQueueFront(NodePtr node) { | |||||
| nodes_seen.insert(node.get()); | |||||
| nodes.emplace_front(std::move(node)); | |||||
| } | |||||
| void AddNodeToQueue(NodePtr node) { | |||||
| nodes_seen.insert(node.get()); | |||||
| nodes.emplace_back(std::move(node)); | |||||
| } | |||||
| void AddNodeToQueueIfNotSeen(NodePtr node) { | |||||
| if (nodes_seen.insert(node.get()).second) { | |||||
| nodes.emplace_back(std::move(node)); | |||||
| } | |||||
| } | |||||
| NodePtr PopFront() { | |||||
| NodePtr node = nodes.front(); | |||||
| nodes.pop_front(); | |||||
| return node; | |||||
| } | |||||
| }; | |||||
| struct RepassLevelState { | |||||
| std::unordered_set<NodePtr> nodes_re_pass; | |||||
| }; | |||||
| struct GraphOneTimeLevelState { | |||||
| std::unordered_set<NodePtr> nodes_last; | |||||
| }; | |||||
| private: | private: | ||||
| GEPass(ComputeGraphPtr &graph, ComputeGraphPtr &root_graph, int depth) | GEPass(ComputeGraphPtr &graph, ComputeGraphPtr &root_graph, int depth) | ||||
| : graph_(graph), root_graph_(root_graph), depth_(depth) {} | : graph_(graph), root_graph_(root_graph), depth_(depth) {} | ||||
| Status RunPassesNodeOnce(NodePtr &node, const NamesToPass &names_to_passes, | |||||
| GraphLevelState &g_state, RepassLevelState &rp_state); | |||||
| Status RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLevelState &g_state); | |||||
| Status RunPassesOneGraph(const NamesToPass &names_to_passes); | Status RunPassesOneGraph(const NamesToPass &names_to_passes); | ||||
| Status RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph); | Status RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph); | ||||
| Status RunPassesOnNode(NodePtr &node, const NamesToPass &names_to_passes, GraphLevelState &g_state, | |||||
| RepassLevelState &rp_state); | |||||
| Status HandleLeakedSuspendNodes(const NamesToPass &names_to_passes, GraphLevelState &g_state); | |||||
| ComputeGraphPtr graph_; | ComputeGraphPtr graph_; | ||||
| ComputeGraphPtr root_graph_; | ComputeGraphPtr root_graph_; | ||||
| int depth_; | int depth_; | ||||
| @@ -84,8 +84,11 @@ Status InferBasePass::Run(NodePtr &node) { | |||||
| bool InferBasePass::NeedInfer(const NodePtr &node) const { return true; } | bool InferBasePass::NeedInfer(const NodePtr &node) const { return true; } | ||||
| void InferBasePass::AddChangedNodesImmediateRepass(const std::set<NodePtr> &changed_nodes) { | void InferBasePass::AddChangedNodesImmediateRepass(const std::set<NodePtr> &changed_nodes) { | ||||
| // need passed_nodes set to solve the problem that multi-input operators do repass in advance. | |||||
| // when there is passed_nodes set, wo should call AddImmediateRePassNode for all nodes in changed_nodes. | |||||
| // need passed_nodes set to solve the problem that multi-input operators do repass in advance. | |||||
| // when there is passed_nodes set, wo should call AddImmediateRePassNode for all nodes in changed_nodes. | |||||
| for (const auto &node : changed_nodes) { | |||||
| AddImmediateRePassNode(node); | |||||
| } | |||||
| } | } | ||||
| graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes) { | graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes) { | ||||
| @@ -1,6 +1,6 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| *+ | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| * You may obtain a copy of the License at | * You may obtain a copy of the License at | ||||
| @@ -22,13 +22,16 @@ | |||||
| #include "graph/shape_refiner.h" | #include "graph/shape_refiner.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
| #include "graph/common/omg_util.h" | |||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| namespace ge { | |||||
| #include "external/graph/operator_factory.h" | |||||
| namespace ge { | |||||
| namespace { | |||||
| constexpr int kSwitchExitAnchorIndex = 0; | |||||
| constexpr int kSwitchPredAnchorIndex = 1; | |||||
| void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { | void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { | ||||
| desc_str += "["; | desc_str += "["; | ||||
| std::vector<std::pair<int64_t, int64_t>> shape_range; | std::vector<std::pair<int64_t, int64_t>> shape_range; | ||||
| @@ -47,129 +50,302 @@ void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { | |||||
| desc_str += "},"; | desc_str += "},"; | ||||
| } | } | ||||
| } | } | ||||
| void UpdateShapeAndDType(const GeTensorDescPtr &src, GeTensorDescPtr &dst) { | |||||
| dst->SetOriginShape(src->GetOriginShape()); | |||||
| dst->SetShape(src->GetShape()); | |||||
| dst->SetDataType(src->GetDataType()); | |||||
| dst->SetOriginDataType(src->GetOriginDataType()); | |||||
| vector<pair<int64_t, int64_t>> src_shape_range; | |||||
| src->GetShapeRange(src_shape_range); | |||||
| dst->SetShapeRange(src_shape_range); | |||||
| dst->SetOriginShapeRange(src_shape_range); | |||||
| ge::TensorUtils::SetRealDimCnt(*dst, static_cast<uint32_t>(src->GetShape().GetDims().size())); | |||||
| } | |||||
| } // namespace | |||||
| std::string GetInTensorInfoWithString(const ge::NodePtr &node) { | |||||
| ge::OpDescPtr op_desc = node->GetOpDesc(); | |||||
| std::string InferShapePass::SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const { | |||||
| std::stringstream ss; | std::stringstream ss; | ||||
| ss << "{"; | |||||
| int32_t in_idx = 0; | |||||
| for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { | |||||
| if (input_desc == nullptr) { | |||||
| in_idx++; | |||||
| ss << "(shape:[" << tensor_desc->MutableShape().ToString() << "]),"; | |||||
| ss << "(format:" << TypeUtils::FormatToSerialString(tensor_desc->GetFormat()) << "),"; | |||||
| ss << "(dtype:" << TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()) << "),"; | |||||
| ss << "(origin_shape:" << tensor_desc->GetOriginShape().ToString() << "),"; | |||||
| ss << "(origin_format:" << TypeUtils::FormatToSerialString(tensor_desc->GetOriginFormat()) << "),"; | |||||
| ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(tensor_desc->GetOriginDataType()) << "),"; | |||||
| string range_str; | |||||
| SerialShapeRange(tensor_desc, range_str); | |||||
| ss << "(shape_range:" << range_str << ")"; | |||||
| return ss.str(); | |||||
| } | |||||
| Status InferShapePass::SuspendV1LoopExitNodes(const NodePtr &node) { | |||||
| if (node->GetType() != SWITCH) { | |||||
| return SUCCESS; | |||||
| } | |||||
| auto pred_node = NodeUtils::GetInDataNodeByIndex(*node, kSwitchPredAnchorIndex); | |||||
| GE_CHECK_NOTNULL(pred_node); | |||||
| if (pred_node->GetType() != LOOPCOND) { | |||||
| return SUCCESS; | |||||
| } | |||||
| for (const auto &anchor_2_node : NodeUtils::GetOutDataNodesWithAnchorByIndex(*node, kSwitchExitAnchorIndex)) { | |||||
| GELOGI("Found v1 loop when infershape, suspend Exit node %s, type %s.", anchor_2_node.second->GetName().c_str(), | |||||
| anchor_2_node.second->GetType().c_str()); | |||||
| auto iter = graphs_2_suspend_nodes_.find(GetCurrentGraphName()); | |||||
| if (iter == graphs_2_suspend_nodes_.end()) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (in_idx > 0) { | |||||
| ss << " "; | |||||
| auto &suspend_nodes = graphs_2_suspend_nodes_[GetCurrentGraphName()]; | |||||
| if (suspend_nodes.nodes_set.insert(anchor_2_node.second).second) { | |||||
| suspend_nodes.nodes.push(anchor_2_node.second); | |||||
| AddNodeSuspend(anchor_2_node.second); | |||||
| } | } | ||||
| ss << "input_" << in_idx << " " << "tensor: ["; | |||||
| ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),"; | |||||
| ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),"; | |||||
| ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),"; | |||||
| ss << "(origin_shape:" << input_desc->GetOriginShape().ToString() << "),"; | |||||
| ss << "(origin_format:" << TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) << "),"; | |||||
| ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) << "),"; | |||||
| string range_str; | |||||
| SerialShapeRange(input_desc, range_str); | |||||
| ss << "(shape_range:" << range_str << ")]"; | |||||
| in_idx++; | |||||
| } | } | ||||
| return ss.str(); | |||||
| return SUCCESS; | |||||
| } | } | ||||
| Status InferShapePass::Run(NodePtr &node) { | |||||
| // kOptimizeAfterSubGraph exist means after subgraph | |||||
| auto ret = ShapeRefiner::InferShapeAndType(node, !OptionExists(kOptimizeAfterSubGraph)); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| // select INFERSHAPE failed info | |||||
| auto graph = node->GetOwnerComputeGraph(); | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| auto root_graph = ge::GraphUtils::FindRootGraph(graph); | |||||
| GE_CHECK_NOTNULL(root_graph); | |||||
| analyzer::DataInfo analyze_info{root_graph->GetSessionID(), root_graph->GetGraphID(), | |||||
| analyzer::INFER_SHAPE, node, "InferShapeFailed!"}; | |||||
| (void)Analyzer::GetInstance()->DoAnalyze(analyze_info); | |||||
| (void)Analyzer::GetInstance()->SaveAnalyzerDataToFile(root_graph->GetSessionID(), | |||||
| root_graph->GetGraphID()); | |||||
| REPORT_CALL_ERROR("E19999", "Call InferShapeAndType for node:%s(%s) failed, input_tensor:%s", | |||||
| node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str()); | |||||
| GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "[Call][InferShapeAndType] for node:%s(%s) failed, input_tensor:%s", | |||||
| node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str()); | |||||
| return GE_GRAPH_INFERSHAPE_FAILED; | |||||
| } | |||||
| GE_CHK_STATUS_RET_NOLOG(RePassLoopNode(node)); | |||||
| bool need_repass = false; | |||||
| auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_repass); | |||||
| if (has_attr) { | |||||
| if (!OptionExists(kOptimizeAfterSubGraph)) { | |||||
| return SUCCESS; | |||||
| Status InferShapePass::Infer(NodePtr &node) { | |||||
| auto ret = SuspendV1LoopExitNodes(node); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Failed to suspend exit node in v1 control flow loop."); | |||||
| return ret; | |||||
| } | |||||
| bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); | |||||
| auto opdesc = node->GetOpDesc(); | |||||
| if (node->Verify() != GRAPH_SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Verifying %s failed.", node->GetName().c_str()); | |||||
| GELOGE(GRAPH_FAILED, "[Call][Verify] Verifying %s failed.", node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| Operator op = OpDescUtils::CreateOperatorFromNode(node); | |||||
| if (!is_unknown_graph) { | |||||
| auto inference_context = ShapeRefiner::CreateInferenceContext(node); | |||||
| GE_CHECK_NOTNULL(inference_context); | |||||
| vector<AscendString> marks; | |||||
| inference_context->GetMarks(marks); | |||||
| GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), marks.size()); | |||||
| op.SetInferenceContext(inference_context); | |||||
| } | |||||
| graphStatus status = CallInferShapeFunc(node, op); | |||||
| if (status != GRAPH_NODE_NEED_REPASS && status != GRAPH_PARAM_INVALID && status != GRAPH_SUCCESS) { | |||||
| // node like netoutput return param_invalid, but valid ? | |||||
| REPORT_CALL_ERROR("E19999", "%s call infer function failed.", node->GetName().c_str()); | |||||
| GELOGE(GRAPH_FAILED, "[Call][InferFunction] failed, node:%s.", node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| UpdateCurNodeOutputDesc(node); | |||||
| if (!is_unknown_graph) { | |||||
| auto ctx_after_infer = op.GetInferenceContext(); | |||||
| if (ctx_after_infer != nullptr) { | |||||
| vector<AscendString> marks; | |||||
| ctx_after_infer->GetMarks(marks); | |||||
| GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), marks.size()); | |||||
| if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !marks.empty()) { | |||||
| GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), | |||||
| marks.size()); | |||||
| ShapeRefiner::PushToContextMap(node, ctx_after_infer); | |||||
| } | |||||
| } | } | ||||
| if (need_repass) { | |||||
| AddImmediateRePassNode(node); | |||||
| GELOGD("Node %s need repass immediately.", node->GetName().c_str()); | |||||
| } else { | |||||
| // clear attr on while | |||||
| node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | |||||
| } | |||||
| return (status == GRAPH_NODE_NEED_REPASS) ? GRAPH_NODE_NEED_REPASS : GRAPH_SUCCESS; | |||||
| } | |||||
| bool InferShapePass::SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) { | |||||
| // check shape range | |||||
| vector<std::pair<int64_t, int64_t>> src_shape_range; | |||||
| vector<std::pair<int64_t, int64_t>> dst_shape_range; | |||||
| src->GetShapeRange(src_shape_range); | |||||
| dst->GetShapeRange(dst_shape_range); | |||||
| if (src_shape_range.size() != dst_shape_range.size()) { | |||||
| GELOGI("Src shape range size is %zu, dst shape range size is %zu, not same.", src_shape_range.size(), | |||||
| dst_shape_range.size()); | |||||
| return false; | |||||
| } | |||||
| for (size_t i = 0; i < src_shape_range.size(); ++i) { | |||||
| if (src_shape_range[i].first != dst_shape_range[i].first || | |||||
| src_shape_range[i].second != dst_shape_range[i].second) { | |||||
| GELOGI("Current dim %zu. Src shape range is [%lu-%lu], dst shape range is [%lu-%lu], not same.", | |||||
| i, src_shape_range[i].first, src_shape_range[i].second, dst_shape_range[i].first, dst_shape_range[i].second); | |||||
| return false; | |||||
| } | } | ||||
| } | } | ||||
| // check shape | |||||
| auto src_shape = src->GetShape(); | |||||
| auto dst_shape = dst->GetShape(); | |||||
| if (src_shape.GetDims() != dst_shape.GetDims() || src->GetOriginShape().GetDims() != dst->GetOriginShape().GetDims() || | |||||
| src->GetDataType() != dst->GetDataType() || src->GetOriginDataType() != dst->GetOriginDataType()) { | |||||
| GELOGD( | |||||
| "Src shape is %s, origin_shape is %s, data_type is %s, origin data_type is %s; " | |||||
| "Dst shape is %s, origin_shape is %s, data_type is %s, original data_type is %s, not same.", | |||||
| src_shape.ToString().c_str(), src->GetOriginShape().ToString().c_str(), | |||||
| TypeUtils::DataTypeToSerialString(src->GetDataType()).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(src->GetOriginDataType()).c_str(), dst_shape.ToString().c_str(), | |||||
| dst->GetOriginShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(dst->GetDataType()).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(dst->GetOriginDataType()).c_str()); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void InferShapePass::UpdateCurNodeOutputDesc(NodePtr &node) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| for (const auto &out_anchor : node->GetAllOutDataAnchors()) { | |||||
| auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); | |||||
| GE_IF_BOOL_EXEC(output_tensor == nullptr, continue); | |||||
| GE_IF_BOOL_EXEC(output_tensor->MutableShape().GetDims().empty(), | |||||
| output_tensor->SetOriginShape(output_tensor->GetShape())); | |||||
| ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetOriginShape().GetDims() | |||||
| .size())); | |||||
| output_tensor->SetOriginDataType(output_tensor->GetDataType()); | |||||
| // set output origin shape range | |||||
| std::vector<std::pair<int64_t, int64_t>> range; | |||||
| (void)output_tensor->GetShapeRange(range); | |||||
| output_tensor->SetOriginShapeRange(range); | |||||
| GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", | |||||
| node->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), | |||||
| TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); | |||||
| } | |||||
| } | |||||
| graphStatus InferShapePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { | |||||
| changed = false; | |||||
| if (SameTensorDesc(src, dst)) { | |||||
| GELOGD("Peer dst tensor_desc is same as src tensor_desc. No need update."); | |||||
| return SUCCESS; | |||||
| } | |||||
| changed = true; | |||||
| UpdateShapeAndDType(src, dst); | |||||
| GELOGD( | |||||
| "UpdatePeerInputDesc from src Node: shape: [%s], datatype: %s, original datatype is %s." | |||||
| "To dst Node: shape: [%s], datatype: %s, original datatype is %s.", | |||||
| src->GetShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(src->GetDataType()).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(src->GetOriginDataType()).c_str(), dst->GetShape().ToString().c_str(), | |||||
| TypeUtils::DataTypeToSerialString(dst->GetDataType()).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(dst->GetOriginDataType()).c_str()); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status InferShapePass::RePassLoopNode(const NodePtr &node) { | |||||
| const auto RePassNode = [&](const std::set<std::string> &re_pass_types) { | |||||
| for (auto &n : node->GetOutDataNodes()) { | |||||
| GE_CHECK_NOTNULL(n); | |||||
| std::string node_type; | |||||
| GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "[Get][OriginalType] of node:%s failed.", n->GetName().c_str()); | |||||
| if (re_pass_types.count(node_type) > 0) { | |||||
| AddImmediateRePassNode(n); | |||||
| (void)AttrUtils::SetBool(n->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, false); | |||||
| GELOGD("Node %s need repass immediately after %s.", n->GetName().c_str(), node->GetName().c_str()); | |||||
| } | |||||
| graphStatus InferShapePass::CallInferShapeFunc(NodePtr &node, Operator &op) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| const auto &op_type = op_desc->GetType(); | |||||
| auto ret = op_desc->CallInferFunc(op); | |||||
| if (ret == GRAPH_PARAM_INVALID) { | |||||
| // Op ir no infer func, try to get infer func from operator factory | |||||
| auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType().c_str()); | |||||
| if (node_op.IsEmpty()) { | |||||
| GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str()); | |||||
| return ret; | |||||
| } | } | ||||
| return SUCCESS; | |||||
| }; | |||||
| const auto ExProcNode = [&](const std::set<std::string> &proc_types, | |||||
| const std::function<void(InferShapePass *, NodePtr)> &proc_func, | |||||
| const std::string &info) { | |||||
| for (auto &n : node->GetOutDataNodes()) { | |||||
| GE_CHECK_NOTNULL(n); | |||||
| std::string node_type; | |||||
| GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "[Get][OriginalType] of node:%s failed.", n->GetName().c_str()); | |||||
| if (proc_types.count(node_type) > 0) { | |||||
| proc_func(this, n); | |||||
| GELOGD("Node %s %s after %s.", n->GetName().c_str(), info.c_str(), node->GetName().c_str()); | |||||
| GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str()); | |||||
| auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); | |||||
| node_op.BreakConnect(); | |||||
| if (temp_op_desc == nullptr) { | |||||
| REPORT_CALL_ERROR("E19999", "GetOpDescFromOperator failed, return nullptr."); | |||||
| GELOGE(GRAPH_FAILED, "[Get][OpDesc] temp op desc is null"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) { | |||||
| GELOGW("InferShapeAndType UpdateInputName failed"); | |||||
| for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) { | |||||
| if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) { | |||||
| break; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | } | ||||
| } | } | ||||
| return SUCCESS; | |||||
| }; | |||||
| std::string node_type; | |||||
| GE_CHK_STATUS_RET(GetOriginalType(node, node_type), | |||||
| "[Get][OriginalType] of node:%s failed.", node->GetName().c_str()); | |||||
| if (kNextIterationOpTypes.count(node_type) > 0) { | |||||
| return RePassNode(kMergeOpTypes); // Re-Pass Merge | |||||
| if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) { | |||||
| GELOGW("InferShapeAndType UpdateOutputName failed"); | |||||
| } | |||||
| op_desc->AddInferFunc(temp_op_desc->GetInferFunc()); | |||||
| ret = op_desc->CallInferFunc(op); | |||||
| GELOGI("op CallInferFunc second. ret: %u", ret); | |||||
| } | } | ||||
| return ret; | |||||
| } | |||||
| if (kMergeOpTypes.count(node_type) > 0) { | |||||
| if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | |||||
| node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | |||||
| return RePassNode(kSwitchOpTypes); // Re-Pass Switch | |||||
| graphStatus InferShapePass::UpdateOutputFromSubgraphs(const std::vector<GeTensorDescPtr> &src, GeTensorDescPtr &dst) { | |||||
| GELOGD("Enter update parent node shape for class branch op process"); | |||||
| // check sub_graph shape.If not same ,do unknown shape process | |||||
| auto ref_out_tensor = src.at(0); | |||||
| ge::GeShape &ref_out_tensor_shape = ref_out_tensor->MutableShape(); | |||||
| for (auto &tensor : src) { | |||||
| if (ref_out_tensor->GetDataType() != tensor->GetDataType()) { | |||||
| REPORT_INNER_ERROR("E19999", "Does not support diff dtype among all ref output, shape:%s", | |||||
| ref_out_tensor_shape.ToString().c_str()); | |||||
| GELOGE(GRAPH_FAILED, "[Check][Param] node does not support diff dtype output"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto shape = tensor->MutableShape(); | |||||
| if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { | |||||
| GELOGD("Shape from subgraph size: %lu, ref_out_tensor_shape size: %lu", shape.GetShapeSize(), | |||||
| ref_out_tensor_shape.GetShapeSize()); | |||||
| ref_out_tensor_shape = GeShape(UNKNOWN_RANK); | |||||
| break; | |||||
| } | } | ||||
| for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) { | |||||
| if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) { | |||||
| continue; | |||||
| } | |||||
| GELOGD("j: %zu ,shape from subgraph size: %lu, ref_out_tensor_shape size: %lu", j, shape.GetShapeSize(), | |||||
| ref_out_tensor_shape.GetShapeSize()); | |||||
| (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM); | |||||
| } | |||||
| } | |||||
| UpdateShapeAndDType(ref_out_tensor, dst); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus InferShapePass::UpdateOutputFromSubgraphsForMultiDims(const std::vector<GeTensorDescPtr> &src, | |||||
| GeTensorDescPtr &dst) { | |||||
| // check sub_graph shape. Get max for update. | |||||
| if (src.empty()) { | |||||
| // TODO LOG | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| if (kSwitchOpTypes.count(node_type) > 0) { | |||||
| if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | |||||
| node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | |||||
| return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeResume, "need resume"); // Resume Exit | |||||
| } else { | |||||
| return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeSuspend, "need suspend"); // Suspend Exit | |||||
| int64_t max_size = 0; | |||||
| size_t max_shape_index = 0; | |||||
| auto &ref_out_tensor = src.at(0); | |||||
| for (size_t j = 0; j < src.size(); ++j) { | |||||
| auto &tensor = src.at(j); | |||||
| if (ref_out_tensor->GetDataType() != tensor->GetDataType()) { | |||||
| REPORT_INNER_ERROR("E19999", "node does not support diff dtype among all ref output"); | |||||
| GELOGE(GRAPH_FAILED, "[Check][Param] node does not support diff dtype among all ref output"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto shape = tensor->MutableShape(); | |||||
| int64_t size = 1; | |||||
| for (auto dim : shape.GetDims()) { | |||||
| if (dim != 0 && INT64_MAX / dim < size) { | |||||
| REPORT_INNER_ERROR("E19999", "The shape:%s size overflow", shape.ToString().c_str()); | |||||
| GELOGE(PARAM_INVALID, "[Check][Overflow] The shape size overflow"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| size *= dim; | |||||
| } | } | ||||
| } | |||||
| if (size > max_size) { | |||||
| max_size = size; | |||||
| max_shape_index = j; | |||||
| } | |||||
| } | |||||
| UpdateShapeAndDType(src.at(max_shape_index), dst); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| Status InferShapePass::OnSuspendNodesLeaked() { | |||||
| auto iter = graphs_2_suspend_nodes_.find(GetCurrentGraphName()); | |||||
| if (iter == graphs_2_suspend_nodes_.end()) { | |||||
| GELOGW("There is no suspend nodes on graph %s", GetCurrentGraphName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| if (!iter->second.nodes.empty()) { | |||||
| AddNodeResume(iter->second.PopSuspendedNode()); | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -17,22 +17,40 @@ | |||||
| #ifndef GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ | #ifndef GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ | ||||
| #define GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ | #define GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ | ||||
| #include "graph/passes/base_pass.h" | |||||
| #include "graph/passes/infer_base_pass.h" | |||||
| #include <stack> | |||||
| namespace ge { | namespace ge { | ||||
| class InferShapePass : public BaseNodePass { | |||||
| class InferShapePass : public InferBasePass { | |||||
| public: | public: | ||||
| /// | |||||
| /// Entry of the InferShapePass optimizer | |||||
| /// @param [in] graph: Input ComputeGraph | |||||
| /// @return SUCCESS: Execution succeed | |||||
| /// @return OTHERS: Execution failed | |||||
| /// @author | |||||
| /// | |||||
| Status Run(ge::NodePtr &node) override; | |||||
| std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override; | |||||
| graphStatus Infer(NodePtr &node) override; | |||||
| graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override; | |||||
| graphStatus UpdateOutputFromSubgraphs(const std::vector<GeTensorDescPtr> &src, GeTensorDescPtr &dst) override; | |||||
| graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector<GeTensorDescPtr> &src, | |||||
| GeTensorDescPtr &dst) override; | |||||
| Status OnSuspendNodesLeaked() override; | |||||
| private: | private: | ||||
| Status RePassLoopNode(const NodePtr &node); | |||||
| graphStatus CallInferShapeFunc(NodePtr &node, Operator &op); | |||||
| bool SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst); | |||||
| void UpdateCurNodeOutputDesc(NodePtr &node); | |||||
| Status SuspendV1LoopExitNodes(const NodePtr &node); | |||||
| struct SuspendNodes { | |||||
| std::stack<NodePtr> nodes; | |||||
| std::unordered_set<NodePtr> nodes_set; | |||||
| NodePtr PopSuspendedNode() { | |||||
| auto top_node = nodes.top(); | |||||
| nodes.pop(); | |||||
| nodes_set.erase(top_node); | |||||
| return top_node; | |||||
| } | |||||
| }; | |||||
| std::map<std::string, SuspendNodes> graphs_2_suspend_nodes_; | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ | #endif // GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ | ||||
| @@ -1999,6 +1999,22 @@ Status GraphPrepare::CheckUserInput(const std::vector<GeTensor> &user_input) { | |||||
| Status GraphPrepare::InferShapeForPreprocess() { | Status GraphPrepare::InferShapeForPreprocess() { | ||||
| GELOGI("Start infershape for preprocess."); | GELOGI("Start infershape for preprocess."); | ||||
| // Prepare dummy_shape for v1 control_flow op before infershape | |||||
| for (const auto &node : compute_graph_->GetAllNodes()) { | |||||
| string type; | |||||
| GetOriginalType(node, type); | |||||
| if (type == MERGE || type == REFMERGE) { | |||||
| for (size_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) { | |||||
| GELOGD("Prepare for infershape: update %s input_shape as dummy.", node->GetName().c_str()); | |||||
| NodeUtils::UpdateInputShape(*node, i, GeShape(DUMMY_SHAPE)); | |||||
| } | |||||
| } else if (type == WHILE) { | |||||
| for (size_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) { | |||||
| GELOGD("Prepare for infershape: update %s output_shape as dummy.", node->GetName().c_str()); | |||||
| NodeUtils::UpdateOutputShape(*node, i, GeShape(DUMMY_SHAPE)); | |||||
| } | |||||
| } | |||||
| } | |||||
| GEPass ge_passes(compute_graph_); | GEPass ge_passes(compute_graph_); | ||||
| NamesToPass names_to_passes; | NamesToPass names_to_passes; | ||||
| AssertPass assert_pass; | AssertPass assert_pass; | ||||
| @@ -72,7 +72,7 @@ TEST(UtestGraphPassesAddnPass, null_pass) { | |||||
| AddNPass *addn_pass = nullptr; | AddNPass *addn_pass = nullptr; | ||||
| NamesToPass names_to_pass; | NamesToPass names_to_pass; | ||||
| names_to_pass.emplace_back("Test", addn_pass); | names_to_pass.emplace_back("Test", addn_pass); | ||||
| EXPECT_EQ(pass.Run(names_to_pass), SUCCESS); | |||||
| EXPECT_NE(pass.Run(names_to_pass), SUCCESS); | |||||
| } | } | ||||
| TEST(UtestGraphPassesAddnPass, null_graph) { | TEST(UtestGraphPassesAddnPass, null_graph) { | ||||
| @@ -17,7 +17,6 @@ | |||||
| #include <iostream> | #include <iostream> | ||||
| #include <map> | #include <map> | ||||
| #include <set> | #include <set> | ||||
| #include <string> | |||||
| #include <vector> | #include <vector> | ||||
| #include "gtest/gtest.h" | #include "gtest/gtest.h" | ||||
| @@ -26,8 +25,6 @@ | |||||
| #include "graph/passes/base_pass.h" | #include "graph/passes/base_pass.h" | ||||
| #undef protected | #undef protected | ||||
| #include "external/graph/ge_error_codes.h" | |||||
| #include "framework/common/ge_inner_error_codes.h" | |||||
| #include "framework/common/types.h" | #include "framework/common/types.h" | ||||
| #include "graph/node.h" | #include "graph/node.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| @@ -67,6 +64,54 @@ class UtestTestPass : public BaseNodePass { | |||||
| names_to_add_repass_.erase(iter); | names_to_add_repass_.erase(iter); | ||||
| } | } | ||||
| } | } | ||||
| iter = names_to_add_repass_immediate_.find(node->GetName()); | |||||
| if (iter != names_to_add_repass_immediate_.end()) { | |||||
| auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); | |||||
| for (const auto &node_name : iter->second) { | |||||
| for (auto &node_re_pass : all_nodes) { | |||||
| if (node_re_pass->GetName() == node_name) { | |||||
| AddImmediateRePassNode(node_re_pass); | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!dead_loop_) { | |||||
| names_to_add_repass_immediate_.erase(iter); | |||||
| } | |||||
| } | |||||
| iter = names_to_add_suspend_.find(node->GetName()); | |||||
| if (iter != names_to_add_suspend_.end()) { | |||||
| auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); | |||||
| for (const auto &node_name : iter->second) { | |||||
| for (auto &node_re_pass : all_nodes) { | |||||
| if (node_re_pass->GetName() == node_name) { | |||||
| AddNodeSuspend(node_re_pass); | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!dead_loop_) { | |||||
| names_to_add_suspend_.erase(iter); | |||||
| } | |||||
| } | |||||
| iter = names_to_add_resume_.find(node->GetName()); | |||||
| if (iter != names_to_add_resume_.end()) { | |||||
| auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); | |||||
| for (const auto &node_name : iter->second) { | |||||
| for (auto &node_re_pass : all_nodes) { | |||||
| if (node_re_pass->GetName() == node_name) { | |||||
| AddNodeResume(node_re_pass); | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!dead_loop_) { | |||||
| names_to_add_resume_.erase(iter); | |||||
| } | |||||
| } | |||||
| // simulate infershape pass | // simulate infershape pass | ||||
| if(node->GetType() == WHILE){ | if(node->GetType() == WHILE){ | ||||
| bool need_repass = false; | bool need_repass = false; | ||||
| @@ -85,6 +130,20 @@ class UtestTestPass : public BaseNodePass { | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status OnSuspendNodesLeaked() override { | |||||
| // resume all node remain in suspend_nodes when leaked | |||||
| auto compute_graph = (iter_nodes_.size() > 0) ? iter_nodes_[0]->GetOwnerComputeGraph() : nullptr; | |||||
| if (compute_graph == nullptr) { | |||||
| return SUCCESS; | |||||
| } | |||||
| for (const auto &node_name : names_to_add_resume_onleaked_) { | |||||
| auto node_to_resume = compute_graph->FindNode(node_name); | |||||
| AddNodeResume(node_to_resume); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| void clear() { iter_nodes_.clear(); } | void clear() { iter_nodes_.clear(); } | ||||
| std::vector<NodePtr> GetIterNodes() { return iter_nodes_; } | std::vector<NodePtr> GetIterNodes() { return iter_nodes_; } | ||||
| @@ -94,12 +153,31 @@ class UtestTestPass : public BaseNodePass { | |||||
| void AddDelNodeName(const std::string &iter_node, const std::string &del_node) { | void AddDelNodeName(const std::string &iter_node, const std::string &del_node) { | ||||
| names_to_add_del_[iter_node].insert(del_node); | names_to_add_del_[iter_node].insert(del_node); | ||||
| } | } | ||||
| void AddRePassImmediateNodeName(const std::string &iter_node, const std::string &re_pass_node) { | |||||
| names_to_add_repass_immediate_[iter_node].insert(re_pass_node); | |||||
| } | |||||
| void AddSuspendNodeName(const std::string &iter_node, const std::string &suspend_node) { | |||||
| names_to_add_suspend_[iter_node].insert(suspend_node); | |||||
| } | |||||
| void AddResumeNodeName(const std::string &iter_node, const std::string &resume_node) { | |||||
| names_to_add_resume_[iter_node].insert(resume_node); | |||||
| } | |||||
| void AddResumeNodeNameOnLeaked(const std::string &resume_node) { | |||||
| names_to_add_resume_onleaked_.insert(resume_node); | |||||
| } | |||||
| unsigned int GetRunTimes() { return run_times_; } | unsigned int GetRunTimes() { return run_times_; } | ||||
| private: | private: | ||||
| std::vector<NodePtr> iter_nodes_; | std::vector<NodePtr> iter_nodes_; | ||||
| std::map<std::string, std::unordered_set<std::string>> names_to_add_del_; | std::map<std::string, std::unordered_set<std::string>> names_to_add_del_; | ||||
| std::map<std::string, std::unordered_set<std::string>> names_to_add_repass_; | std::map<std::string, std::unordered_set<std::string>> names_to_add_repass_; | ||||
| std::map<std::string, std::unordered_set<std::string>> names_to_add_repass_immediate_; | |||||
| std::map<std::string, std::unordered_set<std::string>> names_to_add_suspend_; | |||||
| std::map<std::string, std::unordered_set<std::string>> names_to_add_resume_; | |||||
| std::unordered_set<std::string> names_to_add_resume_onleaked_; | |||||
| bool dead_loop_; | bool dead_loop_; | ||||
| unsigned int run_times_; | unsigned int run_times_; | ||||
| }; | }; | ||||
| @@ -200,6 +278,26 @@ ComputeGraphPtr BuildGraph3() { | |||||
| return builder.GetGraph(); | return builder.GetGraph(); | ||||
| } | } | ||||
| /// cast1--shape1 | |||||
| /// / | |||||
| /// data1 | |||||
| /// \ | |||||
| /// transdata1--shape2 | |||||
| ComputeGraphPtr BuildGraph4() { | |||||
| auto builder = ut::GraphBuilder("g1"); | |||||
| auto data1 = builder.AddNode("data1", DATA, 0, 1); | |||||
| auto cast1 = builder.AddNode("cast1", CAST, 1, 1); | |||||
| auto shape1 = builder.AddNode("shape1", SHAPE, 1, 1); | |||||
| auto transdata1 = builder.AddNode("transdata1", TRANSDATA, 1, 1); | |||||
| auto shape2 = builder.AddNode("shape2", SHAPE, 1, 1); | |||||
| builder.AddDataEdge(data1, 0, cast1, 0); | |||||
| builder.AddDataEdge(data1, 0, transdata1, 0); | |||||
| builder.AddDataEdge(cast1, 0, shape1, 0); | |||||
| builder.AddDataEdge(transdata1, 0, shape2, 0); | |||||
| return builder.GetGraph(); | |||||
| } | |||||
| void CheckIterOrder(UtestTestPass *pass, std::vector<std::unordered_set<std::string>> &nodes_layers) { | void CheckIterOrder(UtestTestPass *pass, std::vector<std::unordered_set<std::string>> &nodes_layers) { | ||||
| std::unordered_set<std::string> layer_nodes; | std::unordered_set<std::string> layer_nodes; | ||||
| size_t layer_index = 0; | size_t layer_index = 0; | ||||
| @@ -509,15 +607,369 @@ ComputeGraphPtr BuildWhileGraph1() { | |||||
| } | } | ||||
| TEST_F(UTESTGraphPassesBasePass, while_infershape) { | TEST_F(UTESTGraphPassesBasePass, while_infershape) { | ||||
| NamesToPass names_to_pass; | |||||
| auto test_pass = UtestTestPass(); | |||||
| names_to_pass.push_back(std::make_pair("test", &test_pass)); | |||||
| NamesToPass names_to_pass; | |||||
| auto test_pass = UtestTestPass(); | |||||
| names_to_pass.push_back(std::make_pair("test", &test_pass)); | |||||
| auto graph = BuildWhileGraph1(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| auto while_node = graph->FindNode("while"); | |||||
| EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); | |||||
| } | |||||
| TEST_F(UTESTGraphPassesBasePass, re_pass_pre_node_immediately) { | |||||
| auto graph = BuildGraph2(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second); | |||||
| // repass pre_node immediately | |||||
| test_pass->AddRePassImmediateNodeName("reshape1", "add1"); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); | |||||
| EXPECT_EQ(test_pass->GetIterNodes().size(), 9);// todo | |||||
| std::vector<std::unordered_set<std::string>> layers; | |||||
| layers.push_back({"data1", "const1", "const2"}); | |||||
| layers.push_back({"shape1"}); | |||||
| layers.push_back({"add1", "addn1"}); | |||||
| layers.push_back({"reshape1", "add1", "sum1"}); | |||||
| CheckIterOrder(test_pass, layers); | |||||
| } | |||||
| TEST_F(UTESTGraphPassesBasePass, re_pass_cur_node_immediately) { | |||||
| auto graph = BuildGraph2(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second); | |||||
| // repass cur_node immediately | |||||
| test_pass->AddRePassImmediateNodeName("reshape1", "reshape1"); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); | |||||
| EXPECT_EQ(test_pass->GetIterNodes().size(), 9); | |||||
| std::vector<std::unordered_set<std::string>> layers; | |||||
| layers.push_back({"data1", "const1", "const2"}); | |||||
| layers.push_back({"shape1"}); | |||||
| layers.push_back({"add1", "addn1"}); | |||||
| layers.push_back({"reshape1"}); | |||||
| layers.push_back({"reshape1", "sum1"}); | |||||
| CheckIterOrder(test_pass, layers); | |||||
| } | |||||
| TEST_F(UTESTGraphPassesBasePass, re_pass_next_node_immediately) { | |||||
| auto graph = BuildGraph2(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second); | |||||
| // repass next_node immediately | |||||
| test_pass->AddRePassImmediateNodeName("reshape1", "sum1"); | |||||
| // repass node after next_node immediately | |||||
| test_pass->AddRePassImmediateNodeName("add1", "sum1"); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); | |||||
| EXPECT_EQ(test_pass->GetIterNodes().size(), 8); | |||||
| std::vector<std::unordered_set<std::string>> layers; | |||||
| layers.push_back({"data1", "const1", "const2"}); | |||||
| layers.push_back({"shape1"}); | |||||
| layers.push_back({"add1", "addn1"}); | |||||
| layers.push_back({"reshape1", "sum1"}); | |||||
| CheckIterOrder(test_pass, layers); | |||||
| } | |||||
| /** | |||||
| * A->B->C | |||||
| * if node B suspend its pre_node A, and C resume A, it is a useless operation, so iter_order should follow normal order | |||||
| * when C resuem A, A will pass again. | |||||
| */ | |||||
| TEST_F(UTESTGraphPassesBasePass, B_suspend_pre_node_A_then_C_resume_A) { | |||||
| auto graph = BuildGraph2(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second); | |||||
| // add1->reshape1->sum1 | |||||
| test_pass->AddSuspendNodeName("reshape1", "add1"); | |||||
| test_pass->AddResumeNodeName("sum1", "add1"); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); | |||||
| EXPECT_EQ(test_pass->GetIterNodes().size(), 9); | |||||
| std::vector<std::unordered_set<std::string>> layers; | |||||
| layers.push_back({"data1", "const1", "const2"}); | |||||
| layers.push_back({"shape1"}); | |||||
| layers.push_back({"add1", "addn1"}); | |||||
| layers.push_back({"reshape1", "sum1"}); | |||||
| layers.push_back({"add1"}); | |||||
| CheckIterOrder(test_pass, layers); | |||||
| } | |||||
| /** | |||||
| * A->B->C | |||||
| * if node B suspend its pre_node A, and B resume A, it is a useless operation, so iter_order should follow normal order | |||||
| * when B resuem A, A will pass again. | |||||
| */ | |||||
| TEST_F(UTESTGraphPassesBasePass, B_suspend_pre_node_A_then_B_resume_A) { | |||||
| auto graph = BuildGraph2(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second); | |||||
| // add1->reshape1->sum1 | |||||
| test_pass->AddSuspendNodeName("reshape1", "add1"); | |||||
| test_pass->AddResumeNodeName("reshape1", "add1"); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); | |||||
| EXPECT_EQ(test_pass->GetIterNodes().size(), 9); | |||||
| std::vector<std::unordered_set<std::string>> layers; | |||||
| layers.push_back({"data1", "const1", "const2"}); | |||||
| layers.push_back({"shape1"}); | |||||
| layers.push_back({"add1", "addn1"}); | |||||
| layers.push_back({"reshape1", "sum1", "add1"}); | |||||
| CheckIterOrder(test_pass, layers); | |||||
| } | |||||
| /** | |||||
| * A->B->C | |||||
| * if node B resume C(which is not suspended), it is a useless operation, C will not pass. | |||||
| */ | |||||
| TEST_F(UTESTGraphPassesBasePass, B_resume_node_not_suspended) { | |||||
| auto graph = BuildGraph2(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second); | |||||
| // add1->reshape1->sum1 | |||||
| test_pass->AddResumeNodeName("reshape1", "sum1"); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); | |||||
| EXPECT_EQ(test_pass->GetIterNodes().size(), 8); | |||||
| std::vector<std::unordered_set<std::string>> layers; | |||||
| layers.push_back({"data1", "const1", "const2"}); | |||||
| layers.push_back({"shape1"}); | |||||
| layers.push_back({"add1", "addn1"}); | |||||
| layers.push_back({"reshape1", "sum1"}); | |||||
| CheckIterOrder(test_pass, layers); | |||||
| } | |||||
| auto graph = BuildWhileGraph1(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| auto while_node = graph->FindNode("while"); | |||||
| EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); | |||||
| /** | |||||
| * A->B->C | |||||
| * if node B suspend its pre_node A, it is a useless operation, so iter_order should follow normal order | |||||
| * because nobody resume it ,which means A is a leaked node, so return fail | |||||
| */ | |||||
| TEST_F(UTESTGraphPassesBasePass, suspend_pre_node_nobody_resume_it_return_failed) { | |||||
| NamesToPass names_to_pass; | |||||
| auto test_pass = UtestTestPass(); | |||||
| names_to_pass.push_back(std::make_pair("test", &test_pass)); | |||||
| // suspend pre_node immediately | |||||
| test_pass.AddSuspendNodeName("reshape1", "add1"); | |||||
| auto graph = BuildGraph2(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass), INTERNAL_ERROR); | |||||
| } | |||||
| /** | |||||
| * A->B->C | |||||
| * if node B suspend its pre_node A, it is a useless operation, | |||||
| * so iter_order should follow normal order | |||||
| * resume A on leaked, which means A will pass again | |||||
| */ | |||||
| TEST_F(UTESTGraphPassesBasePass, suspend_pre_node_resume_it_onleaked) { | |||||
| auto graph = BuildGraph2(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second); | |||||
| // suspend pre_node immediately | |||||
| test_pass->AddSuspendNodeName("reshape1", "add1"); | |||||
| test_pass->AddResumeNodeNameOnLeaked("add1"); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); | |||||
| std::vector<std::unordered_set<std::string>> layers; | |||||
| layers.push_back({"data1", "const1", "const2"}); | |||||
| layers.push_back({"shape1"}); | |||||
| layers.push_back({"add1", "addn1"}); | |||||
| layers.push_back({"reshape1", "sum1"}); | |||||
| layers.push_back({"add1"}); | |||||
| CheckIterOrder(test_pass, layers); | |||||
| } | } | ||||
| /// cast1--shape1 | |||||
| /// / | |||||
| /// data1 | |||||
| /// \ | |||||
| /// transdata1--shape2 | |||||
| /** | |||||
| * suspend cur node | |||||
| * cast1 suspend itself, shape2 resume cast1 | |||||
| * iter order follows : data1; cast1,transdata1; shape2; cast1 ; shape1 | |||||
| */ | |||||
| TEST_F(UTESTGraphPassesBasePass, cast1_suspend_cur_node_shape2_resume_cast1) { | |||||
| auto graph = BuildGraph4(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second); | |||||
| // suspend pre_node immediately | |||||
| test_pass->AddSuspendNodeName("cast1", "cast1"); | |||||
| test_pass->AddResumeNodeName("shape2", "cast1"); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); | |||||
| EXPECT_EQ(test_pass->GetIterNodes().size(), 6); | |||||
| std::vector<std::unordered_set<std::string>> layers; | |||||
| layers.push_back({"data1"}); | |||||
| layers.push_back({"cast1","transdata1"}); | |||||
| layers.push_back({"shape2"}); | |||||
| layers.push_back({"cast1", "shape1"}); | |||||
| CheckIterOrder(test_pass, layers); | |||||
| } | |||||
| /** | |||||
| * suspend cur node | |||||
| * cast1 suspend itself, then resume cast1 | |||||
| * iter order follows : data1; cast1,cast1,transdata1; shape2; shape1. | |||||
| */ | |||||
| TEST_F(UTESTGraphPassesBasePass, cast1_suspend_itslef_then_resume_itself) { | |||||
| auto graph = BuildGraph4(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second); | |||||
| // suspend pre_node immediately | |||||
| test_pass->AddSuspendNodeName("cast1", "cast1"); | |||||
| test_pass->AddResumeNodeName("cast1", "cast1"); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); | |||||
| EXPECT_EQ(test_pass->GetIterNodes().size(), 6); | |||||
| std::vector<std::unordered_set<std::string>> layers; | |||||
| layers.push_back({"data1"}); | |||||
| layers.push_back({"cast1","transdata1","cast1","shape1", "shape2"}); | |||||
| CheckIterOrder(test_pass, layers); | |||||
| } | |||||
| /** | |||||
| * suspend cur node | |||||
| * cast1 suspend itself, then resume cast1 on leaked | |||||
| * iter order follows : data1; cast1,cast1,transdata1; shape2; shape1. | |||||
| */ | |||||
| TEST_F(UTESTGraphPassesBasePass, cast1_suspend_itslef_then_resume_onleaked) { | |||||
| auto graph = BuildGraph4(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second); | |||||
| // suspend pre_node immediately | |||||
| test_pass->AddSuspendNodeName("cast1", "cast1"); | |||||
| test_pass->AddResumeNodeNameOnLeaked("cast1"); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); | |||||
| EXPECT_EQ(test_pass->GetIterNodes().size(), 6); | |||||
| std::vector<std::unordered_set<std::string>> layers; | |||||
| layers.push_back({"data1"}); | |||||
| layers.push_back({"cast1","transdata1", "shape2"}); | |||||
| layers.push_back({"cast1","shape1"}); | |||||
| CheckIterOrder(test_pass, layers); | |||||
| } | |||||
| /** | |||||
| * suspend next node | |||||
| * data1 suspend cast1, then resume cast1 on leaked | |||||
| * iter order follows : data1; transdata1, shape2; cast1, shape1. | |||||
| */ | |||||
| TEST_F(UTESTGraphPassesBasePass, data1_suspend_cast1_resume_cast1_onleaked) { | |||||
| auto graph = BuildGraph4(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second); | |||||
| // suspend pre_node immediately | |||||
| test_pass->AddSuspendNodeName("data1", "cast1"); | |||||
| test_pass->AddResumeNodeNameOnLeaked("cast1"); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); | |||||
| EXPECT_EQ(test_pass->GetIterNodes().size(), 5); | |||||
| std::vector<std::unordered_set<std::string>> layers; | |||||
| layers.push_back({"data1"}); | |||||
| layers.push_back({"transdata1", "shape2"}); | |||||
| layers.push_back({"cast1","shape1"}); | |||||
| CheckIterOrder(test_pass, layers); | |||||
| } | |||||
| /** | |||||
| * suspend next node | |||||
| * data1 suspend cast1, nobody resume it | |||||
| * iter order follows : data1; transdata1, shape2; | |||||
| * run ret is failed ,because node leaked | |||||
| */ | |||||
| TEST_F(UTESTGraphPassesBasePass, data1_suspend_cast1_nobody_resume) { | |||||
| auto graph = BuildGraph4(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second); | |||||
| // suspend pre_node immediately | |||||
| test_pass->AddSuspendNodeName("data1", "cast1"); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass_), INTERNAL_ERROR); | |||||
| EXPECT_EQ(test_pass->GetIterNodes().size(), 3); | |||||
| } | |||||
| TEST_F(UTESTGraphPassesBasePass, re_pass_pre_node_immediately) { | |||||
| NamesToPass names_to_pass; | |||||
| auto test_pass = UtestTestPass(); | |||||
| names_to_pass.push_back(std::make_pair("test", &test_pass)); | |||||
| // repass pre_node immediately | |||||
| test_pass.AddRePassImmediateNodeName("reshape1", "add1"); | |||||
| auto graph = BuildGraph2(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); | |||||
| EXPECT_EQ(test_pass.GetIterNodes().size(), 9);// todo | |||||
| std::vector<std::unordered_set<std::string>> layers; | |||||
| layers.push_back({"data1", "const1", "const2"}); | |||||
| layers.push_back({"shape1"}); | |||||
| layers.push_back({"add1", "addn1"}); | |||||
| layers.push_back({"reshape1", "add1", "sum1"}); | |||||
| CheckIterOrder(&test_pass, layers); | |||||
| } | |||||
| /// sum1 | |||||
| /// / \. | |||||
| /// / \. | |||||
| /// / \. | |||||
| /// reshape1 addn1 | |||||
| /// | c | | |||||
| /// add1 <--- shape1 | |||||
| /// / \ | | |||||
| /// | | | | |||||
| /// data1 const1 const2 | |||||
| TEST_F(UTESTGraphPassesBasePass, re_pass_cur_node_immediately) { | |||||
| NamesToPass names_to_pass; | |||||
| auto test_pass = UtestTestPass(); | |||||
| names_to_pass.push_back(std::make_pair("test", &test_pass)); | |||||
| // repass cur_node immediately | |||||
| test_pass.AddRePassImmediateNodeName("reshape1", "reshape1"); | |||||
| auto graph = BuildGraph2(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); | |||||
| EXPECT_EQ(test_pass.GetIterNodes().size(), 9);// todo | |||||
| std::vector<std::unordered_set<std::string>> layers; | |||||
| layers.push_back({"data1", "const1", "const2"}); | |||||
| layers.push_back({"shape1"}); | |||||
| layers.push_back({"add1", "addn1"}); | |||||
| layers.push_back({"reshape1"}); | |||||
| layers.push_back({"reshape1", "sum1"}); | |||||
| CheckIterOrder(&test_pass, layers); | |||||
| } | |||||
| TEST_F(UTESTGraphPassesBasePass, re_pass_next_node_immediately) { | |||||
| NamesToPass names_to_pass; | |||||
| auto test_pass = UtestTestPass(); | |||||
| names_to_pass.push_back(std::make_pair("test", &test_pass)); | |||||
| // repass next_node immediately | |||||
| test_pass.AddRePassImmediateNodeName("reshape1", "sum1"); | |||||
| // repass node after next_node immediately | |||||
| test_pass.AddRePassImmediateNodeName("add1", "sum1"); | |||||
| auto graph = BuildGraph2(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); | |||||
| EXPECT_EQ(test_pass.GetIterNodes().size(), 8);// todo | |||||
| std::vector<std::unordered_set<std::string>> layers; | |||||
| layers.push_back({"data1", "const1", "const2"}); | |||||
| layers.push_back({"shape1"}); | |||||
| layers.push_back({"add1", "addn1"}); | |||||
| layers.push_back({"reshape1", "sum1"}); | |||||
| CheckIterOrder(&test_pass, layers); | |||||
| } | |||||
| /* | |||||
| TEST_F(UTESTGraphPassesBasePass, suspend_pre_node) { | |||||
| NamesToPass names_to_pass; | |||||
| auto test_pass = UtestTestPass(); | |||||
| names_to_pass.push_back(std::make_pair("test", &test_pass)); | |||||
| // repass next_node immediately | |||||
| test_pass.AddRePassNodeName("reshape1", "sum1"); | |||||
| // repass node after next_node immediately | |||||
| test_pass.AddRePassNodeName("add1", "sum1"); | |||||
| auto graph = BuildGraph2(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); | |||||
| EXPECT_EQ(test_pass.GetIterNodes().size(), 8);// todo | |||||
| std::vector<std::unordered_set<std::string>> layers; | |||||
| layers.push_back({"data1", "const1", "const2"}); | |||||
| layers.push_back({"shape1"}); | |||||
| layers.push_back({"add1", "addn1"}); | |||||
| layers.push_back({"reshape1", "sum1"}); | |||||
| CheckIterOrder(&test_pass, layers); | |||||
| }*/ | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -29,13 +29,77 @@ | |||||
| using namespace std; | using namespace std; | ||||
| using namespace testing; | using namespace testing; | ||||
| namespace ge { | namespace ge { | ||||
| namespace { | |||||
| // do nothing stub infer_func | |||||
| const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; }; | |||||
| // infer from input to output stub infer_func (input size == output size) | |||||
| const auto stub_mapping_func = [](Operator &op) { | |||||
| size_t in_num = op.GetInputsSize(); | |||||
| for (size_t i = 0; i < in_num; ++i) { | |||||
| auto in_desc = op.GetInputDesc(i); | |||||
| auto out_desc = op.GetOutputDesc(i); | |||||
| out_desc.SetShape(in_desc.GetShape()); | |||||
| out_desc.SetDataType(in_desc.GetDataType()); | |||||
| op.UpdateOutputDesc(out_desc.GetName(), out_desc); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| }; | |||||
| // merge infer_func | |||||
| // while infer_func | |||||
| const auto while_infer_func = [](Operator &op) { | |||||
| size_t in_num = op.GetInputsSize(); | |||||
| size_t out_num = op.GetOutputsSize(); | |||||
| if (in_num != out_num) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| bool need_infer_again = false; | |||||
| for (size_t i = 0; i < in_num; ++i) { | |||||
| auto in_desc = op.GetDynamicInputDesc("input", i); | |||||
| auto out_desc = op.GetDynamicOutputDesc("output", i); | |||||
| auto data_shape = in_desc.GetShape(); | |||||
| auto out_shape = out_desc.GetShape(); | |||||
| if(out_shape.GetDims() == DUMMY_SHAPE){ | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| // check datatype between output and input | |||||
| if (in_desc.GetDataType() != out_desc.GetDataType()) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (data_shape.GetDims() != out_shape.GetDims()) { | |||||
| need_infer_again = true; | |||||
| if (data_shape.GetDimNum() != out_shape.GetDimNum()) { | |||||
| in_desc.SetUnknownDimNumShape(); | |||||
| } else { | |||||
| size_t data_dim_num = data_shape.GetDimNum(); | |||||
| std::vector<std::pair<int64_t, int64_t>> data_shape_range = {data_dim_num, std::make_pair(1, UNKNOWN_DIM)}; | |||||
| for (size_t j = 0; j < data_dim_num; ++j) { | |||||
| if (data_shape.GetDim(j) != out_shape.GetDim(j)) { | |||||
| data_shape.SetDim(j, UNKNOWN_DIM); | |||||
| } | |||||
| if (data_shape.GetDim(j) != UNKNOWN_DIM) { | |||||
| data_shape_range[j] = std::make_pair(data_shape.GetDim(j), data_shape.GetDim(j)); | |||||
| } | |||||
| } | |||||
| in_desc.SetShape(data_shape); | |||||
| in_desc.SetShapeRange(data_shape_range); | |||||
| } | |||||
| op.UpdateDynamicOutputDesc("output", i, in_desc); | |||||
| op.UpdateDynamicInputDesc("input", i, in_desc); | |||||
| } | |||||
| } | |||||
| return need_infer_again ? GRAPH_NODE_NEED_REPASS : GRAPH_SUCCESS; | |||||
| }; | |||||
| } | |||||
| class UtestGraphInfershapePass : public testing::Test { | class UtestGraphInfershapePass : public testing::Test { | ||||
| protected: | protected: | ||||
| void SetUp() {} | void SetUp() {} | ||||
| void TearDown() {} | void TearDown() {} | ||||
| }; | }; | ||||
| static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | |||||
| static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num, | |||||
| std::function<graphStatus(Operator &)> infer_func = stub_func) { | |||||
| OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | ||||
| op_desc->SetStreamId(0); | op_desc->SetStreamId(0); | ||||
| static int32_t index = 0; | static int32_t index = 0; | ||||
| @@ -61,14 +125,11 @@ static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string | |||||
| op_desc->SetWorkspaceBytes({}); | op_desc->SetWorkspaceBytes({}); | ||||
| op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); | op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); | ||||
| const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; }; | |||||
| op_desc->AddInferFunc(stub_func); | |||||
| op_desc->AddInferFormatFunc(stub_func); | |||||
| op_desc->AddVerifierFunc(stub_func); | |||||
| op_desc->AddInferFunc(infer_func); | |||||
| return graph.AddNode(op_desc); | return graph.AddNode(op_desc); | ||||
| } | } | ||||
| /* | |||||
| TEST_F(UtestGraphInfershapePass, infershape_pass_failed) { | TEST_F(UtestGraphInfershapePass, infershape_pass_failed) { | ||||
| GeTensorDesc ge_tensor_desc(GeShape({-2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); | GeTensorDesc ge_tensor_desc(GeShape({-2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); | ||||
| string type = "AddN"; | string type = "AddN"; | ||||
| @@ -82,6 +143,7 @@ TEST_F(UtestGraphInfershapePass, infershape_pass_failed) { | |||||
| InferShapePass infershape_pass; | InferShapePass infershape_pass; | ||||
| EXPECT_EQ(infershape_pass.Run(addn_node), GE_GRAPH_INFERSHAPE_FAILED); | EXPECT_EQ(infershape_pass.Run(addn_node), GE_GRAPH_INFERSHAPE_FAILED); | ||||
| } | } | ||||
| */ | |||||
| TEST_F(UtestGraphInfershapePass, delete_need_infer_again) { | TEST_F(UtestGraphInfershapePass, delete_need_infer_again) { | ||||
| auto graph = std::make_shared<ComputeGraph>("test"); | auto graph = std::make_shared<ComputeGraph>("test"); | ||||
| @@ -94,7 +156,43 @@ TEST_F(UtestGraphInfershapePass, delete_need_infer_again) { | |||||
| infershape_pass.options_[kOptimizeAfterSubGraph] = "yes"; | infershape_pass.options_[kOptimizeAfterSubGraph] = "yes"; | ||||
| EXPECT_EQ(infershape_pass.Run(no_op_node), SUCCESS); | EXPECT_EQ(infershape_pass.Run(no_op_node), SUCCESS); | ||||
| } | } | ||||
| TEST_F(UtestGraphInfershapePass, infer_from_pre_to_next) { | |||||
| /* | |||||
| * cast->shape | |||||
| */ | |||||
| auto graph = std::make_shared<ComputeGraph>("test_infer_shape"); | |||||
| auto data1 = CreateNode(*graph, "dataq", DATA, 0, 1); | |||||
| auto cast1 = CreateNode(*graph, "cast1", CAST, 1, 1, stub_mapping_func); | |||||
| auto cast_in_desc = cast1->GetOpDesc()->MutableInputDesc(0); | |||||
| cast_in_desc->SetShape(GeShape({1,2,3})); | |||||
| cast_in_desc->SetDataType(DT_INT32); | |||||
| auto transdata1 = CreateNode(*graph, "transdata1", TRANSDATA, 1, 1, stub_mapping_func); | |||||
| GraphUtils::AddEdge(data1->GetOutDataAnchor(0), cast1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(cast1->GetOutDataAnchor(0), transdata1->GetInDataAnchor(0)); | |||||
| // check before infer cast1 | |||||
| auto cast_before = graph->FindNode("cast1"); | |||||
| vector<int64_t> expect_cast1_shape_dim = {1,2,3}; | |||||
| auto real_cast1_before_shape_dim = cast_before->GetOpDesc()->GetInputDesc(0).GetShape().GetDims(); | |||||
| auto transdata1_before = graph->FindNode("transdata1"); | |||||
| vector<int64_t> expect_transdata1_shape_dim = {}; | |||||
| auto real_transdata1_before_shape_dim = transdata1_before->GetOpDesc()->GetInputDesc(0).GetShape().GetDims(); | |||||
| EXPECT_EQ(real_cast1_before_shape_dim, expect_cast1_shape_dim); | |||||
| EXPECT_EQ(real_transdata1_before_shape_dim, expect_transdata1_shape_dim); | |||||
| // run infershape pass | |||||
| InferShapePass infer_shape_pass; | |||||
| infer_shape_pass.Run(cast_before); | |||||
| // check cast1 add transdata1 to repass_immediately | |||||
| infer_shape_pass.GetNodesNeedRePassImmediately(); | |||||
| EXPECT_TRUE(!infer_shape_pass.GetNodesNeedRePassImmediately().empty()); | |||||
| // check transdata input_shape & datatype after infer | |||||
| auto transdata1_after = graph->FindNode("transdata1"); | |||||
| auto transdata1_opdesc = transdata1_before->GetOpDesc(); | |||||
| auto real_transdata1_after_shape_dim = transdata1_opdesc->GetInputDesc(0).GetShape().GetDims(); | |||||
| EXPECT_EQ(real_transdata1_after_shape_dim, expect_cast1_shape_dim); | |||||
| auto transdata1_datatype_after = transdata1_opdesc->GetInputDesc(0).GetDataType(); | |||||
| EXPECT_EQ(transdata1_datatype_after, DT_INT32); | |||||
| } | |||||
| TEST_F(UtestGraphInfershapePass, stop_node_for_while_loop) { | TEST_F(UtestGraphInfershapePass, stop_node_for_while_loop) { | ||||
| /******************************************************************************* | /******************************************************************************* | ||||
| * Exit Identify | * Exit Identify | ||||