Browse Source

add suspend node in base_pass

tags/v1.3.0
chenyemeng 3 years ago
parent
commit
abde4f2f8c
4 changed files with 141 additions and 22 deletions
  1. +68
    -12
      ge/graph/passes/base_pass.cc
  2. +34
    -6
      ge/graph/passes/base_pass.h
  3. +35
    -4
      ge/graph/passes/infershape_pass.cc
  4. +4
    -0
      ge/hybrid/model/hybrid_model_builder.cc

+ 68
- 12
ge/graph/passes/base_pass.cc View File

@@ -36,6 +36,8 @@ struct DuringPassNodeSets {
std::unordered_set<NodePtr> nodes_re_pass; std::unordered_set<NodePtr> nodes_re_pass;
std::unordered_set<NodePtr> nodes_re_pass_immediately; std::unordered_set<NodePtr> nodes_re_pass_immediately;
std::unordered_set<NodePtr> nodes_last; std::unordered_set<NodePtr> nodes_last;
std::unordered_set<NodePtr> nodes_suspend;
std::unordered_set<NodePtr> nodes_resume;
}; };


void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &input_edge_nodes, void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &input_edge_nodes,
@@ -55,8 +57,15 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &i
} }
} }


bool IsAllInNodesAlive(const Node::Vistor<NodePtr> &nodes, const std::unordered_set<NodePtr> &nodes_suspend) {
return !std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { return nodes_suspend.count(n) > 0; });
}

void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass, void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass,
std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_last) {
DuringPassNodeSets &during_pass_node_set) {
auto &nodes_seen = during_pass_node_set.nodes_seen;
const auto &nodes_last = during_pass_node_set.nodes_last;
const auto &nodes_suspend = during_pass_node_set.nodes_suspend;
for (auto &node : nodes) { for (auto &node : nodes) {
if (node == nullptr) { if (node == nullptr) {
continue; continue;
@@ -64,16 +73,57 @@ void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &n
if (nodes_last.count(node) != 0) { if (nodes_last.count(node) != 0) {
continue; continue;
} }
if (nodes_suspend.count(node) > 0) {
GELOGD("The node %s has suspend by pass, skip it.", node->GetName().c_str());
continue;
}


bool all_in_nodes_alive = IsAllInNodesAlive(node->GetInAllNodes(), nodes_suspend);
bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen); bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen);
if (all_in_nodes_seen && nodes_seen.insert(node.get()).second) {
if (all_in_nodes_seen && all_in_nodes_alive && nodes_seen.insert(node.get()).second) {
nodes_to_pass.push_back(node); nodes_to_pass.push_back(node);
} }
} }
} }


void AddRepassNodes(DuringPassNodeSets &during_pass_node_set, std::deque<NodePtr> &nodes) {
for (const auto &node : during_pass_node_set.nodes_re_pass_immediately) {
GELOGD("The node %s will be re-pass immediately.", node->GetName().c_str());
nodes.push_front(node);
}
during_pass_node_set.nodes_re_pass_immediately.clear();
}

void AddResumeNodes(DuringPassNodeSets &during_pass_node_set, std::deque<NodePtr> &nodes) {
for (auto &node : during_pass_node_set.nodes_resume) {
const auto &it = during_pass_node_set.nodes_suspend.find(node);
if (it != during_pass_node_set.nodes_suspend.end()) {
during_pass_node_set.nodes_suspend.erase(node);
GELOGD("The node %s resumed by pass.", node->GetName().c_str());
nodes.push_back(node);
} else {
GELOGW("The node %s not suspend, drop from resumed", node->GetName().c_str());
}
}
during_pass_node_set.nodes_resume.clear();
}

void PushToSuspendNodes(DuringPassNodeSets &during_pass_node_set, const std::string &pass_name,
const std::unordered_set<NodePtr> &nodes_suspend,
const std::unordered_set<NodePtr> &nodes_resume) {
for (const auto &node : nodes_suspend) {
GELOGD("The iteration suspend of node %s has been set by pass %s", node->GetName().c_str(), pass_name.c_str());
during_pass_node_set.nodes_suspend.emplace(node);
}

for (const auto &node : nodes_resume) {
GELOGD("The iteration suspend of node %s has been resumed by pass %s", node->GetName().c_str(), pass_name.c_str());
during_pass_node_set.nodes_resume.emplace(node);
}
}

void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass, void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass,
std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_to_re_pass,
std::unordered_set<Node *> &nodes_seen, const std::unordered_set<NodePtr> &nodes_to_re_pass,
std::unordered_set<NodePtr> &nodes_re_pass) { std::unordered_set<NodePtr> &nodes_re_pass) {
for (const auto &node_to_re_pass : nodes_to_re_pass) { for (const auto &node_to_re_pass : nodes_to_re_pass) {
if (node_to_re_pass == nullptr) { if (node_to_re_pass == nullptr) {
@@ -113,15 +163,18 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo
return result; return result;
} }


auto nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass();
const auto &nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass();
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass, PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass,
during_pass_node_set.nodes_re_pass); during_pass_node_set.nodes_re_pass);


auto nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately();
const auto &nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately();
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately,
during_pass_node_set.nodes_re_pass_immediately); during_pass_node_set.nodes_re_pass_immediately);


auto nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted();
PushToSuspendNodes(during_pass_node_set, name_to_pass.first,
name_to_pass.second->GetNodesSuspend(), name_to_pass.second->GetNodesResume());

const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted();
during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end());
if (nodes_deleted_by_pass.count(node) > 0) { if (nodes_deleted_by_pass.count(node) > 0) {
GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(),
@@ -221,8 +274,13 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) {
GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str());
continue; continue;
} }
if (during_pass_node_set.nodes_suspend.count(node) > 0) {
GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.",
node->GetName().c_str());
continue;
}


AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last);
AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set);


auto ret = RunPasses(node, names_to_passes, during_pass_node_set); auto ret = RunPasses(node, names_to_passes, during_pass_node_set);
if (ret != SUCCESS) { if (ret != SUCCESS) {
@@ -253,11 +311,9 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) {
// should be called each time at the begin of the iteration // should be called each time at the begin of the iteration
ClearOption(names_to_passes); ClearOption(names_to_passes);
} }
for (const auto &node : during_pass_node_set.nodes_re_pass_immediately) {
GELOGD("The node %s will be re-pass immediately.", node->GetName().c_str());
nodes.push_front(node);
}
during_pass_node_set.nodes_re_pass_immediately.clear();

AddRepassNodes(during_pass_node_set, nodes);
AddResumeNodes(during_pass_node_set, nodes);
} }


for (auto &node : during_pass_node_set.nodes_last) { for (auto &node : during_pass_node_set.nodes_last) {


+ 34
- 6
ge/graph/passes/base_pass.h View File

@@ -51,11 +51,15 @@ class BaseNodePass {


virtual ~BaseNodePass() = default; virtual ~BaseNodePass() = default;


std::unordered_set<NodePtr> GetNodesNeedRePass() { return nodes_need_re_pass_; }
const std::unordered_set<NodePtr> &GetNodesNeedRePass() { return nodes_need_re_pass_; }


std::unordered_set<NodePtr> GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; }
const std::unordered_set<NodePtr> &GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; }


std::unordered_set<NodePtr> GetNodesDeleted() { return nodes_deleted_; }
const std::unordered_set<NodePtr> &GetNodesDeleted() { return nodes_deleted_; }

const std::unordered_set<NodePtr> &GetNodesSuspend() { return nodes_suspend_; }

const std::unordered_set<NodePtr> &GetNodesResume() { return nodes_resume_; }


void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; }


@@ -65,6 +69,8 @@ class BaseNodePass {
nodes_need_re_pass_.clear(); nodes_need_re_pass_.clear();
nodes_deleted_.clear(); nodes_deleted_.clear();
nodes_need_re_pass_immediately_.clear(); nodes_need_re_pass_immediately_.clear();
nodes_suspend_.clear();
nodes_resume_.clear();
} }


protected: protected:
@@ -80,7 +86,7 @@ class BaseNodePass {
/// optimized by other passes, call this function. /// optimized by other passes, call this function.
/// @param node /// @param node
/// ///
void AddRePassNode(NodePtr &node) { nodes_need_re_pass_.insert(node); }
void AddRePassNode(const NodePtr &node) { nodes_need_re_pass_.insert(node); }


/// ///
/// Add a node to be optimized immediately again. If you add a new node to the graph, or /// Add a node to be optimized immediately again. If you add a new node to the graph, or
@@ -88,13 +94,13 @@ class BaseNodePass {
/// optimized by other passes, call this function. /// optimized by other passes, call this function.
/// @param node /// @param node
/// ///
void AddImmediateRePassNode(NodePtr &node) { nodes_need_re_pass_immediately_.insert(node); }
void AddImmediateRePassNode(const NodePtr &node) { nodes_need_re_pass_immediately_.insert(node); }


/// ///
/// Add a node and it's input/output data nodes to be optimized again. /// Add a node and it's input/output data nodes to be optimized again.
/// @param node /// @param node
/// ///
void AddRePassNodesWithInOut(NodePtr &node) {
void AddRePassNodesWithInOut(const NodePtr &node) {
AddRePassNode(node); AddRePassNode(node);
auto out_nodes = node->GetOutNodes(); auto out_nodes = node->GetOutNodes();
for (auto &out_node : out_nodes) { for (auto &out_node : out_nodes) {
@@ -116,12 +122,34 @@ class BaseNodePass {
/// ///
void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); }


///
/// If you suspend a node from the graph, especially following node. The remain
/// iterate passes will stop process on the suspend node(if it can be
/// reached by edge connections) till the last one. Obviously it is a waste of
/// time. You can add the suspend nodes by calling this function, to stop the
/// next iterations.
/// @param node
///
void AddNodeSuspend(const NodePtr &node) { nodes_suspend_.insert(node); }

///
/// If you resume a node from the graph, especially following node. The remain
/// iterate passes will continue process on the resume node(if it can be
/// reached by edge connections) till the last one.
/// You can add the resume nodes by calling this function, to resume the
/// next iterations.
/// @param node
///
void AddNodeResume(const NodePtr &node) { nodes_resume_.insert(node); }

bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } bool OptionExists(NodePassOption option) { return options_.count(option) > 0; }


private: private:
std::unordered_set<NodePtr> nodes_need_re_pass_; std::unordered_set<NodePtr> nodes_need_re_pass_;
std::unordered_set<NodePtr> nodes_need_re_pass_immediately_; std::unordered_set<NodePtr> nodes_need_re_pass_immediately_;
std::unordered_set<NodePtr> nodes_deleted_; std::unordered_set<NodePtr> nodes_deleted_;
std::unordered_set<NodePtr> nodes_suspend_;
std::unordered_set<NodePtr> nodes_resume_;
std::map<NodePassOption, std::string> options_; std::map<NodePassOption, std::string> options_;
}; };




+ 35
- 4
ge/graph/passes/infershape_pass.cc View File

@@ -21,6 +21,8 @@
#include "framework/common/util.h" #include "framework/common/util.h"
#include "graph/shape_refiner.h" #include "graph/shape_refiner.h"
#include "graph/utils/graph_utils.h" #include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"
#include "graph/common/omg_util.h"
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"
#include "utils/tensor_utils.h" #include "utils/tensor_utils.h"
#include "utils/type_utils.h" #include "utils/type_utils.h"
@@ -117,7 +119,9 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) {
const auto RePassNode = [&](const std::set<std::string> &re_pass_types) { const auto RePassNode = [&](const std::set<std::string> &re_pass_types) {
for (auto &n : node->GetOutDataNodes()) { for (auto &n : node->GetOutDataNodes()) {
GE_CHECK_NOTNULL(n); GE_CHECK_NOTNULL(n);
if (re_pass_types.count(n->GetType()) > 0) {
std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "Get original node type failed.");
if (re_pass_types.count(node_type) > 0) {
AddImmediateRePassNode(n); AddImmediateRePassNode(n);
(void)AttrUtils::SetBool(n->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, false); (void)AttrUtils::SetBool(n->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, false);
GELOGD("Node %s need repass immediately after %s.", n->GetName().c_str(), node->GetName().c_str()); GELOGD("Node %s need repass immediately after %s.", n->GetName().c_str(), node->GetName().c_str());
@@ -126,17 +130,44 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) {
return SUCCESS; return SUCCESS;
}; };


if (node->GetType() == NEXTITERATION || node->GetType() == REFNEXTITERATION) {
return RePassNode({MERGE, REFMERGE}); // Re-Pass Merge
const auto ExProcNode = [&](const std::set<std::string> &proc_types,
const std::function<void(InferShapePass *, NodePtr)> &proc_func,
const std::string &info) {
for (auto &n : node->GetOutDataNodes()) {
GE_CHECK_NOTNULL(n);
std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "Get original node type failed.");
if (proc_types.count(node_type) > 0) {
proc_func(this, n);
GELOGD("Node %s %s after %s.", n->GetName().c_str(), info.c_str(), node->GetName().c_str());
}
}
return SUCCESS;
};

std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original node type failed.");
if (kNextIterationOpTypes.count(node_type) > 0) {
return RePassNode(kMergeOpTypes); // Re-Pass Merge
} }


if (node->GetType() == MERGE || node->GetType() == REFMERGE) {
if (kMergeOpTypes.count(node_type) > 0) {
if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) {
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN);
return RePassNode(kSwitchOpTypes); // Re-Pass Switch
} }
return SUCCESS; return SUCCESS;
} }


if (kSwitchOpTypes.count(node_type) > 0) {
if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) {
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN);
return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeResume, "need resume"); // Resume Exit
} else {
return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeSuspend, "need suspend"); // Suspend Exit
}
}

return SUCCESS; return SUCCESS;
} }
} // namespace ge } // namespace ge

+ 4
- 0
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -260,6 +260,10 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n
return SUCCESS; return SUCCESS;
} }


if (node->GetType() == MEMCPYASYNC) { // Convert MemcpyAsync to Identity.
node->GetOpDesc()->SetType(IDENTITY);
}

std::unique_ptr<NodeItem> new_node; std::unique_ptr<NodeItem> new_node;
GE_CHK_STATUS_RET(NodeItem::Create(node, new_node), "[Invoke][Create] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET(NodeItem::Create(node, new_node), "[Invoke][Create] failed, model_name_:[%s]", GetGraphName());
GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor)); GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor));


Loading…
Cancel
Save