@@ -286,13 +286,23 @@ void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t grou | |||||
return; | return; | ||||
} | } | ||||
SetControlFlowGroup(node, group_index); | |||||
} | |||||
/// | |||||
/// @brief Set Op _control_flow_group flag | |||||
/// @param [in] node | |||||
/// @param [in] group, condition group index of node. | |||||
/// @return | |||||
/// | |||||
void SetControlFlowGroup(const NodePtr &node, int64_t group) { | |||||
GE_RT_VOID_CHECK_NOTNULL(node); | GE_RT_VOID_CHECK_NOTNULL(node); | ||||
const auto &op_desc = node->GetOpDesc(); | const auto &op_desc = node->GetOpDesc(); | ||||
GE_RT_VOID_CHECK_NOTNULL(op_desc); | GE_RT_VOID_CHECK_NOTNULL(op_desc); | ||||
// op_desc as AttrHolderAdapter valid, Set attribute always success, just log for check. | // op_desc as AttrHolderAdapter valid, Set attribute always success, just log for check. | ||||
GELOGD("[%s] Set control flow group index: %ld", node->GetName().c_str(), group_index); | |||||
if (!AttrUtils::SetInt(op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { | |||||
GELOGD("[%s] Set control flow group index: %ld", node->GetName().c_str(), group); | |||||
if (!AttrUtils::SetInt(op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, group)) { | |||||
REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), | REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), | ||||
node->GetName().c_str(), node->GetType().c_str()); | node->GetName().c_str(), node->GetType().c_str()); | ||||
GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), | GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), | ||||
@@ -133,6 +133,14 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc); | |||||
/// @return | /// @return | ||||
/// | /// | ||||
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index); | void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index); | ||||
/// | |||||
/// @brief Set Op _control_flow_group flag | |||||
/// @param [in] node | |||||
/// @param [in] group, condition group index of node. | |||||
/// @return | |||||
/// | |||||
void SetControlFlowGroup(const NodePtr &node, int64_t group); | |||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_GRAPH_COMMON_OMG_UTIL_H_ | #endif // GE_GRAPH_COMMON_OMG_UTIL_H_ |
@@ -183,12 +183,6 @@ bool NextIterationPass::VerifyWhileGroup() { | |||||
frame_name.c_str()); | frame_name.c_str()); | ||||
return false; | return false; | ||||
} | } | ||||
// Mark loop as unknown shape If any merge has unknown shape output. | |||||
const auto &op_desc = pair_iter.first->GetOpDesc(); | |||||
if (IsUnknownShapeTensor(op_desc->GetOutputDesc(0))) { | |||||
loop_group_iter.second->is_unknown_shape = true; // under check loop, cannot break. | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -225,7 +219,7 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||||
enter_active->GetName().c_str()); | enter_active->GetName().c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
MarkForceUnknownShape(enter_node, loop_group.is_unknown_shape, group_index); | |||||
SetControlFlowGroup(enter_node, group_index); | |||||
} | } | ||||
for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { | for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { | ||||
@@ -255,8 +249,8 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
MarkForceUnknownShape(next_node, loop_group.is_unknown_shape, group_index); | |||||
MarkForceUnknownShape(merge_node, loop_group.is_unknown_shape, group_index); | |||||
SetControlFlowGroup(next_node, group_index); | |||||
SetControlFlowGroup(merge_node, group_index); | |||||
} | } | ||||
if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || | if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || | ||||
@@ -265,9 +259,9 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape, group_index); | |||||
MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape, group_index); | |||||
MarkForceUnknownShape(next_active, loop_group.is_unknown_shape, group_index); | |||||
SetControlFlowGroup(loop_group.loop_cond, group_index); | |||||
SetControlFlowGroup(enter_active, group_index); | |||||
SetControlFlowGroup(next_active, group_index); | |||||
HandleSwitchExitNodes(loop_group, group_index); | HandleSwitchExitNodes(loop_group, group_index); | ||||
} | } | ||||
@@ -281,17 +275,13 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||||
/// @return void | /// @return void | ||||
/// | /// | ||||
void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) { | void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) { | ||||
if (!loop_group.is_unknown_shape) { | |||||
return; | |||||
} | |||||
for (const auto &switch_node : loop_group.switch_nodes) { | for (const auto &switch_node : loop_group.switch_nodes) { | ||||
MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape, group_index); | |||||
SetControlFlowGroup(switch_node, group_index); | |||||
for (const auto &node : switch_node->GetOutDataNodes()) { | for (const auto &node : switch_node->GetOutDataNodes()) { | ||||
std::string node_type; | std::string node_type; | ||||
(void)GetOriginalType(node, node_type); | (void)GetOriginalType(node, node_type); | ||||
if (kExitOpTypes.count(node_type) > 0) { | if (kExitOpTypes.count(node_type) > 0) { | ||||
MarkForceUnknownShape(node, loop_group.is_unknown_shape, group_index); | |||||
SetControlFlowGroup(node, group_index); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -24,7 +24,6 @@ struct LoopCondGroup { | |||||
std::vector<ge::NodePtr> enter_nodes; // Enter nodes | std::vector<ge::NodePtr> enter_nodes; // Enter nodes | ||||
std::vector<std::pair<ge::NodePtr, ge::NodePtr>> merge_next_pairs; // <Merge, NextIteration> | std::vector<std::pair<ge::NodePtr, ge::NodePtr>> merge_next_pairs; // <Merge, NextIteration> | ||||
std::vector<ge::NodePtr> switch_nodes; // Switch nodes | std::vector<ge::NodePtr> switch_nodes; // Switch nodes | ||||
bool is_unknown_shape{false}; | |||||
}; | }; | ||||
using LoopCondGroupPtr = std::shared_ptr<LoopCondGroup>; | using LoopCondGroupPtr = std::shared_ptr<LoopCondGroup>; | ||||
@@ -22,6 +22,14 @@ | |||||
#include "hybrid_execution_context.h" | #include "hybrid_execution_context.h" | ||||
#include "subgraph_context.h" | #include "subgraph_context.h" | ||||
#define INC_ITERATION_COUNT(iteration) \ | |||||
do { \ | |||||
++iteration; \ | |||||
if (iteration == UINT64_MAX) { \ | |||||
iteration = 1; \ | |||||
} \ | |||||
} while (0) | |||||
namespace ge { | namespace ge { | ||||
namespace hybrid { | namespace hybrid { | ||||
namespace { | namespace { | ||||
@@ -306,15 +314,45 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() { | |||||
return task_context_; | return task_context_; | ||||
} | } | ||||
void NodeState::ResetContext(uint64_t loop_count) { | |||||
loop_count_ = loop_count; | |||||
void NodeState::ResetContext(uint64_t iteration) { | |||||
switch_index_ = -1; | switch_index_ = -1; | ||||
subgraph_context_->ResetContext(node_item_->node); | subgraph_context_->ResetContext(node_item_->node); | ||||
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | |||||
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | |||||
GELOGD("[%s] in while loop, loop count: %lu, data scheduled: %u, ctrl scheduled: %u, merge index: %d", | |||||
GetName().c_str(), loop_count_, data_scheduled_, ctrl_scheduled_, merge_index_); | |||||
if (iteration == 0) { | |||||
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | |||||
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | |||||
} else { | |||||
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size() + node_item_->enter_data_.size()); | |||||
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size() + node_item_->enter_ctrl_.size()); | |||||
} | |||||
iteration_count_ = iteration; | |||||
GELOGD("[%s] in while loop, current iteration: %lu, data scheduled: %u, ctrl scheduled: %u, merge index: %d", | |||||
GetName().c_str(), iteration_count_, data_scheduled_, ctrl_scheduled_, merge_index_); | |||||
} | |||||
void NodeState::ScheduleContext(const NodeState &node_state) { | |||||
if (node_state.node_item_->IsEnterOp()) { | |||||
GELOGD("[%s]{active: %lu, iteration: %lu}, frame{active: %lu, iteration: %lu} [%s]{active: %lu, iteration: %lu}", | |||||
GetName().c_str(), active_count_, iteration_count_, frame_state_->active_count_, | |||||
frame_state_->iteration_count_, node_state.GetName().c_str(), node_state.frame_state_->active_count_, | |||||
node_state.frame_state_->iteration_count_); | |||||
if (frame_state_->active_count_ != active_count_) { | |||||
ResetContext(0); | |||||
active_count_ = frame_state_->active_count_; | |||||
} | |||||
} else if (node_state.node_item_->IsExitOp()) { | |||||
GELOGD("[%s]{active: %lu, iteration: %lu} frame{active: %lu, iteration: %lu} " | |||||
"[%s]{active: %lu, iteration: %lu} parent{active: %lu, iteration: %lu}", | |||||
GetName().c_str(), active_count_, iteration_count_, frame_state_->active_count_, | |||||
frame_state_->iteration_count_, node_state.GetName().c_str(), node_state.frame_state_->active_count_, | |||||
node_state.frame_state_->iteration_count_, node_state.frame_state_->parent_frame_->active_count_, | |||||
node_state.frame_state_->parent_frame_->iteration_count_); | |||||
if (node_state.frame_state_->parent_frame_->iteration_count_ != iteration_count_) { | |||||
ResetContext(node_state.frame_state_->parent_frame_->iteration_count_); | |||||
} | |||||
} else if (node_state.iteration_count_ != iteration_count_) { | |||||
ResetContext(node_state.iteration_count_); | |||||
} | |||||
} | } | ||||
Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &ready) const { | Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &ready) const { | ||||
@@ -346,11 +384,11 @@ Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &rea | |||||
} | } | ||||
bool NodeState::IsScheduleReady() const { | bool NodeState::IsScheduleReady() const { | ||||
GELOGD("[%s] loop[%lu] data[input: %zu, scheduled: %u], ctrl[input: %zu+%zu, scheduled: %u]", | |||||
GetName().c_str(), loop_count_, node_item_->data_recv_.size(), data_scheduled_, | |||||
node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||||
GELOGD("[%s] iteration[%lu] data[input: %zu, scheduled: %u], ctrl[input: %zu+%zu, scheduled: %u]", | |||||
GetName().c_str(), iteration_count_, node_item_->data_recv_.size(), data_scheduled_, | |||||
node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(iteration_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||||
if (node_item_->IsMergeOp()) { | if (node_item_->IsMergeOp()) { | ||||
if (ctrl_scheduled_ != node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1) + node_item_->ctrl_recv_.size()) { | |||||
if (ctrl_scheduled_ != node_item_->GetMergeCtrl(iteration_count_ == 0 ? 0 : 1) + node_item_->ctrl_recv_.size()) { | |||||
return false; | return false; | ||||
} | } | ||||
@@ -366,15 +404,13 @@ bool NodeState::IsScheduleReady() const { | |||||
} | } | ||||
void NodeState::SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready) { | void NodeState::SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready) { | ||||
GELOGD("[%s] schedule [%s], loop[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]", | |||||
node_state.GetName().c_str(), GetName().c_str(), loop_count_, node_state.loop_count_, | |||||
GELOGD("[%s] schedule [%s], iteration[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]", | |||||
node_state.GetName().c_str(), GetName().c_str(), iteration_count_, node_state.iteration_count_, | |||||
node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(), | node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(), | ||||
node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||||
node_item_->GetMergeCtrl(iteration_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||||
std::lock_guard<std::mutex> lk(mu_); | std::lock_guard<std::mutex> lk(mu_); | ||||
if (loop_count_ != node_state.loop_count_) { | |||||
ResetContext(node_state.loop_count_); | |||||
} | |||||
ScheduleContext(node_state); | |||||
++data_scheduled_; | ++data_scheduled_; | ||||
if (node_item_->IsMergeOp()) { | if (node_item_->IsMergeOp()) { | ||||
@@ -394,15 +430,13 @@ void NodeState::SetDataSchedule(const NodeState &node_state, const std::function | |||||
} | } | ||||
void NodeState::SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready) { | void NodeState::SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready) { | ||||
GELOGD("[%s] schedule [%s], loop[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]", | |||||
node_state.GetName().c_str(), GetName().c_str(), loop_count_, node_state.loop_count_, | |||||
GELOGD("[%s] schedule [%s], iteration[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]", | |||||
node_state.GetName().c_str(), GetName().c_str(), iteration_count_, node_state.iteration_count_, | |||||
node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(), | node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(), | ||||
node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||||
node_item_->GetMergeCtrl(iteration_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||||
std::lock_guard<std::mutex> lk(mu_); | std::lock_guard<std::mutex> lk(mu_); | ||||
if (loop_count_ != node_state.loop_count_) { | |||||
ResetContext(node_state.loop_count_); | |||||
} | |||||
ScheduleContext(node_state); | |||||
++ctrl_scheduled_; | ++ctrl_scheduled_; | ||||
if (IsScheduleReady()) { | if (IsScheduleReady()) { | ||||
@@ -410,21 +444,28 @@ void NodeState::SetCtrlSchedule(const NodeState &node_state, const std::function | |||||
} | } | ||||
} | } | ||||
void NodeState::RunLoopNext() { | |||||
GELOGD("Node[%s] run in loop, current count: %lu", GetName().c_str(), loop_count_); | |||||
void NodeState::RunNextIteration() { | |||||
std::lock_guard<std::mutex> lk(mu_); | std::lock_guard<std::mutex> lk(mu_); | ||||
++loop_count_; | |||||
if (loop_count_ == UINT64_MAX) { | |||||
loop_count_ = 1; | |||||
} | |||||
ResetContext(loop_count_); | |||||
INC_ITERATION_COUNT(iteration_count_); | |||||
ResetContext(iteration_count_); | |||||
} | } | ||||
void NodeState::RunLoopExit() { | |||||
GELOGD("Node[%s] run in loop, current count: %lu", GetName().c_str(), loop_count_); | |||||
void NodeState::RunStreamActive() { | |||||
std::lock_guard<std::mutex> lk(mu_); | std::lock_guard<std::mutex> lk(mu_); | ||||
loop_count_ = 0; | |||||
if (node_item_->ctrl_send_.empty()) { // Not for Loop Enter or Loop Next. | |||||
return; | |||||
} | |||||
switch_index_ = 0; | |||||
data_scheduled_ = 0; | |||||
ctrl_scheduled_ = 0; | |||||
if (node_item_->is_enter_active_) { | |||||
frame_state_->iteration_count_ = 0; | |||||
INC_ITERATION_COUNT(frame_state_->active_count_); | |||||
} else { | |||||
INC_ITERATION_COUNT(frame_state_->iteration_count_); | |||||
} | |||||
GELOGD("Node[%s] current iteration: %lu, frame active: %lu, frame iteration: %lu", | |||||
GetName().c_str(), iteration_count_, frame_state_->active_count_, frame_state_->iteration_count_); | |||||
} | } | ||||
void NodeState::SetScheduleFuture(std::future<Status> &&future) { | void NodeState::SetScheduleFuture(std::future<Status> &&future) { | ||||
@@ -33,8 +33,10 @@ struct GraphExecutionContext; | |||||
class SubgraphContext; | class SubgraphContext; | ||||
class TaskContext; | class TaskContext; | ||||
struct NodeState; | struct NodeState; | ||||
struct FrameState; | |||||
using NodeStatePtr = std::shared_ptr<NodeState>; | using NodeStatePtr = std::shared_ptr<NodeState>; | ||||
using FrameStatePtr = std::shared_ptr<FrameState>; | |||||
class ShapeFuture { | class ShapeFuture { | ||||
public: | public: | ||||
@@ -80,6 +82,18 @@ struct ShapeInferenceState { | |||||
std::mutex mu_; | std::mutex mu_; | ||||
}; | }; | ||||
struct FrameState { | |||||
public: | |||||
FrameState(int64_t id) : frame_id_(id) {} | |||||
~FrameState() = default; | |||||
int64_t frame_id_{0}; | |||||
uint64_t active_count_{0}; | |||||
uint64_t iteration_count_{0}; | |||||
std::shared_ptr<FrameState> parent_frame_; | |||||
}; | |||||
// saving sth. dynamic during execution | // saving sth. dynamic during execution | ||||
struct NodeState { | struct NodeState { | ||||
public: | public: | ||||
@@ -112,8 +126,8 @@ struct NodeState { | |||||
return node_item_->IsControlFlowOp() || node_item_->shape_inference_type >= DEPEND_SHAPE_RANGE; | return node_item_->IsControlFlowOp() || node_item_->shape_inference_type >= DEPEND_SHAPE_RANGE; | ||||
} | } | ||||
void RunLoopNext(); | |||||
void RunLoopExit(); | |||||
void RunStreamActive(); | |||||
void RunNextIteration(); | |||||
Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | ||||
@@ -144,6 +158,10 @@ struct NodeState { | |||||
return group_; | return group_; | ||||
} | } | ||||
void SetFrameState(const shared_ptr<FrameState> &frame_state) { | |||||
frame_state_ = frame_state; | |||||
} | |||||
const shared_ptr<NodeTask> &GetKernelTask() const { | const shared_ptr<NodeTask> &GetKernelTask() const { | ||||
return kernel_task_; | return kernel_task_; | ||||
} | } | ||||
@@ -167,7 +185,8 @@ struct NodeState { | |||||
bool IsScheduleReady() const; | bool IsScheduleReady() const; | ||||
void SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | void SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | ||||
void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | ||||
void ResetContext(uint64_t loop_count); | |||||
void ResetContext(uint64_t iteration); | |||||
void ScheduleContext(const NodeState &node_state); | |||||
const NodeItem *node_item_ = nullptr; | const NodeItem *node_item_ = nullptr; | ||||
std::shared_ptr<NodeTask> kernel_task_ = nullptr; | std::shared_ptr<NodeTask> kernel_task_ = nullptr; | ||||
@@ -179,7 +198,9 @@ struct NodeState { | |||||
std::mutex mu_; | std::mutex mu_; | ||||
std::future<Status> schedule_future_; | std::future<Status> schedule_future_; | ||||
uint64_t loop_count_ = 0; | |||||
std::shared_ptr<FrameState> frame_state_; | |||||
uint64_t active_count_ = 0; | |||||
uint64_t iteration_count_ = 0; | |||||
uint32_t ctrl_scheduled_ = 0; | uint32_t ctrl_scheduled_ = 0; | ||||
uint32_t data_scheduled_ = 0; | uint32_t data_scheduled_ = 0; | ||||
int merge_index_ = -1; // Use for Execute (Reset after Executed). | int merge_index_ = -1; // Use for Execute (Reset after Executed). | ||||
@@ -89,6 +89,7 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { | |||||
if (node_state == nullptr) { | if (node_state == nullptr) { | ||||
const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); | const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); | ||||
node_state.reset(new(std::nothrow)NodeState(*node_item, this)); | node_state.reset(new(std::nothrow)NodeState(*node_item, this)); | ||||
node_state->SetFrameState(GetOrCreateFrameState(*node_item)); | |||||
node_state->SetGroup(group_); | node_state->SetGroup(group_); | ||||
(void)guard; | (void)guard; | ||||
} | } | ||||
@@ -102,6 +103,18 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { | |||||
return node_state; | return node_state; | ||||
} | } | ||||
FrameStatePtr SubgraphContext::GetOrCreateFrameState(const NodeItem &node_item) { | |||||
auto &frame_state = frame_states_[node_item.frame_index_]; | |||||
if (frame_state == nullptr) { | |||||
frame_state.reset(new(std::nothrow)FrameState(node_item.frame_index_)); | |||||
if (node_item.frame_index_ != -1) { // -1 is root frame. | |||||
frame_state->parent_frame_ = frame_states_[node_item.parent_frame_]; | |||||
} | |||||
} | |||||
return frame_state; | |||||
} | |||||
Status SubgraphContext::SetInput(int index, const TensorValue &tensor) { | Status SubgraphContext::SetInput(int index, const TensorValue &tensor) { | ||||
if (static_cast<size_t>(index) >= all_inputs_.size()) { | if (static_cast<size_t>(index) >= all_inputs_.size()) { | ||||
GELOGE(INTERNAL_ERROR, | GELOGE(INTERNAL_ERROR, | ||||
@@ -51,6 +51,7 @@ class SubgraphContext { | |||||
void NodeDone(const NodePtr &node); | void NodeDone(const NodePtr &node); | ||||
private: | private: | ||||
FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock | |||||
friend class TaskContext; | friend class TaskContext; | ||||
const GraphItem *graph_item_; | const GraphItem *graph_item_; | ||||
const GraphExecutionContext *execution_context_; | const GraphExecutionContext *execution_context_; | ||||
@@ -59,6 +60,7 @@ class SubgraphContext { | |||||
std::vector<TensorValue> all_outputs_; | std::vector<TensorValue> all_outputs_; | ||||
NodeDoneManager node_done_manager_; | NodeDoneManager node_done_manager_; | ||||
std::unordered_map<const NodeItem *, NodeStatePtr> node_states_; | std::unordered_map<const NodeItem *, NodeStatePtr> node_states_; | ||||
std::unordered_map<int64_t, FrameStatePtr> frame_states_; | |||||
int group_ = -1; | int group_ = -1; | ||||
}; | }; | ||||
} // namespace hybrid | } // namespace hybrid | ||||
@@ -1984,6 +1984,7 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root | |||||
GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item)); | GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item)); | ||||
GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task | GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task | ||||
GE_CHK_STATUS_RET_NOLOG(BuildFrameGroupIndex(*node_item)); | |||||
GE_CHK_STATUS_RET_NOLOG(BuildControlFlowGroup(*graph_item, node, node_item)); | GE_CHK_STATUS_RET_NOLOG(BuildControlFlowGroup(*graph_item, node, node_item)); | ||||
if (node->GetInAllNodes().empty()) { | if (node->GetInAllNodes().empty()) { | ||||
graph_item->root_items_.emplace_back(node_item); | graph_item->root_items_.emplace_back(node_item); | ||||
@@ -2347,6 +2348,60 @@ Status HybridModelBuilder::BuildProfilingControl(GraphItem &graph_item, | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status HybridModelBuilder::BuildFrameGroupIndex(NodeItem &node_item) { | |||||
if (node_item.is_root_node_) { | |||||
GELOGD("[%s] control flow frame group: %ld, parent frame: %ld", | |||||
node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_); | |||||
return SUCCESS; | |||||
} | |||||
int64_t ctrl_flow_group = -1; | |||||
if (node_item.IsEnterOp() && AttrUtils::GetInt(node_item.op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, ctrl_flow_group)) { | |||||
node_item.frame_index_ = ctrl_flow_group; | |||||
if (node_item.IsEnterOp()) { | |||||
const auto src_node = node_item.node->GetInDataNodes().at(0); | |||||
NodeItem *src_node_item = nullptr; | |||||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(src_node, &src_node_item), | |||||
"[%s] failed to get or create node item", src_node->GetName().c_str()); | |||||
if (!src_node_item->is_root_node_) { | |||||
parent_frame_group_[node_item.frame_index_] = src_node_item->frame_index_; | |||||
} | |||||
} | |||||
const auto it = parent_frame_group_.find(node_item.frame_index_); | |||||
node_item.parent_frame_ = (it != parent_frame_group_.end()) ? it->second : -1; | |||||
GELOGD("[%s] control flow frame group: %ld, parent frame: %ld", | |||||
node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_); | |||||
return SUCCESS; | |||||
} | |||||
for (const auto src_node : node_item.node->GetInAllNodes()) { | |||||
NodeItem *src_node_item = nullptr; | |||||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(src_node, &src_node_item), | |||||
"[%s] failed to get or create node item", src_node->GetName().c_str()); | |||||
if (src_node_item->is_root_node_) { | |||||
continue; | |||||
} | |||||
if (src_node_item->IsExitOp()) { | |||||
const auto it = parent_frame_group_.find(src_node_item->frame_index_); | |||||
node_item.frame_index_ = (it != parent_frame_group_.end()) ? it->second : -1; | |||||
} else { | |||||
node_item.frame_index_ = src_node_item->frame_index_; | |||||
} | |||||
const auto it = parent_frame_group_.find(src_node_item->frame_index_); | |||||
node_item.parent_frame_ = (it != parent_frame_group_.end()) ? it->second : -1; | |||||
GELOGD("[%s] control flow frame group: %ld, parent frame: %ld", | |||||
node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_); | |||||
return SUCCESS; | |||||
} | |||||
GELOGD("[%s] control flow frame group: %ld, parent frame: %ld", | |||||
node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_); | |||||
return SUCCESS; | |||||
} | |||||
Status HybridModelBuilder::BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item) { | Status HybridModelBuilder::BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item) { | ||||
GELOGD("Build control flow for node %s", node->GetName().c_str()); | GELOGD("Build control flow for node %s", node->GetName().c_str()); | ||||
using GroupBuilder = std::function<Status(HybridModelBuilder *, const NodePtr &, NodeItem *)>; | using GroupBuilder = std::function<Status(HybridModelBuilder *, const NodePtr &, NodeItem *)>; | ||||
@@ -2466,6 +2521,7 @@ Status HybridModelBuilder::CreateStreamActiveGroup(const NodePtr &node, NodeItem | |||||
if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsEnterNode)) { | if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsEnterNode)) { | ||||
// Enter --> StreamActive --> StreamMerge | // Enter --> StreamActive --> StreamMerge | ||||
node_item->is_enter_active_ = true; | |||||
return CreateMergeEnterGroup(node, node_item); | return CreateMergeEnterGroup(node, node_item); | ||||
} else if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsIterationNode)) { | } else if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsIterationNode)) { | ||||
// NextIteration --> StreamActive {-->} StreamMerge | // NextIteration --> StreamActive {-->} StreamMerge | ||||
@@ -97,6 +97,7 @@ class HybridModelBuilder { | |||||
Status RelinkNextIteration(); | Status RelinkNextIteration(); | ||||
Status BuildProfilingControl(GraphItem &graph_item, const std::map<size_t, std::pair<uint32_t, uint32_t>> &nodes); | Status BuildProfilingControl(GraphItem &graph_item, const std::map<size_t, std::pair<uint32_t, uint32_t>> &nodes); | ||||
Status BuildFrameGroupIndex(NodeItem &node_item); | |||||
Status BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item); | Status BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item); | ||||
Status CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item); | Status CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item); | ||||
Status CreateMergeEnterGroup(const NodePtr &node, NodeItem *node_item); | Status CreateMergeEnterGroup(const NodePtr &node, NodeItem *node_item); | ||||
@@ -123,6 +124,7 @@ class HybridModelBuilder { | |||||
std::map<std::string, NodePtr> constant_op_nodes_; | std::map<std::string, NodePtr> constant_op_nodes_; | ||||
std::map<std::string, NodePtr> stream_merge_op_nodes_; | std::map<std::string, NodePtr> stream_merge_op_nodes_; | ||||
std::map<std::string, NodePtr> next_iteration_op_nodes_; | std::map<std::string, NodePtr> next_iteration_op_nodes_; | ||||
std::map<int64_t, int64_t> parent_frame_group_; | |||||
std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_; | std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_; | ||||
std::map<NodeItem *, std::set<std::string>> node_to_parallel_groups_; | std::map<NodeItem *, std::set<std::string>> node_to_parallel_groups_; | ||||
@@ -20,7 +20,6 @@ | |||||
#include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
#include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/utils/node_utils.h" | |||||
#include "hybrid/executor/worker/shape_inference_engine.h" | #include "hybrid/executor/worker/shape_inference_engine.h" | ||||
#include "hybrid/node_executor/node_executor.h" | #include "hybrid/node_executor/node_executor.h" | ||||
@@ -34,7 +33,7 @@ const std::set<std::string> kControlOpTypes{ | |||||
}; | }; | ||||
const std::set<std::string> kControlFlowOpTypes{ | const std::set<std::string> kControlFlowOpTypes{ | ||||
STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT, | |||||
STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT, | |||||
LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX | LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX | ||||
}; | }; | ||||
@@ -402,8 +401,8 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { | |||||
node_item->root_data_.emplace(this); | node_item->root_data_.emplace(this); | ||||
} | } | ||||
// If Enter feed Not Merge, take as root Node. | // If Enter feed Not Merge, take as root Node. | ||||
if ((kEnterOpTypes.count(node_type) > 0) && (node_item->node_type != STREAMMERGE)) { | |||||
node_item->root_data_.emplace(this); | |||||
if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { | |||||
node_item->enter_data_.emplace(this); | |||||
node_item->enter_inside_.emplace(anchor_index); | node_item->enter_inside_.emplace(anchor_index); | ||||
} | } | ||||
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | ||||
@@ -422,8 +421,8 @@ void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { | |||||
node_item->root_ctrl_.emplace(this); | node_item->root_ctrl_.emplace(this); | ||||
} | } | ||||
// If Enter feed control signal, take as root Node. | // If Enter feed control signal, take as root Node. | ||||
if (kEnterOpTypes.count(node_type) > 0) { | |||||
node_item->root_ctrl_.emplace(this); | |||||
if (IsEnterOp() && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) { | |||||
node_item->enter_ctrl_.emplace(this); | |||||
} | } | ||||
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | ||||
} | } | ||||
@@ -22,6 +22,7 @@ | |||||
#include "external/ge/ge_api_error_codes.h" | #include "external/ge/ge_api_error_codes.h" | ||||
#include "graph/node.h" | #include "graph/node.h" | ||||
#include "graph/op_desc.h" | #include "graph/op_desc.h" | ||||
#include "graph/utils/node_utils.h" | |||||
#include "framework/common/types.h" | #include "framework/common/types.h" | ||||
#include "hybrid/common/tensor_value.h" | #include "hybrid/common/tensor_value.h" | ||||
@@ -92,6 +93,14 @@ struct NodeItem { | |||||
return is_merge_op_; | return is_merge_op_; | ||||
} | } | ||||
bool IsEnterOp() const { | |||||
return kEnterOpTypes.count(node_type) > 0; | |||||
} | |||||
bool IsExitOp() const { | |||||
return kExitOpTypes.count(node_type) > 0; | |||||
} | |||||
bool IsHcclOp() const; | bool IsHcclOp() const; | ||||
void SetToDynamic(); | void SetToDynamic(); | ||||
@@ -135,8 +144,13 @@ struct NodeItem { | |||||
bool is_ctrl_flow_v2_op_ = false; | bool is_ctrl_flow_v2_op_ = false; | ||||
bool is_ctrl_flow_op_ = false; | bool is_ctrl_flow_op_ = false; | ||||
bool is_merge_op_ = false; | bool is_merge_op_ = false; | ||||
bool is_enter_active_ = false; | |||||
int64_t frame_index_ = -1; | |||||
int64_t parent_frame_ = -1; | |||||
std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node | std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node | ||||
std::set<const NodeItem *> root_data_; // Recv data from root node | std::set<const NodeItem *> root_data_; // Recv data from root node | ||||
std::set<const NodeItem *> enter_ctrl_; // Recv ctrl from Enter node | |||||
std::set<const NodeItem *> enter_data_; // Recv data from Enter node | |||||
std::set<const NodeItem *> data_send_; // Send data notify to | std::set<const NodeItem *> data_send_; // Send data notify to | ||||
std::map<const NodeItem *, int> data_recv_; // Recv data notify from | std::map<const NodeItem *, int> data_recv_; // Recv data notify from | ||||
std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | ||||
@@ -90,7 +90,7 @@ Status RtsNodeTask::GetScalarIndexValue(TaskContext &task_context, uint32_t inde | |||||
Status StreamActiveNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { | Status StreamActiveNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { | ||||
GELOGD("[%s] Start to execute.", task_context.GetNodeName()); | GELOGD("[%s] Start to execute.", task_context.GetNodeName()); | ||||
const auto &node_state = task_context.GetNodeState(); | const auto &node_state = task_context.GetNodeState(); | ||||
node_state->SetSwitchIndex(0); | |||||
node_state->RunStreamActive(); | |||||
if (done_callback) { | if (done_callback) { | ||||
GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); | GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); | ||||
} | } | ||||
@@ -204,9 +204,7 @@ Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::functio | |||||
const auto &node_state = task_context.GetNodeState(); | const auto &node_state = task_context.GetNodeState(); | ||||
if (kNextIterationOpTypes.count(node_state->GetType()) > 0) { | if (kNextIterationOpTypes.count(node_state->GetType()) > 0) { | ||||
node_state->RunLoopNext(); | |||||
} else if (kExitOpTypes.count(node_state->GetType()) > 0) { | |||||
node_state->RunLoopExit(); | |||||
node_state->RunNextIteration(); | |||||
} | } | ||||
if (done_callback) { | if (done_callback) { | ||||
@@ -127,7 +127,7 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt | |||||
AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); | AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); | ||||
} | } | ||||
const auto less1 = CreateNode(graph, "less", EXIT, 2, 1); // Mock for less, just pass input0. | |||||
const auto less1 = CreateNode(graph, "less", IDENTITY, 2, 1); // Mock for less, just pass input0. | |||||
const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0); | const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0); | ||||
switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0); | switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0); | ||||
@@ -135,8 +135,8 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt | |||||
AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_EQUAL); // 101 for true. | AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_EQUAL); // 101 for true. | ||||
AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_NOT_EQUAL); | AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_NOT_EQUAL); | ||||
const auto add1 = CreateNode(graph, "add", EXIT, 2, 1); // Mock for add, just pass input0. | |||||
const auto sub1 = CreateNode(graph, "sub", EXIT, 2, 1); // Mock for sub, just pass input0. | |||||
const auto add1 = CreateNode(graph, "add", IDENTITY, 2, 1); // Mock for add, just pass input0. | |||||
const auto sub1 = CreateNode(graph, "sub", IDENTITY, 2, 1); // Mock for sub, just pass input0. | |||||
const auto merge1 = CreateNode(graph, "merge", STREAMMERGE, 2, 2); | const auto merge1 = CreateNode(graph, "merge", STREAMMERGE, 2, 2); | ||||
const auto active2 = CreateNode(graph, "active2", STREAMACTIVE, 0, 0); | const auto active2 = CreateNode(graph, "active2", STREAMACTIVE, 0, 0); | ||||
@@ -89,7 +89,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||||
* \ / \. | * \ / \. | ||||
* Switch Add | * Switch Add | ||||
* / | | | * / | | | ||||
* / | | | |||||
* Active / | | | |||||
* / | | | * / | | | ||||
* LoopCond | | | * LoopCond | | | ||||
* \ | | | * \ | | | ||||
@@ -98,9 +98,10 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||||
* Less | | | * Less | | | ||||
* \ | NextIteration | * \ | NextIteration | ||||
* \ | | | * \ | | | ||||
* \ | | | |||||
* \ | | Active | |||||
* Merge <---------| | * Merge <---------| | ||||
* | | * | | ||||
* | Active | |||||
* | | * | | ||||
* Enter | * Enter | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
@@ -110,6 +111,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||||
GeModelPtr ge_sub_model = make_shared<GeModel>(); | GeModelPtr ge_sub_model = make_shared<GeModel>(); | ||||
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | ||||
auto data1 = CreateNode(*graph, "data", DATA, 1, 1); | |||||
auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); | auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); | ||||
auto merge1 = CreateNode(*graph, "merge", STREAMMERGE, 2, 2); | auto merge1 = CreateNode(*graph, "merge", STREAMMERGE, 2, 2); | ||||
auto less1 = CreateNode(*graph, "less", LESS, 2, 1); | auto less1 = CreateNode(*graph, "less", LESS, 2, 1); | ||||
@@ -129,6 +131,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||||
auto active3 = CreateNode(*graph, "active3", STREAMACTIVE, 0, 0); | auto active3 = CreateNode(*graph, "active3", STREAMACTIVE, 0, 0); | ||||
auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); | auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); | ||||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | ||||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); | GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); | ||||
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); | GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); | ||||
@@ -153,8 +156,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||||
GraphUtils::AddEdge(active1->GetOutControlAnchor(), merge1->GetInControlAnchor()); | GraphUtils::AddEdge(active1->GetOutControlAnchor(), merge1->GetInControlAnchor()); | ||||
GraphUtils::AddEdge(next1->GetOutControlAnchor(), active3->GetInControlAnchor()); | GraphUtils::AddEdge(next1->GetOutControlAnchor(), active3->GetInControlAnchor()); | ||||
//GraphUtils::AddEdge(active3->GetOutControlAnchor(), merge1->GetInControlAnchor()); | |||||
SetNextIteration(merge1, next1); | |||||
SetNextIteration(merge1, next1); // for relink NextIteration --> StreamMerge | |||||
GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_t->GetInControlAnchor()); // Test for not merge. | GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_t->GetInControlAnchor()); // Test for not merge. | ||||
@@ -169,6 +171,17 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||||
AttrUtils::SetBool(add1->GetOpDesc(), ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true); | AttrUtils::SetBool(add1->GetOpDesc(), ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true); | ||||
AttrUtils::SetBool(add1->GetOpDesc(), ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true); | AttrUtils::SetBool(add1->GetOpDesc(), ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true); | ||||
SetControlFlowGroup(enter1, loop1->GetOpDesc()->GetId()); | |||||
SetControlFlowGroup(active1, loop1->GetOpDesc()->GetId()); | |||||
SetControlFlowGroup(merge1, loop1->GetOpDesc()->GetId()); | |||||
SetControlFlowGroup(loop1, loop1->GetOpDesc()->GetId()); | |||||
SetControlFlowGroup(active2, switch_t->GetOpDesc()->GetId()); | |||||
SetControlFlowGroup(switch_t, switch_t->GetOpDesc()->GetId()); | |||||
SetControlFlowGroup(switch_f, switch_t->GetOpDesc()->GetId()); | |||||
SetControlFlowGroup(next1, loop1->GetOpDesc()->GetId()); | |||||
SetControlFlowGroup(active3, loop1->GetOpDesc()->GetId()); | |||||
SetControlFlowGroup(exit1, loop1->GetOpDesc()->GetId()); | |||||
// Build -> IndexSpecialNodes --> stream_merge_op_nodes_ | // Build -> IndexSpecialNodes --> stream_merge_op_nodes_ | ||||
// Build -> LoadGraph -> RelinkNextIteration | // Build -> LoadGraph -> RelinkNextIteration | ||||
// Build -> LoadGraph -> LoadDynamicSubgraph --> BuildNodeItem --> NodeItem::SetDataSend | // Build -> LoadGraph -> LoadDynamicSubgraph --> BuildNodeItem --> NodeItem::SetDataSend | ||||
@@ -190,9 +203,23 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||||
task_executor.emplace(NodeExecutorManager::ExecutorType::RTS, std::unique_ptr<NodeExecutor>(new NodeExecutor())); | task_executor.emplace(NodeExecutorManager::ExecutorType::RTS, std::unique_ptr<NodeExecutor>(new NodeExecutor())); | ||||
task_executor.emplace(NodeExecutorManager::ExecutorType::HOST_CPU, std::unique_ptr<NodeExecutor>(new NodeExecutor())); | task_executor.emplace(NodeExecutorManager::ExecutorType::HOST_CPU, std::unique_ptr<NodeExecutor>(new NodeExecutor())); | ||||
const auto control_group_index = loop1->GetOpDesc()->GetId(); | |||||
HybridModel hybrid_model(ge_root_model); | HybridModel hybrid_model(ge_root_model); | ||||
HybridModelBuilder hybrid_model_builder(hybrid_model); | HybridModelBuilder hybrid_model_builder(hybrid_model); | ||||
ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS); | ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS); | ||||
const auto TestFrameGroup = [&hybrid_model](const NodePtr &n, int64_t index) { | |||||
const auto it = hybrid_model.node_items_.find(n); | |||||
ASSERT_NE(hybrid_model.node_items_.end(), it); | |||||
ASSERT_EQ(it->second->frame_index_, index); | |||||
ASSERT_EQ(it->second->parent_frame_, -1); | |||||
}; | |||||
TestFrameGroup(enter1, control_group_index); | |||||
TestFrameGroup(active1, control_group_index); | |||||
TestFrameGroup(active2, control_group_index); | |||||
TestFrameGroup(active3, control_group_index); | |||||
TestFrameGroup(output1, -1); | |||||
engine_mapping.clear(); | engine_mapping.clear(); | ||||
task_executor.clear(); | task_executor.clear(); | ||||
} | } | ||||
@@ -166,6 +166,10 @@ TEST_F(UtestRtsNodeTask, test_stream_active_task) { | |||||
std::function<void()> done = []() {}; | std::function<void()> done = []() {}; | ||||
ASSERT_EQ(node_state->GetSwitchIndex(), -1); | ASSERT_EQ(node_state->GetSwitchIndex(), -1); | ||||
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); | ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); | ||||
ASSERT_EQ(node_state->GetSwitchIndex(), -1); | |||||
node_item->ctrl_send_.emplace(nullptr); | |||||
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); | |||||
ASSERT_EQ(node_state->GetSwitchIndex(), 0); | ASSERT_EQ(node_state->GetSwitchIndex(), 0); | ||||
} | } | ||||