[new] Add ENAS (Efficient Neural Architecture Search)tags/v0.4.0
| @@ -0,0 +1,223 @@ | |||
| # Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||
| """A module with NAS controller-related code.""" | |||
| import collections | |||
| import os | |||
| import torch | |||
| import torch.nn.functional as F | |||
| import fastNLP | |||
| import fastNLP.models.enas_utils as utils | |||
| from fastNLP.models.enas_utils import Node | |||
| def _construct_dags(prev_nodes, activations, func_names, num_blocks): | |||
| """Constructs a set of DAGs based on the actions, i.e., previous nodes and | |||
| activation functions, sampled from the controller/policy pi. | |||
| Args: | |||
| prev_nodes: Previous node actions from the policy. | |||
| activations: Activations sampled from the policy. | |||
| func_names: Mapping from activation function names to functions. | |||
| num_blocks: Number of blocks in the target RNN cell. | |||
| Returns: | |||
| A list of DAGs defined by the inputs. | |||
| RNN cell DAGs are represented in the following way: | |||
| 1. Each element (node) in a DAG is a list of `Node`s. | |||
| 2. The `Node`s in the list dag[i] correspond to the subsequent nodes | |||
| that take the output from node i as their own input. | |||
| 3. dag[-1] is the node that takes input from x^{(t)} and h^{(t - 1)}. | |||
| dag[-1] always feeds dag[0]. | |||
| dag[-1] acts as if `w_xc`, `w_hc`, `w_xh` and `w_hh` are its | |||
| weights. | |||
| 4. dag[N - 1] is the node that produces the hidden state passed to | |||
| the next timestep. dag[N - 1] is also always a leaf node, and therefore | |||
| is always averaged with the other leaf nodes and fed to the output | |||
| decoder. | |||
| """ | |||
| dags = [] | |||
| for nodes, func_ids in zip(prev_nodes, activations): | |||
| dag = collections.defaultdict(list) | |||
| # add first node | |||
| dag[-1] = [Node(0, func_names[func_ids[0]])] | |||
| dag[-2] = [Node(0, func_names[func_ids[0]])] | |||
| # add following nodes | |||
| for jdx, (idx, func_id) in enumerate(zip(nodes, func_ids[1:])): | |||
| dag[utils.to_item(idx)].append(Node(jdx + 1, func_names[func_id])) | |||
| leaf_nodes = set(range(num_blocks)) - dag.keys() | |||
| # merge with avg | |||
| for idx in leaf_nodes: | |||
| dag[idx] = [Node(num_blocks, 'avg')] | |||
| # This is actually y^{(t)}. h^{(t)} is node N - 1 in | |||
| # the graph, where N Is the number of nodes. I.e., h^{(t)} takes | |||
| # only one other node as its input. | |||
| # last h[t] node | |||
| last_node = Node(num_blocks + 1, 'h[t]') | |||
| dag[num_blocks] = [last_node] | |||
| dags.append(dag) | |||
| return dags | |||
| class Controller(torch.nn.Module): | |||
| """Based on | |||
| https://github.com/pytorch/examples/blob/master/word_language_model/model.py | |||
| RL controllers do not necessarily have much to do with | |||
| language models. | |||
| Base the controller RNN on the GRU from: | |||
| https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py | |||
| """ | |||
| def __init__(self, num_blocks=4, controller_hid=100, cuda=False): | |||
| torch.nn.Module.__init__(self) | |||
| # `num_tokens` here is just the activation function | |||
| # for every even step, | |||
| self.shared_rnn_activations = ['tanh', 'ReLU', 'identity', 'sigmoid'] | |||
| self.num_tokens = [len(self.shared_rnn_activations)] | |||
| self.controller_hid = controller_hid | |||
| self.use_cuda = cuda | |||
| self.num_blocks = num_blocks | |||
| for idx in range(num_blocks): | |||
| self.num_tokens += [idx + 1, len(self.shared_rnn_activations)] | |||
| self.func_names = self.shared_rnn_activations | |||
| num_total_tokens = sum(self.num_tokens) | |||
| self.encoder = torch.nn.Embedding(num_total_tokens, | |||
| controller_hid) | |||
| self.lstm = torch.nn.LSTMCell(controller_hid, controller_hid) | |||
| # Perhaps these weights in the decoder should be | |||
| # shared? At least for the activation functions, which all have the | |||
| # same size. | |||
| self.decoders = [] | |||
| for idx, size in enumerate(self.num_tokens): | |||
| decoder = torch.nn.Linear(controller_hid, size) | |||
| self.decoders.append(decoder) | |||
| self._decoders = torch.nn.ModuleList(self.decoders) | |||
| self.reset_parameters() | |||
| self.static_init_hidden = utils.keydefaultdict(self.init_hidden) | |||
| def _get_default_hidden(key): | |||
| return utils.get_variable( | |||
| torch.zeros(key, self.controller_hid), | |||
| self.use_cuda, | |||
| requires_grad=False) | |||
| self.static_inputs = utils.keydefaultdict(_get_default_hidden) | |||
| def reset_parameters(self): | |||
| init_range = 0.1 | |||
| for param in self.parameters(): | |||
| param.data.uniform_(-init_range, init_range) | |||
| for decoder in self.decoders: | |||
| decoder.bias.data.fill_(0) | |||
| def forward(self, # pylint:disable=arguments-differ | |||
| inputs, | |||
| hidden, | |||
| block_idx, | |||
| is_embed): | |||
| if not is_embed: | |||
| embed = self.encoder(inputs) | |||
| else: | |||
| embed = inputs | |||
| hx, cx = self.lstm(embed, hidden) | |||
| logits = self.decoders[block_idx](hx) | |||
| logits /= 5.0 | |||
| # # exploration | |||
| # if self.args.mode == 'train': | |||
| # logits = (2.5 * F.tanh(logits)) | |||
| return logits, (hx, cx) | |||
| def sample(self, batch_size=1, with_details=False, save_dir=None): | |||
| """Samples a set of `args.num_blocks` many computational nodes from the | |||
| controller, where each node is made up of an activation function, and | |||
| each node except the last also includes a previous node. | |||
| """ | |||
| if batch_size < 1: | |||
| raise Exception(f'Wrong batch_size: {batch_size} < 1') | |||
| # [B, L, H] | |||
| inputs = self.static_inputs[batch_size] | |||
| hidden = self.static_init_hidden[batch_size] | |||
| activations = [] | |||
| entropies = [] | |||
| log_probs = [] | |||
| prev_nodes = [] | |||
| # The RNN controller alternately outputs an activation, | |||
| # followed by a previous node, for each block except the last one, | |||
| # which only gets an activation function. The last node is the output | |||
| # node, and its previous node is the average of all leaf nodes. | |||
| for block_idx in range(2*(self.num_blocks - 1) + 1): | |||
| logits, hidden = self.forward(inputs, | |||
| hidden, | |||
| block_idx, | |||
| is_embed=(block_idx == 0)) | |||
| probs = F.softmax(logits, dim=-1) | |||
| log_prob = F.log_softmax(logits, dim=-1) | |||
| # .mean() for entropy? | |||
| entropy = -(log_prob * probs).sum(1, keepdim=False) | |||
| action = probs.multinomial(num_samples=1).data | |||
| selected_log_prob = log_prob.gather( | |||
| 1, utils.get_variable(action, requires_grad=False)) | |||
| # why the [:, 0] here? Should it be .squeeze(), or | |||
| # .view()? Same below with `action`. | |||
| entropies.append(entropy) | |||
| log_probs.append(selected_log_prob[:, 0]) | |||
| # 0: function, 1: previous node | |||
| mode = block_idx % 2 | |||
| inputs = utils.get_variable( | |||
| action[:, 0] + sum(self.num_tokens[:mode]), | |||
| requires_grad=False) | |||
| if mode == 0: | |||
| activations.append(action[:, 0]) | |||
| elif mode == 1: | |||
| prev_nodes.append(action[:, 0]) | |||
| prev_nodes = torch.stack(prev_nodes).transpose(0, 1) | |||
| activations = torch.stack(activations).transpose(0, 1) | |||
| dags = _construct_dags(prev_nodes, | |||
| activations, | |||
| self.func_names, | |||
| self.num_blocks) | |||
| if save_dir is not None: | |||
| for idx, dag in enumerate(dags): | |||
| utils.draw_network(dag, | |||
| os.path.join(save_dir, f'graph{idx}.png')) | |||
| if with_details: | |||
| return dags, torch.cat(log_probs), torch.cat(entropies) | |||
| return dags | |||
| def init_hidden(self, batch_size): | |||
| zeros = torch.zeros(batch_size, self.controller_hid) | |||
| return (utils.get_variable(zeros, self.use_cuda, requires_grad=False), | |||
| utils.get_variable(zeros.clone(), self.use_cuda, requires_grad=False)) | |||
| @@ -0,0 +1,388 @@ | |||
| # Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||
| """Module containing the shared RNN model.""" | |||
| import numpy as np | |||
| import collections | |||
| import torch | |||
| from torch import nn | |||
| import torch.nn.functional as F | |||
| from torch.autograd import Variable | |||
| import fastNLP.models.enas_utils as utils | |||
| from fastNLP.models.base_model import BaseModel | |||
| import fastNLP.modules.encoder as encoder | |||
| def _get_dropped_weights(w_raw, dropout_p, is_training): | |||
| """Drops out weights to implement DropConnect. | |||
| Args: | |||
| w_raw: Full, pre-dropout, weights to be dropped out. | |||
| dropout_p: Proportion of weights to drop out. | |||
| is_training: True iff _shared_ model is training. | |||
| Returns: | |||
| The dropped weights. | |||
| Why does torch.nn.functional.dropout() return: | |||
| 1. `torch.autograd.Variable()` on the training loop | |||
| 2. `torch.nn.Parameter()` on the controller or eval loop, when | |||
| training = False... | |||
| Even though the call to `_setweights` in the Smerity repo's | |||
| `weight_drop.py` does not have this behaviour, and `F.dropout` always | |||
| returns `torch.autograd.Variable` there, even when `training=False`? | |||
| The above TODO is the reason for the hacky check for `torch.nn.Parameter`. | |||
| """ | |||
| dropped_w = F.dropout(w_raw, p=dropout_p, training=is_training) | |||
| if isinstance(dropped_w, torch.nn.Parameter): | |||
| dropped_w = dropped_w.clone() | |||
| return dropped_w | |||
| class EmbeddingDropout(torch.nn.Embedding): | |||
| """Class for dropping out embeddings by zero'ing out parameters in the | |||
| embedding matrix. | |||
| This is equivalent to dropping out particular words, e.g., in the sentence | |||
| 'the quick brown fox jumps over the lazy dog', dropping out 'the' would | |||
| lead to the sentence '### quick brown fox jumps over ### lazy dog' (in the | |||
| embedding vector space). | |||
| See 'A Theoretically Grounded Application of Dropout in Recurrent Neural | |||
| Networks', (Gal and Ghahramani, 2016). | |||
| """ | |||
| def __init__(self, | |||
| num_embeddings, | |||
| embedding_dim, | |||
| max_norm=None, | |||
| norm_type=2, | |||
| scale_grad_by_freq=False, | |||
| sparse=False, | |||
| dropout=0.1, | |||
| scale=None): | |||
| """Embedding constructor. | |||
| Args: | |||
| dropout: Dropout probability. | |||
| scale: Used to scale parameters of embedding weight matrix that are | |||
| not dropped out. Note that this is _in addition_ to the | |||
| `1/(1 - dropout)` scaling. | |||
| See `torch.nn.Embedding` for remaining arguments. | |||
| """ | |||
| torch.nn.Embedding.__init__(self, | |||
| num_embeddings=num_embeddings, | |||
| embedding_dim=embedding_dim, | |||
| max_norm=max_norm, | |||
| norm_type=norm_type, | |||
| scale_grad_by_freq=scale_grad_by_freq, | |||
| sparse=sparse) | |||
| self.dropout = dropout | |||
| assert (dropout >= 0.0) and (dropout < 1.0), ('Dropout must be >= 0.0 ' | |||
| 'and < 1.0') | |||
| self.scale = scale | |||
| def forward(self, inputs): # pylint:disable=arguments-differ | |||
| """Embeds `inputs` with the dropped out embedding weight matrix.""" | |||
| if self.training: | |||
| dropout = self.dropout | |||
| else: | |||
| dropout = 0 | |||
| if dropout: | |||
| mask = self.weight.data.new(self.weight.size(0), 1) | |||
| mask.bernoulli_(1 - dropout) | |||
| mask = mask.expand_as(self.weight) | |||
| mask = mask / (1 - dropout) | |||
| masked_weight = self.weight * Variable(mask) | |||
| else: | |||
| masked_weight = self.weight | |||
| if self.scale and self.scale != 1: | |||
| masked_weight = masked_weight * self.scale | |||
| return F.embedding(inputs, | |||
| masked_weight, | |||
| max_norm=self.max_norm, | |||
| norm_type=self.norm_type, | |||
| scale_grad_by_freq=self.scale_grad_by_freq, | |||
| sparse=self.sparse) | |||
| class LockedDropout(nn.Module): | |||
| # code from https://github.com/salesforce/awd-lstm-lm/blob/master/locked_dropout.py | |||
| def __init__(self): | |||
| super().__init__() | |||
| def forward(self, x, dropout=0.5): | |||
| if not self.training or not dropout: | |||
| return x | |||
| m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout) | |||
| mask = Variable(m, requires_grad=False) / (1 - dropout) | |||
| mask = mask.expand_as(x) | |||
| return mask * x | |||
| class ENASModel(BaseModel): | |||
| """Shared RNN model.""" | |||
| def __init__(self, embed_num, num_classes, num_blocks=4, cuda=False, shared_hid=1000, shared_embed=1000): | |||
| super(ENASModel, self).__init__() | |||
| self.use_cuda = cuda | |||
| self.shared_hid = shared_hid | |||
| self.num_blocks = num_blocks | |||
| self.decoder = nn.Linear(self.shared_hid, num_classes) | |||
| self.encoder = EmbeddingDropout(embed_num, | |||
| shared_embed, | |||
| dropout=0.1) | |||
| self.lockdrop = LockedDropout() | |||
| self.dag = None | |||
| # Tie weights | |||
| # self.decoder.weight = self.encoder.weight | |||
| # Since W^{x, c} and W^{h, c} are always summed, there | |||
| # is no point duplicating their bias offset parameter. Likewise for | |||
| # W^{x, h} and W^{h, h}. | |||
| self.w_xc = nn.Linear(shared_embed, self.shared_hid) | |||
| self.w_xh = nn.Linear(shared_embed, self.shared_hid) | |||
| # The raw weights are stored here because the hidden-to-hidden weights | |||
| # are weight dropped on the forward pass. | |||
| self.w_hc_raw = torch.nn.Parameter( | |||
| torch.Tensor(self.shared_hid, self.shared_hid)) | |||
| self.w_hh_raw = torch.nn.Parameter( | |||
| torch.Tensor(self.shared_hid, self.shared_hid)) | |||
| self.w_hc = None | |||
| self.w_hh = None | |||
| self.w_h = collections.defaultdict(dict) | |||
| self.w_c = collections.defaultdict(dict) | |||
| for idx in range(self.num_blocks): | |||
| for jdx in range(idx + 1, self.num_blocks): | |||
| self.w_h[idx][jdx] = nn.Linear(self.shared_hid, | |||
| self.shared_hid, | |||
| bias=False) | |||
| self.w_c[idx][jdx] = nn.Linear(self.shared_hid, | |||
| self.shared_hid, | |||
| bias=False) | |||
| self._w_h = nn.ModuleList([self.w_h[idx][jdx] | |||
| for idx in self.w_h | |||
| for jdx in self.w_h[idx]]) | |||
| self._w_c = nn.ModuleList([self.w_c[idx][jdx] | |||
| for idx in self.w_c | |||
| for jdx in self.w_c[idx]]) | |||
| self.batch_norm = None | |||
| # if args.mode == 'train': | |||
| # self.batch_norm = nn.BatchNorm1d(self.shared_hid) | |||
| # else: | |||
| # self.batch_norm = None | |||
| self.reset_parameters() | |||
| self.static_init_hidden = utils.keydefaultdict(self.init_hidden) | |||
| def setDAG(self, dag): | |||
| if self.dag is None: | |||
| self.dag = dag | |||
| def forward(self, word_seq, hidden=None): | |||
| inputs = torch.transpose(word_seq, 0, 1) | |||
| time_steps = inputs.size(0) | |||
| batch_size = inputs.size(1) | |||
| self.w_hh = _get_dropped_weights(self.w_hh_raw, | |||
| 0.5, | |||
| self.training) | |||
| self.w_hc = _get_dropped_weights(self.w_hc_raw, | |||
| 0.5, | |||
| self.training) | |||
| # hidden = self.static_init_hidden[batch_size] if hidden is None else hidden | |||
| hidden = self.static_init_hidden[batch_size] | |||
| embed = self.encoder(inputs) | |||
| embed = self.lockdrop(embed, 0.65 if self.training else 0) | |||
| # The norm of hidden states are clipped here because | |||
| # otherwise ENAS is especially prone to exploding activations on the | |||
| # forward pass. This could probably be fixed in a more elegant way, but | |||
| # it might be exposing a weakness in the ENAS algorithm as currently | |||
| # proposed. | |||
| # | |||
| # For more details, see | |||
| # https://github.com/carpedm20/ENAS-pytorch/issues/6 | |||
| clipped_num = 0 | |||
| max_clipped_norm = 0 | |||
| h1tohT = [] | |||
| logits = [] | |||
| for step in range(time_steps): | |||
| x_t = embed[step] | |||
| logit, hidden = self.cell(x_t, hidden, self.dag) | |||
| hidden_norms = hidden.norm(dim=-1) | |||
| max_norm = 25.0 | |||
| if hidden_norms.data.max() > max_norm: | |||
| # Just directly use the torch slice operations | |||
| # in PyTorch v0.4. | |||
| # | |||
| # This workaround for PyTorch v0.3.1 does everything in numpy, | |||
| # because the PyTorch slicing and slice assignment is too | |||
| # flaky. | |||
| hidden_norms = hidden_norms.data.cpu().numpy() | |||
| clipped_num += 1 | |||
| if hidden_norms.max() > max_clipped_norm: | |||
| max_clipped_norm = hidden_norms.max() | |||
| clip_select = hidden_norms > max_norm | |||
| clip_norms = hidden_norms[clip_select] | |||
| mask = np.ones(hidden.size()) | |||
| normalizer = max_norm/clip_norms | |||
| normalizer = normalizer[:, np.newaxis] | |||
| mask[clip_select] = normalizer | |||
| if self.use_cuda: | |||
| hidden *= torch.autograd.Variable( | |||
| torch.FloatTensor(mask).cuda(), requires_grad=False) | |||
| else: | |||
| hidden *= torch.autograd.Variable( | |||
| torch.FloatTensor(mask), requires_grad=False) | |||
| logits.append(logit) | |||
| h1tohT.append(hidden) | |||
| h1tohT = torch.stack(h1tohT) | |||
| output = torch.stack(logits) | |||
| raw_output = output | |||
| output = self.lockdrop(output, 0.4 if self.training else 0) | |||
| #Pooling | |||
| output = torch.mean(output, 0) | |||
| decoded = self.decoder(output) | |||
| extra_out = {'dropped': decoded, | |||
| 'hiddens': h1tohT, | |||
| 'raw': raw_output} | |||
| return {'pred': decoded, 'hidden': hidden, 'extra_out': extra_out} | |||
| def cell(self, x, h_prev, dag): | |||
| """Computes a single pass through the discovered RNN cell.""" | |||
| c = {} | |||
| h = {} | |||
| f = {} | |||
| f[0] = self.get_f(dag[-1][0].name) | |||
| c[0] = torch.sigmoid(self.w_xc(x) + F.linear(h_prev, self.w_hc, None)) | |||
| h[0] = (c[0]*f[0](self.w_xh(x) + F.linear(h_prev, self.w_hh, None)) + | |||
| (1 - c[0])*h_prev) | |||
| leaf_node_ids = [] | |||
| q = collections.deque() | |||
| q.append(0) | |||
| # Computes connections from the parent nodes `node_id` | |||
| # to their child nodes `next_id` recursively, skipping leaf nodes. A | |||
| # leaf node is a node whose id == `self.num_blocks`. | |||
| # | |||
| # Connections between parent i and child j should be computed as | |||
| # h_j = c_j*f_{ij}{(W^h_{ij}*h_i)} + (1 - c_j)*h_i, | |||
| # where c_j = \sigmoid{(W^c_{ij}*h_i)} | |||
| # | |||
| # See Training details from Section 3.1 of the paper. | |||
| # | |||
| # The following algorithm does a breadth-first (since `q.popleft()` is | |||
| # used) search over the nodes and computes all the hidden states. | |||
| while True: | |||
| if len(q) == 0: | |||
| break | |||
| node_id = q.popleft() | |||
| nodes = dag[node_id] | |||
| for next_node in nodes: | |||
| next_id = next_node.id | |||
| if next_id == self.num_blocks: | |||
| leaf_node_ids.append(node_id) | |||
| assert len(nodes) == 1, ('parent of leaf node should have ' | |||
| 'only one child') | |||
| continue | |||
| w_h = self.w_h[node_id][next_id] | |||
| w_c = self.w_c[node_id][next_id] | |||
| f[next_id] = self.get_f(next_node.name) | |||
| c[next_id] = torch.sigmoid(w_c(h[node_id])) | |||
| h[next_id] = (c[next_id]*f[next_id](w_h(h[node_id])) + | |||
| (1 - c[next_id])*h[node_id]) | |||
| q.append(next_id) | |||
| # Instead of averaging loose ends, perhaps there should | |||
| # be a set of separate unshared weights for each "loose" connection | |||
| # between each node in a cell and the output. | |||
| # | |||
| # As it stands, all weights W^h_{ij} are doing double duty by | |||
| # connecting both from i to j, as well as from i to the output. | |||
| # average all the loose ends | |||
| leaf_nodes = [h[node_id] for node_id in leaf_node_ids] | |||
| output = torch.mean(torch.stack(leaf_nodes, 2), -1) | |||
| # stabilizing the Updates of omega | |||
| if self.batch_norm is not None: | |||
| output = self.batch_norm(output) | |||
| return output, h[self.num_blocks - 1] | |||
| def init_hidden(self, batch_size): | |||
| zeros = torch.zeros(batch_size, self.shared_hid) | |||
| return utils.get_variable(zeros, self.use_cuda, requires_grad=False) | |||
| def get_f(self, name): | |||
| name = name.lower() | |||
| if name == 'relu': | |||
| f = torch.relu | |||
| elif name == 'tanh': | |||
| f = torch.tanh | |||
| elif name == 'identity': | |||
| f = lambda x: x | |||
| elif name == 'sigmoid': | |||
| f = torch.sigmoid | |||
| return f | |||
| @property | |||
| def num_parameters(self): | |||
| def size(p): | |||
| return np.prod(p.size()) | |||
| return sum([size(param) for param in self.parameters()]) | |||
| def reset_parameters(self): | |||
| init_range = 0.025 | |||
| # init_range = 0.025 if self.args.mode == 'train' else 0.04 | |||
| for param in self.parameters(): | |||
| param.data.uniform_(-init_range, init_range) | |||
| self.decoder.bias.data.fill_(0) | |||
| def predict(self, word_seq): | |||
| """ | |||
| :param word_seq: torch.LongTensor, [batch_size, seq_len] | |||
| :return predict: dict of torch.LongTensor, [batch_size, seq_len] | |||
| """ | |||
| output = self(word_seq) | |||
| _, predict = output['pred'].max(dim=1) | |||
| return {'pred': predict} | |||
| @@ -0,0 +1,385 @@ | |||
| # Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||
| import os | |||
| import time | |||
| from datetime import datetime | |||
| from datetime import timedelta | |||
| import numpy as np | |||
| import torch | |||
| import math | |||
| from torch import nn | |||
| try: | |||
| from tqdm.autonotebook import tqdm | |||
| except: | |||
| from fastNLP.core.utils import pseudo_tqdm as tqdm | |||
| from fastNLP.core.batch import Batch | |||
| from fastNLP.core.callback import CallbackManager, CallbackException | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.utils import CheckError | |||
| from fastNLP.core.utils import _move_dict_value_to_device | |||
| import fastNLP | |||
| import fastNLP.models.enas_utils as utils | |||
| from fastNLP.core.utils import _build_args | |||
| from torch.optim import Adam | |||
| def _get_no_grad_ctx_mgr(): | |||
| """Returns a the `torch.no_grad` context manager for PyTorch version >= | |||
| 0.4, or a no-op context manager otherwise. | |||
| """ | |||
| return torch.no_grad() | |||
| class ENASTrainer(fastNLP.Trainer): | |||
| """A class to wrap training code.""" | |||
| def __init__(self, train_data, model, controller, **kwargs): | |||
| """Constructor for training algorithm. | |||
| :param DataSet train_data: the training data | |||
| :param torch.nn.modules.module model: a PyTorch model | |||
| :param torch.nn.modules.module controller: a PyTorch model | |||
| """ | |||
| self.final_epochs = kwargs['final_epochs'] | |||
| kwargs.pop('final_epochs') | |||
| super(ENASTrainer, self).__init__(train_data, model, **kwargs) | |||
| self.controller_step = 0 | |||
| self.shared_step = 0 | |||
| self.max_length = 35 | |||
| self.shared = model | |||
| self.controller = controller | |||
| self.shared_optim = Adam( | |||
| self.shared.parameters(), | |||
| lr=20.0, | |||
| weight_decay=1e-7) | |||
| self.controller_optim = Adam( | |||
| self.controller.parameters(), | |||
| lr=3.5e-4) | |||
| def train(self, load_best_model=True): | |||
| """ | |||
| :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | |||
| 最好的模型参数。 | |||
| :return results: 返回一个字典类型的数据, 内含以下内容:: | |||
| seconds: float, 表示训练时长 | |||
| 以下三个内容只有在提供了dev_data的情况下会有。 | |||
| best_eval: Dict of Dict, 表示evaluation的结果 | |||
| best_epoch: int,在第几个epoch取得的最佳值 | |||
| best_step: int, 在第几个step(batch)更新取得的最佳值 | |||
| """ | |||
| results = {} | |||
| if self.n_epochs <= 0: | |||
| print(f"training epoch is {self.n_epochs}, nothing was done.") | |||
| results['seconds'] = 0. | |||
| return results | |||
| try: | |||
| if torch.cuda.is_available() and self.use_cuda: | |||
| self.model = self.model.cuda() | |||
| self._model_device = self.model.parameters().__next__().device | |||
| self._mode(self.model, is_test=False) | |||
| self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | |||
| start_time = time.time() | |||
| print("training epochs started " + self.start_time, flush=True) | |||
| try: | |||
| self.callback_manager.on_train_begin() | |||
| self._train() | |||
| self.callback_manager.on_train_end(self.model) | |||
| except (CallbackException, KeyboardInterrupt) as e: | |||
| self.callback_manager.on_exception(e, self.model) | |||
| if self.dev_data is not None: | |||
| print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||
| self.tester._format_eval_results(self.best_dev_perf),) | |||
| results['best_eval'] = self.best_dev_perf | |||
| results['best_epoch'] = self.best_dev_epoch | |||
| results['best_step'] = self.best_dev_step | |||
| if load_best_model: | |||
| model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) | |||
| load_succeed = self._load_model(self.model, model_name) | |||
| if load_succeed: | |||
| print("Reloaded the best model.") | |||
| else: | |||
| print("Fail to reload best model.") | |||
| finally: | |||
| pass | |||
| results['seconds'] = round(time.time() - start_time, 2) | |||
| return results | |||
| def _train(self): | |||
| if not self.use_tqdm: | |||
| from fastNLP.core.utils import pseudo_tqdm as inner_tqdm | |||
| else: | |||
| inner_tqdm = tqdm | |||
| self.step = 0 | |||
| start = time.time() | |||
| total_steps = (len(self.train_data) // self.batch_size + int( | |||
| len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||
| with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||
| avg_loss = 0 | |||
| data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||
| prefetch=self.prefetch) | |||
| for epoch in range(1, self.n_epochs+1): | |||
| pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | |||
| last_stage = (epoch > self.n_epochs + 1 - self.final_epochs) | |||
| if epoch == self.n_epochs + 1 - self.final_epochs: | |||
| print('Entering the final stage. (Only train the selected structure)') | |||
| # early stopping | |||
| self.callback_manager.on_epoch_begin(epoch, self.n_epochs) | |||
| # 1. Training the shared parameters omega of the child models | |||
| self.train_shared(pbar) | |||
| # 2. Training the controller parameters theta | |||
| if not last_stage: | |||
| self.train_controller() | |||
| if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | |||
| (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | |||
| and self.dev_data is not None: | |||
| if not last_stage: | |||
| self.derive() | |||
| eval_res = self._do_validation(epoch=epoch, step=self.step) | |||
| eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||
| total_steps) + \ | |||
| self.tester._format_eval_results(eval_res) | |||
| pbar.write(eval_str) | |||
| # lr decay; early stopping | |||
| self.callback_manager.on_epoch_end(epoch, self.n_epochs, self.optimizer) | |||
| # =============== epochs end =================== # | |||
| pbar.close() | |||
| # ============ tqdm end ============== # | |||
| def get_loss(self, inputs, targets, hidden, dags): | |||
| """Computes the loss for the same batch for M models. | |||
| This amounts to an estimate of the loss, which is turned into an | |||
| estimate for the gradients of the shared model. | |||
| """ | |||
| if not isinstance(dags, list): | |||
| dags = [dags] | |||
| loss = 0 | |||
| for dag in dags: | |||
| self.shared.setDAG(dag) | |||
| inputs = _build_args(self.shared.forward, **inputs) | |||
| inputs['hidden'] = hidden | |||
| result = self.shared(**inputs) | |||
| output, hidden, extra_out = result['pred'], result['hidden'], result['extra_out'] | |||
| self.callback_manager.on_loss_begin(targets, result) | |||
| sample_loss = self._compute_loss(result, targets) | |||
| loss += sample_loss | |||
| assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`' | |||
| return loss, hidden, extra_out | |||
| def train_shared(self, pbar=None, max_step=None, dag=None): | |||
| """Train the language model for 400 steps of minibatches of 64 | |||
| examples. | |||
| Args: | |||
| max_step: Used to run extra training steps as a warm-up. | |||
| dag: If not None, is used instead of calling sample(). | |||
| BPTT is truncated at 35 timesteps. | |||
| For each weight update, gradients are estimated by sampling M models | |||
| from the fixed controller policy, and averaging their gradients | |||
| computed on a batch of training data. | |||
| """ | |||
| model = self.shared | |||
| model.train() | |||
| self.controller.eval() | |||
| hidden = self.shared.init_hidden(self.batch_size) | |||
| abs_max_grad = 0 | |||
| abs_max_hidden_norm = 0 | |||
| step = 0 | |||
| raw_total_loss = 0 | |||
| total_loss = 0 | |||
| train_idx = 0 | |||
| avg_loss = 0 | |||
| data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||
| prefetch=self.prefetch) | |||
| for batch_x, batch_y in data_iterator: | |||
| _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||
| indices = data_iterator.get_batch_indices() | |||
| # negative sampling; replace unknown; re-weight batch_y | |||
| self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | |||
| # prediction = self._data_forward(self.model, batch_x) | |||
| dags = self.controller.sample(1) | |||
| inputs, targets = batch_x, batch_y | |||
| # self.callback_manager.on_loss_begin(batch_y, prediction) | |||
| loss, hidden, extra_out = self.get_loss(inputs, | |||
| targets, | |||
| hidden, | |||
| dags) | |||
| hidden.detach_() | |||
| avg_loss += loss.item() | |||
| # Is loss NaN or inf? requires_grad = False | |||
| self.callback_manager.on_backward_begin(loss, self.model) | |||
| self._grad_backward(loss) | |||
| self.callback_manager.on_backward_end(self.model) | |||
| self._update() | |||
| self.callback_manager.on_step_end(self.optimizer) | |||
| if (self.step+1) % self.print_every == 0: | |||
| if self.use_tqdm: | |||
| print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every) | |||
| pbar.update(self.print_every) | |||
| else: | |||
| end = time.time() | |||
| diff = timedelta(seconds=round(end - start)) | |||
| print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( | |||
| epoch, self.step, avg_loss, diff) | |||
| pbar.set_postfix_str(print_output) | |||
| avg_loss = 0 | |||
| self.step += 1 | |||
| step += 1 | |||
| self.shared_step += 1 | |||
| self.callback_manager.on_batch_end() | |||
| # ================= mini-batch end ==================== # | |||
| def get_reward(self, dag, entropies, hidden, valid_idx=0): | |||
| """Computes the perplexity of a single sampled model on a minibatch of | |||
| validation data. | |||
| """ | |||
| if not isinstance(entropies, np.ndarray): | |||
| entropies = entropies.data.cpu().numpy() | |||
| data_iterator = Batch(self.dev_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||
| prefetch=self.prefetch) | |||
| for inputs, targets in data_iterator: | |||
| valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag) | |||
| valid_loss = utils.to_item(valid_loss.data) | |||
| valid_ppl = math.exp(valid_loss) | |||
| R = 80 / valid_ppl | |||
| rewards = R + 1e-4 * entropies | |||
| return rewards, hidden | |||
| def train_controller(self): | |||
| """Fixes the shared parameters and updates the controller parameters. | |||
| The controller is updated with a score function gradient estimator | |||
| (i.e., REINFORCE), with the reward being c/valid_ppl, where valid_ppl | |||
| is computed on a minibatch of validation data. | |||
| A moving average baseline is used. | |||
| The controller is trained for 2000 steps per epoch (i.e., | |||
| first (Train Shared) phase -> second (Train Controller) phase). | |||
| """ | |||
| model = self.controller | |||
| model.train() | |||
| # Why can't we call shared.eval() here? Leads to loss | |||
| # being uniformly zero for the controller. | |||
| # self.shared.eval() | |||
| avg_reward_base = None | |||
| baseline = None | |||
| adv_history = [] | |||
| entropy_history = [] | |||
| reward_history = [] | |||
| hidden = self.shared.init_hidden(self.batch_size) | |||
| total_loss = 0 | |||
| valid_idx = 0 | |||
| for step in range(20): | |||
| # sample models | |||
| dags, log_probs, entropies = self.controller.sample( | |||
| with_details=True) | |||
| # calculate reward | |||
| np_entropies = entropies.data.cpu().numpy() | |||
| # No gradients should be backpropagated to the | |||
| # shared model during controller training, obviously. | |||
| with _get_no_grad_ctx_mgr(): | |||
| rewards, hidden = self.get_reward(dags, | |||
| np_entropies, | |||
| hidden, | |||
| valid_idx) | |||
| reward_history.extend(rewards) | |||
| entropy_history.extend(np_entropies) | |||
| # moving average baseline | |||
| if baseline is None: | |||
| baseline = rewards | |||
| else: | |||
| decay = 0.95 | |||
| baseline = decay * baseline + (1 - decay) * rewards | |||
| adv = rewards - baseline | |||
| adv_history.extend(adv) | |||
| # policy loss | |||
| loss = -log_probs*utils.get_variable(adv, | |||
| self.use_cuda, | |||
| requires_grad=False) | |||
| loss = loss.sum() # or loss.mean() | |||
| # update | |||
| self.controller_optim.zero_grad() | |||
| loss.backward() | |||
| self.controller_optim.step() | |||
| total_loss += utils.to_item(loss.data) | |||
| if ((step % 50) == 0) and (step > 0): | |||
| reward_history, adv_history, entropy_history = [], [], [] | |||
| total_loss = 0 | |||
| self.controller_step += 1 | |||
| # prev_valid_idx = valid_idx | |||
| # valid_idx = ((valid_idx + self.max_length) % | |||
| # (self.valid_data.size(0) - 1)) | |||
| # # Whenever we wrap around to the beginning of the | |||
| # # validation data, we reset the hidden states. | |||
| # if prev_valid_idx > valid_idx: | |||
| # hidden = self.shared.init_hidden(self.batch_size) | |||
| def derive(self, sample_num=10, valid_idx=0): | |||
| """We are always deriving based on the very first batch | |||
| of validation data? This seems wrong... | |||
| """ | |||
| hidden = self.shared.init_hidden(self.batch_size) | |||
| dags, _, entropies = self.controller.sample(sample_num, | |||
| with_details=True) | |||
| max_R = 0 | |||
| best_dag = None | |||
| for dag in dags: | |||
| R, _ = self.get_reward(dag, entropies, hidden, valid_idx) | |||
| if R.max() > max_R: | |||
| max_R = R.max() | |||
| best_dag = dag | |||
| self.model.setDAG(best_dag) | |||
| @@ -0,0 +1,56 @@ | |||
| # Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||
| from __future__ import print_function | |||
| from collections import defaultdict | |||
| import collections | |||
| from datetime import datetime | |||
| import os | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| from torch.autograd import Variable | |||
| def detach(h): | |||
| if type(h) == Variable: | |||
| return Variable(h.data) | |||
| else: | |||
| return tuple(detach(v) for v in h) | |||
| def get_variable(inputs, cuda=False, **kwargs): | |||
| if type(inputs) in [list, np.ndarray]: | |||
| inputs = torch.Tensor(inputs) | |||
| if cuda: | |||
| out = Variable(inputs.cuda(), **kwargs) | |||
| else: | |||
| out = Variable(inputs, **kwargs) | |||
| return out | |||
| def update_lr(optimizer, lr): | |||
| for param_group in optimizer.param_groups: | |||
| param_group['lr'] = lr | |||
| Node = collections.namedtuple('Node', ['id', 'name']) | |||
| class keydefaultdict(defaultdict): | |||
| def __missing__(self, key): | |||
| if self.default_factory is None: | |||
| raise KeyError(key) | |||
| else: | |||
| ret = self[key] = self.default_factory(key) | |||
| return ret | |||
| def to_item(x): | |||
| """Converts x, possibly scalar and possibly tensor, to a Python scalar.""" | |||
| if isinstance(x, (float, int)): | |||
| return x | |||
| if float(torch.__version__[0:3]) < 0.4: | |||
| assert (x.dim() == 1) and (len(x) == 1) | |||
| return x[0] | |||
| return x.item() | |||
| @@ -0,0 +1,112 @@ | |||
| import unittest | |||
| from fastNLP import DataSet | |||
| from fastNLP import Instance | |||
| from fastNLP import Vocabulary | |||
| from fastNLP.core.losses import CrossEntropyLoss | |||
| from fastNLP.core.metrics import AccuracyMetric | |||
| class TestENAS(unittest.TestCase): | |||
| def testENAS(self): | |||
| # 从csv读取数据到DataSet | |||
| sample_path = "tutorials/sample_data/tutorial_sample_dataset.csv" | |||
| dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'), | |||
| sep='\t') | |||
| print(len(dataset)) | |||
| print(dataset[0]) | |||
| print(dataset[-3]) | |||
| dataset.append(Instance(raw_sentence='fake data', label='0')) | |||
| # 将所有数字转为小写 | |||
| dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence') | |||
| # label转int | |||
| dataset.apply(lambda x: int(x['label']), new_field_name='label') | |||
| # 使用空格分割句子 | |||
| def split_sent(ins): | |||
| return ins['raw_sentence'].split() | |||
| dataset.apply(split_sent, new_field_name='words') | |||
| # 增加长度信息 | |||
| dataset.apply(lambda x: len(x['words']), new_field_name='seq_len') | |||
| print(len(dataset)) | |||
| print(dataset[0]) | |||
| # DataSet.drop(func)筛除数据 | |||
| dataset.drop(lambda x: x['seq_len'] <= 3) | |||
| print(len(dataset)) | |||
| # 设置DataSet中,哪些field要转为tensor | |||
| # set target,loss或evaluate中的golden,计算loss,模型评估时使用 | |||
| dataset.set_target("label") | |||
| # set input,模型forward时使用 | |||
| dataset.set_input("words", "seq_len") | |||
| # 分出测试集、训练集 | |||
| test_data, train_data = dataset.split(0.5) | |||
| print(len(test_data)) | |||
| print(len(train_data)) | |||
| # 构建词表, Vocabulary.add(word) | |||
| vocab = Vocabulary(min_freq=2) | |||
| train_data.apply(lambda x: [vocab.add(word) for word in x['words']]) | |||
| vocab.build_vocab() | |||
| # index句子, Vocabulary.to_index(word) | |||
| train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') | |||
| test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') | |||
| print(test_data[0]) | |||
| # 如果你们需要做强化学习或者GAN之类的项目,你们也可以使用这些数据预处理的工具 | |||
| from fastNLP.core.batch import Batch | |||
| from fastNLP.core.sampler import RandomSampler | |||
| batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler()) | |||
| for batch_x, batch_y in batch_iterator: | |||
| print("batch_x has: ", batch_x) | |||
| print("batch_y has: ", batch_y) | |||
| break | |||
| from fastNLP.models.enas_model import ENASModel | |||
| from fastNLP.models.enas_controller import Controller | |||
| model = ENASModel(embed_num=len(vocab), num_classes=5) | |||
| controller = Controller() | |||
| from fastNLP.models.enas_trainer import ENASTrainer | |||
| from copy import deepcopy | |||
| # 更改DataSet中对应field的名称,要以模型的forward等参数名一致 | |||
| train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致 | |||
| train_data.rename_field('label', 'label_seq') | |||
| test_data.rename_field('words', 'word_seq') | |||
| test_data.rename_field('label', 'label_seq') | |||
| loss = CrossEntropyLoss(pred="output", target="label_seq") | |||
| metric = AccuracyMetric(pred="predict", target="label_seq") | |||
| trainer = ENASTrainer(model=model, controller=controller, train_data=train_data, dev_data=test_data, | |||
| loss=CrossEntropyLoss(pred="output", target="label_seq"), | |||
| metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||
| check_code_level=-1, | |||
| save_path=None, | |||
| batch_size=32, | |||
| print_every=1, | |||
| n_epochs=3, | |||
| final_epochs=1) | |||
| trainer.train() | |||
| print('Train finished!') | |||
| # 调用Tester在test_data上评价效果 | |||
| from fastNLP import Tester | |||
| tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||
| batch_size=4) | |||
| acc = tester.test() | |||
| print(acc) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||