Browse Source

modified: ge/graph/passes/base_pass.cc

modified:   ge/graph/passes/base_pass.h
	modified:   ge/graph/passes/folding_pass.cc
	modified:   ge/graph/passes/infershape_pass.cc
	modified:   ge/graph/passes/infershape_pass.h
	modified:   ge/graph/passes/switch_dead_branch_elimination.cc
	modified:   ge/graph/preprocess/graph_preprocess.cc

	modified:   ge/graph/passes/base_pass.cc
	modified:   ge/graph/passes/base_pass.h
	modified:   ge/graph/passes/folding_pass.cc
	modified:   ge/graph/passes/infer_base_pass.h
	modified:   ge/graph/passes/infer_value_range_pass.cc
	modified:   ge/graph/passes/infer_value_range_pass.h
	modified:   ge/graph/passes/infershape_pass.cc
	modified:   ge/graph/passes/infershape_pass.h
	modified:   ge/graph/passes/switch_dead_branch_elimination.cc
	modified:   ge/graph/preprocess/graph_preprocess.cc

	modified:   ge/graph/passes/base_pass.cc
	modified:   ge/graph/passes/base_pass.h
	modified:   ge/graph/passes/folding_pass.cc
	modified:   ge/graph/passes/infer_base_pass.h
	modified:   ge/graph/passes/infer_value_range_pass.cc
	modified:   ge/graph/passes/infer_value_range_pass.h
	modified:   ge/graph/passes/infershape_pass.cc
	modified:   ge/graph/passes/infershape_pass.h
	modified:   ge/graph/passes/merge_pass.cc
	modified:   ge/graph/passes/switch_dead_branch_elimination.cc
	modified:   ge/graph/preprocess/graph_preprocess.cc

	modified:   ge/graph/passes/base_pass.cc
	modified:   ge/graph/passes/base_pass.h
	modified:   ge/graph/passes/folding_pass.cc
	modified:   ge/graph/passes/infer_base_pass.h
	modified:   ge/graph/passes/infer_value_range_pass.cc
	modified:   ge/graph/passes/infer_value_range_pass.h
	modified:   ge/graph/passes/infershape_pass.cc
	modified:   ge/graph/passes/infershape_pass.h
	modified:   ge/graph/passes/merge_pass.cc
	modified:   ge/graph/passes/switch_dead_branch_elimination.cc
	modified:   ge/graph/preprocess/graph_preprocess.cc

	modified:   ge/graph/passes/base_pass.cc
	modified:   ge/graph/passes/base_pass.h
	modified:   ge/graph/passes/folding_pass.cc
	modified:   ge/graph/passes/infer_base_pass.h
        modified:   ge/graph/passes/infer_value_range_pass.cc
	modified:   ge/graph/passes/infer_value_range_pass.h
	modified:   ge/graph/passes/infershape_pass.cc
	modified:   ge/graph/passes/infershape_pass.h
	modified:   ge/graph/passes/merge_pass.cc
	modified:   ge/graph/passes/switch_dead_branch_elimination.cc
	modified:   ge/graph/preprocess/graph_preprocess.cc
pull/1907/head
zhaoxinxin 4 years ago
parent
commit
00c4b026bd
11 changed files with 624 additions and 301 deletions
  1. +231
    -162
      ge/graph/passes/base_pass.cc
  2. +81
    -16
      ge/graph/passes/base_pass.h
  3. +1
    -1
      ge/graph/passes/folding_pass.cc
  4. +1
    -1
      ge/graph/passes/infer_base_pass.h
  5. +1
    -1
      ge/graph/passes/infer_value_range_pass.cc
  6. +1
    -1
      ge/graph/passes/infer_value_range_pass.h
  7. +258
    -104
      ge/graph/passes/infershape_pass.cc
  8. +29
    -12
      ge/graph/passes/infershape_pass.h
  9. +3
    -2
      ge/graph/passes/merge_pass.cc
  10. +2
    -1
      ge/graph/passes/switch_dead_branch_elimination.cc
  11. +16
    -0
      ge/graph/preprocess/graph_preprocess.cc

+ 231
- 162
ge/graph/passes/base_pass.cc View File

@@ -19,9 +19,7 @@
#include <queue> #include <queue>
#include <unordered_set> #include <unordered_set>


#include "framework/common/debug/log.h"
#include "framework/common/debug/ge_log.h"
#include "graph/compute_graph.h"
#include "common/debug/log.h"
#include "graph/utils/graph_utils.h" #include "graph/utils/graph_utils.h"


namespace ge { namespace ge {
@@ -30,47 +28,35 @@ constexpr int kMaxRePassTimes = 10000;
constexpr size_t kMaxOneInNodes = 1000; constexpr size_t kMaxOneInNodes = 1000;
// Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later // Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later
constexpr int kMaxRecursiveDepth = 20; constexpr int kMaxRecursiveDepth = 20;
struct DuringPassNodeSets {
std::unordered_set<Node *> nodes_seen;
std::unordered_set<NodePtr> nodes_deleted;
std::unordered_set<NodePtr> nodes_re_pass;
std::unordered_set<NodePtr> nodes_re_pass_immediately;
std::unordered_set<NodePtr> nodes_last;
std::unordered_set<NodePtr> nodes_suspend;
std::unordered_set<NodePtr> nodes_resume;
};

void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &input_edge_nodes,
std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_last) {
nodes_last.clear();

void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph,
GEPass::GraphLevelState &g_state) {
for (auto &node : graph->GetDirectNode()) { for (auto &node : graph->GetDirectNode()) {
if (node == nullptr) { if (node == nullptr) {
continue; continue;
} }
size_t in_nums = node->GetInNodes().size(); size_t in_nums = node->GetInNodes().size();
if (in_nums == 0) { if (in_nums == 0) {
input_edge_nodes.push_back(node);
nodes_seen.insert(node.get());
g_state.AddNodeToQueueIfNotSeen(node);
} else if (in_nums > kMaxOneInNodes) { } else if (in_nums > kMaxOneInNodes) {
nodes_last.insert(node);
g_state.nodes_last.insert(node);
} }
} }
} }


bool IsAllInNodesAlive(const Node::Vistor<NodePtr> &nodes, const std::unordered_set<NodePtr> &nodes_suspend) {
return !std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { return nodes_suspend.count(n) > 0; });
bool AllNodesIn(const Node::Vistor<NodePtr> &nodes, const std::unordered_set<NodePtr> &nodes_set) {
return !std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) {
return nodes_set.count(n) > 0;
});
} }


void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass,
DuringPassNodeSets &during_pass_node_set) {
auto &nodes_seen = during_pass_node_set.nodes_seen;
const auto &nodes_last = during_pass_node_set.nodes_last;
const auto &nodes_suspend = during_pass_node_set.nodes_suspend;
for (auto &node : nodes) {
void AddNextIterNodes(const NodePtr &cur_node, GEPass::GraphLevelState &g_state) {
const auto &nodes_suspend = g_state.nodes_suspend;
for (auto &node : cur_node->GetOutNodes()) {
if (node == nullptr) { if (node == nullptr) {
continue; continue;
} }
if (nodes_last.count(node) != 0) {
if (g_state.nodes_last.count(node) != 0) {
continue; continue;
} }
if (nodes_suspend.count(node) > 0) { if (nodes_suspend.count(node) > 0) {
@@ -78,47 +64,64 @@ void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &n
continue; continue;
} }


bool all_in_nodes_alive = IsAllInNodesAlive(node->GetInAllNodes(), nodes_suspend);
bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen);
if (all_in_nodes_seen && all_in_nodes_alive && nodes_seen.insert(node.get()).second) {
nodes_to_pass.push_back(node);
if (node->IsAllInNodesSeen(g_state.nodes_seen) && AllNodesIn(node->GetInAllNodes(), nodes_suspend)) {
g_state.AddNodeToQueueIfNotSeen(node);
} }
} }
} }


void AddRepassNodes(DuringPassNodeSets &during_pass_node_set, std::deque<NodePtr> &nodes) {
for (const auto &node : during_pass_node_set.nodes_re_pass_immediately) {
GELOGD("The node %s will be re-pass immediately.", node->GetName().c_str());
nodes.push_front(node);
void AddImmediateRepassNodesToQueue(NodePtr &cur_node, const std::pair<std::string, BaseNodePass *> &name_to_pass,
const std::unordered_set<NodePtr> &nodes_im_re_pass,
GEPass::GraphLevelState &g_state) {
for (const auto &node : nodes_im_re_pass) {
if (node == nullptr) {
GELOGW("Found null immediately re-pass node when executing pass %s on node %s type %s", name_to_pass.first.c_str(),
cur_node->GetName().c_str(), cur_node->GetType().c_str());
continue;
}
if (g_state.nodes_passed.count(node) > 0) {
g_state.AddNodeToQueueFront(node);
continue;
}
// exp: constant folding add new const need repass immediate
if (AllNodesIn(node->GetInAllNodes(), g_state.nodes_passed)) {
g_state.AddNodeToQueueFront(node);
continue;
}
GELOGW("The node %s specified by pass %s has un-passed in_nodes, it will not repass immediately",
node->GetName().c_str(), name_to_pass.first.c_str());
} }
during_pass_node_set.nodes_re_pass_immediately.clear();
} }


void AddResumeNodes(DuringPassNodeSets &during_pass_node_set, std::deque<NodePtr> &nodes) {
for (auto &node : during_pass_node_set.nodes_resume) {
const auto &it = during_pass_node_set.nodes_suspend.find(node);
if (it != during_pass_node_set.nodes_suspend.end()) {
during_pass_node_set.nodes_suspend.erase(node);
GELOGD("The node %s resumed by pass.", node->GetName().c_str());
nodes.push_back(node);
} else {
GELOGW("The node %s not suspend, drop from resumed", node->GetName().c_str());
void AddLastNodesToQueue(GEPass::GraphLevelState &g_state) {
for (auto &node : g_state.nodes_last) {
// todo 为什么会在node_seen中看到node_last,blame一下看看历史合入记录
if (node->IsAllInNodesSeen(g_state.nodes_seen)) {
g_state.AddNodeToQueueIfNotSeen(node);
} }
} }
during_pass_node_set.nodes_resume.clear();
g_state.nodes_last.clear();
} }


void PushToSuspendNodes(DuringPassNodeSets &during_pass_node_set, const std::string &pass_name,
const std::unordered_set<NodePtr> &nodes_suspend,
const std::unordered_set<NodePtr> &nodes_resume) {
void SuspendAndResume(const std::string &pass_name,
const std::unordered_set<NodePtr> &nodes_suspend,
const std::unordered_set<NodePtr> &nodes_resume,
GEPass::GraphLevelState &g_state) {
// TODO 当前没有记录NodePass中suspend和resume的顺序,因此无法辨别NodePass中是先做Suspend还是Resume。
// 因此此处的简单处理是如果在NodePass的过程中,触发了suspend/resume,那么框架以resume为准
// 更好的处理方式是,在NodePass做suspend/resume时,做顺序的记录,在此函数中按序做回放
for (const auto &node : nodes_suspend) { for (const auto &node : nodes_suspend) {
GELOGD("The iteration suspend of node %s has been set by pass %s", node->GetName().c_str(), pass_name.c_str()); GELOGD("The iteration suspend of node %s has been set by pass %s", node->GetName().c_str(), pass_name.c_str());
during_pass_node_set.nodes_suspend.emplace(node);
g_state.nodes_suspend.insert(node);
} }


for (const auto &node : nodes_resume) { for (const auto &node : nodes_resume) {
GELOGD("The iteration suspend of node %s has been resumed by pass %s", node->GetName().c_str(), pass_name.c_str());
during_pass_node_set.nodes_resume.emplace(node);
if (g_state.nodes_suspend.erase(node) > 0) {
if (g_state.nodes_seen.count(node.get()) > 0 || node->IsAllInNodesSeen(g_state.nodes_seen)) {
g_state.nodes.push_back(node);
GELOGD("Node %s has been resumed by pass %s", node->GetName().c_str(), pass_name.c_str());
}
}
} }
} }


@@ -140,54 +143,6 @@ void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass
} }
} }


Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNodeSets &during_pass_node_set) {
if (node == nullptr) {
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid.");
GELOGE(FAILED, "[Check][Param] parameter node is nullptr.");
return FAILED;
}
GELOGD("Begin to run pass for node %s", node->GetName().c_str());
for (const auto &name_to_pass : names_to_passes) {
if (name_to_pass.second == nullptr) {
GELOGE(INTERNAL_ERROR, "[Check][Param] There is null pointer in passes(%s), skip it", name_to_pass.first.c_str());
continue;
}

GELOGD("Begin to run pass %s for node %s", name_to_pass.first.c_str(), node->GetName().c_str());
name_to_pass.second->init();
auto result = name_to_pass.second->Run(node);
if (result != SUCCESS) {
REPORT_CALL_ERROR("E19999", "process pass %s on node:%s failed, ret:%u",
name_to_pass.first.c_str(), node->GetName().c_str(), result);
GELOGE(INTERNAL_ERROR, "[Process][Pass] %s on node %s failed, result "
"%u, the passes will be terminated immediately.",
name_to_pass.first.c_str(), node->GetName().c_str(), result);
return result;
}

const auto &nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass();
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass,
during_pass_node_set.nodes_re_pass);

const auto &nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately();
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately,
during_pass_node_set.nodes_re_pass_immediately);

PushToSuspendNodes(during_pass_node_set, name_to_pass.first,
name_to_pass.second->GetNodesSuspend(), name_to_pass.second->GetNodesResume());

const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted();
during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end());
if (nodes_deleted_by_pass.count(node) > 0) {
GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(),
name_to_pass.first.c_str());
break;
}
}

return SUCCESS;
}

void SetFlagOption(NodePassOption option, NamesToPass names_to_pass) { void SetFlagOption(NodePassOption option, NamesToPass names_to_pass) {
for (auto &name_to_pass : names_to_pass) { for (auto &name_to_pass : names_to_pass) {
name_to_pass.second->SetOption(option, ""); name_to_pass.second->SetOption(option, "");
@@ -200,26 +155,39 @@ void ClearOption(NamesToPass names_to_pass) {
} }
} }


bool CheckNode(const NodePtr &node, const DuringPassNodeSets &during_pass_node_set) {
bool ShouldNodePassActually(const NodePtr &node, const GEPass::GraphLevelState &g_state) {
if (node == nullptr) { if (node == nullptr) {
GELOGW("node is null"); GELOGW("node is null");
return false; return false;
} }
if (during_pass_node_set.nodes_deleted.count(node) > 0) {
// 因为在PassNode之前,会首先将node的输出节点添加queue,因此若在pass node时,删除了node的输出节点,
// 那么会出现:已经删除的节点出现在queue中,并且被pop出来,因此这里做确认,如果node已经被删除过了,就跳过pass
if (g_state.nodes_deleted.count(node) > 0) {
GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str());
return false; return false;
} }
if (during_pass_node_set.nodes_suspend.count(node) > 0) {

// 因为在PassNode之前,会首先将node的输出节点添加queue,因此若在pass node时,suspend了node的输出节点,后续逻辑与上面相同
// TODO 需要注意的是,这里的保证是一次”尽力而为“,若pass node时,将node之前的节点`A`添加到了suspend,
// 那么`A`节点的后继和间接后继节点的pass不会受到suspend的影响
// 理论上来说,如果在pass node之前,首先收集node的输出节点,在pass后,将输出节点做suspend、delete的去除,然后加queue,
// 这样处理就不需要在这里做额外的确认了
if (g_state.nodes_suspend.count(node) > 0) {
GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.",
node->GetName().c_str());
return false;
}
if (!AllNodesIn(node->GetInAllNodes(), g_state.nodes_suspend)) {
GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.",
node->GetName().c_str()); node->GetName().c_str());
return false; return false;
} }

return true; return true;
} }
} // namespace } // namespace


Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map) {
Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map,
bool is_repass_io_immediately) {
if (node == nullptr) { if (node == nullptr) {
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid."); REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid.");
GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); GELOGE(FAILED, "[Check][Param] parameter node is nullptr.");
@@ -235,7 +203,7 @@ Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector<int>
return FAILED; return FAILED;
} }


AddRePassNodesWithInOut(node);
is_repass_io_immediately ? AddImmediateRePassNodesWithInOut(node) : AddRePassNodesWithInOut(node);


if (GraphUtils::IsolateNode(node, io_map) != GRAPH_SUCCESS) { if (GraphUtils::IsolateNode(node, io_map) != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Isolate Node:%s failed", node->GetName().c_str()); REPORT_CALL_ERROR("E19999", "Isolate Node:%s failed", node->GetName().c_str());
@@ -263,6 +231,12 @@ Status GEPass::Run(const NamesToPass &names_to_passes) {
GELOGW("No passes input, the GEPass will do nothing"); GELOGW("No passes input, the GEPass will do nothing");
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
for (const auto &name_to_pass : names_to_passes) {
if (name_to_pass.second == nullptr) {
GELOGE(INTERNAL_ERROR, "[Check][Param] There is null pointer in passes(%s)", name_to_pass.first.c_str());
return INTERNAL_ERROR;
}
}


if (depth_ > kMaxRecursiveDepth) { if (depth_ > kMaxRecursiveDepth) {
GELOGE(PARAM_INVALID, GELOGE(PARAM_INVALID,
@@ -275,81 +249,93 @@ Status GEPass::Run(const NamesToPass &names_to_passes) {
return RunPassesOneGraph(names_to_passes); return RunPassesOneGraph(names_to_passes);
} }


Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) {
GELOGD("Begin to run pass on graph, passes count %zu", names_to_passes.size());
std::deque<NodePtr> nodes;
DuringPassNodeSets during_pass_node_set;
GetAllNodesNoInputEdge(graph_, nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last);
GELOGD("Start points count %zu", nodes.size());
int re_pass_times = 0;
void NotifyPassGraphStart(const ComputeGraphPtr &graph, const NamesToPass &names_to_pass) {
for (auto &name_to_pass : names_to_pass) {
name_to_pass.second->OnStartPassGraph(graph);
}
}


do {
for (auto &node : during_pass_node_set.nodes_re_pass) {
nodes.push_back(node);
during_pass_node_set.nodes_seen.insert(node.get());
Status GEPass::HandleLeakedSuspendNodes(const NamesToPass &names_to_passes, GraphLevelState &g_state) {
for (auto &name_to_pass : names_to_passes) {
name_to_pass.second->init();
auto ret = name_to_pass.second->OnSuspendNodesLeaked();
if (ret != SUCCESS) {
// todo error
return ret;
} }
during_pass_node_set.nodes_re_pass.clear();

while (!nodes.empty()) {
NodePtr node = nodes.front();
nodes.pop_front();
SuspendAndResume(name_to_pass.first,
name_to_pass.second->GetNodesSuspend(),
name_to_pass.second->GetNodesResume(),
g_state);
}
return SUCCESS;
}


(void)during_pass_node_set.nodes_re_pass.erase(node);
if (!CheckNode(node, during_pass_node_set)) {
continue;
}
AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set);
Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) {
GELOGD("Begin to run pass on graph, passes count %zu", names_to_passes.size());
NotifyPassGraphStart(graph_, names_to_passes);
GraphLevelState g_state;
g_state.re_pass_times = 0;
GetAllNodesNoInputEdge(graph_, g_state);
GELOGD("Start points count %zu", g_state.nodes.size());


auto ret = RunPasses(node, names_to_passes, during_pass_node_set);
do {
if (!g_state.nodes_suspend.empty()) {
auto ret = HandleLeakedSuspendNodes(names_to_passes, g_state);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u",
node->GetName().c_str(), node->GetType().c_str(), ret);
// todo log
return ret; return ret;
} }

bool has_sub_graph = false;
ret = RunPassesOnSubGraph(node, names_to_passes, has_sub_graph);
if (ret != SUCCESS) {
GELOGE(ret, "[Run][Passes] on the sub graph of node %s failed", node->GetName().c_str());
return ret;
if (g_state.nodes.empty()) {
// todo 报错,因为suspend泄露场景,没有子类做进一步的resume,此处可能已经彻底泄露,需要报错
return INTERNAL_ERROR;
} }
}
auto ret = RunPassesGraphRepass(names_to_passes, g_state);
if (ret != SUCCESS) {
return ret;
}
} while (!g_state.nodes_suspend.empty());

return SUCCESS;
}


if (has_sub_graph) {
GELOGD("There are subgraphs on node %s, run passes for for the second time", node->GetName().c_str());
SetFlagOption(kOptimizeAfterSubGraph, names_to_passes);
ret = RunPasses(node, names_to_passes, during_pass_node_set);
if (ret != SUCCESS) {
GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code: %u",
node->GetName().c_str(), node->GetType().c_str(), ret);
return ret;
}

// There is only one option scene, so set and clear options around the `RunPasses` func.
// if there are more than one scene to set options, the `ClearOption` function
// should be called each time at the begin of the iteration
ClearOption(names_to_passes);
}


AddRepassNodes(during_pass_node_set, nodes);
AddResumeNodes(during_pass_node_set, nodes);
Status GEPass::RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLevelState &g_state) {
RepassLevelState rp_state;
do {
for (auto &node : rp_state.nodes_re_pass) {
g_state.AddNodeToQueue(node);
} }
rp_state.nodes_re_pass.clear();

while (!g_state.nodes.empty()) {
auto node = g_state.PopFront();


for (auto &node : during_pass_node_set.nodes_last) {
bool all_in_nodes_seen = node->IsAllInNodesSeen(during_pass_node_set.nodes_seen);
if (all_in_nodes_seen && during_pass_node_set.nodes_seen.insert(node.get()).second) {
nodes.push_back(node);
(void)rp_state.nodes_re_pass.erase(node); // todo 回忆一下为什么
if (!ShouldNodePassActually(node, g_state)) {
continue;
}
g_state.nodes_seen.insert(node.get()); // todo 为什么这里seen
AddNextIterNodes(node, g_state);

auto ret = RunPassesNodeOnce(node, names_to_passes, g_state, rp_state);
if (ret != SUCCESS) {
GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", node->GetName().c_str(),
node->GetType().c_str(), ret);
return ret;
} }
} }
during_pass_node_set.nodes_last.clear();
} while ((!during_pass_node_set.nodes_re_pass.empty() || !nodes.empty()) && ++re_pass_times < kMaxRePassTimes);
AddLastNodesToQueue(g_state);
} while ((!rp_state.nodes_re_pass.empty() || !g_state.nodes.empty()) && ++g_state.re_pass_times < kMaxRePassTimes);


if (re_pass_times == kMaxRePassTimes) {
if (g_state.re_pass_times == kMaxRePassTimes) {
GELOGW("re_pass_times should not come to %d", kMaxRePassTimes); GELOGW("re_pass_times should not come to %d", kMaxRePassTimes);
} }
GELOGD("All passes runs end"); GELOGD("All passes runs end");

return SUCCESS; return SUCCESS;
} }

Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph) { Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph) {
auto sub_graph_names = node->GetOpDesc()->GetSubgraphInstanceNames(); auto sub_graph_names = node->GetOpDesc()->GetSubgraphInstanceNames();
has_sub_graph = false; has_sub_graph = false;
@@ -371,4 +357,87 @@ Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names
} }
return SUCCESS; return SUCCESS;
} }

Status GEPass::RunPassesNodeOnce(NodePtr &node, const NamesToPass &names_to_passes,
GraphLevelState &g_state, RepassLevelState &rp_state) {
auto ret = RunPassesOnNode(node, names_to_passes, g_state, rp_state);
if (ret != SUCCESS) {
GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", node->GetName().c_str(),
node->GetType().c_str(), ret);
return ret;
}

bool has_sub_graph = false;
ret = RunPassesOnSubGraph(node, names_to_passes, has_sub_graph);
if (ret != SUCCESS) {
GELOGE(ret, "[Run][Passes] on the sub graph of node %s failed", node->GetName().c_str());
return ret;
}

if (has_sub_graph) {
GELOGD("There are subgraphs on node %s, run passes for for the second time", node->GetName().c_str());
SetFlagOption(kOptimizeAfterSubGraph, names_to_passes);
ret = RunPassesOnNode(node, names_to_passes, g_state, rp_state);
if (ret != SUCCESS) {
GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code: %u", node->GetName().c_str(),
node->GetType().c_str(), ret);
return ret;
}

// There is only one option scene, so set and clear options around the `RunPasses` func.
// if there are more than one scene to set options, the `ClearOption` function
// should be called each time at the begin of the iteration
ClearOption(names_to_passes);
}
return SUCCESS;
}

Status GEPass::RunPassesOnNode(NodePtr &node, const NamesToPass &names_to_passes, GraphLevelState &g_state,
RepassLevelState &rp_state) {
if (node == nullptr) {
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid.");
GELOGE(FAILED, "[Check][Param] parameter node is nullptr.");
return FAILED;
}
GELOGD("Begin to run pass for node %s", node->GetName().c_str());
for (const auto &name_to_pass : names_to_passes) {
GELOGD("Begin to run pass %s for node %s", name_to_pass.first.c_str(), node->GetName().c_str());
name_to_pass.second->init();
auto result = name_to_pass.second->Run(node);
if (result != SUCCESS) {
REPORT_CALL_ERROR("E19999", "process pass %s on node:%s failed, ret:%u",
name_to_pass.first.c_str(), node->GetName().c_str(), result);
GELOGE(INTERNAL_ERROR, "[Process][Pass] %s on node %s failed, result "
"%u, the passes will be terminated immediately.",
name_to_pass.first.c_str(), node->GetName().c_str(), result);
return result;
}
if (name_to_pass.second->GetNodesDeleted().count(node) > 0) {
GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(),
name_to_pass.first.c_str());
break;
}
}

g_state.nodes_passed.insert(node);

for (const auto &name_to_pass : names_to_passes) {
PushToRePassIfSeen(node, name_to_pass, g_state.nodes_seen,
name_to_pass.second->GetNodesNeedRePass(),
rp_state.nodes_re_pass);

AddImmediateRepassNodesToQueue(node, name_to_pass,
name_to_pass.second->GetNodesNeedRePassImmediately(),
g_state);
SuspendAndResume(name_to_pass.first,
name_to_pass.second->GetNodesSuspend(),
name_to_pass.second->GetNodesResume(),
g_state);

const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted();
g_state.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end());
}

return SUCCESS;
}
} // namespace ge } // namespace ge

+ 81
- 16
ge/graph/passes/base_pass.h View File

@@ -22,7 +22,6 @@
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>

#include "framework/common/ge_inner_error_codes.h" #include "framework/common/ge_inner_error_codes.h"
#include "framework/common/types.h" #include "framework/common/types.h"
#include "graph/compute_graph.h" #include "graph/compute_graph.h"
@@ -61,23 +60,32 @@ class BaseNodePass {


const std::unordered_set<NodePtr> &GetNodesResume() { return nodes_resume_; } const std::unordered_set<NodePtr> &GetNodesResume() { return nodes_resume_; }


virtual Status OnSuspendNodesLeaked() { return SUCCESS; }

void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; }


void ClearOptions() { options_.clear(); } void ClearOptions() { options_.clear(); }


void init() { void init() {
nodes_need_re_pass_.clear(); nodes_need_re_pass_.clear();
nodes_deleted_.clear();
nodes_need_re_pass_immediately_.clear(); nodes_need_re_pass_immediately_.clear();
nodes_deleted_.clear();
nodes_suspend_.clear(); nodes_suspend_.clear();
nodes_resume_.clear(); nodes_resume_.clear();
} }


virtual void OnStartPassGraph(const ComputeGraphPtr &graph) {
current_graph_name_ = graph->GetName();
}

protected: protected:
Status IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map);
const string &GetCurrentGraphName() const {
return current_graph_name_;
}
Status IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map, bool is_repass_io_immediately = false);


Status IsolateAndDeleteNode(NodePtr &node, const std::initializer_list<int> &io_map) {
return IsolateAndDeleteNode(node, std::vector<int>(io_map));
Status IsolateAndDeleteNode(NodePtr &node, const std::initializer_list<int> &io_map, bool is_repass_io_immediately = false) {
return IsolateAndDeleteNode(node, std::vector<int>(io_map), is_repass_io_immediately);
} }


/// ///
@@ -112,6 +120,22 @@ class BaseNodePass {
} }
} }


///
/// Add a node and it's input/output data nodes to be optimized immediately again.
/// @param node
///
void AddImmediateRePassNodesWithInOut(const NodePtr &node) {
auto in_nodes = node->GetInNodes();
for (auto &in_node : in_nodes) {
AddImmediateRePassNode(in_node);
}
AddImmediateRePassNode(node);
auto out_nodes = node->GetOutNodes();
for (auto &out_node : out_nodes) {
AddImmediateRePassNode(out_node);
}
}

/// ///
/// If you deleted a node from the graph, especially current node. The remain /// If you deleted a node from the graph, especially current node. The remain
/// iterate passes will continue process on the deleted node(if it can be /// iterate passes will continue process on the deleted node(if it can be
@@ -123,23 +147,15 @@ class BaseNodePass {
void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); }


/// ///
/// If you suspend a node from the graph, especially following node. The remain
/// iterate passes will stop process on the suspend node(if it can be
/// If you postpone a node from the graph, especially following node. The remain
/// iterate passes will stop process on the postpone node(if it can be
/// reached by edge connections) till the last one. Obviously it is a waste of /// reached by edge connections) till the last one. Obviously it is a waste of
/// time. You can add the suspend nodes by calling this function, to stop the
/// time. You can add the postpone nodes by calling this function, to stop the
/// next iterations. /// next iterations.
/// @param node /// @param node
/// ///
void AddNodeSuspend(const NodePtr &node) { nodes_suspend_.insert(node); } void AddNodeSuspend(const NodePtr &node) { nodes_suspend_.insert(node); }


///
/// If you resume a node from the graph, especially following node. The remain
/// iterate passes will continue process on the resume node(if it can be
/// reached by edge connections) till the last one.
/// You can add the resume nodes by calling this function, to resume the
/// next iterations.
/// @param node
///
void AddNodeResume(const NodePtr &node) { nodes_resume_.insert(node); } void AddNodeResume(const NodePtr &node) { nodes_resume_.insert(node); }


bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } bool OptionExists(NodePassOption option) { return options_.count(option) > 0; }
@@ -151,6 +167,7 @@ class BaseNodePass {
std::unordered_set<NodePtr> nodes_suspend_; std::unordered_set<NodePtr> nodes_suspend_;
std::unordered_set<NodePtr> nodes_resume_; std::unordered_set<NodePtr> nodes_resume_;
std::map<NodePassOption, std::string> options_; std::map<NodePassOption, std::string> options_;
std::string current_graph_name_;
}; };


using NamesToPass = std::vector<std::pair<std::string, BaseNodePass *>>; using NamesToPass = std::vector<std::pair<std::string, BaseNodePass *>>;
@@ -160,12 +177,60 @@ class GEPass {
explicit GEPass(ComputeGraphPtr &graph) : graph_(graph), root_graph_(graph), depth_(1) {} explicit GEPass(ComputeGraphPtr &graph) : graph_(graph), root_graph_(graph), depth_(1) {}
virtual ~GEPass() = default; virtual ~GEPass() = default;
Status Run(const NamesToPass &names_to_passes); Status Run(const NamesToPass &names_to_passes);
/*
* todo
* OneGraph: nodes_deleted, nodes_seen, nodes_passed, nodes_suspended
* RePass: nodes_re_pass
* GraphOneTime: nodes_last
* NodeOneTime: nodes_re_pass_immediately, nodes_resume
*/
struct GraphLevelState {
std::unordered_set<NodePtr> nodes_deleted;
std::unordered_set<Node *> nodes_seen;
std::unordered_set<NodePtr> nodes_passed;
std::unordered_set<NodePtr> nodes_suspend;
std::unordered_set<NodePtr> nodes_last;
std::deque<NodePtr> nodes;
int re_pass_times;

void AddNodeToQueueFront(NodePtr node) {
nodes_seen.insert(node.get());
nodes.emplace_front(std::move(node));
}

void AddNodeToQueue(NodePtr node) {
nodes_seen.insert(node.get());
nodes.emplace_back(std::move(node));
}
void AddNodeToQueueIfNotSeen(NodePtr node) {
if (nodes_seen.insert(node.get()).second) {
nodes.emplace_back(std::move(node));
}
}
NodePtr PopFront() {
NodePtr node = nodes.front();
nodes.pop_front();
return node;
}
};
struct RepassLevelState {
std::unordered_set<NodePtr> nodes_re_pass;
};
struct GraphOneTimeLevelState {
std::unordered_set<NodePtr> nodes_last;
};


private: private:
GEPass(ComputeGraphPtr &graph, ComputeGraphPtr &root_graph, int depth) GEPass(ComputeGraphPtr &graph, ComputeGraphPtr &root_graph, int depth)
: graph_(graph), root_graph_(root_graph), depth_(depth) {} : graph_(graph), root_graph_(root_graph), depth_(depth) {}
Status RunPassesNodeOnce(NodePtr &node, const NamesToPass &names_to_passes,
GraphLevelState &g_state, RepassLevelState &rp_state);
Status RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLevelState &g_state);
Status RunPassesOneGraph(const NamesToPass &names_to_passes); Status RunPassesOneGraph(const NamesToPass &names_to_passes);
Status RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph); Status RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph);
Status RunPassesOnNode(NodePtr &node, const NamesToPass &names_to_passes, GraphLevelState &g_state,
RepassLevelState &rp_state);
Status HandleLeakedSuspendNodes(const NamesToPass &names_to_passes, GraphLevelState &g_state);
ComputeGraphPtr graph_; ComputeGraphPtr graph_;
ComputeGraphPtr root_graph_; ComputeGraphPtr root_graph_;
int depth_; int depth_;


+ 1
- 1
ge/graph/passes/folding_pass.cc View File

@@ -363,7 +363,7 @@ Status FoldingPass::ConnectNodeToInAnchor(InDataAnchorPtr &in_anchor, NodePtr &n
in_anchor->GetIdx()); in_anchor->GetIdx());
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
AddRePassNodesWithInOut(node);
AddImmediateRePassNode(node);
return SUCCESS; return SUCCESS;
} }
} // namespace ge } // namespace ge

+ 1
- 1
ge/graph/passes/infer_base_pass.h View File

@@ -36,7 +36,7 @@ class InferBasePass : public BaseNodePass {
* @param dst, output TensorDesc to be updated * @param dst, output TensorDesc to be updated
* @return * @return
*/ */
virtual graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) = 0;
virtual graphStatus UpdateTensorDesc(GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) = 0;


/** /**
* Update the output TensorDesc for nodes which contain subgraphs. * Update the output TensorDesc for nodes which contain subgraphs.


+ 1
- 1
ge/graph/passes/infer_value_range_pass.cc View File

@@ -207,7 +207,7 @@ bool InferValueRangePass::InputHasUnknownValueRange(const NodePtr &node) const {
return has_unknown_value_range; return has_unknown_value_range;
} }


graphStatus InferValueRangePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) {
graphStatus InferValueRangePass::UpdateTensorDesc(GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) {
if (src == nullptr || dst == nullptr) { if (src == nullptr || dst == nullptr) {
REPORT_CALL_ERROR("E19999", "While updating tensor desc, input desc is null."); REPORT_CALL_ERROR("E19999", "While updating tensor desc, input desc is null.");
GELOGE(GRAPH_FAILED, "[Param][check] While updating tensor desc, input desc is null."); GELOGE(GRAPH_FAILED, "[Param][check] While updating tensor desc, input desc is null.");


+ 1
- 1
ge/graph/passes/infer_value_range_pass.h View File

@@ -26,7 +26,7 @@ class InferValueRangePass : public InferBasePass {


private: private:
std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override; std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override;
graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override;
graphStatus UpdateTensorDesc(GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override;
graphStatus UpdateOutputFromSubgraphs(const std::vector<GeTensorDescPtr> &src, GeTensorDescPtr &dst) override; graphStatus UpdateOutputFromSubgraphs(const std::vector<GeTensorDescPtr> &src, GeTensorDescPtr &dst) override;
graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector<GeTensorDescPtr> &src, graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector<GeTensorDescPtr> &src,
GeTensorDescPtr &dst) override; GeTensorDescPtr &dst) override;


+ 258
- 104
ge/graph/passes/infershape_pass.cc View File

@@ -1,6 +1,6 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*+
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
@@ -22,13 +22,16 @@
#include "graph/shape_refiner.h" #include "graph/shape_refiner.h"
#include "graph/utils/graph_utils.h" #include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h" #include "graph/utils/node_utils.h"
#include "graph/common/omg_util.h"
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"
#include "graph/utils/tensor_utils.h" #include "graph/utils/tensor_utils.h"
#include "graph/utils/type_utils.h" #include "graph/utils/type_utils.h"


namespace ge {
#include "external/graph/operator_factory.h"


namespace ge {
namespace {
constexpr int kSwitchExitAnchorIndex = 0;
constexpr int kSwitchPredAnchorIndex = 1;
void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) {
desc_str += "["; desc_str += "[";
std::vector<std::pair<int64_t, int64_t>> shape_range; std::vector<std::pair<int64_t, int64_t>> shape_range;
@@ -47,129 +50,280 @@ void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) {
desc_str += "},"; desc_str += "},";
} }
} }
void UpdateShapeAndDType(const GeTensorDescPtr &src, GeTensorDescPtr &dst) {
dst->SetOriginShape(src->GetOriginShape());
dst->SetShape(src->GetShape());
dst->SetDataType(src->GetDataType());
dst->SetOriginDataType(src->GetOriginDataType());
vector<pair<int64_t, int64_t>> src_shape_range;
src->GetShapeRange(src_shape_range);
dst->SetShapeRange(src_shape_range);
dst->SetOriginShapeRange(src_shape_range);
ge::TensorUtils::SetRealDimCnt(*dst, static_cast<uint32_t>(src->GetShape().GetDims().size()));
}
} // namespace


std::string GetInTensorInfoWithString(const ge::NodePtr &node) {
ge::OpDescPtr op_desc = node->GetOpDesc();
std::string InferShapePass::SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const {
std::stringstream ss; std::stringstream ss;
ss << "{";
int32_t in_idx = 0;
for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
if (input_desc == nullptr) {
in_idx++;
continue;
}
if (in_idx > 0) {
ss << " ";
}
ss << "input_" << in_idx << " " << "tensor: [";
ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),";
ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),";
ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),";
ss << "(origin_shape:" << input_desc->GetOriginShape().ToString() << "),";
ss << "(origin_format:" << TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) << "),";
ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) << "),";
string range_str;
SerialShapeRange(input_desc, range_str);
ss << "(shape_range:" << range_str << ")]";
in_idx++;
}
ss << "(shape:[" << tensor_desc->MutableShape().ToString() << "]),";
ss << "(format:" << TypeUtils::FormatToSerialString(tensor_desc->GetFormat()) << "),";
ss << "(dtype:" << TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()) << "),";
ss << "(origin_shape:" << tensor_desc->GetOriginShape().ToString() << "),";
ss << "(origin_format:" << TypeUtils::FormatToSerialString(tensor_desc->GetOriginFormat()) << "),";
ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(tensor_desc->GetOriginDataType()) << "),";
string range_str;
SerialShapeRange(tensor_desc, range_str);
ss << "(shape_range:" << range_str << ")";
return ss.str(); return ss.str();
} }
Status InferShapePass::SuspendV1LoopExitNodes(const NodePtr &node) {
if (node->GetType() != SWITCH) {
return SUCCESS;
}
auto pred_node = NodeUtils::GetInDataNodeByIndex(*node, kSwitchPredAnchorIndex);
GE_CHECK_NOTNULL(pred_node);
if (pred_node->GetType() != LOOPCOND) {
return SUCCESS;
}


Status InferShapePass::Run(NodePtr &node) {
// kOptimizeAfterSubGraph exist means after subgraph
auto ret = ShapeRefiner::InferShapeAndType(node, !OptionExists(kOptimizeAfterSubGraph));
if (ret != GRAPH_SUCCESS) {
// select INFERSHAPE failed info
auto graph = node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(graph);
auto root_graph = ge::GraphUtils::FindRootGraph(graph);
GE_CHECK_NOTNULL(root_graph);
analyzer::DataInfo analyze_info{root_graph->GetSessionID(), root_graph->GetGraphID(),
analyzer::INFER_SHAPE, node, "InferShapeFailed!"};
(void)Analyzer::GetInstance()->DoAnalyze(analyze_info);
(void)Analyzer::GetInstance()->SaveAnalyzerDataToFile(root_graph->GetSessionID(),
root_graph->GetGraphID());

REPORT_CALL_ERROR("E19999", "Call InferShapeAndType for node:%s(%s) failed, input_tensor:%s",
node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str());
GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "[Call][InferShapeAndType] for node:%s(%s) failed, input_tensor:%s",
node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str());
return GE_GRAPH_INFERSHAPE_FAILED;
}

GE_CHK_STATUS_RET_NOLOG(RePassLoopNode(node));
bool need_repass = false;
auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_repass);
if (has_attr) {
if (!OptionExists(kOptimizeAfterSubGraph)) {
return SUCCESS;
}
if (need_repass) {
AddImmediateRePassNode(node);
GELOGD("Node %s need repass immediately.", node->GetName().c_str());
} else {
// clear attr on while
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN);
for (const auto &anchor_2_node : NodeUtils::GetOutDataNodesWithAnchorByIndex(*node, kSwitchExitAnchorIndex)) {
GELOGI("Found v1 loop when infershape, suspend Exit node %s, type %s.", anchor_2_node.second->GetName().c_str(),
anchor_2_node.second->GetType().c_str());
auto iter = graphs_2_suspend_nodes_.find(GetCurrentGraphName());
auto &suspend_nodes = graphs_2_suspend_nodes_[GetCurrentGraphName()];
if (suspend_nodes.nodes_set.insert(anchor_2_node.second).second) {
suspend_nodes.nodes.push(anchor_2_node.second);
AddNodeSuspend(anchor_2_node.second);
} }
} }
return SUCCESS; return SUCCESS;
} }


Status InferShapePass::RePassLoopNode(const NodePtr &node) {
const auto RePassNode = [&](const std::set<std::string> &re_pass_types) {
for (auto &n : node->GetOutDataNodes()) {
GE_CHECK_NOTNULL(n);
std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "[Get][OriginalType] of node:%s failed.", n->GetName().c_str());
if (re_pass_types.count(node_type) > 0) {
AddImmediateRePassNode(n);
(void)AttrUtils::SetBool(n->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, false);
GELOGD("Node %s need repass immediately after %s.", n->GetName().c_str(), node->GetName().c_str());
Status InferShapePass::Infer(NodePtr &node) {
auto ret = SuspendV1LoopExitNodes(node);
if (ret != SUCCESS) {
//todo LOG
return ret;
}
bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag();
auto opdesc = node->GetOpDesc();
if (node->Verify() != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Verifying %s failed.", node->GetName().c_str());
GELOGE(GRAPH_FAILED, "[Call][Verify] Verifying %s failed.", node->GetName().c_str());
return GRAPH_FAILED;
}
Operator op = OpDescUtils::CreateOperatorFromNode(node);

if (!is_unknown_graph) {
auto inference_context = ShapeRefiner::CreateInferenceContext(node);
GE_CHECK_NOTNULL(inference_context);
GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size());
op.SetInferenceContext(inference_context);
}

graphStatus status = CallInferShapeFunc(node, op);
if (status != GRAPH_NODE_NEED_REPASS && status != GRAPH_PARAM_INVALID && status != GRAPH_SUCCESS) {
// node like netoutput return param_invalid, but valid ?
REPORT_CALL_ERROR("E19999", "%s call infer function failed.", node->GetName().c_str());
GELOGE(GRAPH_FAILED, "[Call][InferFunction] failed, node:%s.", node->GetName().c_str());
return GRAPH_FAILED;
}
if (!is_unknown_graph) {
auto ctx_after_infer = op.GetInferenceContext();
if (ctx_after_infer != nullptr) {
GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size());
if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) {
GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(),
ctx_after_infer->GetMarks().size());
ShapeRefiner::PushToContextMap(node, ctx_after_infer);
} }
} }
return SUCCESS;
};

const auto ExProcNode = [&](const std::set<std::string> &proc_types,
const std::function<void(InferShapePass *, NodePtr)> &proc_func,
const std::string &info) {
for (auto &n : node->GetOutDataNodes()) {
GE_CHECK_NOTNULL(n);
std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "[Get][OriginalType] of node:%s failed.", n->GetName().c_str());
if (proc_types.count(node_type) > 0) {
proc_func(this, n);
GELOGD("Node %s %s after %s.", n->GetName().c_str(), info.c_str(), node->GetName().c_str());
}
}
return (status == GRAPH_NODE_NEED_REPASS) ? GRAPH_NODE_NEED_REPASS : GRAPH_SUCCESS;
}

bool InferShapePass::SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) {
// check shape range
vector<std::pair<int64_t, int64_t>> src_shape_range;
vector<std::pair<int64_t, int64_t>> dst_shape_range;
src->GetShapeRange(src_shape_range);
dst->GetShapeRange(dst_shape_range);
if (src_shape_range.size() != dst_shape_range.size()) {
GELOGI("Src shape range size is %zu, dst shape range size is %zu, not same.", src_shape_range.size(),
dst_shape_range.size());
return false;
}
for (size_t i = 0; i < src_shape_range.size(); ++i) {
if (src_shape_range[i].first != dst_shape_range[i].first ||
src_shape_range[i].second != dst_shape_range[i].second) {
GELOGI("Current dim %zu. Src shape range is [%lu-%lu], dst shape range is [%lu-%lu], not same.",
i, src_shape_range[i].first, src_shape_range[i].second, dst_shape_range[i].first, dst_shape_range[i].second);
return false;
} }
}

// check shape
auto src_shape = src->GetShape();
auto dst_shape = dst->GetShape();
if (src_shape.GetDims() != dst_shape.GetDims() || src->GetOriginShape().GetDims() != dst->GetOriginShape().GetDims() ||
src->GetDataType() != dst->GetDataType() || src->GetOriginDataType() != dst->GetOriginDataType()) {
GELOGD(
"Src shape is %s, origin_shape is %s, data_type is %s, origin data_type is %s; "
"Dst shape is %s, origin_shape is %s, data_type is %s, original data_type is %s, not same.",
src_shape.ToString().c_str(), src->GetOriginShape().ToString().c_str(),
TypeUtils::DataTypeToSerialString(src->GetDataType()).c_str(),
TypeUtils::DataTypeToSerialString(src->GetOriginDataType()).c_str(), dst_shape.ToString().c_str(),
dst->GetOriginShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(dst->GetDataType()).c_str(),
TypeUtils::DataTypeToSerialString(dst->GetOriginDataType()).c_str());
return false;
}
return true;
}

graphStatus InferShapePass::UpdateTensorDesc(GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) {
// refresh src itself
src->SetOriginShape(src->GetShape());
src->SetOriginDataType(src->GetDataType());
TensorUtils::SetRealDimCnt(*src, static_cast<uint32_t>(src->GetOriginShape().GetDims().size()));
vector<pair<int64_t, int64_t>> src_shape_range;
src->GetShapeRange(src_shape_range);
src->SetOriginShapeRange(src_shape_range);

changed = false;
if (SameTensorDesc(src, dst)) {
GELOGD("Peer dst tensor_desc is same as src tensor_desc. No need update.");
return SUCCESS; return SUCCESS;
};
}

changed = true;
UpdateShapeAndDType(src, dst);
GELOGD(
"UpdatePeerInputDesc from src Node: shape: [%s], datatype: %s, original datatype is %s."
"To dst Node: shape: [%s], datatype: %s, original datatype is %s.",
src->GetShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(src->GetDataType()).c_str(),
TypeUtils::DataTypeToSerialString(src->GetOriginDataType()).c_str(), dst->GetShape().ToString().c_str(),
TypeUtils::DataTypeToSerialString(dst->GetDataType()).c_str(),
TypeUtils::DataTypeToSerialString(dst->GetOriginDataType()).c_str());
return SUCCESS;
}

graphStatus InferShapePass::CallInferShapeFunc(NodePtr &node, Operator &op) {
auto op_desc = node->GetOpDesc();
const auto &op_type = op_desc->GetType();
auto ret = op_desc->CallInferFunc(op);
if (ret == GRAPH_PARAM_INVALID) {
// Op ir no infer func, try to get infer func from operator factory
auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType());
if (node_op.IsEmpty()) {
GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str());
return ret;
}


std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(node, node_type),
"[Get][OriginalType] of node:%s failed.", node->GetName().c_str());
if (kNextIterationOpTypes.count(node_type) > 0) {
return RePassNode(kMergeOpTypes); // Re-Pass Merge
GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str());
auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op);
node_op.BreakConnect();
if (temp_op_desc == nullptr) {
REPORT_CALL_ERROR("E19999", "GetOpDescFromOperator failed, return nullptr.");
GELOGE(GRAPH_FAILED, "[Get][OpDesc] temp op desc is null");
return GRAPH_FAILED;
}
if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) {
GELOGW("InferShapeAndType UpdateInputName failed");
for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) {
if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) {
break;
}
return GRAPH_SUCCESS;
}
}
if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) {
GELOGW("InferShapeAndType UpdateOutputName failed");
}
op_desc->AddInferFunc(temp_op_desc->GetInferFunc());
ret = op_desc->CallInferFunc(op);
GELOGI("op CallInferFunc second. ret: %u", ret);
} }
return ret;
}


if (kMergeOpTypes.count(node_type) > 0) {
if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) {
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN);
return RePassNode(kSwitchOpTypes); // Re-Pass Switch
graphStatus InferShapePass::UpdateOutputFromSubgraphs(const std::vector<GeTensorDescPtr> &src, GeTensorDescPtr &dst) {
GELOGD("Enter update parent node shape for class branch op process");
// check sub_graph shape.If not same ,do unknown shape process
auto ref_out_tensor = src.at(0);
ge::GeShape &ref_out_tensor_shape = ref_out_tensor->MutableShape();
for (auto &tensor : src) {
if (ref_out_tensor->GetDataType() != tensor->GetDataType()) {
REPORT_INNER_ERROR("E19999", "Does not support diff dtype among all ref output, shape:%s",
ref_out_tensor_shape.ToString().c_str());
GELOGE(GRAPH_FAILED, "[Check][Param] node does not support diff dtype output");
return GRAPH_FAILED;
}
auto shape = tensor->MutableShape();
if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) {
GELOGD("Shape from subgraph size: %lu, ref_out_tensor_shape size: %lu", shape.GetShapeSize(),
ref_out_tensor_shape.GetShapeSize());
ref_out_tensor_shape = GeShape(UNKNOWN_RANK);
break;
} }
for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) {
if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) {
continue;
}
GELOGD("j: %zu ,shape from subgraph size: %lu, ref_out_tensor_shape size: %lu", j, shape.GetShapeSize(),
ref_out_tensor_shape.GetShapeSize());
(void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM);
}
}
UpdateShapeAndDType(ref_out_tensor, dst);
return GRAPH_SUCCESS;
}
graphStatus InferShapePass::UpdateOutputFromSubgraphsForMultiDims(const std::vector<GeTensorDescPtr> &src,
GeTensorDescPtr &dst) {
// check sub_graph shape. Get max for update.
if (src.empty()) {
// TODO LOG
return SUCCESS; return SUCCESS;
} }


if (kSwitchOpTypes.count(node_type) > 0) {
if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) {
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN);
return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeResume, "need resume"); // Resume Exit
} else {
return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeSuspend, "need suspend"); // Suspend Exit
int64_t max_size = 0;
size_t max_shape_index = 0;
auto &ref_out_tensor = src.at(0);
for (size_t j = 0; j < src.size(); ++j) {
auto &tensor = src.at(j);
if (ref_out_tensor->GetDataType() != tensor->GetDataType()) {
REPORT_INNER_ERROR("E19999", "node does not support diff dtype among all ref output");
GELOGE(GRAPH_FAILED, "[Check][Param] node does not support diff dtype among all ref output");
return GRAPH_FAILED;
}

auto shape = tensor->MutableShape();
int64_t size = 1;
for (auto dim : shape.GetDims()) {
if (dim != 0 && INT64_MAX / dim < size) {
REPORT_INNER_ERROR("E19999", "The shape:%s size overflow", shape.ToString().c_str());
GELOGE(PARAM_INVALID, "[Check][Overflow] The shape size overflow");
return PARAM_INVALID;
}
size *= dim;
} }
}


if (size > max_size) {
max_size = size;
max_shape_index = j;
}
}
UpdateShapeAndDType(src.at(max_shape_index), dst);
return GRAPH_SUCCESS;
}
Status InferShapePass::OnSuspendNodesLeaked() {
auto iter = graphs_2_suspend_nodes_.find(GetCurrentGraphName());
if (iter == graphs_2_suspend_nodes_.end()) {
// todo log warn
return SUCCESS;
}
if (!iter->second.nodes.empty()) {
AddNodeResume(iter->second.PopSuspendedNode());
}
return SUCCESS; return SUCCESS;
} }
} // namespace ge } // namespace ge

+ 29
- 12
ge/graph/passes/infershape_pass.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -17,22 +17,39 @@
#ifndef GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ #ifndef GE_GRAPH_PASSES_INFERSHAPE_PASS_H_
#define GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ #define GE_GRAPH_PASSES_INFERSHAPE_PASS_H_


#include "graph/passes/base_pass.h"
#include "graph/passes/infer_base_pass.h"
#include <stack>


namespace ge { namespace ge {
class InferShapePass : public BaseNodePass {
class InferShapePass : public InferBasePass {
public: public:
///
/// Entry of the InferShapePass optimizer
/// @param [in] graph: Input ComputeGraph
/// @return SUCCESS: Execution succeed
/// @return OTHERS: Execution failed
/// @author
///
Status Run(ge::NodePtr &node) override;
std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override;
graphStatus Infer(NodePtr &node) override;

graphStatus UpdateTensorDesc(GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override;
graphStatus UpdateOutputFromSubgraphs(const std::vector<GeTensorDescPtr> &src, GeTensorDescPtr &dst) override;
graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector<GeTensorDescPtr> &src,
GeTensorDescPtr &dst) override;

Status OnSuspendNodesLeaked() override;


private: private:
Status RePassLoopNode(const NodePtr &node);
graphStatus CallInferShapeFunc(NodePtr &node, Operator &op);
bool SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst);
Status SuspendV1LoopExitNodes(const NodePtr &node);
struct SuspendNodes {
std::stack<NodePtr> nodes;
std::unordered_set<NodePtr> nodes_set;

NodePtr PopSuspendedNode() {
auto top_node = nodes.top();
nodes.pop();
nodes_set.erase(top_node);
return top_node;
}
};
std::map<std::string, SuspendNodes> graphs_2_suspend_nodes_;

}; };
} // namespace ge } // namespace ge
#endif // GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ #endif // GE_GRAPH_PASSES_INFERSHAPE_PASS_H_

+ 3
- 2
ge/graph/passes/merge_pass.cc View File

@@ -31,6 +31,7 @@ namespace ge {
const int kValueIndexOutputIndex = 1; const int kValueIndexOutputIndex = 1;
const size_t kCaseNoInput = 0; const size_t kCaseNoInput = 0;
const size_t kCaseOneInput = 1; const size_t kCaseOneInput = 1;
const bool kWillRepassImmediately = true;


Status MergePass::Run(NodePtr &node) { Status MergePass::Run(NodePtr &node) {
GELOGD("MergePass running"); GELOGD("MergePass running");
@@ -82,14 +83,14 @@ Status MergePass::Run(NodePtr &node) {
} }
auto in_node = in_data_nodes.at(0); auto in_node = in_data_nodes.at(0);
if (IsMergeInputNeedOptimized(in_node)) { if (IsMergeInputNeedOptimized(in_node)) {
if (IsolateAndDeleteNode(in_node, {0}) != SUCCESS) {
if (IsolateAndDeleteNode(in_node, {0}, kWillRepassImmediately) != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Isolate and delete node:%s(%s) failed", REPORT_CALL_ERROR("E19999", "Isolate and delete node:%s(%s) failed",
in_node->GetName().c_str(), in_node->GetType().c_str()); in_node->GetName().c_str(), in_node->GetType().c_str());
GELOGE(FAILED, "[Remove][Node] %s failed.", in_node->GetName().c_str()); GELOGE(FAILED, "[Remove][Node] %s failed.", in_node->GetName().c_str());
return FAILED; return FAILED;
} }
} }
return IsolateAndDeleteNode(node, merge_io_map);
return IsolateAndDeleteNode(node, merge_io_map, kWillRepassImmediately);
} }
default: { default: {
// Case C: input_count > 1, the merge node can not be optimized // Case C: input_count > 1, the merge node can not be optimized


+ 2
- 1
ge/graph/passes/switch_dead_branch_elimination.cc View File

@@ -28,6 +28,7 @@ namespace {
const std::vector<int>::size_type kDataInputIndex = 0; const std::vector<int>::size_type kDataInputIndex = 0;
const std::vector<int>::size_type kPredInputIndex = 1; const std::vector<int>::size_type kPredInputIndex = 1;
const int kDefaultInputIndex = -1; const int kDefaultInputIndex = -1;
const bool kWillRepassImmediately = true;


bool ParsePred(const ConstGeTensorPtr &tensor) { bool ParsePred(const ConstGeTensorPtr &tensor) {
if (tensor == nullptr) { if (tensor == nullptr) {
@@ -134,7 +135,7 @@ Status SwitchDeadBranchElimination::DeleteSwitchNode(NodePtr &node, NodePtr &pre
return FAILED; return FAILED;
} }
switch_io_map[out_index] = kDataInputIndex; switch_io_map[out_index] = kDataInputIndex;
return IsolateAndDeleteNode(node, switch_io_map);
return IsolateAndDeleteNode(node, switch_io_map, kWillRepassImmediately);
} }


Status SwitchDeadBranchElimination::Run(NodePtr &node) { Status SwitchDeadBranchElimination::Run(NodePtr &node) {


+ 16
- 0
ge/graph/preprocess/graph_preprocess.cc View File

@@ -1999,6 +1999,22 @@ Status GraphPrepare::CheckUserInput(const std::vector<GeTensor> &user_input) {


Status GraphPrepare::InferShapeForPreprocess() { Status GraphPrepare::InferShapeForPreprocess() {
GELOGI("Start infershape for preprocess."); GELOGI("Start infershape for preprocess.");
// Prepare dummy_shape for v1 control_flow op before infershape
for (const auto &node : compute_graph_->GetAllNodes()) {
string type;
GetOriginalType(node, type);
if (type == MERGE || type == REFMERGE) {
for (size_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) {
GELOGD("Prepare for infershape: update %s input_shape as dummy.", node->GetName().c_str());
NodeUtils::UpdateInputShape(*node, i, GeShape(DUMMY_SHAPE));
}
} else if (type == WHILE) {
for (size_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) {
GELOGD("Prepare for infershape: update %s output_shape as dummy.", node->GetName().c_str());
NodeUtils::UpdateOutputShape(*node, i, GeShape(DUMMY_SHAPE));
}
}
}
GEPass ge_passes(compute_graph_); GEPass ge_passes(compute_graph_);
NamesToPass names_to_passes; NamesToPass names_to_passes;
AssertPass assert_pass; AssertPass assert_pass;


Loading…
Cancel
Save