diff --git a/ge/graph/passes/base_pass.cc b/ge/graph/passes/base_pass.cc index 16b65f20..a2480d42 100755 --- a/ge/graph/passes/base_pass.cc +++ b/ge/graph/passes/base_pass.cc @@ -56,19 +56,29 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque &i } } -void AddNextIterNodes(const Node::Vistor &nodes, std::deque &nodes_to_pass, +void AddNextIterNodes(const std::vector &nodes, std::deque &nodes_to_pass, DuringPassNodeSets &during_pass_node_set) { - std::unordered_set &nodes_seen = during_pass_node_set.nodes_seen; - const std::unordered_set &nodes_last = during_pass_node_set.nodes_last; - const std::unordered_set &nodes_stopped = during_pass_node_set.nodes_stopped; for (auto &node : nodes) { if (node == nullptr) { continue; } - if (nodes_stopped.count(node) > 0) { + if (during_pass_node_set.nodes_stopped.count(node) > 0) { GELOGD("The node %s was stopped by pass, skip it.", node->GetName().c_str()); continue; } + + nodes_to_pass.push_back(node); + } +} + +void GetNextIterNodes(const Node::Vistor &nodes, std::vector &nodes_to_pass, + DuringPassNodeSets &during_pass_node_set) { + std::unordered_set &nodes_seen = during_pass_node_set.nodes_seen; + const std::unordered_set &nodes_last = during_pass_node_set.nodes_last; + for (auto &node : nodes) { + if (node == nullptr) { + continue; + } if (nodes_last.count(node) != 0) { continue; } @@ -80,6 +90,20 @@ void AddNextIterNodes(const Node::Vistor &nodes, std::deque &n } } +void PushToStoppedNodes(DuringPassNodeSets &during_pass_node_set, const std::string &pass_name, + const std::unordered_set &nodes_stopped, + const std::unordered_set &nodes_restored) { + for (const auto &node : nodes_stopped) { + GELOGD("The node %s was stopped by pass %s", node->GetName().c_str(), pass_name.c_str()); + during_pass_node_set.nodes_stopped.emplace(node); + } + + for (const auto &node : nodes_restored) { + GELOGD("The node %s was restored by pass %s", node->GetName().c_str(), pass_name.c_str()); + during_pass_node_set.nodes_stopped.erase(node); + } +} + void PushToRePassIfSeen(NodePtr &node, const std::pair &name_to_pass, std::unordered_set &nodes_seen, const std::unordered_set &nodes_to_re_pass, std::unordered_set &nodes_re_pass) { @@ -105,6 +129,8 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo } GELOGD("Begin to run pass for node %s", node->GetName().c_str()); for (const auto &name_to_pass : names_to_passes) { + const std::string &pass_name = name_to_pass.first; + BaseNodePass *pass_node = name_to_pass.second; if (name_to_pass.second == nullptr) { GELOGE(INTERNAL_ERROR, "There is null pointer in passes(%s), skip it", name_to_pass.first.c_str()); continue; @@ -129,14 +155,7 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, during_pass_node_set.nodes_re_pass_immediately); - for (const auto &node : name_to_pass.second->GetNodesStopped()) { - GELOGD("The node %s was stopped by pass %s", node->GetName().c_str(), name_to_pass.first.c_str()); - during_pass_node_set.nodes_stopped.emplace(node); - } - for (const auto &node : name_to_pass.second->GetNodesRestored()) { - GELOGD("The node %s was restored by pass %s", node->GetName().c_str(), name_to_pass.first.c_str()); - during_pass_node_set.nodes_stopped.erase(node); - } + PushToStoppedNodes(during_pass_node_set, pass_name, pass_node->GetNodesStopped(), pass_node->GetNodesRestored()); const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); @@ -239,7 +258,9 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { continue; } - const auto all_out_nodes = node->GetOutNodes(); + std::vector nodes_to_pass; + GetNextIterNodes(node->GetOutAllNodes(), nodes_to_pass, during_pass_node_set); + auto ret = RunPasses(node, names_to_passes, during_pass_node_set); if (ret != SUCCESS) { GELOGE(ret, "Failed to process passes on node %s type %s, error code: %u", @@ -275,7 +296,7 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { } during_pass_node_set.nodes_re_pass_immediately.clear(); - AddNextIterNodes(all_out_nodes, nodes, during_pass_node_set); + AddNextIterNodes(nodes_to_pass, nodes, during_pass_node_set); } for (auto &node : during_pass_node_set.nodes_last) { diff --git a/ge/hybrid/executor/subgraph_executor.cc b/ge/hybrid/executor/subgraph_executor.cc index 8685498f..c9c6f768 100644 --- a/ge/hybrid/executor/subgraph_executor.cc +++ b/ge/hybrid/executor/subgraph_executor.cc @@ -295,6 +295,7 @@ Status SubgraphExecutor::PrepareNodes(int group) { GE_CHK_STATUS_RET(PrepareNode(*node_item, group), "[%s] failed to prepare task.", node_item->NodeName().c_str()); RECORD_EXECUTION_EVENT(context_, node_item->NodeName().c_str(), "[PrepareNode] End"); } + GELOGD("[%s] Done preparing nodes successfully.", graph_item_->GetName().c_str()); return SUCCESS; } @@ -348,6 +349,7 @@ Status SubgraphExecutor::NodeScheduled(NodeState *node_state) { graph_item_->GetName().c_str(), node_state->GetName().c_str(), node_state->GetNodeItem()->data_send_.size(), node_state->GetNodeItem()->ctrl_send_.size(), node_state->GetSwitchIndex(), node_state->GetMergeIndex()); + auto future = pre_run_pool_.commit([this, node_state]() -> Status { RECORD_CALLBACK_EVENT(context_, node_state->GetName().c_str(), "[NodeScheduled] Start"); std::function callback = [&](const NodeItem *node_item) { @@ -391,6 +393,7 @@ Status SubgraphExecutor::AfterPrepared(NodeState *node_state) { if (node_state->IsShapeDependence()) { return SUCCESS; } + // Not control flow node, propagate state. return NodeScheduled(node_state); } @@ -399,6 +402,7 @@ void SubgraphExecutor::AfterExecuted(NodeState *node_state) { if (!node_state->IsShapeDependence()) { return; } + // For control flow node, propagate state. auto error = NodeScheduled(node_state); if (error != SUCCESS) { diff --git a/ge/hybrid/executor/worker/execution_engine.cc b/ge/hybrid/executor/worker/execution_engine.cc index 486e0a83..4d77d0f0 100755 --- a/ge/hybrid/executor/worker/execution_engine.cc +++ b/ge/hybrid/executor/worker/execution_engine.cc @@ -16,9 +16,6 @@ #include "hybrid/executor/worker/execution_engine.h" #include "graph/runtime_inference_context.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/tensor_adapter.h" -#include "graph/debug/ge_attr_define.h" #include "graph/load/model_manager/model_manager.h" #include "hybrid/node_executor/node_executor.h" #include "hybrid/executor//worker//shape_inference_engine.h"