diff --git a/fastNLP/automl/__init__.py b/fastNLP/automl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastNLP/automl/enas_controller.py b/fastNLP/automl/enas_controller.py new file mode 100644 index 00000000..6ddbb211 --- /dev/null +++ b/fastNLP/automl/enas_controller.py @@ -0,0 +1,223 @@ +# Code Modified from https://github.com/carpedm20/ENAS-pytorch +"""A module with NAS controller-related code.""" +import collections +import os + +import torch +import torch.nn.functional as F + +import fastNLP.automl.enas_utils as utils +from fastNLP.automl.enas_utils import Node + + +def _construct_dags(prev_nodes, activations, func_names, num_blocks): + """Constructs a set of DAGs based on the actions, i.e., previous nodes and + activation functions, sampled from the controller/policy pi. + + Args: + prev_nodes: Previous node actions from the policy. + activations: Activations sampled from the policy. + func_names: Mapping from activation function names to functions. + num_blocks: Number of blocks in the target RNN cell. + + Returns: + A list of DAGs defined by the inputs. + + RNN cell DAGs are represented in the following way: + + 1. Each element (node) in a DAG is a list of `Node`s. + + 2. The `Node`s in the list dag[i] correspond to the subsequent nodes + that take the output from node i as their own input. + + 3. dag[-1] is the node that takes input from x^{(t)} and h^{(t - 1)}. + dag[-1] always feeds dag[0]. + dag[-1] acts as if `w_xc`, `w_hc`, `w_xh` and `w_hh` are its + weights. + + 4. dag[N - 1] is the node that produces the hidden state passed to + the next timestep. dag[N - 1] is also always a leaf node, and therefore + is always averaged with the other leaf nodes and fed to the output + decoder. + """ + dags = [] + for nodes, func_ids in zip(prev_nodes, activations): + dag = collections.defaultdict(list) + + # add first node + dag[-1] = [Node(0, func_names[func_ids[0]])] + dag[-2] = [Node(0, func_names[func_ids[0]])] + + # add following nodes + for jdx, (idx, func_id) in enumerate(zip(nodes, func_ids[1:])): + dag[utils.to_item(idx)].append(Node(jdx + 1, func_names[func_id])) + + leaf_nodes = set(range(num_blocks)) - dag.keys() + + # merge with avg + for idx in leaf_nodes: + dag[idx] = [Node(num_blocks, 'avg')] + + # This is actually y^{(t)}. h^{(t)} is node N - 1 in + # the graph, where N Is the number of nodes. I.e., h^{(t)} takes + # only one other node as its input. + # last h[t] node + last_node = Node(num_blocks + 1, 'h[t]') + dag[num_blocks] = [last_node] + dags.append(dag) + + return dags + + +class Controller(torch.nn.Module): + """Based on + https://github.com/pytorch/examples/blob/master/word_language_model/model.py + + RL controllers do not necessarily have much to do with + language models. + + Base the controller RNN on the GRU from: + https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py + """ + def __init__(self, num_blocks=4, controller_hid=100, cuda=False): + torch.nn.Module.__init__(self) + + # `num_tokens` here is just the activation function + # for every even step, + self.shared_rnn_activations = ['tanh', 'ReLU', 'identity', 'sigmoid'] + self.num_tokens = [len(self.shared_rnn_activations)] + self.controller_hid = controller_hid + self.use_cuda = cuda + self.num_blocks = num_blocks + for idx in range(num_blocks): + self.num_tokens += [idx + 1, len(self.shared_rnn_activations)] + self.func_names = self.shared_rnn_activations + + num_total_tokens = sum(self.num_tokens) + + self.encoder = torch.nn.Embedding(num_total_tokens, + controller_hid) + self.lstm = torch.nn.LSTMCell(controller_hid, controller_hid) + + # Perhaps these weights in the decoder should be + # shared? At least for the activation functions, which all have the + # same size. + self.decoders = [] + for idx, size in enumerate(self.num_tokens): + decoder = torch.nn.Linear(controller_hid, size) + self.decoders.append(decoder) + + self._decoders = torch.nn.ModuleList(self.decoders) + + self.reset_parameters() + self.static_init_hidden = utils.keydefaultdict(self.init_hidden) + + def _get_default_hidden(key): + return utils.get_variable( + torch.zeros(key, self.controller_hid), + self.use_cuda, + requires_grad=False) + + self.static_inputs = utils.keydefaultdict(_get_default_hidden) + + def reset_parameters(self): + init_range = 0.1 + for param in self.parameters(): + param.data.uniform_(-init_range, init_range) + for decoder in self.decoders: + decoder.bias.data.fill_(0) + + def forward(self, # pylint:disable=arguments-differ + inputs, + hidden, + block_idx, + is_embed): + if not is_embed: + embed = self.encoder(inputs) + else: + embed = inputs + + hx, cx = self.lstm(embed, hidden) + logits = self.decoders[block_idx](hx) + + logits /= 5.0 + + # # exploration + # if self.args.mode == 'train': + # logits = (2.5 * F.tanh(logits)) + + return logits, (hx, cx) + + def sample(self, batch_size=1, with_details=False, save_dir=None): + """Samples a set of `args.num_blocks` many computational nodes from the + controller, where each node is made up of an activation function, and + each node except the last also includes a previous node. + """ + if batch_size < 1: + raise Exception(f'Wrong batch_size: {batch_size} < 1') + + # [B, L, H] + inputs = self.static_inputs[batch_size] + hidden = self.static_init_hidden[batch_size] + + activations = [] + entropies = [] + log_probs = [] + prev_nodes = [] + # The RNN controller alternately outputs an activation, + # followed by a previous node, for each block except the last one, + # which only gets an activation function. The last node is the output + # node, and its previous node is the average of all leaf nodes. + for block_idx in range(2*(self.num_blocks - 1) + 1): + logits, hidden = self.forward(inputs, + hidden, + block_idx, + is_embed=(block_idx == 0)) + + probs = F.softmax(logits, dim=-1) + log_prob = F.log_softmax(logits, dim=-1) + # .mean() for entropy? + entropy = -(log_prob * probs).sum(1, keepdim=False) + + action = probs.multinomial(num_samples=1).data + selected_log_prob = log_prob.gather( + 1, utils.get_variable(action, requires_grad=False)) + + # why the [:, 0] here? Should it be .squeeze(), or + # .view()? Same below with `action`. + entropies.append(entropy) + log_probs.append(selected_log_prob[:, 0]) + + # 0: function, 1: previous node + mode = block_idx % 2 + inputs = utils.get_variable( + action[:, 0] + sum(self.num_tokens[:mode]), + requires_grad=False) + + if mode == 0: + activations.append(action[:, 0]) + elif mode == 1: + prev_nodes.append(action[:, 0]) + + prev_nodes = torch.stack(prev_nodes).transpose(0, 1) + activations = torch.stack(activations).transpose(0, 1) + + dags = _construct_dags(prev_nodes, + activations, + self.func_names, + self.num_blocks) + + if save_dir is not None: + for idx, dag in enumerate(dags): + utils.draw_network(dag, + os.path.join(save_dir, f'graph{idx}.png')) + + if with_details: + return dags, torch.cat(log_probs), torch.cat(entropies) + + return dags + + def init_hidden(self, batch_size): + zeros = torch.zeros(batch_size, self.controller_hid) + return (utils.get_variable(zeros, self.use_cuda, requires_grad=False), + utils.get_variable(zeros.clone(), self.use_cuda, requires_grad=False)) diff --git a/fastNLP/automl/enas_model.py b/fastNLP/automl/enas_model.py new file mode 100644 index 00000000..4f9fb449 --- /dev/null +++ b/fastNLP/automl/enas_model.py @@ -0,0 +1,388 @@ +# Code Modified from https://github.com/carpedm20/ENAS-pytorch + +"""Module containing the shared RNN model.""" +import collections + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.autograd import Variable + +import fastNLP.automl.enas_utils as utils +from fastNLP.models.base_model import BaseModel + + +def _get_dropped_weights(w_raw, dropout_p, is_training): + """Drops out weights to implement DropConnect. + + Args: + w_raw: Full, pre-dropout, weights to be dropped out. + dropout_p: Proportion of weights to drop out. + is_training: True iff _shared_ model is training. + + Returns: + The dropped weights. + + Why does torch.nn.functional.dropout() return: + 1. `torch.autograd.Variable()` on the training loop + 2. `torch.nn.Parameter()` on the controller or eval loop, when + training = False... + + Even though the call to `_setweights` in the Smerity repo's + `weight_drop.py` does not have this behaviour, and `F.dropout` always + returns `torch.autograd.Variable` there, even when `training=False`? + + The above TODO is the reason for the hacky check for `torch.nn.Parameter`. + """ + dropped_w = F.dropout(w_raw, p=dropout_p, training=is_training) + + if isinstance(dropped_w, torch.nn.Parameter): + dropped_w = dropped_w.clone() + + return dropped_w + +class EmbeddingDropout(torch.nn.Embedding): + """Class for dropping out embeddings by zero'ing out parameters in the + embedding matrix. + + This is equivalent to dropping out particular words, e.g., in the sentence + 'the quick brown fox jumps over the lazy dog', dropping out 'the' would + lead to the sentence '### quick brown fox jumps over ### lazy dog' (in the + embedding vector space). + + See 'A Theoretically Grounded Application of Dropout in Recurrent Neural + Networks', (Gal and Ghahramani, 2016). + """ + def __init__(self, + num_embeddings, + embedding_dim, + max_norm=None, + norm_type=2, + scale_grad_by_freq=False, + sparse=False, + dropout=0.1, + scale=None): + """Embedding constructor. + + Args: + dropout: Dropout probability. + scale: Used to scale parameters of embedding weight matrix that are + not dropped out. Note that this is _in addition_ to the + `1/(1 - dropout)` scaling. + + See `torch.nn.Embedding` for remaining arguments. + """ + torch.nn.Embedding.__init__(self, + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse) + self.dropout = dropout + assert (dropout >= 0.0) and (dropout < 1.0), ('Dropout must be >= 0.0 ' + 'and < 1.0') + self.scale = scale + + def forward(self, inputs): # pylint:disable=arguments-differ + """Embeds `inputs` with the dropped out embedding weight matrix.""" + if self.training: + dropout = self.dropout + else: + dropout = 0 + + if dropout: + mask = self.weight.data.new(self.weight.size(0), 1) + mask.bernoulli_(1 - dropout) + mask = mask.expand_as(self.weight) + mask = mask / (1 - dropout) + masked_weight = self.weight * Variable(mask) + else: + masked_weight = self.weight + if self.scale and self.scale != 1: + masked_weight = masked_weight * self.scale + + return F.embedding(inputs, + masked_weight, + max_norm=self.max_norm, + norm_type=self.norm_type, + scale_grad_by_freq=self.scale_grad_by_freq, + sparse=self.sparse) + + +class LockedDropout(nn.Module): + # code from https://github.com/salesforce/awd-lstm-lm/blob/master/locked_dropout.py + def __init__(self): + super().__init__() + + def forward(self, x, dropout=0.5): + if not self.training or not dropout: + return x + m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout) + mask = Variable(m, requires_grad=False) / (1 - dropout) + mask = mask.expand_as(x) + return mask * x + + +class ENASModel(BaseModel): + """Shared RNN model.""" + def __init__(self, embed_num, num_classes, num_blocks=4, cuda=False, shared_hid=1000, shared_embed=1000): + super(ENASModel, self).__init__() + + self.use_cuda = cuda + + self.shared_hid = shared_hid + self.num_blocks = num_blocks + self.decoder = nn.Linear(self.shared_hid, num_classes) + self.encoder = EmbeddingDropout(embed_num, + shared_embed, + dropout=0.1) + self.lockdrop = LockedDropout() + self.dag = None + + # Tie weights + # self.decoder.weight = self.encoder.weight + + # Since W^{x, c} and W^{h, c} are always summed, there + # is no point duplicating their bias offset parameter. Likewise for + # W^{x, h} and W^{h, h}. + self.w_xc = nn.Linear(shared_embed, self.shared_hid) + self.w_xh = nn.Linear(shared_embed, self.shared_hid) + + # The raw weights are stored here because the hidden-to-hidden weights + # are weight dropped on the forward pass. + self.w_hc_raw = torch.nn.Parameter( + torch.Tensor(self.shared_hid, self.shared_hid)) + self.w_hh_raw = torch.nn.Parameter( + torch.Tensor(self.shared_hid, self.shared_hid)) + self.w_hc = None + self.w_hh = None + + self.w_h = collections.defaultdict(dict) + self.w_c = collections.defaultdict(dict) + + for idx in range(self.num_blocks): + for jdx in range(idx + 1, self.num_blocks): + self.w_h[idx][jdx] = nn.Linear(self.shared_hid, + self.shared_hid, + bias=False) + self.w_c[idx][jdx] = nn.Linear(self.shared_hid, + self.shared_hid, + bias=False) + + self._w_h = nn.ModuleList([self.w_h[idx][jdx] + for idx in self.w_h + for jdx in self.w_h[idx]]) + self._w_c = nn.ModuleList([self.w_c[idx][jdx] + for idx in self.w_c + for jdx in self.w_c[idx]]) + + self.batch_norm = None + # if args.mode == 'train': + # self.batch_norm = nn.BatchNorm1d(self.shared_hid) + # else: + # self.batch_norm = None + + self.reset_parameters() + self.static_init_hidden = utils.keydefaultdict(self.init_hidden) + + def setDAG(self, dag): + if self.dag is None: + self.dag = dag + + def forward(self, word_seq, hidden=None): + inputs = torch.transpose(word_seq, 0, 1) + + time_steps = inputs.size(0) + batch_size = inputs.size(1) + + + self.w_hh = _get_dropped_weights(self.w_hh_raw, + 0.5, + self.training) + self.w_hc = _get_dropped_weights(self.w_hc_raw, + 0.5, + self.training) + + # hidden = self.static_init_hidden[batch_size] if hidden is None else hidden + hidden = self.static_init_hidden[batch_size] + + embed = self.encoder(inputs) + + embed = self.lockdrop(embed, 0.65 if self.training else 0) + + # The norm of hidden states are clipped here because + # otherwise ENAS is especially prone to exploding activations on the + # forward pass. This could probably be fixed in a more elegant way, but + # it might be exposing a weakness in the ENAS algorithm as currently + # proposed. + # + # For more details, see + # https://github.com/carpedm20/ENAS-pytorch/issues/6 + clipped_num = 0 + max_clipped_norm = 0 + h1tohT = [] + logits = [] + for step in range(time_steps): + x_t = embed[step] + logit, hidden = self.cell(x_t, hidden, self.dag) + + hidden_norms = hidden.norm(dim=-1) + max_norm = 25.0 + if hidden_norms.data.max() > max_norm: + # Just directly use the torch slice operations + # in PyTorch v0.4. + # + # This workaround for PyTorch v0.3.1 does everything in numpy, + # because the PyTorch slicing and slice assignment is too + # flaky. + hidden_norms = hidden_norms.data.cpu().numpy() + + clipped_num += 1 + if hidden_norms.max() > max_clipped_norm: + max_clipped_norm = hidden_norms.max() + + clip_select = hidden_norms > max_norm + clip_norms = hidden_norms[clip_select] + + mask = np.ones(hidden.size()) + normalizer = max_norm/clip_norms + normalizer = normalizer[:, np.newaxis] + + mask[clip_select] = normalizer + + if self.use_cuda: + hidden *= torch.autograd.Variable( + torch.FloatTensor(mask).cuda(), requires_grad=False) + else: + hidden *= torch.autograd.Variable( + torch.FloatTensor(mask), requires_grad=False) + logits.append(logit) + h1tohT.append(hidden) + + h1tohT = torch.stack(h1tohT) + output = torch.stack(logits) + raw_output = output + + output = self.lockdrop(output, 0.4 if self.training else 0) + + #Pooling + output = torch.mean(output, 0) + + decoded = self.decoder(output) + + extra_out = {'dropped': decoded, + 'hiddens': h1tohT, + 'raw': raw_output} + return {'pred': decoded, 'hidden': hidden, 'extra_out': extra_out} + + def cell(self, x, h_prev, dag): + """Computes a single pass through the discovered RNN cell.""" + c = {} + h = {} + f = {} + + f[0] = self.get_f(dag[-1][0].name) + c[0] = torch.sigmoid(self.w_xc(x) + F.linear(h_prev, self.w_hc, None)) + h[0] = (c[0]*f[0](self.w_xh(x) + F.linear(h_prev, self.w_hh, None)) + + (1 - c[0])*h_prev) + + leaf_node_ids = [] + q = collections.deque() + q.append(0) + + # Computes connections from the parent nodes `node_id` + # to their child nodes `next_id` recursively, skipping leaf nodes. A + # leaf node is a node whose id == `self.num_blocks`. + # + # Connections between parent i and child j should be computed as + # h_j = c_j*f_{ij}{(W^h_{ij}*h_i)} + (1 - c_j)*h_i, + # where c_j = \sigmoid{(W^c_{ij}*h_i)} + # + # See Training details from Section 3.1 of the paper. + # + # The following algorithm does a breadth-first (since `q.popleft()` is + # used) search over the nodes and computes all the hidden states. + while True: + if len(q) == 0: + break + + node_id = q.popleft() + nodes = dag[node_id] + + for next_node in nodes: + next_id = next_node.id + if next_id == self.num_blocks: + leaf_node_ids.append(node_id) + assert len(nodes) == 1, ('parent of leaf node should have ' + 'only one child') + continue + + w_h = self.w_h[node_id][next_id] + w_c = self.w_c[node_id][next_id] + + f[next_id] = self.get_f(next_node.name) + c[next_id] = torch.sigmoid(w_c(h[node_id])) + h[next_id] = (c[next_id]*f[next_id](w_h(h[node_id])) + + (1 - c[next_id])*h[node_id]) + + q.append(next_id) + + # Instead of averaging loose ends, perhaps there should + # be a set of separate unshared weights for each "loose" connection + # between each node in a cell and the output. + # + # As it stands, all weights W^h_{ij} are doing double duty by + # connecting both from i to j, as well as from i to the output. + + # average all the loose ends + leaf_nodes = [h[node_id] for node_id in leaf_node_ids] + output = torch.mean(torch.stack(leaf_nodes, 2), -1) + + # stabilizing the Updates of omega + if self.batch_norm is not None: + output = self.batch_norm(output) + + return output, h[self.num_blocks - 1] + + def init_hidden(self, batch_size): + zeros = torch.zeros(batch_size, self.shared_hid) + return utils.get_variable(zeros, self.use_cuda, requires_grad=False) + + def get_f(self, name): + name = name.lower() + if name == 'relu': + f = torch.relu + elif name == 'tanh': + f = torch.tanh + elif name == 'identity': + f = lambda x: x + elif name == 'sigmoid': + f = torch.sigmoid + return f + + + @property + def num_parameters(self): + def size(p): + return np.prod(p.size()) + return sum([size(param) for param in self.parameters()]) + + + def reset_parameters(self): + init_range = 0.025 + # init_range = 0.025 if self.args.mode == 'train' else 0.04 + for param in self.parameters(): + param.data.uniform_(-init_range, init_range) + self.decoder.bias.data.fill_(0) + + def predict(self, word_seq): + """ + + :param word_seq: torch.LongTensor, [batch_size, seq_len] + :return predict: dict of torch.LongTensor, [batch_size, seq_len] + """ + output = self(word_seq) + _, predict = output['pred'].max(dim=1) + return {'pred': predict} diff --git a/fastNLP/automl/enas_trainer.py b/fastNLP/automl/enas_trainer.py new file mode 100644 index 00000000..7c0da752 --- /dev/null +++ b/fastNLP/automl/enas_trainer.py @@ -0,0 +1,382 @@ +# Code Modified from https://github.com/carpedm20/ENAS-pytorch + +import math +import time +from datetime import datetime +from datetime import timedelta + +import numpy as np +import torch + +try: + from tqdm.autonotebook import tqdm +except: + from fastNLP.core.utils import pseudo_tqdm as tqdm + +from fastNLP.core.batch import Batch +from fastNLP.core.callback import CallbackException +from fastNLP.core.dataset import DataSet +from fastNLP.core.utils import _move_dict_value_to_device +import fastNLP +import fastNLP.automl.enas_utils as utils +from fastNLP.core.utils import _build_args + +from torch.optim import Adam + + +def _get_no_grad_ctx_mgr(): + """Returns a the `torch.no_grad` context manager for PyTorch version >= + 0.4, or a no-op context manager otherwise. + """ + return torch.no_grad() + + +class ENASTrainer(fastNLP.Trainer): + """A class to wrap training code.""" + def __init__(self, train_data, model, controller, **kwargs): + """Constructor for training algorithm. + :param DataSet train_data: the training data + :param torch.nn.modules.module model: a PyTorch model + :param torch.nn.modules.module controller: a PyTorch model + """ + self.final_epochs = kwargs['final_epochs'] + kwargs.pop('final_epochs') + super(ENASTrainer, self).__init__(train_data, model, **kwargs) + self.controller_step = 0 + self.shared_step = 0 + self.max_length = 35 + + self.shared = model + self.controller = controller + + self.shared_optim = Adam( + self.shared.parameters(), + lr=20.0, + weight_decay=1e-7) + + self.controller_optim = Adam( + self.controller.parameters(), + lr=3.5e-4) + + def train(self, load_best_model=True): + """ + :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 + 最好的模型参数。 + :return results: 返回一个字典类型的数据, 内含以下内容:: + + seconds: float, 表示训练时长 + 以下三个内容只有在提供了dev_data的情况下会有。 + best_eval: Dict of Dict, 表示evaluation的结果 + best_epoch: int,在第几个epoch取得的最佳值 + best_step: int, 在第几个step(batch)更新取得的最佳值 + + """ + results = {} + if self.n_epochs <= 0: + print(f"training epoch is {self.n_epochs}, nothing was done.") + results['seconds'] = 0. + return results + try: + if torch.cuda.is_available() and self.use_cuda: + self.model = self.model.cuda() + self._model_device = self.model.parameters().__next__().device + self._mode(self.model, is_test=False) + + self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) + start_time = time.time() + print("training epochs started " + self.start_time, flush=True) + + try: + self.callback_manager.on_train_begin() + self._train() + self.callback_manager.on_train_end(self.model) + except (CallbackException, KeyboardInterrupt) as e: + self.callback_manager.on_exception(e, self.model) + + if self.dev_data is not None: + print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + + self.tester._format_eval_results(self.best_dev_perf),) + results['best_eval'] = self.best_dev_perf + results['best_epoch'] = self.best_dev_epoch + results['best_step'] = self.best_dev_step + if load_best_model: + model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) + load_succeed = self._load_model(self.model, model_name) + if load_succeed: + print("Reloaded the best model.") + else: + print("Fail to reload best model.") + finally: + pass + results['seconds'] = round(time.time() - start_time, 2) + + return results + + def _train(self): + if not self.use_tqdm: + from fastNLP.core.utils import pseudo_tqdm as inner_tqdm + else: + inner_tqdm = tqdm + self.step = 0 + start = time.time() + total_steps = (len(self.train_data) // self.batch_size + int( + len(self.train_data) % self.batch_size != 0)) * self.n_epochs + with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: + avg_loss = 0 + data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, + prefetch=self.prefetch) + for epoch in range(1, self.n_epochs+1): + pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) + last_stage = (epoch > self.n_epochs + 1 - self.final_epochs) + if epoch == self.n_epochs + 1 - self.final_epochs: + print('Entering the final stage. (Only train the selected structure)') + # early stopping + self.callback_manager.on_epoch_begin(epoch, self.n_epochs) + + # 1. Training the shared parameters omega of the child models + self.train_shared(pbar) + + # 2. Training the controller parameters theta + if not last_stage: + self.train_controller() + + if ((self.validate_every > 0 and self.step % self.validate_every == 0) or + (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ + and self.dev_data is not None: + if not last_stage: + self.derive() + eval_res = self._do_validation(epoch=epoch, step=self.step) + eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, + total_steps) + \ + self.tester._format_eval_results(eval_res) + pbar.write(eval_str) + + # lr decay; early stopping + self.callback_manager.on_epoch_end(epoch, self.n_epochs, self.optimizer) + # =============== epochs end =================== # + pbar.close() + # ============ tqdm end ============== # + + + def get_loss(self, inputs, targets, hidden, dags): + """Computes the loss for the same batch for M models. + + This amounts to an estimate of the loss, which is turned into an + estimate for the gradients of the shared model. + """ + if not isinstance(dags, list): + dags = [dags] + + loss = 0 + for dag in dags: + self.shared.setDAG(dag) + inputs = _build_args(self.shared.forward, **inputs) + inputs['hidden'] = hidden + result = self.shared(**inputs) + output, hidden, extra_out = result['pred'], result['hidden'], result['extra_out'] + + self.callback_manager.on_loss_begin(targets, result) + sample_loss = self._compute_loss(result, targets) + loss += sample_loss + + assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`' + return loss, hidden, extra_out + + def train_shared(self, pbar=None, max_step=None, dag=None): + """Train the language model for 400 steps of minibatches of 64 + examples. + + Args: + max_step: Used to run extra training steps as a warm-up. + dag: If not None, is used instead of calling sample(). + + BPTT is truncated at 35 timesteps. + + For each weight update, gradients are estimated by sampling M models + from the fixed controller policy, and averaging their gradients + computed on a batch of training data. + """ + model = self.shared + model.train() + self.controller.eval() + + hidden = self.shared.init_hidden(self.batch_size) + + abs_max_grad = 0 + abs_max_hidden_norm = 0 + step = 0 + raw_total_loss = 0 + total_loss = 0 + train_idx = 0 + avg_loss = 0 + data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, + prefetch=self.prefetch) + + for batch_x, batch_y in data_iterator: + _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) + indices = data_iterator.get_batch_indices() + # negative sampling; replace unknown; re-weight batch_y + self.callback_manager.on_batch_begin(batch_x, batch_y, indices) + # prediction = self._data_forward(self.model, batch_x) + + dags = self.controller.sample(1) + inputs, targets = batch_x, batch_y + # self.callback_manager.on_loss_begin(batch_y, prediction) + loss, hidden, extra_out = self.get_loss(inputs, + targets, + hidden, + dags) + hidden.detach_() + + avg_loss += loss.item() + + # Is loss NaN or inf? requires_grad = False + self.callback_manager.on_backward_begin(loss, self.model) + self._grad_backward(loss) + self.callback_manager.on_backward_end(self.model) + + self._update() + self.callback_manager.on_step_end(self.optimizer) + + if (self.step+1) % self.print_every == 0: + if self.use_tqdm: + print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every) + pbar.update(self.print_every) + else: + end = time.time() + diff = timedelta(seconds=round(end - start)) + print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( + epoch, self.step, avg_loss, diff) + pbar.set_postfix_str(print_output) + avg_loss = 0 + self.step += 1 + step += 1 + self.shared_step += 1 + self.callback_manager.on_batch_end() + # ================= mini-batch end ==================== # + + + def get_reward(self, dag, entropies, hidden, valid_idx=0): + """Computes the perplexity of a single sampled model on a minibatch of + validation data. + """ + if not isinstance(entropies, np.ndarray): + entropies = entropies.data.cpu().numpy() + + data_iterator = Batch(self.dev_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, + prefetch=self.prefetch) + + for inputs, targets in data_iterator: + valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag) + valid_loss = utils.to_item(valid_loss.data) + + valid_ppl = math.exp(valid_loss) + + R = 80 / valid_ppl + + rewards = R + 1e-4 * entropies + + return rewards, hidden + + def train_controller(self): + """Fixes the shared parameters and updates the controller parameters. + + The controller is updated with a score function gradient estimator + (i.e., REINFORCE), with the reward being c/valid_ppl, where valid_ppl + is computed on a minibatch of validation data. + + A moving average baseline is used. + + The controller is trained for 2000 steps per epoch (i.e., + first (Train Shared) phase -> second (Train Controller) phase). + """ + model = self.controller + model.train() + # Why can't we call shared.eval() here? Leads to loss + # being uniformly zero for the controller. + # self.shared.eval() + + avg_reward_base = None + baseline = None + adv_history = [] + entropy_history = [] + reward_history = [] + + hidden = self.shared.init_hidden(self.batch_size) + total_loss = 0 + valid_idx = 0 + for step in range(20): + # sample models + dags, log_probs, entropies = self.controller.sample( + with_details=True) + + # calculate reward + np_entropies = entropies.data.cpu().numpy() + # No gradients should be backpropagated to the + # shared model during controller training, obviously. + with _get_no_grad_ctx_mgr(): + rewards, hidden = self.get_reward(dags, + np_entropies, + hidden, + valid_idx) + + + reward_history.extend(rewards) + entropy_history.extend(np_entropies) + + # moving average baseline + if baseline is None: + baseline = rewards + else: + decay = 0.95 + baseline = decay * baseline + (1 - decay) * rewards + + adv = rewards - baseline + adv_history.extend(adv) + + # policy loss + loss = -log_probs*utils.get_variable(adv, + self.use_cuda, + requires_grad=False) + + loss = loss.sum() # or loss.mean() + + # update + self.controller_optim.zero_grad() + loss.backward() + + self.controller_optim.step() + + total_loss += utils.to_item(loss.data) + + if ((step % 50) == 0) and (step > 0): + reward_history, adv_history, entropy_history = [], [], [] + total_loss = 0 + + self.controller_step += 1 + # prev_valid_idx = valid_idx + # valid_idx = ((valid_idx + self.max_length) % + # (self.valid_data.size(0) - 1)) + # # Whenever we wrap around to the beginning of the + # # validation data, we reset the hidden states. + # if prev_valid_idx > valid_idx: + # hidden = self.shared.init_hidden(self.batch_size) + + def derive(self, sample_num=10, valid_idx=0): + """We are always deriving based on the very first batch + of validation data? This seems wrong... + """ + hidden = self.shared.init_hidden(self.batch_size) + + dags, _, entropies = self.controller.sample(sample_num, + with_details=True) + + max_R = 0 + best_dag = None + for dag in dags: + R, _ = self.get_reward(dag, entropies, hidden, valid_idx) + if R.max() > max_R: + max_R = R.max() + best_dag = dag + + self.model.setDAG(best_dag) diff --git a/fastNLP/automl/enas_utils.py b/fastNLP/automl/enas_utils.py new file mode 100644 index 00000000..7a53dd12 --- /dev/null +++ b/fastNLP/automl/enas_utils.py @@ -0,0 +1,53 @@ +# Code Modified from https://github.com/carpedm20/ENAS-pytorch + +from __future__ import print_function + +import collections +from collections import defaultdict + +import numpy as np +import torch +from torch.autograd import Variable + + +def detach(h): + if type(h) == Variable: + return Variable(h.data) + else: + return tuple(detach(v) for v in h) + +def get_variable(inputs, cuda=False, **kwargs): + if type(inputs) in [list, np.ndarray]: + inputs = torch.Tensor(inputs) + if cuda: + out = Variable(inputs.cuda(), **kwargs) + else: + out = Variable(inputs, **kwargs) + return out + +def update_lr(optimizer, lr): + for param_group in optimizer.param_groups: + param_group['lr'] = lr + +Node = collections.namedtuple('Node', ['id', 'name']) + + +class keydefaultdict(defaultdict): + def __missing__(self, key): + if self.default_factory is None: + raise KeyError(key) + else: + ret = self[key] = self.default_factory(key) + return ret + + +def to_item(x): + """Converts x, possibly scalar and possibly tensor, to a Python scalar.""" + if isinstance(x, (float, int)): + return x + + if float(torch.__version__[0:3]) < 0.4: + assert (x.dim() == 1) and (len(x) == 1) + return x[0] + + return x.item() diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 1bda1f93..437b647a 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -138,7 +138,6 @@ class CallbackManager(Callback): """ super(CallbackManager, self).__init__() # set attribute of trainer environment - self.env = env self.callbacks = [] if callbacks is not None: diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 601fa589..24376a72 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -157,7 +157,7 @@ class DataSet(object): assert name in self.field_arrays self.field_arrays[name].append(field) - def add_field(self, name, fields, padder=AutoPadder(pad_val=0), is_input=False, is_target=False): + def add_field(self, name, fields, padder=AutoPadder(pad_val=0), is_input=False, is_target=False, ignore_type=False): """Add a new field to the DataSet. :param str name: the name of the field. @@ -165,13 +165,14 @@ class DataSet(object): :param int padder: PadBase对象,如何对该Field进行padding。大部分情况使用默认值即可 :param bool is_input: whether this field is model input. :param bool is_target: whether this field is label or target. + :param bool ignore_type: If True, do not perform type check. (Default: False) """ if len(self.field_arrays) != 0: if len(self) != len(fields): raise RuntimeError(f"The field to append must have the same size as dataset. " f"Dataset size {len(self)} != field size {len(fields)}") self.field_arrays[name] = FieldArray(name, fields, is_target=is_target, is_input=is_input, - padder=padder) + padder=padder, ignore_type=ignore_type) def delete_field(self, name): """Delete a field based on the field name. @@ -242,6 +243,8 @@ class DataSet(object): :param padder: PadderBase类型或None. 设置为None即删除padder。即对该field不进行padding操作. :return: """ + if field_name not in self.field_arrays: + raise KeyError("There is no field named {}.".format(field_name)) self.field_arrays[field_name].set_padder(padder) def set_pad_val(self, field_name, pad_val): @@ -252,6 +255,8 @@ class DataSet(object): :param pad_val: int,该field的padder会以pad_val作为padding index :return: """ + if field_name not in self.field_arrays: + raise KeyError("There is no field named {}.".format(field_name)) self.field_arrays[field_name].set_pad_val(pad_val) def get_input_name(self): @@ -287,6 +292,8 @@ class DataSet(object): extra_param['is_input'] = kwargs['is_input'] if 'is_target' in kwargs: extra_param['is_target'] = kwargs['is_target'] + if 'ignore_type' in kwargs: + extra_param['ignore_type'] = kwargs['ignore_type'] if new_field_name is not None: if new_field_name in self.field_arrays: # overwrite the field, keep same attributes @@ -295,11 +302,14 @@ class DataSet(object): extra_param['is_input'] = old_field.is_input if 'is_target' not in extra_param: extra_param['is_target'] = old_field.is_target + if 'ignore_type' not in extra_param: + extra_param['ignore_type'] = old_field.ignore_type self.add_field(name=new_field_name, fields=results, is_input=extra_param["is_input"], - is_target=extra_param["is_target"]) + is_target=extra_param["is_target"], ignore_type=extra_param['ignore_type']) else: self.add_field(name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), - is_target=extra_param.get("is_target", None)) + is_target=extra_param.get("is_target", None), + ignore_type=extra_param.get("ignore_type", False)) else: return results diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index f3fcb3c8..72bb30b5 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -1,5 +1,5 @@ import numpy as np - +from copy import deepcopy class PadderBase: """ @@ -83,6 +83,8 @@ class AutoPadder(PadderBase): array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) for i, content in enumerate(contents): array[i][:len(content)] = content + elif field_ele_dtype is None: + array = contents # 当ignore_type=True时,直接返回contents else: # should only be str array = np.array([content for content in contents]) return array @@ -96,10 +98,16 @@ class FieldArray(object): :param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray. :param bool is_target: If True, this FieldArray is used to compute loss. :param bool is_input: If True, this FieldArray is used to the model input. - :param padder: PadderBase类型。大多数情况下都不需要设置该值,除非需要在多个维度上进行padding(比如英文中对character进行padding) + :param PadderBase padder: PadderBase类型。赋值给fieldarray的padder的对象会被deepcopy一份,需要修改padder参数必须通过 + fieldarray.set_pad_val()。 + 默认为None,(1)如果某个field是scalar,则不进行任何padding;(2)如果为一维list, 且fieldarray的dtype为float或int类型 + 则会进行padding;(3)其它情况不进行padder。 + 假设需要对English word中character进行padding,则需要使用其他的padder。 + 或ignore_type为True但是需要进行padding。 + :param bool ignore_type: whether to ignore type. If True, no type detection will rise for this FieldArray. (default: False) """ - def __init__(self, name, content, is_target=None, is_input=None, padder=AutoPadder(pad_val=0)): + def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False): """DataSet在初始化时会有两类方法对FieldArray操作: 1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray: 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) @@ -114,6 +122,7 @@ class FieldArray(object): 2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))]) 类型检查(dtype check)发生在当该field被设置为is_input或者is_target时。 + ignore_type用来控制是否进行类型检查,如果为True,则不检查。 """ self.name = name @@ -135,7 +144,13 @@ class FieldArray(object): self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐 self.content_dim = None # 表示content是多少维的list + if padder is None: + padder = AutoPadder(pad_val=0) + else: + assert isinstance(padder, PadderBase), "padder must be of type PadderBase." + padder = deepcopy(padder) self.set_padder(padder) + self.ignore_type = ignore_type self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array @@ -149,8 +164,9 @@ class FieldArray(object): self.is_target = is_target def _set_dtype(self): - self.pytype = self._type_detection(self.content) - self.dtype = self._map_to_np_type(self.pytype) + if self.ignore_type is False: + self.pytype = self._type_detection(self.content) + self.dtype = self._map_to_np_type(self.pytype) @property def is_input(self): @@ -190,7 +206,7 @@ class FieldArray(object): if list in type_set: if len(type_set) > 1: # list 跟 非list 混在一起 - raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) + raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) # >1维list inner_type_set = set() for l in content: @@ -213,7 +229,7 @@ class FieldArray(object): return self._basic_type_detection(inner_inner_type_set) else: # list 跟 非list 混在一起 - raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, inner_type_set)) + raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(inner_type_set))) else: # 一维list for content_type in type_set: @@ -237,17 +253,17 @@ class FieldArray(object): return float else: # str 跟 int 或者 float 混在一起 - raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) + raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) else: # str, int, float混在一起 - raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) + raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) def _1d_list_check(self, val): """如果不是1D list就报错 """ type_set = set((type(obj) for obj in val)) if any(obj not in self.BASIC_TYPES for obj in type_set): - raise ValueError("Mixed data types in Field {}: {}".format(self.name, type_set)) + raise ValueError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) self._basic_type_detection(type_set) # otherwise: _basic_type_detection will raise error return True @@ -278,39 +294,40 @@ class FieldArray(object): :param val: int, float, str, or a list of one. """ - if isinstance(val, list): - pass - elif isinstance(val, tuple): # 确保最外层是list - val = list(val) - elif isinstance(val, np.ndarray): - val = val.tolist() - elif any((isinstance(val, t) for t in self.BASIC_TYPES)): - pass - else: - raise RuntimeError( - "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) - - if self.is_input is True or self.is_target is True: - if type(val) == list: - if len(val) == 0: - raise ValueError("Cannot append an empty list.") - if self.content_dim == 2 and self._1d_list_check(val): - # 1维list检查 - pass - elif self.content_dim == 3 and self._2d_list_check(val): - # 2维list检查 - pass - else: - raise RuntimeError( - "Dimension not matched: expect dim={}, got {}.".format(self.content_dim - 1, val)) - elif type(val) in self.BASIC_TYPES and self.content_dim == 1: - # scalar检查 - if type(val) == float and self.pytype == int: - self.pytype = float - self.dtype = self._map_to_np_type(self.pytype) + if self.ignore_type is False: + if isinstance(val, list): + pass + elif isinstance(val, tuple): # 确保最外层是list + val = list(val) + elif isinstance(val, np.ndarray): + val = val.tolist() + elif any((isinstance(val, t) for t in self.BASIC_TYPES)): + pass else: raise RuntimeError( "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) + + if self.is_input is True or self.is_target is True: + if type(val) == list: + if len(val) == 0: + raise ValueError("Cannot append an empty list.") + if self.content_dim == 2 and self._1d_list_check(val): + # 1维list检查 + pass + elif self.content_dim == 3 and self._2d_list_check(val): + # 2维list检查 + pass + else: + raise RuntimeError( + "Dimension not matched: expect dim={}, got {}.".format(self.content_dim - 1, val)) + elif type(val) in self.BASIC_TYPES and self.content_dim == 1: + # scalar检查 + if type(val) == float and self.pytype == int: + self.pytype = float + self.dtype = self._map_to_np_type(self.pytype) + else: + raise RuntimeError( + "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) self.content.append(val) def __getitem__(self, indices): @@ -347,7 +364,7 @@ class FieldArray(object): """ if padder is not None: assert isinstance(padder, PadderBase), "padder must be of type PadderBase." - self.padder = padder + self.padder = deepcopy(padder) def set_pad_val(self, pad_val): """ diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 54fde815..64555e12 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -157,7 +157,7 @@ class MetricBase(object): fast_param = {} if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: fast_param['pred'] = list(pred_dict.values())[0] - fast_param['target'] = list(pred_dict.values())[0] + fast_param['target'] = list(target_dict.values())[0] return fast_param return fast_param @@ -822,3 +822,154 @@ def pred_topk(y_prob, k=1): (1, k)) y_prob_topk = y_prob[x_axis_index, y_pred_topk] return y_pred_topk, y_prob_topk + + +class SQuADMetric(MetricBase): + + def __init__(self, pred_start=None, pred_end=None, target_start=None, target_end=None, + beta=1, right_open=False, print_predict_stat=False): + """ + :param pred_start: [batch], 预测答案开始的index, 如果SQuAD2.0中答案为空则为0 + :param pred_end: [batch], 预测答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间) + :param target_start: [batch], 正确答案开始的index, 如果SQuAD2.0中答案为空则为0 + :param target_end: [batch], 正确答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间) + :param beta: float. f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 + 则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 + :param right_open: boolean. right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。 + :param print_predict_stat: boolean. True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 + """ + super(SQuADMetric, self).__init__() + + self._init_param_map(pred_start=pred_start, pred_end=pred_end, target_start=target_start, target_end=target_end) + + self.print_predict_stat = print_predict_stat + + self.no_ans_correct = 0 + self.no_ans_wrong = 0 + + self.has_ans_correct = 0 + self.has_ans_wrong = 0 + + self.has_ans_f = 0. + + self.no2no = 0 + self.no2yes = 0 + self.yes2no = 0 + self.yes2yes = 0 + + self.f_beta = beta + + self.right_open = right_open + + def evaluate(self, pred_start, pred_end, target_start, target_end): + """ + + :param pred_start: [batch, seq_len] + :param pred_end: [batch, seq_len] + :param target_start: [batch] + :param target_end: [batch] + :param labels: [batch] + :return: + """ + start_inference = pred_start.max(dim=-1)[1].cpu().tolist() + end_inference = pred_end.max(dim=-1)[1].cpu().tolist() + start, end = [], [] + max_len = pred_start.size(1) + t_start = target_start.cpu().tolist() + t_end = target_end.cpu().tolist() + + for s, e in zip(start_inference, end_inference): + start.append(min(s, e)) + end.append(max(s, e)) + for s, e, ts, te in zip(start, end, t_start, t_end): + if not self.right_open: + e += 1 + te += 1 + if ts == 0 and te == int(not self.right_open): + if s == 0 and e == int(not self.right_open): + self.no_ans_correct += 1 + self.no2no += 1 + else: + self.no_ans_wrong += 1 + self.no2yes += 1 + else: + if s == 0 and e == int(not self.right_open): + self.yes2no += 1 + else: + self.yes2yes += 1 + + if s == ts and e == te: + self.has_ans_correct += 1 + else: + self.has_ans_wrong += 1 + a = [0] * s + [1] * (e - s) + [0] * (max_len - e) + b = [0] * ts + [1] * (te - ts) + [0] * (max_len - te) + a, b = torch.tensor(a), torch.tensor(b) + + TP = int(torch.sum(a * b)) + pre = TP / int(torch.sum(a)) if int(torch.sum(a)) > 0 else 0 + rec = TP / int(torch.sum(b)) if int(torch.sum(b)) > 0 else 0 + + if pre + rec > 0: + f = (1 + (self.f_beta**2)) * pre * rec / ((self.f_beta**2) * pre + rec) + else: + f = 0 + self.has_ans_f += f + + def get_metric(self, reset=True): + evaluate_result = {} + + if self.no_ans_correct + self.no_ans_wrong + self.has_ans_correct + self.no_ans_wrong <= 0: + return evaluate_result + + evaluate_result['EM'] = 0 + evaluate_result[f'f_{self.f_beta}'] = 0 + + flag = 0 + + if self.no_ans_correct + self.no_ans_wrong > 0: + evaluate_result[f'noAns-f_{self.f_beta}'] = \ + round(100 * self.no_ans_correct / (self.no_ans_correct + self.no_ans_wrong), 3) + evaluate_result['noAns-EM'] = \ + round(100 * self.no_ans_correct / (self.no_ans_correct + self.no_ans_wrong), 3) + evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'noAns-f_{self.f_beta}'] + evaluate_result['EM'] += evaluate_result['noAns-EM'] + flag += 1 + + if self.has_ans_correct + self.has_ans_wrong > 0: + evaluate_result[f'hasAns-f_{self.f_beta}'] = \ + round(100 * self.has_ans_f / (self.has_ans_correct + self.has_ans_wrong), 3) + evaluate_result['hasAns-EM'] = \ + round(100 * self.has_ans_correct / (self.has_ans_correct + self.has_ans_wrong), 3) + evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'hasAns-f_{self.f_beta}'] + evaluate_result['EM'] += evaluate_result['hasAns-EM'] + flag += 1 + + if self.print_predict_stat: + evaluate_result['no2no'] = self.no2no + evaluate_result['no2yes'] = self.no2yes + evaluate_result['yes2no'] = self.yes2no + evaluate_result['yes2yes'] = self.yes2yes + + if flag <= 0: + return evaluate_result + + evaluate_result[f'f_{self.f_beta}'] = round(evaluate_result[f'f_{self.f_beta}'] / flag, 3) + evaluate_result['EM'] = round(evaluate_result['EM'] / flag, 3) + + if reset: + self.no_ans_correct = 0 + self.no_ans_wrong = 0 + + self.has_ans_correct = 0 + self.has_ans_wrong = 0 + + self.has_ans_f = 0. + + self.no2no = 0 + self.no2yes = 0 + self.yes2no = 0 + self.yes2yes = 0 + + return evaluate_result + diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 6d9f6c68..d9aa520f 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -32,8 +32,8 @@ from fastNLP.core.utils import get_func_signature class Trainer(object): def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, - validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), - check_code_level=0, metric_key=None, sampler=RandomSampler(), prefetch=False, use_tqdm=True, + validate_every=-1, dev_data=None, save_path=None, optimizer=None, + check_code_level=0, metric_key=None, sampler=None, prefetch=False, use_tqdm=True, use_cuda=False, callbacks=None): """ :param DataSet train_data: the training data @@ -96,7 +96,7 @@ class Trainer(object): losser = _prepare_losser(loss) # sampler check - if not isinstance(sampler, BaseSampler): + if sampler is not None and not isinstance(sampler, BaseSampler): raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) if check_code_level > -1: @@ -119,7 +119,7 @@ class Trainer(object): self.best_dev_epoch = None self.best_dev_step = None self.best_dev_perf = None - self.sampler = sampler + self.sampler = sampler if sampler is not None else RandomSampler() self.prefetch = prefetch self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) self.n_steps = (len(self.train_data) // self.batch_size + int( @@ -128,6 +128,8 @@ class Trainer(object): if isinstance(optimizer, torch.optim.Optimizer): self.optimizer = optimizer else: + if optimizer is None: + optimizer = Adam(lr=0.01, weight_decay=0) self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) self.use_tqdm = use_tqdm @@ -145,6 +147,7 @@ class Trainer(object): self.step = 0 self.start_time = None # start timestamp + def train(self, load_best_model=True): """ @@ -365,6 +368,8 @@ class Trainer(object): """ if self.save_path is not None: model_path = os.path.join(self.save_path, model_name) + if not os.path.exists(self.save_path): + os.makedirs(self.save_path, exist_ok=True) if only_param: state_dict = model.state_dict() for key in state_dict: diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 987a3527..a1c8e678 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -196,3 +196,9 @@ class Vocabulary(object): """ self.__dict__.update(state) self.build_reverse_vocab() + + def __repr__(self): + return "Vocabulary({}...)".format(list(self.word_count.keys())[:5]) + + def __iter__(self): + return iter(list(self.word_count.keys())) diff --git a/fastNLP/modules/decoder/CRF.py b/fastNLP/modules/decoder/CRF.py index e1b68e7a..df004224 100644 --- a/fastNLP/modules/decoder/CRF.py +++ b/fastNLP/modules/decoder/CRF.py @@ -192,20 +192,23 @@ class ConditionalRandomField(nn.Module): seq_len, batch_size, n_tags = logits.size() alpha = logits[0] if self.include_start_end_trans: - alpha += self.start_scores.view(1, -1) + alpha = alpha + self.start_scores.view(1, -1) + + flip_mask = mask.eq(0) for i in range(1, seq_len): emit_score = logits[i].view(batch_size, 1, n_tags) trans_score = self.trans_m.view(1, n_tags, n_tags) tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score - alpha = log_sum_exp(tmp, 1) * mask[i].view(batch_size, 1) + alpha * (1 - mask[i]).view(batch_size, 1) + alpha = log_sum_exp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ + alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0) if self.include_start_end_trans: - alpha += self.end_scores.view(1, -1) + alpha = alpha + self.end_scores.view(1, -1) return log_sum_exp(alpha, 1) - def _glod_score(self, logits, tags, mask): + def _gold_score(self, logits, tags, mask): """ Compute the score for the gold path. :param logits: FloatTensor, max_len x batch_size x num_tags @@ -218,17 +221,19 @@ class ConditionalRandomField(nn.Module): seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) # trans_socre [L-1, B] - trans_score = self.trans_m[tags[:seq_len-1], tags[1:]] * mask[1:, :] + mask = mask.byte() + flip_mask = mask.eq(0) + trans_score = self.trans_m[tags[:seq_len-1], tags[1:]].masked_fill(flip_mask[1:, :], 0) # emit_score [L, B] - emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags] * mask + emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags].masked_fill(flip_mask, 0) # score [L-1, B] score = trans_score + emit_score[:seq_len-1, :] - score = score.sum(0) + emit_score[-1] * mask[-1] + score = score.sum(0) + emit_score[-1].masked_fill(flip_mask[-1], 0) if self.include_start_end_trans: st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] last_idx = mask.long().sum(0) - 1 ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] - score += st_scores + ed_scores + score = score + st_scores + ed_scores # return [B,] return score @@ -244,7 +249,7 @@ class ConditionalRandomField(nn.Module): tags = tags.transpose(0, 1).long() mask = mask.transpose(0, 1).float() all_path_score = self._normalizer_likelihood(feats, mask) - gold_path_score = self._glod_score(feats, tags, mask) + gold_path_score = self._gold_score(feats, tags, mask) return all_path_score - gold_path_score @@ -265,7 +270,7 @@ class ConditionalRandomField(nn.Module): """ batch_size, seq_len, n_tags = data.size() data = data.transpose(0, 1).data # L, B, H - mask = mask.transpose(0, 1).data.float() # L, B + mask = mask.transpose(0, 1).data.byte() # L, B # dp vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) @@ -284,7 +289,8 @@ class ConditionalRandomField(nn.Module): score = prev_score + trans_score + cur_score best_score, best_dst = score.max(1) vpath[i] = best_dst - vscore = best_score * mask[i].view(batch_size, 1) + vscore * (1 - mask[i]).view(batch_size, 1) + vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ + vscore.masked_fill(mask[i].view(batch_size, 1), 0) vscore += transitions[:n_tags, n_tags+1].view(1, -1) diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py index 5287bca4..4ae15b18 100644 --- a/fastNLP/modules/utils.py +++ b/fastNLP/modules/utils.py @@ -60,7 +60,8 @@ def initial_parameter(net, initial_method=None): init_method(w.data) # weight else: init.normal_(w.data) # bias - elif hasattr(m, 'weight') and m.weight.requires_grad: + elif m is not None and hasattr(m, 'weight') and \ + hasattr(m.weight, "requires_grad"): init_method(m.weight.data) else: for w in m.parameters(): diff --git a/test/automl/test_enas.py b/test/automl/test_enas.py new file mode 100644 index 00000000..d2d3af05 --- /dev/null +++ b/test/automl/test_enas.py @@ -0,0 +1,111 @@ +import unittest + +from fastNLP import DataSet +from fastNLP import Instance +from fastNLP import Vocabulary +from fastNLP.core.losses import CrossEntropyLoss +from fastNLP.core.metrics import AccuracyMetric + + +class TestENAS(unittest.TestCase): + def testENAS(self): + # 从csv读取数据到DataSet + sample_path = "tutorials/sample_data/tutorial_sample_dataset.csv" + dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'), + sep='\t') + print(len(dataset)) + print(dataset[0]) + print(dataset[-3]) + + dataset.append(Instance(raw_sentence='fake data', label='0')) + # 将所有数字转为小写 + dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence') + # label转int + dataset.apply(lambda x: int(x['label']), new_field_name='label') + + # 使用空格分割句子 + def split_sent(ins): + return ins['raw_sentence'].split() + + dataset.apply(split_sent, new_field_name='words') + + # 增加长度信息 + dataset.apply(lambda x: len(x['words']), new_field_name='seq_len') + print(len(dataset)) + print(dataset[0]) + + # DataSet.drop(func)筛除数据 + dataset.drop(lambda x: x['seq_len'] <= 3) + print(len(dataset)) + + # 设置DataSet中,哪些field要转为tensor + # set target,loss或evaluate中的golden,计算loss,模型评估时使用 + dataset.set_target("label") + # set input,模型forward时使用 + dataset.set_input("words", "seq_len") + + # 分出测试集、训练集 + test_data, train_data = dataset.split(0.5) + print(len(test_data)) + print(len(train_data)) + + # 构建词表, Vocabulary.add(word) + vocab = Vocabulary(min_freq=2) + train_data.apply(lambda x: [vocab.add(word) for word in x['words']]) + vocab.build_vocab() + + # index句子, Vocabulary.to_index(word) + train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') + test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') + print(test_data[0]) + + # 如果你们需要做强化学习或者GAN之类的项目,你们也可以使用这些数据预处理的工具 + from fastNLP.core.batch import Batch + from fastNLP.core.sampler import RandomSampler + + batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler()) + for batch_x, batch_y in batch_iterator: + print("batch_x has: ", batch_x) + print("batch_y has: ", batch_y) + break + + from fastNLP.automl.enas_model import ENASModel + from fastNLP.automl.enas_controller import Controller + model = ENASModel(embed_num=len(vocab), num_classes=5) + controller = Controller() + + from fastNLP.automl.enas_trainer import ENASTrainer + + # 更改DataSet中对应field的名称,要以模型的forward等参数名一致 + train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致 + train_data.rename_field('label', 'label_seq') + test_data.rename_field('words', 'word_seq') + test_data.rename_field('label', 'label_seq') + + loss = CrossEntropyLoss(pred="output", target="label_seq") + metric = AccuracyMetric(pred="predict", target="label_seq") + + trainer = ENASTrainer(model=model, controller=controller, train_data=train_data, dev_data=test_data, + loss=CrossEntropyLoss(pred="output", target="label_seq"), + metrics=AccuracyMetric(pred="predict", target="label_seq"), + check_code_level=-1, + save_path=None, + batch_size=32, + print_every=1, + n_epochs=3, + final_epochs=1) + trainer.train() + print('Train finished!') + + # 调用Tester在test_data上评价效果 + from fastNLP import Tester + + tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"), + batch_size=4) + + acc = tester.test() + print(acc) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/test/core/test_callbacks.py b/test/core/test_callbacks.py index 74ce4876..7d66620c 100644 --- a/test/core/test_callbacks.py +++ b/test/core/test_callbacks.py @@ -136,3 +136,28 @@ class TestCallback(unittest.TestCase): metrics=AccuracyMetric(pred="predict", target="y"), callbacks=[TensorboardCallback("loss", "metric")]) trainer.train() + + def test_readonly_property(self): + from fastNLP.core.callback import Callback + class MyCallback(Callback): + def __init__(self): + super(MyCallback, self).__init__() + + def on_epoch_begin(self, cur_epoch, total_epoch): + print(self.n_epochs, self.n_steps, self.batch_size) + print(self.model) + print(self.optimizer) + + data_set, model = prepare_env() + trainer = Trainer(data_set, model, + loss=BCELoss(pred="predict", target="y"), + n_epochs=5, + batch_size=32, + print_every=50, + optimizer=SGD(lr=0.1), + check_code_level=2, + use_tqdm=False, + dev_data=data_set, + metrics=AccuracyMetric(pred="predict", target="y"), + callbacks=[MyCallback()]) + trainer.train() diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index 72ced912..607f9a13 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -52,7 +52,7 @@ class TestDataSetMethods(unittest.TestCase): self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3) self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3) - def test_add_append(self): + def test_add_field(self): dd = DataSet() dd.add_field("x", [[1, 2, 3]] * 10) dd.add_field("y", [[1, 2, 3, 4]] * 10) @@ -65,6 +65,11 @@ class TestDataSetMethods(unittest.TestCase): with self.assertRaises(RuntimeError): dd.add_field("??", [[1, 2]] * 40) + def test_add_field_ignore_type(self): + dd = DataSet() + dd.add_field("x", [(1, "1"), (2, "2"), (3, "3"), (4, "4")], ignore_type=True, is_target=True) + dd.add_field("y", [{1, "1"}, {2, "2"}, {3, "3"}, {4, "4"}], ignore_type=True, is_target=True) + def test_delete_field(self): dd = DataSet() dd.add_field("x", [[1, 2, 3]] * 10) @@ -115,6 +120,9 @@ class TestDataSetMethods(unittest.TestCase): self.assertTrue(isinstance(res, list) and len(res) > 0) self.assertTrue(res[0], 4) + ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k", ignore_type=True) + # expect no exception raised + def test_drop(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) ds.drop(lambda ins: len(ins["y"]) < 3) @@ -165,7 +173,7 @@ class TestDataSetMethods(unittest.TestCase): dataset.apply(split_sent, new_field_name='words', is_input=True) # print(dataset) - def test_add_field(self): + def test_add_field_v2(self): ds = DataSet({"x": [3, 4]}) ds.add_field('y', [['hello', 'world'], ['this', 'is', 'a', 'test']], is_input=True, is_target=True) # ds.apply(lambda x:[x['x']]*3, is_input=True, is_target=True, new_field_name='y') @@ -208,6 +216,11 @@ class TestDataSetMethods(unittest.TestCase): self.assertTrue(isinstance(ds, DataSet)) self.assertTrue(len(ds) > 0) + def test_add_null(self): + ds = DataSet() + ds.add_field('test', []) + ds.set_target('test') + class TestDataSetIter(unittest.TestCase): def test__repr__(self): diff --git a/test/core/test_fieldarray.py b/test/core/test_fieldarray.py index 151d9335..ff1a8314 100644 --- a/test/core/test_fieldarray.py +++ b/test/core/test_fieldarray.py @@ -155,6 +155,13 @@ class TestFieldArray(unittest.TestCase): self.assertEqual(len(fa), 3) self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6]) + def test_ignore_type(self): + # 测试新添加的参数ignore_type,用来跳过类型检查 + fa = FieldArray("y", [[1.1, 2.2, "jin", {}, "hahah"], [int, 2, "$", 4, 5]], is_input=True, ignore_type=True) + fa.append([1.2, 2.3, str, 4.5, print]) + + fa = FieldArray("y", [(1, "1"), (2, "2"), (3, "3"), (4, "4")], is_target=True, ignore_type=True) + class TestPadder(unittest.TestCase): @@ -215,4 +222,14 @@ class TestPadder(unittest.TestCase): [[[1, 2, 3, -100, -100], [4, 5, -100, -100, -100], [7, 8, 9, 10, -100]], [[1, -100, -100, -100, -100], [-100, -100, -100, -100, -100], [-100, -100, -100, -100, -100]]], padder(contents, None, np.int64).tolist() - ) \ No newline at end of file + ) + + def test_None_dtype(self): + from fastNLP.core.fieldarray import AutoPadder + padder = AutoPadder() + content = [ + [[1, 2, 3], [4, 5], [7, 8, 9, 10]], + [[1]] + ] + ans = padder(content, None, None) + self.assertListEqual(content, ans) diff --git a/test/core/test_vocabulary.py b/test/core/test_vocabulary.py index af2c493b..2f9cd3b1 100644 --- a/test/core/test_vocabulary.py +++ b/test/core/test_vocabulary.py @@ -60,6 +60,15 @@ class TestIndexing(unittest.TestCase): vocab.update(text) self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) + def test_iteration(self): + vocab = Vocabulary() + text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", + "works", "well", "in", "most", "cases", "scales", "well"] + vocab.update(text) + text = set(text) + for word in vocab: + self.assertTrue(word in text) + class TestOther(unittest.TestCase): def test_additional_update(self): diff --git a/test/modules/decoder/test_CRF.py b/test/modules/decoder/test_CRF.py index 0fc331dc..a176348f 100644 --- a/test/modules/decoder/test_CRF.py +++ b/test/modules/decoder/test_CRF.py @@ -66,7 +66,7 @@ class TestCRF(unittest.TestCase): # from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label)) # fast_CRF.trans_m = trans_m - # fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True) + # fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True, unpad=True) # # score equal # self.assertListEqual([score for _, score in allen_res], fast_res[1]) # # seq equal @@ -95,10 +95,34 @@ class TestCRF(unittest.TestCase): # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, # encoding_type='BMES')) # fast_CRF.trans_m = trans_m - # fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True) + # fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True, unpad=True) # # score equal # self.assertListEqual([score for _, score in allen_res], fast_res[1]) # # seq equal # self.assertListEqual([_ for _, score in allen_res], fast_res[0]) + def test_case3(self): + # 测试crf的loss不会出现负数 + import torch + from fastNLP.modules.decoder.CRF import ConditionalRandomField + from fastNLP.core.utils import seq_lens_to_masks + from torch import optim + from torch import nn + num_tags, include_start_end_trans = 4, True + num_samples = 4 + lengths = torch.randint(3, 50, size=(num_samples, )).long() + max_len = lengths.max() + tags = torch.randint(num_tags, size=(num_samples, max_len)) + masks = seq_lens_to_masks(lengths) + feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags)) + crf = ConditionalRandomField(num_tags, include_start_end_trans) + optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1) + for _ in range(10000): + loss = crf(feats, tags, masks).mean() + optimizer.zero_grad() + loss.backward() + optimizer.step() + if _%1000==0: + print(loss) + assert loss.item()>0, "CRF loss cannot be less than 0."