Browse Source

- fix trainer with validate_every > 0

- refine & fix Transformer Encoder
- refine & speed up biaffine parser
tags/v0.3.1^2
yunfan 5 years ago
parent
commit
2e9e6c6c20
7 changed files with 138 additions and 89 deletions
  1. +1
    -1
      fastNLP/core/trainer.py
  2. +38
    -44
      fastNLP/models/biaffine_parser.py
  3. +58
    -23
      fastNLP/modules/aggregator/attention.py
  4. +31
    -12
      fastNLP/modules/encoder/transformer.py
  5. +5
    -6
      reproduction/Biaffine_parser/cfg.cfg
  6. +3
    -2
      reproduction/Biaffine_parser/run.py
  7. +2
    -1
      test/models/test_biaffine_parser.py

+ 1
- 1
fastNLP/core/trainer.py View File

@@ -281,7 +281,7 @@ class Trainer(object):
self.callback_manager.after_batch() self.callback_manager.after_batch()


if ((self.validate_every > 0 and self.step % self.validate_every == 0) or if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
(self.validate_every < 0 and self.step % len(data_iterator)) == 0) \
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \
and self.dev_data is not None: and self.dev_data is not None:
eval_res = self._do_validation(epoch=epoch, step=self.step) eval_res = self._do_validation(epoch=epoch, step=self.step)
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,


+ 38
- 44
fastNLP/models/biaffine_parser.py View File

@@ -6,6 +6,7 @@ from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from fastNLP.modules.utils import initial_parameter from fastNLP.modules.utils import initial_parameter
from fastNLP.modules.encoder.variational_rnn import VarLSTM from fastNLP.modules.encoder.variational_rnn import VarLSTM
from fastNLP.modules.encoder.transformer import TransformerEncoder
from fastNLP.modules.dropout import TimestepDropout from fastNLP.modules.dropout import TimestepDropout
from fastNLP.models.base_model import BaseModel from fastNLP.models.base_model import BaseModel
from fastNLP.modules.utils import seq_mask from fastNLP.modules.utils import seq_mask
@@ -197,53 +198,49 @@ class BiaffineParser(GraphParser):
pos_vocab_size, pos_vocab_size,
pos_emb_dim, pos_emb_dim,
num_label, num_label,
word_hid_dim=100,
pos_hid_dim=100,
rnn_layers=1, rnn_layers=1,
rnn_hidden_size=200, rnn_hidden_size=200,
arc_mlp_size=100, arc_mlp_size=100,
label_mlp_size=100, label_mlp_size=100,
dropout=0.3, dropout=0.3,
use_var_lstm=False,
encoder='lstm',
use_greedy_infer=False): use_greedy_infer=False):


super(BiaffineParser, self).__init__() super(BiaffineParser, self).__init__()
rnn_out_size = 2 * rnn_hidden_size rnn_out_size = 2 * rnn_hidden_size
word_hid_dim = pos_hid_dim = rnn_hidden_size
self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim)
self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim)
self.word_fc = nn.Linear(word_emb_dim, word_hid_dim) self.word_fc = nn.Linear(word_emb_dim, word_hid_dim)
self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim)
self.word_norm = nn.LayerNorm(word_hid_dim) self.word_norm = nn.LayerNorm(word_hid_dim)
self.pos_norm = nn.LayerNorm(pos_hid_dim) self.pos_norm = nn.LayerNorm(pos_hid_dim)
self.use_var_lstm = use_var_lstm
if use_var_lstm:
self.lstm = VarLSTM(input_size=word_hid_dim + pos_hid_dim,
hidden_size=rnn_hidden_size,
num_layers=rnn_layers,
bias=True,
batch_first=True,
input_dropout=dropout,
hidden_dropout=dropout,
bidirectional=True)
self.encoder_name = encoder
if encoder == 'var-lstm':
self.encoder = VarLSTM(input_size=word_hid_dim + pos_hid_dim,
hidden_size=rnn_hidden_size,
num_layers=rnn_layers,
bias=True,
batch_first=True,
input_dropout=dropout,
hidden_dropout=dropout,
bidirectional=True)
elif encoder == 'lstm':
self.encoder = nn.LSTM(input_size=word_hid_dim + pos_hid_dim,
hidden_size=rnn_hidden_size,
num_layers=rnn_layers,
bias=True,
batch_first=True,
dropout=dropout,
bidirectional=True)
else: else:
self.lstm = nn.LSTM(input_size=word_hid_dim + pos_hid_dim,
hidden_size=rnn_hidden_size,
num_layers=rnn_layers,
bias=True,
batch_first=True,
dropout=dropout,
bidirectional=True)

self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size),
nn.LayerNorm(arc_mlp_size),
raise ValueError('unsupported encoder type: {}'.format(encoder))

self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2),
nn.ELU(), nn.ELU(),
TimestepDropout(p=dropout),) TimestepDropout(p=dropout),)
self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp)
self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size),
nn.LayerNorm(label_mlp_size),
nn.ELU(),
TimestepDropout(p=dropout),)
self.label_dep_mlp = copy.deepcopy(self.label_head_mlp)
self.arc_mlp_size = arc_mlp_size
self.label_mlp_size = label_mlp_size
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True)
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True)
self.use_greedy_infer = use_greedy_infer self.use_greedy_infer = use_greedy_infer
@@ -286,24 +283,22 @@ class BiaffineParser(GraphParser):
word, pos = self.word_fc(word), self.pos_fc(pos) word, pos = self.word_fc(word), self.pos_fc(pos)
word, pos = self.word_norm(word), self.pos_norm(pos) word, pos = self.word_norm(word), self.pos_norm(pos)
x = torch.cat([word, pos], dim=2) # -> [N,L,C] x = torch.cat([word, pos], dim=2) # -> [N,L,C]
del word, pos


# lstm, extract features
# encoder, extract features
sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True)
x = x[sort_idx] x = x[sort_idx]
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True)
feat, _ = self.lstm(x) # -> [N,L,C]
feat, _ = self.encoder(x) # -> [N,L,C]
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True)
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
feat = feat[unsort_idx] feat = feat[unsort_idx]


# for arc biaffine # for arc biaffine
# mlp, reduce dim # mlp, reduce dim
arc_dep = self.arc_dep_mlp(feat)
arc_head = self.arc_head_mlp(feat)
label_dep = self.label_dep_mlp(feat)
label_head = self.label_head_mlp(feat)
del feat
feat = self.mlp(feat)
arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size
arc_dep, arc_head = feat[:,:,:arc_sz], feat[:,:,arc_sz:2*arc_sz]
label_dep, label_head = feat[:,:,2*arc_sz:2*arc_sz+label_sz], feat[:,:,2*arc_sz+label_sz:]


# biaffine arc classifier # biaffine arc classifier
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L]
@@ -349,7 +344,7 @@ class BiaffineParser(GraphParser):
batch_size, seq_len, _ = arc_pred.shape batch_size, seq_len, _ = arc_pred.shape
flip_mask = (mask == 0) flip_mask = (mask == 0)
_arc_pred = arc_pred.clone() _arc_pred = arc_pred.clone()
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf)
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf'))
arc_logits = F.log_softmax(_arc_pred, dim=2) arc_logits = F.log_softmax(_arc_pred, dim=2)
label_logits = F.log_softmax(label_pred, dim=2) label_logits = F.log_softmax(label_pred, dim=2)
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1)
@@ -357,12 +352,11 @@ class BiaffineParser(GraphParser):
arc_loss = arc_logits[batch_index, child_index, arc_true] arc_loss = arc_logits[batch_index, child_index, arc_true]
label_loss = label_logits[batch_index, child_index, label_true] label_loss = label_logits[batch_index, child_index, label_true]


arc_loss = arc_loss[:, 1:]
label_loss = label_loss[:, 1:]

float_mask = mask[:, 1:].float()
arc_nll = -(arc_loss*float_mask).mean()
label_nll = -(label_loss*float_mask).mean()
byte_mask = flip_mask.byte()
arc_loss.masked_fill_(byte_mask, 0)
label_loss.masked_fill_(byte_mask, 0)
arc_nll = -arc_loss.mean()
label_nll = -label_loss.mean()
return arc_nll + label_nll return arc_nll + label_nll


def predict(self, word_seq, pos_seq, seq_lens): def predict(self, word_seq, pos_seq, seq_lens):


+ 58
- 23
fastNLP/modules/aggregator/attention.py View File

@@ -5,6 +5,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn


from fastNLP.modules.utils import mask_softmax from fastNLP.modules.utils import mask_softmax
from fastNLP.modules.dropout import TimestepDropout




class Attention(torch.nn.Module): class Attention(torch.nn.Module):
@@ -23,47 +24,81 @@ class Attention(torch.nn.Module):




class DotAtte(nn.Module): class DotAtte(nn.Module):
def __init__(self, key_size, value_size):
def __init__(self, key_size, value_size, dropout=0.1):
super(DotAtte, self).__init__() super(DotAtte, self).__init__()
self.key_size = key_size self.key_size = key_size
self.value_size = value_size self.value_size = value_size
self.scale = math.sqrt(key_size) self.scale = math.sqrt(key_size)
self.drop = nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=2)


def forward(self, Q, K, V, seq_mask=None):
def forward(self, Q, K, V, mask_out=None):
""" """


:param Q: [batch, seq_len, key_size] :param Q: [batch, seq_len, key_size]
:param K: [batch, seq_len, key_size] :param K: [batch, seq_len, key_size]
:param V: [batch, seq_len, value_size] :param V: [batch, seq_len, value_size]
:param seq_mask: [batch, seq_len]
:param mask_out: [batch, seq_len]
""" """
output = torch.matmul(Q, K.transpose(1, 2)) / self.scale output = torch.matmul(Q, K.transpose(1, 2)) / self.scale
if seq_mask is not None:
output.masked_fill_(seq_mask.lt(1), -float('inf'))
output = nn.functional.softmax(output, dim=2)
if mask_out is not None:
output.masked_fill_(mask_out, -float('inf'))
output = self.softmax(output)
output = self.drop(output)
return torch.matmul(output, V) return torch.matmul(output, V)




class MultiHeadAtte(nn.Module): class MultiHeadAtte(nn.Module):
def __init__(self, input_size, output_size, key_size, value_size, num_atte):
def __init__(self, model_size, key_size, value_size, num_head, dropout=0.1):
super(MultiHeadAtte, self).__init__() super(MultiHeadAtte, self).__init__()
self.in_linear = nn.ModuleList()
for i in range(num_atte * 3):
out_feat = key_size if (i % 3) != 2 else value_size
self.in_linear.append(nn.Linear(input_size, out_feat))
self.attes = nn.ModuleList([DotAtte(key_size, value_size) for _ in range(num_atte)])
self.out_linear = nn.Linear(value_size * num_atte, output_size)

def forward(self, Q, K, V, seq_mask=None):
heads = []
for i in range(len(self.attes)):
j = i * 3
qi, ki, vi = self.in_linear[j](Q), self.in_linear[j+1](K), self.in_linear[j+2](V)
headi = self.attes[i](qi, ki, vi, seq_mask)
heads.append(headi)
output = torch.cat(heads, dim=2)
return self.out_linear(output)
self.input_size = model_size
self.key_size = key_size
self.value_size = value_size
self.num_head = num_head

in_size = key_size * num_head
self.q_in = nn.Linear(model_size, in_size)
self.k_in = nn.Linear(model_size, in_size)
self.v_in = nn.Linear(model_size, in_size)
self.attention = DotAtte(key_size=key_size, value_size=value_size)
self.out = nn.Linear(value_size * num_head, model_size)
self.drop = TimestepDropout(dropout)
self.reset_parameters()

def reset_parameters(self):
sqrt = math.sqrt
nn.init.normal_(self.q_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size)))
nn.init.normal_(self.k_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size)))
nn.init.normal_(self.v_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.value_size)))
nn.init.xavier_normal_(self.out.weight)

def forward(self, Q, K, V, atte_mask_out=None):
"""


:param Q: [batch, seq_len, model_size]
:param K: [batch, seq_len, model_size]
:param V: [batch, seq_len, model_size]
:param seq_mask: [batch, seq_len]
"""
batch, seq_len, _ = Q.size()
d_k, d_v, n_head = self.key_size, self.value_size, self.num_head
# input linear
q = self.q_in(Q).view(batch, seq_len, n_head, d_k)
k = self.k_in(K).view(batch, seq_len, n_head, d_k)
v = self.v_in(V).view(batch, seq_len, n_head, d_k)

# transpose q, k and v to do batch attention
q = q.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_k)
k = k.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_k)
v = v.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_v)
if atte_mask_out is not None:
atte_mask_out = atte_mask_out.repeat(n_head, 1, 1)
atte = self.attention(q, k, v, atte_mask_out).view(n_head, batch, seq_len, d_v)

# concat all heads, do output linear
atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, seq_len, -1)
output = self.drop(self.out(atte))
return output


class Bi_Attention(nn.Module): class Bi_Attention(nn.Module):
def __init__(self): def __init__(self):


+ 31
- 12
fastNLP/modules/encoder/transformer.py View File

@@ -1,29 +1,48 @@
import torch
from torch import nn from torch import nn


from ..aggregator.attention import MultiHeadAtte from ..aggregator.attention import MultiHeadAtte
from ..other_modules import LayerNormalization
from ..dropout import TimestepDropout




class TransformerEncoder(nn.Module): class TransformerEncoder(nn.Module):
class SubLayer(nn.Module): class SubLayer(nn.Module):
def __init__(self, input_size, output_size, key_size, value_size, num_atte):
def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1):
super(TransformerEncoder.SubLayer, self).__init__() super(TransformerEncoder.SubLayer, self).__init__()
self.atte = MultiHeadAtte(input_size, output_size, key_size, value_size, num_atte)
self.norm1 = LayerNormalization(output_size)
self.ffn = nn.Sequential(nn.Linear(output_size, output_size),
self.atte = MultiHeadAtte(model_size, key_size, value_size, num_head, dropout)
self.norm1 = nn.LayerNorm(model_size)
self.ffn = nn.Sequential(nn.Linear(model_size, inner_size),
nn.ReLU(), nn.ReLU(),
nn.Linear(output_size, output_size))
self.norm2 = LayerNormalization(output_size)
nn.Linear(inner_size, model_size),
TimestepDropout(dropout),)
self.norm2 = nn.LayerNorm(model_size)


def forward(self, input, seq_mask):
attention = self.atte(input)
def forward(self, input, seq_mask=None, atte_mask_out=None):
"""

:param input: [batch, seq_len, model_size]
:param seq_mask: [batch, seq_len]
:return: [batch, seq_len, model_size]
"""
attention = self.atte(input, input, input, atte_mask_out)
norm_atte = self.norm1(attention + input) norm_atte = self.norm1(attention + input)
attention *= seq_mask
output = self.ffn(norm_atte) output = self.ffn(norm_atte)
return self.norm2(output + norm_atte)
output = self.norm2(output + norm_atte)
output *= seq_mask
return output


def __init__(self, num_layers, **kargs): def __init__(self, num_layers, **kargs):
super(TransformerEncoder, self).__init__() super(TransformerEncoder, self).__init__()
self.layers = nn.Sequential(*[self.SubLayer(**kargs) for _ in range(num_layers)])
self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)])


def forward(self, x, seq_mask=None): def forward(self, x, seq_mask=None):
return self.layers(x, seq_mask)
output = x
if seq_mask is None:
atte_mask_out = None
else:
atte_mask_out = (seq_mask < 1)[:,None,:]
seq_mask = seq_mask[:,:,None]
for layer in self.layers:
output = layer(output, seq_mask, atte_mask_out)
return output

+ 5
- 6
reproduction/Biaffine_parser/cfg.cfg View File

@@ -2,7 +2,8 @@
n_epochs = 40 n_epochs = 40
batch_size = 32 batch_size = 32
use_cuda = true use_cuda = true
validate_every = 500
use_tqdm=true
validate_every = -1
use_golden_train=true use_golden_train=true


[test] [test]
@@ -19,15 +20,13 @@ word_vocab_size = -1
word_emb_dim = 100 word_emb_dim = 100
pos_vocab_size = -1 pos_vocab_size = -1
pos_emb_dim = 100 pos_emb_dim = 100
word_hid_dim = 100
pos_hid_dim = 100
rnn_layers = 3 rnn_layers = 3
rnn_hidden_size = 400
rnn_hidden_size = 256
arc_mlp_size = 500 arc_mlp_size = 500
label_mlp_size = 100 label_mlp_size = 100
num_label = -1 num_label = -1
dropout = 0.33
use_var_lstm=true
dropout = 0.3
encoder="transformer"
use_greedy_infer=false use_greedy_infer=false


[optim] [optim]


+ 3
- 2
reproduction/Biaffine_parser/run.py View File

@@ -141,7 +141,7 @@ model_args['pos_vocab_size'] = len(pos_v)
model_args['num_label'] = len(tag_v) model_args['num_label'] = len(tag_v)


model = BiaffineParser(**model_args.data) model = BiaffineParser(**model_args.data)
model.reset_parameters()
print(model)


word_idxp = IndexerProcessor(word_v, 'words', 'word_seq') word_idxp = IndexerProcessor(word_v, 'words', 'word_seq')
pos_idxp = IndexerProcessor(pos_v, 'pos', 'pos_seq') pos_idxp = IndexerProcessor(pos_v, 'pos', 'pos_seq')
@@ -209,7 +209,8 @@ def save_pipe(path):
pipe = Pipeline(processors=[num_p, word_idxp, pos_idxp, seq_p, set_input_p]) pipe = Pipeline(processors=[num_p, word_idxp, pos_idxp, seq_p, set_input_p])
pipe.add_processor(ModelProcessor(model=model, batch_size=32)) pipe.add_processor(ModelProcessor(model=model, batch_size=32))
pipe.add_processor(label_toword_p) pipe.add_processor(label_toword_p)
torch.save(pipe, os.path.join(path, 'pipe.pkl'))
os.makedirs(path, exist_ok=True)
torch.save({'pipeline': pipe}, os.path.join(path, 'pipe.pkl'))




def test(path): def test(path):


+ 2
- 1
test/models/test_biaffine_parser.py View File

@@ -77,9 +77,10 @@ class TestBiaffineParser(unittest.TestCase):
ds, v1, v2, v3 = init_data() ds, v1, v2, v3 = init_data()
model = BiaffineParser(word_vocab_size=len(v1), word_emb_dim=30, model = BiaffineParser(word_vocab_size=len(v1), word_emb_dim=30,
pos_vocab_size=len(v2), pos_emb_dim=30, pos_vocab_size=len(v2), pos_emb_dim=30,
num_label=len(v3), use_var_lstm=True)
num_label=len(v3), encoder='var-lstm')
trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds, trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds,
loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS',
batch_size=1, validate_every=10,
n_epochs=10, use_cuda=False, use_tqdm=False) n_epochs=10, use_cuda=False, use_tqdm=False)
trainer.train(load_best_model=False) trainer.train(load_best_model=False)




Loading…
Cancel
Save