Browse Source

rollback base_pass

tags/v1.3.0
zhangxiaokun 3 years ago
parent
commit
503dc06074
3 changed files with 13 additions and 103 deletions
  1. +7
    -46
      ge/graph/passes/base_pass.cc
  2. +6
    -34
      ge/graph/passes/base_pass.h
  3. +0
    -23
      ge/graph/passes/infershape_pass.cc

+ 7
- 46
ge/graph/passes/base_pass.cc View File

@@ -36,7 +36,6 @@ struct DuringPassNodeSets {
std::unordered_set<NodePtr> nodes_re_pass; std::unordered_set<NodePtr> nodes_re_pass;
std::unordered_set<NodePtr> nodes_re_pass_immediately; std::unordered_set<NodePtr> nodes_re_pass_immediately;
std::unordered_set<NodePtr> nodes_last; std::unordered_set<NodePtr> nodes_last;
std::unordered_set<NodePtr> nodes_stopped;
}; };


void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &input_edge_nodes, void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &input_edge_nodes,
@@ -56,25 +55,8 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &i
} }
} }


void AddNextIterNodes(const std::vector<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass,
DuringPassNodeSets &during_pass_node_set) {
for (auto &node : nodes) {
if (node == nullptr) {
continue;
}
if (during_pass_node_set.nodes_stopped.count(node) > 0) {
GELOGD("The node %s was stopped by pass, skip it.", node->GetName().c_str());
continue;
}

nodes_to_pass.push_back(node);
}
}

void GetNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::vector<NodePtr> &nodes_to_pass,
DuringPassNodeSets &during_pass_node_set) {
std::unordered_set<Node *> &nodes_seen = during_pass_node_set.nodes_seen;
const std::unordered_set<NodePtr> &nodes_last = during_pass_node_set.nodes_last;
void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass,
std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_last) {
for (auto &node : nodes) { for (auto &node : nodes) {
if (node == nullptr) { if (node == nullptr) {
continue; continue;
@@ -90,22 +72,8 @@ void GetNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::vector<NodePtr> &
} }
} }


void PushToStoppedNodes(DuringPassNodeSets &during_pass_node_set, const std::string &pass_name,
const std::unordered_set<NodePtr> &nodes_stopped,
const std::unordered_set<NodePtr> &nodes_restored) {
for (const auto &node : nodes_stopped) {
GELOGD("The node %s was stopped by pass %s", node->GetName().c_str(), pass_name.c_str());
during_pass_node_set.nodes_stopped.emplace(node);
}

for (const auto &node : nodes_restored) {
GELOGD("The node %s was restored by pass %s", node->GetName().c_str(), pass_name.c_str());
during_pass_node_set.nodes_stopped.erase(node);
}
}

void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass, void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass,
std::unordered_set<Node *> &nodes_seen, const std::unordered_set<NodePtr> &nodes_to_re_pass,
std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_to_re_pass,
std::unordered_set<NodePtr> &nodes_re_pass) { std::unordered_set<NodePtr> &nodes_re_pass) {
for (const auto &node_to_re_pass : nodes_to_re_pass) { for (const auto &node_to_re_pass : nodes_to_re_pass) {
if (node_to_re_pass == nullptr) { if (node_to_re_pass == nullptr) {
@@ -129,8 +97,6 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo
} }
GELOGD("Begin to run pass for node %s", node->GetName().c_str()); GELOGD("Begin to run pass for node %s", node->GetName().c_str());
for (const auto &name_to_pass : names_to_passes) { for (const auto &name_to_pass : names_to_passes) {
const std::string &pass_name = name_to_pass.first;
BaseNodePass *pass_node = name_to_pass.second;
if (name_to_pass.second == nullptr) { if (name_to_pass.second == nullptr) {
GELOGE(INTERNAL_ERROR, "There is null pointer in passes(%s), skip it", name_to_pass.first.c_str()); GELOGE(INTERNAL_ERROR, "There is null pointer in passes(%s), skip it", name_to_pass.first.c_str());
continue; continue;
@@ -147,17 +113,15 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo
return result; return result;
} }


const auto &nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass();
auto nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass();
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass, PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass,
during_pass_node_set.nodes_re_pass); during_pass_node_set.nodes_re_pass);


const auto &nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately();
auto nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately();
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately,
during_pass_node_set.nodes_re_pass_immediately); during_pass_node_set.nodes_re_pass_immediately);


PushToStoppedNodes(during_pass_node_set, pass_name, pass_node->GetNodesStopped(), pass_node->GetNodesRestored());

const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted();
auto nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted();
during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end());
if (nodes_deleted_by_pass.count(node) > 0) { if (nodes_deleted_by_pass.count(node) > 0) {
GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(),
@@ -258,8 +222,7 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) {
continue; continue;
} }


std::vector<NodePtr> nodes_to_pass;
GetNextIterNodes(node->GetOutNodes(), nodes_to_pass, during_pass_node_set);
AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last);


auto ret = RunPasses(node, names_to_passes, during_pass_node_set); auto ret = RunPasses(node, names_to_passes, during_pass_node_set);
if (ret != SUCCESS) { if (ret != SUCCESS) {
@@ -295,8 +258,6 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) {
nodes.push_front(node); nodes.push_front(node);
} }
during_pass_node_set.nodes_re_pass_immediately.clear(); during_pass_node_set.nodes_re_pass_immediately.clear();

AddNextIterNodes(nodes_to_pass, nodes, during_pass_node_set);
} }


for (auto &node : during_pass_node_set.nodes_last) { for (auto &node : during_pass_node_set.nodes_last) {


+ 6
- 34
ge/graph/passes/base_pass.h View File

@@ -51,15 +51,11 @@ class BaseNodePass {


virtual ~BaseNodePass() = default; virtual ~BaseNodePass() = default;


const std::unordered_set<NodePtr> &GetNodesNeedRePass() { return nodes_need_re_pass_; }
std::unordered_set<NodePtr> GetNodesNeedRePass() { return nodes_need_re_pass_; }


const std::unordered_set<NodePtr> &GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; }
std::unordered_set<NodePtr> GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; }


const std::unordered_set<NodePtr> &GetNodesDeleted() { return nodes_deleted_; }

const std::unordered_set<NodePtr> &GetNodesStopped() { return nodes_stopped_; }

const std::unordered_set<NodePtr> &GetNodesRestored() { return nodes_restored_; }
std::unordered_set<NodePtr> GetNodesDeleted() { return nodes_deleted_; }


void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; }


@@ -69,8 +65,6 @@ class BaseNodePass {
nodes_need_re_pass_.clear(); nodes_need_re_pass_.clear();
nodes_deleted_.clear(); nodes_deleted_.clear();
nodes_need_re_pass_immediately_.clear(); nodes_need_re_pass_immediately_.clear();
nodes_stopped_.clear();
nodes_restored_.clear();
} }


protected: protected:
@@ -86,7 +80,7 @@ class BaseNodePass {
/// optimized by other passes, call this function. /// optimized by other passes, call this function.
/// @param node /// @param node
/// ///
void AddRePassNode(const NodePtr &node) { nodes_need_re_pass_.insert(node); }
void AddRePassNode(NodePtr &node) { nodes_need_re_pass_.insert(node); }


/// ///
/// Add a node to be optimized immediately again. If you add a new node to the graph, or /// Add a node to be optimized immediately again. If you add a new node to the graph, or
@@ -94,13 +88,13 @@ class BaseNodePass {
/// optimized by other passes, call this function. /// optimized by other passes, call this function.
/// @param node /// @param node
/// ///
void AddImmediateRePassNode(const NodePtr &node) { nodes_need_re_pass_immediately_.insert(node); }
void AddImmediateRePassNode(NodePtr &node) { nodes_need_re_pass_immediately_.insert(node); }


/// ///
/// Add a node and it's input/output data nodes to be optimized again. /// Add a node and it's input/output data nodes to be optimized again.
/// @param node /// @param node
/// ///
void AddRePassNodesWithInOut(const NodePtr &node) {
void AddRePassNodesWithInOut(NodePtr &node) {
AddRePassNode(node); AddRePassNode(node);
auto out_nodes = node->GetOutNodes(); auto out_nodes = node->GetOutNodes();
for (auto &out_node : out_nodes) { for (auto &out_node : out_nodes) {
@@ -122,34 +116,12 @@ class BaseNodePass {
/// ///
void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); }


///
/// If you stop a node from the graph, especially following node. The remain
/// iterate passes will stop process on the stopped node(if it can be
/// reached by edge connections) till the last one. Obviously it is a waste of
/// time. You can add the stopped nodes by calling this function, to stop the
/// next iterations.
/// @param node
///
void AddNodeStopped(const NodePtr &node) { nodes_stopped_.insert(node); }

///
/// If you restore a node from the graph, especially following node. The remain
/// iterate passes will continue process on the stopped node(if it can be
/// reached by edge connections) till the last one.
/// You can add the restored nodes by calling this function, to restore the
/// next iterations.
/// @param node
///
void AddNodeRestored(const NodePtr &node) { nodes_restored_.insert(node); }

bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } bool OptionExists(NodePassOption option) { return options_.count(option) > 0; }


private: private:
std::unordered_set<NodePtr> nodes_need_re_pass_; std::unordered_set<NodePtr> nodes_need_re_pass_;
std::unordered_set<NodePtr> nodes_need_re_pass_immediately_; std::unordered_set<NodePtr> nodes_need_re_pass_immediately_;
std::unordered_set<NodePtr> nodes_deleted_; std::unordered_set<NodePtr> nodes_deleted_;
std::unordered_set<NodePtr> nodes_stopped_;
std::unordered_set<NodePtr> nodes_restored_;
std::map<NodePassOption, std::string> options_; std::map<NodePassOption, std::string> options_;
}; };




+ 0
- 23
ge/graph/passes/infershape_pass.cc View File

@@ -126,19 +126,6 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) {
return SUCCESS; return SUCCESS;
}; };


const auto ExProcNode = [&](const std::set<std::string> &proc_types,
const std::function<void(InferShapePass *, NodePtr)> &proc_func,
const std::string &info) {
for (auto &n : node->GetOutDataNodes()) {
GE_CHECK_NOTNULL(n);
if (proc_types.count(n->GetType()) > 0) {
proc_func(this, n);
GELOGD("Node %s %s after %s.", n->GetName().c_str(), info.c_str(), node->GetName().c_str());
}
}
return SUCCESS;
};

if (node->GetType() == NEXTITERATION || node->GetType() == REFNEXTITERATION) { if (node->GetType() == NEXTITERATION || node->GetType() == REFNEXTITERATION) {
return RePassNode({MERGE, REFMERGE}); // Re-Pass Merge return RePassNode({MERGE, REFMERGE}); // Re-Pass Merge
} }
@@ -146,20 +133,10 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) {
if (node->GetType() == MERGE || node->GetType() == REFMERGE) { if (node->GetType() == MERGE || node->GetType() == REFMERGE) {
if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) {
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN);
return RePassNode({SWITCH, REFSWITCH}); // Re-Pass Switch
} }
return SUCCESS; return SUCCESS;
} }


if (node->GetType() == SWITCH || node->GetType() == REFSWITCH) {
if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) {
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN);
return ExProcNode({EXIT, REFEXIT}, &InferShapePass::AddNodeRestored, "need restore"); // Restore Exit
} else {
return ExProcNode({EXIT, REFEXIT}, &InferShapePass::AddNodeStopped, "need stop"); // Stop Exit
}
}

return SUCCESS; return SUCCESS;
} }
} // namespace ge } // namespace ge

Loading…
Cancel
Save