Browse Source

Merge remote-tracking branch 'private/dev' into pr

# Conflicts:
#	fastNLP/core/callback.py
#	fastNLP/core/trainer.py
tags/v0.4.10
yunfan 5 years ago
parent
commit
e12041513f
19 changed files with 1508 additions and 68 deletions
  1. +0
    -0
      fastNLP/automl/__init__.py
  2. +223
    -0
      fastNLP/automl/enas_controller.py
  3. +388
    -0
      fastNLP/automl/enas_model.py
  4. +382
    -0
      fastNLP/automl/enas_trainer.py
  5. +53
    -0
      fastNLP/automl/enas_utils.py
  6. +0
    -1
      fastNLP/core/callback.py
  7. +14
    -4
      fastNLP/core/dataset.py
  8. +58
    -41
      fastNLP/core/fieldarray.py
  9. +152
    -1
      fastNLP/core/metrics.py
  10. +9
    -4
      fastNLP/core/trainer.py
  11. +6
    -0
      fastNLP/core/vocabulary.py
  12. +17
    -11
      fastNLP/modules/decoder/CRF.py
  13. +2
    -1
      fastNLP/modules/utils.py
  14. +111
    -0
      test/automl/test_enas.py
  15. +25
    -0
      test/core/test_callbacks.py
  16. +15
    -2
      test/core/test_dataset.py
  17. +18
    -1
      test/core/test_fieldarray.py
  18. +9
    -0
      test/core/test_vocabulary.py
  19. +26
    -2
      test/modules/decoder/test_CRF.py

+ 0
- 0
fastNLP/automl/__init__.py View File


+ 223
- 0
fastNLP/automl/enas_controller.py View File

@@ -0,0 +1,223 @@
# Code Modified from https://github.com/carpedm20/ENAS-pytorch
"""A module with NAS controller-related code."""
import collections
import os

import torch
import torch.nn.functional as F

import fastNLP.automl.enas_utils as utils
from fastNLP.automl.enas_utils import Node


def _construct_dags(prev_nodes, activations, func_names, num_blocks):
"""Constructs a set of DAGs based on the actions, i.e., previous nodes and
activation functions, sampled from the controller/policy pi.

Args:
prev_nodes: Previous node actions from the policy.
activations: Activations sampled from the policy.
func_names: Mapping from activation function names to functions.
num_blocks: Number of blocks in the target RNN cell.

Returns:
A list of DAGs defined by the inputs.

RNN cell DAGs are represented in the following way:

1. Each element (node) in a DAG is a list of `Node`s.

2. The `Node`s in the list dag[i] correspond to the subsequent nodes
that take the output from node i as their own input.

3. dag[-1] is the node that takes input from x^{(t)} and h^{(t - 1)}.
dag[-1] always feeds dag[0].
dag[-1] acts as if `w_xc`, `w_hc`, `w_xh` and `w_hh` are its
weights.

4. dag[N - 1] is the node that produces the hidden state passed to
the next timestep. dag[N - 1] is also always a leaf node, and therefore
is always averaged with the other leaf nodes and fed to the output
decoder.
"""
dags = []
for nodes, func_ids in zip(prev_nodes, activations):
dag = collections.defaultdict(list)

# add first node
dag[-1] = [Node(0, func_names[func_ids[0]])]
dag[-2] = [Node(0, func_names[func_ids[0]])]

# add following nodes
for jdx, (idx, func_id) in enumerate(zip(nodes, func_ids[1:])):
dag[utils.to_item(idx)].append(Node(jdx + 1, func_names[func_id]))

leaf_nodes = set(range(num_blocks)) - dag.keys()

# merge with avg
for idx in leaf_nodes:
dag[idx] = [Node(num_blocks, 'avg')]

# This is actually y^{(t)}. h^{(t)} is node N - 1 in
# the graph, where N Is the number of nodes. I.e., h^{(t)} takes
# only one other node as its input.
# last h[t] node
last_node = Node(num_blocks + 1, 'h[t]')
dag[num_blocks] = [last_node]
dags.append(dag)

return dags


class Controller(torch.nn.Module):
"""Based on
https://github.com/pytorch/examples/blob/master/word_language_model/model.py

RL controllers do not necessarily have much to do with
language models.

Base the controller RNN on the GRU from:
https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py
"""
def __init__(self, num_blocks=4, controller_hid=100, cuda=False):
torch.nn.Module.__init__(self)

# `num_tokens` here is just the activation function
# for every even step,
self.shared_rnn_activations = ['tanh', 'ReLU', 'identity', 'sigmoid']
self.num_tokens = [len(self.shared_rnn_activations)]
self.controller_hid = controller_hid
self.use_cuda = cuda
self.num_blocks = num_blocks
for idx in range(num_blocks):
self.num_tokens += [idx + 1, len(self.shared_rnn_activations)]
self.func_names = self.shared_rnn_activations

num_total_tokens = sum(self.num_tokens)

self.encoder = torch.nn.Embedding(num_total_tokens,
controller_hid)
self.lstm = torch.nn.LSTMCell(controller_hid, controller_hid)

# Perhaps these weights in the decoder should be
# shared? At least for the activation functions, which all have the
# same size.
self.decoders = []
for idx, size in enumerate(self.num_tokens):
decoder = torch.nn.Linear(controller_hid, size)
self.decoders.append(decoder)

self._decoders = torch.nn.ModuleList(self.decoders)

self.reset_parameters()
self.static_init_hidden = utils.keydefaultdict(self.init_hidden)

def _get_default_hidden(key):
return utils.get_variable(
torch.zeros(key, self.controller_hid),
self.use_cuda,
requires_grad=False)

self.static_inputs = utils.keydefaultdict(_get_default_hidden)

def reset_parameters(self):
init_range = 0.1
for param in self.parameters():
param.data.uniform_(-init_range, init_range)
for decoder in self.decoders:
decoder.bias.data.fill_(0)

def forward(self, # pylint:disable=arguments-differ
inputs,
hidden,
block_idx,
is_embed):
if not is_embed:
embed = self.encoder(inputs)
else:
embed = inputs

hx, cx = self.lstm(embed, hidden)
logits = self.decoders[block_idx](hx)

logits /= 5.0

# # exploration
# if self.args.mode == 'train':
# logits = (2.5 * F.tanh(logits))

return logits, (hx, cx)

def sample(self, batch_size=1, with_details=False, save_dir=None):
"""Samples a set of `args.num_blocks` many computational nodes from the
controller, where each node is made up of an activation function, and
each node except the last also includes a previous node.
"""
if batch_size < 1:
raise Exception(f'Wrong batch_size: {batch_size} < 1')

# [B, L, H]
inputs = self.static_inputs[batch_size]
hidden = self.static_init_hidden[batch_size]

activations = []
entropies = []
log_probs = []
prev_nodes = []
# The RNN controller alternately outputs an activation,
# followed by a previous node, for each block except the last one,
# which only gets an activation function. The last node is the output
# node, and its previous node is the average of all leaf nodes.
for block_idx in range(2*(self.num_blocks - 1) + 1):
logits, hidden = self.forward(inputs,
hidden,
block_idx,
is_embed=(block_idx == 0))

probs = F.softmax(logits, dim=-1)
log_prob = F.log_softmax(logits, dim=-1)
# .mean() for entropy?
entropy = -(log_prob * probs).sum(1, keepdim=False)

action = probs.multinomial(num_samples=1).data
selected_log_prob = log_prob.gather(
1, utils.get_variable(action, requires_grad=False))

# why the [:, 0] here? Should it be .squeeze(), or
# .view()? Same below with `action`.
entropies.append(entropy)
log_probs.append(selected_log_prob[:, 0])

# 0: function, 1: previous node
mode = block_idx % 2
inputs = utils.get_variable(
action[:, 0] + sum(self.num_tokens[:mode]),
requires_grad=False)

if mode == 0:
activations.append(action[:, 0])
elif mode == 1:
prev_nodes.append(action[:, 0])

prev_nodes = torch.stack(prev_nodes).transpose(0, 1)
activations = torch.stack(activations).transpose(0, 1)

dags = _construct_dags(prev_nodes,
activations,
self.func_names,
self.num_blocks)

if save_dir is not None:
for idx, dag in enumerate(dags):
utils.draw_network(dag,
os.path.join(save_dir, f'graph{idx}.png'))

if with_details:
return dags, torch.cat(log_probs), torch.cat(entropies)

return dags

def init_hidden(self, batch_size):
zeros = torch.zeros(batch_size, self.controller_hid)
return (utils.get_variable(zeros, self.use_cuda, requires_grad=False),
utils.get_variable(zeros.clone(), self.use_cuda, requires_grad=False))

+ 388
- 0
fastNLP/automl/enas_model.py View File

@@ -0,0 +1,388 @@
# Code Modified from https://github.com/carpedm20/ENAS-pytorch

"""Module containing the shared RNN model."""
import collections

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable

import fastNLP.automl.enas_utils as utils
from fastNLP.models.base_model import BaseModel


def _get_dropped_weights(w_raw, dropout_p, is_training):
"""Drops out weights to implement DropConnect.

Args:
w_raw: Full, pre-dropout, weights to be dropped out.
dropout_p: Proportion of weights to drop out.
is_training: True iff _shared_ model is training.

Returns:
The dropped weights.

Why does torch.nn.functional.dropout() return:
1. `torch.autograd.Variable()` on the training loop
2. `torch.nn.Parameter()` on the controller or eval loop, when
training = False...

Even though the call to `_setweights` in the Smerity repo's
`weight_drop.py` does not have this behaviour, and `F.dropout` always
returns `torch.autograd.Variable` there, even when `training=False`?

The above TODO is the reason for the hacky check for `torch.nn.Parameter`.
"""
dropped_w = F.dropout(w_raw, p=dropout_p, training=is_training)

if isinstance(dropped_w, torch.nn.Parameter):
dropped_w = dropped_w.clone()

return dropped_w

class EmbeddingDropout(torch.nn.Embedding):
"""Class for dropping out embeddings by zero'ing out parameters in the
embedding matrix.

This is equivalent to dropping out particular words, e.g., in the sentence
'the quick brown fox jumps over the lazy dog', dropping out 'the' would
lead to the sentence '### quick brown fox jumps over ### lazy dog' (in the
embedding vector space).

See 'A Theoretically Grounded Application of Dropout in Recurrent Neural
Networks', (Gal and Ghahramani, 2016).
"""
def __init__(self,
num_embeddings,
embedding_dim,
max_norm=None,
norm_type=2,
scale_grad_by_freq=False,
sparse=False,
dropout=0.1,
scale=None):
"""Embedding constructor.

Args:
dropout: Dropout probability.
scale: Used to scale parameters of embedding weight matrix that are
not dropped out. Note that this is _in addition_ to the
`1/(1 - dropout)` scaling.

See `torch.nn.Embedding` for remaining arguments.
"""
torch.nn.Embedding.__init__(self,
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
self.dropout = dropout
assert (dropout >= 0.0) and (dropout < 1.0), ('Dropout must be >= 0.0 '
'and < 1.0')
self.scale = scale

def forward(self, inputs): # pylint:disable=arguments-differ
"""Embeds `inputs` with the dropped out embedding weight matrix."""
if self.training:
dropout = self.dropout
else:
dropout = 0

if dropout:
mask = self.weight.data.new(self.weight.size(0), 1)
mask.bernoulli_(1 - dropout)
mask = mask.expand_as(self.weight)
mask = mask / (1 - dropout)
masked_weight = self.weight * Variable(mask)
else:
masked_weight = self.weight
if self.scale and self.scale != 1:
masked_weight = masked_weight * self.scale

return F.embedding(inputs,
masked_weight,
max_norm=self.max_norm,
norm_type=self.norm_type,
scale_grad_by_freq=self.scale_grad_by_freq,
sparse=self.sparse)


class LockedDropout(nn.Module):
# code from https://github.com/salesforce/awd-lstm-lm/blob/master/locked_dropout.py
def __init__(self):
super().__init__()

def forward(self, x, dropout=0.5):
if not self.training or not dropout:
return x
m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout)
mask = Variable(m, requires_grad=False) / (1 - dropout)
mask = mask.expand_as(x)
return mask * x


class ENASModel(BaseModel):
"""Shared RNN model."""
def __init__(self, embed_num, num_classes, num_blocks=4, cuda=False, shared_hid=1000, shared_embed=1000):
super(ENASModel, self).__init__()

self.use_cuda = cuda

self.shared_hid = shared_hid
self.num_blocks = num_blocks
self.decoder = nn.Linear(self.shared_hid, num_classes)
self.encoder = EmbeddingDropout(embed_num,
shared_embed,
dropout=0.1)
self.lockdrop = LockedDropout()
self.dag = None

# Tie weights
# self.decoder.weight = self.encoder.weight

# Since W^{x, c} and W^{h, c} are always summed, there
# is no point duplicating their bias offset parameter. Likewise for
# W^{x, h} and W^{h, h}.
self.w_xc = nn.Linear(shared_embed, self.shared_hid)
self.w_xh = nn.Linear(shared_embed, self.shared_hid)

# The raw weights are stored here because the hidden-to-hidden weights
# are weight dropped on the forward pass.
self.w_hc_raw = torch.nn.Parameter(
torch.Tensor(self.shared_hid, self.shared_hid))
self.w_hh_raw = torch.nn.Parameter(
torch.Tensor(self.shared_hid, self.shared_hid))
self.w_hc = None
self.w_hh = None

self.w_h = collections.defaultdict(dict)
self.w_c = collections.defaultdict(dict)

for idx in range(self.num_blocks):
for jdx in range(idx + 1, self.num_blocks):
self.w_h[idx][jdx] = nn.Linear(self.shared_hid,
self.shared_hid,
bias=False)
self.w_c[idx][jdx] = nn.Linear(self.shared_hid,
self.shared_hid,
bias=False)

self._w_h = nn.ModuleList([self.w_h[idx][jdx]
for idx in self.w_h
for jdx in self.w_h[idx]])
self._w_c = nn.ModuleList([self.w_c[idx][jdx]
for idx in self.w_c
for jdx in self.w_c[idx]])

self.batch_norm = None
# if args.mode == 'train':
# self.batch_norm = nn.BatchNorm1d(self.shared_hid)
# else:
# self.batch_norm = None

self.reset_parameters()
self.static_init_hidden = utils.keydefaultdict(self.init_hidden)

def setDAG(self, dag):
if self.dag is None:
self.dag = dag

def forward(self, word_seq, hidden=None):
inputs = torch.transpose(word_seq, 0, 1)

time_steps = inputs.size(0)
batch_size = inputs.size(1)


self.w_hh = _get_dropped_weights(self.w_hh_raw,
0.5,
self.training)
self.w_hc = _get_dropped_weights(self.w_hc_raw,
0.5,
self.training)

# hidden = self.static_init_hidden[batch_size] if hidden is None else hidden
hidden = self.static_init_hidden[batch_size]

embed = self.encoder(inputs)

embed = self.lockdrop(embed, 0.65 if self.training else 0)

# The norm of hidden states are clipped here because
# otherwise ENAS is especially prone to exploding activations on the
# forward pass. This could probably be fixed in a more elegant way, but
# it might be exposing a weakness in the ENAS algorithm as currently
# proposed.
#
# For more details, see
# https://github.com/carpedm20/ENAS-pytorch/issues/6
clipped_num = 0
max_clipped_norm = 0
h1tohT = []
logits = []
for step in range(time_steps):
x_t = embed[step]
logit, hidden = self.cell(x_t, hidden, self.dag)

hidden_norms = hidden.norm(dim=-1)
max_norm = 25.0
if hidden_norms.data.max() > max_norm:
# Just directly use the torch slice operations
# in PyTorch v0.4.
#
# This workaround for PyTorch v0.3.1 does everything in numpy,
# because the PyTorch slicing and slice assignment is too
# flaky.
hidden_norms = hidden_norms.data.cpu().numpy()

clipped_num += 1
if hidden_norms.max() > max_clipped_norm:
max_clipped_norm = hidden_norms.max()

clip_select = hidden_norms > max_norm
clip_norms = hidden_norms[clip_select]

mask = np.ones(hidden.size())
normalizer = max_norm/clip_norms
normalizer = normalizer[:, np.newaxis]

mask[clip_select] = normalizer

if self.use_cuda:
hidden *= torch.autograd.Variable(
torch.FloatTensor(mask).cuda(), requires_grad=False)
else:
hidden *= torch.autograd.Variable(
torch.FloatTensor(mask), requires_grad=False)
logits.append(logit)
h1tohT.append(hidden)

h1tohT = torch.stack(h1tohT)
output = torch.stack(logits)
raw_output = output

output = self.lockdrop(output, 0.4 if self.training else 0)

#Pooling
output = torch.mean(output, 0)

decoded = self.decoder(output)

extra_out = {'dropped': decoded,
'hiddens': h1tohT,
'raw': raw_output}
return {'pred': decoded, 'hidden': hidden, 'extra_out': extra_out}

def cell(self, x, h_prev, dag):
"""Computes a single pass through the discovered RNN cell."""
c = {}
h = {}
f = {}

f[0] = self.get_f(dag[-1][0].name)
c[0] = torch.sigmoid(self.w_xc(x) + F.linear(h_prev, self.w_hc, None))
h[0] = (c[0]*f[0](self.w_xh(x) + F.linear(h_prev, self.w_hh, None)) +
(1 - c[0])*h_prev)

leaf_node_ids = []
q = collections.deque()
q.append(0)

# Computes connections from the parent nodes `node_id`
# to their child nodes `next_id` recursively, skipping leaf nodes. A
# leaf node is a node whose id == `self.num_blocks`.
#
# Connections between parent i and child j should be computed as
# h_j = c_j*f_{ij}{(W^h_{ij}*h_i)} + (1 - c_j)*h_i,
# where c_j = \sigmoid{(W^c_{ij}*h_i)}
#
# See Training details from Section 3.1 of the paper.
#
# The following algorithm does a breadth-first (since `q.popleft()` is
# used) search over the nodes and computes all the hidden states.
while True:
if len(q) == 0:
break

node_id = q.popleft()
nodes = dag[node_id]

for next_node in nodes:
next_id = next_node.id
if next_id == self.num_blocks:
leaf_node_ids.append(node_id)
assert len(nodes) == 1, ('parent of leaf node should have '
'only one child')
continue

w_h = self.w_h[node_id][next_id]
w_c = self.w_c[node_id][next_id]

f[next_id] = self.get_f(next_node.name)
c[next_id] = torch.sigmoid(w_c(h[node_id]))
h[next_id] = (c[next_id]*f[next_id](w_h(h[node_id])) +
(1 - c[next_id])*h[node_id])

q.append(next_id)

# Instead of averaging loose ends, perhaps there should
# be a set of separate unshared weights for each "loose" connection
# between each node in a cell and the output.
#
# As it stands, all weights W^h_{ij} are doing double duty by
# connecting both from i to j, as well as from i to the output.

# average all the loose ends
leaf_nodes = [h[node_id] for node_id in leaf_node_ids]
output = torch.mean(torch.stack(leaf_nodes, 2), -1)

# stabilizing the Updates of omega
if self.batch_norm is not None:
output = self.batch_norm(output)

return output, h[self.num_blocks - 1]

def init_hidden(self, batch_size):
zeros = torch.zeros(batch_size, self.shared_hid)
return utils.get_variable(zeros, self.use_cuda, requires_grad=False)

def get_f(self, name):
name = name.lower()
if name == 'relu':
f = torch.relu
elif name == 'tanh':
f = torch.tanh
elif name == 'identity':
f = lambda x: x
elif name == 'sigmoid':
f = torch.sigmoid
return f

@property
def num_parameters(self):
def size(p):
return np.prod(p.size())
return sum([size(param) for param in self.parameters()])


def reset_parameters(self):
init_range = 0.025
# init_range = 0.025 if self.args.mode == 'train' else 0.04
for param in self.parameters():
param.data.uniform_(-init_range, init_range)
self.decoder.bias.data.fill_(0)

def predict(self, word_seq):
"""

:param word_seq: torch.LongTensor, [batch_size, seq_len]
:return predict: dict of torch.LongTensor, [batch_size, seq_len]
"""
output = self(word_seq)
_, predict = output['pred'].max(dim=1)
return {'pred': predict}

+ 382
- 0
fastNLP/automl/enas_trainer.py View File

@@ -0,0 +1,382 @@
# Code Modified from https://github.com/carpedm20/ENAS-pytorch

import math
import time
from datetime import datetime
from datetime import timedelta

import numpy as np
import torch

try:
from tqdm.autonotebook import tqdm
except:
from fastNLP.core.utils import pseudo_tqdm as tqdm

from fastNLP.core.batch import Batch
from fastNLP.core.callback import CallbackException
from fastNLP.core.dataset import DataSet
from fastNLP.core.utils import _move_dict_value_to_device
import fastNLP
import fastNLP.automl.enas_utils as utils
from fastNLP.core.utils import _build_args

from torch.optim import Adam


def _get_no_grad_ctx_mgr():
"""Returns a the `torch.no_grad` context manager for PyTorch version >=
0.4, or a no-op context manager otherwise.
"""
return torch.no_grad()


class ENASTrainer(fastNLP.Trainer):
"""A class to wrap training code."""
def __init__(self, train_data, model, controller, **kwargs):
"""Constructor for training algorithm.
:param DataSet train_data: the training data
:param torch.nn.modules.module model: a PyTorch model
:param torch.nn.modules.module controller: a PyTorch model
"""
self.final_epochs = kwargs['final_epochs']
kwargs.pop('final_epochs')
super(ENASTrainer, self).__init__(train_data, model, **kwargs)
self.controller_step = 0
self.shared_step = 0
self.max_length = 35

self.shared = model
self.controller = controller

self.shared_optim = Adam(
self.shared.parameters(),
lr=20.0,
weight_decay=1e-7)

self.controller_optim = Adam(
self.controller.parameters(),
lr=3.5e-4)

def train(self, load_best_model=True):
"""
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现
最好的模型参数。
:return results: 返回一个字典类型的数据, 内含以下内容::

seconds: float, 表示训练时长
以下三个内容只有在提供了dev_data的情况下会有。
best_eval: Dict of Dict, 表示evaluation的结果
best_epoch: int,在第几个epoch取得的最佳值
best_step: int, 在第几个step(batch)更新取得的最佳值

"""
results = {}
if self.n_epochs <= 0:
print(f"training epoch is {self.n_epochs}, nothing was done.")
results['seconds'] = 0.
return results
try:
if torch.cuda.is_available() and self.use_cuda:
self.model = self.model.cuda()
self._model_device = self.model.parameters().__next__().device
self._mode(self.model, is_test=False)

self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
start_time = time.time()
print("training epochs started " + self.start_time, flush=True)

try:
self.callback_manager.on_train_begin()
self._train()
self.callback_manager.on_train_end(self.model)
except (CallbackException, KeyboardInterrupt) as e:
self.callback_manager.on_exception(e, self.model)

if self.dev_data is not None:
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
self.tester._format_eval_results(self.best_dev_perf),)
results['best_eval'] = self.best_dev_perf
results['best_epoch'] = self.best_dev_epoch
results['best_step'] = self.best_dev_step
if load_best_model:
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])
load_succeed = self._load_model(self.model, model_name)
if load_succeed:
print("Reloaded the best model.")
else:
print("Fail to reload best model.")
finally:
pass
results['seconds'] = round(time.time() - start_time, 2)

return results

def _train(self):
if not self.use_tqdm:
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm
else:
inner_tqdm = tqdm
self.step = 0
start = time.time()
total_steps = (len(self.train_data) // self.batch_size + int(
len(self.train_data) % self.batch_size != 0)) * self.n_epochs
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)
for epoch in range(1, self.n_epochs+1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
last_stage = (epoch > self.n_epochs + 1 - self.final_epochs)
if epoch == self.n_epochs + 1 - self.final_epochs:
print('Entering the final stage. (Only train the selected structure)')
# early stopping
self.callback_manager.on_epoch_begin(epoch, self.n_epochs)

# 1. Training the shared parameters omega of the child models
self.train_shared(pbar)

# 2. Training the controller parameters theta
if not last_stage:
self.train_controller()

if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \
and self.dev_data is not None:
if not last_stage:
self.derive()
eval_res = self._do_validation(epoch=epoch, step=self.step)
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
total_steps) + \
self.tester._format_eval_results(eval_res)
pbar.write(eval_str)

# lr decay; early stopping
self.callback_manager.on_epoch_end(epoch, self.n_epochs, self.optimizer)
# =============== epochs end =================== #
pbar.close()
# ============ tqdm end ============== #


def get_loss(self, inputs, targets, hidden, dags):
"""Computes the loss for the same batch for M models.

This amounts to an estimate of the loss, which is turned into an
estimate for the gradients of the shared model.
"""
if not isinstance(dags, list):
dags = [dags]

loss = 0
for dag in dags:
self.shared.setDAG(dag)
inputs = _build_args(self.shared.forward, **inputs)
inputs['hidden'] = hidden
result = self.shared(**inputs)
output, hidden, extra_out = result['pred'], result['hidden'], result['extra_out']

self.callback_manager.on_loss_begin(targets, result)
sample_loss = self._compute_loss(result, targets)
loss += sample_loss

assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`'
return loss, hidden, extra_out

def train_shared(self, pbar=None, max_step=None, dag=None):
"""Train the language model for 400 steps of minibatches of 64
examples.

Args:
max_step: Used to run extra training steps as a warm-up.
dag: If not None, is used instead of calling sample().

BPTT is truncated at 35 timesteps.

For each weight update, gradients are estimated by sampling M models
from the fixed controller policy, and averaging their gradients
computed on a batch of training data.
"""
model = self.shared
model.train()
self.controller.eval()

hidden = self.shared.init_hidden(self.batch_size)

abs_max_grad = 0
abs_max_hidden_norm = 0
step = 0
raw_total_loss = 0
total_loss = 0
train_idx = 0
avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)

for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
indices = data_iterator.get_batch_indices()
# negative sampling; replace unknown; re-weight batch_y
self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
# prediction = self._data_forward(self.model, batch_x)

dags = self.controller.sample(1)
inputs, targets = batch_x, batch_y
# self.callback_manager.on_loss_begin(batch_y, prediction)
loss, hidden, extra_out = self.get_loss(inputs,
targets,
hidden,
dags)
hidden.detach_()
avg_loss += loss.item()

# Is loss NaN or inf? requires_grad = False
self.callback_manager.on_backward_begin(loss, self.model)
self._grad_backward(loss)
self.callback_manager.on_backward_end(self.model)

self._update()
self.callback_manager.on_step_end(self.optimizer)

if (self.step+1) % self.print_every == 0:
if self.use_tqdm:
print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every)
pbar.update(self.print_every)
else:
end = time.time()
diff = timedelta(seconds=round(end - start))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
epoch, self.step, avg_loss, diff)
pbar.set_postfix_str(print_output)
avg_loss = 0
self.step += 1
step += 1
self.shared_step += 1
self.callback_manager.on_batch_end()
# ================= mini-batch end ==================== #


def get_reward(self, dag, entropies, hidden, valid_idx=0):
"""Computes the perplexity of a single sampled model on a minibatch of
validation data.
"""
if not isinstance(entropies, np.ndarray):
entropies = entropies.data.cpu().numpy()

data_iterator = Batch(self.dev_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)

for inputs, targets in data_iterator:
valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag)
valid_loss = utils.to_item(valid_loss.data)

valid_ppl = math.exp(valid_loss)

R = 80 / valid_ppl

rewards = R + 1e-4 * entropies

return rewards, hidden

def train_controller(self):
"""Fixes the shared parameters and updates the controller parameters.

The controller is updated with a score function gradient estimator
(i.e., REINFORCE), with the reward being c/valid_ppl, where valid_ppl
is computed on a minibatch of validation data.

A moving average baseline is used.

The controller is trained for 2000 steps per epoch (i.e.,
first (Train Shared) phase -> second (Train Controller) phase).
"""
model = self.controller
model.train()
# Why can't we call shared.eval() here? Leads to loss
# being uniformly zero for the controller.
# self.shared.eval()

avg_reward_base = None
baseline = None
adv_history = []
entropy_history = []
reward_history = []

hidden = self.shared.init_hidden(self.batch_size)
total_loss = 0
valid_idx = 0
for step in range(20):
# sample models
dags, log_probs, entropies = self.controller.sample(
with_details=True)

# calculate reward
np_entropies = entropies.data.cpu().numpy()
# No gradients should be backpropagated to the
# shared model during controller training, obviously.
with _get_no_grad_ctx_mgr():
rewards, hidden = self.get_reward(dags,
np_entropies,
hidden,
valid_idx)


reward_history.extend(rewards)
entropy_history.extend(np_entropies)

# moving average baseline
if baseline is None:
baseline = rewards
else:
decay = 0.95
baseline = decay * baseline + (1 - decay) * rewards

adv = rewards - baseline
adv_history.extend(adv)

# policy loss
loss = -log_probs*utils.get_variable(adv,
self.use_cuda,
requires_grad=False)

loss = loss.sum() # or loss.mean()

# update
self.controller_optim.zero_grad()
loss.backward()

self.controller_optim.step()

total_loss += utils.to_item(loss.data)

if ((step % 50) == 0) and (step > 0):
reward_history, adv_history, entropy_history = [], [], []
total_loss = 0

self.controller_step += 1
# prev_valid_idx = valid_idx
# valid_idx = ((valid_idx + self.max_length) %
# (self.valid_data.size(0) - 1))
# # Whenever we wrap around to the beginning of the
# # validation data, we reset the hidden states.
# if prev_valid_idx > valid_idx:
# hidden = self.shared.init_hidden(self.batch_size)

def derive(self, sample_num=10, valid_idx=0):
"""We are always deriving based on the very first batch
of validation data? This seems wrong...
"""
hidden = self.shared.init_hidden(self.batch_size)

dags, _, entropies = self.controller.sample(sample_num,
with_details=True)

max_R = 0
best_dag = None
for dag in dags:
R, _ = self.get_reward(dag, entropies, hidden, valid_idx)
if R.max() > max_R:
max_R = R.max()
best_dag = dag

self.model.setDAG(best_dag)

+ 53
- 0
fastNLP/automl/enas_utils.py View File

@@ -0,0 +1,53 @@
# Code Modified from https://github.com/carpedm20/ENAS-pytorch

from __future__ import print_function

import collections
from collections import defaultdict

import numpy as np
import torch
from torch.autograd import Variable


def detach(h):
if type(h) == Variable:
return Variable(h.data)
else:
return tuple(detach(v) for v in h)

def get_variable(inputs, cuda=False, **kwargs):
if type(inputs) in [list, np.ndarray]:
inputs = torch.Tensor(inputs)
if cuda:
out = Variable(inputs.cuda(), **kwargs)
else:
out = Variable(inputs, **kwargs)
return out

def update_lr(optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr

Node = collections.namedtuple('Node', ['id', 'name'])


class keydefaultdict(defaultdict):
def __missing__(self, key):
if self.default_factory is None:
raise KeyError(key)
else:
ret = self[key] = self.default_factory(key)
return ret


def to_item(x):
"""Converts x, possibly scalar and possibly tensor, to a Python scalar."""
if isinstance(x, (float, int)):
return x

if float(torch.__version__[0:3]) < 0.4:
assert (x.dim() == 1) and (len(x) == 1)
return x[0]

return x.item()

+ 0
- 1
fastNLP/core/callback.py View File

@@ -138,7 +138,6 @@ class CallbackManager(Callback):
"""
super(CallbackManager, self).__init__()
# set attribute of trainer environment
self.env = env

self.callbacks = []
if callbacks is not None:


+ 14
- 4
fastNLP/core/dataset.py View File

@@ -157,7 +157,7 @@ class DataSet(object):
assert name in self.field_arrays
self.field_arrays[name].append(field)

def add_field(self, name, fields, padder=AutoPadder(pad_val=0), is_input=False, is_target=False):
def add_field(self, name, fields, padder=AutoPadder(pad_val=0), is_input=False, is_target=False, ignore_type=False):
"""Add a new field to the DataSet.
:param str name: the name of the field.
@@ -165,13 +165,14 @@ class DataSet(object):
:param int padder: PadBase对象,如何对该Field进行padding。大部分情况使用默认值即可
:param bool is_input: whether this field is model input.
:param bool is_target: whether this field is label or target.
:param bool ignore_type: If True, do not perform type check. (Default: False)
"""
if len(self.field_arrays) != 0:
if len(self) != len(fields):
raise RuntimeError(f"The field to append must have the same size as dataset. "
f"Dataset size {len(self)} != field size {len(fields)}")
self.field_arrays[name] = FieldArray(name, fields, is_target=is_target, is_input=is_input,
padder=padder)
padder=padder, ignore_type=ignore_type)

def delete_field(self, name):
"""Delete a field based on the field name.
@@ -242,6 +243,8 @@ class DataSet(object):
:param padder: PadderBase类型或None. 设置为None即删除padder。即对该field不进行padding操作.
:return:
"""
if field_name not in self.field_arrays:
raise KeyError("There is no field named {}.".format(field_name))
self.field_arrays[field_name].set_padder(padder)

def set_pad_val(self, field_name, pad_val):
@@ -252,6 +255,8 @@ class DataSet(object):
:param pad_val: int,该field的padder会以pad_val作为padding index
:return:
"""
if field_name not in self.field_arrays:
raise KeyError("There is no field named {}.".format(field_name))
self.field_arrays[field_name].set_pad_val(pad_val)

def get_input_name(self):
@@ -287,6 +292,8 @@ class DataSet(object):
extra_param['is_input'] = kwargs['is_input']
if 'is_target' in kwargs:
extra_param['is_target'] = kwargs['is_target']
if 'ignore_type' in kwargs:
extra_param['ignore_type'] = kwargs['ignore_type']
if new_field_name is not None:
if new_field_name in self.field_arrays:
# overwrite the field, keep same attributes
@@ -295,11 +302,14 @@ class DataSet(object):
extra_param['is_input'] = old_field.is_input
if 'is_target' not in extra_param:
extra_param['is_target'] = old_field.is_target
if 'ignore_type' not in extra_param:
extra_param['ignore_type'] = old_field.ignore_type
self.add_field(name=new_field_name, fields=results, is_input=extra_param["is_input"],
is_target=extra_param["is_target"])
is_target=extra_param["is_target"], ignore_type=extra_param['ignore_type'])
else:
self.add_field(name=new_field_name, fields=results, is_input=extra_param.get("is_input", None),
is_target=extra_param.get("is_target", None))
is_target=extra_param.get("is_target", None),
ignore_type=extra_param.get("ignore_type", False))
else:
return results



+ 58
- 41
fastNLP/core/fieldarray.py View File

@@ -1,5 +1,5 @@
import numpy as np
from copy import deepcopy

class PadderBase:
"""
@@ -83,6 +83,8 @@ class AutoPadder(PadderBase):
array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype)
for i, content in enumerate(contents):
array[i][:len(content)] = content
elif field_ele_dtype is None:
array = contents # 当ignore_type=True时,直接返回contents
else: # should only be str
array = np.array([content for content in contents])
return array
@@ -96,10 +98,16 @@ class FieldArray(object):
:param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray.
:param bool is_target: If True, this FieldArray is used to compute loss.
:param bool is_input: If True, this FieldArray is used to the model input.
:param padder: PadderBase类型。大多数情况下都不需要设置该值,除非需要在多个维度上进行padding(比如英文中对character进行padding)
:param PadderBase padder: PadderBase类型。赋值给fieldarray的padder的对象会被deepcopy一份,需要修改padder参数必须通过
fieldarray.set_pad_val()。
默认为None,(1)如果某个field是scalar,则不进行任何padding;(2)如果为一维list, 且fieldarray的dtype为float或int类型
则会进行padding;(3)其它情况不进行padder。
假设需要对English word中character进行padding,则需要使用其他的padder。
或ignore_type为True但是需要进行padding。
:param bool ignore_type: whether to ignore type. If True, no type detection will rise for this FieldArray. (default: False)
"""

def __init__(self, name, content, is_target=None, is_input=None, padder=AutoPadder(pad_val=0)):
def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False):
"""DataSet在初始化时会有两类方法对FieldArray操作:
1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray:
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]})
@@ -114,6 +122,7 @@ class FieldArray(object):
2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))])

类型检查(dtype check)发生在当该field被设置为is_input或者is_target时。
ignore_type用来控制是否进行类型检查,如果为True,则不检查。

"""
self.name = name
@@ -135,7 +144,13 @@ class FieldArray(object):

self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐
self.content_dim = None # 表示content是多少维的list
if padder is None:
padder = AutoPadder(pad_val=0)
else:
assert isinstance(padder, PadderBase), "padder must be of type PadderBase."
padder = deepcopy(padder)
self.set_padder(padder)
self.ignore_type = ignore_type

self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array

@@ -149,8 +164,9 @@ class FieldArray(object):
self.is_target = is_target

def _set_dtype(self):
self.pytype = self._type_detection(self.content)
self.dtype = self._map_to_np_type(self.pytype)
if self.ignore_type is False:
self.pytype = self._type_detection(self.content)
self.dtype = self._map_to_np_type(self.pytype)

@property
def is_input(self):
@@ -190,7 +206,7 @@ class FieldArray(object):
if list in type_set:
if len(type_set) > 1:
# list 跟 非list 混在一起
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set))
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set)))
# >1维list
inner_type_set = set()
for l in content:
@@ -213,7 +229,7 @@ class FieldArray(object):
return self._basic_type_detection(inner_inner_type_set)
else:
# list 跟 非list 混在一起
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, inner_type_set))
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(inner_type_set)))
else:
# 一维list
for content_type in type_set:
@@ -237,17 +253,17 @@ class FieldArray(object):
return float
else:
# str 跟 int 或者 float 混在一起
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set))
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set)))
else:
# str, int, float混在一起
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set))
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set)))

def _1d_list_check(self, val):
"""如果不是1D list就报错
"""
type_set = set((type(obj) for obj in val))
if any(obj not in self.BASIC_TYPES for obj in type_set):
raise ValueError("Mixed data types in Field {}: {}".format(self.name, type_set))
raise ValueError("Mixed data types in Field {}: {}".format(self.name, list(type_set)))
self._basic_type_detection(type_set)
# otherwise: _basic_type_detection will raise error
return True
@@ -278,39 +294,40 @@ class FieldArray(object):

:param val: int, float, str, or a list of one.
"""
if isinstance(val, list):
pass
elif isinstance(val, tuple): # 确保最外层是list
val = list(val)
elif isinstance(val, np.ndarray):
val = val.tolist()
elif any((isinstance(val, t) for t in self.BASIC_TYPES)):
pass
else:
raise RuntimeError(
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES))

if self.is_input is True or self.is_target is True:
if type(val) == list:
if len(val) == 0:
raise ValueError("Cannot append an empty list.")
if self.content_dim == 2 and self._1d_list_check(val):
# 1维list检查
pass
elif self.content_dim == 3 and self._2d_list_check(val):
# 2维list检查
pass
else:
raise RuntimeError(
"Dimension not matched: expect dim={}, got {}.".format(self.content_dim - 1, val))
elif type(val) in self.BASIC_TYPES and self.content_dim == 1:
# scalar检查
if type(val) == float and self.pytype == int:
self.pytype = float
self.dtype = self._map_to_np_type(self.pytype)
if self.ignore_type is False:
if isinstance(val, list):
pass
elif isinstance(val, tuple): # 确保最外层是list
val = list(val)
elif isinstance(val, np.ndarray):
val = val.tolist()
elif any((isinstance(val, t) for t in self.BASIC_TYPES)):
pass
else:
raise RuntimeError(
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES))

if self.is_input is True or self.is_target is True:
if type(val) == list:
if len(val) == 0:
raise ValueError("Cannot append an empty list.")
if self.content_dim == 2 and self._1d_list_check(val):
# 1维list检查
pass
elif self.content_dim == 3 and self._2d_list_check(val):
# 2维list检查
pass
else:
raise RuntimeError(
"Dimension not matched: expect dim={}, got {}.".format(self.content_dim - 1, val))
elif type(val) in self.BASIC_TYPES and self.content_dim == 1:
# scalar检查
if type(val) == float and self.pytype == int:
self.pytype = float
self.dtype = self._map_to_np_type(self.pytype)
else:
raise RuntimeError(
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES))
self.content.append(val)

def __getitem__(self, indices):
@@ -347,7 +364,7 @@ class FieldArray(object):
"""
if padder is not None:
assert isinstance(padder, PadderBase), "padder must be of type PadderBase."
self.padder = padder
self.padder = deepcopy(padder)

def set_pad_val(self, pad_val):
"""


+ 152
- 1
fastNLP/core/metrics.py View File

@@ -157,7 +157,7 @@ class MetricBase(object):
fast_param = {}
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
fast_param['pred'] = list(pred_dict.values())[0]
fast_param['target'] = list(pred_dict.values())[0]
fast_param['target'] = list(target_dict.values())[0]
return fast_param
return fast_param

@@ -822,3 +822,154 @@ def pred_topk(y_prob, k=1):
(1, k))
y_prob_topk = y_prob[x_axis_index, y_pred_topk]
return y_pred_topk, y_prob_topk


class SQuADMetric(MetricBase):

def __init__(self, pred_start=None, pred_end=None, target_start=None, target_end=None,
beta=1, right_open=False, print_predict_stat=False):
"""
:param pred_start: [batch], 预测答案开始的index, 如果SQuAD2.0中答案为空则为0
:param pred_end: [batch], 预测答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间)
:param target_start: [batch], 正确答案开始的index, 如果SQuAD2.0中答案为空则为0
:param target_end: [batch], 正确答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间)
:param beta: float. f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。
:param right_open: boolean. right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。
:param print_predict_stat: boolean. True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出
"""
super(SQuADMetric, self).__init__()

self._init_param_map(pred_start=pred_start, pred_end=pred_end, target_start=target_start, target_end=target_end)

self.print_predict_stat = print_predict_stat

self.no_ans_correct = 0
self.no_ans_wrong = 0

self.has_ans_correct = 0
self.has_ans_wrong = 0

self.has_ans_f = 0.

self.no2no = 0
self.no2yes = 0
self.yes2no = 0
self.yes2yes = 0

self.f_beta = beta

self.right_open = right_open

def evaluate(self, pred_start, pred_end, target_start, target_end):
"""

:param pred_start: [batch, seq_len]
:param pred_end: [batch, seq_len]
:param target_start: [batch]
:param target_end: [batch]
:param labels: [batch]
:return:
"""
start_inference = pred_start.max(dim=-1)[1].cpu().tolist()
end_inference = pred_end.max(dim=-1)[1].cpu().tolist()
start, end = [], []
max_len = pred_start.size(1)
t_start = target_start.cpu().tolist()
t_end = target_end.cpu().tolist()

for s, e in zip(start_inference, end_inference):
start.append(min(s, e))
end.append(max(s, e))
for s, e, ts, te in zip(start, end, t_start, t_end):
if not self.right_open:
e += 1
te += 1
if ts == 0 and te == int(not self.right_open):
if s == 0 and e == int(not self.right_open):
self.no_ans_correct += 1
self.no2no += 1
else:
self.no_ans_wrong += 1
self.no2yes += 1
else:
if s == 0 and e == int(not self.right_open):
self.yes2no += 1
else:
self.yes2yes += 1

if s == ts and e == te:
self.has_ans_correct += 1
else:
self.has_ans_wrong += 1
a = [0] * s + [1] * (e - s) + [0] * (max_len - e)
b = [0] * ts + [1] * (te - ts) + [0] * (max_len - te)
a, b = torch.tensor(a), torch.tensor(b)

TP = int(torch.sum(a * b))
pre = TP / int(torch.sum(a)) if int(torch.sum(a)) > 0 else 0
rec = TP / int(torch.sum(b)) if int(torch.sum(b)) > 0 else 0

if pre + rec > 0:
f = (1 + (self.f_beta**2)) * pre * rec / ((self.f_beta**2) * pre + rec)
else:
f = 0
self.has_ans_f += f

def get_metric(self, reset=True):
evaluate_result = {}

if self.no_ans_correct + self.no_ans_wrong + self.has_ans_correct + self.no_ans_wrong <= 0:
return evaluate_result

evaluate_result['EM'] = 0
evaluate_result[f'f_{self.f_beta}'] = 0

flag = 0

if self.no_ans_correct + self.no_ans_wrong > 0:
evaluate_result[f'noAns-f_{self.f_beta}'] = \
round(100 * self.no_ans_correct / (self.no_ans_correct + self.no_ans_wrong), 3)
evaluate_result['noAns-EM'] = \
round(100 * self.no_ans_correct / (self.no_ans_correct + self.no_ans_wrong), 3)
evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'noAns-f_{self.f_beta}']
evaluate_result['EM'] += evaluate_result['noAns-EM']
flag += 1

if self.has_ans_correct + self.has_ans_wrong > 0:
evaluate_result[f'hasAns-f_{self.f_beta}'] = \
round(100 * self.has_ans_f / (self.has_ans_correct + self.has_ans_wrong), 3)
evaluate_result['hasAns-EM'] = \
round(100 * self.has_ans_correct / (self.has_ans_correct + self.has_ans_wrong), 3)
evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'hasAns-f_{self.f_beta}']
evaluate_result['EM'] += evaluate_result['hasAns-EM']
flag += 1

if self.print_predict_stat:
evaluate_result['no2no'] = self.no2no
evaluate_result['no2yes'] = self.no2yes
evaluate_result['yes2no'] = self.yes2no
evaluate_result['yes2yes'] = self.yes2yes

if flag <= 0:
return evaluate_result

evaluate_result[f'f_{self.f_beta}'] = round(evaluate_result[f'f_{self.f_beta}'] / flag, 3)
evaluate_result['EM'] = round(evaluate_result['EM'] / flag, 3)

if reset:
self.no_ans_correct = 0
self.no_ans_wrong = 0

self.has_ans_correct = 0
self.has_ans_wrong = 0

self.has_ans_f = 0.

self.no2no = 0
self.no2yes = 0
self.yes2no = 0
self.yes2yes = 0

return evaluate_result


+ 9
- 4
fastNLP/core/trainer.py View File

@@ -32,8 +32,8 @@ from fastNLP.core.utils import get_func_signature

class Trainer(object):
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50,
validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0),
check_code_level=0, metric_key=None, sampler=RandomSampler(), prefetch=False, use_tqdm=True,
validate_every=-1, dev_data=None, save_path=None, optimizer=None,
check_code_level=0, metric_key=None, sampler=None, prefetch=False, use_tqdm=True,
use_cuda=False, callbacks=None):
"""
:param DataSet train_data: the training data
@@ -96,7 +96,7 @@ class Trainer(object):
losser = _prepare_losser(loss)

# sampler check
if not isinstance(sampler, BaseSampler):
if sampler is not None and not isinstance(sampler, BaseSampler):
raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler)))

if check_code_level > -1:
@@ -119,7 +119,7 @@ class Trainer(object):
self.best_dev_epoch = None
self.best_dev_step = None
self.best_dev_perf = None
self.sampler = sampler
self.sampler = sampler if sampler is not None else RandomSampler()
self.prefetch = prefetch
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks)
self.n_steps = (len(self.train_data) // self.batch_size + int(
@@ -128,6 +128,8 @@ class Trainer(object):
if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer
else:
if optimizer is None:
optimizer = Adam(lr=0.01, weight_decay=0)
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters())

self.use_tqdm = use_tqdm
@@ -145,6 +147,7 @@ class Trainer(object):
self.step = 0
self.start_time = None # start timestamp


def train(self, load_best_model=True):
"""

@@ -365,6 +368,8 @@ class Trainer(object):
"""
if self.save_path is not None:
model_path = os.path.join(self.save_path, model_name)
if not os.path.exists(self.save_path):
os.makedirs(self.save_path, exist_ok=True)
if only_param:
state_dict = model.state_dict()
for key in state_dict:


+ 6
- 0
fastNLP/core/vocabulary.py View File

@@ -196,3 +196,9 @@ class Vocabulary(object):
"""
self.__dict__.update(state)
self.build_reverse_vocab()

def __repr__(self):
return "Vocabulary({}...)".format(list(self.word_count.keys())[:5])

def __iter__(self):
return iter(list(self.word_count.keys()))

+ 17
- 11
fastNLP/modules/decoder/CRF.py View File

@@ -192,20 +192,23 @@ class ConditionalRandomField(nn.Module):
seq_len, batch_size, n_tags = logits.size()
alpha = logits[0]
if self.include_start_end_trans:
alpha += self.start_scores.view(1, -1)
alpha = alpha + self.start_scores.view(1, -1)

flip_mask = mask.eq(0)

for i in range(1, seq_len):
emit_score = logits[i].view(batch_size, 1, n_tags)
trans_score = self.trans_m.view(1, n_tags, n_tags)
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score
alpha = log_sum_exp(tmp, 1) * mask[i].view(batch_size, 1) + alpha * (1 - mask[i]).view(batch_size, 1)
alpha = log_sum_exp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \
alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0)

if self.include_start_end_trans:
alpha += self.end_scores.view(1, -1)
alpha = alpha + self.end_scores.view(1, -1)

return log_sum_exp(alpha, 1)

def _glod_score(self, logits, tags, mask):
def _gold_score(self, logits, tags, mask):
"""
Compute the score for the gold path.
:param logits: FloatTensor, max_len x batch_size x num_tags
@@ -218,17 +221,19 @@ class ConditionalRandomField(nn.Module):
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)

# trans_socre [L-1, B]
trans_score = self.trans_m[tags[:seq_len-1], tags[1:]] * mask[1:, :]
mask = mask.byte()
flip_mask = mask.eq(0)
trans_score = self.trans_m[tags[:seq_len-1], tags[1:]].masked_fill(flip_mask[1:, :], 0)
# emit_score [L, B]
emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags] * mask
emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags].masked_fill(flip_mask, 0)
# score [L-1, B]
score = trans_score + emit_score[:seq_len-1, :]
score = score.sum(0) + emit_score[-1] * mask[-1]
score = score.sum(0) + emit_score[-1].masked_fill(flip_mask[-1], 0)
if self.include_start_end_trans:
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]]
last_idx = mask.long().sum(0) - 1
ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]]
score += st_scores + ed_scores
score = score + st_scores + ed_scores
# return [B,]
return score

@@ -244,7 +249,7 @@ class ConditionalRandomField(nn.Module):
tags = tags.transpose(0, 1).long()
mask = mask.transpose(0, 1).float()
all_path_score = self._normalizer_likelihood(feats, mask)
gold_path_score = self._glod_score(feats, tags, mask)
gold_path_score = self._gold_score(feats, tags, mask)

return all_path_score - gold_path_score

@@ -265,7 +270,7 @@ class ConditionalRandomField(nn.Module):
"""
batch_size, seq_len, n_tags = data.size()
data = data.transpose(0, 1).data # L, B, H
mask = mask.transpose(0, 1).data.float() # L, B
mask = mask.transpose(0, 1).data.byte() # L, B

# dp
vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
@@ -284,7 +289,8 @@ class ConditionalRandomField(nn.Module):
score = prev_score + trans_score + cur_score
best_score, best_dst = score.max(1)
vpath[i] = best_dst
vscore = best_score * mask[i].view(batch_size, 1) + vscore * (1 - mask[i]).view(batch_size, 1)
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \
vscore.masked_fill(mask[i].view(batch_size, 1), 0)

vscore += transitions[:n_tags, n_tags+1].view(1, -1)



+ 2
- 1
fastNLP/modules/utils.py View File

@@ -60,7 +60,8 @@ def initial_parameter(net, initial_method=None):
init_method(w.data) # weight
else:
init.normal_(w.data) # bias
elif hasattr(m, 'weight') and m.weight.requires_grad:
elif m is not None and hasattr(m, 'weight') and \
hasattr(m.weight, "requires_grad"):
init_method(m.weight.data)
else:
for w in m.parameters():


+ 111
- 0
test/automl/test_enas.py View File

@@ -0,0 +1,111 @@
import unittest

from fastNLP import DataSet
from fastNLP import Instance
from fastNLP import Vocabulary
from fastNLP.core.losses import CrossEntropyLoss
from fastNLP.core.metrics import AccuracyMetric


class TestENAS(unittest.TestCase):
def testENAS(self):
# 从csv读取数据到DataSet
sample_path = "tutorials/sample_data/tutorial_sample_dataset.csv"
dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'),
sep='\t')
print(len(dataset))
print(dataset[0])
print(dataset[-3])

dataset.append(Instance(raw_sentence='fake data', label='0'))
# 将所有数字转为小写
dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')
# label转int
dataset.apply(lambda x: int(x['label']), new_field_name='label')

# 使用空格分割句子
def split_sent(ins):
return ins['raw_sentence'].split()

dataset.apply(split_sent, new_field_name='words')

# 增加长度信息
dataset.apply(lambda x: len(x['words']), new_field_name='seq_len')
print(len(dataset))
print(dataset[0])

# DataSet.drop(func)筛除数据
dataset.drop(lambda x: x['seq_len'] <= 3)
print(len(dataset))

# 设置DataSet中,哪些field要转为tensor
# set target,loss或evaluate中的golden,计算loss,模型评估时使用
dataset.set_target("label")
# set input,模型forward时使用
dataset.set_input("words", "seq_len")

# 分出测试集、训练集
test_data, train_data = dataset.split(0.5)
print(len(test_data))
print(len(train_data))

# 构建词表, Vocabulary.add(word)
vocab = Vocabulary(min_freq=2)
train_data.apply(lambda x: [vocab.add(word) for word in x['words']])
vocab.build_vocab()

# index句子, Vocabulary.to_index(word)
train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')
test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')
print(test_data[0])

# 如果你们需要做强化学习或者GAN之类的项目,你们也可以使用这些数据预处理的工具
from fastNLP.core.batch import Batch
from fastNLP.core.sampler import RandomSampler

batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler())
for batch_x, batch_y in batch_iterator:
print("batch_x has: ", batch_x)
print("batch_y has: ", batch_y)
break

from fastNLP.automl.enas_model import ENASModel
from fastNLP.automl.enas_controller import Controller
model = ENASModel(embed_num=len(vocab), num_classes=5)
controller = Controller()

from fastNLP.automl.enas_trainer import ENASTrainer

# 更改DataSet中对应field的名称,要以模型的forward等参数名一致
train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致
train_data.rename_field('label', 'label_seq')
test_data.rename_field('words', 'word_seq')
test_data.rename_field('label', 'label_seq')

loss = CrossEntropyLoss(pred="output", target="label_seq")
metric = AccuracyMetric(pred="predict", target="label_seq")

trainer = ENASTrainer(model=model, controller=controller, train_data=train_data, dev_data=test_data,
loss=CrossEntropyLoss(pred="output", target="label_seq"),
metrics=AccuracyMetric(pred="predict", target="label_seq"),
check_code_level=-1,
save_path=None,
batch_size=32,
print_every=1,
n_epochs=3,
final_epochs=1)
trainer.train()
print('Train finished!')

# 调用Tester在test_data上评价效果
from fastNLP import Tester

tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"),
batch_size=4)

acc = tester.test()
print(acc)


if __name__ == '__main__':
unittest.main()

+ 25
- 0
test/core/test_callbacks.py View File

@@ -136,3 +136,28 @@ class TestCallback(unittest.TestCase):
metrics=AccuracyMetric(pred="predict", target="y"),
callbacks=[TensorboardCallback("loss", "metric")])
trainer.train()

def test_readonly_property(self):
from fastNLP.core.callback import Callback
class MyCallback(Callback):
def __init__(self):
super(MyCallback, self).__init__()

def on_epoch_begin(self, cur_epoch, total_epoch):
print(self.n_epochs, self.n_steps, self.batch_size)
print(self.model)
print(self.optimizer)

data_set, model = prepare_env()
trainer = Trainer(data_set, model,
loss=BCELoss(pred="predict", target="y"),
n_epochs=5,
batch_size=32,
print_every=50,
optimizer=SGD(lr=0.1),
check_code_level=2,
use_tqdm=False,
dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"),
callbacks=[MyCallback()])
trainer.train()

+ 15
- 2
test/core/test_dataset.py View File

@@ -52,7 +52,7 @@ class TestDataSetMethods(unittest.TestCase):
self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3)
self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3)

def test_add_append(self):
def test_add_field(self):
dd = DataSet()
dd.add_field("x", [[1, 2, 3]] * 10)
dd.add_field("y", [[1, 2, 3, 4]] * 10)
@@ -65,6 +65,11 @@ class TestDataSetMethods(unittest.TestCase):
with self.assertRaises(RuntimeError):
dd.add_field("??", [[1, 2]] * 40)

def test_add_field_ignore_type(self):
dd = DataSet()
dd.add_field("x", [(1, "1"), (2, "2"), (3, "3"), (4, "4")], ignore_type=True, is_target=True)
dd.add_field("y", [{1, "1"}, {2, "2"}, {3, "3"}, {4, "4"}], ignore_type=True, is_target=True)

def test_delete_field(self):
dd = DataSet()
dd.add_field("x", [[1, 2, 3]] * 10)
@@ -115,6 +120,9 @@ class TestDataSetMethods(unittest.TestCase):
self.assertTrue(isinstance(res, list) and len(res) > 0)
self.assertTrue(res[0], 4)

ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k", ignore_type=True)
# expect no exception raised

def test_drop(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20})
ds.drop(lambda ins: len(ins["y"]) < 3)
@@ -165,7 +173,7 @@ class TestDataSetMethods(unittest.TestCase):
dataset.apply(split_sent, new_field_name='words', is_input=True)
# print(dataset)

def test_add_field(self):
def test_add_field_v2(self):
ds = DataSet({"x": [3, 4]})
ds.add_field('y', [['hello', 'world'], ['this', 'is', 'a', 'test']], is_input=True, is_target=True)
# ds.apply(lambda x:[x['x']]*3, is_input=True, is_target=True, new_field_name='y')
@@ -208,6 +216,11 @@ class TestDataSetMethods(unittest.TestCase):
self.assertTrue(isinstance(ds, DataSet))
self.assertTrue(len(ds) > 0)

def test_add_null(self):
ds = DataSet()
ds.add_field('test', [])
ds.set_target('test')


class TestDataSetIter(unittest.TestCase):
def test__repr__(self):


+ 18
- 1
test/core/test_fieldarray.py View File

@@ -155,6 +155,13 @@ class TestFieldArray(unittest.TestCase):
self.assertEqual(len(fa), 3)
self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6])

def test_ignore_type(self):
# 测试新添加的参数ignore_type,用来跳过类型检查
fa = FieldArray("y", [[1.1, 2.2, "jin", {}, "hahah"], [int, 2, "$", 4, 5]], is_input=True, ignore_type=True)
fa.append([1.2, 2.3, str, 4.5, print])

fa = FieldArray("y", [(1, "1"), (2, "2"), (3, "3"), (4, "4")], is_target=True, ignore_type=True)


class TestPadder(unittest.TestCase):

@@ -215,4 +222,14 @@ class TestPadder(unittest.TestCase):
[[[1, 2, 3, -100, -100], [4, 5, -100, -100, -100], [7, 8, 9, 10, -100]],
[[1, -100, -100, -100, -100], [-100, -100, -100, -100, -100], [-100, -100, -100, -100, -100]]],
padder(contents, None, np.int64).tolist()
)
)

def test_None_dtype(self):
from fastNLP.core.fieldarray import AutoPadder
padder = AutoPadder()
content = [
[[1, 2, 3], [4, 5], [7, 8, 9, 10]],
[[1]]
]
ans = padder(content, None, None)
self.assertListEqual(content, ans)

+ 9
- 0
test/core/test_vocabulary.py View File

@@ -60,6 +60,15 @@ class TestIndexing(unittest.TestCase):
vocab.update(text)
self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]])

def test_iteration(self):
vocab = Vocabulary()
text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in",
"works", "well", "in", "most", "cases", "scales", "well"]
vocab.update(text)
text = set(text)
for word in vocab:
self.assertTrue(word in text)


class TestOther(unittest.TestCase):
def test_additional_update(self):


+ 26
- 2
test/modules/decoder/test_CRF.py View File

@@ -66,7 +66,7 @@ class TestCRF(unittest.TestCase):
# from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label))
# fast_CRF.trans_m = trans_m
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True)
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True, unpad=True)
# # score equal
# self.assertListEqual([score for _, score in allen_res], fast_res[1])
# # seq equal
@@ -95,10 +95,34 @@ class TestCRF(unittest.TestCase):
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label,
# encoding_type='BMES'))
# fast_CRF.trans_m = trans_m
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True)
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True, unpad=True)
# # score equal
# self.assertListEqual([score for _, score in allen_res], fast_res[1])
# # seq equal
# self.assertListEqual([_ for _, score in allen_res], fast_res[0])

def test_case3(self):
# 测试crf的loss不会出现负数
import torch
from fastNLP.modules.decoder.CRF import ConditionalRandomField
from fastNLP.core.utils import seq_lens_to_masks
from torch import optim
from torch import nn

num_tags, include_start_end_trans = 4, True
num_samples = 4
lengths = torch.randint(3, 50, size=(num_samples, )).long()
max_len = lengths.max()
tags = torch.randint(num_tags, size=(num_samples, max_len))
masks = seq_lens_to_masks(lengths)
feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags))
crf = ConditionalRandomField(num_tags, include_start_end_trans)
optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1)
for _ in range(10000):
loss = crf(feats, tags, masks).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
if _%1000==0:
print(loss)
assert loss.item()>0, "CRF loss cannot be less than 0."

Loading…
Cancel
Save