Browse Source

Fix base_pass nodes_seen

tags/v1.3.0
zhangxiaokun 3 years ago
parent
commit
5b9f2b4094
3 changed files with 40 additions and 18 deletions
  1. +36
    -15
      ge/graph/passes/base_pass.cc
  2. +4
    -0
      ge/hybrid/executor/subgraph_executor.cc
  3. +0
    -3
      ge/hybrid/executor/worker/execution_engine.cc

+ 36
- 15
ge/graph/passes/base_pass.cc View File

@@ -56,19 +56,29 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &i
} }
} }


void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass,
void AddNextIterNodes(const std::vector<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass,
DuringPassNodeSets &during_pass_node_set) { DuringPassNodeSets &during_pass_node_set) {
std::unordered_set<Node *> &nodes_seen = during_pass_node_set.nodes_seen;
const std::unordered_set<NodePtr> &nodes_last = during_pass_node_set.nodes_last;
const std::unordered_set<NodePtr> &nodes_stopped = during_pass_node_set.nodes_stopped;
for (auto &node : nodes) { for (auto &node : nodes) {
if (node == nullptr) { if (node == nullptr) {
continue; continue;
} }
if (nodes_stopped.count(node) > 0) {
if (during_pass_node_set.nodes_stopped.count(node) > 0) {
GELOGD("The node %s was stopped by pass, skip it.", node->GetName().c_str()); GELOGD("The node %s was stopped by pass, skip it.", node->GetName().c_str());
continue; continue;
} }

nodes_to_pass.push_back(node);
}
}

void GetNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::vector<NodePtr> &nodes_to_pass,
DuringPassNodeSets &during_pass_node_set) {
std::unordered_set<Node *> &nodes_seen = during_pass_node_set.nodes_seen;
const std::unordered_set<NodePtr> &nodes_last = during_pass_node_set.nodes_last;
for (auto &node : nodes) {
if (node == nullptr) {
continue;
}
if (nodes_last.count(node) != 0) { if (nodes_last.count(node) != 0) {
continue; continue;
} }
@@ -80,6 +90,20 @@ void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &n
} }
} }


void PushToStoppedNodes(DuringPassNodeSets &during_pass_node_set, const std::string &pass_name,
const std::unordered_set<NodePtr> &nodes_stopped,
const std::unordered_set<NodePtr> &nodes_restored) {
for (const auto &node : nodes_stopped) {
GELOGD("The node %s was stopped by pass %s", node->GetName().c_str(), pass_name.c_str());
during_pass_node_set.nodes_stopped.emplace(node);
}

for (const auto &node : nodes_restored) {
GELOGD("The node %s was restored by pass %s", node->GetName().c_str(), pass_name.c_str());
during_pass_node_set.nodes_stopped.erase(node);
}
}

void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass, void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass,
std::unordered_set<Node *> &nodes_seen, const std::unordered_set<NodePtr> &nodes_to_re_pass, std::unordered_set<Node *> &nodes_seen, const std::unordered_set<NodePtr> &nodes_to_re_pass,
std::unordered_set<NodePtr> &nodes_re_pass) { std::unordered_set<NodePtr> &nodes_re_pass) {
@@ -105,6 +129,8 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo
} }
GELOGD("Begin to run pass for node %s", node->GetName().c_str()); GELOGD("Begin to run pass for node %s", node->GetName().c_str());
for (const auto &name_to_pass : names_to_passes) { for (const auto &name_to_pass : names_to_passes) {
const std::string &pass_name = name_to_pass.first;
BaseNodePass *pass_node = name_to_pass.second;
if (name_to_pass.second == nullptr) { if (name_to_pass.second == nullptr) {
GELOGE(INTERNAL_ERROR, "There is null pointer in passes(%s), skip it", name_to_pass.first.c_str()); GELOGE(INTERNAL_ERROR, "There is null pointer in passes(%s), skip it", name_to_pass.first.c_str());
continue; continue;
@@ -129,14 +155,7 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately,
during_pass_node_set.nodes_re_pass_immediately); during_pass_node_set.nodes_re_pass_immediately);


for (const auto &node : name_to_pass.second->GetNodesStopped()) {
GELOGD("The node %s was stopped by pass %s", node->GetName().c_str(), name_to_pass.first.c_str());
during_pass_node_set.nodes_stopped.emplace(node);
}
for (const auto &node : name_to_pass.second->GetNodesRestored()) {
GELOGD("The node %s was restored by pass %s", node->GetName().c_str(), name_to_pass.first.c_str());
during_pass_node_set.nodes_stopped.erase(node);
}
PushToStoppedNodes(during_pass_node_set, pass_name, pass_node->GetNodesStopped(), pass_node->GetNodesRestored());


const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted();
during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end());
@@ -239,7 +258,9 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) {
continue; continue;
} }


const auto all_out_nodes = node->GetOutNodes();
std::vector<NodePtr> nodes_to_pass;
GetNextIterNodes(node->GetOutAllNodes(), nodes_to_pass, during_pass_node_set);

auto ret = RunPasses(node, names_to_passes, during_pass_node_set); auto ret = RunPasses(node, names_to_passes, during_pass_node_set);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret, "Failed to process passes on node %s type %s, error code: %u", GELOGE(ret, "Failed to process passes on node %s type %s, error code: %u",
@@ -275,7 +296,7 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) {
} }
during_pass_node_set.nodes_re_pass_immediately.clear(); during_pass_node_set.nodes_re_pass_immediately.clear();


AddNextIterNodes(all_out_nodes, nodes, during_pass_node_set);
AddNextIterNodes(nodes_to_pass, nodes, during_pass_node_set);
} }


for (auto &node : during_pass_node_set.nodes_last) { for (auto &node : during_pass_node_set.nodes_last) {


+ 4
- 0
ge/hybrid/executor/subgraph_executor.cc View File

@@ -295,6 +295,7 @@ Status SubgraphExecutor::PrepareNodes(int group) {
GE_CHK_STATUS_RET(PrepareNode(*node_item, group), "[%s] failed to prepare task.", node_item->NodeName().c_str()); GE_CHK_STATUS_RET(PrepareNode(*node_item, group), "[%s] failed to prepare task.", node_item->NodeName().c_str());
RECORD_EXECUTION_EVENT(context_, node_item->NodeName().c_str(), "[PrepareNode] End"); RECORD_EXECUTION_EVENT(context_, node_item->NodeName().c_str(), "[PrepareNode] End");
} }

GELOGD("[%s] Done preparing nodes successfully.", graph_item_->GetName().c_str()); GELOGD("[%s] Done preparing nodes successfully.", graph_item_->GetName().c_str());
return SUCCESS; return SUCCESS;
} }
@@ -348,6 +349,7 @@ Status SubgraphExecutor::NodeScheduled(NodeState *node_state) {
graph_item_->GetName().c_str(), node_state->GetName().c_str(), graph_item_->GetName().c_str(), node_state->GetName().c_str(),
node_state->GetNodeItem()->data_send_.size(), node_state->GetNodeItem()->ctrl_send_.size(), node_state->GetNodeItem()->data_send_.size(), node_state->GetNodeItem()->ctrl_send_.size(),
node_state->GetSwitchIndex(), node_state->GetMergeIndex()); node_state->GetSwitchIndex(), node_state->GetMergeIndex());

auto future = pre_run_pool_.commit([this, node_state]() -> Status { auto future = pre_run_pool_.commit([this, node_state]() -> Status {
RECORD_CALLBACK_EVENT(context_, node_state->GetName().c_str(), "[NodeScheduled] Start"); RECORD_CALLBACK_EVENT(context_, node_state->GetName().c_str(), "[NodeScheduled] Start");
std::function<void(const NodeItem *)> callback = [&](const NodeItem *node_item) { std::function<void(const NodeItem *)> callback = [&](const NodeItem *node_item) {
@@ -391,6 +393,7 @@ Status SubgraphExecutor::AfterPrepared(NodeState *node_state) {
if (node_state->IsShapeDependence()) { if (node_state->IsShapeDependence()) {
return SUCCESS; return SUCCESS;
} }

// Not control flow node, propagate state. // Not control flow node, propagate state.
return NodeScheduled(node_state); return NodeScheduled(node_state);
} }
@@ -399,6 +402,7 @@ void SubgraphExecutor::AfterExecuted(NodeState *node_state) {
if (!node_state->IsShapeDependence()) { if (!node_state->IsShapeDependence()) {
return; return;
} }

// For control flow node, propagate state. // For control flow node, propagate state.
auto error = NodeScheduled(node_state); auto error = NodeScheduled(node_state);
if (error != SUCCESS) { if (error != SUCCESS) {


+ 0
- 3
ge/hybrid/executor/worker/execution_engine.cc View File

@@ -16,9 +16,6 @@


#include "hybrid/executor/worker/execution_engine.h" #include "hybrid/executor/worker/execution_engine.h"
#include "graph/runtime_inference_context.h" #include "graph/runtime_inference_context.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/tensor_adapter.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/load/model_manager/model_manager.h" #include "graph/load/model_manager/model_manager.h"
#include "hybrid/node_executor/node_executor.h" #include "hybrid/node_executor/node_executor.h"
#include "hybrid/executor//worker//shape_inference_engine.h" #include "hybrid/executor//worker//shape_inference_engine.h"


Loading…
Cancel
Save