Browse Source

Fix base_pass nodes_seen

tags/v1.3.0
zhangxiaokun 3 years ago
parent
commit
5b9f2b4094
3 changed files with 40 additions and 18 deletions
  1. +36
    -15
      ge/graph/passes/base_pass.cc
  2. +4
    -0
      ge/hybrid/executor/subgraph_executor.cc
  3. +0
    -3
      ge/hybrid/executor/worker/execution_engine.cc

+ 36
- 15
ge/graph/passes/base_pass.cc View File

@@ -56,19 +56,29 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &i
}
}

void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass,
void AddNextIterNodes(const std::vector<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass,
DuringPassNodeSets &during_pass_node_set) {
std::unordered_set<Node *> &nodes_seen = during_pass_node_set.nodes_seen;
const std::unordered_set<NodePtr> &nodes_last = during_pass_node_set.nodes_last;
const std::unordered_set<NodePtr> &nodes_stopped = during_pass_node_set.nodes_stopped;
for (auto &node : nodes) {
if (node == nullptr) {
continue;
}
if (nodes_stopped.count(node) > 0) {
if (during_pass_node_set.nodes_stopped.count(node) > 0) {
GELOGD("The node %s was stopped by pass, skip it.", node->GetName().c_str());
continue;
}

nodes_to_pass.push_back(node);
}
}

void GetNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::vector<NodePtr> &nodes_to_pass,
DuringPassNodeSets &during_pass_node_set) {
std::unordered_set<Node *> &nodes_seen = during_pass_node_set.nodes_seen;
const std::unordered_set<NodePtr> &nodes_last = during_pass_node_set.nodes_last;
for (auto &node : nodes) {
if (node == nullptr) {
continue;
}
if (nodes_last.count(node) != 0) {
continue;
}
@@ -80,6 +90,20 @@ void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &n
}
}

void PushToStoppedNodes(DuringPassNodeSets &during_pass_node_set, const std::string &pass_name,
const std::unordered_set<NodePtr> &nodes_stopped,
const std::unordered_set<NodePtr> &nodes_restored) {
for (const auto &node : nodes_stopped) {
GELOGD("The node %s was stopped by pass %s", node->GetName().c_str(), pass_name.c_str());
during_pass_node_set.nodes_stopped.emplace(node);
}

for (const auto &node : nodes_restored) {
GELOGD("The node %s was restored by pass %s", node->GetName().c_str(), pass_name.c_str());
during_pass_node_set.nodes_stopped.erase(node);
}
}

void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass,
std::unordered_set<Node *> &nodes_seen, const std::unordered_set<NodePtr> &nodes_to_re_pass,
std::unordered_set<NodePtr> &nodes_re_pass) {
@@ -105,6 +129,8 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo
}
GELOGD("Begin to run pass for node %s", node->GetName().c_str());
for (const auto &name_to_pass : names_to_passes) {
const std::string &pass_name = name_to_pass.first;
BaseNodePass *pass_node = name_to_pass.second;
if (name_to_pass.second == nullptr) {
GELOGE(INTERNAL_ERROR, "There is null pointer in passes(%s), skip it", name_to_pass.first.c_str());
continue;
@@ -129,14 +155,7 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately,
during_pass_node_set.nodes_re_pass_immediately);

for (const auto &node : name_to_pass.second->GetNodesStopped()) {
GELOGD("The node %s was stopped by pass %s", node->GetName().c_str(), name_to_pass.first.c_str());
during_pass_node_set.nodes_stopped.emplace(node);
}
for (const auto &node : name_to_pass.second->GetNodesRestored()) {
GELOGD("The node %s was restored by pass %s", node->GetName().c_str(), name_to_pass.first.c_str());
during_pass_node_set.nodes_stopped.erase(node);
}
PushToStoppedNodes(during_pass_node_set, pass_name, pass_node->GetNodesStopped(), pass_node->GetNodesRestored());

const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted();
during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end());
@@ -239,7 +258,9 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) {
continue;
}

const auto all_out_nodes = node->GetOutNodes();
std::vector<NodePtr> nodes_to_pass;
GetNextIterNodes(node->GetOutAllNodes(), nodes_to_pass, during_pass_node_set);

auto ret = RunPasses(node, names_to_passes, during_pass_node_set);
if (ret != SUCCESS) {
GELOGE(ret, "Failed to process passes on node %s type %s, error code: %u",
@@ -275,7 +296,7 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) {
}
during_pass_node_set.nodes_re_pass_immediately.clear();

AddNextIterNodes(all_out_nodes, nodes, during_pass_node_set);
AddNextIterNodes(nodes_to_pass, nodes, during_pass_node_set);
}

for (auto &node : during_pass_node_set.nodes_last) {


+ 4
- 0
ge/hybrid/executor/subgraph_executor.cc View File

@@ -295,6 +295,7 @@ Status SubgraphExecutor::PrepareNodes(int group) {
GE_CHK_STATUS_RET(PrepareNode(*node_item, group), "[%s] failed to prepare task.", node_item->NodeName().c_str());
RECORD_EXECUTION_EVENT(context_, node_item->NodeName().c_str(), "[PrepareNode] End");
}

GELOGD("[%s] Done preparing nodes successfully.", graph_item_->GetName().c_str());
return SUCCESS;
}
@@ -348,6 +349,7 @@ Status SubgraphExecutor::NodeScheduled(NodeState *node_state) {
graph_item_->GetName().c_str(), node_state->GetName().c_str(),
node_state->GetNodeItem()->data_send_.size(), node_state->GetNodeItem()->ctrl_send_.size(),
node_state->GetSwitchIndex(), node_state->GetMergeIndex());

auto future = pre_run_pool_.commit([this, node_state]() -> Status {
RECORD_CALLBACK_EVENT(context_, node_state->GetName().c_str(), "[NodeScheduled] Start");
std::function<void(const NodeItem *)> callback = [&](const NodeItem *node_item) {
@@ -391,6 +393,7 @@ Status SubgraphExecutor::AfterPrepared(NodeState *node_state) {
if (node_state->IsShapeDependence()) {
return SUCCESS;
}

// Not control flow node, propagate state.
return NodeScheduled(node_state);
}
@@ -399,6 +402,7 @@ void SubgraphExecutor::AfterExecuted(NodeState *node_state) {
if (!node_state->IsShapeDependence()) {
return;
}

// For control flow node, propagate state.
auto error = NodeScheduled(node_state);
if (error != SUCCESS) {


+ 0
- 3
ge/hybrid/executor/worker/execution_engine.cc View File

@@ -16,9 +16,6 @@

#include "hybrid/executor/worker/execution_engine.h"
#include "graph/runtime_inference_context.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/tensor_adapter.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/load/model_manager/model_manager.h"
#include "hybrid/node_executor/node_executor.h"
#include "hybrid/executor//worker//shape_inference_engine.h"


Loading…
Cancel
Save