| @@ -145,17 +145,63 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: | |||
| /// @return | |||
| /// | |||
| void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) { | |||
| for (auto it = switch_groups.begin(); it != switch_groups.end(); ++it) { | |||
| const auto &op_node = it->first; | |||
| const auto &op_desc = op_node->GetOpDesc(); | |||
| if (op_desc->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||
| continue; | |||
| // Step 0: no group assigned. such as: | |||
| // Merge1{id=0, group=} => {Switch1{id=1, group=}, Switch2{id=2, group=}} | |||
| // Merge2{id=3, group=} => {Switch1{id=1, group=}, Switch3{id=4, group=}} | |||
| // Merge3{id=5, group=} => {Switch4{id=6, group=}, Switch5{id=7, group=}} | |||
| // Merge4{id=8, group=} => {Switch1{id=1, group=}, Switch5{id=7, group=}} | |||
| std::map<int64_t, int64_t> unique_groups; | |||
| const auto GetGroupIndex = [&unique_groups](const NodePtr &merge, const std::vector<NodePtr> &switch_group) { | |||
| int64_t group_index = merge->GetOpDesc()->GetId(); | |||
| std::set<int64_t> group_ids{group_index}; | |||
| for (const auto &node : switch_group) { | |||
| if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { | |||
| GELOGI("[%s] Get group from [%s], index[%ld]", merge->GetName().c_str(), node->GetName().c_str(), group_index); | |||
| group_ids.insert(group_index); | |||
| } | |||
| } | |||
| const auto it = unique_groups.find(group_index); | |||
| if (it != unique_groups.end()) { | |||
| group_index = it->second; | |||
| } | |||
| int64_t group_index = op_desc->GetId(); | |||
| SetControlFlowGroup(op_node, group_index); | |||
| for (const auto &n : it->second) { | |||
| SetControlFlowGroup(n, group_index); | |||
| for (auto id : group_ids) { | |||
| unique_groups[id] = group_index; | |||
| } | |||
| return group_index; | |||
| }; | |||
| const auto SetGroupIndex = [](const NodePtr &merge, const std::vector<NodePtr> &switch_group, int64_t group_index) { | |||
| SetControlFlowGroup(merge, group_index); | |||
| for (const auto &node : switch_group) { | |||
| SetControlFlowGroup(node, group_index); | |||
| } | |||
| }; | |||
| // Step 1: Set group index to merge, if switch already has group, use assigned group. | |||
| // Merge1{id=0, group=0} => {Switch1{id=1, group=0}, Switch2{id=2, group=0}} | |||
| // Merge2{id=3, group=0} => {Switch1{id=1, group=0}, Switch3{id=4, group=0}} | |||
| // Merge3{id=5, group=5} => {Switch4{id=6, group=5}, Switch5{id=7, group=5}} | |||
| // Merge4{id=8, group=0} => {Switch1{id=1, group=0}, Switch5{id=7, group=0}} | |||
| for (const auto group : switch_groups) { | |||
| int64_t group_index = GetGroupIndex(group.first, group.second); | |||
| SetGroupIndex(group.first, group.second, group_index); | |||
| } | |||
| // Step 2: Adjust crossed merge group for unique group. | |||
| // Merge1{id=0, group=0} => {Switch1{id=1, group=0}, Switch2{id=2, group=0}} | |||
| // Merge2{id=3, group=0} => {Switch1{id=1, group=0}, Switch3{id=4, group=0}} | |||
| // Merge3{id=5, group=0} => {Switch4{id=6, group=0}, Switch5{id=7, group=0}} | |||
| // Merge4{id=8, group=0} => {Switch1{id=1, group=0}, Switch5{id=7, group=0}} | |||
| for (const auto group : switch_groups) { | |||
| int64_t group_index = -1; | |||
| (void)AttrUtils::GetInt(group.first->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
| const auto it = unique_groups.find(group_index); | |||
| if (it != unique_groups.end() && it->first != it->second) { | |||
| SetGroupIndex(group.first, group.second, it->second); | |||
| } | |||
| } | |||
| } | |||