Merge pull request !1976 from zhaoxinxin/mastertags/v1.5.1
@@ -1,374 +1,475 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/passes/base_pass.h" | |||||
#include <queue> | |||||
#include <unordered_set> | |||||
#include "framework/common/debug/log.h" | |||||
#include "framework/common/debug/ge_log.h" | |||||
#include "graph/compute_graph.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
namespace ge { | |||||
namespace { | |||||
constexpr int kMaxRePassTimes = 10000; | |||||
constexpr size_t kMaxOneInNodes = 1000; | |||||
// Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later | |||||
constexpr int kMaxRecursiveDepth = 20; | |||||
struct DuringPassNodeSets { | |||||
std::unordered_set<Node *> nodes_seen; | |||||
std::unordered_set<NodePtr> nodes_deleted; | |||||
std::unordered_set<NodePtr> nodes_re_pass; | |||||
std::unordered_set<NodePtr> nodes_re_pass_immediately; | |||||
std::unordered_set<NodePtr> nodes_last; | |||||
std::unordered_set<NodePtr> nodes_suspend; | |||||
std::unordered_set<NodePtr> nodes_resume; | |||||
}; | |||||
void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &input_edge_nodes, | |||||
std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_last) { | |||||
nodes_last.clear(); | |||||
for (auto &node : graph->GetDirectNode()) { | |||||
if (node == nullptr) { | |||||
continue; | |||||
} | |||||
size_t in_nums = node->GetInNodes().size(); | |||||
if (in_nums == 0) { | |||||
input_edge_nodes.push_back(node); | |||||
nodes_seen.insert(node.get()); | |||||
} else if (in_nums > kMaxOneInNodes) { | |||||
nodes_last.insert(node); | |||||
} | |||||
} | |||||
} | |||||
bool IsAllInNodesAlive(const Node::Vistor<NodePtr> &nodes, const std::unordered_set<NodePtr> &nodes_suspend) { | |||||
return !std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { return nodes_suspend.count(n) > 0; }); | |||||
} | |||||
void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass, | |||||
DuringPassNodeSets &during_pass_node_set) { | |||||
auto &nodes_seen = during_pass_node_set.nodes_seen; | |||||
const auto &nodes_last = during_pass_node_set.nodes_last; | |||||
const auto &nodes_suspend = during_pass_node_set.nodes_suspend; | |||||
for (auto &node : nodes) { | |||||
if (node == nullptr) { | |||||
continue; | |||||
} | |||||
if (nodes_last.count(node) != 0) { | |||||
continue; | |||||
} | |||||
if (nodes_suspend.count(node) > 0) { | |||||
GELOGD("The node %s has suspend by pass, skip it.", node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
bool all_in_nodes_alive = IsAllInNodesAlive(node->GetInAllNodes(), nodes_suspend); | |||||
bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen); | |||||
if (all_in_nodes_seen && all_in_nodes_alive && nodes_seen.insert(node.get()).second) { | |||||
nodes_to_pass.push_back(node); | |||||
} | |||||
} | |||||
} | |||||
void AddRepassNodes(DuringPassNodeSets &during_pass_node_set, std::deque<NodePtr> &nodes) { | |||||
for (const auto &node : during_pass_node_set.nodes_re_pass_immediately) { | |||||
GELOGD("The node %s will be re-pass immediately.", node->GetName().c_str()); | |||||
nodes.push_front(node); | |||||
} | |||||
during_pass_node_set.nodes_re_pass_immediately.clear(); | |||||
} | |||||
void AddResumeNodes(DuringPassNodeSets &during_pass_node_set, std::deque<NodePtr> &nodes) { | |||||
for (auto &node : during_pass_node_set.nodes_resume) { | |||||
const auto &it = during_pass_node_set.nodes_suspend.find(node); | |||||
if (it != during_pass_node_set.nodes_suspend.end()) { | |||||
during_pass_node_set.nodes_suspend.erase(node); | |||||
GELOGD("The node %s resumed by pass.", node->GetName().c_str()); | |||||
nodes.push_back(node); | |||||
} else { | |||||
GELOGW("The node %s not suspend, drop from resumed", node->GetName().c_str()); | |||||
} | |||||
} | |||||
during_pass_node_set.nodes_resume.clear(); | |||||
} | |||||
void PushToSuspendNodes(DuringPassNodeSets &during_pass_node_set, const std::string &pass_name, | |||||
const std::unordered_set<NodePtr> &nodes_suspend, | |||||
const std::unordered_set<NodePtr> &nodes_resume) { | |||||
for (const auto &node : nodes_suspend) { | |||||
GELOGD("The iteration suspend of node %s has been set by pass %s", node->GetName().c_str(), pass_name.c_str()); | |||||
during_pass_node_set.nodes_suspend.emplace(node); | |||||
} | |||||
for (const auto &node : nodes_resume) { | |||||
GELOGD("The iteration suspend of node %s has been resumed by pass %s", node->GetName().c_str(), pass_name.c_str()); | |||||
during_pass_node_set.nodes_resume.emplace(node); | |||||
} | |||||
} | |||||
void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass, | |||||
std::unordered_set<Node *> &nodes_seen, const std::unordered_set<NodePtr> &nodes_to_re_pass, | |||||
std::unordered_set<NodePtr> &nodes_re_pass) { | |||||
for (const auto &node_to_re_pass : nodes_to_re_pass) { | |||||
if (node_to_re_pass == nullptr) { | |||||
GELOGW("Found null re-pass node when executing %s on node %s type %s", name_to_pass.first.c_str(), | |||||
node->GetName().c_str(), node->GetType().c_str()); | |||||
continue; | |||||
} | |||||
if (nodes_seen.count(node_to_re_pass.get()) > 0 || node_to_re_pass->IsAllInNodesSeen(nodes_seen)) { | |||||
GELOGD("The node %s will be re-pass.", node_to_re_pass->GetName().c_str()); | |||||
nodes_re_pass.insert(node_to_re_pass); | |||||
} else { | |||||
GELOGD("The node %s are not all seen, don't set repass this time", node_to_re_pass->GetName().c_str()); | |||||
} | |||||
} | |||||
} | |||||
Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNodeSets &during_pass_node_set) { | |||||
if (node == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid."); | |||||
GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); | |||||
return FAILED; | |||||
} | |||||
GELOGD("Begin to run pass for node %s", node->GetName().c_str()); | |||||
for (const auto &name_to_pass : names_to_passes) { | |||||
if (name_to_pass.second == nullptr) { | |||||
GELOGE(INTERNAL_ERROR, "[Check][Param] There is null pointer in passes(%s), skip it", name_to_pass.first.c_str()); | |||||
continue; | |||||
} | |||||
GELOGD("Begin to run pass %s for node %s", name_to_pass.first.c_str(), node->GetName().c_str()); | |||||
name_to_pass.second->init(); | |||||
auto result = name_to_pass.second->Run(node); | |||||
if (result != SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "process pass %s on node:%s failed, ret:%u", | |||||
name_to_pass.first.c_str(), node->GetName().c_str(), result); | |||||
GELOGE(INTERNAL_ERROR, "[Process][Pass] %s on node %s failed, result " | |||||
"%u, the passes will be terminated immediately.", | |||||
name_to_pass.first.c_str(), node->GetName().c_str(), result); | |||||
return result; | |||||
} | |||||
const auto &nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); | |||||
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass, | |||||
during_pass_node_set.nodes_re_pass); | |||||
const auto &nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); | |||||
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, | |||||
during_pass_node_set.nodes_re_pass_immediately); | |||||
PushToSuspendNodes(during_pass_node_set, name_to_pass.first, | |||||
name_to_pass.second->GetNodesSuspend(), name_to_pass.second->GetNodesResume()); | |||||
const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); | |||||
during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); | |||||
if (nodes_deleted_by_pass.count(node) > 0) { | |||||
GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), | |||||
name_to_pass.first.c_str()); | |||||
break; | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
void SetFlagOption(NodePassOption option, NamesToPass names_to_pass) { | |||||
for (auto &name_to_pass : names_to_pass) { | |||||
name_to_pass.second->SetOption(option, ""); | |||||
} | |||||
} | |||||
void ClearOption(NamesToPass names_to_pass) { | |||||
for (auto &name_to_pass : names_to_pass) { | |||||
name_to_pass.second->ClearOptions(); | |||||
} | |||||
} | |||||
bool CheckNode(const NodePtr &node, const DuringPassNodeSets &during_pass_node_set) { | |||||
if (node == nullptr) { | |||||
GELOGW("node is null"); | |||||
return false; | |||||
} | |||||
if (during_pass_node_set.nodes_deleted.count(node) > 0) { | |||||
GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); | |||||
return false; | |||||
} | |||||
if (during_pass_node_set.nodes_suspend.count(node) > 0) { | |||||
GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", | |||||
node->GetName().c_str()); | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
} // namespace | |||||
Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map) { | |||||
if (node == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid."); | |||||
GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); | |||||
return FAILED; | |||||
} | |||||
GELOGI("Prepare to isolate and delete node, name:%s, type:%s.", node->GetName().c_str(), | |||||
node->GetType().c_str()); | |||||
ComputeGraphPtr graph = node->GetOwnerComputeGraph(); | |||||
if (graph == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", "The owner graph of node:%s must not be null.", node->GetName().c_str()); | |||||
GELOGE(FAILED, "[Get][OwnerComputeGraph] failed, The owner graph of node:%s must not be null.", | |||||
node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
AddRePassNodesWithInOut(node); | |||||
if (GraphUtils::IsolateNode(node, io_map) != GRAPH_SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Isolate Node:%s failed", node->GetName().c_str()); | |||||
GELOGE(FAILED, "[Isolate][Node] %s failed.", node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "call RemoveNodeWithoutRelink for node:%s failed.", node->GetName().c_str()); | |||||
GELOGE(FAILED, "[Call][RemoveNodeWithoutRelink] for node:%s failed.", node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
AddNodeDeleted(node); | |||||
return SUCCESS; | |||||
} | |||||
Status GEPass::Run(const NamesToPass &names_to_passes) { | |||||
if (graph_ == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", "graph_ is nullptr, check invalid."); | |||||
GELOGE(INTERNAL_ERROR, "[Check][Param] The graph is nullptr"); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
if (names_to_passes.empty()) { | |||||
GELOGW("No passes input, the GEPass will do nothing"); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
if (depth_ > kMaxRecursiveDepth) { | |||||
GELOGE(PARAM_INVALID, | |||||
"[Check][Param] The pass for root graph %s will be terminated because too many nesting" | |||||
" levels(%d) of subgraphs, last subgraph is %s", | |||||
root_graph_->GetName().c_str(), depth_, graph_->GetName().c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
return RunPassesOneGraph(names_to_passes); | |||||
} | |||||
Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { | |||||
GELOGD("Begin to run pass on graph, passes count %zu", names_to_passes.size()); | |||||
std::deque<NodePtr> nodes; | |||||
DuringPassNodeSets during_pass_node_set; | |||||
GetAllNodesNoInputEdge(graph_, nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last); | |||||
GELOGD("Start points count %zu", nodes.size()); | |||||
int re_pass_times = 0; | |||||
do { | |||||
for (auto &node : during_pass_node_set.nodes_re_pass) { | |||||
nodes.push_back(node); | |||||
during_pass_node_set.nodes_seen.insert(node.get()); | |||||
} | |||||
during_pass_node_set.nodes_re_pass.clear(); | |||||
while (!nodes.empty()) { | |||||
NodePtr node = nodes.front(); | |||||
nodes.pop_front(); | |||||
(void)during_pass_node_set.nodes_re_pass.erase(node); | |||||
if (!CheckNode(node, during_pass_node_set)) { | |||||
continue; | |||||
} | |||||
AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set); | |||||
auto ret = RunPasses(node, names_to_passes, during_pass_node_set); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", | |||||
node->GetName().c_str(), node->GetType().c_str(), ret); | |||||
return ret; | |||||
} | |||||
bool has_sub_graph = false; | |||||
ret = RunPassesOnSubGraph(node, names_to_passes, has_sub_graph); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[Run][Passes] on the sub graph of node %s failed", node->GetName().c_str()); | |||||
return ret; | |||||
} | |||||
if (has_sub_graph) { | |||||
GELOGD("There are subgraphs on node %s, run passes for for the second time", node->GetName().c_str()); | |||||
SetFlagOption(kOptimizeAfterSubGraph, names_to_passes); | |||||
ret = RunPasses(node, names_to_passes, during_pass_node_set); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code: %u", | |||||
node->GetName().c_str(), node->GetType().c_str(), ret); | |||||
return ret; | |||||
} | |||||
// There is only one option scene, so set and clear options around the `RunPasses` func. | |||||
// if there are more than one scene to set options, the `ClearOption` function | |||||
// should be called each time at the begin of the iteration | |||||
ClearOption(names_to_passes); | |||||
} | |||||
AddRepassNodes(during_pass_node_set, nodes); | |||||
AddResumeNodes(during_pass_node_set, nodes); | |||||
} | |||||
for (auto &node : during_pass_node_set.nodes_last) { | |||||
bool all_in_nodes_seen = node->IsAllInNodesSeen(during_pass_node_set.nodes_seen); | |||||
if (all_in_nodes_seen && during_pass_node_set.nodes_seen.insert(node.get()).second) { | |||||
nodes.push_back(node); | |||||
} | |||||
} | |||||
during_pass_node_set.nodes_last.clear(); | |||||
} while ((!during_pass_node_set.nodes_re_pass.empty() || !nodes.empty()) && ++re_pass_times < kMaxRePassTimes); | |||||
if (re_pass_times == kMaxRePassTimes) { | |||||
GELOGW("re_pass_times should not come to %d", kMaxRePassTimes); | |||||
} | |||||
GELOGD("All passes runs end"); | |||||
return SUCCESS; | |||||
} | |||||
Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph) { | |||||
auto sub_graph_names = node->GetOpDesc()->GetSubgraphInstanceNames(); | |||||
has_sub_graph = false; | |||||
for (const auto &name : sub_graph_names) { | |||||
auto graph = root_graph_->GetSubgraph(name); | |||||
if (graph == nullptr) { | |||||
GELOGW("Can not find the sub graph %s from node %s, the pass-process will skip it", | |||||
name.c_str(), node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
has_sub_graph = true; | |||||
GELOGI("Begin to run passes on the sub graph %s of node %s", name.c_str(), node->GetName().c_str()); | |||||
GEPass pass(graph, root_graph_, depth_ + 1); | |||||
auto ret = pass.Run(names_to_passes); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[Run][Passes] for sub graph:%s from node:%s failed", name.c_str(), node->GetName().c_str()); | |||||
return ret; | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
} // namespace ge | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/passes/base_pass.h" | |||||
#include <queue> | |||||
#include <unordered_set> | |||||
#include "common/debug/log.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
namespace ge { | |||||
namespace { | |||||
constexpr int kMaxRePassTimes = 10000; | |||||
constexpr size_t kMaxOneInNodes = 1000; | |||||
// Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later | |||||
constexpr int kMaxRecursiveDepth = 20; | |||||
void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, | |||||
GEPass::GraphLevelState &g_state) { | |||||
for (auto &node : graph->GetDirectNode()) { | |||||
if (node == nullptr) { | |||||
continue; | |||||
} | |||||
size_t in_nums = node->GetInNodes().size(); | |||||
if (in_nums == 0) { | |||||
g_state.AddNodeToQueueIfNotSeen(node); | |||||
} else if (in_nums > kMaxOneInNodes) { | |||||
g_state.nodes_last.insert(node); | |||||
} | |||||
} | |||||
} | |||||
bool AnyNodesIn(const Node::Vistor<NodePtr> &nodes, const std::unordered_set<NodePtr> &nodes_set) { | |||||
return std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { | |||||
return nodes_set.count(n) > 0; | |||||
}); | |||||
} | |||||
bool IsNodeReadyToQueue(const NodePtr &node, GEPass::GraphLevelState &g_state) { | |||||
if (node == nullptr) { | |||||
GELOGW("node is null"); | |||||
return false; | |||||
} | |||||
if (g_state.nodes_deleted.count(node) > 0) { | |||||
GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); | |||||
return false; | |||||
} | |||||
if (g_state.nodes_last.count(node) != 0) { | |||||
return false; | |||||
} | |||||
// all in_node seen && all in_node not suspend | |||||
if (!node->IsAllInNodesSeen(g_state.nodes_seen)) { | |||||
return false; | |||||
} | |||||
if (g_state.nodes_suspend.count(node) > 0) { | |||||
GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", | |||||
node->GetName().c_str()); | |||||
return false; | |||||
} | |||||
if (AnyNodesIn(node->GetInAllNodes(), g_state.nodes_suspend)) { | |||||
GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", | |||||
node->GetName().c_str()); | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
void AddNextIterNodes(const NodePtr &cur_node, | |||||
std::unordered_set<NodePtr> &out_nodes_before_pass, | |||||
GEPass::GraphLevelState &g_state) { | |||||
for (auto &node : cur_node->GetOutNodes()) { | |||||
if (node == nullptr) { | |||||
continue; | |||||
} | |||||
if(out_nodes_before_pass.erase(node) == 0) { | |||||
// after pass node, new output node come up | |||||
GELOGI("New output node %s come up after pass %s.", | |||||
node->GetName().c_str(), cur_node->GetName().c_str()); | |||||
} | |||||
// all in_node seen && all in_node not suspend | |||||
if (IsNodeReadyToQueue(node, g_state)) { | |||||
g_state.AddNodeToQueueIfNotSeen(node); | |||||
} | |||||
} | |||||
// | |||||
for (const auto &node : out_nodes_before_pass) { | |||||
// A-->B-->C if B was | |||||
// unlink edge may happend, add these node to queue if needed | |||||
if (node->GetInAllNodes().empty() && IsNodeReadyToQueue(node, g_state)) { | |||||
GELOGI("Node %s may lost from cur node, add to queue if not seen.", | |||||
node->GetName().c_str(), cur_node->GetName().c_str()); | |||||
g_state.AddNodeToQueueIfNotSeen(node); | |||||
} | |||||
} | |||||
} | |||||
void AddImmediateRepassNodesToQueue(NodePtr &cur_node, | |||||
std::unordered_map<NodePtr, std::string> re_pass_imm_nodes_to_pass_names, | |||||
GEPass::GraphLevelState &g_state) { | |||||
for (const auto &node_2_pass_names : re_pass_imm_nodes_to_pass_names) { | |||||
auto imme_repass_node = node_2_pass_names.first; | |||||
if (imme_repass_node == nullptr) { | |||||
GELOGW("Found null immediately re-pass node when executing pass %s on node %s type %s", | |||||
node_2_pass_names.second.c_str(), | |||||
cur_node->GetName().c_str(), cur_node->GetType().c_str()); | |||||
continue; | |||||
} | |||||
if (g_state.nodes_passed.count(imme_repass_node) > 0) { | |||||
GELOGD("The node %s specified by pass %s has been passed, it will repass immediately", | |||||
imme_repass_node->GetName().c_str(), node_2_pass_names.second.c_str()); | |||||
g_state.AddNodeToQueueFront(imme_repass_node); | |||||
continue; | |||||
} | |||||
GELOGW("The node %s specified by pass %s has un-passed, it will not repass immediately", | |||||
node_2_pass_names.first->GetName().c_str(), node_2_pass_names.second.c_str()); | |||||
} | |||||
} | |||||
void AddLastNodesToQueue(GEPass::GraphLevelState &g_state) { | |||||
for (auto &node : g_state.nodes_last) { | |||||
if (node->IsAllInNodesSeen(g_state.nodes_seen)) { | |||||
g_state.AddNodeToQueueIfNotSeen(node); | |||||
} | |||||
} | |||||
g_state.nodes_last.clear(); | |||||
} | |||||
void AddResumeNodesToQueue(const std::unordered_map<NodePtr, std::string> resume_node_2_pass_names, | |||||
GEPass::GraphLevelState &g_state) { | |||||
// Now base pass doesnt record the order of suspend & resume, so we dont know which one come first in a node pass. | |||||
// Here if one node pass suspend and resume a node ,consider it resume that node. | |||||
// Better way to record the order, and here suspend or resume in order. | |||||
for (const auto &node_2_pass_names : resume_node_2_pass_names) { | |||||
auto node = node_2_pass_names.first; | |||||
if (g_state.nodes_suspend.erase(node) > 0) { | |||||
if (g_state.nodes_seen.count(node.get()) > 0 || node->IsAllInNodesSeen(g_state.nodes_seen)) { | |||||
g_state.nodes.push_back(node); | |||||
GELOGD("Node %s has been resumed by pass %s, and add to pass queue", | |||||
node->GetName().c_str(), node_2_pass_names.second.c_str()); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass, | |||||
std::unordered_set<Node *> &nodes_seen, const std::vector<NodePtr> &nodes_to_re_pass, | |||||
GEPass::RepassLevelState &rp_state) { | |||||
for (const auto &node_to_re_pass : nodes_to_re_pass) { | |||||
if (node_to_re_pass == nullptr) { | |||||
GELOGW("Found null re-pass node when executing %s on node %s type %s", name_to_pass.first.c_str(), | |||||
node->GetName().c_str(), node->GetType().c_str()); | |||||
continue; | |||||
} | |||||
if (nodes_seen.count(node_to_re_pass.get()) > 0 || node_to_re_pass->IsAllInNodesSeen(nodes_seen)) { | |||||
if (rp_state.AddNodeToRepass(node_to_re_pass)) { | |||||
GELOGD("The node %s will be re-pass.", node_to_re_pass->GetName().c_str()); | |||||
continue; | |||||
} | |||||
GELOGD("Node %s has been added to repass queue, no need to add again.", node_to_re_pass->GetName().c_str()); | |||||
} else { | |||||
GELOGD("The node %s are not all seen, don't set repass this time", node_to_re_pass->GetName().c_str()); | |||||
} | |||||
} | |||||
} | |||||
void SetFlagOption(NodePassOption option, NamesToPass names_to_pass) { | |||||
for (auto &name_to_pass : names_to_pass) { | |||||
name_to_pass.second->SetOption(option, ""); | |||||
} | |||||
} | |||||
void ClearOption(NamesToPass names_to_pass) { | |||||
for (auto &name_to_pass : names_to_pass) { | |||||
name_to_pass.second->ClearOptions(); | |||||
} | |||||
} | |||||
} // namespace | |||||
Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map, | |||||
bool is_repass_io_immediately) { | |||||
if (node == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid."); | |||||
GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); | |||||
return FAILED; | |||||
} | |||||
GELOGI("Prepare to isolate and delete node, name:%s, type:%s.", node->GetName().c_str(), | |||||
node->GetType().c_str()); | |||||
ComputeGraphPtr graph = node->GetOwnerComputeGraph(); | |||||
if (graph == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", "The owner graph of node:%s must not be null.", node->GetName().c_str()); | |||||
GELOGE(FAILED, "[Get][OwnerComputeGraph] failed, The owner graph of node:%s must not be null.", | |||||
node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
is_repass_io_immediately ? AddImmediateRePassNodesWithInOut(node) : AddRePassNodesWithInOut(node); | |||||
if (GraphUtils::IsolateNode(node, io_map) != GRAPH_SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Isolate Node:%s failed", node->GetName().c_str()); | |||||
GELOGE(FAILED, "[Isolate][Node] %s failed.", node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "call RemoveNodeWithoutRelink for node:%s failed.", node->GetName().c_str()); | |||||
GELOGE(FAILED, "[Call][RemoveNodeWithoutRelink] for node:%s failed.", node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
AddNodeDeleted(node); | |||||
return SUCCESS; | |||||
} | |||||
Status GEPass::Run(const NamesToPass &names_to_passes) { | |||||
if (graph_ == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", "graph_ is nullptr, check invalid."); | |||||
GELOGE(INTERNAL_ERROR, "[Check][Param] The graph is nullptr"); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
if (names_to_passes.empty()) { | |||||
GELOGW("No passes input, the GEPass will do nothing"); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
for (const auto &name_to_pass : names_to_passes) { | |||||
if (name_to_pass.second == nullptr) { | |||||
GELOGE(INTERNAL_ERROR, "[Check][Param] There is null pointer in passes(%s)", name_to_pass.first.c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
} | |||||
if (depth_ > kMaxRecursiveDepth) { | |||||
GELOGE(PARAM_INVALID, | |||||
"[Check][Param] The pass for root graph %s will be terminated because too many nesting" | |||||
" levels(%d) of subgraphs, last subgraph is %s", | |||||
root_graph_->GetName().c_str(), depth_, graph_->GetName().c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
return RunPassesOneGraph(names_to_passes); | |||||
// todo debug mode is on, find first node in topo order which is not passed. and give a warning | |||||
} | |||||
void NotifyPassGraphStart(const ComputeGraphPtr &graph, const NamesToPass &names_to_pass) { | |||||
for (auto &name_to_pass : names_to_pass) { | |||||
name_to_pass.second->OnStartPassGraph(graph); | |||||
} | |||||
} | |||||
Status GEPass::HandleLeakedSuspendNodes(const NamesToPass &names_to_passes, GraphLevelState &g_state) { | |||||
std::unordered_map<NodePtr, std::string> resume_nodes_to_pass_names; | |||||
for (auto &name_to_pass : names_to_passes) { | |||||
name_to_pass.second->init(); | |||||
auto ret = name_to_pass.second->OnSuspendNodesLeaked(); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Internal error with OnSuspendNodesLeaked on pass %s.", name_to_pass.first.c_str()); | |||||
return ret; | |||||
} | |||||
for (const auto &resume_node : name_to_pass.second->GetNodesResume()){ | |||||
resume_nodes_to_pass_names[resume_node].append(name_to_pass.first + ","); | |||||
} | |||||
} | |||||
AddResumeNodesToQueue(resume_nodes_to_pass_names, g_state); | |||||
return SUCCESS; | |||||
} | |||||
Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { | |||||
GELOGD("Begin to run pass on graph, passes count %zu", names_to_passes.size()); | |||||
NotifyPassGraphStart(graph_, names_to_passes); | |||||
GraphLevelState g_state; | |||||
g_state.re_pass_times = 0; | |||||
GetAllNodesNoInputEdge(graph_, g_state); | |||||
GELOGD("Start points count %zu", g_state.nodes.size()); | |||||
do { | |||||
if (!g_state.nodes_suspend.empty()) { | |||||
auto ret = HandleLeakedSuspendNodes(names_to_passes, g_state); | |||||
if (ret != SUCCESS) { | |||||
// log inside upper function | |||||
return ret; | |||||
} | |||||
if (g_state.nodes.empty()) { | |||||
GELOGE(INTERNAL_ERROR, "There are some suspended nodes leaked and no pass resume them."); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
} | |||||
auto ret = RunPassesGraphRepass(names_to_passes, g_state); | |||||
if (ret != SUCCESS) { | |||||
return ret; | |||||
} | |||||
} while (!g_state.nodes_suspend.empty()); | |||||
return SUCCESS; | |||||
} | |||||
Status GEPass::RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLevelState &g_state) { | |||||
RepassLevelState rp_state; | |||||
do { | |||||
for (auto &node : rp_state.nodes_re_pass) { | |||||
if (rp_state.nodes_re_pass_set.count(node) > 0) { | |||||
GELOGD("Add node %s to queue for re-pass", node->GetName().c_str()); | |||||
g_state.AddNodeToQueue(node); | |||||
} | |||||
} | |||||
rp_state.ClearRepass(); | |||||
while (!g_state.nodes.empty()) { | |||||
auto node = g_state.PopFront(); | |||||
if (g_state.nodes_deleted.count(node) > 0) { | |||||
GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
rp_state.EraseNodeFromRepass(node); | |||||
g_state.nodes_seen.insert(node.get()); | |||||
// collect out nodes before pass | |||||
std::unordered_set<NodePtr> out_nodes_before_pass; | |||||
for (const auto &out_node : node->GetOutNodes()) { | |||||
out_nodes_before_pass.insert(out_node); | |||||
} | |||||
auto ret = RunPassesNodeOnce(node, names_to_passes, g_state, rp_state); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", node->GetName().c_str(), | |||||
node->GetType().c_str(), ret); | |||||
return ret; | |||||
} | |||||
AddNextIterNodes(node, out_nodes_before_pass, g_state); | |||||
} | |||||
AddLastNodesToQueue(g_state); | |||||
} while ((!rp_state.nodes_re_pass.empty() || !g_state.nodes.empty()) && ++g_state.re_pass_times < kMaxRePassTimes); | |||||
if (g_state.re_pass_times == kMaxRePassTimes) { | |||||
GELOGW("re_pass_times should not come to %d", kMaxRePassTimes); | |||||
} | |||||
GELOGD("All passes runs end"); | |||||
return SUCCESS; | |||||
} | |||||
Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph) { | |||||
auto sub_graph_names = node->GetOpDesc()->GetSubgraphInstanceNames(); | |||||
has_sub_graph = false; | |||||
for (const auto &name : sub_graph_names) { | |||||
auto graph = root_graph_->GetSubgraph(name); | |||||
if (graph == nullptr) { | |||||
GELOGW("Can not find the sub graph %s from node %s, the pass-process will skip it", | |||||
name.c_str(), node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
has_sub_graph = true; | |||||
GELOGI("Begin to run passes on the sub graph %s of node %s", name.c_str(), node->GetName().c_str()); | |||||
GEPass pass(graph, root_graph_, depth_ + 1); | |||||
auto ret = pass.Run(names_to_passes); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[Run][Passes] for sub graph:%s from node:%s failed", name.c_str(), node->GetName().c_str()); | |||||
return ret; | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status GEPass::RunPassesNodeOnce(NodePtr &node, const NamesToPass &names_to_passes, | |||||
GraphLevelState &g_state, RepassLevelState &rp_state) { | |||||
auto ret = RunPassesOnNode(node, names_to_passes, g_state, rp_state); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", node->GetName().c_str(), | |||||
node->GetType().c_str(), ret); | |||||
return ret; | |||||
} | |||||
bool has_sub_graph = false; | |||||
ret = RunPassesOnSubGraph(node, names_to_passes, has_sub_graph); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[Run][Passes] on the sub graph of node %s failed", node->GetName().c_str()); | |||||
return ret; | |||||
} | |||||
if (has_sub_graph) { | |||||
GELOGD("There are subgraphs on node %s, run passes for for the second time", node->GetName().c_str()); | |||||
SetFlagOption(kOptimizeAfterSubGraph, names_to_passes); | |||||
ret = RunPassesOnNode(node, names_to_passes, g_state, rp_state); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code: %u", node->GetName().c_str(), | |||||
node->GetType().c_str(), ret); | |||||
return ret; | |||||
} | |||||
// There is only one option scene, so set and clear options around the `RunPasses` func. | |||||
// if there are more than one scene to set options, the `ClearOption` function | |||||
// should be called each time at the begin of the iteration | |||||
ClearOption(names_to_passes); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status GEPass::RunPassesOnNode(NodePtr &node, const NamesToPass &names_to_passes, GraphLevelState &g_state, | |||||
RepassLevelState &rp_state) { | |||||
if (node == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid."); | |||||
GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); | |||||
return FAILED; | |||||
} | |||||
GELOGD("Begin to run pass for node %s", node->GetName().c_str()); | |||||
for (const auto &name_to_pass : names_to_passes) { | |||||
GELOGD("Begin to run pass %s for node %s", name_to_pass.first.c_str(), node->GetName().c_str()); | |||||
name_to_pass.second->init(); | |||||
auto result = name_to_pass.second->Run(node); | |||||
if (result != SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "process pass %s on node:%s failed, ret:%u", | |||||
name_to_pass.first.c_str(), node->GetName().c_str(), result); | |||||
GELOGE(INTERNAL_ERROR, "[Process][Pass] %s on node %s failed, result " | |||||
"%u, the passes will be terminated immediately.", | |||||
name_to_pass.first.c_str(), node->GetName().c_str(), result); | |||||
return result; | |||||
} | |||||
if (name_to_pass.second->GetNodesDeleted().count(node) > 0) { | |||||
GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), | |||||
name_to_pass.first.c_str()); | |||||
break; | |||||
} | |||||
} | |||||
g_state.nodes_passed.insert(node); | |||||
std::unordered_map<NodePtr, std::string> re_pass_imm_nodes_to_pass_names; | |||||
std::unordered_map<NodePtr, std::string> resume_nodes_to_pass_names; | |||||
// if muti psss repass one same node, it will add to queue many times, so collect and duplicate | |||||
for (const auto &name_to_pass : names_to_passes) { | |||||
PushToRePassIfSeen(node, name_to_pass, g_state.nodes_seen, | |||||
name_to_pass.second->GetNodesNeedRePass(), | |||||
rp_state); | |||||
// collect imm_node && resume_node among these passes | |||||
for (const auto &imm_node : name_to_pass.second->GetNodesNeedRePassImmediately()){ | |||||
re_pass_imm_nodes_to_pass_names[imm_node].append(name_to_pass.first + ","); | |||||
} | |||||
for (const auto &resume_node : name_to_pass.second->GetNodesResume()){ | |||||
resume_nodes_to_pass_names[resume_node].append(name_to_pass.first + ","); | |||||
} | |||||
for (const auto &suspend_node : name_to_pass.second->GetNodesSuspend()) { | |||||
GELOGD("The iteration suspend of node %s has been set by pass %s", suspend_node->GetName().c_str(), | |||||
name_to_pass.first.c_str()); | |||||
g_state.nodes_suspend.insert(suspend_node); | |||||
} | |||||
const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); | |||||
g_state.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); | |||||
} | |||||
AddImmediateRepassNodesToQueue(node, re_pass_imm_nodes_to_pass_names, g_state); | |||||
AddResumeNodesToQueue(resume_nodes_to_pass_names, g_state); | |||||
return SUCCESS; | |||||
} | |||||
} // namespace ge |
@@ -22,7 +22,6 @@ | |||||
#include <unordered_set> | #include <unordered_set> | ||||
#include <utility> | #include <utility> | ||||
#include <vector> | #include <vector> | ||||
#include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
#include "framework/common/types.h" | #include "framework/common/types.h" | ||||
#include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
@@ -40,6 +39,7 @@ enum NodePassOption { | |||||
}; | }; | ||||
class BaseNodePass { | class BaseNodePass { | ||||
// todo comments | |||||
public: | public: | ||||
/// | /// | ||||
/// Optimize on one node. the function can add nodes to the graph, change | /// Optimize on one node. the function can add nodes to the graph, change | ||||
@@ -51,7 +51,7 @@ class BaseNodePass { | |||||
virtual ~BaseNodePass() = default; | virtual ~BaseNodePass() = default; | ||||
const std::unordered_set<NodePtr> &GetNodesNeedRePass() { return nodes_need_re_pass_; } | |||||
const std::vector<NodePtr> &GetNodesNeedRePass() { return nodes_need_re_pass_; } | |||||
const std::unordered_set<NodePtr> &GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; } | const std::unordered_set<NodePtr> &GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; } | ||||
@@ -61,23 +61,32 @@ class BaseNodePass { | |||||
const std::unordered_set<NodePtr> &GetNodesResume() { return nodes_resume_; } | const std::unordered_set<NodePtr> &GetNodesResume() { return nodes_resume_; } | ||||
virtual Status OnSuspendNodesLeaked() { return SUCCESS; } | |||||
void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } | void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } | ||||
void ClearOptions() { options_.clear(); } | void ClearOptions() { options_.clear(); } | ||||
void init() { | void init() { | ||||
nodes_need_re_pass_.clear(); | nodes_need_re_pass_.clear(); | ||||
nodes_deleted_.clear(); | |||||
nodes_need_re_pass_immediately_.clear(); | nodes_need_re_pass_immediately_.clear(); | ||||
nodes_deleted_.clear(); | |||||
nodes_suspend_.clear(); | nodes_suspend_.clear(); | ||||
nodes_resume_.clear(); | nodes_resume_.clear(); | ||||
} | } | ||||
virtual void OnStartPassGraph(const ComputeGraphPtr &graph) { | |||||
current_graph_name_ = graph->GetName(); | |||||
} | |||||
protected: | protected: | ||||
Status IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map); | |||||
const string &GetCurrentGraphName() const { | |||||
return current_graph_name_; | |||||
} | |||||
Status IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map, bool is_repass_io_immediately = false); | |||||
Status IsolateAndDeleteNode(NodePtr &node, const std::initializer_list<int> &io_map) { | |||||
return IsolateAndDeleteNode(node, std::vector<int>(io_map)); | |||||
Status IsolateAndDeleteNode(NodePtr &node, const std::initializer_list<int> &io_map, bool is_repass_io_immediately = false) { | |||||
return IsolateAndDeleteNode(node, std::vector<int>(io_map), is_repass_io_immediately); | |||||
} | } | ||||
/// | /// | ||||
@@ -86,7 +95,7 @@ class BaseNodePass { | |||||
/// optimized by other passes, call this function. | /// optimized by other passes, call this function. | ||||
/// @param node | /// @param node | ||||
/// | /// | ||||
void AddRePassNode(const NodePtr &node) { nodes_need_re_pass_.insert(node); } | |||||
void AddRePassNode(const NodePtr &node) { nodes_need_re_pass_.emplace_back(node); } | |||||
/// | /// | ||||
/// Add a node to be optimized immediately again. If you add a new node to the graph, or | /// Add a node to be optimized immediately again. If you add a new node to the graph, or | ||||
@@ -101,14 +110,30 @@ class BaseNodePass { | |||||
/// @param node | /// @param node | ||||
/// | /// | ||||
void AddRePassNodesWithInOut(const NodePtr &node) { | void AddRePassNodesWithInOut(const NodePtr &node) { | ||||
auto in_nodes = node->GetInNodes(); | |||||
for (auto &in_node : in_nodes) { | |||||
AddRePassNode(in_node); | |||||
} | |||||
AddRePassNode(node); | AddRePassNode(node); | ||||
auto out_nodes = node->GetOutNodes(); | auto out_nodes = node->GetOutNodes(); | ||||
for (auto &out_node : out_nodes) { | for (auto &out_node : out_nodes) { | ||||
AddRePassNode(out_node); | AddRePassNode(out_node); | ||||
} | } | ||||
} | |||||
/// | |||||
/// Add a node and it's input/output data nodes to be optimized immediately again. | |||||
/// @param node | |||||
/// | |||||
void AddImmediateRePassNodesWithInOut(const NodePtr &node) { | |||||
auto in_nodes = node->GetInNodes(); | auto in_nodes = node->GetInNodes(); | ||||
for (auto &in_node : in_nodes) { | for (auto &in_node : in_nodes) { | ||||
AddRePassNode(in_node); | |||||
AddImmediateRePassNode(in_node); | |||||
} | |||||
AddImmediateRePassNode(node); | |||||
auto out_nodes = node->GetOutNodes(); | |||||
for (auto &out_node : out_nodes) { | |||||
AddImmediateRePassNode(out_node); | |||||
} | } | ||||
} | } | ||||
@@ -123,34 +148,27 @@ class BaseNodePass { | |||||
void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } | void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } | ||||
/// | /// | ||||
/// If you suspend a node from the graph, especially following node. The remain | |||||
/// iterate passes will stop process on the suspend node(if it can be | |||||
/// If you postpone a node from the graph, especially following node. The remain | |||||
/// iterate passes will stop process on the postpone node(if it can be | |||||
/// reached by edge connections) till the last one. Obviously it is a waste of | /// reached by edge connections) till the last one. Obviously it is a waste of | ||||
/// time. You can add the suspend nodes by calling this function, to stop the | |||||
/// time. You can add the postpone nodes by calling this function, to stop the | |||||
/// next iterations. | /// next iterations. | ||||
/// @param node | /// @param node | ||||
/// | /// | ||||
void AddNodeSuspend(const NodePtr &node) { nodes_suspend_.insert(node); } | void AddNodeSuspend(const NodePtr &node) { nodes_suspend_.insert(node); } | ||||
/// | |||||
/// If you resume a node from the graph, especially following node. The remain | |||||
/// iterate passes will continue process on the resume node(if it can be | |||||
/// reached by edge connections) till the last one. | |||||
/// You can add the resume nodes by calling this function, to resume the | |||||
/// next iterations. | |||||
/// @param node | |||||
/// | |||||
void AddNodeResume(const NodePtr &node) { nodes_resume_.insert(node); } | void AddNodeResume(const NodePtr &node) { nodes_resume_.insert(node); } | ||||
bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } | bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } | ||||
private: | private: | ||||
std::unordered_set<NodePtr> nodes_need_re_pass_; | |||||
std::vector<NodePtr> nodes_need_re_pass_; | |||||
std::unordered_set<NodePtr> nodes_need_re_pass_immediately_; | std::unordered_set<NodePtr> nodes_need_re_pass_immediately_; | ||||
std::unordered_set<NodePtr> nodes_deleted_; | std::unordered_set<NodePtr> nodes_deleted_; | ||||
std::unordered_set<NodePtr> nodes_suspend_; | std::unordered_set<NodePtr> nodes_suspend_; | ||||
std::unordered_set<NodePtr> nodes_resume_; | std::unordered_set<NodePtr> nodes_resume_; | ||||
std::map<NodePassOption, std::string> options_; | std::map<NodePassOption, std::string> options_; | ||||
std::string current_graph_name_; | |||||
}; | }; | ||||
using NamesToPass = std::vector<std::pair<std::string, BaseNodePass *>>; | using NamesToPass = std::vector<std::pair<std::string, BaseNodePass *>>; | ||||
@@ -160,12 +178,75 @@ class GEPass { | |||||
explicit GEPass(ComputeGraphPtr &graph) : graph_(graph), root_graph_(graph), depth_(1) {} | explicit GEPass(ComputeGraphPtr &graph) : graph_(graph), root_graph_(graph), depth_(1) {} | ||||
virtual ~GEPass() = default; | virtual ~GEPass() = default; | ||||
Status Run(const NamesToPass &names_to_passes); | Status Run(const NamesToPass &names_to_passes); | ||||
/* | |||||
* todo | |||||
* OneGraph: nodes_deleted, nodes_seen, nodes_passed, nodes_suspended | |||||
* RePass: nodes_re_pass | |||||
* GraphOneTime: nodes_last | |||||
* NodeOneTime: nodes_re_pass_immediately, nodes_resume | |||||
*/ | |||||
struct GraphLevelState { | |||||
std::unordered_set<NodePtr> nodes_deleted; | |||||
std::unordered_set<Node *> nodes_seen; | |||||
std::unordered_set<NodePtr> nodes_passed; | |||||
std::unordered_set<NodePtr> nodes_suspend; | |||||
std::unordered_set<NodePtr> nodes_last; | |||||
std::deque<NodePtr> nodes; | |||||
int re_pass_times; | |||||
void AddNodeToQueueFront(NodePtr node) { | |||||
nodes_seen.insert(node.get()); | |||||
nodes.emplace_front(std::move(node)); | |||||
} | |||||
void AddNodeToQueue(NodePtr node) { | |||||
nodes_seen.insert(node.get()); | |||||
nodes.emplace_back(std::move(node)); | |||||
} | |||||
void AddNodeToQueueIfNotSeen(NodePtr node) { | |||||
if (nodes_seen.insert(node.get()).second) { | |||||
nodes.emplace_back(std::move(node)); | |||||
} | |||||
} | |||||
NodePtr PopFront() { | |||||
NodePtr node = nodes.front(); | |||||
nodes.pop_front(); | |||||
return node; | |||||
} | |||||
}; | |||||
struct RepassLevelState { | |||||
std::vector<NodePtr> nodes_re_pass; | |||||
std::unordered_set<NodePtr> nodes_re_pass_set; | |||||
bool AddNodeToRepass(NodePtr node) { | |||||
if (!nodes_re_pass_set.insert(node).second) { | |||||
return false; | |||||
} | |||||
nodes_re_pass.emplace_back(node); | |||||
return true; | |||||
} | |||||
void EraseNodeFromRepass(NodePtr node) { | |||||
nodes_re_pass_set.erase(node); | |||||
} | |||||
void ClearRepass() { | |||||
nodes_re_pass_set.clear(); | |||||
nodes_re_pass.clear(); | |||||
} | |||||
}; | |||||
struct GraphOneTimeLevelState { | |||||
std::unordered_set<NodePtr> nodes_last; | |||||
}; | |||||
private: | private: | ||||
GEPass(ComputeGraphPtr &graph, ComputeGraphPtr &root_graph, int depth) | GEPass(ComputeGraphPtr &graph, ComputeGraphPtr &root_graph, int depth) | ||||
: graph_(graph), root_graph_(root_graph), depth_(depth) {} | : graph_(graph), root_graph_(root_graph), depth_(depth) {} | ||||
Status RunPassesNodeOnce(NodePtr &node, const NamesToPass &names_to_passes, | |||||
GraphLevelState &g_state, RepassLevelState &rp_state); | |||||
Status RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLevelState &g_state); | |||||
Status RunPassesOneGraph(const NamesToPass &names_to_passes); | Status RunPassesOneGraph(const NamesToPass &names_to_passes); | ||||
Status RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph); | Status RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph); | ||||
Status RunPassesOnNode(NodePtr &node, const NamesToPass &names_to_passes, GraphLevelState &g_state, | |||||
RepassLevelState &rp_state); | |||||
Status HandleLeakedSuspendNodes(const NamesToPass &names_to_passes, GraphLevelState &g_state); | |||||
ComputeGraphPtr graph_; | ComputeGraphPtr graph_; | ||||
ComputeGraphPtr root_graph_; | ComputeGraphPtr root_graph_; | ||||
int depth_; | int depth_; | ||||
@@ -86,6 +86,9 @@ bool InferBasePass::NeedInfer(const NodePtr &node) const { return true; } | |||||
void InferBasePass::AddChangedNodesImmediateRepass(const std::set<NodePtr> &changed_nodes) { | void InferBasePass::AddChangedNodesImmediateRepass(const std::set<NodePtr> &changed_nodes) { | ||||
// need passed_nodes set to solve the problem that multi-input operators do repass in advance. | // need passed_nodes set to solve the problem that multi-input operators do repass in advance. | ||||
// when there is passed_nodes set, wo should call AddImmediateRePassNode for all nodes in changed_nodes. | // when there is passed_nodes set, wo should call AddImmediateRePassNode for all nodes in changed_nodes. | ||||
for (const auto &node_ele : changed_nodes) { | |||||
AddImmediateRePassNode(node_ele); | |||||
} | |||||
} | } | ||||
graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes) { | graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes) { | ||||
@@ -1,175 +1,370 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/passes/infershape_pass.h" | |||||
#include "common/util/error_manager/error_manager.h" | |||||
#include "framework/common/debug/ge_log.h" | |||||
#include "analyzer/analyzer.h" | |||||
#include "framework/common/util.h" | |||||
#include "graph/shape_refiner.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
#include "graph/utils/node_utils.h" | |||||
#include "common/omg_util.h" | |||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "graph/utils/tensor_utils.h" | |||||
#include "graph/utils/type_utils.h" | |||||
namespace ge { | |||||
void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { | |||||
desc_str += "["; | |||||
std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
(void)desc->GetShapeRange(shape_range); | |||||
for (const auto &pair : shape_range) { | |||||
desc_str += "{"; | |||||
desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); | |||||
desc_str += "},"; | |||||
} | |||||
desc_str += "]"; | |||||
shape_range.clear(); | |||||
(void)desc->GetOriginShapeRange(shape_range); | |||||
for (const auto &pair : shape_range) { | |||||
desc_str += ",{"; | |||||
desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); | |||||
desc_str += "},"; | |||||
} | |||||
} | |||||
std::string GetInTensorInfoWithString(const ge::NodePtr &node) { | |||||
ge::OpDescPtr op_desc = node->GetOpDesc(); | |||||
std::stringstream ss; | |||||
ss << "{"; | |||||
int32_t in_idx = 0; | |||||
for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { | |||||
if (input_desc == nullptr) { | |||||
in_idx++; | |||||
continue; | |||||
} | |||||
if (in_idx > 0) { | |||||
ss << " "; | |||||
} | |||||
ss << "input_" << in_idx << " " << "tensor: ["; | |||||
ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),"; | |||||
ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),"; | |||||
ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),"; | |||||
ss << "(origin_shape:" << input_desc->GetOriginShape().ToString() << "),"; | |||||
ss << "(origin_format:" << TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) << "),"; | |||||
ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) << "),"; | |||||
string range_str; | |||||
SerialShapeRange(input_desc, range_str); | |||||
ss << "(shape_range:" << range_str << ")]"; | |||||
in_idx++; | |||||
} | |||||
return ss.str(); | |||||
} | |||||
Status InferShapePass::Run(NodePtr &node) { | |||||
// kOptimizeAfterSubGraph exist means after subgraph | |||||
auto ret = ShapeRefiner::InferShapeAndType(node, !OptionExists(kOptimizeAfterSubGraph)); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
// select INFERSHAPE failed info | |||||
auto graph = node->GetOwnerComputeGraph(); | |||||
GE_CHECK_NOTNULL(graph); | |||||
auto root_graph = ge::GraphUtils::FindRootGraph(graph); | |||||
GE_CHECK_NOTNULL(root_graph); | |||||
analyzer::DataInfo analyze_info{root_graph->GetSessionID(), root_graph->GetGraphID(), | |||||
analyzer::INFER_SHAPE, node, "InferShapeFailed!"}; | |||||
(void)Analyzer::GetInstance()->DoAnalyze(analyze_info); | |||||
(void)Analyzer::GetInstance()->SaveAnalyzerDataToFile(root_graph->GetSessionID(), | |||||
root_graph->GetGraphID()); | |||||
REPORT_CALL_ERROR("E19999", "Call InferShapeAndType for node:%s(%s) failed, input_tensor:%s", | |||||
node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str()); | |||||
GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "[Call][InferShapeAndType] for node:%s(%s) failed, input_tensor:%s", | |||||
node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str()); | |||||
return GE_GRAPH_INFERSHAPE_FAILED; | |||||
} | |||||
GE_CHK_STATUS_RET_NOLOG(RePassLoopNode(node)); | |||||
bool need_repass = false; | |||||
auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_repass); | |||||
if (has_attr) { | |||||
if (!OptionExists(kOptimizeAfterSubGraph)) { | |||||
return SUCCESS; | |||||
} | |||||
if (need_repass) { | |||||
AddImmediateRePassNode(node); | |||||
GELOGD("Node %s need repass immediately.", node->GetName().c_str()); | |||||
} else { | |||||
// clear attr on while | |||||
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status InferShapePass::RePassLoopNode(const NodePtr &node) { | |||||
const auto RePassNode = [&](const std::set<std::string> &re_pass_types) { | |||||
for (auto &n : node->GetOutDataNodes()) { | |||||
GE_CHECK_NOTNULL(n); | |||||
std::string node_type; | |||||
GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "[Get][OriginalType] of node:%s failed.", n->GetName().c_str()); | |||||
if (re_pass_types.count(node_type) > 0) { | |||||
AddImmediateRePassNode(n); | |||||
(void)AttrUtils::SetBool(n->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, false); | |||||
GELOGD("Node %s need repass immediately after %s.", n->GetName().c_str(), node->GetName().c_str()); | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
}; | |||||
const auto ExProcNode = [&](const std::set<std::string> &proc_types, | |||||
const std::function<void(InferShapePass *, NodePtr)> &proc_func, | |||||
const std::string &info) { | |||||
for (auto &n : node->GetOutDataNodes()) { | |||||
GE_CHECK_NOTNULL(n); | |||||
std::string node_type; | |||||
GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "[Get][OriginalType] of node:%s failed.", n->GetName().c_str()); | |||||
if (proc_types.count(node_type) > 0) { | |||||
proc_func(this, n); | |||||
GELOGD("Node %s %s after %s.", n->GetName().c_str(), info.c_str(), node->GetName().c_str()); | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
}; | |||||
std::string node_type; | |||||
GE_CHK_STATUS_RET(GetOriginalType(node, node_type), | |||||
"[Get][OriginalType] of node:%s failed.", node->GetName().c_str()); | |||||
if (kNextIterationOpTypes.count(node_type) > 0) { | |||||
return RePassNode(kMergeOpTypes); // Re-Pass Merge | |||||
} | |||||
if (kMergeOpTypes.count(node_type) > 0) { | |||||
if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | |||||
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | |||||
return RePassNode(kSwitchOpTypes); // Re-Pass Switch | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
if (kSwitchOpTypes.count(node_type) > 0) { | |||||
if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | |||||
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | |||||
return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeResume, "need resume"); // Resume Exit | |||||
} else { | |||||
return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeSuspend, "need suspend"); // Suspend Exit | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
} // namespace ge | |||||
/** | |||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
*+ | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/passes/infershape_pass.h" | |||||
#include "common/util/error_manager/error_manager.h" | |||||
#include "framework/common/debug/ge_log.h" | |||||
#include "analyzer/analyzer.h" | |||||
#include "framework/common/util.h" | |||||
#include "graph/shape_refiner.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
#include "graph/utils/node_utils.h" | |||||
#include "common/omg_util.h" | |||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "graph/utils/tensor_utils.h" | |||||
#include "graph/utils/type_utils.h" | |||||
#include "external/graph/operator_factory.h" | |||||
namespace ge { | |||||
namespace { | |||||
constexpr int kSwitchExitAnchorIndex = 0; | |||||
constexpr int kSwitchPredAnchorIndex = 1; | |||||
void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { | |||||
desc_str += "["; | |||||
std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
(void)desc->GetShapeRange(shape_range); | |||||
for (const auto &pair : shape_range) { | |||||
desc_str += "{"; | |||||
desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); | |||||
desc_str += "},"; | |||||
} | |||||
desc_str += "]"; | |||||
shape_range.clear(); | |||||
(void)desc->GetOriginShapeRange(shape_range); | |||||
for (const auto &pair : shape_range) { | |||||
desc_str += ",{"; | |||||
desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); | |||||
desc_str += "},"; | |||||
} | |||||
} | |||||
void UpdateShapeAndDType(const GeTensorDescPtr &src, GeTensorDescPtr &dst) { | |||||
dst->SetOriginShape(src->GetOriginShape()); | |||||
dst->SetShape(src->GetShape()); | |||||
dst->SetDataType(src->GetDataType()); | |||||
dst->SetOriginDataType(src->GetOriginDataType()); | |||||
vector<pair<int64_t, int64_t>> src_shape_range; | |||||
src->GetShapeRange(src_shape_range); | |||||
dst->SetShapeRange(src_shape_range); | |||||
dst->SetOriginShapeRange(src_shape_range); | |||||
ge::TensorUtils::SetRealDimCnt(*dst, static_cast<uint32_t>(src->GetShape().GetDims().size())); | |||||
} | |||||
} // namespace | |||||
std::string InferShapePass::SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const { | |||||
std::stringstream ss; | |||||
ss << "(shape:[" << tensor_desc->MutableShape().ToString() << "]),"; | |||||
ss << "(format:" << TypeUtils::FormatToSerialString(tensor_desc->GetFormat()) << "),"; | |||||
ss << "(dtype:" << TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()) << "),"; | |||||
ss << "(origin_shape:" << tensor_desc->GetOriginShape().ToString() << "),"; | |||||
ss << "(origin_format:" << TypeUtils::FormatToSerialString(tensor_desc->GetOriginFormat()) << "),"; | |||||
ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(tensor_desc->GetOriginDataType()) << "),"; | |||||
string range_str; | |||||
SerialShapeRange(tensor_desc, range_str); | |||||
ss << "(shape_range:" << range_str << ")"; | |||||
return ss.str(); | |||||
} | |||||
Status InferShapePass::SuspendV1LoopExitNodes(const NodePtr &node) { | |||||
if (node->GetType() != SWITCH) { | |||||
return SUCCESS; | |||||
} | |||||
auto pred_node = NodeUtils::GetInDataNodeByIndex(*node, kSwitchPredAnchorIndex); | |||||
GE_CHECK_NOTNULL(pred_node); | |||||
if (pred_node->GetType() != LOOPCOND) { | |||||
return SUCCESS; | |||||
} | |||||
for (const auto &anchor_2_node : NodeUtils::GetOutDataNodesWithAnchorByIndex(*node, kSwitchExitAnchorIndex)) { | |||||
GELOGI("Found v1 loop when infershape, suspend Exit node %s, type %s.", anchor_2_node.second->GetName().c_str(), | |||||
anchor_2_node.second->GetType().c_str()); | |||||
auto &suspend_nodes = graphs_2_suspend_nodes_[GetCurrentGraphName()]; | |||||
if (suspend_nodes.nodes_set.insert(anchor_2_node.second).second) { | |||||
suspend_nodes.nodes.push(anchor_2_node.second); | |||||
AddNodeSuspend(anchor_2_node.second); | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status InferShapePass::Infer(NodePtr &node) { | |||||
auto ret = InferShapeAndType(node); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
auto graph = node->GetOwnerComputeGraph(); | |||||
GE_CHECK_NOTNULL(graph); | |||||
auto root_graph = ge::GraphUtils::FindRootGraph(graph); | |||||
GE_CHECK_NOTNULL(root_graph); | |||||
analyzer::DataInfo analyze_info{root_graph->GetSessionID(), root_graph->GetGraphID(), | |||||
analyzer::INFER_SHAPE, node, "InferShapeFailed!"}; | |||||
(void)Analyzer::GetInstance()->DoAnalyze(analyze_info); | |||||
(void)Analyzer::GetInstance()->SaveAnalyzerDataToFile(root_graph->GetSessionID(), | |||||
root_graph->GetGraphID()); | |||||
REPORT_CALL_ERROR("E19999", "Call InferShapeAndType for node:%s(%s) failed", node->GetName().c_str(), | |||||
node->GetType().c_str()); | |||||
GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "[Call][InferShapeAndType] for node:%s(%s) failed", node->GetName().c_str(), | |||||
node->GetType().c_str()); | |||||
return GE_GRAPH_INFERSHAPE_FAILED; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
graphStatus InferShapePass::InferShapeAndType(NodePtr &node) { | |||||
auto ret = SuspendV1LoopExitNodes(node); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Suspend V1 loop exit nodes failed."); | |||||
return ret; | |||||
} | |||||
bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); | |||||
auto opdesc = node->GetOpDesc(); | |||||
if (node->Verify() != GRAPH_SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Verifying %s failed.", node->GetName().c_str()); | |||||
GELOGE(GRAPH_FAILED, "[Call][Verify] Verifying %s failed.", node->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
Operator op = OpDescUtils::CreateOperatorFromNode(node); | |||||
if (!is_unknown_graph) { | |||||
auto inference_context = ShapeRefiner::CreateInferenceContext(node); | |||||
GE_CHECK_NOTNULL(inference_context); | |||||
GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size()); | |||||
op.SetInferenceContext(inference_context); | |||||
} | |||||
graphStatus status = CallInferShapeFunc(node, op); | |||||
if (status != GRAPH_NODE_NEED_REPASS && status != GRAPH_PARAM_INVALID && status != GRAPH_SUCCESS) { | |||||
// node like netoutput return param_invalid, but valid ? | |||||
return GE_GRAPH_INFERSHAPE_FAILED; | |||||
} | |||||
UpdateCurNodeOutputDesc(node); | |||||
if (!is_unknown_graph) { | |||||
auto ctx_after_infer = op.GetInferenceContext(); | |||||
if (ctx_after_infer != nullptr) { | |||||
GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size()); | |||||
if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) { | |||||
GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), | |||||
ctx_after_infer->GetMarks().size()); | |||||
ShapeRefiner::PushToContextMap(node, ctx_after_infer); | |||||
} | |||||
} | |||||
} | |||||
return (status == GRAPH_NODE_NEED_REPASS) ? GRAPH_NODE_NEED_REPASS : GRAPH_SUCCESS; | |||||
} | |||||
void InferShapePass::UpdateCurNodeOutputDesc(NodePtr &node) { | |||||
auto op_desc = node->GetOpDesc(); | |||||
for (const auto &out_anchor : node->GetAllOutDataAnchors()) { | |||||
auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); | |||||
GE_IF_BOOL_EXEC(output_tensor == nullptr, continue); | |||||
GE_IF_BOOL_EXEC(output_tensor->MutableShape().GetDims().empty(), | |||||
output_tensor->SetOriginShape(output_tensor->GetShape())); | |||||
ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetOriginShape().GetDims() | |||||
.size())); | |||||
output_tensor->SetOriginDataType(output_tensor->GetDataType()); | |||||
// set output origin shape range | |||||
std::vector<std::pair<int64_t, int64_t>> range; | |||||
(void)output_tensor->GetShapeRange(range); | |||||
output_tensor->SetOriginShapeRange(range); | |||||
GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", | |||||
node->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), | |||||
TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), | |||||
TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); | |||||
} | |||||
} | |||||
bool InferShapePass::SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) { | |||||
// check shape range | |||||
vector<std::pair<int64_t, int64_t>> src_shape_range; | |||||
vector<std::pair<int64_t, int64_t>> dst_shape_range; | |||||
src->GetShapeRange(src_shape_range); | |||||
dst->GetShapeRange(dst_shape_range); | |||||
if (src_shape_range.size() != dst_shape_range.size()) { | |||||
GELOGI("Src shape range size is %zu, dst shape range size is %zu, not same.", src_shape_range.size(), | |||||
dst_shape_range.size()); | |||||
return false; | |||||
} | |||||
for (size_t i = 0; i < src_shape_range.size(); ++i) { | |||||
if (src_shape_range[i].first != dst_shape_range[i].first || | |||||
src_shape_range[i].second != dst_shape_range[i].second) { | |||||
GELOGI("Current dim %zu. Src shape range is [%lu-%lu], dst shape range is [%lu-%lu], not same.", | |||||
i, src_shape_range[i].first, src_shape_range[i].second, dst_shape_range[i].first, dst_shape_range[i].second); | |||||
return false; | |||||
} | |||||
} | |||||
// check shape | |||||
auto src_shape = src->GetShape(); | |||||
auto dst_shape = dst->GetShape(); | |||||
if (src_shape.GetDims() != dst_shape.GetDims() || src->GetOriginShape().GetDims() != dst->GetOriginShape().GetDims() || | |||||
src->GetDataType() != dst->GetDataType() || src->GetOriginDataType() != dst->GetOriginDataType()) { | |||||
GELOGD( | |||||
"Src shape is %s, origin_shape is %s, data_type is %s, origin data_type is %s; " | |||||
"Dst shape is %s, origin_shape is %s, data_type is %s, original data_type is %s, not same.", | |||||
src_shape.ToString().c_str(), src->GetOriginShape().ToString().c_str(), | |||||
TypeUtils::DataTypeToSerialString(src->GetDataType()).c_str(), | |||||
TypeUtils::DataTypeToSerialString(src->GetOriginDataType()).c_str(), dst_shape.ToString().c_str(), | |||||
dst->GetOriginShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(dst->GetDataType()).c_str(), | |||||
TypeUtils::DataTypeToSerialString(dst->GetOriginDataType()).c_str()); | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
graphStatus InferShapePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { | |||||
changed = !SameTensorDesc(src, dst); | |||||
// refresh src itself | |||||
src->SetOriginShape(src->GetShape()); | |||||
src->SetOriginDataType(src->GetDataType()); | |||||
TensorUtils::SetRealDimCnt(*src, static_cast<uint32_t>(src->GetOriginShape().GetDims().size())); | |||||
vector<pair<int64_t, int64_t>> src_shape_range; | |||||
src->GetShapeRange(src_shape_range); | |||||
src->SetOriginShapeRange(src_shape_range); | |||||
if (!changed) { | |||||
GELOGD("Peer dst tensor_desc is same as src tensor_desc. No need update."); | |||||
return SUCCESS; | |||||
} | |||||
UpdateShapeAndDType(src, dst); | |||||
GELOGD( | |||||
"UpdatePeerInputDesc from src Node: shape: [%s], datatype: %s, original datatype is %s." | |||||
"To dst Node: shape: [%s], datatype: %s, original datatype is %s.", | |||||
src->GetShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(src->GetDataType()).c_str(), | |||||
TypeUtils::DataTypeToSerialString(src->GetOriginDataType()).c_str(), dst->GetShape().ToString().c_str(), | |||||
TypeUtils::DataTypeToSerialString(dst->GetDataType()).c_str(), | |||||
TypeUtils::DataTypeToSerialString(dst->GetOriginDataType()).c_str()); | |||||
return SUCCESS; | |||||
} | |||||
graphStatus InferShapePass::CallInferShapeFunc(NodePtr &node, Operator &op) { | |||||
auto op_desc = node->GetOpDesc(); | |||||
const auto &op_type = op_desc->GetType(); | |||||
auto ret = op_desc->CallInferFunc(op); | |||||
if (ret == GRAPH_PARAM_INVALID) { | |||||
// Op ir no infer func, try to get infer func from operator factory | |||||
auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType()); | |||||
if (node_op.IsEmpty()) { | |||||
GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str()); | |||||
return ret; | |||||
} | |||||
GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str()); | |||||
auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); | |||||
node_op.BreakConnect(); | |||||
if (temp_op_desc == nullptr) { | |||||
REPORT_CALL_ERROR("E19999", "GetOpDescFromOperator failed, return nullptr."); | |||||
GELOGE(GRAPH_FAILED, "[Get][OpDesc] temp op desc is null"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) { | |||||
GELOGW("InferShapeAndType UpdateInputName failed"); | |||||
for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) { | |||||
if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) { | |||||
break; | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
} | |||||
if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) { | |||||
GELOGW("InferShapeAndType UpdateOutputName failed"); | |||||
} | |||||
op_desc->AddInferFunc(temp_op_desc->GetInferFunc()); | |||||
ret = op_desc->CallInferFunc(op); | |||||
GELOGI("op CallInferFunc second. ret: %u", ret); | |||||
} | |||||
return ret; | |||||
} | |||||
graphStatus InferShapePass::UpdateOutputFromSubgraphs(const std::vector<GeTensorDescPtr> &src, GeTensorDescPtr &dst) { | |||||
GELOGD("Enter update parent node shape for class branch op process"); | |||||
// check sub_graph shape.If not same ,do unknown shape process | |||||
auto ref_out_tensor = src.at(0); | |||||
ge::GeShape &ref_out_tensor_shape = ref_out_tensor->MutableShape(); | |||||
for (auto &tensor : src) { | |||||
if (ref_out_tensor->GetDataType() != tensor->GetDataType()) { | |||||
REPORT_INNER_ERROR("E19999", "Does not support diff dtype among all ref output, shape:%s", | |||||
ref_out_tensor_shape.ToString().c_str()); | |||||
GELOGE(GRAPH_FAILED, "[Check][Param] node does not support diff dtype output"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
auto shape = tensor->MutableShape(); | |||||
if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { | |||||
GELOGD("Shape from subgraph size: %lu, ref_out_tensor_shape size: %lu", shape.GetShapeSize(), | |||||
ref_out_tensor_shape.GetShapeSize()); | |||||
ref_out_tensor_shape = GeShape(UNKNOWN_RANK); | |||||
break; | |||||
} | |||||
for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) { | |||||
if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) { | |||||
continue; | |||||
} | |||||
GELOGD("j: %zu ,shape from subgraph size: %lu, ref_out_tensor_shape size: %lu", j, shape.GetShapeSize(), | |||||
ref_out_tensor_shape.GetShapeSize()); | |||||
(void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM); | |||||
} | |||||
} | |||||
UpdateShapeAndDType(ref_out_tensor, dst); | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
graphStatus InferShapePass::UpdateOutputFromSubgraphsForMultiDims(const std::vector<GeTensorDescPtr> &src, | |||||
GeTensorDescPtr &dst) { | |||||
// check sub_graph shape. Get max for update. | |||||
if (src.empty()) { | |||||
GELOGI("Src subgraph shape is empty."); | |||||
return SUCCESS; | |||||
} | |||||
int64_t max_size = 0; | |||||
size_t max_shape_index = 0; | |||||
auto &ref_out_tensor = src.at(0); | |||||
for (size_t j = 0; j < src.size(); ++j) { | |||||
auto &tensor = src.at(j); | |||||
if (ref_out_tensor->GetDataType() != tensor->GetDataType()) { | |||||
REPORT_INNER_ERROR("E19999", "node does not support diff dtype among all ref output"); | |||||
GELOGE(GRAPH_FAILED, "[Check][Param] node does not support diff dtype among all ref output"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
auto shape = tensor->MutableShape(); | |||||
int64_t size = 1; | |||||
for (auto dim : shape.GetDims()) { | |||||
if (dim != 0 && INT64_MAX / dim < size) { | |||||
REPORT_INNER_ERROR("E19999", "The shape:%s size overflow", shape.ToString().c_str()); | |||||
GELOGE(PARAM_INVALID, "[Check][Overflow] The shape size overflow"); | |||||
return PARAM_INVALID; | |||||
} | |||||
size *= dim; | |||||
} | |||||
if (size > max_size) { | |||||
max_size = size; | |||||
max_shape_index = j; | |||||
} | |||||
} | |||||
UpdateShapeAndDType(src.at(max_shape_index), dst); | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
Status InferShapePass::OnSuspendNodesLeaked() { | |||||
auto iter = graphs_2_suspend_nodes_.find(GetCurrentGraphName()); | |||||
if (iter == graphs_2_suspend_nodes_.end()) { | |||||
GELOGI("Current graph %s no suspend node.", GetCurrentGraphName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
if (!iter->second.nodes.empty()) { | |||||
AddNodeResume(iter->second.PopSuspendedNode()); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
} // namespace ge |
@@ -1,38 +1,56 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ | |||||
#define GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ | |||||
#include "graph/passes/base_pass.h" | |||||
namespace ge { | |||||
class InferShapePass : public BaseNodePass { | |||||
public: | |||||
/// | |||||
/// Entry of the InferShapePass optimizer | |||||
/// @param [in] graph: Input ComputeGraph | |||||
/// @return SUCCESS: Execution succeed | |||||
/// @return OTHERS: Execution failed | |||||
/// @author | |||||
/// | |||||
Status Run(ge::NodePtr &node) override; | |||||
private: | |||||
Status RePassLoopNode(const NodePtr &node); | |||||
}; | |||||
} // namespace ge | |||||
#endif // GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ | |||||
/** | |||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ | |||||
#define GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ | |||||
#include "graph/passes/infer_base_pass.h" | |||||
#include <stack> | |||||
namespace ge { | |||||
class InferShapePass : public InferBasePass { | |||||
public: | |||||
std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override; | |||||
graphStatus Infer(NodePtr &node) override; | |||||
graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override; | |||||
graphStatus UpdateOutputFromSubgraphs(const std::vector<GeTensorDescPtr> &src, GeTensorDescPtr &dst) override; | |||||
graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector<GeTensorDescPtr> &src, | |||||
GeTensorDescPtr &dst) override; | |||||
Status OnSuspendNodesLeaked() override; | |||||
private: | |||||
graphStatus InferShapeAndType(NodePtr &node); | |||||
graphStatus CallInferShapeFunc(NodePtr &node, Operator &op); | |||||
bool SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst); | |||||
void UpdateCurNodeOutputDesc(NodePtr &node); | |||||
Status SuspendV1LoopExitNodes(const NodePtr &node); | |||||
struct SuspendNodes { | |||||
std::stack<NodePtr> nodes; | |||||
std::unordered_set<NodePtr> nodes_set; | |||||
NodePtr PopSuspendedNode() { | |||||
auto top_node = nodes.top(); | |||||
nodes.pop(); | |||||
nodes_set.erase(top_node); | |||||
return top_node; | |||||
} | |||||
}; | |||||
std::map<std::string, SuspendNodes> graphs_2_suspend_nodes_; | |||||
}; | |||||
} // namespace ge | |||||
#endif // GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ |
@@ -1999,6 +1999,22 @@ Status GraphPrepare::CheckUserInput(const std::vector<GeTensor> &user_input) { | |||||
Status GraphPrepare::InferShapeForPreprocess() { | Status GraphPrepare::InferShapeForPreprocess() { | ||||
GELOGI("Start infershape for preprocess."); | GELOGI("Start infershape for preprocess."); | ||||
// Prepare dummy_shape for v1 control_flow op before infershape | |||||
for (const auto &node : compute_graph_->GetAllNodes()) { | |||||
string type; | |||||
GetOriginalType(node, type); | |||||
if (type == MERGE || type == REFMERGE) { | |||||
for (size_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) { | |||||
GELOGD("Prepare for infershape: update %s input_shape as dummy.", node->GetName().c_str()); | |||||
NodeUtils::UpdateInputShape(*node, i, GeShape(DUMMY_SHAPE)); | |||||
} | |||||
} else if (type == WHILE) { | |||||
for (size_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) { | |||||
GELOGD("Prepare for infershape: update %s output_shape as dummy.", node->GetName().c_str()); | |||||
NodeUtils::UpdateOutputShape(*node, i, GeShape(DUMMY_SHAPE)); | |||||
} | |||||
} | |||||
} | |||||
GEPass ge_passes(compute_graph_); | GEPass ge_passes(compute_graph_); | ||||
NamesToPass names_to_passes; | NamesToPass names_to_passes; | ||||
AssertPass assert_pass; | AssertPass assert_pass; | ||||
@@ -72,7 +72,7 @@ TEST(UtestGraphPassesAddnPass, null_pass) { | |||||
AddNPass *addn_pass = nullptr; | AddNPass *addn_pass = nullptr; | ||||
NamesToPass names_to_pass; | NamesToPass names_to_pass; | ||||
names_to_pass.emplace_back("Test", addn_pass); | names_to_pass.emplace_back("Test", addn_pass); | ||||
EXPECT_EQ(pass.Run(names_to_pass), SUCCESS); | |||||
EXPECT_EQ(pass.Run(names_to_pass), INTERNAL_ERROR); | |||||
} | } | ||||
TEST(UtestGraphPassesAddnPass, null_graph) { | TEST(UtestGraphPassesAddnPass, null_graph) { | ||||
@@ -1,161 +1,262 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#define protected public | |||||
#define private public | |||||
#include "graph/passes/infershape_pass.h" | |||||
#include "graph/utils/tensor_utils.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
#include "graph/operator_factory.h" | |||||
#include "graph/operator_reg.h" | |||||
#include "graph_builder_utils.h" | |||||
using namespace std; | |||||
using namespace testing; | |||||
namespace ge { | |||||
class UtestGraphInfershapePass : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | |||||
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||||
op_desc->SetStreamId(0); | |||||
static int32_t index = 0; | |||||
op_desc->SetId(index++); | |||||
GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||||
TensorUtils::SetSize(tensor, 512); | |||||
vector<int64_t> input_offset; | |||||
for (int i = 0; i < in_num; i++) { | |||||
op_desc->AddInputDesc(tensor); | |||||
input_offset.emplace_back(1024); | |||||
} | |||||
op_desc->SetInputOffset(input_offset); | |||||
vector<int64_t> output_offset; | |||||
for (int i = 0; i < out_num; i++) { | |||||
op_desc->AddOutputDesc(tensor); | |||||
output_offset.emplace_back(1024); | |||||
} | |||||
op_desc->SetOutputOffset(output_offset); | |||||
op_desc->SetWorkspace({}); | |||||
op_desc->SetWorkspaceBytes({}); | |||||
op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); | |||||
const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; }; | |||||
op_desc->AddInferFunc(stub_func); | |||||
op_desc->AddInferFormatFunc(stub_func); | |||||
op_desc->AddVerifierFunc(stub_func); | |||||
return graph.AddNode(op_desc); | |||||
} | |||||
TEST_F(UtestGraphInfershapePass, infershape_pass_failed) { | |||||
GeTensorDesc ge_tensor_desc(GeShape({-2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); | |||||
string type = "AddN"; | |||||
auto addn_op_desc = std::make_shared<OpDesc>("AddN", type); | |||||
addn_op_desc->AddInputDesc(ge_tensor_desc); | |||||
addn_op_desc->AddOutputDesc(ge_tensor_desc); | |||||
auto graph = std::make_shared<ComputeGraph>("test"); | |||||
auto addn_node = std::make_shared<Node>(addn_op_desc, graph); | |||||
addn_node->Init(); | |||||
InferShapePass infershape_pass; | |||||
EXPECT_EQ(infershape_pass.Run(addn_node), GE_GRAPH_INFERSHAPE_FAILED); | |||||
} | |||||
TEST_F(UtestGraphInfershapePass, delete_need_infer_again) { | |||||
auto graph = std::make_shared<ComputeGraph>("test"); | |||||
auto no_op_desc = std::make_shared<OpDesc>("No", "NoOp"); | |||||
auto no_op_node = graph->AddNode(no_op_desc); | |||||
AttrUtils::SetBool(no_op_desc, "_need_infer_again", false); | |||||
InferShapePass infershape_pass; | |||||
infershape_pass.options_[kOptimizeAfterSubGraph] = "yes"; | |||||
EXPECT_EQ(infershape_pass.Run(no_op_node), SUCCESS); | |||||
} | |||||
TEST_F(UtestGraphInfershapePass, stop_node_for_while_loop) { | |||||
/******************************************************************************* | |||||
* Exit Identify | |||||
* \ / \. | |||||
* \ / \. | |||||
* Switch Add | |||||
* / | | | |||||
* / | | | |||||
* / | | | |||||
* LoopCond | | | |||||
* \ | | | |||||
* \ | | | |||||
* \ | | | |||||
* Less | | | |||||
* \ | NextIteration | |||||
* \ | | | |||||
* \ | | | |||||
* Merge <---------| | |||||
* | | |||||
* | | |||||
* Enter | |||||
******************************************************************************/ | |||||
auto graph = std::make_shared<ComputeGraph>("test_infer_shape"); | |||||
auto data1 = CreateNode(*graph, "data", DATA, 1, 1); | |||||
auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); | |||||
auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2); | |||||
auto less1 = CreateNode(*graph, "less", LESS, 2, 1); | |||||
auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1); | |||||
auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2); | |||||
auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); | |||||
auto add1 = CreateNode(*graph, "add", ADD, 2, 1); | |||||
auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); | |||||
auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); | |||||
auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||||
auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||||
auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); | |||||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||||
GEPass ge_passes(graph); | |||||
NamesToPass names_to_passes; | |||||
InferShapePass infer_shape_pass; | |||||
names_to_passes.emplace_back("InferShapePass", &infer_shape_pass); | |||||
EXPECT_EQ(ge_passes.Run(names_to_passes), SUCCESS); | |||||
} | |||||
} // namespace ge | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#define protected public | |||||
#define private public | |||||
#include "graph/passes/infershape_pass.h" | |||||
#include "graph/utils/tensor_utils.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
#include "graph/operator_factory.h" | |||||
#include "graph/operator_reg.h" | |||||
#include "graph_builder_utils.h" | |||||
using namespace std; | |||||
using namespace testing; | |||||
namespace ge { | |||||
class UtestGraphInfershapePass : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | |||||
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||||
op_desc->SetStreamId(0); | |||||
static int32_t index = 0; | |||||
op_desc->SetId(index++); | |||||
GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||||
TensorUtils::SetSize(tensor, 512); | |||||
vector<int64_t> input_offset; | |||||
for (int i = 0; i < in_num; i++) { | |||||
op_desc->AddInputDesc(tensor); | |||||
input_offset.emplace_back(1024); | |||||
} | |||||
op_desc->SetInputOffset(input_offset); | |||||
vector<int64_t> output_offset; | |||||
for (int i = 0; i < out_num; i++) { | |||||
op_desc->AddOutputDesc(tensor); | |||||
output_offset.emplace_back(1024); | |||||
} | |||||
op_desc->SetOutputOffset(output_offset); | |||||
op_desc->SetWorkspace({}); | |||||
op_desc->SetWorkspaceBytes({}); | |||||
op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); | |||||
const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; }; | |||||
op_desc->AddInferFunc(stub_func); | |||||
op_desc->AddInferFormatFunc(stub_func); | |||||
op_desc->AddVerifierFunc(stub_func); | |||||
return graph.AddNode(op_desc); | |||||
} | |||||
TEST_F(UtestGraphInfershapePass, infershape_pass_failed) { | |||||
GeTensorDesc ge_tensor_desc(GeShape({-2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); | |||||
string type = "AddN"; | |||||
auto addn_op_desc = std::make_shared<OpDesc>("AddN", type); | |||||
addn_op_desc->AddInputDesc(ge_tensor_desc); | |||||
addn_op_desc->AddOutputDesc(ge_tensor_desc); | |||||
auto graph = std::make_shared<ComputeGraph>("test"); | |||||
auto addn_node = std::make_shared<Node>(addn_op_desc, graph); | |||||
addn_node->Init(); | |||||
InferShapePass infershape_pass; | |||||
EXPECT_EQ(infershape_pass.Run(addn_node), GRAPH_FAILED); | |||||
} | |||||
TEST_F(UtestGraphInfershapePass, stop_node_for_while_loop) { | |||||
/******************************************************************************* | |||||
* Exit Identify | |||||
* \ / \. | |||||
* \ / \. | |||||
* Switch Add | |||||
* / | | | |||||
* / | | | |||||
* / | | | |||||
* LoopCond | | | |||||
* \ | | | |||||
* \ | | | |||||
* \ | | | |||||
* Less | | | |||||
* \ | NextIteration | |||||
* \ | | | |||||
* \ | | | |||||
* Merge <---------| | |||||
* | | |||||
* | | |||||
* Enter | |||||
******************************************************************************/ | |||||
auto graph = std::make_shared<ComputeGraph>("test_infer_shape"); | |||||
auto data1 = CreateNode(*graph, "data", DATA, 1, 1); | |||||
auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); | |||||
auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2); | |||||
auto less1 = CreateNode(*graph, "less", LESS, 2, 1); | |||||
auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1); | |||||
auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2); | |||||
auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); | |||||
auto add1 = CreateNode(*graph, "add", ADD, 2, 1); | |||||
auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); | |||||
auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); | |||||
auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||||
auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||||
auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); | |||||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||||
GEPass ge_passes(graph); | |||||
NamesToPass names_to_passes; | |||||
InferShapePass infer_shape_pass; | |||||
names_to_passes.emplace_back("InferShapePass", &infer_shape_pass); | |||||
EXPECT_EQ(infer_shape_pass.Run(switch1), SUCCESS); | |||||
auto suspend_nodes = infer_shape_pass.GetNodesSuspend(); | |||||
auto exit_node = graph->FindNode("exit"); | |||||
EXPECT_EQ(suspend_nodes.count(exit_node), 1); | |||||
infer_shape_pass.OnSuspendNodesLeaked(); | |||||
auto resume_nodes = infer_shape_pass.GetNodesResume(); | |||||
EXPECT_EQ(resume_nodes.count(exit_node), 1); | |||||
} | |||||
TEST_F(UtestGraphInfershapePass, update_tensordesc_when_changed) { | |||||
GeTensorDesc src_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); | |||||
GeTensorDesc dst_ge_tensor_desc(GeShape({2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); | |||||
GeTensorDescPtr src_tensor_desc_ptr = std::make_shared<GeTensorDesc>(src_ge_tensor_desc); | |||||
GeTensorDescPtr dst_tensor_desc_ptr = std::make_shared<GeTensorDesc>(dst_ge_tensor_desc); | |||||
InferShapePass infershape_pass; | |||||
bool changed = false; | |||||
infershape_pass.UpdateTensorDesc(src_tensor_desc_ptr, dst_tensor_desc_ptr, changed); | |||||
EXPECT_EQ(changed, true); | |||||
EXPECT_EQ(dst_tensor_desc_ptr->GetShape().GetDims(), std::vector<int64_t>({1, 2, 3, 4})); | |||||
} | |||||
TEST_F(UtestGraphInfershapePass, update_tensordesc_when_not_changed) { | |||||
GeTensorDesc src_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); | |||||
GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); | |||||
GeTensorDescPtr src_tensor_desc_ptr = std::make_shared<GeTensorDesc>(src_ge_tensor_desc); | |||||
GeTensorDescPtr dst_tensor_desc_ptr = std::make_shared<GeTensorDesc>(dst_ge_tensor_desc); | |||||
InferShapePass infershape_pass; | |||||
bool changed = false; | |||||
infershape_pass.UpdateTensorDesc(src_tensor_desc_ptr, dst_tensor_desc_ptr, changed); | |||||
EXPECT_EQ(changed, false); | |||||
} | |||||
TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_failed) { | |||||
// ref output has different dtype | |||||
GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); | |||||
GeTensorDesc ge_tensor_desc2(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); | |||||
GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); | |||||
GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc1); | |||||
GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc2); | |||||
GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared<GeTensorDesc>(dst_ge_tensor_desc); | |||||
InferShapePass infershape_pass; | |||||
auto ret = infershape_pass.UpdateOutputFromSubgraphs({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr}, dst_ge_tensor_desc_ptr); | |||||
EXPECT_EQ(ret, GRAPH_FAILED); | |||||
} | |||||
TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_get_unknown_rank) { | |||||
// ref output has different dtype | |||||
GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3}), ge::FORMAT_NCHW, DT_FLOAT); | |||||
GeTensorDesc ge_tensor_desc2(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); | |||||
GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); | |||||
GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc1); | |||||
GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc2); | |||||
GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared<GeTensorDesc>(dst_ge_tensor_desc); | |||||
InferShapePass infershape_pass; | |||||
auto ret = infershape_pass.UpdateOutputFromSubgraphs({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr}, dst_ge_tensor_desc_ptr); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
EXPECT_EQ(dst_ge_tensor_desc_ptr->GetShape().GetDims(), UNKNOWN_RANK); | |||||
} | |||||
TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_get_unknown_shape) { | |||||
// ref output has different dtype | |||||
GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); | |||||
GeTensorDesc ge_tensor_desc2(GeShape({2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); | |||||
GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); | |||||
GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc1); | |||||
GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc2); | |||||
GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared<GeTensorDesc>(dst_ge_tensor_desc); | |||||
InferShapePass infershape_pass; | |||||
auto ret = infershape_pass.UpdateOutputFromSubgraphs({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr}, dst_ge_tensor_desc_ptr); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
EXPECT_EQ(dst_ge_tensor_desc_ptr->GetShape().GetDims(), std::vector<int64_t>({-1,2,3,4})); | |||||
// todo shape range? | |||||
} | |||||
TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_for_multiDims_failed) { | |||||
// ref output has different dtype | |||||
GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); | |||||
GeTensorDesc ge_tensor_desc2(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); | |||||
GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); | |||||
GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc1); | |||||
GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc2); | |||||
GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared<GeTensorDesc>(dst_ge_tensor_desc); | |||||
InferShapePass infershape_pass; | |||||
auto ret = infershape_pass.UpdateOutputFromSubgraphsForMultiDims({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr}, | |||||
dst_ge_tensor_desc_ptr); | |||||
EXPECT_EQ(ret, GRAPH_FAILED); | |||||
} | |||||
TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_for_multiDims_failed_shape_size_overflow) { | |||||
// ref output has different dtype | |||||
GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); | |||||
GeTensorDesc ge_tensor_desc2(GeShape({INT64_MAX, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); | |||||
GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); | |||||
GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc1); | |||||
GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc2); | |||||
GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared<GeTensorDesc>(dst_ge_tensor_desc); | |||||
InferShapePass infershape_pass; | |||||
auto ret = infershape_pass.UpdateOutputFromSubgraphsForMultiDims({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr}, | |||||
dst_ge_tensor_desc_ptr); | |||||
EXPECT_EQ(ret, PARAM_INVALID); | |||||
} | |||||
TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_for_multiDims_success) { | |||||
// ref output has different dtype | |||||
GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); | |||||
GeTensorDesc ge_tensor_desc2(GeShape({2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); | |||||
GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); | |||||
GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc1); | |||||
GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared<GeTensorDesc>(ge_tensor_desc2); | |||||
GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared<GeTensorDesc>(dst_ge_tensor_desc); | |||||
InferShapePass infershape_pass; | |||||
auto ret = infershape_pass.UpdateOutputFromSubgraphsForMultiDims({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr}, | |||||
dst_ge_tensor_desc_ptr); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
EXPECT_EQ(dst_ge_tensor_desc_ptr->GetShape().GetDims(), std::vector<int64_t>({2,2,3,4})); | |||||
} | |||||
} // namespace ge |