- refine & fix Transformer Encoder - refine & speed up biaffine parsertags/v0.3.1^2
@@ -281,7 +281,7 @@ class Trainer(object): | |||||
self.callback_manager.after_batch() | self.callback_manager.after_batch() | ||||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | ||||
(self.validate_every < 0 and self.step % len(data_iterator)) == 0) \ | |||||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | |||||
and self.dev_data is not None: | and self.dev_data is not None: | ||||
eval_res = self._do_validation(epoch=epoch, step=self.step) | eval_res = self._do_validation(epoch=epoch, step=self.step) | ||||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | ||||
@@ -6,6 +6,7 @@ from torch import nn | |||||
from torch.nn import functional as F | from torch.nn import functional as F | ||||
from fastNLP.modules.utils import initial_parameter | from fastNLP.modules.utils import initial_parameter | ||||
from fastNLP.modules.encoder.variational_rnn import VarLSTM | from fastNLP.modules.encoder.variational_rnn import VarLSTM | ||||
from fastNLP.modules.encoder.transformer import TransformerEncoder | |||||
from fastNLP.modules.dropout import TimestepDropout | from fastNLP.modules.dropout import TimestepDropout | ||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules.utils import seq_mask | from fastNLP.modules.utils import seq_mask | ||||
@@ -197,53 +198,49 @@ class BiaffineParser(GraphParser): | |||||
pos_vocab_size, | pos_vocab_size, | ||||
pos_emb_dim, | pos_emb_dim, | ||||
num_label, | num_label, | ||||
word_hid_dim=100, | |||||
pos_hid_dim=100, | |||||
rnn_layers=1, | rnn_layers=1, | ||||
rnn_hidden_size=200, | rnn_hidden_size=200, | ||||
arc_mlp_size=100, | arc_mlp_size=100, | ||||
label_mlp_size=100, | label_mlp_size=100, | ||||
dropout=0.3, | dropout=0.3, | ||||
use_var_lstm=False, | |||||
encoder='lstm', | |||||
use_greedy_infer=False): | use_greedy_infer=False): | ||||
super(BiaffineParser, self).__init__() | super(BiaffineParser, self).__init__() | ||||
rnn_out_size = 2 * rnn_hidden_size | rnn_out_size = 2 * rnn_hidden_size | ||||
word_hid_dim = pos_hid_dim = rnn_hidden_size | |||||
self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) | self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) | ||||
self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) | self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) | ||||
self.word_fc = nn.Linear(word_emb_dim, word_hid_dim) | self.word_fc = nn.Linear(word_emb_dim, word_hid_dim) | ||||
self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) | self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) | ||||
self.word_norm = nn.LayerNorm(word_hid_dim) | self.word_norm = nn.LayerNorm(word_hid_dim) | ||||
self.pos_norm = nn.LayerNorm(pos_hid_dim) | self.pos_norm = nn.LayerNorm(pos_hid_dim) | ||||
self.use_var_lstm = use_var_lstm | |||||
if use_var_lstm: | |||||
self.lstm = VarLSTM(input_size=word_hid_dim + pos_hid_dim, | |||||
hidden_size=rnn_hidden_size, | |||||
num_layers=rnn_layers, | |||||
bias=True, | |||||
batch_first=True, | |||||
input_dropout=dropout, | |||||
hidden_dropout=dropout, | |||||
bidirectional=True) | |||||
self.encoder_name = encoder | |||||
if encoder == 'var-lstm': | |||||
self.encoder = VarLSTM(input_size=word_hid_dim + pos_hid_dim, | |||||
hidden_size=rnn_hidden_size, | |||||
num_layers=rnn_layers, | |||||
bias=True, | |||||
batch_first=True, | |||||
input_dropout=dropout, | |||||
hidden_dropout=dropout, | |||||
bidirectional=True) | |||||
elif encoder == 'lstm': | |||||
self.encoder = nn.LSTM(input_size=word_hid_dim + pos_hid_dim, | |||||
hidden_size=rnn_hidden_size, | |||||
num_layers=rnn_layers, | |||||
bias=True, | |||||
batch_first=True, | |||||
dropout=dropout, | |||||
bidirectional=True) | |||||
else: | else: | ||||
self.lstm = nn.LSTM(input_size=word_hid_dim + pos_hid_dim, | |||||
hidden_size=rnn_hidden_size, | |||||
num_layers=rnn_layers, | |||||
bias=True, | |||||
batch_first=True, | |||||
dropout=dropout, | |||||
bidirectional=True) | |||||
self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), | |||||
nn.LayerNorm(arc_mlp_size), | |||||
raise ValueError('unsupported encoder type: {}'.format(encoder)) | |||||
self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2), | |||||
nn.ELU(), | nn.ELU(), | ||||
TimestepDropout(p=dropout),) | TimestepDropout(p=dropout),) | ||||
self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) | |||||
self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), | |||||
nn.LayerNorm(label_mlp_size), | |||||
nn.ELU(), | |||||
TimestepDropout(p=dropout),) | |||||
self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) | |||||
self.arc_mlp_size = arc_mlp_size | |||||
self.label_mlp_size = label_mlp_size | |||||
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | ||||
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | ||||
self.use_greedy_infer = use_greedy_infer | self.use_greedy_infer = use_greedy_infer | ||||
@@ -286,24 +283,22 @@ class BiaffineParser(GraphParser): | |||||
word, pos = self.word_fc(word), self.pos_fc(pos) | word, pos = self.word_fc(word), self.pos_fc(pos) | ||||
word, pos = self.word_norm(word), self.pos_norm(pos) | word, pos = self.word_norm(word), self.pos_norm(pos) | ||||
x = torch.cat([word, pos], dim=2) # -> [N,L,C] | x = torch.cat([word, pos], dim=2) # -> [N,L,C] | ||||
del word, pos | |||||
# lstm, extract features | |||||
# encoder, extract features | |||||
sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) | sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) | ||||
x = x[sort_idx] | x = x[sort_idx] | ||||
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) | x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) | ||||
feat, _ = self.lstm(x) # -> [N,L,C] | |||||
feat, _ = self.encoder(x) # -> [N,L,C] | |||||
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | ||||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | ||||
feat = feat[unsort_idx] | feat = feat[unsort_idx] | ||||
# for arc biaffine | # for arc biaffine | ||||
# mlp, reduce dim | # mlp, reduce dim | ||||
arc_dep = self.arc_dep_mlp(feat) | |||||
arc_head = self.arc_head_mlp(feat) | |||||
label_dep = self.label_dep_mlp(feat) | |||||
label_head = self.label_head_mlp(feat) | |||||
del feat | |||||
feat = self.mlp(feat) | |||||
arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size | |||||
arc_dep, arc_head = feat[:,:,:arc_sz], feat[:,:,arc_sz:2*arc_sz] | |||||
label_dep, label_head = feat[:,:,2*arc_sz:2*arc_sz+label_sz], feat[:,:,2*arc_sz+label_sz:] | |||||
# biaffine arc classifier | # biaffine arc classifier | ||||
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | ||||
@@ -349,7 +344,7 @@ class BiaffineParser(GraphParser): | |||||
batch_size, seq_len, _ = arc_pred.shape | batch_size, seq_len, _ = arc_pred.shape | ||||
flip_mask = (mask == 0) | flip_mask = (mask == 0) | ||||
_arc_pred = arc_pred.clone() | _arc_pred = arc_pred.clone() | ||||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) | |||||
arc_logits = F.log_softmax(_arc_pred, dim=2) | arc_logits = F.log_softmax(_arc_pred, dim=2) | ||||
label_logits = F.log_softmax(label_pred, dim=2) | label_logits = F.log_softmax(label_pred, dim=2) | ||||
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) | batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) | ||||
@@ -357,12 +352,11 @@ class BiaffineParser(GraphParser): | |||||
arc_loss = arc_logits[batch_index, child_index, arc_true] | arc_loss = arc_logits[batch_index, child_index, arc_true] | ||||
label_loss = label_logits[batch_index, child_index, label_true] | label_loss = label_logits[batch_index, child_index, label_true] | ||||
arc_loss = arc_loss[:, 1:] | |||||
label_loss = label_loss[:, 1:] | |||||
float_mask = mask[:, 1:].float() | |||||
arc_nll = -(arc_loss*float_mask).mean() | |||||
label_nll = -(label_loss*float_mask).mean() | |||||
byte_mask = flip_mask.byte() | |||||
arc_loss.masked_fill_(byte_mask, 0) | |||||
label_loss.masked_fill_(byte_mask, 0) | |||||
arc_nll = -arc_loss.mean() | |||||
label_nll = -label_loss.mean() | |||||
return arc_nll + label_nll | return arc_nll + label_nll | ||||
def predict(self, word_seq, pos_seq, seq_lens): | def predict(self, word_seq, pos_seq, seq_lens): | ||||
@@ -5,6 +5,7 @@ import torch.nn.functional as F | |||||
from torch import nn | from torch import nn | ||||
from fastNLP.modules.utils import mask_softmax | from fastNLP.modules.utils import mask_softmax | ||||
from fastNLP.modules.dropout import TimestepDropout | |||||
class Attention(torch.nn.Module): | class Attention(torch.nn.Module): | ||||
@@ -23,47 +24,81 @@ class Attention(torch.nn.Module): | |||||
class DotAtte(nn.Module): | class DotAtte(nn.Module): | ||||
def __init__(self, key_size, value_size): | |||||
def __init__(self, key_size, value_size, dropout=0.1): | |||||
super(DotAtte, self).__init__() | super(DotAtte, self).__init__() | ||||
self.key_size = key_size | self.key_size = key_size | ||||
self.value_size = value_size | self.value_size = value_size | ||||
self.scale = math.sqrt(key_size) | self.scale = math.sqrt(key_size) | ||||
self.drop = nn.Dropout(dropout) | |||||
self.softmax = nn.Softmax(dim=2) | |||||
def forward(self, Q, K, V, seq_mask=None): | |||||
def forward(self, Q, K, V, mask_out=None): | |||||
""" | """ | ||||
:param Q: [batch, seq_len, key_size] | :param Q: [batch, seq_len, key_size] | ||||
:param K: [batch, seq_len, key_size] | :param K: [batch, seq_len, key_size] | ||||
:param V: [batch, seq_len, value_size] | :param V: [batch, seq_len, value_size] | ||||
:param seq_mask: [batch, seq_len] | |||||
:param mask_out: [batch, seq_len] | |||||
""" | """ | ||||
output = torch.matmul(Q, K.transpose(1, 2)) / self.scale | output = torch.matmul(Q, K.transpose(1, 2)) / self.scale | ||||
if seq_mask is not None: | |||||
output.masked_fill_(seq_mask.lt(1), -float('inf')) | |||||
output = nn.functional.softmax(output, dim=2) | |||||
if mask_out is not None: | |||||
output.masked_fill_(mask_out, -float('inf')) | |||||
output = self.softmax(output) | |||||
output = self.drop(output) | |||||
return torch.matmul(output, V) | return torch.matmul(output, V) | ||||
class MultiHeadAtte(nn.Module): | class MultiHeadAtte(nn.Module): | ||||
def __init__(self, input_size, output_size, key_size, value_size, num_atte): | |||||
def __init__(self, model_size, key_size, value_size, num_head, dropout=0.1): | |||||
super(MultiHeadAtte, self).__init__() | super(MultiHeadAtte, self).__init__() | ||||
self.in_linear = nn.ModuleList() | |||||
for i in range(num_atte * 3): | |||||
out_feat = key_size if (i % 3) != 2 else value_size | |||||
self.in_linear.append(nn.Linear(input_size, out_feat)) | |||||
self.attes = nn.ModuleList([DotAtte(key_size, value_size) for _ in range(num_atte)]) | |||||
self.out_linear = nn.Linear(value_size * num_atte, output_size) | |||||
def forward(self, Q, K, V, seq_mask=None): | |||||
heads = [] | |||||
for i in range(len(self.attes)): | |||||
j = i * 3 | |||||
qi, ki, vi = self.in_linear[j](Q), self.in_linear[j+1](K), self.in_linear[j+2](V) | |||||
headi = self.attes[i](qi, ki, vi, seq_mask) | |||||
heads.append(headi) | |||||
output = torch.cat(heads, dim=2) | |||||
return self.out_linear(output) | |||||
self.input_size = model_size | |||||
self.key_size = key_size | |||||
self.value_size = value_size | |||||
self.num_head = num_head | |||||
in_size = key_size * num_head | |||||
self.q_in = nn.Linear(model_size, in_size) | |||||
self.k_in = nn.Linear(model_size, in_size) | |||||
self.v_in = nn.Linear(model_size, in_size) | |||||
self.attention = DotAtte(key_size=key_size, value_size=value_size) | |||||
self.out = nn.Linear(value_size * num_head, model_size) | |||||
self.drop = TimestepDropout(dropout) | |||||
self.reset_parameters() | |||||
def reset_parameters(self): | |||||
sqrt = math.sqrt | |||||
nn.init.normal_(self.q_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size))) | |||||
nn.init.normal_(self.k_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size))) | |||||
nn.init.normal_(self.v_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.value_size))) | |||||
nn.init.xavier_normal_(self.out.weight) | |||||
def forward(self, Q, K, V, atte_mask_out=None): | |||||
""" | |||||
:param Q: [batch, seq_len, model_size] | |||||
:param K: [batch, seq_len, model_size] | |||||
:param V: [batch, seq_len, model_size] | |||||
:param seq_mask: [batch, seq_len] | |||||
""" | |||||
batch, seq_len, _ = Q.size() | |||||
d_k, d_v, n_head = self.key_size, self.value_size, self.num_head | |||||
# input linear | |||||
q = self.q_in(Q).view(batch, seq_len, n_head, d_k) | |||||
k = self.k_in(K).view(batch, seq_len, n_head, d_k) | |||||
v = self.v_in(V).view(batch, seq_len, n_head, d_k) | |||||
# transpose q, k and v to do batch attention | |||||
q = q.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_k) | |||||
k = k.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_k) | |||||
v = v.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_v) | |||||
if atte_mask_out is not None: | |||||
atte_mask_out = atte_mask_out.repeat(n_head, 1, 1) | |||||
atte = self.attention(q, k, v, atte_mask_out).view(n_head, batch, seq_len, d_v) | |||||
# concat all heads, do output linear | |||||
atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, seq_len, -1) | |||||
output = self.drop(self.out(atte)) | |||||
return output | |||||
class Bi_Attention(nn.Module): | class Bi_Attention(nn.Module): | ||||
def __init__(self): | def __init__(self): | ||||
@@ -1,29 +1,48 @@ | |||||
import torch | |||||
from torch import nn | from torch import nn | ||||
from ..aggregator.attention import MultiHeadAtte | from ..aggregator.attention import MultiHeadAtte | ||||
from ..other_modules import LayerNormalization | |||||
from ..dropout import TimestepDropout | |||||
class TransformerEncoder(nn.Module): | class TransformerEncoder(nn.Module): | ||||
class SubLayer(nn.Module): | class SubLayer(nn.Module): | ||||
def __init__(self, input_size, output_size, key_size, value_size, num_atte): | |||||
def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): | |||||
super(TransformerEncoder.SubLayer, self).__init__() | super(TransformerEncoder.SubLayer, self).__init__() | ||||
self.atte = MultiHeadAtte(input_size, output_size, key_size, value_size, num_atte) | |||||
self.norm1 = LayerNormalization(output_size) | |||||
self.ffn = nn.Sequential(nn.Linear(output_size, output_size), | |||||
self.atte = MultiHeadAtte(model_size, key_size, value_size, num_head, dropout) | |||||
self.norm1 = nn.LayerNorm(model_size) | |||||
self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), | |||||
nn.ReLU(), | nn.ReLU(), | ||||
nn.Linear(output_size, output_size)) | |||||
self.norm2 = LayerNormalization(output_size) | |||||
nn.Linear(inner_size, model_size), | |||||
TimestepDropout(dropout),) | |||||
self.norm2 = nn.LayerNorm(model_size) | |||||
def forward(self, input, seq_mask): | |||||
attention = self.atte(input) | |||||
def forward(self, input, seq_mask=None, atte_mask_out=None): | |||||
""" | |||||
:param input: [batch, seq_len, model_size] | |||||
:param seq_mask: [batch, seq_len] | |||||
:return: [batch, seq_len, model_size] | |||||
""" | |||||
attention = self.atte(input, input, input, atte_mask_out) | |||||
norm_atte = self.norm1(attention + input) | norm_atte = self.norm1(attention + input) | ||||
attention *= seq_mask | |||||
output = self.ffn(norm_atte) | output = self.ffn(norm_atte) | ||||
return self.norm2(output + norm_atte) | |||||
output = self.norm2(output + norm_atte) | |||||
output *= seq_mask | |||||
return output | |||||
def __init__(self, num_layers, **kargs): | def __init__(self, num_layers, **kargs): | ||||
super(TransformerEncoder, self).__init__() | super(TransformerEncoder, self).__init__() | ||||
self.layers = nn.Sequential(*[self.SubLayer(**kargs) for _ in range(num_layers)]) | |||||
self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)]) | |||||
def forward(self, x, seq_mask=None): | def forward(self, x, seq_mask=None): | ||||
return self.layers(x, seq_mask) | |||||
output = x | |||||
if seq_mask is None: | |||||
atte_mask_out = None | |||||
else: | |||||
atte_mask_out = (seq_mask < 1)[:,None,:] | |||||
seq_mask = seq_mask[:,:,None] | |||||
for layer in self.layers: | |||||
output = layer(output, seq_mask, atte_mask_out) | |||||
return output |
@@ -2,7 +2,8 @@ | |||||
n_epochs = 40 | n_epochs = 40 | ||||
batch_size = 32 | batch_size = 32 | ||||
use_cuda = true | use_cuda = true | ||||
validate_every = 500 | |||||
use_tqdm=true | |||||
validate_every = -1 | |||||
use_golden_train=true | use_golden_train=true | ||||
[test] | [test] | ||||
@@ -19,15 +20,13 @@ word_vocab_size = -1 | |||||
word_emb_dim = 100 | word_emb_dim = 100 | ||||
pos_vocab_size = -1 | pos_vocab_size = -1 | ||||
pos_emb_dim = 100 | pos_emb_dim = 100 | ||||
word_hid_dim = 100 | |||||
pos_hid_dim = 100 | |||||
rnn_layers = 3 | rnn_layers = 3 | ||||
rnn_hidden_size = 400 | |||||
rnn_hidden_size = 256 | |||||
arc_mlp_size = 500 | arc_mlp_size = 500 | ||||
label_mlp_size = 100 | label_mlp_size = 100 | ||||
num_label = -1 | num_label = -1 | ||||
dropout = 0.33 | |||||
use_var_lstm=true | |||||
dropout = 0.3 | |||||
encoder="transformer" | |||||
use_greedy_infer=false | use_greedy_infer=false | ||||
[optim] | [optim] | ||||
@@ -141,7 +141,7 @@ model_args['pos_vocab_size'] = len(pos_v) | |||||
model_args['num_label'] = len(tag_v) | model_args['num_label'] = len(tag_v) | ||||
model = BiaffineParser(**model_args.data) | model = BiaffineParser(**model_args.data) | ||||
model.reset_parameters() | |||||
print(model) | |||||
word_idxp = IndexerProcessor(word_v, 'words', 'word_seq') | word_idxp = IndexerProcessor(word_v, 'words', 'word_seq') | ||||
pos_idxp = IndexerProcessor(pos_v, 'pos', 'pos_seq') | pos_idxp = IndexerProcessor(pos_v, 'pos', 'pos_seq') | ||||
@@ -209,7 +209,8 @@ def save_pipe(path): | |||||
pipe = Pipeline(processors=[num_p, word_idxp, pos_idxp, seq_p, set_input_p]) | pipe = Pipeline(processors=[num_p, word_idxp, pos_idxp, seq_p, set_input_p]) | ||||
pipe.add_processor(ModelProcessor(model=model, batch_size=32)) | pipe.add_processor(ModelProcessor(model=model, batch_size=32)) | ||||
pipe.add_processor(label_toword_p) | pipe.add_processor(label_toword_p) | ||||
torch.save(pipe, os.path.join(path, 'pipe.pkl')) | |||||
os.makedirs(path, exist_ok=True) | |||||
torch.save({'pipeline': pipe}, os.path.join(path, 'pipe.pkl')) | |||||
def test(path): | def test(path): | ||||
@@ -77,9 +77,10 @@ class TestBiaffineParser(unittest.TestCase): | |||||
ds, v1, v2, v3 = init_data() | ds, v1, v2, v3 = init_data() | ||||
model = BiaffineParser(word_vocab_size=len(v1), word_emb_dim=30, | model = BiaffineParser(word_vocab_size=len(v1), word_emb_dim=30, | ||||
pos_vocab_size=len(v2), pos_emb_dim=30, | pos_vocab_size=len(v2), pos_emb_dim=30, | ||||
num_label=len(v3), use_var_lstm=True) | |||||
num_label=len(v3), encoder='var-lstm') | |||||
trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds, | trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds, | ||||
loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', | loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', | ||||
batch_size=1, validate_every=10, | |||||
n_epochs=10, use_cuda=False, use_tqdm=False) | n_epochs=10, use_cuda=False, use_tqdm=False) | ||||
trainer.train(load_best_model=False) | trainer.train(load_best_model=False) | ||||