From: @zhangxiaokun9 Reviewed-by: @xchu42,@ji_chen Signed-off-by: @ji_chentags/v1.3.0
| @@ -286,13 +286,23 @@ void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t grou | |||
| return; | |||
| } | |||
| SetControlFlowGroup(node, group_index); | |||
| } | |||
| /// | |||
| /// @brief Set Op _control_flow_group flag | |||
| /// @param [in] node | |||
| /// @param [in] group, condition group index of node. | |||
| /// @return | |||
| /// | |||
| void SetControlFlowGroup(const NodePtr &node, int64_t group) { | |||
| GE_RT_VOID_CHECK_NOTNULL(node); | |||
| const auto &op_desc = node->GetOpDesc(); | |||
| GE_RT_VOID_CHECK_NOTNULL(op_desc); | |||
| // op_desc as AttrHolderAdapter valid, Set attribute always success, just log for check. | |||
| GELOGD("[%s] Set control flow group index: %ld", node->GetName().c_str(), group_index); | |||
| if (!AttrUtils::SetInt(op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { | |||
| GELOGD("[%s] Set control flow group index: %ld", node->GetName().c_str(), group); | |||
| if (!AttrUtils::SetInt(op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, group)) { | |||
| REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), | |||
| node->GetName().c_str(), node->GetType().c_str()); | |||
| GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), | |||
| @@ -133,6 +133,14 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc); | |||
| /// @return | |||
| /// | |||
| void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index); | |||
| /// | |||
| /// @brief Set Op _control_flow_group flag | |||
| /// @param [in] node | |||
| /// @param [in] group, condition group index of node. | |||
| /// @return | |||
| /// | |||
| void SetControlFlowGroup(const NodePtr &node, int64_t group); | |||
| } // namespace ge | |||
| #endif // GE_GRAPH_COMMON_OMG_UTIL_H_ | |||
| @@ -186,12 +186,6 @@ bool NextIterationPass::VerifyWhileGroup() { | |||
| frame_name.c_str()); | |||
| return false; | |||
| } | |||
| // Mark loop as unknown shape If any merge has unknown shape output. | |||
| const auto &op_desc = pair_iter.first->GetOpDesc(); | |||
| if (IsUnknownShapeTensor(op_desc->GetOutputDesc(0))) { | |||
| loop_group_iter.second->is_unknown_shape = true; // under check loop, cannot break. | |||
| } | |||
| } | |||
| } | |||
| @@ -229,7 +223,7 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||
| enter_active->GetName().c_str(), enter_active->GetType().c_str()); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| MarkForceUnknownShape(enter_node, loop_group.is_unknown_shape, group_index); | |||
| SetControlFlowGroup(enter_node, group_index); | |||
| } | |||
| for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { | |||
| @@ -264,8 +258,8 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||
| return INTERNAL_ERROR; | |||
| } | |||
| MarkForceUnknownShape(next_node, loop_group.is_unknown_shape, group_index); | |||
| MarkForceUnknownShape(merge_node, loop_group.is_unknown_shape, group_index); | |||
| SetControlFlowGroup(next_node, group_index); | |||
| SetControlFlowGroup(merge_node, group_index); | |||
| } | |||
| if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || | |||
| @@ -274,9 +268,9 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||
| return INTERNAL_ERROR; | |||
| } | |||
| MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape, group_index); | |||
| MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape, group_index); | |||
| MarkForceUnknownShape(next_active, loop_group.is_unknown_shape, group_index); | |||
| SetControlFlowGroup(loop_group.loop_cond, group_index); | |||
| SetControlFlowGroup(enter_active, group_index); | |||
| SetControlFlowGroup(next_active, group_index); | |||
| HandleSwitchExitNodes(loop_group, group_index); | |||
| } | |||
| @@ -290,17 +284,13 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||
| /// @return void | |||
| /// | |||
| void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) { | |||
| if (!loop_group.is_unknown_shape) { | |||
| return; | |||
| } | |||
| for (const auto &switch_node : loop_group.switch_nodes) { | |||
| MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape, group_index); | |||
| SetControlFlowGroup(switch_node, group_index); | |||
| for (const auto &node : switch_node->GetOutDataNodes()) { | |||
| std::string node_type; | |||
| (void)GetOriginalType(node, node_type); | |||
| if (kExitOpTypes.count(node_type) > 0) { | |||
| MarkForceUnknownShape(node, loop_group.is_unknown_shape, group_index); | |||
| SetControlFlowGroup(node, group_index); | |||
| } | |||
| } | |||
| } | |||
| @@ -24,7 +24,6 @@ struct LoopCondGroup { | |||
| std::vector<ge::NodePtr> enter_nodes; // Enter nodes | |||
| std::vector<std::pair<ge::NodePtr, ge::NodePtr>> merge_next_pairs; // <Merge, NextIteration> | |||
| std::vector<ge::NodePtr> switch_nodes; // Switch nodes | |||
| bool is_unknown_shape{false}; | |||
| }; | |||
| using LoopCondGroupPtr = std::shared_ptr<LoopCondGroup>; | |||
| @@ -22,6 +22,14 @@ | |||
| #include "hybrid_execution_context.h" | |||
| #include "subgraph_context.h" | |||
| #define INC_ITERATION_COUNT(iteration) \ | |||
| do { \ | |||
| ++iteration; \ | |||
| if (iteration == UINT64_MAX) { \ | |||
| iteration = 1; \ | |||
| } \ | |||
| } while (0) | |||
| namespace ge { | |||
| namespace hybrid { | |||
| namespace { | |||
| @@ -306,15 +314,45 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() { | |||
| return task_context_; | |||
| } | |||
| void NodeState::ResetContext(uint64_t loop_count) { | |||
| loop_count_ = loop_count; | |||
| void NodeState::ResetContext(uint64_t iteration) { | |||
| switch_index_ = -1; | |||
| subgraph_context_->ResetContext(node_item_->node); | |||
| data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | |||
| ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | |||
| GELOGD("[%s] in while loop, loop count: %lu, data scheduled: %u, ctrl scheduled: %u, merge index: %d", | |||
| GetName().c_str(), loop_count_, data_scheduled_, ctrl_scheduled_, merge_index_); | |||
| if (iteration == 0) { | |||
| data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | |||
| ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | |||
| } else { | |||
| data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size() + node_item_->enter_data_.size()); | |||
| ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size() + node_item_->enter_ctrl_.size()); | |||
| } | |||
| iteration_count_ = iteration; | |||
| GELOGD("[%s] in while loop, current iteration: %lu, data scheduled: %u, ctrl scheduled: %u, merge index: %d", | |||
| GetName().c_str(), iteration_count_, data_scheduled_, ctrl_scheduled_, merge_index_); | |||
| } | |||
| void NodeState::ScheduleContext(const NodeState &node_state) { | |||
| if (node_state.node_item_->IsEnterOp()) { | |||
| GELOGD("[%s]{active: %lu, iteration: %lu}, frame{active: %lu, iteration: %lu} [%s]{active: %lu, iteration: %lu}", | |||
| GetName().c_str(), active_count_, iteration_count_, frame_state_->active_count_, | |||
| frame_state_->iteration_count_, node_state.GetName().c_str(), node_state.frame_state_->active_count_, | |||
| node_state.frame_state_->iteration_count_); | |||
| if (frame_state_->active_count_ != active_count_) { | |||
| ResetContext(0); | |||
| active_count_ = frame_state_->active_count_; | |||
| } | |||
| } else if (node_state.node_item_->IsExitOp()) { | |||
| GELOGD("[%s]{active: %lu, iteration: %lu} frame{active: %lu, iteration: %lu} " | |||
| "[%s]{active: %lu, iteration: %lu} parent{active: %lu, iteration: %lu}", | |||
| GetName().c_str(), active_count_, iteration_count_, frame_state_->active_count_, | |||
| frame_state_->iteration_count_, node_state.GetName().c_str(), node_state.frame_state_->active_count_, | |||
| node_state.frame_state_->iteration_count_, node_state.frame_state_->parent_frame_->active_count_, | |||
| node_state.frame_state_->parent_frame_->iteration_count_); | |||
| if (node_state.frame_state_->parent_frame_->iteration_count_ != iteration_count_) { | |||
| ResetContext(node_state.frame_state_->parent_frame_->iteration_count_); | |||
| } | |||
| } else if (node_state.iteration_count_ != iteration_count_) { | |||
| ResetContext(node_state.iteration_count_); | |||
| } | |||
| } | |||
| Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &ready) const { | |||
| @@ -346,11 +384,11 @@ Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &rea | |||
| } | |||
| bool NodeState::IsScheduleReady() const { | |||
| GELOGD("[%s] loop[%lu] data[input: %zu, scheduled: %u], ctrl[input: %zu+%zu, scheduled: %u]", | |||
| GetName().c_str(), loop_count_, node_item_->data_recv_.size(), data_scheduled_, | |||
| node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||
| GELOGD("[%s] iteration[%lu] data[input: %zu, scheduled: %u], ctrl[input: %zu+%zu, scheduled: %u]", | |||
| GetName().c_str(), iteration_count_, node_item_->data_recv_.size(), data_scheduled_, | |||
| node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(iteration_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||
| if (node_item_->IsMergeOp()) { | |||
| if (ctrl_scheduled_ != node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1) + node_item_->ctrl_recv_.size()) { | |||
| if (ctrl_scheduled_ != node_item_->GetMergeCtrl(iteration_count_ == 0 ? 0 : 1) + node_item_->ctrl_recv_.size()) { | |||
| return false; | |||
| } | |||
| @@ -366,15 +404,13 @@ bool NodeState::IsScheduleReady() const { | |||
| } | |||
| void NodeState::SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready) { | |||
| GELOGD("[%s] schedule [%s], loop[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]", | |||
| node_state.GetName().c_str(), GetName().c_str(), loop_count_, node_state.loop_count_, | |||
| GELOGD("[%s] schedule [%s], iteration[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]", | |||
| node_state.GetName().c_str(), GetName().c_str(), iteration_count_, node_state.iteration_count_, | |||
| node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(), | |||
| node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||
| node_item_->GetMergeCtrl(iteration_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||
| std::lock_guard<std::mutex> lk(mu_); | |||
| if (loop_count_ != node_state.loop_count_) { | |||
| ResetContext(node_state.loop_count_); | |||
| } | |||
| ScheduleContext(node_state); | |||
| ++data_scheduled_; | |||
| if (node_item_->IsMergeOp()) { | |||
| @@ -394,15 +430,13 @@ void NodeState::SetDataSchedule(const NodeState &node_state, const std::function | |||
| } | |||
| void NodeState::SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready) { | |||
| GELOGD("[%s] schedule [%s], loop[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]", | |||
| node_state.GetName().c_str(), GetName().c_str(), loop_count_, node_state.loop_count_, | |||
| GELOGD("[%s] schedule [%s], iteration[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]", | |||
| node_state.GetName().c_str(), GetName().c_str(), iteration_count_, node_state.iteration_count_, | |||
| node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(), | |||
| node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||
| node_item_->GetMergeCtrl(iteration_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||
| std::lock_guard<std::mutex> lk(mu_); | |||
| if (loop_count_ != node_state.loop_count_) { | |||
| ResetContext(node_state.loop_count_); | |||
| } | |||
| ScheduleContext(node_state); | |||
| ++ctrl_scheduled_; | |||
| if (IsScheduleReady()) { | |||
| @@ -410,21 +444,28 @@ void NodeState::SetCtrlSchedule(const NodeState &node_state, const std::function | |||
| } | |||
| } | |||
| void NodeState::RunLoopNext() { | |||
| GELOGD("Node[%s] run in loop, current count: %lu", GetName().c_str(), loop_count_); | |||
| void NodeState::RunNextIteration() { | |||
| std::lock_guard<std::mutex> lk(mu_); | |||
| ++loop_count_; | |||
| if (loop_count_ == UINT64_MAX) { | |||
| loop_count_ = 1; | |||
| } | |||
| ResetContext(loop_count_); | |||
| INC_ITERATION_COUNT(iteration_count_); | |||
| ResetContext(iteration_count_); | |||
| } | |||
| void NodeState::RunLoopExit() { | |||
| GELOGD("Node[%s] run in loop, current count: %lu", GetName().c_str(), loop_count_); | |||
| void NodeState::RunStreamActive() { | |||
| std::lock_guard<std::mutex> lk(mu_); | |||
| loop_count_ = 0; | |||
| if (node_item_->ctrl_send_.empty()) { // Not for Loop Enter or Loop Next. | |||
| return; | |||
| } | |||
| switch_index_ = 0; | |||
| data_scheduled_ = 0; | |||
| ctrl_scheduled_ = 0; | |||
| if (node_item_->is_enter_active_) { | |||
| frame_state_->iteration_count_ = 0; | |||
| INC_ITERATION_COUNT(frame_state_->active_count_); | |||
| } else { | |||
| INC_ITERATION_COUNT(frame_state_->iteration_count_); | |||
| } | |||
| GELOGD("Node[%s] current iteration: %lu, frame active: %lu, frame iteration: %lu", | |||
| GetName().c_str(), iteration_count_, frame_state_->active_count_, frame_state_->iteration_count_); | |||
| } | |||
| void NodeState::SetScheduleFuture(std::future<Status> &&future) { | |||
| @@ -33,8 +33,10 @@ struct GraphExecutionContext; | |||
| class SubgraphContext; | |||
| class TaskContext; | |||
| struct NodeState; | |||
| struct FrameState; | |||
| using NodeStatePtr = std::shared_ptr<NodeState>; | |||
| using FrameStatePtr = std::shared_ptr<FrameState>; | |||
| class ShapeFuture { | |||
| public: | |||
| @@ -80,6 +82,18 @@ struct ShapeInferenceState { | |||
| std::mutex mu_; | |||
| }; | |||
| struct FrameState { | |||
| public: | |||
| FrameState(int64_t id) : frame_id_(id) {} | |||
| ~FrameState() = default; | |||
| int64_t frame_id_{0}; | |||
| uint64_t active_count_{0}; | |||
| uint64_t iteration_count_{0}; | |||
| std::shared_ptr<FrameState> parent_frame_; | |||
| }; | |||
| // saving sth. dynamic during execution | |||
| struct NodeState { | |||
| public: | |||
| @@ -112,8 +126,8 @@ struct NodeState { | |||
| return node_item_->IsControlFlowOp() || node_item_->shape_inference_type >= DEPEND_SHAPE_RANGE; | |||
| } | |||
| void RunLoopNext(); | |||
| void RunLoopExit(); | |||
| void RunStreamActive(); | |||
| void RunNextIteration(); | |||
| Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | |||
| @@ -144,6 +158,10 @@ struct NodeState { | |||
| return group_; | |||
| } | |||
| void SetFrameState(const shared_ptr<FrameState> &frame_state) { | |||
| frame_state_ = frame_state; | |||
| } | |||
| const shared_ptr<NodeTask> &GetKernelTask() const { | |||
| return kernel_task_; | |||
| } | |||
| @@ -167,7 +185,8 @@ struct NodeState { | |||
| bool IsScheduleReady() const; | |||
| void SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | |||
| void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | |||
| void ResetContext(uint64_t loop_count); | |||
| void ResetContext(uint64_t iteration); | |||
| void ScheduleContext(const NodeState &node_state); | |||
| const NodeItem *node_item_ = nullptr; | |||
| std::shared_ptr<NodeTask> kernel_task_ = nullptr; | |||
| @@ -179,7 +198,9 @@ struct NodeState { | |||
| std::mutex mu_; | |||
| std::future<Status> schedule_future_; | |||
| uint64_t loop_count_ = 0; | |||
| std::shared_ptr<FrameState> frame_state_; | |||
| uint64_t active_count_ = 0; | |||
| uint64_t iteration_count_ = 0; | |||
| uint32_t ctrl_scheduled_ = 0; | |||
| uint32_t data_scheduled_ = 0; | |||
| int merge_index_ = -1; // Use for Execute (Reset after Executed). | |||
| @@ -89,6 +89,7 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { | |||
| if (node_state == nullptr) { | |||
| const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); | |||
| node_state.reset(new(std::nothrow)NodeState(*node_item, this)); | |||
| node_state->SetFrameState(GetOrCreateFrameState(*node_item)); | |||
| node_state->SetGroup(group_); | |||
| (void)guard; | |||
| } | |||
| @@ -102,6 +103,20 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { | |||
| return node_state; | |||
| } | |||
| FrameStatePtr SubgraphContext::GetOrCreateFrameState(const NodeItem &node_item) { | |||
| auto &frame_state = frame_states_[node_item.frame_index_]; | |||
| if (frame_state == nullptr) { | |||
| GELOGD("[%s] Create FrameState, frame index: %ld, parent frame index: %ld", | |||
| node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_); | |||
| frame_state.reset(new(std::nothrow)FrameState(node_item.frame_index_)); | |||
| if (node_item.frame_index_ != -1) { // -1 is root frame. | |||
| frame_state->parent_frame_ = frame_states_[node_item.parent_frame_]; | |||
| } | |||
| } | |||
| return frame_state; | |||
| } | |||
| Status SubgraphContext::SetInput(int index, const TensorValue &tensor) { | |||
| if (static_cast<size_t>(index) >= all_inputs_.size()) { | |||
| GELOGE(INTERNAL_ERROR, | |||
| @@ -51,6 +51,7 @@ class SubgraphContext { | |||
| void NodeDone(const NodePtr &node); | |||
| private: | |||
| FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock | |||
| friend class TaskContext; | |||
| const GraphItem *graph_item_; | |||
| const GraphExecutionContext *execution_context_; | |||
| @@ -59,6 +60,7 @@ class SubgraphContext { | |||
| std::vector<TensorValue> all_outputs_; | |||
| NodeDoneManager node_done_manager_; | |||
| std::unordered_map<const NodeItem *, NodeStatePtr> node_states_; | |||
| std::unordered_map<int64_t, FrameStatePtr> frame_states_; | |||
| int group_ = -1; | |||
| }; | |||
| } // namespace hybrid | |||
| @@ -1945,6 +1945,7 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root | |||
| GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item)); | |||
| GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task | |||
| GE_CHK_STATUS_RET_NOLOG(BuildFrameGroupIndex(*node_item)); | |||
| GE_CHK_STATUS_RET_NOLOG(BuildControlFlowGroup(*graph_item, node, node_item)); | |||
| if (node->GetInAllNodes().empty()) { | |||
| graph_item->root_items_.emplace_back(node_item); | |||
| @@ -2308,6 +2309,62 @@ Status HybridModelBuilder::BuildProfilingControl(GraphItem &graph_item, | |||
| return SUCCESS; | |||
| } | |||
| Status HybridModelBuilder::BuildFrameGroupIndex(NodeItem &node_item) { | |||
| if (node_item.is_root_node_) { | |||
| GELOGD("[%s] control flow frame group: %ld, parent frame: %ld", | |||
| node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_); | |||
| return SUCCESS; | |||
| } | |||
| int64_t ctrl_flow_group = -1; | |||
| if (node_item.IsEnterOp() && AttrUtils::GetInt(node_item.op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, ctrl_flow_group)) { | |||
| node_item.frame_index_ = ctrl_flow_group; | |||
| for (const auto src_node : node_item.node->GetInAllNodes()) { | |||
| NodeItem *src_node_item = nullptr; | |||
| GE_CHK_STATUS_RET(GetOrCreateNodeItem(src_node, &src_node_item), | |||
| "[%s] failed to get or create node item", src_node->GetName().c_str()); | |||
| if (!src_node_item->is_root_node_) { | |||
| GELOGD("[%s] frame index: %ld, [%s] parent frame index: %ld", node_item.node_name.c_str(), | |||
| node_item.frame_index_, src_node_item->node_name.c_str(), src_node_item->frame_index_); | |||
| parent_frame_group_[node_item.frame_index_] = src_node_item->frame_index_; | |||
| break; | |||
| } | |||
| } | |||
| const auto it = parent_frame_group_.find(node_item.frame_index_); | |||
| node_item.parent_frame_ = (it != parent_frame_group_.end()) ? it->second : -1; | |||
| GELOGD("[%s] control flow frame group: %ld, parent frame: %ld", | |||
| node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_); | |||
| return SUCCESS; | |||
| } | |||
| for (const auto src_node : node_item.node->GetInAllNodes()) { | |||
| NodeItem *src_node_item = nullptr; | |||
| GE_CHK_STATUS_RET(GetOrCreateNodeItem(src_node, &src_node_item), | |||
| "[%s] failed to get or create node item", src_node->GetName().c_str()); | |||
| if (src_node_item->is_root_node_) { | |||
| continue; | |||
| } | |||
| if (src_node_item->IsExitOp()) { | |||
| const auto it = parent_frame_group_.find(src_node_item->frame_index_); | |||
| node_item.frame_index_ = (it != parent_frame_group_.end()) ? it->second : -1; | |||
| } else { | |||
| node_item.frame_index_ = src_node_item->frame_index_; | |||
| } | |||
| const auto it = parent_frame_group_.find(node_item.frame_index_); | |||
| node_item.parent_frame_ = (it != parent_frame_group_.end()) ? it->second : -1; | |||
| GELOGD("[%s] control flow frame group: %ld, parent frame: %ld", | |||
| node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_); | |||
| return SUCCESS; | |||
| } | |||
| GELOGD("[%s] control flow frame group: %ld, parent frame: %ld", | |||
| node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_); | |||
| return SUCCESS; | |||
| } | |||
| Status HybridModelBuilder::BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item) { | |||
| GELOGD("Build control flow for node %s", node->GetName().c_str()); | |||
| using GroupBuilder = std::function<Status(HybridModelBuilder *, const NodePtr &, NodeItem *)>; | |||
| @@ -2427,6 +2484,7 @@ Status HybridModelBuilder::CreateStreamActiveGroup(const NodePtr &node, NodeItem | |||
| if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsEnterNode)) { | |||
| // Enter --> StreamActive --> StreamMerge | |||
| node_item->is_enter_active_ = true; | |||
| return CreateMergeEnterGroup(node, node_item); | |||
| } else if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsIterationNode)) { | |||
| // NextIteration --> StreamActive {-->} StreamMerge | |||
| @@ -97,6 +97,7 @@ class HybridModelBuilder { | |||
| Status RelinkNextIteration(); | |||
| Status BuildProfilingControl(GraphItem &graph_item, const std::map<size_t, std::pair<uint32_t, uint32_t>> &nodes); | |||
| Status BuildFrameGroupIndex(NodeItem &node_item); | |||
| Status BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item); | |||
| Status CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item); | |||
| Status CreateMergeEnterGroup(const NodePtr &node, NodeItem *node_item); | |||
| @@ -123,6 +124,7 @@ class HybridModelBuilder { | |||
| std::map<std::string, NodePtr> constant_op_nodes_; | |||
| std::map<std::string, NodePtr> stream_merge_op_nodes_; | |||
| std::map<std::string, NodePtr> next_iteration_op_nodes_; | |||
| std::map<int64_t, int64_t> parent_frame_group_; | |||
| std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_; | |||
| std::map<NodeItem *, std::set<std::string>> node_to_parallel_groups_; | |||
| @@ -20,7 +20,6 @@ | |||
| #include "graph/common/omg_util.h" | |||
| #include "graph/compute_graph.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "graph/utils/node_utils.h" | |||
| #include "hybrid/executor/worker/shape_inference_engine.h" | |||
| #include "hybrid/node_executor/node_executor.h" | |||
| @@ -34,7 +33,7 @@ const std::set<std::string> kControlOpTypes{ | |||
| }; | |||
| const std::set<std::string> kControlFlowOpTypes{ | |||
| STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT, | |||
| STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT, | |||
| LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX | |||
| }; | |||
| @@ -402,8 +401,8 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { | |||
| node_item->root_data_.emplace(this); | |||
| } | |||
| // If Enter feed Not Merge, take as root Node. | |||
| if ((kEnterOpTypes.count(node_type) > 0) && (node_item->node_type != STREAMMERGE)) { | |||
| node_item->root_data_.emplace(this); | |||
| if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { | |||
| node_item->enter_data_.emplace(this); | |||
| node_item->enter_inside_.emplace(anchor_index); | |||
| } | |||
| GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | |||
| @@ -422,8 +421,8 @@ void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { | |||
| node_item->root_ctrl_.emplace(this); | |||
| } | |||
| // If Enter feed control signal, take as root Node. | |||
| if (kEnterOpTypes.count(node_type) > 0) { | |||
| node_item->root_ctrl_.emplace(this); | |||
| if (IsEnterOp() && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) { | |||
| node_item->enter_ctrl_.emplace(this); | |||
| } | |||
| GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | |||
| } | |||
| @@ -22,6 +22,7 @@ | |||
| #include "external/ge/ge_api_error_codes.h" | |||
| #include "graph/node.h" | |||
| #include "graph/op_desc.h" | |||
| #include "graph/utils/node_utils.h" | |||
| #include "framework/common/types.h" | |||
| #include "hybrid/common/tensor_value.h" | |||
| @@ -92,6 +93,14 @@ struct NodeItem { | |||
| return is_merge_op_; | |||
| } | |||
| bool IsEnterOp() const { | |||
| return kEnterOpTypes.count(node_type) > 0; | |||
| } | |||
| bool IsExitOp() const { | |||
| return kExitOpTypes.count(node_type) > 0; | |||
| } | |||
| bool IsHcclOp() const; | |||
| void SetToDynamic(); | |||
| @@ -135,8 +144,13 @@ struct NodeItem { | |||
| bool is_ctrl_flow_v2_op_ = false; | |||
| bool is_ctrl_flow_op_ = false; | |||
| bool is_merge_op_ = false; | |||
| bool is_enter_active_ = false; | |||
| int64_t frame_index_ = -1; | |||
| int64_t parent_frame_ = -1; | |||
| std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node | |||
| std::set<const NodeItem *> root_data_; // Recv data from root node | |||
| std::set<const NodeItem *> enter_ctrl_; // Recv ctrl from Enter node | |||
| std::set<const NodeItem *> enter_data_; // Recv data from Enter node | |||
| std::set<const NodeItem *> data_send_; // Send data notify to | |||
| std::map<const NodeItem *, int> data_recv_; // Recv data notify from | |||
| std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | |||
| @@ -90,7 +90,7 @@ Status RtsNodeTask::GetScalarIndexValue(TaskContext &task_context, uint32_t inde | |||
| Status StreamActiveNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { | |||
| GELOGD("[%s] Start to execute.", task_context.GetNodeName()); | |||
| const auto &node_state = task_context.GetNodeState(); | |||
| node_state->SetSwitchIndex(0); | |||
| node_state->RunStreamActive(); | |||
| if (done_callback) { | |||
| GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); | |||
| } | |||
| @@ -204,9 +204,7 @@ Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::functio | |||
| const auto &node_state = task_context.GetNodeState(); | |||
| if (kNextIterationOpTypes.count(node_state->GetType()) > 0) { | |||
| node_state->RunLoopNext(); | |||
| } else if (kExitOpTypes.count(node_state->GetType()) > 0) { | |||
| node_state->RunLoopExit(); | |||
| node_state->RunNextIteration(); | |||
| } | |||
| if (done_callback) { | |||
| @@ -127,7 +127,7 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt | |||
| AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); | |||
| } | |||
| const auto less1 = CreateNode(graph, "less", EXIT, 2, 1); // Mock for less, just pass input0. | |||
| const auto less1 = CreateNode(graph, "less", IDENTITY, 2, 1); // Mock for less, just pass input0. | |||
| const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0); | |||
| switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0); | |||
| @@ -135,8 +135,8 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt | |||
| AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_EQUAL); // 101 for true. | |||
| AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_NOT_EQUAL); | |||
| const auto add1 = CreateNode(graph, "add", EXIT, 2, 1); // Mock for add, just pass input0. | |||
| const auto sub1 = CreateNode(graph, "sub", EXIT, 2, 1); // Mock for sub, just pass input0. | |||
| const auto add1 = CreateNode(graph, "add", IDENTITY, 2, 1); // Mock for add, just pass input0. | |||
| const auto sub1 = CreateNode(graph, "sub", IDENTITY, 2, 1); // Mock for sub, just pass input0. | |||
| const auto merge1 = CreateNode(graph, "merge", STREAMMERGE, 2, 2); | |||
| const auto active2 = CreateNode(graph, "active2", STREAMACTIVE, 0, 0); | |||
| @@ -89,7 +89,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||
| * \ / \. | |||
| * Switch Add | |||
| * / | | | |||
| * / | | | |||
| * Active / | | | |||
| * / | | | |||
| * LoopCond | | | |||
| * \ | | | |||
| @@ -98,9 +98,10 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||
| * Less | | | |||
| * \ | NextIteration | |||
| * \ | | | |||
| * \ | | | |||
| * \ | | Active | |||
| * Merge <---------| | |||
| * | | |||
| * | Active | |||
| * | | |||
| * Enter | |||
| ******************************************************************************/ | |||
| @@ -110,6 +111,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||
| GeModelPtr ge_sub_model = make_shared<GeModel>(); | |||
| ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||
| auto data1 = CreateNode(*graph, "data", DATA, 1, 1); | |||
| auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); | |||
| auto merge1 = CreateNode(*graph, "merge", STREAMMERGE, 2, 2); | |||
| auto less1 = CreateNode(*graph, "less", LESS, 2, 1); | |||
| @@ -129,6 +131,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||
| auto active3 = CreateNode(*graph, "active3", STREAMACTIVE, 0, 0); | |||
| auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); | |||
| GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); | |||
| @@ -153,8 +156,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||
| GraphUtils::AddEdge(active1->GetOutControlAnchor(), merge1->GetInControlAnchor()); | |||
| GraphUtils::AddEdge(next1->GetOutControlAnchor(), active3->GetInControlAnchor()); | |||
| //GraphUtils::AddEdge(active3->GetOutControlAnchor(), merge1->GetInControlAnchor()); | |||
| SetNextIteration(merge1, next1); | |||
| SetNextIteration(merge1, next1); // for relink NextIteration --> StreamMerge | |||
| GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_t->GetInControlAnchor()); // Test for not merge. | |||
| @@ -169,6 +171,17 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||
| AttrUtils::SetBool(add1->GetOpDesc(), ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true); | |||
| AttrUtils::SetBool(add1->GetOpDesc(), ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true); | |||
| SetControlFlowGroup(enter1, loop1->GetOpDesc()->GetId()); | |||
| SetControlFlowGroup(active1, loop1->GetOpDesc()->GetId()); | |||
| SetControlFlowGroup(merge1, loop1->GetOpDesc()->GetId()); | |||
| SetControlFlowGroup(loop1, loop1->GetOpDesc()->GetId()); | |||
| SetControlFlowGroup(active2, switch_t->GetOpDesc()->GetId()); | |||
| SetControlFlowGroup(switch_t, switch_t->GetOpDesc()->GetId()); | |||
| SetControlFlowGroup(switch_f, switch_t->GetOpDesc()->GetId()); | |||
| SetControlFlowGroup(next1, loop1->GetOpDesc()->GetId()); | |||
| SetControlFlowGroup(active3, loop1->GetOpDesc()->GetId()); | |||
| SetControlFlowGroup(exit1, loop1->GetOpDesc()->GetId()); | |||
| // Build -> IndexSpecialNodes --> stream_merge_op_nodes_ | |||
| // Build -> LoadGraph -> RelinkNextIteration | |||
| // Build -> LoadGraph -> LoadDynamicSubgraph --> BuildNodeItem --> NodeItem::SetDataSend | |||
| @@ -190,9 +203,23 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||
| task_executor.emplace(NodeExecutorManager::ExecutorType::RTS, std::unique_ptr<NodeExecutor>(new NodeExecutor())); | |||
| task_executor.emplace(NodeExecutorManager::ExecutorType::HOST_CPU, std::unique_ptr<NodeExecutor>(new NodeExecutor())); | |||
| const auto control_group_index = loop1->GetOpDesc()->GetId(); | |||
| HybridModel hybrid_model(ge_root_model); | |||
| HybridModelBuilder hybrid_model_builder(hybrid_model); | |||
| ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS); | |||
| const auto TestFrameGroup = [&hybrid_model](const NodePtr &n, int64_t index) { | |||
| const auto it = hybrid_model.node_items_.find(n); | |||
| ASSERT_NE(hybrid_model.node_items_.end(), it); | |||
| ASSERT_EQ(it->second->frame_index_, index); | |||
| ASSERT_EQ(it->second->parent_frame_, -1); | |||
| }; | |||
| TestFrameGroup(enter1, control_group_index); | |||
| TestFrameGroup(active1, control_group_index); | |||
| TestFrameGroup(active2, control_group_index); | |||
| TestFrameGroup(active3, control_group_index); | |||
| TestFrameGroup(output1, -1); | |||
| engine_mapping.clear(); | |||
| task_executor.clear(); | |||
| } | |||
| @@ -166,6 +166,10 @@ TEST_F(UtestRtsNodeTask, test_stream_active_task) { | |||
| std::function<void()> done = []() {}; | |||
| ASSERT_EQ(node_state->GetSwitchIndex(), -1); | |||
| ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); | |||
| ASSERT_EQ(node_state->GetSwitchIndex(), -1); | |||
| node_item->ctrl_send_.emplace(nullptr); | |||
| ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); | |||
| ASSERT_EQ(node_state->GetSwitchIndex(), 0); | |||
| } | |||