|
|
@@ -24,7 +24,9 @@ using std::string; |
|
|
|
|
|
|
|
namespace ge { |
|
|
|
namespace { |
|
|
|
const int64_t kLoopType = 1; |
|
|
|
constexpr int64_t kLoopType = 1; |
|
|
|
constexpr uint8_t kMaxTransOp = 3; |
|
|
|
constexpr uint8_t kTransOpIoSize = 1; |
|
|
|
} |
|
|
|
|
|
|
|
Status NextIterationPass::Run(ComputeGraphPtr graph) { |
|
|
@@ -287,18 +289,25 @@ void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, i |
|
|
|
std::string node_type; |
|
|
|
for (const auto &switch_node : loop_group.switch_nodes) { |
|
|
|
SetControlFlowGroup(switch_node, group_index); |
|
|
|
for (const auto &node : switch_node->GetOutDataNodes()) { |
|
|
|
(void)GetOriginalType(node, node_type); |
|
|
|
if (kExitOpTypes.count(node_type) > 0) { |
|
|
|
SetControlFlowGroup(node, group_index); |
|
|
|
} else { |
|
|
|
// For: Switch -> Cast -> Exit |
|
|
|
for (const auto &n : node->GetOutDataNodes()) { |
|
|
|
(void)GetOriginalType(n, node_type); |
|
|
|
if (kExitOpTypes.count(node_type) > 0) { |
|
|
|
SetControlFlowGroup(n, group_index); |
|
|
|
} |
|
|
|
for (auto node : switch_node->GetOutDataNodes()) { |
|
|
|
// Switch --> Exit |
|
|
|
// Switch --> Cast --> Exit |
|
|
|
// Switch --> TransData --> Cast --> Exit |
|
|
|
for (uint8_t i = 0; i < kMaxTransOp; ++i) { |
|
|
|
if (node->GetInDataNodes().size() != kTransOpIoSize || node->GetAllOutDataAnchorsSize() != kTransOpIoSize) { |
|
|
|
break; |
|
|
|
} |
|
|
|
|
|
|
|
if (kExitOpTypes.count(NodeUtils::GetNodeType(node)) > 0) { |
|
|
|
SetControlFlowGroup(node, group_index); |
|
|
|
break; |
|
|
|
} |
|
|
|
|
|
|
|
const auto &all_nodes = node->GetOutAllNodes(); |
|
|
|
if (all_nodes.size() != kTransOpIoSize) { |
|
|
|
break; |
|
|
|
} |
|
|
|
node = all_nodes.at(0); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|