Browse Source

Fix loop nesting

tags/v1.3.0
zhangxiaokun 3 years ago
parent
commit
04f27bc560
16 changed files with 260 additions and 76 deletions
  1. +12
    -2
      ge/graph/common/omg_util.cc
  2. +8
    -0
      ge/graph/common/omg_util.h
  3. +8
    -18
      ge/graph/passes/next_iteration_pass.cc
  4. +0
    -1
      ge/graph/passes/next_iteration_pass.h
  5. +75
    -34
      ge/hybrid/executor/node_state.cc
  6. +25
    -4
      ge/hybrid/executor/node_state.h
  7. +13
    -0
      ge/hybrid/executor/subgraph_context.cc
  8. +2
    -0
      ge/hybrid/executor/subgraph_context.h
  9. +56
    -0
      ge/hybrid/model/hybrid_model_builder.cc
  10. +2
    -0
      ge/hybrid/model/hybrid_model_builder.h
  11. +5
    -6
      ge/hybrid/model/node_item.cc
  12. +14
    -0
      ge/hybrid/model/node_item.h
  13. +2
    -4
      ge/hybrid/node_executor/rts/rts_node_task.cc
  14. +3
    -3
      tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc
  15. +31
    -4
      tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc
  16. +4
    -0
      tests/ut/ge/hybrid/node_executor/rts/rts_node_task_unittest.cc

+ 12
- 2
ge/graph/common/omg_util.cc View File

@@ -286,13 +286,23 @@ void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t grou
return; return;
} }


SetControlFlowGroup(node, group_index);
}

///
/// @brief Set Op _control_flow_group flag
/// @param [in] node
/// @param [in] group, condition group index of node.
/// @return
///
void SetControlFlowGroup(const NodePtr &node, int64_t group) {
GE_RT_VOID_CHECK_NOTNULL(node); GE_RT_VOID_CHECK_NOTNULL(node);
const auto &op_desc = node->GetOpDesc(); const auto &op_desc = node->GetOpDesc();
GE_RT_VOID_CHECK_NOTNULL(op_desc); GE_RT_VOID_CHECK_NOTNULL(op_desc);


// op_desc as AttrHolderAdapter valid, Set attribute always success, just log for check. // op_desc as AttrHolderAdapter valid, Set attribute always success, just log for check.
GELOGD("[%s] Set control flow group index: %ld", node->GetName().c_str(), group_index);
if (!AttrUtils::SetInt(op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) {
GELOGD("[%s] Set control flow group index: %ld", node->GetName().c_str(), group);
if (!AttrUtils::SetInt(op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, group)) {
REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(),
node->GetName().c_str(), node->GetType().c_str()); node->GetName().c_str(), node->GetType().c_str());
GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(),


+ 8
- 0
ge/graph/common/omg_util.h View File

@@ -133,6 +133,14 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc);
/// @return /// @return
/// ///
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index); void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index);

///
/// @brief Set Op _control_flow_group flag
/// @param [in] node
/// @param [in] group, condition group index of node.
/// @return
///
void SetControlFlowGroup(const NodePtr &node, int64_t group);
} // namespace ge } // namespace ge


#endif // GE_GRAPH_COMMON_OMG_UTIL_H_ #endif // GE_GRAPH_COMMON_OMG_UTIL_H_

+ 8
- 18
ge/graph/passes/next_iteration_pass.cc View File

@@ -183,12 +183,6 @@ bool NextIterationPass::VerifyWhileGroup() {
frame_name.c_str()); frame_name.c_str());
return false; return false;
} }

// Mark loop as unknown shape If any merge has unknown shape output.
const auto &op_desc = pair_iter.first->GetOpDesc();
if (IsUnknownShapeTensor(op_desc->GetOutputDesc(0))) {
loop_group_iter.second->is_unknown_shape = true; // under check loop, cannot break.
}
} }
} }


@@ -225,7 +219,7 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
enter_active->GetName().c_str()); enter_active->GetName().c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
MarkForceUnknownShape(enter_node, loop_group.is_unknown_shape, group_index);
SetControlFlowGroup(enter_node, group_index);
} }


for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { for (const auto &pair : loop_cond_iter.second->merge_next_pairs) {
@@ -255,8 +249,8 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }


MarkForceUnknownShape(next_node, loop_group.is_unknown_shape, group_index);
MarkForceUnknownShape(merge_node, loop_group.is_unknown_shape, group_index);
SetControlFlowGroup(next_node, group_index);
SetControlFlowGroup(merge_node, group_index);
} }


if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) ||
@@ -265,9 +259,9 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }


MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape, group_index);
MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape, group_index);
MarkForceUnknownShape(next_active, loop_group.is_unknown_shape, group_index);
SetControlFlowGroup(loop_group.loop_cond, group_index);
SetControlFlowGroup(enter_active, group_index);
SetControlFlowGroup(next_active, group_index);
HandleSwitchExitNodes(loop_group, group_index); HandleSwitchExitNodes(loop_group, group_index);
} }


@@ -281,17 +275,13 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
/// @return void /// @return void
/// ///
void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) { void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) {
if (!loop_group.is_unknown_shape) {
return;
}

for (const auto &switch_node : loop_group.switch_nodes) { for (const auto &switch_node : loop_group.switch_nodes) {
MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape, group_index);
SetControlFlowGroup(switch_node, group_index);
for (const auto &node : switch_node->GetOutDataNodes()) { for (const auto &node : switch_node->GetOutDataNodes()) {
std::string node_type; std::string node_type;
(void)GetOriginalType(node, node_type); (void)GetOriginalType(node, node_type);
if (kExitOpTypes.count(node_type) > 0) { if (kExitOpTypes.count(node_type) > 0) {
MarkForceUnknownShape(node, loop_group.is_unknown_shape, group_index);
SetControlFlowGroup(node, group_index);
} }
} }
} }


+ 0
- 1
ge/graph/passes/next_iteration_pass.h View File

@@ -24,7 +24,6 @@ struct LoopCondGroup {
std::vector<ge::NodePtr> enter_nodes; // Enter nodes std::vector<ge::NodePtr> enter_nodes; // Enter nodes
std::vector<std::pair<ge::NodePtr, ge::NodePtr>> merge_next_pairs; // <Merge, NextIteration> std::vector<std::pair<ge::NodePtr, ge::NodePtr>> merge_next_pairs; // <Merge, NextIteration>
std::vector<ge::NodePtr> switch_nodes; // Switch nodes std::vector<ge::NodePtr> switch_nodes; // Switch nodes
bool is_unknown_shape{false};
}; };
using LoopCondGroupPtr = std::shared_ptr<LoopCondGroup>; using LoopCondGroupPtr = std::shared_ptr<LoopCondGroup>;




+ 75
- 34
ge/hybrid/executor/node_state.cc View File

@@ -22,6 +22,14 @@
#include "hybrid_execution_context.h" #include "hybrid_execution_context.h"
#include "subgraph_context.h" #include "subgraph_context.h"


#define INC_ITERATION_COUNT(iteration) \
do { \
++iteration; \
if (iteration == UINT64_MAX) { \
iteration = 1; \
} \
} while (0)

namespace ge { namespace ge {
namespace hybrid { namespace hybrid {
namespace { namespace {
@@ -306,15 +314,45 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() {
return task_context_; return task_context_;
} }


void NodeState::ResetContext(uint64_t loop_count) {
loop_count_ = loop_count;

void NodeState::ResetContext(uint64_t iteration) {
switch_index_ = -1; switch_index_ = -1;
subgraph_context_->ResetContext(node_item_->node); subgraph_context_->ResetContext(node_item_->node);
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size());
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size());
GELOGD("[%s] in while loop, loop count: %lu, data scheduled: %u, ctrl scheduled: %u, merge index: %d",
GetName().c_str(), loop_count_, data_scheduled_, ctrl_scheduled_, merge_index_);
if (iteration == 0) {
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size());
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size());
} else {
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size() + node_item_->enter_data_.size());
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size() + node_item_->enter_ctrl_.size());
}

iteration_count_ = iteration;
GELOGD("[%s] in while loop, current iteration: %lu, data scheduled: %u, ctrl scheduled: %u, merge index: %d",
GetName().c_str(), iteration_count_, data_scheduled_, ctrl_scheduled_, merge_index_);
}

void NodeState::ScheduleContext(const NodeState &node_state) {
if (node_state.node_item_->IsEnterOp()) {
GELOGD("[%s]{active: %lu, iteration: %lu}, frame{active: %lu, iteration: %lu} [%s]{active: %lu, iteration: %lu}",
GetName().c_str(), active_count_, iteration_count_, frame_state_->active_count_,
frame_state_->iteration_count_, node_state.GetName().c_str(), node_state.frame_state_->active_count_,
node_state.frame_state_->iteration_count_);
if (frame_state_->active_count_ != active_count_) {
ResetContext(0);
active_count_ = frame_state_->active_count_;
}
} else if (node_state.node_item_->IsExitOp()) {
GELOGD("[%s]{active: %lu, iteration: %lu} frame{active: %lu, iteration: %lu} "
"[%s]{active: %lu, iteration: %lu} parent{active: %lu, iteration: %lu}",
GetName().c_str(), active_count_, iteration_count_, frame_state_->active_count_,
frame_state_->iteration_count_, node_state.GetName().c_str(), node_state.frame_state_->active_count_,
node_state.frame_state_->iteration_count_, node_state.frame_state_->parent_frame_->active_count_,
node_state.frame_state_->parent_frame_->iteration_count_);
if (node_state.frame_state_->parent_frame_->iteration_count_ != iteration_count_) {
ResetContext(node_state.frame_state_->parent_frame_->iteration_count_);
}
} else if (node_state.iteration_count_ != iteration_count_) {
ResetContext(node_state.iteration_count_);
}
} }


Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &ready) const { Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &ready) const {
@@ -346,11 +384,11 @@ Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &rea
} }


bool NodeState::IsScheduleReady() const { bool NodeState::IsScheduleReady() const {
GELOGD("[%s] loop[%lu] data[input: %zu, scheduled: %u], ctrl[input: %zu+%zu, scheduled: %u]",
GetName().c_str(), loop_count_, node_item_->data_recv_.size(), data_scheduled_,
node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_);
GELOGD("[%s] iteration[%lu] data[input: %zu, scheduled: %u], ctrl[input: %zu+%zu, scheduled: %u]",
GetName().c_str(), iteration_count_, node_item_->data_recv_.size(), data_scheduled_,
node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(iteration_count_ == 0 ? 0 : 1), ctrl_scheduled_);
if (node_item_->IsMergeOp()) { if (node_item_->IsMergeOp()) {
if (ctrl_scheduled_ != node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1) + node_item_->ctrl_recv_.size()) {
if (ctrl_scheduled_ != node_item_->GetMergeCtrl(iteration_count_ == 0 ? 0 : 1) + node_item_->ctrl_recv_.size()) {
return false; return false;
} }


@@ -366,15 +404,13 @@ bool NodeState::IsScheduleReady() const {
} }


void NodeState::SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready) { void NodeState::SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready) {
GELOGD("[%s] schedule [%s], loop[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]",
node_state.GetName().c_str(), GetName().c_str(), loop_count_, node_state.loop_count_,
GELOGD("[%s] schedule [%s], iteration[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]",
node_state.GetName().c_str(), GetName().c_str(), iteration_count_, node_state.iteration_count_,
node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(), node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(),
node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_);
node_item_->GetMergeCtrl(iteration_count_ == 0 ? 0 : 1), ctrl_scheduled_);


std::lock_guard<std::mutex> lk(mu_); std::lock_guard<std::mutex> lk(mu_);
if (loop_count_ != node_state.loop_count_) {
ResetContext(node_state.loop_count_);
}
ScheduleContext(node_state);
++data_scheduled_; ++data_scheduled_;


if (node_item_->IsMergeOp()) { if (node_item_->IsMergeOp()) {
@@ -394,15 +430,13 @@ void NodeState::SetDataSchedule(const NodeState &node_state, const std::function
} }


void NodeState::SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready) { void NodeState::SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready) {
GELOGD("[%s] schedule [%s], loop[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]",
node_state.GetName().c_str(), GetName().c_str(), loop_count_, node_state.loop_count_,
GELOGD("[%s] schedule [%s], iteration[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]",
node_state.GetName().c_str(), GetName().c_str(), iteration_count_, node_state.iteration_count_,
node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(), node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(),
node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_);
node_item_->GetMergeCtrl(iteration_count_ == 0 ? 0 : 1), ctrl_scheduled_);


std::lock_guard<std::mutex> lk(mu_); std::lock_guard<std::mutex> lk(mu_);
if (loop_count_ != node_state.loop_count_) {
ResetContext(node_state.loop_count_);
}
ScheduleContext(node_state);
++ctrl_scheduled_; ++ctrl_scheduled_;


if (IsScheduleReady()) { if (IsScheduleReady()) {
@@ -410,21 +444,28 @@ void NodeState::SetCtrlSchedule(const NodeState &node_state, const std::function
} }
} }


void NodeState::RunLoopNext() {
GELOGD("Node[%s] run in loop, current count: %lu", GetName().c_str(), loop_count_);
void NodeState::RunNextIteration() {
std::lock_guard<std::mutex> lk(mu_); std::lock_guard<std::mutex> lk(mu_);
++loop_count_;
if (loop_count_ == UINT64_MAX) {
loop_count_ = 1;
}

ResetContext(loop_count_);
INC_ITERATION_COUNT(iteration_count_);
ResetContext(iteration_count_);
} }


void NodeState::RunLoopExit() {
GELOGD("Node[%s] run in loop, current count: %lu", GetName().c_str(), loop_count_);
void NodeState::RunStreamActive() {
std::lock_guard<std::mutex> lk(mu_); std::lock_guard<std::mutex> lk(mu_);
loop_count_ = 0;
if (node_item_->ctrl_send_.empty()) { // Not for Loop Enter or Loop Next.
return;
}
switch_index_ = 0;
data_scheduled_ = 0;
ctrl_scheduled_ = 0;
if (node_item_->is_enter_active_) {
frame_state_->iteration_count_ = 0;
INC_ITERATION_COUNT(frame_state_->active_count_);
} else {
INC_ITERATION_COUNT(frame_state_->iteration_count_);
}
GELOGD("Node[%s] current iteration: %lu, frame active: %lu, frame iteration: %lu",
GetName().c_str(), iteration_count_, frame_state_->active_count_, frame_state_->iteration_count_);
} }


void NodeState::SetScheduleFuture(std::future<Status> &&future) { void NodeState::SetScheduleFuture(std::future<Status> &&future) {


+ 25
- 4
ge/hybrid/executor/node_state.h View File

@@ -33,8 +33,10 @@ struct GraphExecutionContext;
class SubgraphContext; class SubgraphContext;
class TaskContext; class TaskContext;
struct NodeState; struct NodeState;
struct FrameState;


using NodeStatePtr = std::shared_ptr<NodeState>; using NodeStatePtr = std::shared_ptr<NodeState>;
using FrameStatePtr = std::shared_ptr<FrameState>;


class ShapeFuture { class ShapeFuture {
public: public:
@@ -80,6 +82,18 @@ struct ShapeInferenceState {
std::mutex mu_; std::mutex mu_;
}; };


struct FrameState {
public:
FrameState(int64_t id) : frame_id_(id) {}
~FrameState() = default;

int64_t frame_id_{0};
uint64_t active_count_{0};
uint64_t iteration_count_{0};

std::shared_ptr<FrameState> parent_frame_;
};

// saving sth. dynamic during execution // saving sth. dynamic during execution
struct NodeState { struct NodeState {
public: public:
@@ -112,8 +126,8 @@ struct NodeState {
return node_item_->IsControlFlowOp() || node_item_->shape_inference_type >= DEPEND_SHAPE_RANGE; return node_item_->IsControlFlowOp() || node_item_->shape_inference_type >= DEPEND_SHAPE_RANGE;
} }


void RunLoopNext();
void RunLoopExit();
void RunStreamActive();
void RunNextIteration();


Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const;


@@ -144,6 +158,10 @@ struct NodeState {
return group_; return group_;
} }


void SetFrameState(const shared_ptr<FrameState> &frame_state) {
frame_state_ = frame_state;
}

const shared_ptr<NodeTask> &GetKernelTask() const { const shared_ptr<NodeTask> &GetKernelTask() const {
return kernel_task_; return kernel_task_;
} }
@@ -167,7 +185,8 @@ struct NodeState {
bool IsScheduleReady() const; bool IsScheduleReady() const;
void SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); void SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready);
void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready);
void ResetContext(uint64_t loop_count);
void ResetContext(uint64_t iteration);
void ScheduleContext(const NodeState &node_state);


const NodeItem *node_item_ = nullptr; const NodeItem *node_item_ = nullptr;
std::shared_ptr<NodeTask> kernel_task_ = nullptr; std::shared_ptr<NodeTask> kernel_task_ = nullptr;
@@ -179,7 +198,9 @@ struct NodeState {
std::mutex mu_; std::mutex mu_;


std::future<Status> schedule_future_; std::future<Status> schedule_future_;
uint64_t loop_count_ = 0;
std::shared_ptr<FrameState> frame_state_;
uint64_t active_count_ = 0;
uint64_t iteration_count_ = 0;
uint32_t ctrl_scheduled_ = 0; uint32_t ctrl_scheduled_ = 0;
uint32_t data_scheduled_ = 0; uint32_t data_scheduled_ = 0;
int merge_index_ = -1; // Use for Execute (Reset after Executed). int merge_index_ = -1; // Use for Execute (Reset after Executed).


+ 13
- 0
ge/hybrid/executor/subgraph_context.cc View File

@@ -89,6 +89,7 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) {
if (node_state == nullptr) { if (node_state == nullptr) {
const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); const auto &guard = node_item->MutexGuard("GetOrCreateNodeState");
node_state.reset(new(std::nothrow)NodeState(*node_item, this)); node_state.reset(new(std::nothrow)NodeState(*node_item, this));
node_state->SetFrameState(GetOrCreateFrameState(*node_item));
node_state->SetGroup(group_); node_state->SetGroup(group_);
(void)guard; (void)guard;
} }
@@ -102,6 +103,18 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) {
return node_state; return node_state;
} }


FrameStatePtr SubgraphContext::GetOrCreateFrameState(const NodeItem &node_item) {
auto &frame_state = frame_states_[node_item.frame_index_];
if (frame_state == nullptr) {
frame_state.reset(new(std::nothrow)FrameState(node_item.frame_index_));
if (node_item.frame_index_ != -1) { // -1 is root frame.
frame_state->parent_frame_ = frame_states_[node_item.parent_frame_];
}
}

return frame_state;
}

Status SubgraphContext::SetInput(int index, const TensorValue &tensor) { Status SubgraphContext::SetInput(int index, const TensorValue &tensor) {
if (static_cast<size_t>(index) >= all_inputs_.size()) { if (static_cast<size_t>(index) >= all_inputs_.size()) {
GELOGE(INTERNAL_ERROR, GELOGE(INTERNAL_ERROR,


+ 2
- 0
ge/hybrid/executor/subgraph_context.h View File

@@ -51,6 +51,7 @@ class SubgraphContext {
void NodeDone(const NodePtr &node); void NodeDone(const NodePtr &node);


private: private:
FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock
friend class TaskContext; friend class TaskContext;
const GraphItem *graph_item_; const GraphItem *graph_item_;
const GraphExecutionContext *execution_context_; const GraphExecutionContext *execution_context_;
@@ -59,6 +60,7 @@ class SubgraphContext {
std::vector<TensorValue> all_outputs_; std::vector<TensorValue> all_outputs_;
NodeDoneManager node_done_manager_; NodeDoneManager node_done_manager_;
std::unordered_map<const NodeItem *, NodeStatePtr> node_states_; std::unordered_map<const NodeItem *, NodeStatePtr> node_states_;
std::unordered_map<int64_t, FrameStatePtr> frame_states_;
int group_ = -1; int group_ = -1;
}; };
} // namespace hybrid } // namespace hybrid


+ 56
- 0
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -1984,6 +1984,7 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root
GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item)); GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item));
GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task


GE_CHK_STATUS_RET_NOLOG(BuildFrameGroupIndex(*node_item));
GE_CHK_STATUS_RET_NOLOG(BuildControlFlowGroup(*graph_item, node, node_item)); GE_CHK_STATUS_RET_NOLOG(BuildControlFlowGroup(*graph_item, node, node_item));
if (node->GetInAllNodes().empty()) { if (node->GetInAllNodes().empty()) {
graph_item->root_items_.emplace_back(node_item); graph_item->root_items_.emplace_back(node_item);
@@ -2347,6 +2348,60 @@ Status HybridModelBuilder::BuildProfilingControl(GraphItem &graph_item,
return SUCCESS; return SUCCESS;
} }


Status HybridModelBuilder::BuildFrameGroupIndex(NodeItem &node_item) {
if (node_item.is_root_node_) {
GELOGD("[%s] control flow frame group: %ld, parent frame: %ld",
node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_);
return SUCCESS;
}

int64_t ctrl_flow_group = -1;
if (node_item.IsEnterOp() && AttrUtils::GetInt(node_item.op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, ctrl_flow_group)) {
node_item.frame_index_ = ctrl_flow_group;
if (node_item.IsEnterOp()) {
const auto src_node = node_item.node->GetInDataNodes().at(0);
NodeItem *src_node_item = nullptr;
GE_CHK_STATUS_RET(GetOrCreateNodeItem(src_node, &src_node_item),
"[%s] failed to get or create node item", src_node->GetName().c_str());
if (!src_node_item->is_root_node_) {
parent_frame_group_[node_item.frame_index_] = src_node_item->frame_index_;
}
}

const auto it = parent_frame_group_.find(node_item.frame_index_);
node_item.parent_frame_ = (it != parent_frame_group_.end()) ? it->second : -1;
GELOGD("[%s] control flow frame group: %ld, parent frame: %ld",
node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_);
return SUCCESS;
}

for (const auto src_node : node_item.node->GetInAllNodes()) {
NodeItem *src_node_item = nullptr;
GE_CHK_STATUS_RET(GetOrCreateNodeItem(src_node, &src_node_item),
"[%s] failed to get or create node item", src_node->GetName().c_str());
if (src_node_item->is_root_node_) {
continue;
}

if (src_node_item->IsExitOp()) {
const auto it = parent_frame_group_.find(src_node_item->frame_index_);
node_item.frame_index_ = (it != parent_frame_group_.end()) ? it->second : -1;
} else {
node_item.frame_index_ = src_node_item->frame_index_;
}

const auto it = parent_frame_group_.find(src_node_item->frame_index_);
node_item.parent_frame_ = (it != parent_frame_group_.end()) ? it->second : -1;
GELOGD("[%s] control flow frame group: %ld, parent frame: %ld",
node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_);
return SUCCESS;
}

GELOGD("[%s] control flow frame group: %ld, parent frame: %ld",
node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_);
return SUCCESS;
}

Status HybridModelBuilder::BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item) { Status HybridModelBuilder::BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item) {
GELOGD("Build control flow for node %s", node->GetName().c_str()); GELOGD("Build control flow for node %s", node->GetName().c_str());
using GroupBuilder = std::function<Status(HybridModelBuilder *, const NodePtr &, NodeItem *)>; using GroupBuilder = std::function<Status(HybridModelBuilder *, const NodePtr &, NodeItem *)>;
@@ -2466,6 +2521,7 @@ Status HybridModelBuilder::CreateStreamActiveGroup(const NodePtr &node, NodeItem


if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsEnterNode)) { if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsEnterNode)) {
// Enter --> StreamActive --> StreamMerge // Enter --> StreamActive --> StreamMerge
node_item->is_enter_active_ = true;
return CreateMergeEnterGroup(node, node_item); return CreateMergeEnterGroup(node, node_item);
} else if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsIterationNode)) { } else if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsIterationNode)) {
// NextIteration --> StreamActive {-->} StreamMerge // NextIteration --> StreamActive {-->} StreamMerge


+ 2
- 0
ge/hybrid/model/hybrid_model_builder.h View File

@@ -97,6 +97,7 @@ class HybridModelBuilder {


Status RelinkNextIteration(); Status RelinkNextIteration();
Status BuildProfilingControl(GraphItem &graph_item, const std::map<size_t, std::pair<uint32_t, uint32_t>> &nodes); Status BuildProfilingControl(GraphItem &graph_item, const std::map<size_t, std::pair<uint32_t, uint32_t>> &nodes);
Status BuildFrameGroupIndex(NodeItem &node_item);
Status BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item); Status BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item);
Status CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item); Status CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item);
Status CreateMergeEnterGroup(const NodePtr &node, NodeItem *node_item); Status CreateMergeEnterGroup(const NodePtr &node, NodeItem *node_item);
@@ -123,6 +124,7 @@ class HybridModelBuilder {
std::map<std::string, NodePtr> constant_op_nodes_; std::map<std::string, NodePtr> constant_op_nodes_;
std::map<std::string, NodePtr> stream_merge_op_nodes_; std::map<std::string, NodePtr> stream_merge_op_nodes_;
std::map<std::string, NodePtr> next_iteration_op_nodes_; std::map<std::string, NodePtr> next_iteration_op_nodes_;
std::map<int64_t, int64_t> parent_frame_group_;
std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_; std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_;
std::map<NodeItem *, std::set<std::string>> node_to_parallel_groups_; std::map<NodeItem *, std::set<std::string>> node_to_parallel_groups_;




+ 5
- 6
ge/hybrid/model/node_item.cc View File

@@ -20,7 +20,6 @@
#include "graph/common/omg_util.h" #include "graph/common/omg_util.h"
#include "graph/compute_graph.h" #include "graph/compute_graph.h"
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"
#include "graph/utils/node_utils.h"
#include "hybrid/executor/worker/shape_inference_engine.h" #include "hybrid/executor/worker/shape_inference_engine.h"
#include "hybrid/node_executor/node_executor.h" #include "hybrid/node_executor/node_executor.h"


@@ -34,7 +33,7 @@ const std::set<std::string> kControlOpTypes{
}; };


const std::set<std::string> kControlFlowOpTypes{ const std::set<std::string> kControlFlowOpTypes{
STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT,
STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT,
LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX
}; };


@@ -402,8 +401,8 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) {
node_item->root_data_.emplace(this); node_item->root_data_.emplace(this);
} }
// If Enter feed Not Merge, take as root Node. // If Enter feed Not Merge, take as root Node.
if ((kEnterOpTypes.count(node_type) > 0) && (node_item->node_type != STREAMMERGE)) {
node_item->root_data_.emplace(this);
if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) {
node_item->enter_data_.emplace(this);
node_item->enter_inside_.emplace(anchor_index); node_item->enter_inside_.emplace(anchor_index);
} }
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str());
@@ -422,8 +421,8 @@ void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) {
node_item->root_ctrl_.emplace(this); node_item->root_ctrl_.emplace(this);
} }
// If Enter feed control signal, take as root Node. // If Enter feed control signal, take as root Node.
if (kEnterOpTypes.count(node_type) > 0) {
node_item->root_ctrl_.emplace(this);
if (IsEnterOp() && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) {
node_item->enter_ctrl_.emplace(this);
} }
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str());
} }


+ 14
- 0
ge/hybrid/model/node_item.h View File

@@ -22,6 +22,7 @@
#include "external/ge/ge_api_error_codes.h" #include "external/ge/ge_api_error_codes.h"
#include "graph/node.h" #include "graph/node.h"
#include "graph/op_desc.h" #include "graph/op_desc.h"
#include "graph/utils/node_utils.h"
#include "framework/common/types.h" #include "framework/common/types.h"
#include "hybrid/common/tensor_value.h" #include "hybrid/common/tensor_value.h"


@@ -92,6 +93,14 @@ struct NodeItem {
return is_merge_op_; return is_merge_op_;
} }


bool IsEnterOp() const {
return kEnterOpTypes.count(node_type) > 0;
}

bool IsExitOp() const {
return kExitOpTypes.count(node_type) > 0;
}

bool IsHcclOp() const; bool IsHcclOp() const;


void SetToDynamic(); void SetToDynamic();
@@ -135,8 +144,13 @@ struct NodeItem {
bool is_ctrl_flow_v2_op_ = false; bool is_ctrl_flow_v2_op_ = false;
bool is_ctrl_flow_op_ = false; bool is_ctrl_flow_op_ = false;
bool is_merge_op_ = false; bool is_merge_op_ = false;
bool is_enter_active_ = false;
int64_t frame_index_ = -1;
int64_t parent_frame_ = -1;
std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node
std::set<const NodeItem *> root_data_; // Recv data from root node std::set<const NodeItem *> root_data_; // Recv data from root node
std::set<const NodeItem *> enter_ctrl_; // Recv ctrl from Enter node
std::set<const NodeItem *> enter_data_; // Recv data from Enter node
std::set<const NodeItem *> data_send_; // Send data notify to std::set<const NodeItem *> data_send_; // Send data notify to
std::map<const NodeItem *, int> data_recv_; // Recv data notify from std::map<const NodeItem *, int> data_recv_; // Recv data notify from
std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to


+ 2
- 4
ge/hybrid/node_executor/rts/rts_node_task.cc View File

@@ -90,7 +90,7 @@ Status RtsNodeTask::GetScalarIndexValue(TaskContext &task_context, uint32_t inde
Status StreamActiveNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { Status StreamActiveNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
GELOGD("[%s] Start to execute.", task_context.GetNodeName()); GELOGD("[%s] Start to execute.", task_context.GetNodeName());
const auto &node_state = task_context.GetNodeState(); const auto &node_state = task_context.GetNodeState();
node_state->SetSwitchIndex(0);
node_state->RunStreamActive();
if (done_callback) { if (done_callback) {
GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
} }
@@ -204,9 +204,7 @@ Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::functio


const auto &node_state = task_context.GetNodeState(); const auto &node_state = task_context.GetNodeState();
if (kNextIterationOpTypes.count(node_state->GetType()) > 0) { if (kNextIterationOpTypes.count(node_state->GetType()) > 0) {
node_state->RunLoopNext();
} else if (kExitOpTypes.count(node_state->GetType()) > 0) {
node_state->RunLoopExit();
node_state->RunNextIteration();
} }


if (done_callback) { if (done_callback) {


+ 3
- 3
tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc View File

@@ -127,7 +127,7 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt
AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight);
} }


const auto less1 = CreateNode(graph, "less", EXIT, 2, 1); // Mock for less, just pass input0.
const auto less1 = CreateNode(graph, "less", IDENTITY, 2, 1); // Mock for less, just pass input0.


const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0); const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0);
switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0); switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0);
@@ -135,8 +135,8 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt
AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_EQUAL); // 101 for true. AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_EQUAL); // 101 for true.
AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_NOT_EQUAL); AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_NOT_EQUAL);


const auto add1 = CreateNode(graph, "add", EXIT, 2, 1); // Mock for add, just pass input0.
const auto sub1 = CreateNode(graph, "sub", EXIT, 2, 1); // Mock for sub, just pass input0.
const auto add1 = CreateNode(graph, "add", IDENTITY, 2, 1); // Mock for add, just pass input0.
const auto sub1 = CreateNode(graph, "sub", IDENTITY, 2, 1); // Mock for sub, just pass input0.


const auto merge1 = CreateNode(graph, "merge", STREAMMERGE, 2, 2); const auto merge1 = CreateNode(graph, "merge", STREAMMERGE, 2, 2);
const auto active2 = CreateNode(graph, "active2", STREAMACTIVE, 0, 0); const auto active2 = CreateNode(graph, "active2", STREAMACTIVE, 0, 0);


+ 31
- 4
tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc View File

@@ -89,7 +89,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) {
* \ / \. * \ / \.
* Switch Add * Switch Add
* / | | * / | |
* / | |
* Active / | |
* / | | * / | |
* LoopCond | | * LoopCond | |
* \ | | * \ | |
@@ -98,9 +98,10 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) {
* Less | | * Less | |
* \ | NextIteration * \ | NextIteration
* \ | | * \ | |
* \ | |
* \ | | Active
* Merge <---------| * Merge <---------|
* | * |
* | Active
* | * |
* Enter * Enter
******************************************************************************/ ******************************************************************************/
@@ -110,6 +111,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) {
GeModelPtr ge_sub_model = make_shared<GeModel>(); GeModelPtr ge_sub_model = make_shared<GeModel>();
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model);


auto data1 = CreateNode(*graph, "data", DATA, 1, 1);
auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1);
auto merge1 = CreateNode(*graph, "merge", STREAMMERGE, 2, 2); auto merge1 = CreateNode(*graph, "merge", STREAMMERGE, 2, 2);
auto less1 = CreateNode(*graph, "less", LESS, 2, 1); auto less1 = CreateNode(*graph, "less", LESS, 2, 1);
@@ -129,6 +131,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) {
auto active3 = CreateNode(*graph, "active3", STREAMACTIVE, 0, 0); auto active3 = CreateNode(*graph, "active3", STREAMACTIVE, 0, 0);
auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1);


GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0));
GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0));
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0));
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1));
@@ -153,8 +156,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) {
GraphUtils::AddEdge(active1->GetOutControlAnchor(), merge1->GetInControlAnchor()); GraphUtils::AddEdge(active1->GetOutControlAnchor(), merge1->GetInControlAnchor());


GraphUtils::AddEdge(next1->GetOutControlAnchor(), active3->GetInControlAnchor()); GraphUtils::AddEdge(next1->GetOutControlAnchor(), active3->GetInControlAnchor());
//GraphUtils::AddEdge(active3->GetOutControlAnchor(), merge1->GetInControlAnchor());
SetNextIteration(merge1, next1);
SetNextIteration(merge1, next1); // for relink NextIteration --> StreamMerge


GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_t->GetInControlAnchor()); // Test for not merge. GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_t->GetInControlAnchor()); // Test for not merge.


@@ -169,6 +171,17 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) {
AttrUtils::SetBool(add1->GetOpDesc(), ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true); AttrUtils::SetBool(add1->GetOpDesc(), ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true);
AttrUtils::SetBool(add1->GetOpDesc(), ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true); AttrUtils::SetBool(add1->GetOpDesc(), ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true);


SetControlFlowGroup(enter1, loop1->GetOpDesc()->GetId());
SetControlFlowGroup(active1, loop1->GetOpDesc()->GetId());
SetControlFlowGroup(merge1, loop1->GetOpDesc()->GetId());
SetControlFlowGroup(loop1, loop1->GetOpDesc()->GetId());
SetControlFlowGroup(active2, switch_t->GetOpDesc()->GetId());
SetControlFlowGroup(switch_t, switch_t->GetOpDesc()->GetId());
SetControlFlowGroup(switch_f, switch_t->GetOpDesc()->GetId());
SetControlFlowGroup(next1, loop1->GetOpDesc()->GetId());
SetControlFlowGroup(active3, loop1->GetOpDesc()->GetId());
SetControlFlowGroup(exit1, loop1->GetOpDesc()->GetId());

// Build -> IndexSpecialNodes --> stream_merge_op_nodes_ // Build -> IndexSpecialNodes --> stream_merge_op_nodes_
// Build -> LoadGraph -> RelinkNextIteration // Build -> LoadGraph -> RelinkNextIteration
// Build -> LoadGraph -> LoadDynamicSubgraph --> BuildNodeItem --> NodeItem::SetDataSend // Build -> LoadGraph -> LoadDynamicSubgraph --> BuildNodeItem --> NodeItem::SetDataSend
@@ -190,9 +203,23 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) {
task_executor.emplace(NodeExecutorManager::ExecutorType::RTS, std::unique_ptr<NodeExecutor>(new NodeExecutor())); task_executor.emplace(NodeExecutorManager::ExecutorType::RTS, std::unique_ptr<NodeExecutor>(new NodeExecutor()));
task_executor.emplace(NodeExecutorManager::ExecutorType::HOST_CPU, std::unique_ptr<NodeExecutor>(new NodeExecutor())); task_executor.emplace(NodeExecutorManager::ExecutorType::HOST_CPU, std::unique_ptr<NodeExecutor>(new NodeExecutor()));


const auto control_group_index = loop1->GetOpDesc()->GetId();
HybridModel hybrid_model(ge_root_model); HybridModel hybrid_model(ge_root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model); HybridModelBuilder hybrid_model_builder(hybrid_model);
ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS); ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS);

const auto TestFrameGroup = [&hybrid_model](const NodePtr &n, int64_t index) {
const auto it = hybrid_model.node_items_.find(n);
ASSERT_NE(hybrid_model.node_items_.end(), it);
ASSERT_EQ(it->second->frame_index_, index);
ASSERT_EQ(it->second->parent_frame_, -1);
};
TestFrameGroup(enter1, control_group_index);
TestFrameGroup(active1, control_group_index);
TestFrameGroup(active2, control_group_index);
TestFrameGroup(active3, control_group_index);
TestFrameGroup(output1, -1);

engine_mapping.clear(); engine_mapping.clear();
task_executor.clear(); task_executor.clear();
} }


+ 4
- 0
tests/ut/ge/hybrid/node_executor/rts/rts_node_task_unittest.cc View File

@@ -166,6 +166,10 @@ TEST_F(UtestRtsNodeTask, test_stream_active_task) {
std::function<void()> done = []() {}; std::function<void()> done = []() {};
ASSERT_EQ(node_state->GetSwitchIndex(), -1); ASSERT_EQ(node_state->GetSwitchIndex(), -1);
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS);
ASSERT_EQ(node_state->GetSwitchIndex(), -1);

node_item->ctrl_send_.emplace(nullptr);
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS);
ASSERT_EQ(node_state->GetSwitchIndex(), 0); ASSERT_EQ(node_state->GetSwitchIndex(), 0);
} }




Loading…
Cancel
Save