From 04f27bc560d15981ea596f67f37e3aa7b11d93d8 Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Mon, 7 Jun 2021 12:36:08 +0800 Subject: [PATCH] Fix loop nesting --- ge/graph/common/omg_util.cc | 14 ++- ge/graph/common/omg_util.h | 8 ++ ge/graph/passes/next_iteration_pass.cc | 26 ++--- ge/graph/passes/next_iteration_pass.h | 1 - ge/hybrid/executor/node_state.cc | 109 ++++++++++++------ ge/hybrid/executor/node_state.h | 29 ++++- ge/hybrid/executor/subgraph_context.cc | 13 +++ ge/hybrid/executor/subgraph_context.h | 2 + ge/hybrid/model/hybrid_model_builder.cc | 56 +++++++++ ge/hybrid/model/hybrid_model_builder.h | 2 + ge/hybrid/model/node_item.cc | 11 +- ge/hybrid/model/node_item.h | 14 +++ ge/hybrid/node_executor/rts/rts_node_task.cc | 6 +- .../executor/subgraph_executor_unittest.cc | 6 +- .../model/hybrid_model_builder_unittest.cc | 35 +++++- .../rts/rts_node_task_unittest.cc | 4 + 16 files changed, 260 insertions(+), 76 deletions(-) diff --git a/ge/graph/common/omg_util.cc b/ge/graph/common/omg_util.cc index 598677bd..52e6cb9c 100644 --- a/ge/graph/common/omg_util.cc +++ b/ge/graph/common/omg_util.cc @@ -286,13 +286,23 @@ void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t grou return; } + SetControlFlowGroup(node, group_index); +} + +/// +/// @brief Set Op _control_flow_group flag +/// @param [in] node +/// @param [in] group, condition group index of node. +/// @return +/// +void SetControlFlowGroup(const NodePtr &node, int64_t group) { GE_RT_VOID_CHECK_NOTNULL(node); const auto &op_desc = node->GetOpDesc(); GE_RT_VOID_CHECK_NOTNULL(op_desc); // op_desc as AttrHolderAdapter valid, Set attribute always success, just log for check. - GELOGD("[%s] Set control flow group index: %ld", node->GetName().c_str(), group_index); - if (!AttrUtils::SetInt(op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { + GELOGD("[%s] Set control flow group index: %ld", node->GetName().c_str(), group); + if (!AttrUtils::SetInt(op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, group)) { REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), node->GetName().c_str(), node->GetType().c_str()); GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), diff --git a/ge/graph/common/omg_util.h b/ge/graph/common/omg_util.h index 91fcd29e..148e4102 100644 --- a/ge/graph/common/omg_util.h +++ b/ge/graph/common/omg_util.h @@ -133,6 +133,14 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc); /// @return /// void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index); + +/// +/// @brief Set Op _control_flow_group flag +/// @param [in] node +/// @param [in] group, condition group index of node. +/// @return +/// +void SetControlFlowGroup(const NodePtr &node, int64_t group); } // namespace ge #endif // GE_GRAPH_COMMON_OMG_UTIL_H_ diff --git a/ge/graph/passes/next_iteration_pass.cc b/ge/graph/passes/next_iteration_pass.cc index 71b9e621..f7c8a290 100644 --- a/ge/graph/passes/next_iteration_pass.cc +++ b/ge/graph/passes/next_iteration_pass.cc @@ -183,12 +183,6 @@ bool NextIterationPass::VerifyWhileGroup() { frame_name.c_str()); return false; } - - // Mark loop as unknown shape If any merge has unknown shape output. - const auto &op_desc = pair_iter.first->GetOpDesc(); - if (IsUnknownShapeTensor(op_desc->GetOutputDesc(0))) { - loop_group_iter.second->is_unknown_shape = true; // under check loop, cannot break. - } } } @@ -225,7 +219,7 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { enter_active->GetName().c_str()); return INTERNAL_ERROR; } - MarkForceUnknownShape(enter_node, loop_group.is_unknown_shape, group_index); + SetControlFlowGroup(enter_node, group_index); } for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { @@ -255,8 +249,8 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { return INTERNAL_ERROR; } - MarkForceUnknownShape(next_node, loop_group.is_unknown_shape, group_index); - MarkForceUnknownShape(merge_node, loop_group.is_unknown_shape, group_index); + SetControlFlowGroup(next_node, group_index); + SetControlFlowGroup(merge_node, group_index); } if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || @@ -265,9 +259,9 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { return INTERNAL_ERROR; } - MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape, group_index); - MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape, group_index); - MarkForceUnknownShape(next_active, loop_group.is_unknown_shape, group_index); + SetControlFlowGroup(loop_group.loop_cond, group_index); + SetControlFlowGroup(enter_active, group_index); + SetControlFlowGroup(next_active, group_index); HandleSwitchExitNodes(loop_group, group_index); } @@ -281,17 +275,13 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { /// @return void /// void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) { - if (!loop_group.is_unknown_shape) { - return; - } - for (const auto &switch_node : loop_group.switch_nodes) { - MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape, group_index); + SetControlFlowGroup(switch_node, group_index); for (const auto &node : switch_node->GetOutDataNodes()) { std::string node_type; (void)GetOriginalType(node, node_type); if (kExitOpTypes.count(node_type) > 0) { - MarkForceUnknownShape(node, loop_group.is_unknown_shape, group_index); + SetControlFlowGroup(node, group_index); } } } diff --git a/ge/graph/passes/next_iteration_pass.h b/ge/graph/passes/next_iteration_pass.h index b6a0846d..2143719c 100755 --- a/ge/graph/passes/next_iteration_pass.h +++ b/ge/graph/passes/next_iteration_pass.h @@ -24,7 +24,6 @@ struct LoopCondGroup { std::vector enter_nodes; // Enter nodes std::vector> merge_next_pairs; // std::vector switch_nodes; // Switch nodes - bool is_unknown_shape{false}; }; using LoopCondGroupPtr = std::shared_ptr; diff --git a/ge/hybrid/executor/node_state.cc b/ge/hybrid/executor/node_state.cc index fd47cfb2..313a2934 100644 --- a/ge/hybrid/executor/node_state.cc +++ b/ge/hybrid/executor/node_state.cc @@ -22,6 +22,14 @@ #include "hybrid_execution_context.h" #include "subgraph_context.h" +#define INC_ITERATION_COUNT(iteration) \ +do { \ + ++iteration; \ + if (iteration == UINT64_MAX) { \ + iteration = 1; \ + } \ +} while (0) + namespace ge { namespace hybrid { namespace { @@ -306,15 +314,45 @@ std::shared_ptr NodeState::GetTaskContext() { return task_context_; } -void NodeState::ResetContext(uint64_t loop_count) { - loop_count_ = loop_count; - +void NodeState::ResetContext(uint64_t iteration) { switch_index_ = -1; subgraph_context_->ResetContext(node_item_->node); - data_scheduled_ = static_cast(node_item_->root_data_.size()); - ctrl_scheduled_ = static_cast(node_item_->root_ctrl_.size()); - GELOGD("[%s] in while loop, loop count: %lu, data scheduled: %u, ctrl scheduled: %u, merge index: %d", - GetName().c_str(), loop_count_, data_scheduled_, ctrl_scheduled_, merge_index_); + if (iteration == 0) { + data_scheduled_ = static_cast(node_item_->root_data_.size()); + ctrl_scheduled_ = static_cast(node_item_->root_ctrl_.size()); + } else { + data_scheduled_ = static_cast(node_item_->root_data_.size() + node_item_->enter_data_.size()); + ctrl_scheduled_ = static_cast(node_item_->root_ctrl_.size() + node_item_->enter_ctrl_.size()); + } + + iteration_count_ = iteration; + GELOGD("[%s] in while loop, current iteration: %lu, data scheduled: %u, ctrl scheduled: %u, merge index: %d", + GetName().c_str(), iteration_count_, data_scheduled_, ctrl_scheduled_, merge_index_); +} + +void NodeState::ScheduleContext(const NodeState &node_state) { + if (node_state.node_item_->IsEnterOp()) { + GELOGD("[%s]{active: %lu, iteration: %lu}, frame{active: %lu, iteration: %lu} [%s]{active: %lu, iteration: %lu}", + GetName().c_str(), active_count_, iteration_count_, frame_state_->active_count_, + frame_state_->iteration_count_, node_state.GetName().c_str(), node_state.frame_state_->active_count_, + node_state.frame_state_->iteration_count_); + if (frame_state_->active_count_ != active_count_) { + ResetContext(0); + active_count_ = frame_state_->active_count_; + } + } else if (node_state.node_item_->IsExitOp()) { + GELOGD("[%s]{active: %lu, iteration: %lu} frame{active: %lu, iteration: %lu} " + "[%s]{active: %lu, iteration: %lu} parent{active: %lu, iteration: %lu}", + GetName().c_str(), active_count_, iteration_count_, frame_state_->active_count_, + frame_state_->iteration_count_, node_state.GetName().c_str(), node_state.frame_state_->active_count_, + node_state.frame_state_->iteration_count_, node_state.frame_state_->parent_frame_->active_count_, + node_state.frame_state_->parent_frame_->iteration_count_); + if (node_state.frame_state_->parent_frame_->iteration_count_ != iteration_count_) { + ResetContext(node_state.frame_state_->parent_frame_->iteration_count_); + } + } else if (node_state.iteration_count_ != iteration_count_) { + ResetContext(node_state.iteration_count_); + } } Status NodeState::NodeScheduled(const std::function &ready) const { @@ -346,11 +384,11 @@ Status NodeState::NodeScheduled(const std::function &rea } bool NodeState::IsScheduleReady() const { - GELOGD("[%s] loop[%lu] data[input: %zu, scheduled: %u], ctrl[input: %zu+%zu, scheduled: %u]", - GetName().c_str(), loop_count_, node_item_->data_recv_.size(), data_scheduled_, - node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); + GELOGD("[%s] iteration[%lu] data[input: %zu, scheduled: %u], ctrl[input: %zu+%zu, scheduled: %u]", + GetName().c_str(), iteration_count_, node_item_->data_recv_.size(), data_scheduled_, + node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(iteration_count_ == 0 ? 0 : 1), ctrl_scheduled_); if (node_item_->IsMergeOp()) { - if (ctrl_scheduled_ != node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1) + node_item_->ctrl_recv_.size()) { + if (ctrl_scheduled_ != node_item_->GetMergeCtrl(iteration_count_ == 0 ? 0 : 1) + node_item_->ctrl_recv_.size()) { return false; } @@ -366,15 +404,13 @@ bool NodeState::IsScheduleReady() const { } void NodeState::SetDataSchedule(const NodeState &node_state, const std::function &ready) { - GELOGD("[%s] schedule [%s], loop[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]", - node_state.GetName().c_str(), GetName().c_str(), loop_count_, node_state.loop_count_, + GELOGD("[%s] schedule [%s], iteration[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]", + node_state.GetName().c_str(), GetName().c_str(), iteration_count_, node_state.iteration_count_, node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(), - node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); + node_item_->GetMergeCtrl(iteration_count_ == 0 ? 0 : 1), ctrl_scheduled_); std::lock_guard lk(mu_); - if (loop_count_ != node_state.loop_count_) { - ResetContext(node_state.loop_count_); - } + ScheduleContext(node_state); ++data_scheduled_; if (node_item_->IsMergeOp()) { @@ -394,15 +430,13 @@ void NodeState::SetDataSchedule(const NodeState &node_state, const std::function } void NodeState::SetCtrlSchedule(const NodeState &node_state, const std::function &ready) { - GELOGD("[%s] schedule [%s], loop[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]", - node_state.GetName().c_str(), GetName().c_str(), loop_count_, node_state.loop_count_, + GELOGD("[%s] schedule [%s], iteration[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]", + node_state.GetName().c_str(), GetName().c_str(), iteration_count_, node_state.iteration_count_, node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(), - node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); + node_item_->GetMergeCtrl(iteration_count_ == 0 ? 0 : 1), ctrl_scheduled_); std::lock_guard lk(mu_); - if (loop_count_ != node_state.loop_count_) { - ResetContext(node_state.loop_count_); - } + ScheduleContext(node_state); ++ctrl_scheduled_; if (IsScheduleReady()) { @@ -410,21 +444,28 @@ void NodeState::SetCtrlSchedule(const NodeState &node_state, const std::function } } -void NodeState::RunLoopNext() { - GELOGD("Node[%s] run in loop, current count: %lu", GetName().c_str(), loop_count_); +void NodeState::RunNextIteration() { std::lock_guard lk(mu_); - ++loop_count_; - if (loop_count_ == UINT64_MAX) { - loop_count_ = 1; - } - - ResetContext(loop_count_); + INC_ITERATION_COUNT(iteration_count_); + ResetContext(iteration_count_); } -void NodeState::RunLoopExit() { - GELOGD("Node[%s] run in loop, current count: %lu", GetName().c_str(), loop_count_); +void NodeState::RunStreamActive() { std::lock_guard lk(mu_); - loop_count_ = 0; + if (node_item_->ctrl_send_.empty()) { // Not for Loop Enter or Loop Next. + return; + } + switch_index_ = 0; + data_scheduled_ = 0; + ctrl_scheduled_ = 0; + if (node_item_->is_enter_active_) { + frame_state_->iteration_count_ = 0; + INC_ITERATION_COUNT(frame_state_->active_count_); + } else { + INC_ITERATION_COUNT(frame_state_->iteration_count_); + } + GELOGD("Node[%s] current iteration: %lu, frame active: %lu, frame iteration: %lu", + GetName().c_str(), iteration_count_, frame_state_->active_count_, frame_state_->iteration_count_); } void NodeState::SetScheduleFuture(std::future &&future) { diff --git a/ge/hybrid/executor/node_state.h b/ge/hybrid/executor/node_state.h index e4afdb9f..9dd29846 100644 --- a/ge/hybrid/executor/node_state.h +++ b/ge/hybrid/executor/node_state.h @@ -33,8 +33,10 @@ struct GraphExecutionContext; class SubgraphContext; class TaskContext; struct NodeState; +struct FrameState; using NodeStatePtr = std::shared_ptr; +using FrameStatePtr = std::shared_ptr; class ShapeFuture { public: @@ -80,6 +82,18 @@ struct ShapeInferenceState { std::mutex mu_; }; +struct FrameState { + public: + FrameState(int64_t id) : frame_id_(id) {} + ~FrameState() = default; + + int64_t frame_id_{0}; + uint64_t active_count_{0}; + uint64_t iteration_count_{0}; + + std::shared_ptr parent_frame_; +}; + // saving sth. dynamic during execution struct NodeState { public: @@ -112,8 +126,8 @@ struct NodeState { return node_item_->IsControlFlowOp() || node_item_->shape_inference_type >= DEPEND_SHAPE_RANGE; } - void RunLoopNext(); - void RunLoopExit(); + void RunStreamActive(); + void RunNextIteration(); Status NodeScheduled(const std::function &ready) const; @@ -144,6 +158,10 @@ struct NodeState { return group_; } + void SetFrameState(const shared_ptr &frame_state) { + frame_state_ = frame_state; + } + const shared_ptr &GetKernelTask() const { return kernel_task_; } @@ -167,7 +185,8 @@ struct NodeState { bool IsScheduleReady() const; void SetDataSchedule(const NodeState &node_state, const std::function &ready); void SetCtrlSchedule(const NodeState &node_state, const std::function &ready); - void ResetContext(uint64_t loop_count); + void ResetContext(uint64_t iteration); + void ScheduleContext(const NodeState &node_state); const NodeItem *node_item_ = nullptr; std::shared_ptr kernel_task_ = nullptr; @@ -179,7 +198,9 @@ struct NodeState { std::mutex mu_; std::future schedule_future_; - uint64_t loop_count_ = 0; + std::shared_ptr frame_state_; + uint64_t active_count_ = 0; + uint64_t iteration_count_ = 0; uint32_t ctrl_scheduled_ = 0; uint32_t data_scheduled_ = 0; int merge_index_ = -1; // Use for Execute (Reset after Executed). diff --git a/ge/hybrid/executor/subgraph_context.cc b/ge/hybrid/executor/subgraph_context.cc index afd8ca79..99ea10f7 100644 --- a/ge/hybrid/executor/subgraph_context.cc +++ b/ge/hybrid/executor/subgraph_context.cc @@ -89,6 +89,7 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { if (node_state == nullptr) { const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); node_state.reset(new(std::nothrow)NodeState(*node_item, this)); + node_state->SetFrameState(GetOrCreateFrameState(*node_item)); node_state->SetGroup(group_); (void)guard; } @@ -102,6 +103,18 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { return node_state; } +FrameStatePtr SubgraphContext::GetOrCreateFrameState(const NodeItem &node_item) { + auto &frame_state = frame_states_[node_item.frame_index_]; + if (frame_state == nullptr) { + frame_state.reset(new(std::nothrow)FrameState(node_item.frame_index_)); + if (node_item.frame_index_ != -1) { // -1 is root frame. + frame_state->parent_frame_ = frame_states_[node_item.parent_frame_]; + } + } + + return frame_state; +} + Status SubgraphContext::SetInput(int index, const TensorValue &tensor) { if (static_cast(index) >= all_inputs_.size()) { GELOGE(INTERNAL_ERROR, diff --git a/ge/hybrid/executor/subgraph_context.h b/ge/hybrid/executor/subgraph_context.h index 303382c1..a43cd210 100644 --- a/ge/hybrid/executor/subgraph_context.h +++ b/ge/hybrid/executor/subgraph_context.h @@ -51,6 +51,7 @@ class SubgraphContext { void NodeDone(const NodePtr &node); private: + FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock friend class TaskContext; const GraphItem *graph_item_; const GraphExecutionContext *execution_context_; @@ -59,6 +60,7 @@ class SubgraphContext { std::vector all_outputs_; NodeDoneManager node_done_manager_; std::unordered_map node_states_; + std::unordered_map frame_states_; int group_ = -1; }; } // namespace hybrid diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index 906dddae..bbde3ffa 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -1984,6 +1984,7 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item)); GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task + GE_CHK_STATUS_RET_NOLOG(BuildFrameGroupIndex(*node_item)); GE_CHK_STATUS_RET_NOLOG(BuildControlFlowGroup(*graph_item, node, node_item)); if (node->GetInAllNodes().empty()) { graph_item->root_items_.emplace_back(node_item); @@ -2347,6 +2348,60 @@ Status HybridModelBuilder::BuildProfilingControl(GraphItem &graph_item, return SUCCESS; } +Status HybridModelBuilder::BuildFrameGroupIndex(NodeItem &node_item) { + if (node_item.is_root_node_) { + GELOGD("[%s] control flow frame group: %ld, parent frame: %ld", + node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_); + return SUCCESS; + } + + int64_t ctrl_flow_group = -1; + if (node_item.IsEnterOp() && AttrUtils::GetInt(node_item.op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, ctrl_flow_group)) { + node_item.frame_index_ = ctrl_flow_group; + if (node_item.IsEnterOp()) { + const auto src_node = node_item.node->GetInDataNodes().at(0); + NodeItem *src_node_item = nullptr; + GE_CHK_STATUS_RET(GetOrCreateNodeItem(src_node, &src_node_item), + "[%s] failed to get or create node item", src_node->GetName().c_str()); + if (!src_node_item->is_root_node_) { + parent_frame_group_[node_item.frame_index_] = src_node_item->frame_index_; + } + } + + const auto it = parent_frame_group_.find(node_item.frame_index_); + node_item.parent_frame_ = (it != parent_frame_group_.end()) ? it->second : -1; + GELOGD("[%s] control flow frame group: %ld, parent frame: %ld", + node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_); + return SUCCESS; + } + + for (const auto src_node : node_item.node->GetInAllNodes()) { + NodeItem *src_node_item = nullptr; + GE_CHK_STATUS_RET(GetOrCreateNodeItem(src_node, &src_node_item), + "[%s] failed to get or create node item", src_node->GetName().c_str()); + if (src_node_item->is_root_node_) { + continue; + } + + if (src_node_item->IsExitOp()) { + const auto it = parent_frame_group_.find(src_node_item->frame_index_); + node_item.frame_index_ = (it != parent_frame_group_.end()) ? it->second : -1; + } else { + node_item.frame_index_ = src_node_item->frame_index_; + } + + const auto it = parent_frame_group_.find(src_node_item->frame_index_); + node_item.parent_frame_ = (it != parent_frame_group_.end()) ? it->second : -1; + GELOGD("[%s] control flow frame group: %ld, parent frame: %ld", + node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_); + return SUCCESS; + } + + GELOGD("[%s] control flow frame group: %ld, parent frame: %ld", + node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_); + return SUCCESS; +} + Status HybridModelBuilder::BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item) { GELOGD("Build control flow for node %s", node->GetName().c_str()); using GroupBuilder = std::function; @@ -2466,6 +2521,7 @@ Status HybridModelBuilder::CreateStreamActiveGroup(const NodePtr &node, NodeItem if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsEnterNode)) { // Enter --> StreamActive --> StreamMerge + node_item->is_enter_active_ = true; return CreateMergeEnterGroup(node, node_item); } else if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsIterationNode)) { // NextIteration --> StreamActive {-->} StreamMerge diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index d0ee54ed..92974441 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -97,6 +97,7 @@ class HybridModelBuilder { Status RelinkNextIteration(); Status BuildProfilingControl(GraphItem &graph_item, const std::map> &nodes); + Status BuildFrameGroupIndex(NodeItem &node_item); Status BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item); Status CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item); Status CreateMergeEnterGroup(const NodePtr &node, NodeItem *node_item); @@ -123,6 +124,7 @@ class HybridModelBuilder { std::map constant_op_nodes_; std::map stream_merge_op_nodes_; std::map next_iteration_op_nodes_; + std::map parent_frame_group_; std::map> parallel_group_to_nodes_; std::map> node_to_parallel_groups_; diff --git a/ge/hybrid/model/node_item.cc b/ge/hybrid/model/node_item.cc index 7054fd46..b339e630 100644 --- a/ge/hybrid/model/node_item.cc +++ b/ge/hybrid/model/node_item.cc @@ -20,7 +20,6 @@ #include "graph/common/omg_util.h" #include "graph/compute_graph.h" #include "graph/debug/ge_attr_define.h" -#include "graph/utils/node_utils.h" #include "hybrid/executor/worker/shape_inference_engine.h" #include "hybrid/node_executor/node_executor.h" @@ -34,7 +33,7 @@ const std::set kControlOpTypes{ }; const std::set kControlFlowOpTypes{ - STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT, + STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT, LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX }; @@ -402,8 +401,8 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { node_item->root_data_.emplace(this); } // If Enter feed Not Merge, take as root Node. - if ((kEnterOpTypes.count(node_type) > 0) && (node_item->node_type != STREAMMERGE)) { - node_item->root_data_.emplace(this); + if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { + node_item->enter_data_.emplace(this); node_item->enter_inside_.emplace(anchor_index); } GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); @@ -422,8 +421,8 @@ void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { node_item->root_ctrl_.emplace(this); } // If Enter feed control signal, take as root Node. - if (kEnterOpTypes.count(node_type) > 0) { - node_item->root_ctrl_.emplace(this); + if (IsEnterOp() && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) { + node_item->enter_ctrl_.emplace(this); } GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); } diff --git a/ge/hybrid/model/node_item.h b/ge/hybrid/model/node_item.h index 67f92868..8de15952 100644 --- a/ge/hybrid/model/node_item.h +++ b/ge/hybrid/model/node_item.h @@ -22,6 +22,7 @@ #include "external/ge/ge_api_error_codes.h" #include "graph/node.h" #include "graph/op_desc.h" +#include "graph/utils/node_utils.h" #include "framework/common/types.h" #include "hybrid/common/tensor_value.h" @@ -92,6 +93,14 @@ struct NodeItem { return is_merge_op_; } + bool IsEnterOp() const { + return kEnterOpTypes.count(node_type) > 0; + } + + bool IsExitOp() const { + return kExitOpTypes.count(node_type) > 0; + } + bool IsHcclOp() const; void SetToDynamic(); @@ -135,8 +144,13 @@ struct NodeItem { bool is_ctrl_flow_v2_op_ = false; bool is_ctrl_flow_op_ = false; bool is_merge_op_ = false; + bool is_enter_active_ = false; + int64_t frame_index_ = -1; + int64_t parent_frame_ = -1; std::set root_ctrl_; // Recv ctrl from root node std::set root_data_; // Recv data from root node + std::set enter_ctrl_; // Recv ctrl from Enter node + std::set enter_data_; // Recv data from Enter node std::set data_send_; // Send data notify to std::map data_recv_; // Recv data notify from std::set ctrl_send_; // Send ctrl notify to diff --git a/ge/hybrid/node_executor/rts/rts_node_task.cc b/ge/hybrid/node_executor/rts/rts_node_task.cc index 5ad8eaf4..104196ee 100644 --- a/ge/hybrid/node_executor/rts/rts_node_task.cc +++ b/ge/hybrid/node_executor/rts/rts_node_task.cc @@ -90,7 +90,7 @@ Status RtsNodeTask::GetScalarIndexValue(TaskContext &task_context, uint32_t inde Status StreamActiveNodeTask::ExecuteAsync(TaskContext &task_context, std::function done_callback) { GELOGD("[%s] Start to execute.", task_context.GetNodeName()); const auto &node_state = task_context.GetNodeState(); - node_state->SetSwitchIndex(0); + node_state->RunStreamActive(); if (done_callback) { GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); } @@ -204,9 +204,7 @@ Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::functio const auto &node_state = task_context.GetNodeState(); if (kNextIterationOpTypes.count(node_state->GetType()) > 0) { - node_state->RunLoopNext(); - } else if (kExitOpTypes.count(node_state->GetType()) > 0) { - node_state->RunLoopExit(); + node_state->RunNextIteration(); } if (done_callback) { diff --git a/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc b/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc index d97629cf..2dc3b639 100644 --- a/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc +++ b/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc @@ -127,7 +127,7 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); } - const auto less1 = CreateNode(graph, "less", EXIT, 2, 1); // Mock for less, just pass input0. + const auto less1 = CreateNode(graph, "less", IDENTITY, 2, 1); // Mock for less, just pass input0. const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0); switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0); @@ -135,8 +135,8 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_EQUAL); // 101 for true. AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_NOT_EQUAL); - const auto add1 = CreateNode(graph, "add", EXIT, 2, 1); // Mock for add, just pass input0. - const auto sub1 = CreateNode(graph, "sub", EXIT, 2, 1); // Mock for sub, just pass input0. + const auto add1 = CreateNode(graph, "add", IDENTITY, 2, 1); // Mock for add, just pass input0. + const auto sub1 = CreateNode(graph, "sub", IDENTITY, 2, 1); // Mock for sub, just pass input0. const auto merge1 = CreateNode(graph, "merge", STREAMMERGE, 2, 2); const auto active2 = CreateNode(graph, "active2", STREAMACTIVE, 0, 0); diff --git a/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc b/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc index 85264082..2ab82350 100644 --- a/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc +++ b/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc @@ -89,7 +89,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { * \ / \. * Switch Add * / | | - * / | | + * Active / | | * / | | * LoopCond | | * \ | | @@ -98,9 +98,10 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { * Less | | * \ | NextIteration * \ | | - * \ | | + * \ | | Active * Merge <---------| * | + * | Active * | * Enter ******************************************************************************/ @@ -110,6 +111,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { GeModelPtr ge_sub_model = make_shared(); ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); + auto data1 = CreateNode(*graph, "data", DATA, 1, 1); auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); auto merge1 = CreateNode(*graph, "merge", STREAMMERGE, 2, 2); auto less1 = CreateNode(*graph, "less", LESS, 2, 1); @@ -129,6 +131,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { auto active3 = CreateNode(*graph, "active3", STREAMACTIVE, 0, 0); auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); @@ -153,8 +156,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { GraphUtils::AddEdge(active1->GetOutControlAnchor(), merge1->GetInControlAnchor()); GraphUtils::AddEdge(next1->GetOutControlAnchor(), active3->GetInControlAnchor()); - //GraphUtils::AddEdge(active3->GetOutControlAnchor(), merge1->GetInControlAnchor()); - SetNextIteration(merge1, next1); + SetNextIteration(merge1, next1); // for relink NextIteration --> StreamMerge GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_t->GetInControlAnchor()); // Test for not merge. @@ -169,6 +171,17 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { AttrUtils::SetBool(add1->GetOpDesc(), ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true); AttrUtils::SetBool(add1->GetOpDesc(), ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true); + SetControlFlowGroup(enter1, loop1->GetOpDesc()->GetId()); + SetControlFlowGroup(active1, loop1->GetOpDesc()->GetId()); + SetControlFlowGroup(merge1, loop1->GetOpDesc()->GetId()); + SetControlFlowGroup(loop1, loop1->GetOpDesc()->GetId()); + SetControlFlowGroup(active2, switch_t->GetOpDesc()->GetId()); + SetControlFlowGroup(switch_t, switch_t->GetOpDesc()->GetId()); + SetControlFlowGroup(switch_f, switch_t->GetOpDesc()->GetId()); + SetControlFlowGroup(next1, loop1->GetOpDesc()->GetId()); + SetControlFlowGroup(active3, loop1->GetOpDesc()->GetId()); + SetControlFlowGroup(exit1, loop1->GetOpDesc()->GetId()); + // Build -> IndexSpecialNodes --> stream_merge_op_nodes_ // Build -> LoadGraph -> RelinkNextIteration // Build -> LoadGraph -> LoadDynamicSubgraph --> BuildNodeItem --> NodeItem::SetDataSend @@ -190,9 +203,23 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { task_executor.emplace(NodeExecutorManager::ExecutorType::RTS, std::unique_ptr(new NodeExecutor())); task_executor.emplace(NodeExecutorManager::ExecutorType::HOST_CPU, std::unique_ptr(new NodeExecutor())); + const auto control_group_index = loop1->GetOpDesc()->GetId(); HybridModel hybrid_model(ge_root_model); HybridModelBuilder hybrid_model_builder(hybrid_model); ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS); + + const auto TestFrameGroup = [&hybrid_model](const NodePtr &n, int64_t index) { + const auto it = hybrid_model.node_items_.find(n); + ASSERT_NE(hybrid_model.node_items_.end(), it); + ASSERT_EQ(it->second->frame_index_, index); + ASSERT_EQ(it->second->parent_frame_, -1); + }; + TestFrameGroup(enter1, control_group_index); + TestFrameGroup(active1, control_group_index); + TestFrameGroup(active2, control_group_index); + TestFrameGroup(active3, control_group_index); + TestFrameGroup(output1, -1); + engine_mapping.clear(); task_executor.clear(); } diff --git a/tests/ut/ge/hybrid/node_executor/rts/rts_node_task_unittest.cc b/tests/ut/ge/hybrid/node_executor/rts/rts_node_task_unittest.cc index c4c2c65b..44b2f37f 100644 --- a/tests/ut/ge/hybrid/node_executor/rts/rts_node_task_unittest.cc +++ b/tests/ut/ge/hybrid/node_executor/rts/rts_node_task_unittest.cc @@ -166,6 +166,10 @@ TEST_F(UtestRtsNodeTask, test_stream_active_task) { std::function done = []() {}; ASSERT_EQ(node_state->GetSwitchIndex(), -1); ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); + ASSERT_EQ(node_state->GetSwitchIndex(), -1); + + node_item->ctrl_send_.emplace(nullptr); + ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); ASSERT_EQ(node_state->GetSwitchIndex(), 0); }