From: @zhangxiaokun9 Reviewed-by: @xchu42,@ji_chen Signed-off-by: @ji_chentags/v1.3.0
@@ -286,13 +286,23 @@ void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t grou | |||
return; | |||
} | |||
SetControlFlowGroup(node, group_index); | |||
} | |||
/// | |||
/// @brief Set Op _control_flow_group flag | |||
/// @param [in] node | |||
/// @param [in] group, condition group index of node. | |||
/// @return | |||
/// | |||
void SetControlFlowGroup(const NodePtr &node, int64_t group) { | |||
GE_RT_VOID_CHECK_NOTNULL(node); | |||
const auto &op_desc = node->GetOpDesc(); | |||
GE_RT_VOID_CHECK_NOTNULL(op_desc); | |||
// op_desc as AttrHolderAdapter valid, Set attribute always success, just log for check. | |||
GELOGD("[%s] Set control flow group index: %ld", node->GetName().c_str(), group_index); | |||
if (!AttrUtils::SetInt(op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { | |||
GELOGD("[%s] Set control flow group index: %ld", node->GetName().c_str(), group); | |||
if (!AttrUtils::SetInt(op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, group)) { | |||
REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), | |||
node->GetName().c_str(), node->GetType().c_str()); | |||
GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), | |||
@@ -133,6 +133,14 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc); | |||
/// @return | |||
/// | |||
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index); | |||
/// | |||
/// @brief Set Op _control_flow_group flag | |||
/// @param [in] node | |||
/// @param [in] group, condition group index of node. | |||
/// @return | |||
/// | |||
void SetControlFlowGroup(const NodePtr &node, int64_t group); | |||
} // namespace ge | |||
#endif // GE_GRAPH_COMMON_OMG_UTIL_H_ |
@@ -186,12 +186,6 @@ bool NextIterationPass::VerifyWhileGroup() { | |||
frame_name.c_str()); | |||
return false; | |||
} | |||
// Mark loop as unknown shape If any merge has unknown shape output. | |||
const auto &op_desc = pair_iter.first->GetOpDesc(); | |||
if (IsUnknownShapeTensor(op_desc->GetOutputDesc(0))) { | |||
loop_group_iter.second->is_unknown_shape = true; // under check loop, cannot break. | |||
} | |||
} | |||
} | |||
@@ -229,7 +223,7 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||
enter_active->GetName().c_str(), enter_active->GetType().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
MarkForceUnknownShape(enter_node, loop_group.is_unknown_shape, group_index); | |||
SetControlFlowGroup(enter_node, group_index); | |||
} | |||
for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { | |||
@@ -264,8 +258,8 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||
return INTERNAL_ERROR; | |||
} | |||
MarkForceUnknownShape(next_node, loop_group.is_unknown_shape, group_index); | |||
MarkForceUnknownShape(merge_node, loop_group.is_unknown_shape, group_index); | |||
SetControlFlowGroup(next_node, group_index); | |||
SetControlFlowGroup(merge_node, group_index); | |||
} | |||
if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || | |||
@@ -274,9 +268,9 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||
return INTERNAL_ERROR; | |||
} | |||
MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape, group_index); | |||
MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape, group_index); | |||
MarkForceUnknownShape(next_active, loop_group.is_unknown_shape, group_index); | |||
SetControlFlowGroup(loop_group.loop_cond, group_index); | |||
SetControlFlowGroup(enter_active, group_index); | |||
SetControlFlowGroup(next_active, group_index); | |||
HandleSwitchExitNodes(loop_group, group_index); | |||
} | |||
@@ -290,17 +284,13 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||
/// @return void | |||
/// | |||
void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) { | |||
if (!loop_group.is_unknown_shape) { | |||
return; | |||
} | |||
for (const auto &switch_node : loop_group.switch_nodes) { | |||
MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape, group_index); | |||
SetControlFlowGroup(switch_node, group_index); | |||
for (const auto &node : switch_node->GetOutDataNodes()) { | |||
std::string node_type; | |||
(void)GetOriginalType(node, node_type); | |||
if (kExitOpTypes.count(node_type) > 0) { | |||
MarkForceUnknownShape(node, loop_group.is_unknown_shape, group_index); | |||
SetControlFlowGroup(node, group_index); | |||
} | |||
} | |||
} | |||
@@ -24,7 +24,6 @@ struct LoopCondGroup { | |||
std::vector<ge::NodePtr> enter_nodes; // Enter nodes | |||
std::vector<std::pair<ge::NodePtr, ge::NodePtr>> merge_next_pairs; // <Merge, NextIteration> | |||
std::vector<ge::NodePtr> switch_nodes; // Switch nodes | |||
bool is_unknown_shape{false}; | |||
}; | |||
using LoopCondGroupPtr = std::shared_ptr<LoopCondGroup>; | |||
@@ -22,6 +22,14 @@ | |||
#include "hybrid_execution_context.h" | |||
#include "subgraph_context.h" | |||
#define INC_ITERATION_COUNT(iteration) \ | |||
do { \ | |||
++iteration; \ | |||
if (iteration == UINT64_MAX) { \ | |||
iteration = 1; \ | |||
} \ | |||
} while (0) | |||
namespace ge { | |||
namespace hybrid { | |||
namespace { | |||
@@ -306,15 +314,45 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() { | |||
return task_context_; | |||
} | |||
void NodeState::ResetContext(uint64_t loop_count) { | |||
loop_count_ = loop_count; | |||
void NodeState::ResetContext(uint64_t iteration) { | |||
switch_index_ = -1; | |||
subgraph_context_->ResetContext(node_item_->node); | |||
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | |||
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | |||
GELOGD("[%s] in while loop, loop count: %lu, data scheduled: %u, ctrl scheduled: %u, merge index: %d", | |||
GetName().c_str(), loop_count_, data_scheduled_, ctrl_scheduled_, merge_index_); | |||
if (iteration == 0) { | |||
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | |||
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | |||
} else { | |||
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size() + node_item_->enter_data_.size()); | |||
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size() + node_item_->enter_ctrl_.size()); | |||
} | |||
iteration_count_ = iteration; | |||
GELOGD("[%s] in while loop, current iteration: %lu, data scheduled: %u, ctrl scheduled: %u, merge index: %d", | |||
GetName().c_str(), iteration_count_, data_scheduled_, ctrl_scheduled_, merge_index_); | |||
} | |||
void NodeState::ScheduleContext(const NodeState &node_state) { | |||
if (node_state.node_item_->IsEnterOp()) { | |||
GELOGD("[%s]{active: %lu, iteration: %lu}, frame{active: %lu, iteration: %lu} [%s]{active: %lu, iteration: %lu}", | |||
GetName().c_str(), active_count_, iteration_count_, frame_state_->active_count_, | |||
frame_state_->iteration_count_, node_state.GetName().c_str(), node_state.frame_state_->active_count_, | |||
node_state.frame_state_->iteration_count_); | |||
if (frame_state_->active_count_ != active_count_) { | |||
ResetContext(0); | |||
active_count_ = frame_state_->active_count_; | |||
} | |||
} else if (node_state.node_item_->IsExitOp()) { | |||
GELOGD("[%s]{active: %lu, iteration: %lu} frame{active: %lu, iteration: %lu} " | |||
"[%s]{active: %lu, iteration: %lu} parent{active: %lu, iteration: %lu}", | |||
GetName().c_str(), active_count_, iteration_count_, frame_state_->active_count_, | |||
frame_state_->iteration_count_, node_state.GetName().c_str(), node_state.frame_state_->active_count_, | |||
node_state.frame_state_->iteration_count_, node_state.frame_state_->parent_frame_->active_count_, | |||
node_state.frame_state_->parent_frame_->iteration_count_); | |||
if (node_state.frame_state_->parent_frame_->iteration_count_ != iteration_count_) { | |||
ResetContext(node_state.frame_state_->parent_frame_->iteration_count_); | |||
} | |||
} else if (node_state.iteration_count_ != iteration_count_) { | |||
ResetContext(node_state.iteration_count_); | |||
} | |||
} | |||
Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &ready) const { | |||
@@ -346,11 +384,11 @@ Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &rea | |||
} | |||
bool NodeState::IsScheduleReady() const { | |||
GELOGD("[%s] loop[%lu] data[input: %zu, scheduled: %u], ctrl[input: %zu+%zu, scheduled: %u]", | |||
GetName().c_str(), loop_count_, node_item_->data_recv_.size(), data_scheduled_, | |||
node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||
GELOGD("[%s] iteration[%lu] data[input: %zu, scheduled: %u], ctrl[input: %zu+%zu, scheduled: %u]", | |||
GetName().c_str(), iteration_count_, node_item_->data_recv_.size(), data_scheduled_, | |||
node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(iteration_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||
if (node_item_->IsMergeOp()) { | |||
if (ctrl_scheduled_ != node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1) + node_item_->ctrl_recv_.size()) { | |||
if (ctrl_scheduled_ != node_item_->GetMergeCtrl(iteration_count_ == 0 ? 0 : 1) + node_item_->ctrl_recv_.size()) { | |||
return false; | |||
} | |||
@@ -366,15 +404,13 @@ bool NodeState::IsScheduleReady() const { | |||
} | |||
void NodeState::SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready) { | |||
GELOGD("[%s] schedule [%s], loop[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]", | |||
node_state.GetName().c_str(), GetName().c_str(), loop_count_, node_state.loop_count_, | |||
GELOGD("[%s] schedule [%s], iteration[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]", | |||
node_state.GetName().c_str(), GetName().c_str(), iteration_count_, node_state.iteration_count_, | |||
node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(), | |||
node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||
node_item_->GetMergeCtrl(iteration_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||
std::lock_guard<std::mutex> lk(mu_); | |||
if (loop_count_ != node_state.loop_count_) { | |||
ResetContext(node_state.loop_count_); | |||
} | |||
ScheduleContext(node_state); | |||
++data_scheduled_; | |||
if (node_item_->IsMergeOp()) { | |||
@@ -394,15 +430,13 @@ void NodeState::SetDataSchedule(const NodeState &node_state, const std::function | |||
} | |||
void NodeState::SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready) { | |||
GELOGD("[%s] schedule [%s], loop[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]", | |||
node_state.GetName().c_str(), GetName().c_str(), loop_count_, node_state.loop_count_, | |||
GELOGD("[%s] schedule [%s], iteration[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]", | |||
node_state.GetName().c_str(), GetName().c_str(), iteration_count_, node_state.iteration_count_, | |||
node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(), | |||
node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||
node_item_->GetMergeCtrl(iteration_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||
std::lock_guard<std::mutex> lk(mu_); | |||
if (loop_count_ != node_state.loop_count_) { | |||
ResetContext(node_state.loop_count_); | |||
} | |||
ScheduleContext(node_state); | |||
++ctrl_scheduled_; | |||
if (IsScheduleReady()) { | |||
@@ -410,21 +444,28 @@ void NodeState::SetCtrlSchedule(const NodeState &node_state, const std::function | |||
} | |||
} | |||
void NodeState::RunLoopNext() { | |||
GELOGD("Node[%s] run in loop, current count: %lu", GetName().c_str(), loop_count_); | |||
void NodeState::RunNextIteration() { | |||
std::lock_guard<std::mutex> lk(mu_); | |||
++loop_count_; | |||
if (loop_count_ == UINT64_MAX) { | |||
loop_count_ = 1; | |||
} | |||
ResetContext(loop_count_); | |||
INC_ITERATION_COUNT(iteration_count_); | |||
ResetContext(iteration_count_); | |||
} | |||
void NodeState::RunLoopExit() { | |||
GELOGD("Node[%s] run in loop, current count: %lu", GetName().c_str(), loop_count_); | |||
void NodeState::RunStreamActive() { | |||
std::lock_guard<std::mutex> lk(mu_); | |||
loop_count_ = 0; | |||
if (node_item_->ctrl_send_.empty()) { // Not for Loop Enter or Loop Next. | |||
return; | |||
} | |||
switch_index_ = 0; | |||
data_scheduled_ = 0; | |||
ctrl_scheduled_ = 0; | |||
if (node_item_->is_enter_active_) { | |||
frame_state_->iteration_count_ = 0; | |||
INC_ITERATION_COUNT(frame_state_->active_count_); | |||
} else { | |||
INC_ITERATION_COUNT(frame_state_->iteration_count_); | |||
} | |||
GELOGD("Node[%s] current iteration: %lu, frame active: %lu, frame iteration: %lu", | |||
GetName().c_str(), iteration_count_, frame_state_->active_count_, frame_state_->iteration_count_); | |||
} | |||
void NodeState::SetScheduleFuture(std::future<Status> &&future) { | |||
@@ -33,8 +33,10 @@ struct GraphExecutionContext; | |||
class SubgraphContext; | |||
class TaskContext; | |||
struct NodeState; | |||
struct FrameState; | |||
using NodeStatePtr = std::shared_ptr<NodeState>; | |||
using FrameStatePtr = std::shared_ptr<FrameState>; | |||
class ShapeFuture { | |||
public: | |||
@@ -80,6 +82,18 @@ struct ShapeInferenceState { | |||
std::mutex mu_; | |||
}; | |||
struct FrameState { | |||
public: | |||
FrameState(int64_t id) : frame_id_(id) {} | |||
~FrameState() = default; | |||
int64_t frame_id_{0}; | |||
uint64_t active_count_{0}; | |||
uint64_t iteration_count_{0}; | |||
std::shared_ptr<FrameState> parent_frame_; | |||
}; | |||
// saving sth. dynamic during execution | |||
struct NodeState { | |||
public: | |||
@@ -112,8 +126,8 @@ struct NodeState { | |||
return node_item_->IsControlFlowOp() || node_item_->shape_inference_type >= DEPEND_SHAPE_RANGE; | |||
} | |||
void RunLoopNext(); | |||
void RunLoopExit(); | |||
void RunStreamActive(); | |||
void RunNextIteration(); | |||
Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | |||
@@ -144,6 +158,10 @@ struct NodeState { | |||
return group_; | |||
} | |||
void SetFrameState(const shared_ptr<FrameState> &frame_state) { | |||
frame_state_ = frame_state; | |||
} | |||
const shared_ptr<NodeTask> &GetKernelTask() const { | |||
return kernel_task_; | |||
} | |||
@@ -167,7 +185,8 @@ struct NodeState { | |||
bool IsScheduleReady() const; | |||
void SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | |||
void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | |||
void ResetContext(uint64_t loop_count); | |||
void ResetContext(uint64_t iteration); | |||
void ScheduleContext(const NodeState &node_state); | |||
const NodeItem *node_item_ = nullptr; | |||
std::shared_ptr<NodeTask> kernel_task_ = nullptr; | |||
@@ -179,7 +198,9 @@ struct NodeState { | |||
std::mutex mu_; | |||
std::future<Status> schedule_future_; | |||
uint64_t loop_count_ = 0; | |||
std::shared_ptr<FrameState> frame_state_; | |||
uint64_t active_count_ = 0; | |||
uint64_t iteration_count_ = 0; | |||
uint32_t ctrl_scheduled_ = 0; | |||
uint32_t data_scheduled_ = 0; | |||
int merge_index_ = -1; // Use for Execute (Reset after Executed). | |||
@@ -89,6 +89,7 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { | |||
if (node_state == nullptr) { | |||
const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); | |||
node_state.reset(new(std::nothrow)NodeState(*node_item, this)); | |||
node_state->SetFrameState(GetOrCreateFrameState(*node_item)); | |||
node_state->SetGroup(group_); | |||
(void)guard; | |||
} | |||
@@ -102,6 +103,20 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { | |||
return node_state; | |||
} | |||
FrameStatePtr SubgraphContext::GetOrCreateFrameState(const NodeItem &node_item) { | |||
auto &frame_state = frame_states_[node_item.frame_index_]; | |||
if (frame_state == nullptr) { | |||
GELOGD("[%s] Create FrameState, frame index: %ld, parent frame index: %ld", | |||
node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_); | |||
frame_state.reset(new(std::nothrow)FrameState(node_item.frame_index_)); | |||
if (node_item.frame_index_ != -1) { // -1 is root frame. | |||
frame_state->parent_frame_ = frame_states_[node_item.parent_frame_]; | |||
} | |||
} | |||
return frame_state; | |||
} | |||
Status SubgraphContext::SetInput(int index, const TensorValue &tensor) { | |||
if (static_cast<size_t>(index) >= all_inputs_.size()) { | |||
GELOGE(INTERNAL_ERROR, | |||
@@ -51,6 +51,7 @@ class SubgraphContext { | |||
void NodeDone(const NodePtr &node); | |||
private: | |||
FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock | |||
friend class TaskContext; | |||
const GraphItem *graph_item_; | |||
const GraphExecutionContext *execution_context_; | |||
@@ -59,6 +60,7 @@ class SubgraphContext { | |||
std::vector<TensorValue> all_outputs_; | |||
NodeDoneManager node_done_manager_; | |||
std::unordered_map<const NodeItem *, NodeStatePtr> node_states_; | |||
std::unordered_map<int64_t, FrameStatePtr> frame_states_; | |||
int group_ = -1; | |||
}; | |||
} // namespace hybrid | |||
@@ -1945,6 +1945,7 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root | |||
GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item)); | |||
GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task | |||
GE_CHK_STATUS_RET_NOLOG(BuildFrameGroupIndex(*node_item)); | |||
GE_CHK_STATUS_RET_NOLOG(BuildControlFlowGroup(*graph_item, node, node_item)); | |||
if (node->GetInAllNodes().empty()) { | |||
graph_item->root_items_.emplace_back(node_item); | |||
@@ -2308,6 +2309,62 @@ Status HybridModelBuilder::BuildProfilingControl(GraphItem &graph_item, | |||
return SUCCESS; | |||
} | |||
Status HybridModelBuilder::BuildFrameGroupIndex(NodeItem &node_item) { | |||
if (node_item.is_root_node_) { | |||
GELOGD("[%s] control flow frame group: %ld, parent frame: %ld", | |||
node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_); | |||
return SUCCESS; | |||
} | |||
int64_t ctrl_flow_group = -1; | |||
if (node_item.IsEnterOp() && AttrUtils::GetInt(node_item.op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, ctrl_flow_group)) { | |||
node_item.frame_index_ = ctrl_flow_group; | |||
for (const auto src_node : node_item.node->GetInAllNodes()) { | |||
NodeItem *src_node_item = nullptr; | |||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(src_node, &src_node_item), | |||
"[%s] failed to get or create node item", src_node->GetName().c_str()); | |||
if (!src_node_item->is_root_node_) { | |||
GELOGD("[%s] frame index: %ld, [%s] parent frame index: %ld", node_item.node_name.c_str(), | |||
node_item.frame_index_, src_node_item->node_name.c_str(), src_node_item->frame_index_); | |||
parent_frame_group_[node_item.frame_index_] = src_node_item->frame_index_; | |||
break; | |||
} | |||
} | |||
const auto it = parent_frame_group_.find(node_item.frame_index_); | |||
node_item.parent_frame_ = (it != parent_frame_group_.end()) ? it->second : -1; | |||
GELOGD("[%s] control flow frame group: %ld, parent frame: %ld", | |||
node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_); | |||
return SUCCESS; | |||
} | |||
for (const auto src_node : node_item.node->GetInAllNodes()) { | |||
NodeItem *src_node_item = nullptr; | |||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(src_node, &src_node_item), | |||
"[%s] failed to get or create node item", src_node->GetName().c_str()); | |||
if (src_node_item->is_root_node_) { | |||
continue; | |||
} | |||
if (src_node_item->IsExitOp()) { | |||
const auto it = parent_frame_group_.find(src_node_item->frame_index_); | |||
node_item.frame_index_ = (it != parent_frame_group_.end()) ? it->second : -1; | |||
} else { | |||
node_item.frame_index_ = src_node_item->frame_index_; | |||
} | |||
const auto it = parent_frame_group_.find(node_item.frame_index_); | |||
node_item.parent_frame_ = (it != parent_frame_group_.end()) ? it->second : -1; | |||
GELOGD("[%s] control flow frame group: %ld, parent frame: %ld", | |||
node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_); | |||
return SUCCESS; | |||
} | |||
GELOGD("[%s] control flow frame group: %ld, parent frame: %ld", | |||
node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_); | |||
return SUCCESS; | |||
} | |||
Status HybridModelBuilder::BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item) { | |||
GELOGD("Build control flow for node %s", node->GetName().c_str()); | |||
using GroupBuilder = std::function<Status(HybridModelBuilder *, const NodePtr &, NodeItem *)>; | |||
@@ -2427,6 +2484,7 @@ Status HybridModelBuilder::CreateStreamActiveGroup(const NodePtr &node, NodeItem | |||
if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsEnterNode)) { | |||
// Enter --> StreamActive --> StreamMerge | |||
node_item->is_enter_active_ = true; | |||
return CreateMergeEnterGroup(node, node_item); | |||
} else if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsIterationNode)) { | |||
// NextIteration --> StreamActive {-->} StreamMerge | |||
@@ -97,6 +97,7 @@ class HybridModelBuilder { | |||
Status RelinkNextIteration(); | |||
Status BuildProfilingControl(GraphItem &graph_item, const std::map<size_t, std::pair<uint32_t, uint32_t>> &nodes); | |||
Status BuildFrameGroupIndex(NodeItem &node_item); | |||
Status BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item); | |||
Status CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item); | |||
Status CreateMergeEnterGroup(const NodePtr &node, NodeItem *node_item); | |||
@@ -123,6 +124,7 @@ class HybridModelBuilder { | |||
std::map<std::string, NodePtr> constant_op_nodes_; | |||
std::map<std::string, NodePtr> stream_merge_op_nodes_; | |||
std::map<std::string, NodePtr> next_iteration_op_nodes_; | |||
std::map<int64_t, int64_t> parent_frame_group_; | |||
std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_; | |||
std::map<NodeItem *, std::set<std::string>> node_to_parallel_groups_; | |||
@@ -20,7 +20,6 @@ | |||
#include "graph/common/omg_util.h" | |||
#include "graph/compute_graph.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/utils/node_utils.h" | |||
#include "hybrid/executor/worker/shape_inference_engine.h" | |||
#include "hybrid/node_executor/node_executor.h" | |||
@@ -34,7 +33,7 @@ const std::set<std::string> kControlOpTypes{ | |||
}; | |||
const std::set<std::string> kControlFlowOpTypes{ | |||
STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT, | |||
STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT, | |||
LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX | |||
}; | |||
@@ -402,8 +401,8 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { | |||
node_item->root_data_.emplace(this); | |||
} | |||
// If Enter feed Not Merge, take as root Node. | |||
if ((kEnterOpTypes.count(node_type) > 0) && (node_item->node_type != STREAMMERGE)) { | |||
node_item->root_data_.emplace(this); | |||
if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { | |||
node_item->enter_data_.emplace(this); | |||
node_item->enter_inside_.emplace(anchor_index); | |||
} | |||
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | |||
@@ -422,8 +421,8 @@ void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { | |||
node_item->root_ctrl_.emplace(this); | |||
} | |||
// If Enter feed control signal, take as root Node. | |||
if (kEnterOpTypes.count(node_type) > 0) { | |||
node_item->root_ctrl_.emplace(this); | |||
if (IsEnterOp() && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) { | |||
node_item->enter_ctrl_.emplace(this); | |||
} | |||
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | |||
} | |||
@@ -22,6 +22,7 @@ | |||
#include "external/ge/ge_api_error_codes.h" | |||
#include "graph/node.h" | |||
#include "graph/op_desc.h" | |||
#include "graph/utils/node_utils.h" | |||
#include "framework/common/types.h" | |||
#include "hybrid/common/tensor_value.h" | |||
@@ -92,6 +93,14 @@ struct NodeItem { | |||
return is_merge_op_; | |||
} | |||
bool IsEnterOp() const { | |||
return kEnterOpTypes.count(node_type) > 0; | |||
} | |||
bool IsExitOp() const { | |||
return kExitOpTypes.count(node_type) > 0; | |||
} | |||
bool IsHcclOp() const; | |||
void SetToDynamic(); | |||
@@ -135,8 +144,13 @@ struct NodeItem { | |||
bool is_ctrl_flow_v2_op_ = false; | |||
bool is_ctrl_flow_op_ = false; | |||
bool is_merge_op_ = false; | |||
bool is_enter_active_ = false; | |||
int64_t frame_index_ = -1; | |||
int64_t parent_frame_ = -1; | |||
std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node | |||
std::set<const NodeItem *> root_data_; // Recv data from root node | |||
std::set<const NodeItem *> enter_ctrl_; // Recv ctrl from Enter node | |||
std::set<const NodeItem *> enter_data_; // Recv data from Enter node | |||
std::set<const NodeItem *> data_send_; // Send data notify to | |||
std::map<const NodeItem *, int> data_recv_; // Recv data notify from | |||
std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | |||
@@ -90,7 +90,7 @@ Status RtsNodeTask::GetScalarIndexValue(TaskContext &task_context, uint32_t inde | |||
Status StreamActiveNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { | |||
GELOGD("[%s] Start to execute.", task_context.GetNodeName()); | |||
const auto &node_state = task_context.GetNodeState(); | |||
node_state->SetSwitchIndex(0); | |||
node_state->RunStreamActive(); | |||
if (done_callback) { | |||
GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); | |||
} | |||
@@ -204,9 +204,7 @@ Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::functio | |||
const auto &node_state = task_context.GetNodeState(); | |||
if (kNextIterationOpTypes.count(node_state->GetType()) > 0) { | |||
node_state->RunLoopNext(); | |||
} else if (kExitOpTypes.count(node_state->GetType()) > 0) { | |||
node_state->RunLoopExit(); | |||
node_state->RunNextIteration(); | |||
} | |||
if (done_callback) { | |||
@@ -127,7 +127,7 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt | |||
AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); | |||
} | |||
const auto less1 = CreateNode(graph, "less", EXIT, 2, 1); // Mock for less, just pass input0. | |||
const auto less1 = CreateNode(graph, "less", IDENTITY, 2, 1); // Mock for less, just pass input0. | |||
const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0); | |||
switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0); | |||
@@ -135,8 +135,8 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt | |||
AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_EQUAL); // 101 for true. | |||
AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_NOT_EQUAL); | |||
const auto add1 = CreateNode(graph, "add", EXIT, 2, 1); // Mock for add, just pass input0. | |||
const auto sub1 = CreateNode(graph, "sub", EXIT, 2, 1); // Mock for sub, just pass input0. | |||
const auto add1 = CreateNode(graph, "add", IDENTITY, 2, 1); // Mock for add, just pass input0. | |||
const auto sub1 = CreateNode(graph, "sub", IDENTITY, 2, 1); // Mock for sub, just pass input0. | |||
const auto merge1 = CreateNode(graph, "merge", STREAMMERGE, 2, 2); | |||
const auto active2 = CreateNode(graph, "active2", STREAMACTIVE, 0, 0); | |||
@@ -89,7 +89,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||
* \ / \. | |||
* Switch Add | |||
* / | | | |||
* / | | | |||
* Active / | | | |||
* / | | | |||
* LoopCond | | | |||
* \ | | | |||
@@ -98,9 +98,10 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||
* Less | | | |||
* \ | NextIteration | |||
* \ | | | |||
* \ | | | |||
* \ | | Active | |||
* Merge <---------| | |||
* | | |||
* | Active | |||
* | | |||
* Enter | |||
******************************************************************************/ | |||
@@ -110,6 +111,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||
GeModelPtr ge_sub_model = make_shared<GeModel>(); | |||
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||
auto data1 = CreateNode(*graph, "data", DATA, 1, 1); | |||
auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); | |||
auto merge1 = CreateNode(*graph, "merge", STREAMMERGE, 2, 2); | |||
auto less1 = CreateNode(*graph, "less", LESS, 2, 1); | |||
@@ -129,6 +131,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||
auto active3 = CreateNode(*graph, "active3", STREAMACTIVE, 0, 0); | |||
auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); | |||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); | |||
@@ -153,8 +156,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||
GraphUtils::AddEdge(active1->GetOutControlAnchor(), merge1->GetInControlAnchor()); | |||
GraphUtils::AddEdge(next1->GetOutControlAnchor(), active3->GetInControlAnchor()); | |||
//GraphUtils::AddEdge(active3->GetOutControlAnchor(), merge1->GetInControlAnchor()); | |||
SetNextIteration(merge1, next1); | |||
SetNextIteration(merge1, next1); // for relink NextIteration --> StreamMerge | |||
GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_t->GetInControlAnchor()); // Test for not merge. | |||
@@ -169,6 +171,17 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||
AttrUtils::SetBool(add1->GetOpDesc(), ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true); | |||
AttrUtils::SetBool(add1->GetOpDesc(), ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true); | |||
SetControlFlowGroup(enter1, loop1->GetOpDesc()->GetId()); | |||
SetControlFlowGroup(active1, loop1->GetOpDesc()->GetId()); | |||
SetControlFlowGroup(merge1, loop1->GetOpDesc()->GetId()); | |||
SetControlFlowGroup(loop1, loop1->GetOpDesc()->GetId()); | |||
SetControlFlowGroup(active2, switch_t->GetOpDesc()->GetId()); | |||
SetControlFlowGroup(switch_t, switch_t->GetOpDesc()->GetId()); | |||
SetControlFlowGroup(switch_f, switch_t->GetOpDesc()->GetId()); | |||
SetControlFlowGroup(next1, loop1->GetOpDesc()->GetId()); | |||
SetControlFlowGroup(active3, loop1->GetOpDesc()->GetId()); | |||
SetControlFlowGroup(exit1, loop1->GetOpDesc()->GetId()); | |||
// Build -> IndexSpecialNodes --> stream_merge_op_nodes_ | |||
// Build -> LoadGraph -> RelinkNextIteration | |||
// Build -> LoadGraph -> LoadDynamicSubgraph --> BuildNodeItem --> NodeItem::SetDataSend | |||
@@ -190,9 +203,23 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||
task_executor.emplace(NodeExecutorManager::ExecutorType::RTS, std::unique_ptr<NodeExecutor>(new NodeExecutor())); | |||
task_executor.emplace(NodeExecutorManager::ExecutorType::HOST_CPU, std::unique_ptr<NodeExecutor>(new NodeExecutor())); | |||
const auto control_group_index = loop1->GetOpDesc()->GetId(); | |||
HybridModel hybrid_model(ge_root_model); | |||
HybridModelBuilder hybrid_model_builder(hybrid_model); | |||
ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS); | |||
const auto TestFrameGroup = [&hybrid_model](const NodePtr &n, int64_t index) { | |||
const auto it = hybrid_model.node_items_.find(n); | |||
ASSERT_NE(hybrid_model.node_items_.end(), it); | |||
ASSERT_EQ(it->second->frame_index_, index); | |||
ASSERT_EQ(it->second->parent_frame_, -1); | |||
}; | |||
TestFrameGroup(enter1, control_group_index); | |||
TestFrameGroup(active1, control_group_index); | |||
TestFrameGroup(active2, control_group_index); | |||
TestFrameGroup(active3, control_group_index); | |||
TestFrameGroup(output1, -1); | |||
engine_mapping.clear(); | |||
task_executor.clear(); | |||
} | |||
@@ -166,6 +166,10 @@ TEST_F(UtestRtsNodeTask, test_stream_active_task) { | |||
std::function<void()> done = []() {}; | |||
ASSERT_EQ(node_state->GetSwitchIndex(), -1); | |||
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); | |||
ASSERT_EQ(node_state->GetSwitchIndex(), -1); | |||
node_item->ctrl_send_.emplace(nullptr); | |||
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); | |||
ASSERT_EQ(node_state->GetSwitchIndex(), 0); | |||
} | |||