Browse Source

DSP: Switch -> TransData -> Cast -> Exit

tags/v1.5.1
zhangxiaokun 3 years ago
parent
commit
b9715a1458
2 changed files with 22 additions and 13 deletions
  1. +21
    -12
      ge/graph/passes/next_iteration_pass.cc
  2. +1
    -1
      ge/hybrid/executor/worker/execution_engine.cc

+ 21
- 12
ge/graph/passes/next_iteration_pass.cc View File

@@ -24,7 +24,9 @@ using std::string;

namespace ge {
namespace {
const int64_t kLoopType = 1;
constexpr int64_t kLoopType = 1;
constexpr uint8_t kMaxTransOp = 3;
constexpr uint8_t kTransOpIoSize = 1;
}

Status NextIterationPass::Run(ComputeGraphPtr graph) {
@@ -287,18 +289,25 @@ void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, i
std::string node_type;
for (const auto &switch_node : loop_group.switch_nodes) {
SetControlFlowGroup(switch_node, group_index);
for (const auto &node : switch_node->GetOutDataNodes()) {
(void)GetOriginalType(node, node_type);
if (kExitOpTypes.count(node_type) > 0) {
SetControlFlowGroup(node, group_index);
} else {
// For: Switch -> Cast -> Exit
for (const auto &n : node->GetOutDataNodes()) {
(void)GetOriginalType(n, node_type);
if (kExitOpTypes.count(node_type) > 0) {
SetControlFlowGroup(n, group_index);
}
for (auto node : switch_node->GetOutDataNodes()) {
// Switch --> Exit
// Switch --> Cast --> Exit
// Switch --> TransData --> Cast --> Exit
for (uint8_t i = 0; i < kMaxTransOp; ++i) {
if (node->GetInDataNodes().size() != kTransOpIoSize || node->GetAllOutDataAnchorsSize() != kTransOpIoSize) {
break;
}

if (kExitOpTypes.count(NodeUtils::GetNodeType(node)) > 0) {
SetControlFlowGroup(node, group_index);
break;
}

const auto &all_nodes = node->GetOutAllNodes();
if (all_nodes.size() != kTransOpIoSize) {
break;
}
node = all_nodes.at(0);
}
}
}


+ 1
- 1
ge/hybrid/executor/worker/execution_engine.cc View File

@@ -373,9 +373,9 @@ Status ExecutionEngine::DoExecuteAsync(NodeState &node_state,
auto executor = node_item.node_executor;
GE_CHECK_NOTNULL(executor);
RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] Start");
node_state.UpdatePersistTensor();
GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[Prepare][Task] for [%s] failed.",
node_state.GetName().c_str());
node_state.UpdatePersistTensor();
RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] End");
GELOGD("[%s] Done task preparation successfully.", node_state.GetName().c_str());



Loading…
Cancel
Save