@@ -35,23 +35,21 @@ class SeqLabelEvaluator(Evaluator): | |||
def __init__(self): | |||
super(SeqLabelEvaluator, self).__init__() | |||
def __call__(self, predict, truth): | |||
def __call__(self, predict, truth, **_): | |||
""" | |||
:param predict: list of dict, the network outputs from all batches. | |||
:param truth: list of dict, the ground truths from all batch_y. | |||
:return accuracy: | |||
""" | |||
truth = [item["truth"] for item in truth] | |||
predict = [item["predict"] for item in predict] | |||
total_correct, total_count = 0., 0. | |||
total_correct, total_count = 0., 0. | |||
for x, y in zip(predict, truth): | |||
# x = torch.tensor(x) | |||
y = y.to(x) # make sure they are in the same device | |||
mask = x.ge(1).long() | |||
correct = torch.sum(x * mask == y * mask) - torch.sum(x.le(0)) | |||
mask = (y > 0) | |||
correct = torch.sum(((x == y) * mask).long()) | |||
total_correct += float(correct) | |||
total_count += float(torch.sum(mask)) | |||
total_count += float(torch.sum(mask.long())) | |||
accuracy = total_correct / total_count | |||
return {"accuracy": float(accuracy)} | |||
@@ -1,4 +1,5 @@ | |||
import torch | |||
from collections import defaultdict | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.metrics import Evaluator | |||
@@ -71,17 +72,18 @@ class Tester(object): | |||
# turn on the testing mode; clean up the history | |||
self.mode(network, is_test=True) | |||
self.eval_history.clear() | |||
output_list = [] | |||
truth_list = [] | |||
output, truths = defaultdict(list), defaultdict(list) | |||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) | |||
with torch.no_grad(): | |||
for batch_x, batch_y in data_iterator: | |||
prediction = self.data_forward(network, batch_x) | |||
output_list.append(prediction) | |||
truth_list.append(batch_y) | |||
eval_results = self.evaluate(output_list, truth_list) | |||
assert isinstance(prediction, dict) | |||
for k, v in prediction.items(): | |||
output[k].append(v) | |||
for k, v in batch_y.items(): | |||
truths[k].append(v) | |||
eval_results = self.evaluate(**output, **truths) | |||
print("[tester] {}".format(self.print_eval_results(eval_results))) | |||
logger.info("[tester] {}".format(self.print_eval_results(eval_results))) | |||
self.mode(network, is_test=False) | |||
@@ -105,14 +107,10 @@ class Tester(object): | |||
y = network(**x) | |||
return y | |||
def evaluate(self, predict, truth): | |||
def evaluate(self, **kwargs): | |||
"""Compute evaluation metrics. | |||
:param predict: list of Tensor | |||
:param truth: list of dict | |||
:return eval_results: can be anything. It will be stored in self.eval_history | |||
""" | |||
return self._evaluator(predict, truth) | |||
return self._evaluator(**kwargs) | |||
def print_eval_results(self, results): | |||
"""Override this method to support more print formats. | |||
@@ -47,7 +47,8 @@ class Trainer(object): | |||
"valid_step": 500, "eval_sort_key": 'acc', | |||
"loss": Loss(None), # used to pass type check | |||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | |||
"evaluator": Evaluator() | |||
"eval_batch_size": 64, | |||
"evaluator": Evaluator(), | |||
} | |||
""" | |||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||
@@ -78,6 +79,7 @@ class Trainer(object): | |||
self.n_epochs = int(default_args["epochs"]) | |||
self.batch_size = int(default_args["batch_size"]) | |||
self.eval_batch_size = int(default_args['eval_batch_size']) | |||
self.pickle_path = default_args["pickle_path"] | |||
self.validate = default_args["validate"] | |||
self.save_best_dev = default_args["save_best_dev"] | |||
@@ -98,6 +100,8 @@ class Trainer(object): | |||
self._best_accuracy = 0.0 | |||
self.eval_sort_key = default_args['eval_sort_key'] | |||
self.validator = None | |||
self.epoch = 0 | |||
self.step = 0 | |||
def train(self, network, train_data, dev_data=None): | |||
"""General Training Procedure | |||
@@ -118,7 +122,7 @@ class Trainer(object): | |||
# define Tester over dev data | |||
self.dev_data = None | |||
if self.validate: | |||
default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path, | |||
default_valid_args = {"batch_size": self.eval_batch_size, "pickle_path": self.pickle_path, | |||
"use_cuda": self.use_cuda, "evaluator": self._evaluator} | |||
if self.validator is None: | |||
self.validator = self._create_validator(default_valid_args) | |||
@@ -139,9 +143,9 @@ class Trainer(object): | |||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | |||
print("training epochs started " + self.start_time) | |||
logger.info("training epochs started " + self.start_time) | |||
epoch, iters = 1, 0 | |||
while epoch <= self.n_epochs: | |||
logger.info("training epoch {}".format(epoch)) | |||
self.epoch, self.step = 1, 0 | |||
while self.epoch <= self.n_epochs: | |||
logger.info("training epoch {}".format(self.epoch)) | |||
# prepare mini-batch iterator | |||
data_iterator = Batch(train_data, batch_size=self.batch_size, | |||
@@ -150,14 +154,13 @@ class Trainer(object): | |||
logger.info("prepared data iterator") | |||
# one forward and backward pass | |||
iters = self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, | |||
step=iters, dev_data=dev_data) | |||
self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, dev_data=dev_data) | |||
# validation | |||
if self.validate: | |||
self.valid_model() | |||
self.save_model(self._model, 'training_model_' + self.start_time) | |||
epoch += 1 | |||
self.epoch += 1 | |||
def _train_step(self, data_iterator, network, **kwargs): | |||
"""Training process in one epoch. | |||
@@ -167,7 +170,6 @@ class Trainer(object): | |||
- start: time.time(), the starting time of this step. | |||
- epoch: int, | |||
""" | |||
step = kwargs['step'] | |||
for batch_x, batch_y in data_iterator: | |||
prediction = self.data_forward(network, batch_x) | |||
@@ -177,25 +179,31 @@ class Trainer(object): | |||
self.grad_backward(loss) | |||
self.update() | |||
self._summary_writer.add_scalar("loss", loss.item(), global_step=step) | |||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | |||
for name, param in self._model.named_parameters(): | |||
if param.requires_grad: | |||
<<<<<<< HEAD | |||
# self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=step) | |||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=step) | |||
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=step) | |||
pass | |||
if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: | |||
======= | |||
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) | |||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) | |||
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) | |||
if kwargs["n_print"] > 0 and self.step % kwargs["n_print"] == 0: | |||
>>>>>>> 5924fe0... fix and update tester, trainer, seq_model, add parser pipeline builder | |||
end = time.time() | |||
diff = timedelta(seconds=round(end - kwargs["start"])) | |||
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( | |||
kwargs["epoch"], step, loss.data, diff) | |||
self.epoch, self.step, loss.data, diff) | |||
print(print_output) | |||
logger.info(print_output) | |||
if self.validate and self.valid_step > 0 and step > 0 and step % self.valid_step == 0: | |||
if self.validate and self.valid_step > 0 and self.step > 0 and self.step % self.valid_step == 0: | |||
self.valid_model() | |||
step += 1 | |||
return step | |||
self.step += 1 | |||
def valid_model(self): | |||
if self.dev_data is None: | |||
@@ -203,6 +211,8 @@ class Trainer(object): | |||
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.") | |||
logger.info("validation started") | |||
res = self.validator.test(self._model, self.dev_data) | |||
for name, num in res.items(): | |||
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) | |||
if self.save_best_dev and self.best_eval_result(res): | |||
logger.info('save best result! {}'.format(res)) | |||
print('save best result! {}'.format(res)) | |||
@@ -10,6 +10,7 @@ from fastNLP.modules.utils import initial_parameter | |||
from fastNLP.modules.encoder.variational_rnn import VarLSTM | |||
from fastNLP.modules.dropout import TimestepDropout | |||
from fastNLP.models.base_model import BaseModel | |||
from fastNLP.modules.utils import seq_mask | |||
def mst(scores): | |||
""" | |||
@@ -123,31 +124,31 @@ class GraphParser(BaseModel): | |||
def forward(self, x): | |||
raise NotImplementedError | |||
def _greedy_decoder(self, arc_matrix, seq_mask=None): | |||
def _greedy_decoder(self, arc_matrix, mask=None): | |||
_, seq_len, _ = arc_matrix.shape | |||
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | |||
flip_mask = (seq_mask == 0).byte() | |||
flip_mask = (mask == 0).byte() | |||
matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||
_, heads = torch.max(matrix, dim=2) | |||
if seq_mask is not None: | |||
heads *= seq_mask.long() | |||
if mask is not None: | |||
heads *= mask.long() | |||
return heads | |||
def _mst_decoder(self, arc_matrix, seq_mask=None): | |||
def _mst_decoder(self, arc_matrix, mask=None): | |||
batch_size, seq_len, _ = arc_matrix.shape | |||
matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix) | |||
ans = matrix.new_zeros(batch_size, seq_len).long() | |||
lens = (seq_mask.long()).sum(1) if seq_mask is not None else torch.zeros(batch_size) + seq_len | |||
lens = (mask.long()).sum(1) if mask is not None else torch.zeros(batch_size) + seq_len | |||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=lens.device) | |||
seq_mask[batch_idx, lens-1] = 0 | |||
mask[batch_idx, lens-1] = 0 | |||
for i, graph in enumerate(matrix): | |||
len_i = lens[i] | |||
if len_i == seq_len: | |||
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device) | |||
else: | |||
ans[i, :len_i] = torch.as_tensor(mst(graph[:len_i, :len_i].cpu().numpy()), device=ans.device) | |||
if seq_mask is not None: | |||
ans *= seq_mask.long() | |||
if mask is not None: | |||
ans *= mask.long() | |||
return ans | |||
@@ -191,13 +192,6 @@ class LabelBilinear(nn.Module): | |||
output += self.lin(torch.cat([x1, x2], dim=2)) | |||
return output | |||
def len2masks(origin_len, max_len): | |||
if origin_len.dim() <= 1: | |||
origin_len = origin_len.unsqueeze(1) # [batch_size, 1] | |||
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=origin_len.device) # [max_len,] | |||
seq_mask = torch.gt(origin_len, seq_range.unsqueeze(0)) # [batch_size, max_len] | |||
return seq_mask | |||
class BiaffineParser(GraphParser): | |||
"""Biaffine Dependency Parser implemantation. | |||
refer to ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) | |||
@@ -277,12 +271,12 @@ class BiaffineParser(GraphParser): | |||
""" | |||
:param word_seq: [batch_size, seq_len] sequence of word's indices | |||
:param pos_seq: [batch_size, seq_len] sequence of word's indices | |||
:param seq_mask: [batch_size, seq_len] sequence of length masks | |||
:param word_seq_origin_len: [batch_size, seq_len] sequence of length masks | |||
:param gold_heads: [batch_size, seq_len] sequence of golden heads | |||
:return dict: parsing results | |||
arc_pred: [batch_size, seq_len, seq_len] | |||
label_pred: [batch_size, seq_len, seq_len] | |||
seq_mask: [batch_size, seq_len] | |||
mask: [batch_size, seq_len] | |||
head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads | |||
""" | |||
# prepare embeddings | |||
@@ -294,7 +288,7 @@ class BiaffineParser(GraphParser): | |||
# print('forward {} {}'.format(batch_size, seq_len)) | |||
# get sequence mask | |||
seq_mask = len2masks(word_seq_origin_len, seq_len).long() | |||
mask = seq_mask(word_seq_origin_len, seq_len).long() | |||
word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] | |||
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] | |||
@@ -327,14 +321,14 @@ class BiaffineParser(GraphParser): | |||
if gold_heads is None or not self.training: | |||
# use greedy decoding in training | |||
if self.training or self.use_greedy_infer: | |||
heads = self._greedy_decoder(arc_pred, seq_mask) | |||
heads = self._greedy_decoder(arc_pred, mask) | |||
else: | |||
heads = self._mst_decoder(arc_pred, seq_mask) | |||
heads = self._mst_decoder(arc_pred, mask) | |||
head_pred = heads | |||
else: | |||
assert self.training # must be training mode | |||
if torch.rand(1).item() < self.explore_p: | |||
heads = self._greedy_decoder(arc_pred, seq_mask) | |||
heads = self._greedy_decoder(arc_pred, mask) | |||
head_pred = heads | |||
else: | |||
head_pred = None | |||
@@ -343,12 +337,12 @@ class BiaffineParser(GraphParser): | |||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) | |||
label_head = label_head[batch_range, heads].contiguous() | |||
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | |||
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'seq_mask': seq_mask} | |||
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask} | |||
if head_pred is not None: | |||
res_dict['head_pred'] = head_pred | |||
return res_dict | |||
def loss(self, arc_pred, label_pred, head_indices, head_labels, seq_mask, **_): | |||
def loss(self, arc_pred, label_pred, head_indices, head_labels, mask, **_): | |||
""" | |||
Compute loss. | |||
@@ -356,12 +350,12 @@ class BiaffineParser(GraphParser): | |||
:param label_pred: [batch_size, seq_len, n_tags] | |||
:param head_indices: [batch_size, seq_len] | |||
:param head_labels: [batch_size, seq_len] | |||
:param seq_mask: [batch_size, seq_len] | |||
:param mask: [batch_size, seq_len] | |||
:return: loss value | |||
""" | |||
batch_size, seq_len, _ = arc_pred.shape | |||
flip_mask = (seq_mask == 0) | |||
flip_mask = (mask == 0) | |||
_arc_pred = arc_pred.new_empty((batch_size, seq_len, seq_len)).copy_(arc_pred) | |||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||
arc_logits = F.log_softmax(_arc_pred, dim=2) | |||
@@ -374,7 +368,7 @@ class BiaffineParser(GraphParser): | |||
arc_loss = arc_loss[:, 1:] | |||
label_loss = label_loss[:, 1:] | |||
float_mask = seq_mask[:, 1:].float() | |||
float_mask = mask[:, 1:].float() | |||
arc_nll = -(arc_loss*float_mask).mean() | |||
label_nll = -(label_loss*float_mask).mean() | |||
return arc_nll + label_nll | |||
@@ -4,20 +4,7 @@ import numpy as np | |||
from fastNLP.models.base_model import BaseModel | |||
from fastNLP.modules import decoder, encoder | |||
def seq_mask(seq_len, max_len): | |||
"""Create a mask for the sequences. | |||
:param seq_len: list or torch.LongTensor | |||
:param max_len: int | |||
:return mask: torch.LongTensor | |||
""" | |||
if isinstance(seq_len, list): | |||
seq_len = torch.LongTensor(seq_len) | |||
mask = [torch.ge(seq_len, i + 1) for i in range(max_len)] | |||
mask = torch.stack(mask, 1) | |||
return mask | |||
from fastNLP.modules.utils import seq_mask | |||
class SeqLabeling(BaseModel): | |||
@@ -82,7 +69,7 @@ class SeqLabeling(BaseModel): | |||
def make_mask(self, x, seq_len): | |||
batch_size, max_len = x.size(0), x.size(1) | |||
mask = seq_mask(seq_len, max_len) | |||
mask = mask.byte().view(batch_size, max_len) | |||
mask = mask.view(batch_size, max_len) | |||
mask = mask.to(x).float() | |||
return mask | |||
@@ -114,16 +101,20 @@ class AdvSeqLabel(SeqLabeling): | |||
word_emb_dim = args["word_emb_dim"] | |||
hidden_dim = args["rnn_hidden_units"] | |||
num_classes = args["num_classes"] | |||
dropout = args['dropout'] | |||
self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) | |||
self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=1, dropout=0.5, bidirectional=True) | |||
self.norm1 = torch.nn.LayerNorm(word_emb_dim) | |||
# self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=2, dropout=dropout, bidirectional=True) | |||
self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, bidirectional=True, batch_first=True) | |||
self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) | |||
self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | |||
self.relu = torch.nn.ReLU() | |||
self.drop = torch.nn.Dropout(0.5) | |||
self.norm2 = torch.nn.LayerNorm(hidden_dim * 2 // 3) | |||
# self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | |||
self.relu = torch.nn.LeakyReLU() | |||
self.drop = torch.nn.Dropout(dropout) | |||
self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes) | |||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | |||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||
def forward(self, word_seq, word_seq_origin_len, truth=None): | |||
""" | |||
@@ -135,12 +126,10 @@ class AdvSeqLabel(SeqLabeling): | |||
""" | |||
word_seq = word_seq.long() | |||
word_seq_origin_len = word_seq_origin_len.long() | |||
self.mask = self.make_mask(word_seq, word_seq_origin_len) | |||
word_seq_origin_len = word_seq_origin_len.cpu().numpy() | |||
sent_len, idx_sort = np.sort(word_seq_origin_len)[::-1], np.argsort(-word_seq_origin_len) | |||
idx_unsort = np.argsort(idx_sort) | |||
idx_sort = torch.from_numpy(idx_sort) | |||
idx_unsort = torch.from_numpy(idx_unsort) | |||
sent_len, idx_sort = torch.sort(word_seq_origin_len, descending=True) | |||
_, idx_unsort = torch.sort(idx_sort, descending=False) | |||
# word_seq_origin_len = word_seq_origin_len.long() | |||
truth = truth.long() if truth is not None else None | |||
@@ -155,26 +144,28 @@ class AdvSeqLabel(SeqLabeling): | |||
truth = truth.cuda() if truth is not None else None | |||
x = self.Embedding(word_seq) | |||
x = self.norm1(x) | |||
# [batch_size, max_len, word_emb_dim] | |||
sent_variable = x.index_select(0, idx_sort) | |||
sent_variable = x[idx_sort] | |||
sent_packed = torch.nn.utils.rnn.pack_padded_sequence(sent_variable, sent_len, batch_first=True) | |||
x = self.Rnn(sent_packed) | |||
x, _ = self.Rnn(sent_packed) | |||
# print(x) | |||
# [batch_size, max_len, hidden_size * direction] | |||
sent_output = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)[0] | |||
x = sent_output.index_select(0, idx_unsort) | |||
x = sent_output[idx_unsort] | |||
x = x.contiguous() | |||
x = x.view(batch_size * max_len, -1) | |||
# x = x.view(batch_size * max_len, -1) | |||
x = self.Linear1(x) | |||
# x = self.batch_norm(x) | |||
x = self.norm2(x) | |||
x = self.relu(x) | |||
x = self.drop(x) | |||
x = self.Linear2(x) | |||
x = x.view(batch_size, max_len, -1) | |||
# x = x.view(batch_size, max_len, -1) | |||
# [batch_size, max_len, num_classes] | |||
return {"loss": self._internal_loss(x, truth) if truth is not None else None, | |||
"predict": self.decode(x)} | |||
@@ -183,41 +174,45 @@ class AdvSeqLabel(SeqLabeling): | |||
out = self.forward(**x) | |||
return {"predict": out["predict"]} | |||
args = { | |||
'vocab_size': 20, | |||
'word_emb_dim': 100, | |||
'rnn_hidden_units': 100, | |||
'num_classes': 10, | |||
} | |||
model = AdvSeqLabel(args) | |||
data = [] | |||
for i in range(20): | |||
word_seq = torch.randint(20, (15,)).long() | |||
word_seq_len = torch.LongTensor([15]) | |||
truth = torch.randint(10, (15,)).long() | |||
data.append((word_seq, word_seq_len, truth)) | |||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) | |||
print(model) | |||
curidx = 0 | |||
for i in range(1000): | |||
endidx = min(len(data), curidx + 5) | |||
b_word, b_len, b_truth = [], [], [] | |||
for word_seq, word_seq_len, truth in data[curidx: endidx]: | |||
b_word.append(word_seq) | |||
b_len.append(word_seq_len) | |||
b_truth.append(truth) | |||
word_seq = torch.stack(b_word, dim=0) | |||
word_seq_len = torch.cat(b_len, dim=0) | |||
truth = torch.stack(b_truth, dim=0) | |||
res = model(word_seq, word_seq_len, truth) | |||
loss = res['loss'] | |||
pred = res['predict'] | |||
print('loss: {} acc {}'.format(loss.item(), ((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) | |||
optimizer.zero_grad() | |||
loss.backward() | |||
optimizer.step() | |||
curidx = endidx | |||
if curidx == len(data): | |||
curidx = 0 | |||
def loss(self, **kwargs): | |||
assert 'loss' in kwargs | |||
return kwargs['loss'] | |||
if __name__ == '__main__': | |||
args = { | |||
'vocab_size': 20, | |||
'word_emb_dim': 100, | |||
'rnn_hidden_units': 100, | |||
'num_classes': 10, | |||
} | |||
model = AdvSeqLabel(args) | |||
data = [] | |||
for i in range(20): | |||
word_seq = torch.randint(20, (15,)).long() | |||
word_seq_len = torch.LongTensor([15]) | |||
truth = torch.randint(10, (15,)).long() | |||
data.append((word_seq, word_seq_len, truth)) | |||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) | |||
print(model) | |||
curidx = 0 | |||
for i in range(1000): | |||
endidx = min(len(data), curidx + 5) | |||
b_word, b_len, b_truth = [], [], [] | |||
for word_seq, word_seq_len, truth in data[curidx: endidx]: | |||
b_word.append(word_seq) | |||
b_len.append(word_seq_len) | |||
b_truth.append(truth) | |||
word_seq = torch.stack(b_word, dim=0) | |||
word_seq_len = torch.cat(b_len, dim=0) | |||
truth = torch.stack(b_truth, dim=0) | |||
res = model(word_seq, word_seq_len, truth) | |||
loss = res['loss'] | |||
pred = res['predict'] | |||
print('loss: {} acc {}'.format(loss.item(), ((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) | |||
optimizer.zero_grad() | |||
loss.backward() | |||
optimizer.step() | |||
curidx = endidx | |||
if curidx == len(data): | |||
curidx = 0 | |||
@@ -77,11 +77,13 @@ def initial_parameter(net, initial_method=None): | |||
def seq_mask(seq_len, max_len): | |||
"""Create sequence mask. | |||
:param seq_len: list of int, the lengths of sequences in a batch. | |||
:param seq_len: list or torch.Tensor, the lengths of sequences in a batch. | |||
:param max_len: int, the maximum sequence length in a batch. | |||
:return mask: torch.LongTensor, [batch_size, max_len] | |||
""" | |||
mask = [torch.ge(torch.LongTensor(seq_len), i + 1) for i in range(max_len)] | |||
mask = torch.stack(mask, 1) | |||
return mask | |||
if not isinstance(seq_len, torch.Tensor): | |||
seq_len = torch.LongTensor(seq_len) | |||
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1] | |||
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len] | |||
return torch.gt(seq_len, seq_range) # [batch_size, max_len] |
@@ -0,0 +1,80 @@ | |||
import sys | |||
import os | |||
sys.path.extend(['/home/yfshao/workdir/dev_fastnlp']) | |||
from fastNLP.api.processor import * | |||
from fastNLP.api.pipeline import Pipeline | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.models.biaffine_parser import BiaffineParser | |||
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | |||
import _pickle as pickle | |||
import torch | |||
def _load(path): | |||
with open(path, 'rb') as f: | |||
obj = pickle.load(f) | |||
return obj | |||
def _load_all(src): | |||
model_path = src | |||
src = os.path.dirname(src) | |||
word_v = _load(src+'/word_v.pkl') | |||
pos_v = _load(src+'/pos_v.pkl') | |||
tag_v = _load(src+'/tag_v.pkl') | |||
model_args = ConfigSection() | |||
ConfigLoader.load_config('cfg.cfg', {'model': model_args}) | |||
model_args['word_vocab_size'] = len(word_v) | |||
model_args['pos_vocab_size'] = len(pos_v) | |||
model_args['num_label'] = len(tag_v) | |||
model = BiaffineParser(**model_args.data) | |||
model.load_state_dict(torch.load(model_path)) | |||
return { | |||
'word_v': word_v, | |||
'pos_v': pos_v, | |||
'tag_v': tag_v, | |||
'model': model, | |||
} | |||
def build(load_path, save_path): | |||
BOS = '<BOS>' | |||
NUM = '<NUM>' | |||
_dict = _load_all(load_path) | |||
word_vocab = _dict['word_v'] | |||
pos_vocab = _dict['pos_v'] | |||
tag_vocab = _dict['tag_v'] | |||
model = _dict['model'] | |||
print('load model from {}'.format(load_path)) | |||
word_seq = 'raw_word_seq' | |||
pos_seq = 'raw_pos_seq' | |||
# build pipeline | |||
pipe = Pipeline() | |||
pipe.add_processor(Num2TagProcessor(NUM, 'sentence', word_seq)) | |||
pipe.add_processor(PreAppendProcessor(BOS, word_seq)) | |||
pipe.add_processor(PreAppendProcessor(BOS, 'sent_pos', pos_seq)) | |||
pipe.add_processor(IndexerProcessor(word_vocab, word_seq, 'word_seq')) | |||
pipe.add_processor(IndexerProcessor(pos_vocab, pos_seq, 'pos_seq')) | |||
pipe.add_processor(SeqLenProcessor(word_seq, 'word_seq_origin_len')) | |||
pipe.add_processor(SetTensorProcessor({'word_seq':True, 'pos_seq':True, 'word_seq_origin_len':True}, default=False)) | |||
pipe.add_processor(ModelProcessor(model, 'word_seq_origin_len')) | |||
pipe.add_processor(SliceProcessor(1, None, None, 'head_pred', 'heads')) | |||
pipe.add_processor(SliceProcessor(1, None, None, 'label_pred', 'label_pred')) | |||
pipe.add_processor(Index2WordProcessor(tag_vocab, 'label_pred', 'labels')) | |||
if not os.path.exists(save_path): | |||
os.makedirs(save_path) | |||
with open(save_path+'/pipeline.pkl', 'wb') as f: | |||
torch.save(pipe, f) | |||
print('save pipeline in {}'.format(save_path)) | |||
import argparse | |||
parser = argparse.ArgumentParser(description='build pipeline for parser.') | |||
parser.add_argument('--src', type=str, default='/home/yfshao/workdir/dev_fastnlp/reproduction/Biaffine_parser/save') | |||
parser.add_argument('--dst', type=str, default='/home/yfshao/workdir/dev_fastnlp/reproduction/Biaffine_parser/pipe') | |||
args = parser.parse_args() | |||
build(args.src, args.dst) |