Browse Source

Pre Merge pull request !1907 from zhaoxinxin/master

pull/1907/MERGE
zhaoxinxin Gitee 4 years ago
parent
commit
1dea357b63
9 changed files with 1278 additions and 334 deletions
  1. +300
    -184
      ge/graph/passes/base_pass.cc
  2. +81
    -16
      ge/graph/passes/base_pass.h
  3. +5
    -2
      ge/graph/passes/infer_base_pass.cc
  4. +278
    -102
      ge/graph/passes/infershape_pass.cc
  5. +30
    -12
      ge/graph/passes/infershape_pass.h
  6. +16
    -0
      ge/graph/preprocess/graph_preprocess.cc
  7. +1
    -1
      tests/ut/ge/graph/passes/addn_pass_unittest.cc
  8. +463
    -11
      tests/ut/ge/graph/passes/base_pass_unittest.cc
  9. +104
    -6
      tests/ut/ge/graph/passes/infershape_pass_unittest.cc

+ 300
- 184
ge/graph/passes/base_pass.cc View File

@@ -19,9 +19,7 @@
#include <queue>
#include <unordered_set>

#include "framework/common/debug/log.h"
#include "framework/common/debug/ge_log.h"
#include "graph/compute_graph.h"
#include "common/debug/log.h"
#include "graph/utils/graph_utils.h"

namespace ge {
@@ -30,95 +28,161 @@ constexpr int kMaxRePassTimes = 10000;
constexpr size_t kMaxOneInNodes = 1000;
// Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later
constexpr int kMaxRecursiveDepth = 20;
struct DuringPassNodeSets {
std::unordered_set<Node *> nodes_seen;
std::unordered_set<NodePtr> nodes_deleted;
std::unordered_set<NodePtr> nodes_re_pass;
std::unordered_set<NodePtr> nodes_re_pass_immediately;
std::unordered_set<NodePtr> nodes_last;
std::unordered_set<NodePtr> nodes_suspend;
std::unordered_set<NodePtr> nodes_resume;
};

void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &input_edge_nodes,
std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_last) {
nodes_last.clear();

void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph,
GEPass::GraphLevelState &g_state) {
for (auto &node : graph->GetDirectNode()) {
if (node == nullptr) {
continue;
}
size_t in_nums = node->GetInNodes().size();
if (in_nums == 0) {
input_edge_nodes.push_back(node);
nodes_seen.insert(node.get());
g_state.AddNodeToQueueIfNotSeen(node);
} else if (in_nums > kMaxOneInNodes) {
nodes_last.insert(node);
g_state.nodes_last.insert(node);
}
}
}

bool IsAllInNodesAlive(const Node::Vistor<NodePtr> &nodes, const std::unordered_set<NodePtr> &nodes_suspend) {
return !std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { return nodes_suspend.count(n) > 0; });
bool AllNodesIn(const Node::Vistor<NodePtr> &nodes, const std::unordered_set<NodePtr> &nodes_set) {
return !std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) {
return nodes_set.count(n) == 0;
});
}

bool AnyNodesIn(const Node::Vistor<NodePtr> &nodes, const std::unordered_set<NodePtr> &nodes_set) {
return std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) {
return nodes_set.count(n) > 0;
});
}

void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass,
DuringPassNodeSets &during_pass_node_set) {
auto &nodes_seen = during_pass_node_set.nodes_seen;
const auto &nodes_last = during_pass_node_set.nodes_last;
const auto &nodes_suspend = during_pass_node_set.nodes_suspend;
for (auto &node : nodes) {
bool IsNodeReadyToQueue(const NodePtr &node, GEPass::GraphLevelState &g_state) {
if (node == nullptr) {
GELOGW("node is null");
return false;
}

if (g_state.nodes_deleted.count(node) > 0) {
GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str());
return false;
}

if (g_state.nodes_last.count(node) != 0) {
return false;
}

if (!node->IsAllInNodesSeen(g_state.nodes_seen)) {
return false;
}

// 因为在PassNode之前,会首先将node的输出节点添加queue,因此若在pass node时,suspend了node的输出节点,后续逻辑与上面相同
// TODO 需要注意的是,这里的保证是一次”尽力而为“,若pass node时,将node之前的节点`A`添加到了suspend,
// 那么`A`节点的后继和间接后继节点的pass不会受到suspend的影响
// 理论上来说,如果在pass node之前,首先收集node的输出节点,在pass后,将输出节点做suspend、delete的去除,然后加queue,
// 这样处理就不需要在这里做额外的确认了
if (g_state.nodes_suspend.count(node) > 0) {
GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.",
node->GetName().c_str());
return false;
}
if (AnyNodesIn(node->GetInAllNodes(), g_state.nodes_suspend)) {
GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.",
node->GetName().c_str());
return false;
}
return true;
}

void CollectOutNodesBeforePass(const NodePtr &node, std::unordered_set<NodePtr> &out_nodes_before_pass) {
for (const auto &out_node : node->GetOutNodes()) {
out_nodes_before_pass.insert(out_node);
}
}

void AddNextIterNodes(const NodePtr &cur_node, std::unordered_set<NodePtr> &out_nodes_before_pass,
GEPass::GraphLevelState &g_state) {
for (auto &node : cur_node->GetOutNodes()) {
if (node == nullptr) {
continue;
}
if (nodes_last.count(node) != 0) {
continue;
if (out_nodes_before_pass.erase(node) == 0) {
// after pass node , new output node come up
GELOGD("New output nodes %s come up after pass %s.", node->GetName().c_str(), cur_node->GetName().c_str());
}

if (IsNodeReadyToQueue(node, g_state)) {
g_state.AddNodeToQueueIfNotSeen(node);
}
if (nodes_suspend.count(node) > 0) {
GELOGD("The node %s has suspend by pass, skip it.", node->GetName().c_str());
}
// A-->B-->C
// \
// D--->E
// If B has been delete after pass, two case need to consider
// 1. A & C & E has been repass by B. good choice
// 2. A & C & E not added to repass, C will not pass because no one trigger it.
// while E will pass because D will trigger it.
// So here we need add node which has no input_node to queue.
for (const auto &node : out_nodes_before_pass) {
if (!node->GetInAllNodes().empty()) {
GELOGD("Node %s used to be output of node %s, but after pass it doesnt. "
"It may triggered by other node, so no need add to queue now.");
continue;
}

bool all_in_nodes_alive = IsAllInNodesAlive(node->GetInAllNodes(), nodes_suspend);
bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen);
if (all_in_nodes_seen && all_in_nodes_alive && nodes_seen.insert(node.get()).second) {
nodes_to_pass.push_back(node);
if (IsNodeReadyToQueue(node, g_state)) {
// unlink edge may happen, add these node to queue otherwise they can not pass
GELOGI("Node %s may lost from cur node, add to queue if not seen.",
node->GetName().c_str(), cur_node->GetName().c_str());
g_state.AddNodeToQueueIfNotSeen(node);
}
}
}

void AddRepassNodes(DuringPassNodeSets &during_pass_node_set, std::deque<NodePtr> &nodes) {
for (const auto &node : during_pass_node_set.nodes_re_pass_immediately) {
GELOGD("The node %s will be re-pass immediately.", node->GetName().c_str());
nodes.push_front(node);
void AddImmediateRepassNodesToQueue(NodePtr &cur_node,
const std::unordered_map<NodePtr, std::string> re_pass_imm_nodes_to_pass_names,
GEPass::GraphLevelState &g_state) {
for (const auto &node_2_pass_names : re_pass_imm_nodes_to_pass_names) {
auto repass_imm_node = node_2_pass_names.first;
if (repass_imm_node == nullptr) {
GELOGW("Found null immediately re-pass node when executing pass %s on node %s type %s",
node_2_pass_names.second.c_str(),
cur_node->GetName().c_str(), cur_node->GetType().c_str());
continue;
}
if (g_state.nodes_passed.count(repass_imm_node) > 0) {
GELOGD("The node %s specified by pass %s has been passed, it will repass immediately",
repass_imm_node->GetName().c_str(), node_2_pass_names.second.c_str());
g_state.AddNodeToQueueFront(repass_imm_node);
continue;
}
GELOGW("The node %s specified by pass %s has un-passed in_nodes, it will not repass immediately",
repass_imm_node->GetName().c_str(), node_2_pass_names.second.c_str());
}
during_pass_node_set.nodes_re_pass_immediately.clear();
}

void AddResumeNodes(DuringPassNodeSets &during_pass_node_set, std::deque<NodePtr> &nodes) {
for (auto &node : during_pass_node_set.nodes_resume) {
const auto &it = during_pass_node_set.nodes_suspend.find(node);
if (it != during_pass_node_set.nodes_suspend.end()) {
during_pass_node_set.nodes_suspend.erase(node);
GELOGD("The node %s resumed by pass.", node->GetName().c_str());
nodes.push_back(node);
} else {
GELOGW("The node %s not suspend, drop from resumed", node->GetName().c_str());
void AddLastNodesToQueue(GEPass::GraphLevelState &g_state) {
for (auto &node : g_state.nodes_last) {
// todo 为什么会在node_seen中看到node_last,blame一下看看历史合入记录
if (node->IsAllInNodesSeen(g_state.nodes_seen)) {
g_state.AddNodeToQueueIfNotSeen(node);
}
}
during_pass_node_set.nodes_resume.clear();
g_state.nodes_last.clear();
}

void PushToSuspendNodes(DuringPassNodeSets &during_pass_node_set, const std::string &pass_name,
const std::unordered_set<NodePtr> &nodes_suspend,
const std::unordered_set<NodePtr> &nodes_resume) {
for (const auto &node : nodes_suspend) {
GELOGD("The iteration suspend of node %s has been set by pass %s", node->GetName().c_str(), pass_name.c_str());
during_pass_node_set.nodes_suspend.emplace(node);
}

for (const auto &node : nodes_resume) {
GELOGD("The iteration suspend of node %s has been resumed by pass %s", node->GetName().c_str(), pass_name.c_str());
during_pass_node_set.nodes_resume.emplace(node);
void AddResumeNodesToQueue(const std::unordered_map<NodePtr, std::string> resume_nodes_to_pass_names,
GEPass::GraphLevelState &g_state) {
// Currently we dont keep the order of suspend nodes and resume nodes, so its hard to know
// which one comes first. Simple way : if a node both have suspend & resume state, we will resume it.
// Better way: keep the order when suspend/resume a node, and in this func suspend/resume in order.
for (const auto &node_2_pass_names : resume_nodes_to_pass_names) {
auto node = node_2_pass_names.first;
if (g_state.nodes_suspend.erase(node) > 0) {
if (g_state.nodes_seen.count(node.get()) > 0 || node->IsAllInNodesSeen(g_state.nodes_seen)) {
g_state.nodes.push_back(node);
GELOGD("Node %s has been resumed by pass %s, add to queue.",
node->GetName().c_str(), node_2_pass_names.second.c_str());
}
}
}
}

@@ -140,54 +204,6 @@ void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass
}
}

Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNodeSets &during_pass_node_set) {
if (node == nullptr) {
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid.");
GELOGE(FAILED, "[Check][Param] parameter node is nullptr.");
return FAILED;
}
GELOGD("Begin to run pass for node %s", node->GetName().c_str());
for (const auto &name_to_pass : names_to_passes) {
if (name_to_pass.second == nullptr) {
GELOGE(INTERNAL_ERROR, "[Check][Param] There is null pointer in passes(%s), skip it", name_to_pass.first.c_str());
continue;
}

GELOGD("Begin to run pass %s for node %s", name_to_pass.first.c_str(), node->GetName().c_str());
name_to_pass.second->init();
auto result = name_to_pass.second->Run(node);
if (result != SUCCESS) {
REPORT_CALL_ERROR("E19999", "process pass %s on node:%s failed, ret:%u",
name_to_pass.first.c_str(), node->GetName().c_str(), result);
GELOGE(INTERNAL_ERROR, "[Process][Pass] %s on node %s failed, result "
"%u, the passes will be terminated immediately.",
name_to_pass.first.c_str(), node->GetName().c_str(), result);
return result;
}

const auto &nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass();
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass,
during_pass_node_set.nodes_re_pass);

const auto &nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately();
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately,
during_pass_node_set.nodes_re_pass_immediately);

PushToSuspendNodes(during_pass_node_set, name_to_pass.first,
name_to_pass.second->GetNodesSuspend(), name_to_pass.second->GetNodesResume());

const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted();
during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end());
if (nodes_deleted_by_pass.count(node) > 0) {
GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(),
name_to_pass.first.c_str());
break;
}
}

return SUCCESS;
}

void SetFlagOption(NodePassOption option, NamesToPass names_to_pass) {
for (auto &name_to_pass : names_to_pass) {
name_to_pass.second->SetOption(option, "");
@@ -199,27 +215,10 @@ void ClearOption(NamesToPass names_to_pass) {
name_to_pass.second->ClearOptions();
}
}

bool CheckNode(const NodePtr &node, const DuringPassNodeSets &during_pass_node_set) {
if (node == nullptr) {
GELOGW("node is null");
return false;
}
if (during_pass_node_set.nodes_deleted.count(node) > 0) {
GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str());
return false;
}
if (during_pass_node_set.nodes_suspend.count(node) > 0) {
GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.",
node->GetName().c_str());
return false;
}

return true;
}
} // namespace

Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map) {
Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map,
bool is_repass_io_immediately) {
if (node == nullptr) {
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid.");
GELOGE(FAILED, "[Check][Param] parameter node is nullptr.");
@@ -235,7 +234,7 @@ Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector<int>
return FAILED;
}

AddRePassNodesWithInOut(node);
is_repass_io_immediately ? AddImmediateRePassNodesWithInOut(node) : AddRePassNodesWithInOut(node);

if (GraphUtils::IsolateNode(node, io_map) != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Isolate Node:%s failed", node->GetName().c_str());
@@ -263,6 +262,12 @@ Status GEPass::Run(const NamesToPass &names_to_passes) {
GELOGW("No passes input, the GEPass will do nothing");
return INTERNAL_ERROR;
}
for (const auto &name_to_pass : names_to_passes) {
if (name_to_pass.second == nullptr) {
GELOGE(INTERNAL_ERROR, "[Check][Param] There is null pointer in passes(%s)", name_to_pass.first.c_str());
return INTERNAL_ERROR;
}
}

if (depth_ > kMaxRecursiveDepth) {
GELOGE(PARAM_INVALID,
@@ -275,81 +280,101 @@ Status GEPass::Run(const NamesToPass &names_to_passes) {
return RunPassesOneGraph(names_to_passes);
}

void NotifyPassGraphStart(const ComputeGraphPtr &graph, const NamesToPass &names_to_pass) {
for (auto &name_to_pass : names_to_pass) {
name_to_pass.second->OnStartPassGraph(graph);
}
}

Status GEPass::HandleLeakedSuspendNodes(const NamesToPass &names_to_passes, GraphLevelState &g_state) {
std::unordered_map<NodePtr, std::string> resume_nodes_to_pass_names;
for (auto &name_to_pass : names_to_passes) {
name_to_pass.second->init();
auto ret = name_to_pass.second->OnSuspendNodesLeaked();
if (ret != SUCCESS) {
GELOGE(ret, "Internal Error happened when pass %s handle on suspend nodes leaked.",
name_to_pass.first.c_str());
return ret;
}
for (const auto &resume_node : name_to_pass.second->GetNodesResume()){
resume_nodes_to_pass_names[resume_node].append(name_to_pass.first + ",");
}
}
AddResumeNodesToQueue(resume_nodes_to_pass_names, g_state);
return SUCCESS;
}

Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) {
GELOGD("Begin to run pass on graph, passes count %zu", names_to_passes.size());
std::deque<NodePtr> nodes;
DuringPassNodeSets during_pass_node_set;
GetAllNodesNoInputEdge(graph_, nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last);
GELOGD("Start points count %zu", nodes.size());
int re_pass_times = 0;
NotifyPassGraphStart(graph_, names_to_passes);
GraphLevelState g_state;
g_state.re_pass_times = 0;
GetAllNodesNoInputEdge(graph_, g_state);
GELOGD("Start points count %zu", g_state.nodes.size());

do {
for (auto &node : during_pass_node_set.nodes_re_pass) {
nodes.push_back(node);
during_pass_node_set.nodes_seen.insert(node.get());
if (!g_state.nodes_suspend.empty()) {
auto ret = HandleLeakedSuspendNodes(names_to_passes, g_state);
if (ret != SUCCESS) {
GELOGE(ret, "Failed to handle leaked suspend nodes, break base pass.");
return ret;
}
if (g_state.nodes.empty()) {
// There are suspend nodes leaked, but no pass resume it
GELOGE(INTERNAL_ERROR, "There are suspend nodes but no pass resume, which means"
"some nodes in this graph never pass.");
return INTERNAL_ERROR;
}
}
auto ret = RunPassesGraphRepass(names_to_passes, g_state);
if (ret != SUCCESS) {
return ret;
}
during_pass_node_set.nodes_re_pass.clear();
} while (!g_state.nodes_suspend.empty());

while (!nodes.empty()) {
NodePtr node = nodes.front();
nodes.pop_front();
return SUCCESS;
}

(void)during_pass_node_set.nodes_re_pass.erase(node);
if (!CheckNode(node, during_pass_node_set)) {
continue;
}
AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set);

auto ret = RunPasses(node, names_to_passes, during_pass_node_set);
if (ret != SUCCESS) {
GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u",
node->GetName().c_str(), node->GetType().c_str(), ret);
return ret;
}
Status GEPass::RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLevelState &g_state) {
RepassLevelState rp_state;
do {
for (auto &node : rp_state.nodes_re_pass) {
GELOGD("Add node %s to queue for re-pass.", node->GetName().c_str());
g_state.AddNodeToQueue(node);
}
rp_state.nodes_re_pass.clear();

bool has_sub_graph = false;
ret = RunPassesOnSubGraph(node, names_to_passes, has_sub_graph);
if (ret != SUCCESS) {
GELOGE(ret, "[Run][Passes] on the sub graph of node %s failed", node->GetName().c_str());
return ret;
}
while (!g_state.nodes.empty()) {
auto node = g_state.PopFront();

if (has_sub_graph) {
GELOGD("There are subgraphs on node %s, run passes for for the second time", node->GetName().c_str());
SetFlagOption(kOptimizeAfterSubGraph, names_to_passes);
ret = RunPasses(node, names_to_passes, during_pass_node_set);
if (ret != SUCCESS) {
GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code: %u",
node->GetName().c_str(), node->GetType().c_str(), ret);
return ret;
}

// There is only one option scene, so set and clear options around the `RunPasses` func.
// if there are more than one scene to set options, the `ClearOption` function
// should be called each time at the begin of the iteration
ClearOption(names_to_passes);
if (g_state.nodes_deleted.count(node) > 0) {
GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str());
}
(void)rp_state.nodes_re_pass.erase(node);// todo why
g_state.nodes_seen.insert(node.get()); // todo 为什么这里seen

AddRepassNodes(during_pass_node_set, nodes);
AddResumeNodes(during_pass_node_set, nodes);
}
std::unordered_set<NodePtr> out_nodes_before_pass;
CollectOutNodesBeforePass(node, out_nodes_before_pass);

for (auto &node : during_pass_node_set.nodes_last) {
bool all_in_nodes_seen = node->IsAllInNodesSeen(during_pass_node_set.nodes_seen);
if (all_in_nodes_seen && during_pass_node_set.nodes_seen.insert(node.get()).second) {
nodes.push_back(node);
auto ret = RunPassesNodeOnce(node, names_to_passes, g_state, rp_state);
if (ret != SUCCESS) {
GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", node->GetName().c_str(),
node->GetType().c_str(), ret);
return ret;
}
AddNextIterNodes(node, out_nodes_before_pass, g_state);
}
during_pass_node_set.nodes_last.clear();
} while ((!during_pass_node_set.nodes_re_pass.empty() || !nodes.empty()) && ++re_pass_times < kMaxRePassTimes);
AddLastNodesToQueue(g_state);
} while ((!rp_state.nodes_re_pass.empty() || !g_state.nodes.empty()) && ++g_state.re_pass_times < kMaxRePassTimes);

if (re_pass_times == kMaxRePassTimes) {
if (g_state.re_pass_times == kMaxRePassTimes) {
GELOGW("re_pass_times should not come to %d", kMaxRePassTimes);
}
GELOGD("All passes runs end");

return SUCCESS;
}

Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph) {
auto sub_graph_names = node->GetOpDesc()->GetSubgraphInstanceNames();
has_sub_graph = false;
@@ -371,4 +396,95 @@ Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names
}
return SUCCESS;
}

Status GEPass::RunPassesNodeOnce(NodePtr &node, const NamesToPass &names_to_passes,
GraphLevelState &g_state, RepassLevelState &rp_state) {
auto ret = RunPassesOnNode(node, names_to_passes, g_state, rp_state);
if (ret != SUCCESS) {
GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", node->GetName().c_str(),
node->GetType().c_str(), ret);
return ret;
}

bool has_sub_graph = false;
ret = RunPassesOnSubGraph(node, names_to_passes, has_sub_graph);
if (ret != SUCCESS) {
GELOGE(ret, "[Run][Passes] on the sub graph of node %s failed", node->GetName().c_str());
return ret;
}

if (has_sub_graph) {
GELOGD("There are subgraphs on node %s, run passes for for the second time", node->GetName().c_str());
SetFlagOption(kOptimizeAfterSubGraph, names_to_passes);
ret = RunPassesOnNode(node, names_to_passes, g_state, rp_state);
if (ret != SUCCESS) {
GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code: %u", node->GetName().c_str(),
node->GetType().c_str(), ret);
return ret;
}

// There is only one option scene, so set and clear options around the `RunPasses` func.
// if there are more than one scene to set options, the `ClearOption` function
// should be called each time at the begin of the iteration
ClearOption(names_to_passes);
}
return SUCCESS;
}

Status GEPass::RunPassesOnNode(NodePtr &node, const NamesToPass &names_to_passes, GraphLevelState &g_state,
RepassLevelState &rp_state) {
if (node == nullptr) {
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid.");
GELOGE(FAILED, "[Check][Param] parameter node is nullptr.");
return FAILED;
}
GELOGD("Begin to run pass for node %s", node->GetName().c_str());
for (const auto &name_to_pass : names_to_passes) {
GELOGD("Begin to run pass %s for node %s", name_to_pass.first.c_str(), node->GetName().c_str());
name_to_pass.second->init();
auto result = name_to_pass.second->Run(node);
if (result != SUCCESS) {
REPORT_CALL_ERROR("E19999", "process pass %s on node:%s failed, ret:%u", name_to_pass.first.c_str(),
node->GetName().c_str(), result);
GELOGE(INTERNAL_ERROR,
"[Process][Pass] %s on node %s failed, result "
"%u, the passes will be terminated immediately.",
name_to_pass.first.c_str(), node->GetName().c_str(), result);
return result;
}
if (name_to_pass.second->GetNodesDeleted().count(node) > 0) {
GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(),
name_to_pass.first.c_str());
break;
}
}

g_state.nodes_passed.insert(node);

std::unordered_map<NodePtr, std::string> repass_imm_nodes_to_pass_names;
std::unordered_map<NodePtr, std::string> resume_nodes_to_pass_names;
// if multi pass add one node to repass immediately, here need to remove duplication
for (const auto &name_to_pass : names_to_passes) {
PushToRePassIfSeen(node, name_to_pass, g_state.nodes_seen, name_to_pass.second->GetNodesNeedRePass(),
rp_state.nodes_re_pass);
// collect imm_node && resume_node among these passes
for (const auto &imm_node : name_to_pass.second->GetNodesNeedRePassImmediately()) {
repass_imm_nodes_to_pass_names[imm_node].append(name_to_pass.first + ",");
}
for (const auto &resume_node : name_to_pass.second->GetNodesResume()) {
resume_nodes_to_pass_names[resume_node].append(name_to_pass.first + ",");
}

for (const auto &suspend_node : name_to_pass.second->GetNodesSuspend()) {
GELOGD("The iteration suspend of node %s has been set by pass %s", suspend_node->GetName().c_str(),
name_to_pass.first.c_str());
g_state.nodes_suspend.insert(suspend_node);
}
const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted();
g_state.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end());
}
AddImmediateRepassNodesToQueue(node, repass_imm_nodes_to_pass_names, g_state);
AddResumeNodesToQueue(resume_nodes_to_pass_names, g_state);
return SUCCESS;
}
} // namespace ge

+ 81
- 16
ge/graph/passes/base_pass.h View File

@@ -22,7 +22,6 @@
#include <unordered_set>
#include <utility>
#include <vector>

#include "framework/common/ge_inner_error_codes.h"
#include "framework/common/types.h"
#include "graph/compute_graph.h"
@@ -61,23 +60,32 @@ class BaseNodePass {

const std::unordered_set<NodePtr> &GetNodesResume() { return nodes_resume_; }

virtual Status OnSuspendNodesLeaked() { return SUCCESS; }

void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; }

void ClearOptions() { options_.clear(); }

void init() {
nodes_need_re_pass_.clear();
nodes_deleted_.clear();
nodes_need_re_pass_immediately_.clear();
nodes_deleted_.clear();
nodes_suspend_.clear();
nodes_resume_.clear();
}

virtual void OnStartPassGraph(const ComputeGraphPtr &graph) {
current_graph_name_ = graph->GetName();
}

protected:
Status IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map);
const string &GetCurrentGraphName() const {
return current_graph_name_;
}
Status IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map, bool is_repass_io_immediately = false);

Status IsolateAndDeleteNode(NodePtr &node, const std::initializer_list<int> &io_map) {
return IsolateAndDeleteNode(node, std::vector<int>(io_map));
Status IsolateAndDeleteNode(NodePtr &node, const std::initializer_list<int> &io_map, bool is_repass_io_immediately = false) {
return IsolateAndDeleteNode(node, std::vector<int>(io_map), is_repass_io_immediately);
}

///
@@ -112,6 +120,22 @@ class BaseNodePass {
}
}

///
/// Add a node and it's input/output data nodes to be optimized immediately again.
/// @param node
///
void AddImmediateRePassNodesWithInOut(const NodePtr &node) {
auto in_nodes = node->GetInNodes();
for (auto &in_node : in_nodes) {
AddImmediateRePassNode(in_node);
}
AddImmediateRePassNode(node);
auto out_nodes = node->GetOutNodes();
for (auto &out_node : out_nodes) {
AddImmediateRePassNode(out_node);
}
}

///
/// If you deleted a node from the graph, especially current node. The remain
/// iterate passes will continue process on the deleted node(if it can be
@@ -123,23 +147,15 @@ class BaseNodePass {
void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); }

///
/// If you suspend a node from the graph, especially following node. The remain
/// iterate passes will stop process on the suspend node(if it can be
/// If you postpone a node from the graph, especially following node. The remain
/// iterate passes will stop process on the postpone node(if it can be
/// reached by edge connections) till the last one. Obviously it is a waste of
/// time. You can add the suspend nodes by calling this function, to stop the
/// time. You can add the postpone nodes by calling this function, to stop the
/// next iterations.
/// @param node
///
void AddNodeSuspend(const NodePtr &node) { nodes_suspend_.insert(node); }

///
/// If you resume a node from the graph, especially following node. The remain
/// iterate passes will continue process on the resume node(if it can be
/// reached by edge connections) till the last one.
/// You can add the resume nodes by calling this function, to resume the
/// next iterations.
/// @param node
///
void AddNodeResume(const NodePtr &node) { nodes_resume_.insert(node); }

bool OptionExists(NodePassOption option) { return options_.count(option) > 0; }
@@ -151,6 +167,7 @@ class BaseNodePass {
std::unordered_set<NodePtr> nodes_suspend_;
std::unordered_set<NodePtr> nodes_resume_;
std::map<NodePassOption, std::string> options_;
std::string current_graph_name_;
};

using NamesToPass = std::vector<std::pair<std::string, BaseNodePass *>>;
@@ -160,12 +177,60 @@ class GEPass {
explicit GEPass(ComputeGraphPtr &graph) : graph_(graph), root_graph_(graph), depth_(1) {}
virtual ~GEPass() = default;
Status Run(const NamesToPass &names_to_passes);
/*
* todo
* OneGraph: nodes_deleted, nodes_seen, nodes_passed, nodes_suspended
* RePass: nodes_re_pass
* GraphOneTime: nodes_last
* NodeOneTime: nodes_re_pass_immediately, nodes_resume
*/
struct GraphLevelState {
std::unordered_set<NodePtr> nodes_deleted;
std::unordered_set<Node *> nodes_seen;
std::unordered_set<NodePtr> nodes_passed;
std::unordered_set<NodePtr> nodes_suspend;
std::unordered_set<NodePtr> nodes_last;
std::deque<NodePtr> nodes;
int re_pass_times;

void AddNodeToQueueFront(NodePtr node) {
nodes_seen.insert(node.get());
nodes.emplace_front(std::move(node));
}

void AddNodeToQueue(NodePtr node) {
nodes_seen.insert(node.get());
nodes.emplace_back(std::move(node));
}
void AddNodeToQueueIfNotSeen(NodePtr node) {
if (nodes_seen.insert(node.get()).second) {
nodes.emplace_back(std::move(node));
}
}
NodePtr PopFront() {
NodePtr node = nodes.front();
nodes.pop_front();
return node;
}
};
struct RepassLevelState {
std::unordered_set<NodePtr> nodes_re_pass;
};
struct GraphOneTimeLevelState {
std::unordered_set<NodePtr> nodes_last;
};

private:
GEPass(ComputeGraphPtr &graph, ComputeGraphPtr &root_graph, int depth)
: graph_(graph), root_graph_(root_graph), depth_(depth) {}
Status RunPassesNodeOnce(NodePtr &node, const NamesToPass &names_to_passes,
GraphLevelState &g_state, RepassLevelState &rp_state);
Status RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLevelState &g_state);
Status RunPassesOneGraph(const NamesToPass &names_to_passes);
Status RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph);
Status RunPassesOnNode(NodePtr &node, const NamesToPass &names_to_passes, GraphLevelState &g_state,
RepassLevelState &rp_state);
Status HandleLeakedSuspendNodes(const NamesToPass &names_to_passes, GraphLevelState &g_state);
ComputeGraphPtr graph_;
ComputeGraphPtr root_graph_;
int depth_;


+ 5
- 2
ge/graph/passes/infer_base_pass.cc View File

@@ -84,8 +84,11 @@ Status InferBasePass::Run(NodePtr &node) {

bool InferBasePass::NeedInfer(const NodePtr &node) const { return true; }
void InferBasePass::AddChangedNodesImmediateRepass(const std::set<NodePtr> &changed_nodes) {
// need passed_nodes set to solve the problem that multi-input operators do repass in advance.
// when there is passed_nodes set, wo should call AddImmediateRePassNode for all nodes in changed_nodes.
// need passed_nodes set to solve the problem that multi-input operators do repass in advance.
// when there is passed_nodes set, wo should call AddImmediateRePassNode for all nodes in changed_nodes.
for (const auto &node : changed_nodes) {
AddImmediateRePassNode(node);
}
}

graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes) {


+ 278
- 102
ge/graph/passes/infershape_pass.cc View File

@@ -1,6 +1,6 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*+
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
@@ -22,13 +22,16 @@
#include "graph/shape_refiner.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"
#include "graph/common/omg_util.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/type_utils.h"

namespace ge {
#include "external/graph/operator_factory.h"

namespace ge {
namespace {
constexpr int kSwitchExitAnchorIndex = 0;
constexpr int kSwitchPredAnchorIndex = 1;
void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) {
desc_str += "[";
std::vector<std::pair<int64_t, int64_t>> shape_range;
@@ -47,129 +50,302 @@ void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) {
desc_str += "},";
}
}
void UpdateShapeAndDType(const GeTensorDescPtr &src, GeTensorDescPtr &dst) {
dst->SetOriginShape(src->GetOriginShape());
dst->SetShape(src->GetShape());
dst->SetDataType(src->GetDataType());
dst->SetOriginDataType(src->GetOriginDataType());
vector<pair<int64_t, int64_t>> src_shape_range;
src->GetShapeRange(src_shape_range);
dst->SetShapeRange(src_shape_range);
dst->SetOriginShapeRange(src_shape_range);
ge::TensorUtils::SetRealDimCnt(*dst, static_cast<uint32_t>(src->GetShape().GetDims().size()));
}
} // namespace

std::string GetInTensorInfoWithString(const ge::NodePtr &node) {
ge::OpDescPtr op_desc = node->GetOpDesc();
std::string InferShapePass::SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const {
std::stringstream ss;
ss << "{";
int32_t in_idx = 0;
for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
if (input_desc == nullptr) {
in_idx++;
ss << "(shape:[" << tensor_desc->MutableShape().ToString() << "]),";
ss << "(format:" << TypeUtils::FormatToSerialString(tensor_desc->GetFormat()) << "),";
ss << "(dtype:" << TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()) << "),";
ss << "(origin_shape:" << tensor_desc->GetOriginShape().ToString() << "),";
ss << "(origin_format:" << TypeUtils::FormatToSerialString(tensor_desc->GetOriginFormat()) << "),";
ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(tensor_desc->GetOriginDataType()) << "),";
string range_str;
SerialShapeRange(tensor_desc, range_str);
ss << "(shape_range:" << range_str << ")";
return ss.str();
}
Status InferShapePass::SuspendV1LoopExitNodes(const NodePtr &node) {
if (node->GetType() != SWITCH) {
return SUCCESS;
}
auto pred_node = NodeUtils::GetInDataNodeByIndex(*node, kSwitchPredAnchorIndex);
GE_CHECK_NOTNULL(pred_node);
if (pred_node->GetType() != LOOPCOND) {
return SUCCESS;
}

for (const auto &anchor_2_node : NodeUtils::GetOutDataNodesWithAnchorByIndex(*node, kSwitchExitAnchorIndex)) {
GELOGI("Found v1 loop when infershape, suspend Exit node %s, type %s.", anchor_2_node.second->GetName().c_str(),
anchor_2_node.second->GetType().c_str());
auto iter = graphs_2_suspend_nodes_.find(GetCurrentGraphName());
if (iter == graphs_2_suspend_nodes_.end()) {
continue;
}
if (in_idx > 0) {
ss << " ";
auto &suspend_nodes = graphs_2_suspend_nodes_[GetCurrentGraphName()];
if (suspend_nodes.nodes_set.insert(anchor_2_node.second).second) {
suspend_nodes.nodes.push(anchor_2_node.second);
AddNodeSuspend(anchor_2_node.second);
}
ss << "input_" << in_idx << " " << "tensor: [";
ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),";
ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),";
ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),";
ss << "(origin_shape:" << input_desc->GetOriginShape().ToString() << "),";
ss << "(origin_format:" << TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) << "),";
ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) << "),";
string range_str;
SerialShapeRange(input_desc, range_str);
ss << "(shape_range:" << range_str << ")]";
in_idx++;
}
return ss.str();
return SUCCESS;
}

Status InferShapePass::Run(NodePtr &node) {
// kOptimizeAfterSubGraph exist means after subgraph
auto ret = ShapeRefiner::InferShapeAndType(node, !OptionExists(kOptimizeAfterSubGraph));
if (ret != GRAPH_SUCCESS) {
// select INFERSHAPE failed info
auto graph = node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(graph);
auto root_graph = ge::GraphUtils::FindRootGraph(graph);
GE_CHECK_NOTNULL(root_graph);
analyzer::DataInfo analyze_info{root_graph->GetSessionID(), root_graph->GetGraphID(),
analyzer::INFER_SHAPE, node, "InferShapeFailed!"};
(void)Analyzer::GetInstance()->DoAnalyze(analyze_info);
(void)Analyzer::GetInstance()->SaveAnalyzerDataToFile(root_graph->GetSessionID(),
root_graph->GetGraphID());

REPORT_CALL_ERROR("E19999", "Call InferShapeAndType for node:%s(%s) failed, input_tensor:%s",
node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str());
GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "[Call][InferShapeAndType] for node:%s(%s) failed, input_tensor:%s",
node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str());
return GE_GRAPH_INFERSHAPE_FAILED;
}

GE_CHK_STATUS_RET_NOLOG(RePassLoopNode(node));
bool need_repass = false;
auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_repass);
if (has_attr) {
if (!OptionExists(kOptimizeAfterSubGraph)) {
return SUCCESS;
Status InferShapePass::Infer(NodePtr &node) {
auto ret = SuspendV1LoopExitNodes(node);
if (ret != SUCCESS) {
GELOGE(ret, "Failed to suspend exit node in v1 control flow loop.");
return ret;
}
bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag();
auto opdesc = node->GetOpDesc();
if (node->Verify() != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Verifying %s failed.", node->GetName().c_str());
GELOGE(GRAPH_FAILED, "[Call][Verify] Verifying %s failed.", node->GetName().c_str());
return GRAPH_FAILED;
}
Operator op = OpDescUtils::CreateOperatorFromNode(node);

if (!is_unknown_graph) {
auto inference_context = ShapeRefiner::CreateInferenceContext(node);
GE_CHECK_NOTNULL(inference_context);
vector<AscendString> marks;
inference_context->GetMarks(marks);
GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), marks.size());
op.SetInferenceContext(inference_context);
}

graphStatus status = CallInferShapeFunc(node, op);
if (status != GRAPH_NODE_NEED_REPASS && status != GRAPH_PARAM_INVALID && status != GRAPH_SUCCESS) {
// node like netoutput return param_invalid, but valid ?
REPORT_CALL_ERROR("E19999", "%s call infer function failed.", node->GetName().c_str());
GELOGE(GRAPH_FAILED, "[Call][InferFunction] failed, node:%s.", node->GetName().c_str());
return GRAPH_FAILED;
}
UpdateCurNodeOutputDesc(node);
if (!is_unknown_graph) {
auto ctx_after_infer = op.GetInferenceContext();
if (ctx_after_infer != nullptr) {
vector<AscendString> marks;
ctx_after_infer->GetMarks(marks);
GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), marks.size());
if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !marks.empty()) {
GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(),
marks.size());
ShapeRefiner::PushToContextMap(node, ctx_after_infer);
}
}
if (need_repass) {
AddImmediateRePassNode(node);
GELOGD("Node %s need repass immediately.", node->GetName().c_str());
} else {
// clear attr on while
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN);
}
return (status == GRAPH_NODE_NEED_REPASS) ? GRAPH_NODE_NEED_REPASS : GRAPH_SUCCESS;
}

bool InferShapePass::SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) {
// check shape range
vector<std::pair<int64_t, int64_t>> src_shape_range;
vector<std::pair<int64_t, int64_t>> dst_shape_range;
src->GetShapeRange(src_shape_range);
dst->GetShapeRange(dst_shape_range);
if (src_shape_range.size() != dst_shape_range.size()) {
GELOGI("Src shape range size is %zu, dst shape range size is %zu, not same.", src_shape_range.size(),
dst_shape_range.size());
return false;
}
for (size_t i = 0; i < src_shape_range.size(); ++i) {
if (src_shape_range[i].first != dst_shape_range[i].first ||
src_shape_range[i].second != dst_shape_range[i].second) {
GELOGI("Current dim %zu. Src shape range is [%lu-%lu], dst shape range is [%lu-%lu], not same.",
i, src_shape_range[i].first, src_shape_range[i].second, dst_shape_range[i].first, dst_shape_range[i].second);
return false;
}
}

// check shape
auto src_shape = src->GetShape();
auto dst_shape = dst->GetShape();
if (src_shape.GetDims() != dst_shape.GetDims() || src->GetOriginShape().GetDims() != dst->GetOriginShape().GetDims() ||
src->GetDataType() != dst->GetDataType() || src->GetOriginDataType() != dst->GetOriginDataType()) {
GELOGD(
"Src shape is %s, origin_shape is %s, data_type is %s, origin data_type is %s; "
"Dst shape is %s, origin_shape is %s, data_type is %s, original data_type is %s, not same.",
src_shape.ToString().c_str(), src->GetOriginShape().ToString().c_str(),
TypeUtils::DataTypeToSerialString(src->GetDataType()).c_str(),
TypeUtils::DataTypeToSerialString(src->GetOriginDataType()).c_str(), dst_shape.ToString().c_str(),
dst->GetOriginShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(dst->GetDataType()).c_str(),
TypeUtils::DataTypeToSerialString(dst->GetOriginDataType()).c_str());
return false;
}
return true;
}

void InferShapePass::UpdateCurNodeOutputDesc(NodePtr &node) {
auto op_desc = node->GetOpDesc();
for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
GE_IF_BOOL_EXEC(output_tensor == nullptr, continue);
GE_IF_BOOL_EXEC(output_tensor->MutableShape().GetDims().empty(),
output_tensor->SetOriginShape(output_tensor->GetShape()));

ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetOriginShape().GetDims()
.size()));
output_tensor->SetOriginDataType(output_tensor->GetDataType());
// set output origin shape range
std::vector<std::pair<int64_t, int64_t>> range;
(void)output_tensor->GetShapeRange(range);
output_tensor->SetOriginShapeRange(range);
GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s",
node->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(),
TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(),
TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str());
}
}

graphStatus InferShapePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) {
changed = false;
if (SameTensorDesc(src, dst)) {
GELOGD("Peer dst tensor_desc is same as src tensor_desc. No need update.");
return SUCCESS;
}

changed = true;
UpdateShapeAndDType(src, dst);
GELOGD(
"UpdatePeerInputDesc from src Node: shape: [%s], datatype: %s, original datatype is %s."
"To dst Node: shape: [%s], datatype: %s, original datatype is %s.",
src->GetShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(src->GetDataType()).c_str(),
TypeUtils::DataTypeToSerialString(src->GetOriginDataType()).c_str(), dst->GetShape().ToString().c_str(),
TypeUtils::DataTypeToSerialString(dst->GetDataType()).c_str(),
TypeUtils::DataTypeToSerialString(dst->GetOriginDataType()).c_str());
return SUCCESS;
}

Status InferShapePass::RePassLoopNode(const NodePtr &node) {
const auto RePassNode = [&](const std::set<std::string> &re_pass_types) {
for (auto &n : node->GetOutDataNodes()) {
GE_CHECK_NOTNULL(n);
std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "[Get][OriginalType] of node:%s failed.", n->GetName().c_str());
if (re_pass_types.count(node_type) > 0) {
AddImmediateRePassNode(n);
(void)AttrUtils::SetBool(n->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, false);
GELOGD("Node %s need repass immediately after %s.", n->GetName().c_str(), node->GetName().c_str());
}
graphStatus InferShapePass::CallInferShapeFunc(NodePtr &node, Operator &op) {
auto op_desc = node->GetOpDesc();
const auto &op_type = op_desc->GetType();
auto ret = op_desc->CallInferFunc(op);
if (ret == GRAPH_PARAM_INVALID) {
// Op ir no infer func, try to get infer func from operator factory
auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType().c_str());
if (node_op.IsEmpty()) {
GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str());
return ret;
}
return SUCCESS;
};

const auto ExProcNode = [&](const std::set<std::string> &proc_types,
const std::function<void(InferShapePass *, NodePtr)> &proc_func,
const std::string &info) {
for (auto &n : node->GetOutDataNodes()) {
GE_CHECK_NOTNULL(n);
std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "[Get][OriginalType] of node:%s failed.", n->GetName().c_str());
if (proc_types.count(node_type) > 0) {
proc_func(this, n);
GELOGD("Node %s %s after %s.", n->GetName().c_str(), info.c_str(), node->GetName().c_str());

GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str());
auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op);
node_op.BreakConnect();
if (temp_op_desc == nullptr) {
REPORT_CALL_ERROR("E19999", "GetOpDescFromOperator failed, return nullptr.");
GELOGE(GRAPH_FAILED, "[Get][OpDesc] temp op desc is null");
return GRAPH_FAILED;
}
if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) {
GELOGW("InferShapeAndType UpdateInputName failed");
for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) {
if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) {
break;
}
return GRAPH_SUCCESS;
}
}
return SUCCESS;
};

std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(node, node_type),
"[Get][OriginalType] of node:%s failed.", node->GetName().c_str());
if (kNextIterationOpTypes.count(node_type) > 0) {
return RePassNode(kMergeOpTypes); // Re-Pass Merge
if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) {
GELOGW("InferShapeAndType UpdateOutputName failed");
}
op_desc->AddInferFunc(temp_op_desc->GetInferFunc());
ret = op_desc->CallInferFunc(op);
GELOGI("op CallInferFunc second. ret: %u", ret);
}
return ret;
}

if (kMergeOpTypes.count(node_type) > 0) {
if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) {
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN);
return RePassNode(kSwitchOpTypes); // Re-Pass Switch
graphStatus InferShapePass::UpdateOutputFromSubgraphs(const std::vector<GeTensorDescPtr> &src, GeTensorDescPtr &dst) {
GELOGD("Enter update parent node shape for class branch op process");
// check sub_graph shape.If not same ,do unknown shape process
auto ref_out_tensor = src.at(0);
ge::GeShape &ref_out_tensor_shape = ref_out_tensor->MutableShape();
for (auto &tensor : src) {
if (ref_out_tensor->GetDataType() != tensor->GetDataType()) {
REPORT_INNER_ERROR("E19999", "Does not support diff dtype among all ref output, shape:%s",
ref_out_tensor_shape.ToString().c_str());
GELOGE(GRAPH_FAILED, "[Check][Param] node does not support diff dtype output");
return GRAPH_FAILED;
}
auto shape = tensor->MutableShape();
if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) {
GELOGD("Shape from subgraph size: %lu, ref_out_tensor_shape size: %lu", shape.GetShapeSize(),
ref_out_tensor_shape.GetShapeSize());
ref_out_tensor_shape = GeShape(UNKNOWN_RANK);
break;
}
for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) {
if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) {
continue;
}
GELOGD("j: %zu ,shape from subgraph size: %lu, ref_out_tensor_shape size: %lu", j, shape.GetShapeSize(),
ref_out_tensor_shape.GetShapeSize());
(void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM);
}
}
UpdateShapeAndDType(ref_out_tensor, dst);
return GRAPH_SUCCESS;
}
graphStatus InferShapePass::UpdateOutputFromSubgraphsForMultiDims(const std::vector<GeTensorDescPtr> &src,
GeTensorDescPtr &dst) {
// check sub_graph shape. Get max for update.
if (src.empty()) {
// TODO LOG
return SUCCESS;
}

if (kSwitchOpTypes.count(node_type) > 0) {
if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) {
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN);
return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeResume, "need resume"); // Resume Exit
} else {
return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeSuspend, "need suspend"); // Suspend Exit
int64_t max_size = 0;
size_t max_shape_index = 0;
auto &ref_out_tensor = src.at(0);
for (size_t j = 0; j < src.size(); ++j) {
auto &tensor = src.at(j);
if (ref_out_tensor->GetDataType() != tensor->GetDataType()) {
REPORT_INNER_ERROR("E19999", "node does not support diff dtype among all ref output");
GELOGE(GRAPH_FAILED, "[Check][Param] node does not support diff dtype among all ref output");
return GRAPH_FAILED;
}

auto shape = tensor->MutableShape();
int64_t size = 1;
for (auto dim : shape.GetDims()) {
if (dim != 0 && INT64_MAX / dim < size) {
REPORT_INNER_ERROR("E19999", "The shape:%s size overflow", shape.ToString().c_str());
GELOGE(PARAM_INVALID, "[Check][Overflow] The shape size overflow");
return PARAM_INVALID;
}
size *= dim;
}
}

if (size > max_size) {
max_size = size;
max_shape_index = j;
}
}
UpdateShapeAndDType(src.at(max_shape_index), dst);
return GRAPH_SUCCESS;
}
Status InferShapePass::OnSuspendNodesLeaked() {
auto iter = graphs_2_suspend_nodes_.find(GetCurrentGraphName());
if (iter == graphs_2_suspend_nodes_.end()) {
GELOGW("There is no suspend nodes on graph %s", GetCurrentGraphName().c_str());
return SUCCESS;
}
if (!iter->second.nodes.empty()) {
AddNodeResume(iter->second.PopSuspendedNode());
}
return SUCCESS;
}
} // namespace ge

+ 30
- 12
ge/graph/passes/infershape_pass.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -17,22 +17,40 @@
#ifndef GE_GRAPH_PASSES_INFERSHAPE_PASS_H_
#define GE_GRAPH_PASSES_INFERSHAPE_PASS_H_

#include "graph/passes/base_pass.h"
#include "graph/passes/infer_base_pass.h"
#include <stack>

namespace ge {
class InferShapePass : public BaseNodePass {
class InferShapePass : public InferBasePass {
public:
///
/// Entry of the InferShapePass optimizer
/// @param [in] graph: Input ComputeGraph
/// @return SUCCESS: Execution succeed
/// @return OTHERS: Execution failed
/// @author
///
Status Run(ge::NodePtr &node) override;
std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override;
graphStatus Infer(NodePtr &node) override;

graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override;
graphStatus UpdateOutputFromSubgraphs(const std::vector<GeTensorDescPtr> &src, GeTensorDescPtr &dst) override;
graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector<GeTensorDescPtr> &src,
GeTensorDescPtr &dst) override;

Status OnSuspendNodesLeaked() override;

private:
Status RePassLoopNode(const NodePtr &node);
graphStatus CallInferShapeFunc(NodePtr &node, Operator &op);
bool SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst);
void UpdateCurNodeOutputDesc(NodePtr &node);
Status SuspendV1LoopExitNodes(const NodePtr &node);
struct SuspendNodes {
std::stack<NodePtr> nodes;
std::unordered_set<NodePtr> nodes_set;

NodePtr PopSuspendedNode() {
auto top_node = nodes.top();
nodes.pop();
nodes_set.erase(top_node);
return top_node;
}
};
std::map<std::string, SuspendNodes> graphs_2_suspend_nodes_;

};
} // namespace ge
#endif // GE_GRAPH_PASSES_INFERSHAPE_PASS_H_

+ 16
- 0
ge/graph/preprocess/graph_preprocess.cc View File

@@ -1999,6 +1999,22 @@ Status GraphPrepare::CheckUserInput(const std::vector<GeTensor> &user_input) {

Status GraphPrepare::InferShapeForPreprocess() {
GELOGI("Start infershape for preprocess.");
// Prepare dummy_shape for v1 control_flow op before infershape
for (const auto &node : compute_graph_->GetAllNodes()) {
string type;
GetOriginalType(node, type);
if (type == MERGE || type == REFMERGE) {
for (size_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) {
GELOGD("Prepare for infershape: update %s input_shape as dummy.", node->GetName().c_str());
NodeUtils::UpdateInputShape(*node, i, GeShape(DUMMY_SHAPE));
}
} else if (type == WHILE) {
for (size_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) {
GELOGD("Prepare for infershape: update %s output_shape as dummy.", node->GetName().c_str());
NodeUtils::UpdateOutputShape(*node, i, GeShape(DUMMY_SHAPE));
}
}
}
GEPass ge_passes(compute_graph_);
NamesToPass names_to_passes;
AssertPass assert_pass;


+ 1
- 1
tests/ut/ge/graph/passes/addn_pass_unittest.cc View File

@@ -72,7 +72,7 @@ TEST(UtestGraphPassesAddnPass, null_pass) {
AddNPass *addn_pass = nullptr;
NamesToPass names_to_pass;
names_to_pass.emplace_back("Test", addn_pass);
EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
EXPECT_NE(pass.Run(names_to_pass), SUCCESS);
}

TEST(UtestGraphPassesAddnPass, null_graph) {


+ 463
- 11
tests/ut/ge/graph/passes/base_pass_unittest.cc View File

@@ -17,7 +17,6 @@
#include <iostream>
#include <map>
#include <set>
#include <string>
#include <vector>

#include "gtest/gtest.h"
@@ -26,8 +25,6 @@
#include "graph/passes/base_pass.h"
#undef protected

#include "external/graph/ge_error_codes.h"
#include "framework/common/ge_inner_error_codes.h"
#include "framework/common/types.h"
#include "graph/node.h"
#include "graph/utils/graph_utils.h"
@@ -67,6 +64,54 @@ class UtestTestPass : public BaseNodePass {
names_to_add_repass_.erase(iter);
}
}

iter = names_to_add_repass_immediate_.find(node->GetName());
if (iter != names_to_add_repass_immediate_.end()) {
auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes();
for (const auto &node_name : iter->second) {
for (auto &node_re_pass : all_nodes) {
if (node_re_pass->GetName() == node_name) {
AddImmediateRePassNode(node_re_pass);
break;
}
}
}
if (!dead_loop_) {
names_to_add_repass_immediate_.erase(iter);
}
}

iter = names_to_add_suspend_.find(node->GetName());
if (iter != names_to_add_suspend_.end()) {
auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes();
for (const auto &node_name : iter->second) {
for (auto &node_re_pass : all_nodes) {
if (node_re_pass->GetName() == node_name) {
AddNodeSuspend(node_re_pass);
break;
}
}
}
if (!dead_loop_) {
names_to_add_suspend_.erase(iter);
}
}

iter = names_to_add_resume_.find(node->GetName());
if (iter != names_to_add_resume_.end()) {
auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes();
for (const auto &node_name : iter->second) {
for (auto &node_re_pass : all_nodes) {
if (node_re_pass->GetName() == node_name) {
AddNodeResume(node_re_pass);
break;
}
}
}
if (!dead_loop_) {
names_to_add_resume_.erase(iter);
}
}
// simulate infershape pass
if(node->GetType() == WHILE){
bool need_repass = false;
@@ -85,6 +130,20 @@ class UtestTestPass : public BaseNodePass {
}
return SUCCESS;
}

Status OnSuspendNodesLeaked() override {
// resume all node remain in suspend_nodes when leaked
auto compute_graph = (iter_nodes_.size() > 0) ? iter_nodes_[0]->GetOwnerComputeGraph() : nullptr;
if (compute_graph == nullptr) {
return SUCCESS;
}

for (const auto &node_name : names_to_add_resume_onleaked_) {
auto node_to_resume = compute_graph->FindNode(node_name);
AddNodeResume(node_to_resume);
}
return SUCCESS;
}
void clear() { iter_nodes_.clear(); }
std::vector<NodePtr> GetIterNodes() { return iter_nodes_; }

@@ -94,12 +153,31 @@ class UtestTestPass : public BaseNodePass {
void AddDelNodeName(const std::string &iter_node, const std::string &del_node) {
names_to_add_del_[iter_node].insert(del_node);
}
void AddRePassImmediateNodeName(const std::string &iter_node, const std::string &re_pass_node) {
names_to_add_repass_immediate_[iter_node].insert(re_pass_node);
}

void AddSuspendNodeName(const std::string &iter_node, const std::string &suspend_node) {
names_to_add_suspend_[iter_node].insert(suspend_node);
}
void AddResumeNodeName(const std::string &iter_node, const std::string &resume_node) {
names_to_add_resume_[iter_node].insert(resume_node);
}
void AddResumeNodeNameOnLeaked(const std::string &resume_node) {
names_to_add_resume_onleaked_.insert(resume_node);
}

unsigned int GetRunTimes() { return run_times_; }

private:
std::vector<NodePtr> iter_nodes_;
std::map<std::string, std::unordered_set<std::string>> names_to_add_del_;
std::map<std::string, std::unordered_set<std::string>> names_to_add_repass_;
std::map<std::string, std::unordered_set<std::string>> names_to_add_repass_immediate_;
std::map<std::string, std::unordered_set<std::string>> names_to_add_suspend_;
std::map<std::string, std::unordered_set<std::string>> names_to_add_resume_;
std::unordered_set<std::string> names_to_add_resume_onleaked_;

bool dead_loop_;
unsigned int run_times_;
};
@@ -200,6 +278,26 @@ ComputeGraphPtr BuildGraph3() {
return builder.GetGraph();
}

/// cast1--shape1
/// /
/// data1
/// \
/// transdata1--shape2
ComputeGraphPtr BuildGraph4() {
auto builder = ut::GraphBuilder("g1");
auto data1 = builder.AddNode("data1", DATA, 0, 1);
auto cast1 = builder.AddNode("cast1", CAST, 1, 1);
auto shape1 = builder.AddNode("shape1", SHAPE, 1, 1);
auto transdata1 = builder.AddNode("transdata1", TRANSDATA, 1, 1);
auto shape2 = builder.AddNode("shape2", SHAPE, 1, 1);

builder.AddDataEdge(data1, 0, cast1, 0);
builder.AddDataEdge(data1, 0, transdata1, 0);
builder.AddDataEdge(cast1, 0, shape1, 0);
builder.AddDataEdge(transdata1, 0, shape2, 0);
return builder.GetGraph();
}

void CheckIterOrder(UtestTestPass *pass, std::vector<std::unordered_set<std::string>> &nodes_layers) {
std::unordered_set<std::string> layer_nodes;
size_t layer_index = 0;
@@ -509,15 +607,369 @@ ComputeGraphPtr BuildWhileGraph1() {
}

TEST_F(UTESTGraphPassesBasePass, while_infershape) {
NamesToPass names_to_pass;
auto test_pass = UtestTestPass();
names_to_pass.push_back(std::make_pair("test", &test_pass));
NamesToPass names_to_pass;
auto test_pass = UtestTestPass();
names_to_pass.push_back(std::make_pair("test", &test_pass));

auto graph = BuildWhileGraph1();
auto ge_pass = GEPass(graph);
auto while_node = graph->FindNode("while");
EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1);
EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
}

TEST_F(UTESTGraphPassesBasePass, re_pass_pre_node_immediately) {
auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// repass pre_node immediately
test_pass->AddRePassImmediateNodeName("reshape1", "add1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);

EXPECT_EQ(test_pass->GetIterNodes().size(), 9);// todo
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1", "add1", "sum1"});
CheckIterOrder(test_pass, layers);
}

TEST_F(UTESTGraphPassesBasePass, re_pass_cur_node_immediately) {
auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// repass cur_node immediately
test_pass->AddRePassImmediateNodeName("reshape1", "reshape1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);

EXPECT_EQ(test_pass->GetIterNodes().size(), 9);
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1"});
layers.push_back({"reshape1", "sum1"});
CheckIterOrder(test_pass, layers);
}

TEST_F(UTESTGraphPassesBasePass, re_pass_next_node_immediately) {
auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// repass next_node immediately
test_pass->AddRePassImmediateNodeName("reshape1", "sum1");
// repass node after next_node immediately
test_pass->AddRePassImmediateNodeName("add1", "sum1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);

EXPECT_EQ(test_pass->GetIterNodes().size(), 8);
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1", "sum1"});
CheckIterOrder(test_pass, layers);
}
/**
* A->B->C
* if node B suspend its pre_node A, and C resume A, it is a useless operation, so iter_order should follow normal order
* when C resuem A, A will pass again.
*/
TEST_F(UTESTGraphPassesBasePass, B_suspend_pre_node_A_then_C_resume_A) {
auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// add1->reshape1->sum1
test_pass->AddSuspendNodeName("reshape1", "add1");
test_pass->AddResumeNodeName("sum1", "add1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
EXPECT_EQ(test_pass->GetIterNodes().size(), 9);
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1", "sum1"});
layers.push_back({"add1"});
CheckIterOrder(test_pass, layers);
}

/**
* A->B->C
* if node B suspend its pre_node A, and B resume A, it is a useless operation, so iter_order should follow normal order
* when B resuem A, A will pass again.
*/
TEST_F(UTESTGraphPassesBasePass, B_suspend_pre_node_A_then_B_resume_A) {
auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// add1->reshape1->sum1
test_pass->AddSuspendNodeName("reshape1", "add1");
test_pass->AddResumeNodeName("reshape1", "add1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
EXPECT_EQ(test_pass->GetIterNodes().size(), 9);
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1", "sum1", "add1"});
CheckIterOrder(test_pass, layers);
}

/**
* A->B->C
* if node B resume C(which is not suspended), it is a useless operation, C will not pass.
*/
TEST_F(UTESTGraphPassesBasePass, B_resume_node_not_suspended) {
auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// add1->reshape1->sum1
test_pass->AddResumeNodeName("reshape1", "sum1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
EXPECT_EQ(test_pass->GetIterNodes().size(), 8);
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1", "sum1"});
CheckIterOrder(test_pass, layers);
}

auto graph = BuildWhileGraph1();
auto ge_pass = GEPass(graph);
auto while_node = graph->FindNode("while");
EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1);
EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
/**
* A->B->C
* if node B suspend its pre_node A, it is a useless operation, so iter_order should follow normal order
* because nobody resume it ,which means A is a leaked node, so return fail
*/
TEST_F(UTESTGraphPassesBasePass, suspend_pre_node_nobody_resume_it_return_failed) {
NamesToPass names_to_pass;
auto test_pass = UtestTestPass();
names_to_pass.push_back(std::make_pair("test", &test_pass));
// suspend pre_node immediately
test_pass.AddSuspendNodeName("reshape1", "add1");
auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
EXPECT_EQ(ge_pass.Run(names_to_pass), INTERNAL_ERROR);
}

/**
* A->B->C
* if node B suspend its pre_node A, it is a useless operation,
* so iter_order should follow normal order
* resume A on leaked, which means A will pass again
*/
TEST_F(UTESTGraphPassesBasePass, suspend_pre_node_resume_it_onleaked) {
auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// suspend pre_node immediately
test_pass->AddSuspendNodeName("reshape1", "add1");
test_pass->AddResumeNodeNameOnLeaked("add1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1", "sum1"});
layers.push_back({"add1"});
CheckIterOrder(test_pass, layers);
}


/// cast1--shape1
/// /
/// data1
/// \
/// transdata1--shape2
/**
* suspend cur node
* cast1 suspend itself, shape2 resume cast1
* iter order follows : data1; cast1,transdata1; shape2; cast1 ; shape1
*/
TEST_F(UTESTGraphPassesBasePass, cast1_suspend_cur_node_shape2_resume_cast1) {
auto graph = BuildGraph4();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// suspend pre_node immediately
test_pass->AddSuspendNodeName("cast1", "cast1");
test_pass->AddResumeNodeName("shape2", "cast1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
EXPECT_EQ(test_pass->GetIterNodes().size(), 6);
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1"});
layers.push_back({"cast1","transdata1"});
layers.push_back({"shape2"});
layers.push_back({"cast1", "shape1"});
CheckIterOrder(test_pass, layers);
}
/**
* suspend cur node
* cast1 suspend itself, then resume cast1
* iter order follows : data1; cast1,cast1,transdata1; shape2; shape1.
*/
TEST_F(UTESTGraphPassesBasePass, cast1_suspend_itslef_then_resume_itself) {
auto graph = BuildGraph4();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// suspend pre_node immediately
test_pass->AddSuspendNodeName("cast1", "cast1");
test_pass->AddResumeNodeName("cast1", "cast1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
EXPECT_EQ(test_pass->GetIterNodes().size(), 6);
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1"});
layers.push_back({"cast1","transdata1","cast1","shape1", "shape2"});
CheckIterOrder(test_pass, layers);
}
/**
* suspend cur node
* cast1 suspend itself, then resume cast1 on leaked
* iter order follows : data1; cast1,cast1,transdata1; shape2; shape1.
*/
TEST_F(UTESTGraphPassesBasePass, cast1_suspend_itslef_then_resume_onleaked) {
auto graph = BuildGraph4();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// suspend pre_node immediately
test_pass->AddSuspendNodeName("cast1", "cast1");
test_pass->AddResumeNodeNameOnLeaked("cast1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
EXPECT_EQ(test_pass->GetIterNodes().size(), 6);
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1"});
layers.push_back({"cast1","transdata1", "shape2"});
layers.push_back({"cast1","shape1"});
CheckIterOrder(test_pass, layers);
}
/**
* suspend next node
* data1 suspend cast1, then resume cast1 on leaked
* iter order follows : data1; transdata1, shape2; cast1, shape1.
*/
TEST_F(UTESTGraphPassesBasePass, data1_suspend_cast1_resume_cast1_onleaked) {
auto graph = BuildGraph4();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// suspend pre_node immediately
test_pass->AddSuspendNodeName("data1", "cast1");
test_pass->AddResumeNodeNameOnLeaked("cast1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
EXPECT_EQ(test_pass->GetIterNodes().size(), 5);
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1"});
layers.push_back({"transdata1", "shape2"});
layers.push_back({"cast1","shape1"});
CheckIterOrder(test_pass, layers);
}

/**
* suspend next node
* data1 suspend cast1, nobody resume it
* iter order follows : data1; transdata1, shape2;
* run ret is failed ,because node leaked
*/
TEST_F(UTESTGraphPassesBasePass, data1_suspend_cast1_nobody_resume) {
auto graph = BuildGraph4();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// suspend pre_node immediately
test_pass->AddSuspendNodeName("data1", "cast1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), INTERNAL_ERROR);
EXPECT_EQ(test_pass->GetIterNodes().size(), 3);
}


TEST_F(UTESTGraphPassesBasePass, re_pass_pre_node_immediately) {
NamesToPass names_to_pass;
auto test_pass = UtestTestPass();
names_to_pass.push_back(std::make_pair("test", &test_pass));

// repass pre_node immediately
test_pass.AddRePassImmediateNodeName("reshape1", "add1");

auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
EXPECT_EQ(test_pass.GetIterNodes().size(), 9);// todo
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1", "add1", "sum1"});
CheckIterOrder(&test_pass, layers);
}
/// sum1
/// / \.
/// / \.
/// / \.
/// reshape1 addn1
/// | c |
/// add1 <--- shape1
/// / \ |
/// | | |
/// data1 const1 const2
TEST_F(UTESTGraphPassesBasePass, re_pass_cur_node_immediately) {
NamesToPass names_to_pass;
auto test_pass = UtestTestPass();
names_to_pass.push_back(std::make_pair("test", &test_pass));

// repass cur_node immediately
test_pass.AddRePassImmediateNodeName("reshape1", "reshape1");

auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
EXPECT_EQ(test_pass.GetIterNodes().size(), 9);// todo
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1"});
layers.push_back({"reshape1", "sum1"});
CheckIterOrder(&test_pass, layers);
}

TEST_F(UTESTGraphPassesBasePass, re_pass_next_node_immediately) {
NamesToPass names_to_pass;
auto test_pass = UtestTestPass();
names_to_pass.push_back(std::make_pair("test", &test_pass));

// repass next_node immediately
test_pass.AddRePassImmediateNodeName("reshape1", "sum1");
// repass node after next_node immediately
test_pass.AddRePassImmediateNodeName("add1", "sum1");

auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
EXPECT_EQ(test_pass.GetIterNodes().size(), 8);// todo
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1", "sum1"});
CheckIterOrder(&test_pass, layers);
}
/*
TEST_F(UTESTGraphPassesBasePass, suspend_pre_node) {
NamesToPass names_to_pass;
auto test_pass = UtestTestPass();
names_to_pass.push_back(std::make_pair("test", &test_pass));

// repass next_node immediately
test_pass.AddRePassNodeName("reshape1", "sum1");
// repass node after next_node immediately
test_pass.AddRePassNodeName("add1", "sum1");

auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
EXPECT_EQ(test_pass.GetIterNodes().size(), 8);// todo
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1", "sum1"});
CheckIterOrder(&test_pass, layers);
}*/
} // namespace ge

+ 104
- 6
tests/ut/ge/graph/passes/infershape_pass_unittest.cc View File

@@ -29,13 +29,77 @@
using namespace std;
using namespace testing;
namespace ge {
namespace {
// do nothing stub infer_func
const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; };
// infer from input to output stub infer_func (input size == output size)
const auto stub_mapping_func = [](Operator &op) {
size_t in_num = op.GetInputsSize();
for (size_t i = 0; i < in_num; ++i) {
auto in_desc = op.GetInputDesc(i);
auto out_desc = op.GetOutputDesc(i);
out_desc.SetShape(in_desc.GetShape());
out_desc.SetDataType(in_desc.GetDataType());
op.UpdateOutputDesc(out_desc.GetName(), out_desc);
}
return GRAPH_SUCCESS;
};
// merge infer_func

// while infer_func
const auto while_infer_func = [](Operator &op) {
size_t in_num = op.GetInputsSize();
size_t out_num = op.GetOutputsSize();
if (in_num != out_num) {
return GRAPH_FAILED;
}
bool need_infer_again = false;
for (size_t i = 0; i < in_num; ++i) {
auto in_desc = op.GetDynamicInputDesc("input", i);
auto out_desc = op.GetDynamicOutputDesc("output", i);
auto data_shape = in_desc.GetShape();
auto out_shape = out_desc.GetShape();
if(out_shape.GetDims() == DUMMY_SHAPE){
return GRAPH_SUCCESS;
}
// check datatype between output and input
if (in_desc.GetDataType() != out_desc.GetDataType()) {
return GRAPH_FAILED;
}

if (data_shape.GetDims() != out_shape.GetDims()) {
need_infer_again = true;
if (data_shape.GetDimNum() != out_shape.GetDimNum()) {
in_desc.SetUnknownDimNumShape();
} else {
size_t data_dim_num = data_shape.GetDimNum();
std::vector<std::pair<int64_t, int64_t>> data_shape_range = {data_dim_num, std::make_pair(1, UNKNOWN_DIM)};
for (size_t j = 0; j < data_dim_num; ++j) {
if (data_shape.GetDim(j) != out_shape.GetDim(j)) {
data_shape.SetDim(j, UNKNOWN_DIM);
}
if (data_shape.GetDim(j) != UNKNOWN_DIM) {
data_shape_range[j] = std::make_pair(data_shape.GetDim(j), data_shape.GetDim(j));
}
}
in_desc.SetShape(data_shape);
in_desc.SetShapeRange(data_shape_range);
}
op.UpdateDynamicOutputDesc("output", i, in_desc);
op.UpdateDynamicInputDesc("input", i, in_desc);
}
}
return need_infer_again ? GRAPH_NODE_NEED_REPASS : GRAPH_SUCCESS;
};
}
class UtestGraphInfershapePass : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
};

static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) {
static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num,
std::function<graphStatus(Operator &)> infer_func = stub_func) {
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
op_desc->SetStreamId(0);
static int32_t index = 0;
@@ -61,14 +125,11 @@ static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string
op_desc->SetWorkspaceBytes({});
op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE");

const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; };
op_desc->AddInferFunc(stub_func);
op_desc->AddInferFormatFunc(stub_func);
op_desc->AddVerifierFunc(stub_func);

op_desc->AddInferFunc(infer_func);
return graph.AddNode(op_desc);
}

/*
TEST_F(UtestGraphInfershapePass, infershape_pass_failed) {
GeTensorDesc ge_tensor_desc(GeShape({-2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16);
string type = "AddN";
@@ -82,6 +143,7 @@ TEST_F(UtestGraphInfershapePass, infershape_pass_failed) {
InferShapePass infershape_pass;
EXPECT_EQ(infershape_pass.Run(addn_node), GE_GRAPH_INFERSHAPE_FAILED);
}
*/

TEST_F(UtestGraphInfershapePass, delete_need_infer_again) {
auto graph = std::make_shared<ComputeGraph>("test");
@@ -94,7 +156,43 @@ TEST_F(UtestGraphInfershapePass, delete_need_infer_again) {
infershape_pass.options_[kOptimizeAfterSubGraph] = "yes";
EXPECT_EQ(infershape_pass.Run(no_op_node), SUCCESS);
}
TEST_F(UtestGraphInfershapePass, infer_from_pre_to_next) {
/*
* cast->shape
*/
auto graph = std::make_shared<ComputeGraph>("test_infer_shape");
auto data1 = CreateNode(*graph, "dataq", DATA, 0, 1);
auto cast1 = CreateNode(*graph, "cast1", CAST, 1, 1, stub_mapping_func);
auto cast_in_desc = cast1->GetOpDesc()->MutableInputDesc(0);
cast_in_desc->SetShape(GeShape({1,2,3}));
cast_in_desc->SetDataType(DT_INT32);
auto transdata1 = CreateNode(*graph, "transdata1", TRANSDATA, 1, 1, stub_mapping_func);
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), cast1->GetInDataAnchor(0));
GraphUtils::AddEdge(cast1->GetOutDataAnchor(0), transdata1->GetInDataAnchor(0));

// check before infer cast1
auto cast_before = graph->FindNode("cast1");
vector<int64_t> expect_cast1_shape_dim = {1,2,3};
auto real_cast1_before_shape_dim = cast_before->GetOpDesc()->GetInputDesc(0).GetShape().GetDims();
auto transdata1_before = graph->FindNode("transdata1");
vector<int64_t> expect_transdata1_shape_dim = {};
auto real_transdata1_before_shape_dim = transdata1_before->GetOpDesc()->GetInputDesc(0).GetShape().GetDims();
EXPECT_EQ(real_cast1_before_shape_dim, expect_cast1_shape_dim);
EXPECT_EQ(real_transdata1_before_shape_dim, expect_transdata1_shape_dim);
// run infershape pass
InferShapePass infer_shape_pass;
infer_shape_pass.Run(cast_before);
// check cast1 add transdata1 to repass_immediately
infer_shape_pass.GetNodesNeedRePassImmediately();
EXPECT_TRUE(!infer_shape_pass.GetNodesNeedRePassImmediately().empty());
// check transdata input_shape & datatype after infer
auto transdata1_after = graph->FindNode("transdata1");
auto transdata1_opdesc = transdata1_before->GetOpDesc();
auto real_transdata1_after_shape_dim = transdata1_opdesc->GetInputDesc(0).GetShape().GetDims();
EXPECT_EQ(real_transdata1_after_shape_dim, expect_cast1_shape_dim);
auto transdata1_datatype_after = transdata1_opdesc->GetInputDesc(0).GetDataType();
EXPECT_EQ(transdata1_datatype_after, DT_INT32);
}
TEST_F(UtestGraphInfershapePass, stop_node_for_while_loop) {
/*******************************************************************************
* Exit Identify


Loading…
Cancel
Save