# Conflicts: # fastNLP/core/callback.py # fastNLP/core/trainer.pytags/v0.4.10
@@ -0,0 +1,223 @@ | |||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
"""A module with NAS controller-related code.""" | |||||
import collections | |||||
import os | |||||
import torch | |||||
import torch.nn.functional as F | |||||
import fastNLP.automl.enas_utils as utils | |||||
from fastNLP.automl.enas_utils import Node | |||||
def _construct_dags(prev_nodes, activations, func_names, num_blocks): | |||||
"""Constructs a set of DAGs based on the actions, i.e., previous nodes and | |||||
activation functions, sampled from the controller/policy pi. | |||||
Args: | |||||
prev_nodes: Previous node actions from the policy. | |||||
activations: Activations sampled from the policy. | |||||
func_names: Mapping from activation function names to functions. | |||||
num_blocks: Number of blocks in the target RNN cell. | |||||
Returns: | |||||
A list of DAGs defined by the inputs. | |||||
RNN cell DAGs are represented in the following way: | |||||
1. Each element (node) in a DAG is a list of `Node`s. | |||||
2. The `Node`s in the list dag[i] correspond to the subsequent nodes | |||||
that take the output from node i as their own input. | |||||
3. dag[-1] is the node that takes input from x^{(t)} and h^{(t - 1)}. | |||||
dag[-1] always feeds dag[0]. | |||||
dag[-1] acts as if `w_xc`, `w_hc`, `w_xh` and `w_hh` are its | |||||
weights. | |||||
4. dag[N - 1] is the node that produces the hidden state passed to | |||||
the next timestep. dag[N - 1] is also always a leaf node, and therefore | |||||
is always averaged with the other leaf nodes and fed to the output | |||||
decoder. | |||||
""" | |||||
dags = [] | |||||
for nodes, func_ids in zip(prev_nodes, activations): | |||||
dag = collections.defaultdict(list) | |||||
# add first node | |||||
dag[-1] = [Node(0, func_names[func_ids[0]])] | |||||
dag[-2] = [Node(0, func_names[func_ids[0]])] | |||||
# add following nodes | |||||
for jdx, (idx, func_id) in enumerate(zip(nodes, func_ids[1:])): | |||||
dag[utils.to_item(idx)].append(Node(jdx + 1, func_names[func_id])) | |||||
leaf_nodes = set(range(num_blocks)) - dag.keys() | |||||
# merge with avg | |||||
for idx in leaf_nodes: | |||||
dag[idx] = [Node(num_blocks, 'avg')] | |||||
# This is actually y^{(t)}. h^{(t)} is node N - 1 in | |||||
# the graph, where N Is the number of nodes. I.e., h^{(t)} takes | |||||
# only one other node as its input. | |||||
# last h[t] node | |||||
last_node = Node(num_blocks + 1, 'h[t]') | |||||
dag[num_blocks] = [last_node] | |||||
dags.append(dag) | |||||
return dags | |||||
class Controller(torch.nn.Module): | |||||
"""Based on | |||||
https://github.com/pytorch/examples/blob/master/word_language_model/model.py | |||||
RL controllers do not necessarily have much to do with | |||||
language models. | |||||
Base the controller RNN on the GRU from: | |||||
https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py | |||||
""" | |||||
def __init__(self, num_blocks=4, controller_hid=100, cuda=False): | |||||
torch.nn.Module.__init__(self) | |||||
# `num_tokens` here is just the activation function | |||||
# for every even step, | |||||
self.shared_rnn_activations = ['tanh', 'ReLU', 'identity', 'sigmoid'] | |||||
self.num_tokens = [len(self.shared_rnn_activations)] | |||||
self.controller_hid = controller_hid | |||||
self.use_cuda = cuda | |||||
self.num_blocks = num_blocks | |||||
for idx in range(num_blocks): | |||||
self.num_tokens += [idx + 1, len(self.shared_rnn_activations)] | |||||
self.func_names = self.shared_rnn_activations | |||||
num_total_tokens = sum(self.num_tokens) | |||||
self.encoder = torch.nn.Embedding(num_total_tokens, | |||||
controller_hid) | |||||
self.lstm = torch.nn.LSTMCell(controller_hid, controller_hid) | |||||
# Perhaps these weights in the decoder should be | |||||
# shared? At least for the activation functions, which all have the | |||||
# same size. | |||||
self.decoders = [] | |||||
for idx, size in enumerate(self.num_tokens): | |||||
decoder = torch.nn.Linear(controller_hid, size) | |||||
self.decoders.append(decoder) | |||||
self._decoders = torch.nn.ModuleList(self.decoders) | |||||
self.reset_parameters() | |||||
self.static_init_hidden = utils.keydefaultdict(self.init_hidden) | |||||
def _get_default_hidden(key): | |||||
return utils.get_variable( | |||||
torch.zeros(key, self.controller_hid), | |||||
self.use_cuda, | |||||
requires_grad=False) | |||||
self.static_inputs = utils.keydefaultdict(_get_default_hidden) | |||||
def reset_parameters(self): | |||||
init_range = 0.1 | |||||
for param in self.parameters(): | |||||
param.data.uniform_(-init_range, init_range) | |||||
for decoder in self.decoders: | |||||
decoder.bias.data.fill_(0) | |||||
def forward(self, # pylint:disable=arguments-differ | |||||
inputs, | |||||
hidden, | |||||
block_idx, | |||||
is_embed): | |||||
if not is_embed: | |||||
embed = self.encoder(inputs) | |||||
else: | |||||
embed = inputs | |||||
hx, cx = self.lstm(embed, hidden) | |||||
logits = self.decoders[block_idx](hx) | |||||
logits /= 5.0 | |||||
# # exploration | |||||
# if self.args.mode == 'train': | |||||
# logits = (2.5 * F.tanh(logits)) | |||||
return logits, (hx, cx) | |||||
def sample(self, batch_size=1, with_details=False, save_dir=None): | |||||
"""Samples a set of `args.num_blocks` many computational nodes from the | |||||
controller, where each node is made up of an activation function, and | |||||
each node except the last also includes a previous node. | |||||
""" | |||||
if batch_size < 1: | |||||
raise Exception(f'Wrong batch_size: {batch_size} < 1') | |||||
# [B, L, H] | |||||
inputs = self.static_inputs[batch_size] | |||||
hidden = self.static_init_hidden[batch_size] | |||||
activations = [] | |||||
entropies = [] | |||||
log_probs = [] | |||||
prev_nodes = [] | |||||
# The RNN controller alternately outputs an activation, | |||||
# followed by a previous node, for each block except the last one, | |||||
# which only gets an activation function. The last node is the output | |||||
# node, and its previous node is the average of all leaf nodes. | |||||
for block_idx in range(2*(self.num_blocks - 1) + 1): | |||||
logits, hidden = self.forward(inputs, | |||||
hidden, | |||||
block_idx, | |||||
is_embed=(block_idx == 0)) | |||||
probs = F.softmax(logits, dim=-1) | |||||
log_prob = F.log_softmax(logits, dim=-1) | |||||
# .mean() for entropy? | |||||
entropy = -(log_prob * probs).sum(1, keepdim=False) | |||||
action = probs.multinomial(num_samples=1).data | |||||
selected_log_prob = log_prob.gather( | |||||
1, utils.get_variable(action, requires_grad=False)) | |||||
# why the [:, 0] here? Should it be .squeeze(), or | |||||
# .view()? Same below with `action`. | |||||
entropies.append(entropy) | |||||
log_probs.append(selected_log_prob[:, 0]) | |||||
# 0: function, 1: previous node | |||||
mode = block_idx % 2 | |||||
inputs = utils.get_variable( | |||||
action[:, 0] + sum(self.num_tokens[:mode]), | |||||
requires_grad=False) | |||||
if mode == 0: | |||||
activations.append(action[:, 0]) | |||||
elif mode == 1: | |||||
prev_nodes.append(action[:, 0]) | |||||
prev_nodes = torch.stack(prev_nodes).transpose(0, 1) | |||||
activations = torch.stack(activations).transpose(0, 1) | |||||
dags = _construct_dags(prev_nodes, | |||||
activations, | |||||
self.func_names, | |||||
self.num_blocks) | |||||
if save_dir is not None: | |||||
for idx, dag in enumerate(dags): | |||||
utils.draw_network(dag, | |||||
os.path.join(save_dir, f'graph{idx}.png')) | |||||
if with_details: | |||||
return dags, torch.cat(log_probs), torch.cat(entropies) | |||||
return dags | |||||
def init_hidden(self, batch_size): | |||||
zeros = torch.zeros(batch_size, self.controller_hid) | |||||
return (utils.get_variable(zeros, self.use_cuda, requires_grad=False), | |||||
utils.get_variable(zeros.clone(), self.use_cuda, requires_grad=False)) |
@@ -0,0 +1,388 @@ | |||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
"""Module containing the shared RNN model.""" | |||||
import collections | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn.functional as F | |||||
from torch import nn | |||||
from torch.autograd import Variable | |||||
import fastNLP.automl.enas_utils as utils | |||||
from fastNLP.models.base_model import BaseModel | |||||
def _get_dropped_weights(w_raw, dropout_p, is_training): | |||||
"""Drops out weights to implement DropConnect. | |||||
Args: | |||||
w_raw: Full, pre-dropout, weights to be dropped out. | |||||
dropout_p: Proportion of weights to drop out. | |||||
is_training: True iff _shared_ model is training. | |||||
Returns: | |||||
The dropped weights. | |||||
Why does torch.nn.functional.dropout() return: | |||||
1. `torch.autograd.Variable()` on the training loop | |||||
2. `torch.nn.Parameter()` on the controller or eval loop, when | |||||
training = False... | |||||
Even though the call to `_setweights` in the Smerity repo's | |||||
`weight_drop.py` does not have this behaviour, and `F.dropout` always | |||||
returns `torch.autograd.Variable` there, even when `training=False`? | |||||
The above TODO is the reason for the hacky check for `torch.nn.Parameter`. | |||||
""" | |||||
dropped_w = F.dropout(w_raw, p=dropout_p, training=is_training) | |||||
if isinstance(dropped_w, torch.nn.Parameter): | |||||
dropped_w = dropped_w.clone() | |||||
return dropped_w | |||||
class EmbeddingDropout(torch.nn.Embedding): | |||||
"""Class for dropping out embeddings by zero'ing out parameters in the | |||||
embedding matrix. | |||||
This is equivalent to dropping out particular words, e.g., in the sentence | |||||
'the quick brown fox jumps over the lazy dog', dropping out 'the' would | |||||
lead to the sentence '### quick brown fox jumps over ### lazy dog' (in the | |||||
embedding vector space). | |||||
See 'A Theoretically Grounded Application of Dropout in Recurrent Neural | |||||
Networks', (Gal and Ghahramani, 2016). | |||||
""" | |||||
def __init__(self, | |||||
num_embeddings, | |||||
embedding_dim, | |||||
max_norm=None, | |||||
norm_type=2, | |||||
scale_grad_by_freq=False, | |||||
sparse=False, | |||||
dropout=0.1, | |||||
scale=None): | |||||
"""Embedding constructor. | |||||
Args: | |||||
dropout: Dropout probability. | |||||
scale: Used to scale parameters of embedding weight matrix that are | |||||
not dropped out. Note that this is _in addition_ to the | |||||
`1/(1 - dropout)` scaling. | |||||
See `torch.nn.Embedding` for remaining arguments. | |||||
""" | |||||
torch.nn.Embedding.__init__(self, | |||||
num_embeddings=num_embeddings, | |||||
embedding_dim=embedding_dim, | |||||
max_norm=max_norm, | |||||
norm_type=norm_type, | |||||
scale_grad_by_freq=scale_grad_by_freq, | |||||
sparse=sparse) | |||||
self.dropout = dropout | |||||
assert (dropout >= 0.0) and (dropout < 1.0), ('Dropout must be >= 0.0 ' | |||||
'and < 1.0') | |||||
self.scale = scale | |||||
def forward(self, inputs): # pylint:disable=arguments-differ | |||||
"""Embeds `inputs` with the dropped out embedding weight matrix.""" | |||||
if self.training: | |||||
dropout = self.dropout | |||||
else: | |||||
dropout = 0 | |||||
if dropout: | |||||
mask = self.weight.data.new(self.weight.size(0), 1) | |||||
mask.bernoulli_(1 - dropout) | |||||
mask = mask.expand_as(self.weight) | |||||
mask = mask / (1 - dropout) | |||||
masked_weight = self.weight * Variable(mask) | |||||
else: | |||||
masked_weight = self.weight | |||||
if self.scale and self.scale != 1: | |||||
masked_weight = masked_weight * self.scale | |||||
return F.embedding(inputs, | |||||
masked_weight, | |||||
max_norm=self.max_norm, | |||||
norm_type=self.norm_type, | |||||
scale_grad_by_freq=self.scale_grad_by_freq, | |||||
sparse=self.sparse) | |||||
class LockedDropout(nn.Module): | |||||
# code from https://github.com/salesforce/awd-lstm-lm/blob/master/locked_dropout.py | |||||
def __init__(self): | |||||
super().__init__() | |||||
def forward(self, x, dropout=0.5): | |||||
if not self.training or not dropout: | |||||
return x | |||||
m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout) | |||||
mask = Variable(m, requires_grad=False) / (1 - dropout) | |||||
mask = mask.expand_as(x) | |||||
return mask * x | |||||
class ENASModel(BaseModel): | |||||
"""Shared RNN model.""" | |||||
def __init__(self, embed_num, num_classes, num_blocks=4, cuda=False, shared_hid=1000, shared_embed=1000): | |||||
super(ENASModel, self).__init__() | |||||
self.use_cuda = cuda | |||||
self.shared_hid = shared_hid | |||||
self.num_blocks = num_blocks | |||||
self.decoder = nn.Linear(self.shared_hid, num_classes) | |||||
self.encoder = EmbeddingDropout(embed_num, | |||||
shared_embed, | |||||
dropout=0.1) | |||||
self.lockdrop = LockedDropout() | |||||
self.dag = None | |||||
# Tie weights | |||||
# self.decoder.weight = self.encoder.weight | |||||
# Since W^{x, c} and W^{h, c} are always summed, there | |||||
# is no point duplicating their bias offset parameter. Likewise for | |||||
# W^{x, h} and W^{h, h}. | |||||
self.w_xc = nn.Linear(shared_embed, self.shared_hid) | |||||
self.w_xh = nn.Linear(shared_embed, self.shared_hid) | |||||
# The raw weights are stored here because the hidden-to-hidden weights | |||||
# are weight dropped on the forward pass. | |||||
self.w_hc_raw = torch.nn.Parameter( | |||||
torch.Tensor(self.shared_hid, self.shared_hid)) | |||||
self.w_hh_raw = torch.nn.Parameter( | |||||
torch.Tensor(self.shared_hid, self.shared_hid)) | |||||
self.w_hc = None | |||||
self.w_hh = None | |||||
self.w_h = collections.defaultdict(dict) | |||||
self.w_c = collections.defaultdict(dict) | |||||
for idx in range(self.num_blocks): | |||||
for jdx in range(idx + 1, self.num_blocks): | |||||
self.w_h[idx][jdx] = nn.Linear(self.shared_hid, | |||||
self.shared_hid, | |||||
bias=False) | |||||
self.w_c[idx][jdx] = nn.Linear(self.shared_hid, | |||||
self.shared_hid, | |||||
bias=False) | |||||
self._w_h = nn.ModuleList([self.w_h[idx][jdx] | |||||
for idx in self.w_h | |||||
for jdx in self.w_h[idx]]) | |||||
self._w_c = nn.ModuleList([self.w_c[idx][jdx] | |||||
for idx in self.w_c | |||||
for jdx in self.w_c[idx]]) | |||||
self.batch_norm = None | |||||
# if args.mode == 'train': | |||||
# self.batch_norm = nn.BatchNorm1d(self.shared_hid) | |||||
# else: | |||||
# self.batch_norm = None | |||||
self.reset_parameters() | |||||
self.static_init_hidden = utils.keydefaultdict(self.init_hidden) | |||||
def setDAG(self, dag): | |||||
if self.dag is None: | |||||
self.dag = dag | |||||
def forward(self, word_seq, hidden=None): | |||||
inputs = torch.transpose(word_seq, 0, 1) | |||||
time_steps = inputs.size(0) | |||||
batch_size = inputs.size(1) | |||||
self.w_hh = _get_dropped_weights(self.w_hh_raw, | |||||
0.5, | |||||
self.training) | |||||
self.w_hc = _get_dropped_weights(self.w_hc_raw, | |||||
0.5, | |||||
self.training) | |||||
# hidden = self.static_init_hidden[batch_size] if hidden is None else hidden | |||||
hidden = self.static_init_hidden[batch_size] | |||||
embed = self.encoder(inputs) | |||||
embed = self.lockdrop(embed, 0.65 if self.training else 0) | |||||
# The norm of hidden states are clipped here because | |||||
# otherwise ENAS is especially prone to exploding activations on the | |||||
# forward pass. This could probably be fixed in a more elegant way, but | |||||
# it might be exposing a weakness in the ENAS algorithm as currently | |||||
# proposed. | |||||
# | |||||
# For more details, see | |||||
# https://github.com/carpedm20/ENAS-pytorch/issues/6 | |||||
clipped_num = 0 | |||||
max_clipped_norm = 0 | |||||
h1tohT = [] | |||||
logits = [] | |||||
for step in range(time_steps): | |||||
x_t = embed[step] | |||||
logit, hidden = self.cell(x_t, hidden, self.dag) | |||||
hidden_norms = hidden.norm(dim=-1) | |||||
max_norm = 25.0 | |||||
if hidden_norms.data.max() > max_norm: | |||||
# Just directly use the torch slice operations | |||||
# in PyTorch v0.4. | |||||
# | |||||
# This workaround for PyTorch v0.3.1 does everything in numpy, | |||||
# because the PyTorch slicing and slice assignment is too | |||||
# flaky. | |||||
hidden_norms = hidden_norms.data.cpu().numpy() | |||||
clipped_num += 1 | |||||
if hidden_norms.max() > max_clipped_norm: | |||||
max_clipped_norm = hidden_norms.max() | |||||
clip_select = hidden_norms > max_norm | |||||
clip_norms = hidden_norms[clip_select] | |||||
mask = np.ones(hidden.size()) | |||||
normalizer = max_norm/clip_norms | |||||
normalizer = normalizer[:, np.newaxis] | |||||
mask[clip_select] = normalizer | |||||
if self.use_cuda: | |||||
hidden *= torch.autograd.Variable( | |||||
torch.FloatTensor(mask).cuda(), requires_grad=False) | |||||
else: | |||||
hidden *= torch.autograd.Variable( | |||||
torch.FloatTensor(mask), requires_grad=False) | |||||
logits.append(logit) | |||||
h1tohT.append(hidden) | |||||
h1tohT = torch.stack(h1tohT) | |||||
output = torch.stack(logits) | |||||
raw_output = output | |||||
output = self.lockdrop(output, 0.4 if self.training else 0) | |||||
#Pooling | |||||
output = torch.mean(output, 0) | |||||
decoded = self.decoder(output) | |||||
extra_out = {'dropped': decoded, | |||||
'hiddens': h1tohT, | |||||
'raw': raw_output} | |||||
return {'pred': decoded, 'hidden': hidden, 'extra_out': extra_out} | |||||
def cell(self, x, h_prev, dag): | |||||
"""Computes a single pass through the discovered RNN cell.""" | |||||
c = {} | |||||
h = {} | |||||
f = {} | |||||
f[0] = self.get_f(dag[-1][0].name) | |||||
c[0] = torch.sigmoid(self.w_xc(x) + F.linear(h_prev, self.w_hc, None)) | |||||
h[0] = (c[0]*f[0](self.w_xh(x) + F.linear(h_prev, self.w_hh, None)) + | |||||
(1 - c[0])*h_prev) | |||||
leaf_node_ids = [] | |||||
q = collections.deque() | |||||
q.append(0) | |||||
# Computes connections from the parent nodes `node_id` | |||||
# to their child nodes `next_id` recursively, skipping leaf nodes. A | |||||
# leaf node is a node whose id == `self.num_blocks`. | |||||
# | |||||
# Connections between parent i and child j should be computed as | |||||
# h_j = c_j*f_{ij}{(W^h_{ij}*h_i)} + (1 - c_j)*h_i, | |||||
# where c_j = \sigmoid{(W^c_{ij}*h_i)} | |||||
# | |||||
# See Training details from Section 3.1 of the paper. | |||||
# | |||||
# The following algorithm does a breadth-first (since `q.popleft()` is | |||||
# used) search over the nodes and computes all the hidden states. | |||||
while True: | |||||
if len(q) == 0: | |||||
break | |||||
node_id = q.popleft() | |||||
nodes = dag[node_id] | |||||
for next_node in nodes: | |||||
next_id = next_node.id | |||||
if next_id == self.num_blocks: | |||||
leaf_node_ids.append(node_id) | |||||
assert len(nodes) == 1, ('parent of leaf node should have ' | |||||
'only one child') | |||||
continue | |||||
w_h = self.w_h[node_id][next_id] | |||||
w_c = self.w_c[node_id][next_id] | |||||
f[next_id] = self.get_f(next_node.name) | |||||
c[next_id] = torch.sigmoid(w_c(h[node_id])) | |||||
h[next_id] = (c[next_id]*f[next_id](w_h(h[node_id])) + | |||||
(1 - c[next_id])*h[node_id]) | |||||
q.append(next_id) | |||||
# Instead of averaging loose ends, perhaps there should | |||||
# be a set of separate unshared weights for each "loose" connection | |||||
# between each node in a cell and the output. | |||||
# | |||||
# As it stands, all weights W^h_{ij} are doing double duty by | |||||
# connecting both from i to j, as well as from i to the output. | |||||
# average all the loose ends | |||||
leaf_nodes = [h[node_id] for node_id in leaf_node_ids] | |||||
output = torch.mean(torch.stack(leaf_nodes, 2), -1) | |||||
# stabilizing the Updates of omega | |||||
if self.batch_norm is not None: | |||||
output = self.batch_norm(output) | |||||
return output, h[self.num_blocks - 1] | |||||
def init_hidden(self, batch_size): | |||||
zeros = torch.zeros(batch_size, self.shared_hid) | |||||
return utils.get_variable(zeros, self.use_cuda, requires_grad=False) | |||||
def get_f(self, name): | |||||
name = name.lower() | |||||
if name == 'relu': | |||||
f = torch.relu | |||||
elif name == 'tanh': | |||||
f = torch.tanh | |||||
elif name == 'identity': | |||||
f = lambda x: x | |||||
elif name == 'sigmoid': | |||||
f = torch.sigmoid | |||||
return f | |||||
@property | |||||
def num_parameters(self): | |||||
def size(p): | |||||
return np.prod(p.size()) | |||||
return sum([size(param) for param in self.parameters()]) | |||||
def reset_parameters(self): | |||||
init_range = 0.025 | |||||
# init_range = 0.025 if self.args.mode == 'train' else 0.04 | |||||
for param in self.parameters(): | |||||
param.data.uniform_(-init_range, init_range) | |||||
self.decoder.bias.data.fill_(0) | |||||
def predict(self, word_seq): | |||||
""" | |||||
:param word_seq: torch.LongTensor, [batch_size, seq_len] | |||||
:return predict: dict of torch.LongTensor, [batch_size, seq_len] | |||||
""" | |||||
output = self(word_seq) | |||||
_, predict = output['pred'].max(dim=1) | |||||
return {'pred': predict} |
@@ -0,0 +1,382 @@ | |||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
import math | |||||
import time | |||||
from datetime import datetime | |||||
from datetime import timedelta | |||||
import numpy as np | |||||
import torch | |||||
try: | |||||
from tqdm.autonotebook import tqdm | |||||
except: | |||||
from fastNLP.core.utils import pseudo_tqdm as tqdm | |||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.callback import CallbackException | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.utils import _move_dict_value_to_device | |||||
import fastNLP | |||||
import fastNLP.automl.enas_utils as utils | |||||
from fastNLP.core.utils import _build_args | |||||
from torch.optim import Adam | |||||
def _get_no_grad_ctx_mgr(): | |||||
"""Returns a the `torch.no_grad` context manager for PyTorch version >= | |||||
0.4, or a no-op context manager otherwise. | |||||
""" | |||||
return torch.no_grad() | |||||
class ENASTrainer(fastNLP.Trainer): | |||||
"""A class to wrap training code.""" | |||||
def __init__(self, train_data, model, controller, **kwargs): | |||||
"""Constructor for training algorithm. | |||||
:param DataSet train_data: the training data | |||||
:param torch.nn.modules.module model: a PyTorch model | |||||
:param torch.nn.modules.module controller: a PyTorch model | |||||
""" | |||||
self.final_epochs = kwargs['final_epochs'] | |||||
kwargs.pop('final_epochs') | |||||
super(ENASTrainer, self).__init__(train_data, model, **kwargs) | |||||
self.controller_step = 0 | |||||
self.shared_step = 0 | |||||
self.max_length = 35 | |||||
self.shared = model | |||||
self.controller = controller | |||||
self.shared_optim = Adam( | |||||
self.shared.parameters(), | |||||
lr=20.0, | |||||
weight_decay=1e-7) | |||||
self.controller_optim = Adam( | |||||
self.controller.parameters(), | |||||
lr=3.5e-4) | |||||
def train(self, load_best_model=True): | |||||
""" | |||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | |||||
最好的模型参数。 | |||||
:return results: 返回一个字典类型的数据, 内含以下内容:: | |||||
seconds: float, 表示训练时长 | |||||
以下三个内容只有在提供了dev_data的情况下会有。 | |||||
best_eval: Dict of Dict, 表示evaluation的结果 | |||||
best_epoch: int,在第几个epoch取得的最佳值 | |||||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||||
""" | |||||
results = {} | |||||
if self.n_epochs <= 0: | |||||
print(f"training epoch is {self.n_epochs}, nothing was done.") | |||||
results['seconds'] = 0. | |||||
return results | |||||
try: | |||||
if torch.cuda.is_available() and self.use_cuda: | |||||
self.model = self.model.cuda() | |||||
self._model_device = self.model.parameters().__next__().device | |||||
self._mode(self.model, is_test=False) | |||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | |||||
start_time = time.time() | |||||
print("training epochs started " + self.start_time, flush=True) | |||||
try: | |||||
self.callback_manager.on_train_begin() | |||||
self._train() | |||||
self.callback_manager.on_train_end(self.model) | |||||
except (CallbackException, KeyboardInterrupt) as e: | |||||
self.callback_manager.on_exception(e, self.model) | |||||
if self.dev_data is not None: | |||||
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||||
self.tester._format_eval_results(self.best_dev_perf),) | |||||
results['best_eval'] = self.best_dev_perf | |||||
results['best_epoch'] = self.best_dev_epoch | |||||
results['best_step'] = self.best_dev_step | |||||
if load_best_model: | |||||
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) | |||||
load_succeed = self._load_model(self.model, model_name) | |||||
if load_succeed: | |||||
print("Reloaded the best model.") | |||||
else: | |||||
print("Fail to reload best model.") | |||||
finally: | |||||
pass | |||||
results['seconds'] = round(time.time() - start_time, 2) | |||||
return results | |||||
def _train(self): | |||||
if not self.use_tqdm: | |||||
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm | |||||
else: | |||||
inner_tqdm = tqdm | |||||
self.step = 0 | |||||
start = time.time() | |||||
total_steps = (len(self.train_data) // self.batch_size + int( | |||||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||||
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||||
avg_loss = 0 | |||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||||
prefetch=self.prefetch) | |||||
for epoch in range(1, self.n_epochs+1): | |||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | |||||
last_stage = (epoch > self.n_epochs + 1 - self.final_epochs) | |||||
if epoch == self.n_epochs + 1 - self.final_epochs: | |||||
print('Entering the final stage. (Only train the selected structure)') | |||||
# early stopping | |||||
self.callback_manager.on_epoch_begin(epoch, self.n_epochs) | |||||
# 1. Training the shared parameters omega of the child models | |||||
self.train_shared(pbar) | |||||
# 2. Training the controller parameters theta | |||||
if not last_stage: | |||||
self.train_controller() | |||||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | |||||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | |||||
and self.dev_data is not None: | |||||
if not last_stage: | |||||
self.derive() | |||||
eval_res = self._do_validation(epoch=epoch, step=self.step) | |||||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||||
total_steps) + \ | |||||
self.tester._format_eval_results(eval_res) | |||||
pbar.write(eval_str) | |||||
# lr decay; early stopping | |||||
self.callback_manager.on_epoch_end(epoch, self.n_epochs, self.optimizer) | |||||
# =============== epochs end =================== # | |||||
pbar.close() | |||||
# ============ tqdm end ============== # | |||||
def get_loss(self, inputs, targets, hidden, dags): | |||||
"""Computes the loss for the same batch for M models. | |||||
This amounts to an estimate of the loss, which is turned into an | |||||
estimate for the gradients of the shared model. | |||||
""" | |||||
if not isinstance(dags, list): | |||||
dags = [dags] | |||||
loss = 0 | |||||
for dag in dags: | |||||
self.shared.setDAG(dag) | |||||
inputs = _build_args(self.shared.forward, **inputs) | |||||
inputs['hidden'] = hidden | |||||
result = self.shared(**inputs) | |||||
output, hidden, extra_out = result['pred'], result['hidden'], result['extra_out'] | |||||
self.callback_manager.on_loss_begin(targets, result) | |||||
sample_loss = self._compute_loss(result, targets) | |||||
loss += sample_loss | |||||
assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`' | |||||
return loss, hidden, extra_out | |||||
def train_shared(self, pbar=None, max_step=None, dag=None): | |||||
"""Train the language model for 400 steps of minibatches of 64 | |||||
examples. | |||||
Args: | |||||
max_step: Used to run extra training steps as a warm-up. | |||||
dag: If not None, is used instead of calling sample(). | |||||
BPTT is truncated at 35 timesteps. | |||||
For each weight update, gradients are estimated by sampling M models | |||||
from the fixed controller policy, and averaging their gradients | |||||
computed on a batch of training data. | |||||
""" | |||||
model = self.shared | |||||
model.train() | |||||
self.controller.eval() | |||||
hidden = self.shared.init_hidden(self.batch_size) | |||||
abs_max_grad = 0 | |||||
abs_max_hidden_norm = 0 | |||||
step = 0 | |||||
raw_total_loss = 0 | |||||
total_loss = 0 | |||||
train_idx = 0 | |||||
avg_loss = 0 | |||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||||
prefetch=self.prefetch) | |||||
for batch_x, batch_y in data_iterator: | |||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||||
indices = data_iterator.get_batch_indices() | |||||
# negative sampling; replace unknown; re-weight batch_y | |||||
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | |||||
# prediction = self._data_forward(self.model, batch_x) | |||||
dags = self.controller.sample(1) | |||||
inputs, targets = batch_x, batch_y | |||||
# self.callback_manager.on_loss_begin(batch_y, prediction) | |||||
loss, hidden, extra_out = self.get_loss(inputs, | |||||
targets, | |||||
hidden, | |||||
dags) | |||||
hidden.detach_() | |||||
avg_loss += loss.item() | |||||
# Is loss NaN or inf? requires_grad = False | |||||
self.callback_manager.on_backward_begin(loss, self.model) | |||||
self._grad_backward(loss) | |||||
self.callback_manager.on_backward_end(self.model) | |||||
self._update() | |||||
self.callback_manager.on_step_end(self.optimizer) | |||||
if (self.step+1) % self.print_every == 0: | |||||
if self.use_tqdm: | |||||
print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every) | |||||
pbar.update(self.print_every) | |||||
else: | |||||
end = time.time() | |||||
diff = timedelta(seconds=round(end - start)) | |||||
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( | |||||
epoch, self.step, avg_loss, diff) | |||||
pbar.set_postfix_str(print_output) | |||||
avg_loss = 0 | |||||
self.step += 1 | |||||
step += 1 | |||||
self.shared_step += 1 | |||||
self.callback_manager.on_batch_end() | |||||
# ================= mini-batch end ==================== # | |||||
def get_reward(self, dag, entropies, hidden, valid_idx=0): | |||||
"""Computes the perplexity of a single sampled model on a minibatch of | |||||
validation data. | |||||
""" | |||||
if not isinstance(entropies, np.ndarray): | |||||
entropies = entropies.data.cpu().numpy() | |||||
data_iterator = Batch(self.dev_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||||
prefetch=self.prefetch) | |||||
for inputs, targets in data_iterator: | |||||
valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag) | |||||
valid_loss = utils.to_item(valid_loss.data) | |||||
valid_ppl = math.exp(valid_loss) | |||||
R = 80 / valid_ppl | |||||
rewards = R + 1e-4 * entropies | |||||
return rewards, hidden | |||||
def train_controller(self): | |||||
"""Fixes the shared parameters and updates the controller parameters. | |||||
The controller is updated with a score function gradient estimator | |||||
(i.e., REINFORCE), with the reward being c/valid_ppl, where valid_ppl | |||||
is computed on a minibatch of validation data. | |||||
A moving average baseline is used. | |||||
The controller is trained for 2000 steps per epoch (i.e., | |||||
first (Train Shared) phase -> second (Train Controller) phase). | |||||
""" | |||||
model = self.controller | |||||
model.train() | |||||
# Why can't we call shared.eval() here? Leads to loss | |||||
# being uniformly zero for the controller. | |||||
# self.shared.eval() | |||||
avg_reward_base = None | |||||
baseline = None | |||||
adv_history = [] | |||||
entropy_history = [] | |||||
reward_history = [] | |||||
hidden = self.shared.init_hidden(self.batch_size) | |||||
total_loss = 0 | |||||
valid_idx = 0 | |||||
for step in range(20): | |||||
# sample models | |||||
dags, log_probs, entropies = self.controller.sample( | |||||
with_details=True) | |||||
# calculate reward | |||||
np_entropies = entropies.data.cpu().numpy() | |||||
# No gradients should be backpropagated to the | |||||
# shared model during controller training, obviously. | |||||
with _get_no_grad_ctx_mgr(): | |||||
rewards, hidden = self.get_reward(dags, | |||||
np_entropies, | |||||
hidden, | |||||
valid_idx) | |||||
reward_history.extend(rewards) | |||||
entropy_history.extend(np_entropies) | |||||
# moving average baseline | |||||
if baseline is None: | |||||
baseline = rewards | |||||
else: | |||||
decay = 0.95 | |||||
baseline = decay * baseline + (1 - decay) * rewards | |||||
adv = rewards - baseline | |||||
adv_history.extend(adv) | |||||
# policy loss | |||||
loss = -log_probs*utils.get_variable(adv, | |||||
self.use_cuda, | |||||
requires_grad=False) | |||||
loss = loss.sum() # or loss.mean() | |||||
# update | |||||
self.controller_optim.zero_grad() | |||||
loss.backward() | |||||
self.controller_optim.step() | |||||
total_loss += utils.to_item(loss.data) | |||||
if ((step % 50) == 0) and (step > 0): | |||||
reward_history, adv_history, entropy_history = [], [], [] | |||||
total_loss = 0 | |||||
self.controller_step += 1 | |||||
# prev_valid_idx = valid_idx | |||||
# valid_idx = ((valid_idx + self.max_length) % | |||||
# (self.valid_data.size(0) - 1)) | |||||
# # Whenever we wrap around to the beginning of the | |||||
# # validation data, we reset the hidden states. | |||||
# if prev_valid_idx > valid_idx: | |||||
# hidden = self.shared.init_hidden(self.batch_size) | |||||
def derive(self, sample_num=10, valid_idx=0): | |||||
"""We are always deriving based on the very first batch | |||||
of validation data? This seems wrong... | |||||
""" | |||||
hidden = self.shared.init_hidden(self.batch_size) | |||||
dags, _, entropies = self.controller.sample(sample_num, | |||||
with_details=True) | |||||
max_R = 0 | |||||
best_dag = None | |||||
for dag in dags: | |||||
R, _ = self.get_reward(dag, entropies, hidden, valid_idx) | |||||
if R.max() > max_R: | |||||
max_R = R.max() | |||||
best_dag = dag | |||||
self.model.setDAG(best_dag) |
@@ -0,0 +1,53 @@ | |||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
from __future__ import print_function | |||||
import collections | |||||
from collections import defaultdict | |||||
import numpy as np | |||||
import torch | |||||
from torch.autograd import Variable | |||||
def detach(h): | |||||
if type(h) == Variable: | |||||
return Variable(h.data) | |||||
else: | |||||
return tuple(detach(v) for v in h) | |||||
def get_variable(inputs, cuda=False, **kwargs): | |||||
if type(inputs) in [list, np.ndarray]: | |||||
inputs = torch.Tensor(inputs) | |||||
if cuda: | |||||
out = Variable(inputs.cuda(), **kwargs) | |||||
else: | |||||
out = Variable(inputs, **kwargs) | |||||
return out | |||||
def update_lr(optimizer, lr): | |||||
for param_group in optimizer.param_groups: | |||||
param_group['lr'] = lr | |||||
Node = collections.namedtuple('Node', ['id', 'name']) | |||||
class keydefaultdict(defaultdict): | |||||
def __missing__(self, key): | |||||
if self.default_factory is None: | |||||
raise KeyError(key) | |||||
else: | |||||
ret = self[key] = self.default_factory(key) | |||||
return ret | |||||
def to_item(x): | |||||
"""Converts x, possibly scalar and possibly tensor, to a Python scalar.""" | |||||
if isinstance(x, (float, int)): | |||||
return x | |||||
if float(torch.__version__[0:3]) < 0.4: | |||||
assert (x.dim() == 1) and (len(x) == 1) | |||||
return x[0] | |||||
return x.item() |
@@ -138,7 +138,6 @@ class CallbackManager(Callback): | |||||
""" | """ | ||||
super(CallbackManager, self).__init__() | super(CallbackManager, self).__init__() | ||||
# set attribute of trainer environment | # set attribute of trainer environment | ||||
self.env = env | |||||
self.callbacks = [] | self.callbacks = [] | ||||
if callbacks is not None: | if callbacks is not None: | ||||
@@ -157,7 +157,7 @@ class DataSet(object): | |||||
assert name in self.field_arrays | assert name in self.field_arrays | ||||
self.field_arrays[name].append(field) | self.field_arrays[name].append(field) | ||||
def add_field(self, name, fields, padder=AutoPadder(pad_val=0), is_input=False, is_target=False): | |||||
def add_field(self, name, fields, padder=AutoPadder(pad_val=0), is_input=False, is_target=False, ignore_type=False): | |||||
"""Add a new field to the DataSet. | """Add a new field to the DataSet. | ||||
:param str name: the name of the field. | :param str name: the name of the field. | ||||
@@ -165,13 +165,14 @@ class DataSet(object): | |||||
:param int padder: PadBase对象,如何对该Field进行padding。大部分情况使用默认值即可 | :param int padder: PadBase对象,如何对该Field进行padding。大部分情况使用默认值即可 | ||||
:param bool is_input: whether this field is model input. | :param bool is_input: whether this field is model input. | ||||
:param bool is_target: whether this field is label or target. | :param bool is_target: whether this field is label or target. | ||||
:param bool ignore_type: If True, do not perform type check. (Default: False) | |||||
""" | """ | ||||
if len(self.field_arrays) != 0: | if len(self.field_arrays) != 0: | ||||
if len(self) != len(fields): | if len(self) != len(fields): | ||||
raise RuntimeError(f"The field to append must have the same size as dataset. " | raise RuntimeError(f"The field to append must have the same size as dataset. " | ||||
f"Dataset size {len(self)} != field size {len(fields)}") | f"Dataset size {len(self)} != field size {len(fields)}") | ||||
self.field_arrays[name] = FieldArray(name, fields, is_target=is_target, is_input=is_input, | self.field_arrays[name] = FieldArray(name, fields, is_target=is_target, is_input=is_input, | ||||
padder=padder) | |||||
padder=padder, ignore_type=ignore_type) | |||||
def delete_field(self, name): | def delete_field(self, name): | ||||
"""Delete a field based on the field name. | """Delete a field based on the field name. | ||||
@@ -242,6 +243,8 @@ class DataSet(object): | |||||
:param padder: PadderBase类型或None. 设置为None即删除padder。即对该field不进行padding操作. | :param padder: PadderBase类型或None. 设置为None即删除padder。即对该field不进行padding操作. | ||||
:return: | :return: | ||||
""" | """ | ||||
if field_name not in self.field_arrays: | |||||
raise KeyError("There is no field named {}.".format(field_name)) | |||||
self.field_arrays[field_name].set_padder(padder) | self.field_arrays[field_name].set_padder(padder) | ||||
def set_pad_val(self, field_name, pad_val): | def set_pad_val(self, field_name, pad_val): | ||||
@@ -252,6 +255,8 @@ class DataSet(object): | |||||
:param pad_val: int,该field的padder会以pad_val作为padding index | :param pad_val: int,该field的padder会以pad_val作为padding index | ||||
:return: | :return: | ||||
""" | """ | ||||
if field_name not in self.field_arrays: | |||||
raise KeyError("There is no field named {}.".format(field_name)) | |||||
self.field_arrays[field_name].set_pad_val(pad_val) | self.field_arrays[field_name].set_pad_val(pad_val) | ||||
def get_input_name(self): | def get_input_name(self): | ||||
@@ -287,6 +292,8 @@ class DataSet(object): | |||||
extra_param['is_input'] = kwargs['is_input'] | extra_param['is_input'] = kwargs['is_input'] | ||||
if 'is_target' in kwargs: | if 'is_target' in kwargs: | ||||
extra_param['is_target'] = kwargs['is_target'] | extra_param['is_target'] = kwargs['is_target'] | ||||
if 'ignore_type' in kwargs: | |||||
extra_param['ignore_type'] = kwargs['ignore_type'] | |||||
if new_field_name is not None: | if new_field_name is not None: | ||||
if new_field_name in self.field_arrays: | if new_field_name in self.field_arrays: | ||||
# overwrite the field, keep same attributes | # overwrite the field, keep same attributes | ||||
@@ -295,11 +302,14 @@ class DataSet(object): | |||||
extra_param['is_input'] = old_field.is_input | extra_param['is_input'] = old_field.is_input | ||||
if 'is_target' not in extra_param: | if 'is_target' not in extra_param: | ||||
extra_param['is_target'] = old_field.is_target | extra_param['is_target'] = old_field.is_target | ||||
if 'ignore_type' not in extra_param: | |||||
extra_param['ignore_type'] = old_field.ignore_type | |||||
self.add_field(name=new_field_name, fields=results, is_input=extra_param["is_input"], | self.add_field(name=new_field_name, fields=results, is_input=extra_param["is_input"], | ||||
is_target=extra_param["is_target"]) | |||||
is_target=extra_param["is_target"], ignore_type=extra_param['ignore_type']) | |||||
else: | else: | ||||
self.add_field(name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), | self.add_field(name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), | ||||
is_target=extra_param.get("is_target", None)) | |||||
is_target=extra_param.get("is_target", None), | |||||
ignore_type=extra_param.get("ignore_type", False)) | |||||
else: | else: | ||||
return results | return results | ||||
@@ -1,5 +1,5 @@ | |||||
import numpy as np | import numpy as np | ||||
from copy import deepcopy | |||||
class PadderBase: | class PadderBase: | ||||
""" | """ | ||||
@@ -83,6 +83,8 @@ class AutoPadder(PadderBase): | |||||
array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) | array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) | ||||
for i, content in enumerate(contents): | for i, content in enumerate(contents): | ||||
array[i][:len(content)] = content | array[i][:len(content)] = content | ||||
elif field_ele_dtype is None: | |||||
array = contents # 当ignore_type=True时,直接返回contents | |||||
else: # should only be str | else: # should only be str | ||||
array = np.array([content for content in contents]) | array = np.array([content for content in contents]) | ||||
return array | return array | ||||
@@ -96,10 +98,16 @@ class FieldArray(object): | |||||
:param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray. | :param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray. | ||||
:param bool is_target: If True, this FieldArray is used to compute loss. | :param bool is_target: If True, this FieldArray is used to compute loss. | ||||
:param bool is_input: If True, this FieldArray is used to the model input. | :param bool is_input: If True, this FieldArray is used to the model input. | ||||
:param padder: PadderBase类型。大多数情况下都不需要设置该值,除非需要在多个维度上进行padding(比如英文中对character进行padding) | |||||
:param PadderBase padder: PadderBase类型。赋值给fieldarray的padder的对象会被deepcopy一份,需要修改padder参数必须通过 | |||||
fieldarray.set_pad_val()。 | |||||
默认为None,(1)如果某个field是scalar,则不进行任何padding;(2)如果为一维list, 且fieldarray的dtype为float或int类型 | |||||
则会进行padding;(3)其它情况不进行padder。 | |||||
假设需要对English word中character进行padding,则需要使用其他的padder。 | |||||
或ignore_type为True但是需要进行padding。 | |||||
:param bool ignore_type: whether to ignore type. If True, no type detection will rise for this FieldArray. (default: False) | |||||
""" | """ | ||||
def __init__(self, name, content, is_target=None, is_input=None, padder=AutoPadder(pad_val=0)): | |||||
def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False): | |||||
"""DataSet在初始化时会有两类方法对FieldArray操作: | """DataSet在初始化时会有两类方法对FieldArray操作: | ||||
1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray: | 1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray: | ||||
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) | 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) | ||||
@@ -114,6 +122,7 @@ class FieldArray(object): | |||||
2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))]) | 2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))]) | ||||
类型检查(dtype check)发生在当该field被设置为is_input或者is_target时。 | 类型检查(dtype check)发生在当该field被设置为is_input或者is_target时。 | ||||
ignore_type用来控制是否进行类型检查,如果为True,则不检查。 | |||||
""" | """ | ||||
self.name = name | self.name = name | ||||
@@ -135,7 +144,13 @@ class FieldArray(object): | |||||
self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐 | self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐 | ||||
self.content_dim = None # 表示content是多少维的list | self.content_dim = None # 表示content是多少维的list | ||||
if padder is None: | |||||
padder = AutoPadder(pad_val=0) | |||||
else: | |||||
assert isinstance(padder, PadderBase), "padder must be of type PadderBase." | |||||
padder = deepcopy(padder) | |||||
self.set_padder(padder) | self.set_padder(padder) | ||||
self.ignore_type = ignore_type | |||||
self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array | self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array | ||||
@@ -149,8 +164,9 @@ class FieldArray(object): | |||||
self.is_target = is_target | self.is_target = is_target | ||||
def _set_dtype(self): | def _set_dtype(self): | ||||
self.pytype = self._type_detection(self.content) | |||||
self.dtype = self._map_to_np_type(self.pytype) | |||||
if self.ignore_type is False: | |||||
self.pytype = self._type_detection(self.content) | |||||
self.dtype = self._map_to_np_type(self.pytype) | |||||
@property | @property | ||||
def is_input(self): | def is_input(self): | ||||
@@ -190,7 +206,7 @@ class FieldArray(object): | |||||
if list in type_set: | if list in type_set: | ||||
if len(type_set) > 1: | if len(type_set) > 1: | ||||
# list 跟 非list 混在一起 | # list 跟 非list 混在一起 | ||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||||
# >1维list | # >1维list | ||||
inner_type_set = set() | inner_type_set = set() | ||||
for l in content: | for l in content: | ||||
@@ -213,7 +229,7 @@ class FieldArray(object): | |||||
return self._basic_type_detection(inner_inner_type_set) | return self._basic_type_detection(inner_inner_type_set) | ||||
else: | else: | ||||
# list 跟 非list 混在一起 | # list 跟 非list 混在一起 | ||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, inner_type_set)) | |||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(inner_type_set))) | |||||
else: | else: | ||||
# 一维list | # 一维list | ||||
for content_type in type_set: | for content_type in type_set: | ||||
@@ -237,17 +253,17 @@ class FieldArray(object): | |||||
return float | return float | ||||
else: | else: | ||||
# str 跟 int 或者 float 混在一起 | # str 跟 int 或者 float 混在一起 | ||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||||
else: | else: | ||||
# str, int, float混在一起 | # str, int, float混在一起 | ||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||||
def _1d_list_check(self, val): | def _1d_list_check(self, val): | ||||
"""如果不是1D list就报错 | """如果不是1D list就报错 | ||||
""" | """ | ||||
type_set = set((type(obj) for obj in val)) | type_set = set((type(obj) for obj in val)) | ||||
if any(obj not in self.BASIC_TYPES for obj in type_set): | if any(obj not in self.BASIC_TYPES for obj in type_set): | ||||
raise ValueError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||||
raise ValueError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||||
self._basic_type_detection(type_set) | self._basic_type_detection(type_set) | ||||
# otherwise: _basic_type_detection will raise error | # otherwise: _basic_type_detection will raise error | ||||
return True | return True | ||||
@@ -278,39 +294,40 @@ class FieldArray(object): | |||||
:param val: int, float, str, or a list of one. | :param val: int, float, str, or a list of one. | ||||
""" | """ | ||||
if isinstance(val, list): | |||||
pass | |||||
elif isinstance(val, tuple): # 确保最外层是list | |||||
val = list(val) | |||||
elif isinstance(val, np.ndarray): | |||||
val = val.tolist() | |||||
elif any((isinstance(val, t) for t in self.BASIC_TYPES)): | |||||
pass | |||||
else: | |||||
raise RuntimeError( | |||||
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) | |||||
if self.is_input is True or self.is_target is True: | |||||
if type(val) == list: | |||||
if len(val) == 0: | |||||
raise ValueError("Cannot append an empty list.") | |||||
if self.content_dim == 2 and self._1d_list_check(val): | |||||
# 1维list检查 | |||||
pass | |||||
elif self.content_dim == 3 and self._2d_list_check(val): | |||||
# 2维list检查 | |||||
pass | |||||
else: | |||||
raise RuntimeError( | |||||
"Dimension not matched: expect dim={}, got {}.".format(self.content_dim - 1, val)) | |||||
elif type(val) in self.BASIC_TYPES and self.content_dim == 1: | |||||
# scalar检查 | |||||
if type(val) == float and self.pytype == int: | |||||
self.pytype = float | |||||
self.dtype = self._map_to_np_type(self.pytype) | |||||
if self.ignore_type is False: | |||||
if isinstance(val, list): | |||||
pass | |||||
elif isinstance(val, tuple): # 确保最外层是list | |||||
val = list(val) | |||||
elif isinstance(val, np.ndarray): | |||||
val = val.tolist() | |||||
elif any((isinstance(val, t) for t in self.BASIC_TYPES)): | |||||
pass | |||||
else: | else: | ||||
raise RuntimeError( | raise RuntimeError( | ||||
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) | "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) | ||||
if self.is_input is True or self.is_target is True: | |||||
if type(val) == list: | |||||
if len(val) == 0: | |||||
raise ValueError("Cannot append an empty list.") | |||||
if self.content_dim == 2 and self._1d_list_check(val): | |||||
# 1维list检查 | |||||
pass | |||||
elif self.content_dim == 3 and self._2d_list_check(val): | |||||
# 2维list检查 | |||||
pass | |||||
else: | |||||
raise RuntimeError( | |||||
"Dimension not matched: expect dim={}, got {}.".format(self.content_dim - 1, val)) | |||||
elif type(val) in self.BASIC_TYPES and self.content_dim == 1: | |||||
# scalar检查 | |||||
if type(val) == float and self.pytype == int: | |||||
self.pytype = float | |||||
self.dtype = self._map_to_np_type(self.pytype) | |||||
else: | |||||
raise RuntimeError( | |||||
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) | |||||
self.content.append(val) | self.content.append(val) | ||||
def __getitem__(self, indices): | def __getitem__(self, indices): | ||||
@@ -347,7 +364,7 @@ class FieldArray(object): | |||||
""" | """ | ||||
if padder is not None: | if padder is not None: | ||||
assert isinstance(padder, PadderBase), "padder must be of type PadderBase." | assert isinstance(padder, PadderBase), "padder must be of type PadderBase." | ||||
self.padder = padder | |||||
self.padder = deepcopy(padder) | |||||
def set_pad_val(self, pad_val): | def set_pad_val(self, pad_val): | ||||
""" | """ | ||||
@@ -157,7 +157,7 @@ class MetricBase(object): | |||||
fast_param = {} | fast_param = {} | ||||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | ||||
fast_param['pred'] = list(pred_dict.values())[0] | fast_param['pred'] = list(pred_dict.values())[0] | ||||
fast_param['target'] = list(pred_dict.values())[0] | |||||
fast_param['target'] = list(target_dict.values())[0] | |||||
return fast_param | return fast_param | ||||
return fast_param | return fast_param | ||||
@@ -822,3 +822,154 @@ def pred_topk(y_prob, k=1): | |||||
(1, k)) | (1, k)) | ||||
y_prob_topk = y_prob[x_axis_index, y_pred_topk] | y_prob_topk = y_prob[x_axis_index, y_pred_topk] | ||||
return y_pred_topk, y_prob_topk | return y_pred_topk, y_prob_topk | ||||
class SQuADMetric(MetricBase): | |||||
def __init__(self, pred_start=None, pred_end=None, target_start=None, target_end=None, | |||||
beta=1, right_open=False, print_predict_stat=False): | |||||
""" | |||||
:param pred_start: [batch], 预测答案开始的index, 如果SQuAD2.0中答案为空则为0 | |||||
:param pred_end: [batch], 预测答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间) | |||||
:param target_start: [batch], 正确答案开始的index, 如果SQuAD2.0中答案为空则为0 | |||||
:param target_end: [batch], 正确答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间) | |||||
:param beta: float. f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 | |||||
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||||
:param right_open: boolean. right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。 | |||||
:param print_predict_stat: boolean. True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 | |||||
""" | |||||
super(SQuADMetric, self).__init__() | |||||
self._init_param_map(pred_start=pred_start, pred_end=pred_end, target_start=target_start, target_end=target_end) | |||||
self.print_predict_stat = print_predict_stat | |||||
self.no_ans_correct = 0 | |||||
self.no_ans_wrong = 0 | |||||
self.has_ans_correct = 0 | |||||
self.has_ans_wrong = 0 | |||||
self.has_ans_f = 0. | |||||
self.no2no = 0 | |||||
self.no2yes = 0 | |||||
self.yes2no = 0 | |||||
self.yes2yes = 0 | |||||
self.f_beta = beta | |||||
self.right_open = right_open | |||||
def evaluate(self, pred_start, pred_end, target_start, target_end): | |||||
""" | |||||
:param pred_start: [batch, seq_len] | |||||
:param pred_end: [batch, seq_len] | |||||
:param target_start: [batch] | |||||
:param target_end: [batch] | |||||
:param labels: [batch] | |||||
:return: | |||||
""" | |||||
start_inference = pred_start.max(dim=-1)[1].cpu().tolist() | |||||
end_inference = pred_end.max(dim=-1)[1].cpu().tolist() | |||||
start, end = [], [] | |||||
max_len = pred_start.size(1) | |||||
t_start = target_start.cpu().tolist() | |||||
t_end = target_end.cpu().tolist() | |||||
for s, e in zip(start_inference, end_inference): | |||||
start.append(min(s, e)) | |||||
end.append(max(s, e)) | |||||
for s, e, ts, te in zip(start, end, t_start, t_end): | |||||
if not self.right_open: | |||||
e += 1 | |||||
te += 1 | |||||
if ts == 0 and te == int(not self.right_open): | |||||
if s == 0 and e == int(not self.right_open): | |||||
self.no_ans_correct += 1 | |||||
self.no2no += 1 | |||||
else: | |||||
self.no_ans_wrong += 1 | |||||
self.no2yes += 1 | |||||
else: | |||||
if s == 0 and e == int(not self.right_open): | |||||
self.yes2no += 1 | |||||
else: | |||||
self.yes2yes += 1 | |||||
if s == ts and e == te: | |||||
self.has_ans_correct += 1 | |||||
else: | |||||
self.has_ans_wrong += 1 | |||||
a = [0] * s + [1] * (e - s) + [0] * (max_len - e) | |||||
b = [0] * ts + [1] * (te - ts) + [0] * (max_len - te) | |||||
a, b = torch.tensor(a), torch.tensor(b) | |||||
TP = int(torch.sum(a * b)) | |||||
pre = TP / int(torch.sum(a)) if int(torch.sum(a)) > 0 else 0 | |||||
rec = TP / int(torch.sum(b)) if int(torch.sum(b)) > 0 else 0 | |||||
if pre + rec > 0: | |||||
f = (1 + (self.f_beta**2)) * pre * rec / ((self.f_beta**2) * pre + rec) | |||||
else: | |||||
f = 0 | |||||
self.has_ans_f += f | |||||
def get_metric(self, reset=True): | |||||
evaluate_result = {} | |||||
if self.no_ans_correct + self.no_ans_wrong + self.has_ans_correct + self.no_ans_wrong <= 0: | |||||
return evaluate_result | |||||
evaluate_result['EM'] = 0 | |||||
evaluate_result[f'f_{self.f_beta}'] = 0 | |||||
flag = 0 | |||||
if self.no_ans_correct + self.no_ans_wrong > 0: | |||||
evaluate_result[f'noAns-f_{self.f_beta}'] = \ | |||||
round(100 * self.no_ans_correct / (self.no_ans_correct + self.no_ans_wrong), 3) | |||||
evaluate_result['noAns-EM'] = \ | |||||
round(100 * self.no_ans_correct / (self.no_ans_correct + self.no_ans_wrong), 3) | |||||
evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'noAns-f_{self.f_beta}'] | |||||
evaluate_result['EM'] += evaluate_result['noAns-EM'] | |||||
flag += 1 | |||||
if self.has_ans_correct + self.has_ans_wrong > 0: | |||||
evaluate_result[f'hasAns-f_{self.f_beta}'] = \ | |||||
round(100 * self.has_ans_f / (self.has_ans_correct + self.has_ans_wrong), 3) | |||||
evaluate_result['hasAns-EM'] = \ | |||||
round(100 * self.has_ans_correct / (self.has_ans_correct + self.has_ans_wrong), 3) | |||||
evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'hasAns-f_{self.f_beta}'] | |||||
evaluate_result['EM'] += evaluate_result['hasAns-EM'] | |||||
flag += 1 | |||||
if self.print_predict_stat: | |||||
evaluate_result['no2no'] = self.no2no | |||||
evaluate_result['no2yes'] = self.no2yes | |||||
evaluate_result['yes2no'] = self.yes2no | |||||
evaluate_result['yes2yes'] = self.yes2yes | |||||
if flag <= 0: | |||||
return evaluate_result | |||||
evaluate_result[f'f_{self.f_beta}'] = round(evaluate_result[f'f_{self.f_beta}'] / flag, 3) | |||||
evaluate_result['EM'] = round(evaluate_result['EM'] / flag, 3) | |||||
if reset: | |||||
self.no_ans_correct = 0 | |||||
self.no_ans_wrong = 0 | |||||
self.has_ans_correct = 0 | |||||
self.has_ans_wrong = 0 | |||||
self.has_ans_f = 0. | |||||
self.no2no = 0 | |||||
self.no2yes = 0 | |||||
self.yes2no = 0 | |||||
self.yes2yes = 0 | |||||
return evaluate_result | |||||
@@ -32,8 +32,8 @@ from fastNLP.core.utils import get_func_signature | |||||
class Trainer(object): | class Trainer(object): | ||||
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | ||||
validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), | |||||
check_code_level=0, metric_key=None, sampler=RandomSampler(), prefetch=False, use_tqdm=True, | |||||
validate_every=-1, dev_data=None, save_path=None, optimizer=None, | |||||
check_code_level=0, metric_key=None, sampler=None, prefetch=False, use_tqdm=True, | |||||
use_cuda=False, callbacks=None): | use_cuda=False, callbacks=None): | ||||
""" | """ | ||||
:param DataSet train_data: the training data | :param DataSet train_data: the training data | ||||
@@ -96,7 +96,7 @@ class Trainer(object): | |||||
losser = _prepare_losser(loss) | losser = _prepare_losser(loss) | ||||
# sampler check | # sampler check | ||||
if not isinstance(sampler, BaseSampler): | |||||
if sampler is not None and not isinstance(sampler, BaseSampler): | |||||
raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) | raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) | ||||
if check_code_level > -1: | if check_code_level > -1: | ||||
@@ -119,7 +119,7 @@ class Trainer(object): | |||||
self.best_dev_epoch = None | self.best_dev_epoch = None | ||||
self.best_dev_step = None | self.best_dev_step = None | ||||
self.best_dev_perf = None | self.best_dev_perf = None | ||||
self.sampler = sampler | |||||
self.sampler = sampler if sampler is not None else RandomSampler() | |||||
self.prefetch = prefetch | self.prefetch = prefetch | ||||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | ||||
self.n_steps = (len(self.train_data) // self.batch_size + int( | self.n_steps = (len(self.train_data) // self.batch_size + int( | ||||
@@ -128,6 +128,8 @@ class Trainer(object): | |||||
if isinstance(optimizer, torch.optim.Optimizer): | if isinstance(optimizer, torch.optim.Optimizer): | ||||
self.optimizer = optimizer | self.optimizer = optimizer | ||||
else: | else: | ||||
if optimizer is None: | |||||
optimizer = Adam(lr=0.01, weight_decay=0) | |||||
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) | self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) | ||||
self.use_tqdm = use_tqdm | self.use_tqdm = use_tqdm | ||||
@@ -145,6 +147,7 @@ class Trainer(object): | |||||
self.step = 0 | self.step = 0 | ||||
self.start_time = None # start timestamp | self.start_time = None # start timestamp | ||||
def train(self, load_best_model=True): | def train(self, load_best_model=True): | ||||
""" | """ | ||||
@@ -365,6 +368,8 @@ class Trainer(object): | |||||
""" | """ | ||||
if self.save_path is not None: | if self.save_path is not None: | ||||
model_path = os.path.join(self.save_path, model_name) | model_path = os.path.join(self.save_path, model_name) | ||||
if not os.path.exists(self.save_path): | |||||
os.makedirs(self.save_path, exist_ok=True) | |||||
if only_param: | if only_param: | ||||
state_dict = model.state_dict() | state_dict = model.state_dict() | ||||
for key in state_dict: | for key in state_dict: | ||||
@@ -196,3 +196,9 @@ class Vocabulary(object): | |||||
""" | """ | ||||
self.__dict__.update(state) | self.__dict__.update(state) | ||||
self.build_reverse_vocab() | self.build_reverse_vocab() | ||||
def __repr__(self): | |||||
return "Vocabulary({}...)".format(list(self.word_count.keys())[:5]) | |||||
def __iter__(self): | |||||
return iter(list(self.word_count.keys())) |
@@ -192,20 +192,23 @@ class ConditionalRandomField(nn.Module): | |||||
seq_len, batch_size, n_tags = logits.size() | seq_len, batch_size, n_tags = logits.size() | ||||
alpha = logits[0] | alpha = logits[0] | ||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
alpha += self.start_scores.view(1, -1) | |||||
alpha = alpha + self.start_scores.view(1, -1) | |||||
flip_mask = mask.eq(0) | |||||
for i in range(1, seq_len): | for i in range(1, seq_len): | ||||
emit_score = logits[i].view(batch_size, 1, n_tags) | emit_score = logits[i].view(batch_size, 1, n_tags) | ||||
trans_score = self.trans_m.view(1, n_tags, n_tags) | trans_score = self.trans_m.view(1, n_tags, n_tags) | ||||
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score | tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score | ||||
alpha = log_sum_exp(tmp, 1) * mask[i].view(batch_size, 1) + alpha * (1 - mask[i]).view(batch_size, 1) | |||||
alpha = log_sum_exp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ | |||||
alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0) | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
alpha += self.end_scores.view(1, -1) | |||||
alpha = alpha + self.end_scores.view(1, -1) | |||||
return log_sum_exp(alpha, 1) | return log_sum_exp(alpha, 1) | ||||
def _glod_score(self, logits, tags, mask): | |||||
def _gold_score(self, logits, tags, mask): | |||||
""" | """ | ||||
Compute the score for the gold path. | Compute the score for the gold path. | ||||
:param logits: FloatTensor, max_len x batch_size x num_tags | :param logits: FloatTensor, max_len x batch_size x num_tags | ||||
@@ -218,17 +221,19 @@ class ConditionalRandomField(nn.Module): | |||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | ||||
# trans_socre [L-1, B] | # trans_socre [L-1, B] | ||||
trans_score = self.trans_m[tags[:seq_len-1], tags[1:]] * mask[1:, :] | |||||
mask = mask.byte() | |||||
flip_mask = mask.eq(0) | |||||
trans_score = self.trans_m[tags[:seq_len-1], tags[1:]].masked_fill(flip_mask[1:, :], 0) | |||||
# emit_score [L, B] | # emit_score [L, B] | ||||
emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags] * mask | |||||
emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags].masked_fill(flip_mask, 0) | |||||
# score [L-1, B] | # score [L-1, B] | ||||
score = trans_score + emit_score[:seq_len-1, :] | score = trans_score + emit_score[:seq_len-1, :] | ||||
score = score.sum(0) + emit_score[-1] * mask[-1] | |||||
score = score.sum(0) + emit_score[-1].masked_fill(flip_mask[-1], 0) | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | ||||
last_idx = mask.long().sum(0) - 1 | last_idx = mask.long().sum(0) - 1 | ||||
ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] | ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] | ||||
score += st_scores + ed_scores | |||||
score = score + st_scores + ed_scores | |||||
# return [B,] | # return [B,] | ||||
return score | return score | ||||
@@ -244,7 +249,7 @@ class ConditionalRandomField(nn.Module): | |||||
tags = tags.transpose(0, 1).long() | tags = tags.transpose(0, 1).long() | ||||
mask = mask.transpose(0, 1).float() | mask = mask.transpose(0, 1).float() | ||||
all_path_score = self._normalizer_likelihood(feats, mask) | all_path_score = self._normalizer_likelihood(feats, mask) | ||||
gold_path_score = self._glod_score(feats, tags, mask) | |||||
gold_path_score = self._gold_score(feats, tags, mask) | |||||
return all_path_score - gold_path_score | return all_path_score - gold_path_score | ||||
@@ -265,7 +270,7 @@ class ConditionalRandomField(nn.Module): | |||||
""" | """ | ||||
batch_size, seq_len, n_tags = data.size() | batch_size, seq_len, n_tags = data.size() | ||||
data = data.transpose(0, 1).data # L, B, H | data = data.transpose(0, 1).data # L, B, H | ||||
mask = mask.transpose(0, 1).data.float() # L, B | |||||
mask = mask.transpose(0, 1).data.byte() # L, B | |||||
# dp | # dp | ||||
vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | ||||
@@ -284,7 +289,8 @@ class ConditionalRandomField(nn.Module): | |||||
score = prev_score + trans_score + cur_score | score = prev_score + trans_score + cur_score | ||||
best_score, best_dst = score.max(1) | best_score, best_dst = score.max(1) | ||||
vpath[i] = best_dst | vpath[i] = best_dst | ||||
vscore = best_score * mask[i].view(batch_size, 1) + vscore * (1 - mask[i]).view(batch_size, 1) | |||||
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ | |||||
vscore.masked_fill(mask[i].view(batch_size, 1), 0) | |||||
vscore += transitions[:n_tags, n_tags+1].view(1, -1) | vscore += transitions[:n_tags, n_tags+1].view(1, -1) | ||||
@@ -60,7 +60,8 @@ def initial_parameter(net, initial_method=None): | |||||
init_method(w.data) # weight | init_method(w.data) # weight | ||||
else: | else: | ||||
init.normal_(w.data) # bias | init.normal_(w.data) # bias | ||||
elif hasattr(m, 'weight') and m.weight.requires_grad: | |||||
elif m is not None and hasattr(m, 'weight') and \ | |||||
hasattr(m.weight, "requires_grad"): | |||||
init_method(m.weight.data) | init_method(m.weight.data) | ||||
else: | else: | ||||
for w in m.parameters(): | for w in m.parameters(): | ||||
@@ -0,0 +1,111 @@ | |||||
import unittest | |||||
from fastNLP import DataSet | |||||
from fastNLP import Instance | |||||
from fastNLP import Vocabulary | |||||
from fastNLP.core.losses import CrossEntropyLoss | |||||
from fastNLP.core.metrics import AccuracyMetric | |||||
class TestENAS(unittest.TestCase): | |||||
def testENAS(self): | |||||
# 从csv读取数据到DataSet | |||||
sample_path = "tutorials/sample_data/tutorial_sample_dataset.csv" | |||||
dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'), | |||||
sep='\t') | |||||
print(len(dataset)) | |||||
print(dataset[0]) | |||||
print(dataset[-3]) | |||||
dataset.append(Instance(raw_sentence='fake data', label='0')) | |||||
# 将所有数字转为小写 | |||||
dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence') | |||||
# label转int | |||||
dataset.apply(lambda x: int(x['label']), new_field_name='label') | |||||
# 使用空格分割句子 | |||||
def split_sent(ins): | |||||
return ins['raw_sentence'].split() | |||||
dataset.apply(split_sent, new_field_name='words') | |||||
# 增加长度信息 | |||||
dataset.apply(lambda x: len(x['words']), new_field_name='seq_len') | |||||
print(len(dataset)) | |||||
print(dataset[0]) | |||||
# DataSet.drop(func)筛除数据 | |||||
dataset.drop(lambda x: x['seq_len'] <= 3) | |||||
print(len(dataset)) | |||||
# 设置DataSet中,哪些field要转为tensor | |||||
# set target,loss或evaluate中的golden,计算loss,模型评估时使用 | |||||
dataset.set_target("label") | |||||
# set input,模型forward时使用 | |||||
dataset.set_input("words", "seq_len") | |||||
# 分出测试集、训练集 | |||||
test_data, train_data = dataset.split(0.5) | |||||
print(len(test_data)) | |||||
print(len(train_data)) | |||||
# 构建词表, Vocabulary.add(word) | |||||
vocab = Vocabulary(min_freq=2) | |||||
train_data.apply(lambda x: [vocab.add(word) for word in x['words']]) | |||||
vocab.build_vocab() | |||||
# index句子, Vocabulary.to_index(word) | |||||
train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') | |||||
test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') | |||||
print(test_data[0]) | |||||
# 如果你们需要做强化学习或者GAN之类的项目,你们也可以使用这些数据预处理的工具 | |||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.sampler import RandomSampler | |||||
batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler()) | |||||
for batch_x, batch_y in batch_iterator: | |||||
print("batch_x has: ", batch_x) | |||||
print("batch_y has: ", batch_y) | |||||
break | |||||
from fastNLP.automl.enas_model import ENASModel | |||||
from fastNLP.automl.enas_controller import Controller | |||||
model = ENASModel(embed_num=len(vocab), num_classes=5) | |||||
controller = Controller() | |||||
from fastNLP.automl.enas_trainer import ENASTrainer | |||||
# 更改DataSet中对应field的名称,要以模型的forward等参数名一致 | |||||
train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致 | |||||
train_data.rename_field('label', 'label_seq') | |||||
test_data.rename_field('words', 'word_seq') | |||||
test_data.rename_field('label', 'label_seq') | |||||
loss = CrossEntropyLoss(pred="output", target="label_seq") | |||||
metric = AccuracyMetric(pred="predict", target="label_seq") | |||||
trainer = ENASTrainer(model=model, controller=controller, train_data=train_data, dev_data=test_data, | |||||
loss=CrossEntropyLoss(pred="output", target="label_seq"), | |||||
metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||||
check_code_level=-1, | |||||
save_path=None, | |||||
batch_size=32, | |||||
print_every=1, | |||||
n_epochs=3, | |||||
final_epochs=1) | |||||
trainer.train() | |||||
print('Train finished!') | |||||
# 调用Tester在test_data上评价效果 | |||||
from fastNLP import Tester | |||||
tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||||
batch_size=4) | |||||
acc = tester.test() | |||||
print(acc) | |||||
if __name__ == '__main__': | |||||
unittest.main() |
@@ -136,3 +136,28 @@ class TestCallback(unittest.TestCase): | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | metrics=AccuracyMetric(pred="predict", target="y"), | ||||
callbacks=[TensorboardCallback("loss", "metric")]) | callbacks=[TensorboardCallback("loss", "metric")]) | ||||
trainer.train() | trainer.train() | ||||
def test_readonly_property(self): | |||||
from fastNLP.core.callback import Callback | |||||
class MyCallback(Callback): | |||||
def __init__(self): | |||||
super(MyCallback, self).__init__() | |||||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||||
print(self.n_epochs, self.n_steps, self.batch_size) | |||||
print(self.model) | |||||
print(self.optimizer) | |||||
data_set, model = prepare_env() | |||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=5, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
callbacks=[MyCallback()]) | |||||
trainer.train() |
@@ -52,7 +52,7 @@ class TestDataSetMethods(unittest.TestCase): | |||||
self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3) | self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3) | ||||
self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3) | self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3) | ||||
def test_add_append(self): | |||||
def test_add_field(self): | |||||
dd = DataSet() | dd = DataSet() | ||||
dd.add_field("x", [[1, 2, 3]] * 10) | dd.add_field("x", [[1, 2, 3]] * 10) | ||||
dd.add_field("y", [[1, 2, 3, 4]] * 10) | dd.add_field("y", [[1, 2, 3, 4]] * 10) | ||||
@@ -65,6 +65,11 @@ class TestDataSetMethods(unittest.TestCase): | |||||
with self.assertRaises(RuntimeError): | with self.assertRaises(RuntimeError): | ||||
dd.add_field("??", [[1, 2]] * 40) | dd.add_field("??", [[1, 2]] * 40) | ||||
def test_add_field_ignore_type(self): | |||||
dd = DataSet() | |||||
dd.add_field("x", [(1, "1"), (2, "2"), (3, "3"), (4, "4")], ignore_type=True, is_target=True) | |||||
dd.add_field("y", [{1, "1"}, {2, "2"}, {3, "3"}, {4, "4"}], ignore_type=True, is_target=True) | |||||
def test_delete_field(self): | def test_delete_field(self): | ||||
dd = DataSet() | dd = DataSet() | ||||
dd.add_field("x", [[1, 2, 3]] * 10) | dd.add_field("x", [[1, 2, 3]] * 10) | ||||
@@ -115,6 +120,9 @@ class TestDataSetMethods(unittest.TestCase): | |||||
self.assertTrue(isinstance(res, list) and len(res) > 0) | self.assertTrue(isinstance(res, list) and len(res) > 0) | ||||
self.assertTrue(res[0], 4) | self.assertTrue(res[0], 4) | ||||
ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k", ignore_type=True) | |||||
# expect no exception raised | |||||
def test_drop(self): | def test_drop(self): | ||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) | ||||
ds.drop(lambda ins: len(ins["y"]) < 3) | ds.drop(lambda ins: len(ins["y"]) < 3) | ||||
@@ -165,7 +173,7 @@ class TestDataSetMethods(unittest.TestCase): | |||||
dataset.apply(split_sent, new_field_name='words', is_input=True) | dataset.apply(split_sent, new_field_name='words', is_input=True) | ||||
# print(dataset) | # print(dataset) | ||||
def test_add_field(self): | |||||
def test_add_field_v2(self): | |||||
ds = DataSet({"x": [3, 4]}) | ds = DataSet({"x": [3, 4]}) | ||||
ds.add_field('y', [['hello', 'world'], ['this', 'is', 'a', 'test']], is_input=True, is_target=True) | ds.add_field('y', [['hello', 'world'], ['this', 'is', 'a', 'test']], is_input=True, is_target=True) | ||||
# ds.apply(lambda x:[x['x']]*3, is_input=True, is_target=True, new_field_name='y') | # ds.apply(lambda x:[x['x']]*3, is_input=True, is_target=True, new_field_name='y') | ||||
@@ -208,6 +216,11 @@ class TestDataSetMethods(unittest.TestCase): | |||||
self.assertTrue(isinstance(ds, DataSet)) | self.assertTrue(isinstance(ds, DataSet)) | ||||
self.assertTrue(len(ds) > 0) | self.assertTrue(len(ds) > 0) | ||||
def test_add_null(self): | |||||
ds = DataSet() | |||||
ds.add_field('test', []) | |||||
ds.set_target('test') | |||||
class TestDataSetIter(unittest.TestCase): | class TestDataSetIter(unittest.TestCase): | ||||
def test__repr__(self): | def test__repr__(self): | ||||
@@ -155,6 +155,13 @@ class TestFieldArray(unittest.TestCase): | |||||
self.assertEqual(len(fa), 3) | self.assertEqual(len(fa), 3) | ||||
self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6]) | self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6]) | ||||
def test_ignore_type(self): | |||||
# 测试新添加的参数ignore_type,用来跳过类型检查 | |||||
fa = FieldArray("y", [[1.1, 2.2, "jin", {}, "hahah"], [int, 2, "$", 4, 5]], is_input=True, ignore_type=True) | |||||
fa.append([1.2, 2.3, str, 4.5, print]) | |||||
fa = FieldArray("y", [(1, "1"), (2, "2"), (3, "3"), (4, "4")], is_target=True, ignore_type=True) | |||||
class TestPadder(unittest.TestCase): | class TestPadder(unittest.TestCase): | ||||
@@ -215,4 +222,14 @@ class TestPadder(unittest.TestCase): | |||||
[[[1, 2, 3, -100, -100], [4, 5, -100, -100, -100], [7, 8, 9, 10, -100]], | [[[1, 2, 3, -100, -100], [4, 5, -100, -100, -100], [7, 8, 9, 10, -100]], | ||||
[[1, -100, -100, -100, -100], [-100, -100, -100, -100, -100], [-100, -100, -100, -100, -100]]], | [[1, -100, -100, -100, -100], [-100, -100, -100, -100, -100], [-100, -100, -100, -100, -100]]], | ||||
padder(contents, None, np.int64).tolist() | padder(contents, None, np.int64).tolist() | ||||
) | |||||
) | |||||
def test_None_dtype(self): | |||||
from fastNLP.core.fieldarray import AutoPadder | |||||
padder = AutoPadder() | |||||
content = [ | |||||
[[1, 2, 3], [4, 5], [7, 8, 9, 10]], | |||||
[[1]] | |||||
] | |||||
ans = padder(content, None, None) | |||||
self.assertListEqual(content, ans) |
@@ -60,6 +60,15 @@ class TestIndexing(unittest.TestCase): | |||||
vocab.update(text) | vocab.update(text) | ||||
self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) | self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) | ||||
def test_iteration(self): | |||||
vocab = Vocabulary() | |||||
text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", | |||||
"works", "well", "in", "most", "cases", "scales", "well"] | |||||
vocab.update(text) | |||||
text = set(text) | |||||
for word in vocab: | |||||
self.assertTrue(word in text) | |||||
class TestOther(unittest.TestCase): | class TestOther(unittest.TestCase): | ||||
def test_additional_update(self): | def test_additional_update(self): | ||||
@@ -66,7 +66,7 @@ class TestCRF(unittest.TestCase): | |||||
# from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions | # from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions | ||||
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label)) | # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label)) | ||||
# fast_CRF.trans_m = trans_m | # fast_CRF.trans_m = trans_m | ||||
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True) | |||||
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True, unpad=True) | |||||
# # score equal | # # score equal | ||||
# self.assertListEqual([score for _, score in allen_res], fast_res[1]) | # self.assertListEqual([score for _, score in allen_res], fast_res[1]) | ||||
# # seq equal | # # seq equal | ||||
@@ -95,10 +95,34 @@ class TestCRF(unittest.TestCase): | |||||
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, | # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, | ||||
# encoding_type='BMES')) | # encoding_type='BMES')) | ||||
# fast_CRF.trans_m = trans_m | # fast_CRF.trans_m = trans_m | ||||
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True) | |||||
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True, unpad=True) | |||||
# # score equal | # # score equal | ||||
# self.assertListEqual([score for _, score in allen_res], fast_res[1]) | # self.assertListEqual([score for _, score in allen_res], fast_res[1]) | ||||
# # seq equal | # # seq equal | ||||
# self.assertListEqual([_ for _, score in allen_res], fast_res[0]) | # self.assertListEqual([_ for _, score in allen_res], fast_res[0]) | ||||
def test_case3(self): | |||||
# 测试crf的loss不会出现负数 | |||||
import torch | |||||
from fastNLP.modules.decoder.CRF import ConditionalRandomField | |||||
from fastNLP.core.utils import seq_lens_to_masks | |||||
from torch import optim | |||||
from torch import nn | |||||
num_tags, include_start_end_trans = 4, True | |||||
num_samples = 4 | |||||
lengths = torch.randint(3, 50, size=(num_samples, )).long() | |||||
max_len = lengths.max() | |||||
tags = torch.randint(num_tags, size=(num_samples, max_len)) | |||||
masks = seq_lens_to_masks(lengths) | |||||
feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags)) | |||||
crf = ConditionalRandomField(num_tags, include_start_end_trans) | |||||
optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1) | |||||
for _ in range(10000): | |||||
loss = crf(feats, tags, masks).mean() | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
optimizer.step() | |||||
if _%1000==0: | |||||
print(loss) | |||||
assert loss.item()>0, "CRF loss cannot be less than 0." |