|
- # Code Modified from https://github.com/carpedm20/ENAS-pytorch
- """A module with NAS controller-related code."""
- import collections
- import os
-
- import torch
- import torch.nn.functional as F
-
- import fastNLP.automl.enas_utils as utils
- from fastNLP.automl.enas_utils import Node
-
-
- def _construct_dags(prev_nodes, activations, func_names, num_blocks):
- """Constructs a set of DAGs based on the actions, i.e., previous nodes and
- activation functions, sampled from the controller/policy pi.
-
- Args:
- prev_nodes: Previous node actions from the policy.
- activations: Activations sampled from the policy.
- func_names: Mapping from activation function names to functions.
- num_blocks: Number of blocks in the target RNN cell.
-
- Returns:
- A list of DAGs defined by the inputs.
-
- RNN cell DAGs are represented in the following way:
-
- 1. Each element (node) in a DAG is a list of `Node`s.
-
- 2. The `Node`s in the list dag[i] correspond to the subsequent nodes
- that take the output from node i as their own input.
-
- 3. dag[-1] is the node that takes input from x^{(t)} and h^{(t - 1)}.
- dag[-1] always feeds dag[0].
- dag[-1] acts as if `w_xc`, `w_hc`, `w_xh` and `w_hh` are its
- weights.
-
- 4. dag[N - 1] is the node that produces the hidden state passed to
- the next timestep. dag[N - 1] is also always a leaf node, and therefore
- is always averaged with the other leaf nodes and fed to the output
- decoder.
- """
- dags = []
- for nodes, func_ids in zip(prev_nodes, activations):
- dag = collections.defaultdict(list)
-
- # add first node
- dag[-1] = [Node(0, func_names[func_ids[0]])]
- dag[-2] = [Node(0, func_names[func_ids[0]])]
-
- # add following nodes
- for jdx, (idx, func_id) in enumerate(zip(nodes, func_ids[1:])):
- dag[utils.to_item(idx)].append(Node(jdx + 1, func_names[func_id]))
-
- leaf_nodes = set(range(num_blocks)) - dag.keys()
-
- # merge with avg
- for idx in leaf_nodes:
- dag[idx] = [Node(num_blocks, 'avg')]
-
- # This is actually y^{(t)}. h^{(t)} is node N - 1 in
- # the graph, where N Is the number of nodes. I.e., h^{(t)} takes
- # only one other node as its input.
- # last h[t] node
- last_node = Node(num_blocks + 1, 'h[t]')
- dag[num_blocks] = [last_node]
- dags.append(dag)
-
- return dags
-
-
- class Controller(torch.nn.Module):
- """Based on
- https://github.com/pytorch/examples/blob/master/word_language_model/model.py
-
- RL controllers do not necessarily have much to do with
- language models.
-
- Base the controller RNN on the GRU from:
- https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py
- """
- def __init__(self, num_blocks=4, controller_hid=100, cuda=False):
- torch.nn.Module.__init__(self)
-
- # `num_tokens` here is just the activation function
- # for every even step,
- self.shared_rnn_activations = ['tanh', 'ReLU', 'identity', 'sigmoid']
- self.num_tokens = [len(self.shared_rnn_activations)]
- self.controller_hid = controller_hid
- self.use_cuda = cuda
- self.num_blocks = num_blocks
- for idx in range(num_blocks):
- self.num_tokens += [idx + 1, len(self.shared_rnn_activations)]
- self.func_names = self.shared_rnn_activations
-
- num_total_tokens = sum(self.num_tokens)
-
- self.encoder = torch.nn.Embedding(num_total_tokens,
- controller_hid)
- self.lstm = torch.nn.LSTMCell(controller_hid, controller_hid)
-
- # Perhaps these weights in the decoder should be
- # shared? At least for the activation functions, which all have the
- # same size.
- self.decoders = []
- for idx, size in enumerate(self.num_tokens):
- decoder = torch.nn.Linear(controller_hid, size)
- self.decoders.append(decoder)
-
- self._decoders = torch.nn.ModuleList(self.decoders)
-
- self.reset_parameters()
- self.static_init_hidden = utils.keydefaultdict(self.init_hidden)
-
- def _get_default_hidden(key):
- return utils.get_variable(
- torch.zeros(key, self.controller_hid),
- self.use_cuda,
- requires_grad=False)
-
- self.static_inputs = utils.keydefaultdict(_get_default_hidden)
-
- def reset_parameters(self):
- init_range = 0.1
- for param in self.parameters():
- param.data.uniform_(-init_range, init_range)
- for decoder in self.decoders:
- decoder.bias.data.fill_(0)
-
- def forward(self, # pylint:disable=arguments-differ
- inputs,
- hidden,
- block_idx,
- is_embed):
- if not is_embed:
- embed = self.encoder(inputs)
- else:
- embed = inputs
-
- hx, cx = self.lstm(embed, hidden)
- logits = self.decoders[block_idx](hx)
-
- logits /= 5.0
-
- # # exploration
- # if self.args.mode == 'train':
- # logits = (2.5 * F.tanh(logits))
-
- return logits, (hx, cx)
-
- def sample(self, batch_size=1, with_details=False, save_dir=None):
- """Samples a set of `args.num_blocks` many computational nodes from the
- controller, where each node is made up of an activation function, and
- each node except the last also includes a previous node.
- """
- if batch_size < 1:
- raise Exception(f'Wrong batch_size: {batch_size} < 1')
-
- # [B, L, H]
- inputs = self.static_inputs[batch_size]
- hidden = self.static_init_hidden[batch_size]
-
- activations = []
- entropies = []
- log_probs = []
- prev_nodes = []
- # The RNN controller alternately outputs an activation,
- # followed by a previous node, for each block except the last one,
- # which only gets an activation function. The last node is the output
- # node, and its previous node is the average of all leaf nodes.
- for block_idx in range(2*(self.num_blocks - 1) + 1):
- logits, hidden = self.forward(inputs,
- hidden,
- block_idx,
- is_embed=(block_idx == 0))
-
- probs = F.softmax(logits, dim=-1)
- log_prob = F.log_softmax(logits, dim=-1)
- # .mean() for entropy?
- entropy = -(log_prob * probs).sum(1, keepdim=False)
-
- action = probs.multinomial(num_samples=1).data
- selected_log_prob = log_prob.gather(
- 1, utils.get_variable(action, requires_grad=False))
-
- # why the [:, 0] here? Should it be .squeeze(), or
- # .view()? Same below with `action`.
- entropies.append(entropy)
- log_probs.append(selected_log_prob[:, 0])
-
- # 0: function, 1: previous node
- mode = block_idx % 2
- inputs = utils.get_variable(
- action[:, 0] + sum(self.num_tokens[:mode]),
- requires_grad=False)
-
- if mode == 0:
- activations.append(action[:, 0])
- elif mode == 1:
- prev_nodes.append(action[:, 0])
-
- prev_nodes = torch.stack(prev_nodes).transpose(0, 1)
- activations = torch.stack(activations).transpose(0, 1)
-
- dags = _construct_dags(prev_nodes,
- activations,
- self.func_names,
- self.num_blocks)
-
- if save_dir is not None:
- for idx, dag in enumerate(dags):
- utils.draw_network(dag,
- os.path.join(save_dir, f'graph{idx}.png'))
-
- if with_details:
- return dags, torch.cat(log_probs), torch.cat(entropies)
-
- return dags
-
- def init_hidden(self, batch_size):
- zeros = torch.zeros(batch_size, self.controller_hid)
- return (utils.get_variable(zeros, self.use_cuda, requires_grad=False),
- utils.get_variable(zeros.clone(), self.use_cuda, requires_grad=False))
|