|
- import copy
-
- import numpy as np
- import torch
- import logging
- from collections import OrderedDict
- from torch import nn
-
- from pytorch.darts.dartsmutator import DartsMutator
- from pytorch.mutables import LayerChoice, InputChoice
-
- logger = logging.getLogger(__name__)
-
- class PdartsMutator(DartsMutator):
- """
- It works with PdartsTrainer to calculate ops weights,
- and drop weights in different PDARTS epochs.
- """
- def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches={}):
- self.pdarts_epoch_index = pdarts_epoch_index
- self.pdarts_num_to_drop = pdarts_num_to_drop
- # save the last two switches and choices for restrict skip
- self.last_two_switches = None
- self.last_two_choices = None
-
- if switches is None:
- self.switches = {}
- else:
- self.switches = switches
-
- super(PdartsMutator, self).__init__(model)
-
- # this loop go through mutables with different keys,
- # it's mainly to update length of choices.
- for mutable in self.mutables:
- if isinstance(mutable, LayerChoice):
- switches = self.switches.get(mutable.key, [True for j in range(len(mutable))])
- # choices = self.choices[mutable.key]
-
- operations_count = np.sum(switches)
- # +1 and -1 are caused by zero operation in darts network
- # the zero operation is not in choices list(switches) in network, but its weight are in,
- # so it needs one more weights and switch for zero.
- self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(operations_count + 1))
- self.switches[mutable.key] = switches
-
- # update LayerChoice instances in model,
- # it's physically remove dropped choices operations.
- for module in self.model.modules():
- if isinstance(module, LayerChoice):
- switches = self.switches.get(module.key)
- choices = self.choices[module.key]
- if len(module) > len(choices):
- # from last to first, so that it won't effect previous indexes after removed one.
- for index in range(len(switches)-1, -1, -1):
- if switches[index] == False:
- del module[index]
- assert len(module) <= len(choices), "Failed to remove dropped choices."
-
- def export(self, last, switches):
- # In last pdarts_epoches, need to restrict skipconnection
- # Cannot rely on super().export() because P-DARTS has deleted some of the choices and has misaligned length.
- if last:
- # restrict Up to 2 skipconnect (normal cell only)
- name = "normal"
- max_num = 2
- skip_num = self.check_skip_num(name, switches)
- logger.info("Initially, the number of skipconnect is {}.".format(skip_num))
- while skip_num > max_num:
- logger.info("Restricting {} skipconnect to {}.".format(skip_num, max_num))
- logger.info("Original normal_switch is {}.".format(switches))
- # update self.choices setting skip prob to 0 and self.switches setting skip prob to False
- switches = self.delete_min_sk(name, switches)
- logger.info("Restricted normal_switch is {}.".format(switches))
- skip_num = self.check_skip_num(name, switches)
-
- # from bool result convert to human readable by Mutator export()
- results = super().sample_final()
- for mutable in self.mutables:
- if isinstance(mutable, LayerChoice):
- # As some operations are dropped physically,
- # so it needs to fill back false to track dropped operations.
- trained_result = results[mutable.key]
- trained_index = 0
- switches = self.switches[mutable.key]
- result = torch.Tensor(switches).bool()
- for index in range(len(result)):
- if result[index]:
- result[index] = trained_result[trained_index]
- trained_index += 1
- results[mutable.key] = result
- return results
-
- def drop_paths(self):
- """
- This method is called when a PDARTS epoch is finished.
- It prepares switches for next epoch.
- candidate operations with False switch will be doppped in next epoch.
- """
- all_switches = copy.deepcopy(self.switches)
- for key in all_switches:
- switches = all_switches[key]
- idxs = []
- for j in range(len(switches)):
- if switches[j]:
- idxs.append(j)
- sorted_weights = self.choices[key].data.cpu().numpy()[:-1]
- drop = np.argsort(sorted_weights)[:self.pdarts_num_to_drop[self.pdarts_epoch_index]]
- for idx in drop:
- switches[idxs[idx]] = False
- return all_switches
-
-
- def check_skip_num(self, name, switches):
- counter = 0
- for key in switches:
- if name in key:
- # zero operation not in switches, so "skipconnect" in 2
- if switches[key][2]:
- counter += 1
- return counter
-
- def delete_min_sk(self, name, switches):
- def _get_sk_idx(key, switches):
- if not switches[key][2]:
- idx = -1
- else:
- idx = 0
- for i in range(2):
- # switches has 1 True, self.switches has 2 True
- if self.switches[key][i]:
- idx += 1
- return idx
- sk_choices = [1.0 for i in range(14)]
- sk_keys = [None for i in range(14)] # key has skip connection
- sk_choices_idx = -1
- for key in switches:
- if name in key:
- # default key in order
- sk_choices_idx += 1
- idx = _get_sk_idx(key, switches)
- if not idx == -1:
- sk_keys[sk_choices_idx] = key
- sk_choices[sk_choices_idx] = self.choices[key][idx]
- min_sk_idx = np.argmin(sk_choices)
- idx = _get_sk_idx(sk_keys[min_sk_idx], switches)
- # modify self.choices or copy.deepcopy ?
- self.choices[sk_keys[min_sk_idx]][idx] = 0.0
- # modify self.switches or copy.deepcopy ?
- # self.switches indicate last two switches, and switches indicate present(last) switches
- self.switches[sk_keys[min_sk_idx]][2] = False
- switches[sk_keys[min_sk_idx]][2] = False
- return switches
-
-
- def _generate_search_space(self):
- """
- Generate search space from mutables.
- Here is the search space format:
- ::
- { key_name: {"_type": "layer_choice",
- "_value": ["conv1", "conv2"]} }
- { key_name: {"_type": "input_choice",
- "_value": {"candidates": ["in1", "in2"],
- "n_chosen": 1}} }
- Returns
- -------
- dict
- the generated search space
- """
- res = OrderedDict()
- res["op_list"] = OrderedDict()
- res["search_space"] = {"reduction_cell": OrderedDict(), "normal_cell": OrderedDict()}
- keys = []
- for mutable in self.mutables:
- # for now we only generate flattened search space
- if (len(res["search_space"]["reduction_cell"]) + len(res["search_space"]["normal_cell"])) >= 36:
- break
-
- if isinstance(mutable, LayerChoice):
- key = mutable.key
- if key not in keys:
- val = mutable.names
- if not res["op_list"]:
- res["op_list"] = {"_type": "layer_choice", "_value": val + ["none"]}
- node_type = "normal_cell" if "normal" in key else "reduction_cell"
- res["search_space"][node_type][key] = "op_list"
- keys.append(key)
-
- elif isinstance(mutable, InputChoice):
- key = mutable.key
- if key not in keys:
- node_type = "normal_cell" if "normal" in key else "reduction_cell"
- res["search_space"][node_type][key] = {"_type": "input_choice",
- "_value": {"candidates": mutable.choose_from,
- "n_chosen": mutable.n_chosen}}
- keys.append(key)
- else:
- raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
-
- return res
|