From 822aaf6286899e163a5162ba9b474ac13719b3eb Mon Sep 17 00:00:00 2001 From: yunfan Date: Mon, 12 Nov 2018 21:37:56 +0800 Subject: [PATCH] fix and update tester, trainer, seq_model, add parser pipeline builder --- fastNLP/core/metrics.py | 12 +-- fastNLP/core/tester.py | 22 ++--- fastNLP/core/trainer.py | 38 +++++--- fastNLP/models/biaffine_parser.py | 48 +++++----- fastNLP/models/sequence_modeling.py | 129 +++++++++++++------------- fastNLP/modules/utils.py | 10 +- reproduction/Biaffine_parser/infer.py | 80 ++++++++++++++++ 7 files changed, 208 insertions(+), 131 deletions(-) create mode 100644 reproduction/Biaffine_parser/infer.py diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 73203b1c..2e02c531 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -35,23 +35,21 @@ class SeqLabelEvaluator(Evaluator): def __init__(self): super(SeqLabelEvaluator, self).__init__() - def __call__(self, predict, truth): + def __call__(self, predict, truth, **_): """ :param predict: list of dict, the network outputs from all batches. :param truth: list of dict, the ground truths from all batch_y. :return accuracy: """ - truth = [item["truth"] for item in truth] - predict = [item["predict"] for item in predict] - total_correct, total_count = 0., 0. + total_correct, total_count = 0., 0. for x, y in zip(predict, truth): # x = torch.tensor(x) y = y.to(x) # make sure they are in the same device - mask = x.ge(1).long() - correct = torch.sum(x * mask == y * mask) - torch.sum(x.le(0)) + mask = (y > 0) + correct = torch.sum(((x == y) * mask).long()) total_correct += float(correct) - total_count += float(torch.sum(mask)) + total_count += float(torch.sum(mask.long())) accuracy = total_correct / total_count return {"accuracy": float(accuracy)} diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 51f84691..dfdd397d 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -1,4 +1,5 @@ import torch +from collections import defaultdict from fastNLP.core.batch import Batch from fastNLP.core.metrics import Evaluator @@ -71,17 +72,18 @@ class Tester(object): # turn on the testing mode; clean up the history self.mode(network, is_test=True) self.eval_history.clear() - output_list = [] - truth_list = [] - + output, truths = defaultdict(list), defaultdict(list) data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) with torch.no_grad(): for batch_x, batch_y in data_iterator: prediction = self.data_forward(network, batch_x) - output_list.append(prediction) - truth_list.append(batch_y) - eval_results = self.evaluate(output_list, truth_list) + assert isinstance(prediction, dict) + for k, v in prediction.items(): + output[k].append(v) + for k, v in batch_y.items(): + truths[k].append(v) + eval_results = self.evaluate(**output, **truths) print("[tester] {}".format(self.print_eval_results(eval_results))) logger.info("[tester] {}".format(self.print_eval_results(eval_results))) self.mode(network, is_test=False) @@ -105,14 +107,10 @@ class Tester(object): y = network(**x) return y - def evaluate(self, predict, truth): + def evaluate(self, **kwargs): """Compute evaluation metrics. - - :param predict: list of Tensor - :param truth: list of dict - :return eval_results: can be anything. It will be stored in self.eval_history """ - return self._evaluator(predict, truth) + return self._evaluator(**kwargs) def print_eval_results(self, results): """Override this method to support more print formats. diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index aa2cd385..3f1525b7 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -47,7 +47,8 @@ class Trainer(object): "valid_step": 500, "eval_sort_key": 'acc', "loss": Loss(None), # used to pass type check "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), - "evaluator": Evaluator() + "eval_batch_size": 64, + "evaluator": Evaluator(), } """ "required_args" is the collection of arguments that users must pass to Trainer explicitly. @@ -78,6 +79,7 @@ class Trainer(object): self.n_epochs = int(default_args["epochs"]) self.batch_size = int(default_args["batch_size"]) + self.eval_batch_size = int(default_args['eval_batch_size']) self.pickle_path = default_args["pickle_path"] self.validate = default_args["validate"] self.save_best_dev = default_args["save_best_dev"] @@ -98,6 +100,8 @@ class Trainer(object): self._best_accuracy = 0.0 self.eval_sort_key = default_args['eval_sort_key'] self.validator = None + self.epoch = 0 + self.step = 0 def train(self, network, train_data, dev_data=None): """General Training Procedure @@ -118,7 +122,7 @@ class Trainer(object): # define Tester over dev data self.dev_data = None if self.validate: - default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path, + default_valid_args = {"batch_size": self.eval_batch_size, "pickle_path": self.pickle_path, "use_cuda": self.use_cuda, "evaluator": self._evaluator} if self.validator is None: self.validator = self._create_validator(default_valid_args) @@ -139,9 +143,9 @@ class Trainer(object): self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) print("training epochs started " + self.start_time) logger.info("training epochs started " + self.start_time) - epoch, iters = 1, 0 - while epoch <= self.n_epochs: - logger.info("training epoch {}".format(epoch)) + self.epoch, self.step = 1, 0 + while self.epoch <= self.n_epochs: + logger.info("training epoch {}".format(self.epoch)) # prepare mini-batch iterator data_iterator = Batch(train_data, batch_size=self.batch_size, @@ -150,14 +154,13 @@ class Trainer(object): logger.info("prepared data iterator") # one forward and backward pass - iters = self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, - step=iters, dev_data=dev_data) + self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, dev_data=dev_data) # validation if self.validate: self.valid_model() self.save_model(self._model, 'training_model_' + self.start_time) - epoch += 1 + self.epoch += 1 def _train_step(self, data_iterator, network, **kwargs): """Training process in one epoch. @@ -167,7 +170,6 @@ class Trainer(object): - start: time.time(), the starting time of this step. - epoch: int, """ - step = kwargs['step'] for batch_x, batch_y in data_iterator: prediction = self.data_forward(network, batch_x) @@ -177,25 +179,31 @@ class Trainer(object): self.grad_backward(loss) self.update() - self._summary_writer.add_scalar("loss", loss.item(), global_step=step) + self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) for name, param in self._model.named_parameters(): if param.requires_grad: +<<<<<<< HEAD # self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=step) # self._summary_writer.add_scalar(name + "_std", param.std(), global_step=step) # self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=step) pass if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: +======= + self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) + # self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) + # self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) + if kwargs["n_print"] > 0 and self.step % kwargs["n_print"] == 0: +>>>>>>> 5924fe0... fix and update tester, trainer, seq_model, add parser pipeline builder end = time.time() diff = timedelta(seconds=round(end - kwargs["start"])) print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( - kwargs["epoch"], step, loss.data, diff) + self.epoch, self.step, loss.data, diff) print(print_output) logger.info(print_output) - if self.validate and self.valid_step > 0 and step > 0 and step % self.valid_step == 0: + if self.validate and self.valid_step > 0 and self.step > 0 and self.step % self.valid_step == 0: self.valid_model() - step += 1 - return step + self.step += 1 def valid_model(self): if self.dev_data is None: @@ -203,6 +211,8 @@ class Trainer(object): "self.validate is True in trainer, but dev_data is None. Please provide the validation data.") logger.info("validation started") res = self.validator.test(self._model, self.dev_data) + for name, num in res.items(): + self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) if self.save_best_dev and self.best_eval_result(res): logger.info('save best result! {}'.format(res)) print('save best result! {}'.format(res)) diff --git a/fastNLP/models/biaffine_parser.py b/fastNLP/models/biaffine_parser.py index 43239f8c..2a42116c 100644 --- a/fastNLP/models/biaffine_parser.py +++ b/fastNLP/models/biaffine_parser.py @@ -10,6 +10,7 @@ from fastNLP.modules.utils import initial_parameter from fastNLP.modules.encoder.variational_rnn import VarLSTM from fastNLP.modules.dropout import TimestepDropout from fastNLP.models.base_model import BaseModel +from fastNLP.modules.utils import seq_mask def mst(scores): """ @@ -123,31 +124,31 @@ class GraphParser(BaseModel): def forward(self, x): raise NotImplementedError - def _greedy_decoder(self, arc_matrix, seq_mask=None): + def _greedy_decoder(self, arc_matrix, mask=None): _, seq_len, _ = arc_matrix.shape matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) - flip_mask = (seq_mask == 0).byte() + flip_mask = (mask == 0).byte() matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf) _, heads = torch.max(matrix, dim=2) - if seq_mask is not None: - heads *= seq_mask.long() + if mask is not None: + heads *= mask.long() return heads - def _mst_decoder(self, arc_matrix, seq_mask=None): + def _mst_decoder(self, arc_matrix, mask=None): batch_size, seq_len, _ = arc_matrix.shape matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix) ans = matrix.new_zeros(batch_size, seq_len).long() - lens = (seq_mask.long()).sum(1) if seq_mask is not None else torch.zeros(batch_size) + seq_len + lens = (mask.long()).sum(1) if mask is not None else torch.zeros(batch_size) + seq_len batch_idx = torch.arange(batch_size, dtype=torch.long, device=lens.device) - seq_mask[batch_idx, lens-1] = 0 + mask[batch_idx, lens-1] = 0 for i, graph in enumerate(matrix): len_i = lens[i] if len_i == seq_len: ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device) else: ans[i, :len_i] = torch.as_tensor(mst(graph[:len_i, :len_i].cpu().numpy()), device=ans.device) - if seq_mask is not None: - ans *= seq_mask.long() + if mask is not None: + ans *= mask.long() return ans @@ -191,13 +192,6 @@ class LabelBilinear(nn.Module): output += self.lin(torch.cat([x1, x2], dim=2)) return output -def len2masks(origin_len, max_len): - if origin_len.dim() <= 1: - origin_len = origin_len.unsqueeze(1) # [batch_size, 1] - seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=origin_len.device) # [max_len,] - seq_mask = torch.gt(origin_len, seq_range.unsqueeze(0)) # [batch_size, max_len] - return seq_mask - class BiaffineParser(GraphParser): """Biaffine Dependency Parser implemantation. refer to ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) @@ -277,12 +271,12 @@ class BiaffineParser(GraphParser): """ :param word_seq: [batch_size, seq_len] sequence of word's indices :param pos_seq: [batch_size, seq_len] sequence of word's indices - :param seq_mask: [batch_size, seq_len] sequence of length masks + :param word_seq_origin_len: [batch_size, seq_len] sequence of length masks :param gold_heads: [batch_size, seq_len] sequence of golden heads :return dict: parsing results arc_pred: [batch_size, seq_len, seq_len] label_pred: [batch_size, seq_len, seq_len] - seq_mask: [batch_size, seq_len] + mask: [batch_size, seq_len] head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads """ # prepare embeddings @@ -294,7 +288,7 @@ class BiaffineParser(GraphParser): # print('forward {} {}'.format(batch_size, seq_len)) # get sequence mask - seq_mask = len2masks(word_seq_origin_len, seq_len).long() + mask = seq_mask(word_seq_origin_len, seq_len).long() word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] @@ -327,14 +321,14 @@ class BiaffineParser(GraphParser): if gold_heads is None or not self.training: # use greedy decoding in training if self.training or self.use_greedy_infer: - heads = self._greedy_decoder(arc_pred, seq_mask) + heads = self._greedy_decoder(arc_pred, mask) else: - heads = self._mst_decoder(arc_pred, seq_mask) + heads = self._mst_decoder(arc_pred, mask) head_pred = heads else: assert self.training # must be training mode if torch.rand(1).item() < self.explore_p: - heads = self._greedy_decoder(arc_pred, seq_mask) + heads = self._greedy_decoder(arc_pred, mask) head_pred = heads else: head_pred = None @@ -343,12 +337,12 @@ class BiaffineParser(GraphParser): batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) label_head = label_head[batch_range, heads].contiguous() label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] - res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'seq_mask': seq_mask} + res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask} if head_pred is not None: res_dict['head_pred'] = head_pred return res_dict - def loss(self, arc_pred, label_pred, head_indices, head_labels, seq_mask, **_): + def loss(self, arc_pred, label_pred, head_indices, head_labels, mask, **_): """ Compute loss. @@ -356,12 +350,12 @@ class BiaffineParser(GraphParser): :param label_pred: [batch_size, seq_len, n_tags] :param head_indices: [batch_size, seq_len] :param head_labels: [batch_size, seq_len] - :param seq_mask: [batch_size, seq_len] + :param mask: [batch_size, seq_len] :return: loss value """ batch_size, seq_len, _ = arc_pred.shape - flip_mask = (seq_mask == 0) + flip_mask = (mask == 0) _arc_pred = arc_pred.new_empty((batch_size, seq_len, seq_len)).copy_(arc_pred) _arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) arc_logits = F.log_softmax(_arc_pred, dim=2) @@ -374,7 +368,7 @@ class BiaffineParser(GraphParser): arc_loss = arc_loss[:, 1:] label_loss = label_loss[:, 1:] - float_mask = seq_mask[:, 1:].float() + float_mask = mask[:, 1:].float() arc_nll = -(arc_loss*float_mask).mean() label_nll = -(label_loss*float_mask).mean() return arc_nll + label_nll diff --git a/fastNLP/models/sequence_modeling.py b/fastNLP/models/sequence_modeling.py index 61a742b3..f9813144 100644 --- a/fastNLP/models/sequence_modeling.py +++ b/fastNLP/models/sequence_modeling.py @@ -4,20 +4,7 @@ import numpy as np from fastNLP.models.base_model import BaseModel from fastNLP.modules import decoder, encoder - - -def seq_mask(seq_len, max_len): - """Create a mask for the sequences. - - :param seq_len: list or torch.LongTensor - :param max_len: int - :return mask: torch.LongTensor - """ - if isinstance(seq_len, list): - seq_len = torch.LongTensor(seq_len) - mask = [torch.ge(seq_len, i + 1) for i in range(max_len)] - mask = torch.stack(mask, 1) - return mask +from fastNLP.modules.utils import seq_mask class SeqLabeling(BaseModel): @@ -82,7 +69,7 @@ class SeqLabeling(BaseModel): def make_mask(self, x, seq_len): batch_size, max_len = x.size(0), x.size(1) mask = seq_mask(seq_len, max_len) - mask = mask.byte().view(batch_size, max_len) + mask = mask.view(batch_size, max_len) mask = mask.to(x).float() return mask @@ -114,16 +101,20 @@ class AdvSeqLabel(SeqLabeling): word_emb_dim = args["word_emb_dim"] hidden_dim = args["rnn_hidden_units"] num_classes = args["num_classes"] + dropout = args['dropout'] self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) - self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=1, dropout=0.5, bidirectional=True) + self.norm1 = torch.nn.LayerNorm(word_emb_dim) + # self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=2, dropout=dropout, bidirectional=True) + self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, bidirectional=True, batch_first=True) self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) - self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) - self.relu = torch.nn.ReLU() - self.drop = torch.nn.Dropout(0.5) + self.norm2 = torch.nn.LayerNorm(hidden_dim * 2 // 3) + # self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) + self.relu = torch.nn.LeakyReLU() + self.drop = torch.nn.Dropout(dropout) self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes) - self.Crf = decoder.CRF.ConditionalRandomField(num_classes) + self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) def forward(self, word_seq, word_seq_origin_len, truth=None): """ @@ -135,12 +126,10 @@ class AdvSeqLabel(SeqLabeling): """ word_seq = word_seq.long() + word_seq_origin_len = word_seq_origin_len.long() self.mask = self.make_mask(word_seq, word_seq_origin_len) - word_seq_origin_len = word_seq_origin_len.cpu().numpy() - sent_len, idx_sort = np.sort(word_seq_origin_len)[::-1], np.argsort(-word_seq_origin_len) - idx_unsort = np.argsort(idx_sort) - idx_sort = torch.from_numpy(idx_sort) - idx_unsort = torch.from_numpy(idx_unsort) + sent_len, idx_sort = torch.sort(word_seq_origin_len, descending=True) + _, idx_unsort = torch.sort(idx_sort, descending=False) # word_seq_origin_len = word_seq_origin_len.long() truth = truth.long() if truth is not None else None @@ -155,26 +144,28 @@ class AdvSeqLabel(SeqLabeling): truth = truth.cuda() if truth is not None else None x = self.Embedding(word_seq) + x = self.norm1(x) # [batch_size, max_len, word_emb_dim] - sent_variable = x.index_select(0, idx_sort) + sent_variable = x[idx_sort] sent_packed = torch.nn.utils.rnn.pack_padded_sequence(sent_variable, sent_len, batch_first=True) - x = self.Rnn(sent_packed) + x, _ = self.Rnn(sent_packed) # print(x) # [batch_size, max_len, hidden_size * direction] sent_output = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)[0] - x = sent_output.index_select(0, idx_unsort) + x = sent_output[idx_unsort] x = x.contiguous() - x = x.view(batch_size * max_len, -1) + # x = x.view(batch_size * max_len, -1) x = self.Linear1(x) # x = self.batch_norm(x) + x = self.norm2(x) x = self.relu(x) x = self.drop(x) x = self.Linear2(x) - x = x.view(batch_size, max_len, -1) + # x = x.view(batch_size, max_len, -1) # [batch_size, max_len, num_classes] return {"loss": self._internal_loss(x, truth) if truth is not None else None, "predict": self.decode(x)} @@ -183,41 +174,45 @@ class AdvSeqLabel(SeqLabeling): out = self.forward(**x) return {"predict": out["predict"]} - -args = { - 'vocab_size': 20, - 'word_emb_dim': 100, - 'rnn_hidden_units': 100, - 'num_classes': 10, -} -model = AdvSeqLabel(args) -data = [] -for i in range(20): - word_seq = torch.randint(20, (15,)).long() - word_seq_len = torch.LongTensor([15]) - truth = torch.randint(10, (15,)).long() - data.append((word_seq, word_seq_len, truth)) -optimizer = torch.optim.Adam(model.parameters(), lr=0.01) -print(model) -curidx = 0 -for i in range(1000): - endidx = min(len(data), curidx + 5) - b_word, b_len, b_truth = [], [], [] - for word_seq, word_seq_len, truth in data[curidx: endidx]: - b_word.append(word_seq) - b_len.append(word_seq_len) - b_truth.append(truth) - word_seq = torch.stack(b_word, dim=0) - word_seq_len = torch.cat(b_len, dim=0) - truth = torch.stack(b_truth, dim=0) - res = model(word_seq, word_seq_len, truth) - loss = res['loss'] - pred = res['predict'] - print('loss: {} acc {}'.format(loss.item(), ((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) - optimizer.zero_grad() - loss.backward() - optimizer.step() - curidx = endidx - if curidx == len(data): - curidx = 0 + def loss(self, **kwargs): + assert 'loss' in kwargs + return kwargs['loss'] + +if __name__ == '__main__': + args = { + 'vocab_size': 20, + 'word_emb_dim': 100, + 'rnn_hidden_units': 100, + 'num_classes': 10, + } + model = AdvSeqLabel(args) + data = [] + for i in range(20): + word_seq = torch.randint(20, (15,)).long() + word_seq_len = torch.LongTensor([15]) + truth = torch.randint(10, (15,)).long() + data.append((word_seq, word_seq_len, truth)) + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + print(model) + curidx = 0 + for i in range(1000): + endidx = min(len(data), curidx + 5) + b_word, b_len, b_truth = [], [], [] + for word_seq, word_seq_len, truth in data[curidx: endidx]: + b_word.append(word_seq) + b_len.append(word_seq_len) + b_truth.append(truth) + word_seq = torch.stack(b_word, dim=0) + word_seq_len = torch.cat(b_len, dim=0) + truth = torch.stack(b_truth, dim=0) + res = model(word_seq, word_seq_len, truth) + loss = res['loss'] + pred = res['predict'] + print('loss: {} acc {}'.format(loss.item(), ((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) + optimizer.zero_grad() + loss.backward() + optimizer.step() + curidx = endidx + if curidx == len(data): + curidx = 0 diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py index 21497037..5056e181 100644 --- a/fastNLP/modules/utils.py +++ b/fastNLP/modules/utils.py @@ -77,11 +77,13 @@ def initial_parameter(net, initial_method=None): def seq_mask(seq_len, max_len): """Create sequence mask. - :param seq_len: list of int, the lengths of sequences in a batch. + :param seq_len: list or torch.Tensor, the lengths of sequences in a batch. :param max_len: int, the maximum sequence length in a batch. :return mask: torch.LongTensor, [batch_size, max_len] """ - mask = [torch.ge(torch.LongTensor(seq_len), i + 1) for i in range(max_len)] - mask = torch.stack(mask, 1) - return mask + if not isinstance(seq_len, torch.Tensor): + seq_len = torch.LongTensor(seq_len) + seq_len = seq_len.view(-1, 1).long() # [batch_size, 1] + seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len] + return torch.gt(seq_len, seq_range) # [batch_size, max_len] diff --git a/reproduction/Biaffine_parser/infer.py b/reproduction/Biaffine_parser/infer.py new file mode 100644 index 00000000..691c01d0 --- /dev/null +++ b/reproduction/Biaffine_parser/infer.py @@ -0,0 +1,80 @@ +import sys +import os + +sys.path.extend(['/home/yfshao/workdir/dev_fastnlp']) + +from fastNLP.api.processor import * +from fastNLP.api.pipeline import Pipeline +from fastNLP.core.dataset import DataSet +from fastNLP.models.biaffine_parser import BiaffineParser +from fastNLP.loader.config_loader import ConfigSection, ConfigLoader + +import _pickle as pickle +import torch + +def _load(path): + with open(path, 'rb') as f: + obj = pickle.load(f) + return obj + +def _load_all(src): + model_path = src + src = os.path.dirname(src) + + word_v = _load(src+'/word_v.pkl') + pos_v = _load(src+'/pos_v.pkl') + tag_v = _load(src+'/tag_v.pkl') + + model_args = ConfigSection() + ConfigLoader.load_config('cfg.cfg', {'model': model_args}) + model_args['word_vocab_size'] = len(word_v) + model_args['pos_vocab_size'] = len(pos_v) + model_args['num_label'] = len(tag_v) + + model = BiaffineParser(**model_args.data) + model.load_state_dict(torch.load(model_path)) + return { + 'word_v': word_v, + 'pos_v': pos_v, + 'tag_v': tag_v, + 'model': model, + } + +def build(load_path, save_path): + BOS = '' + NUM = '' + _dict = _load_all(load_path) + word_vocab = _dict['word_v'] + pos_vocab = _dict['pos_v'] + tag_vocab = _dict['tag_v'] + model = _dict['model'] + print('load model from {}'.format(load_path)) + word_seq = 'raw_word_seq' + pos_seq = 'raw_pos_seq' + + # build pipeline + pipe = Pipeline() + pipe.add_processor(Num2TagProcessor(NUM, 'sentence', word_seq)) + pipe.add_processor(PreAppendProcessor(BOS, word_seq)) + pipe.add_processor(PreAppendProcessor(BOS, 'sent_pos', pos_seq)) + pipe.add_processor(IndexerProcessor(word_vocab, word_seq, 'word_seq')) + pipe.add_processor(IndexerProcessor(pos_vocab, pos_seq, 'pos_seq')) + pipe.add_processor(SeqLenProcessor(word_seq, 'word_seq_origin_len')) + pipe.add_processor(SetTensorProcessor({'word_seq':True, 'pos_seq':True, 'word_seq_origin_len':True}, default=False)) + pipe.add_processor(ModelProcessor(model, 'word_seq_origin_len')) + pipe.add_processor(SliceProcessor(1, None, None, 'head_pred', 'heads')) + pipe.add_processor(SliceProcessor(1, None, None, 'label_pred', 'label_pred')) + pipe.add_processor(Index2WordProcessor(tag_vocab, 'label_pred', 'labels')) + if not os.path.exists(save_path): + os.makedirs(save_path) + with open(save_path+'/pipeline.pkl', 'wb') as f: + torch.save(pipe, f) + print('save pipeline in {}'.format(save_path)) + + +import argparse +parser = argparse.ArgumentParser(description='build pipeline for parser.') +parser.add_argument('--src', type=str, default='/home/yfshao/workdir/dev_fastnlp/reproduction/Biaffine_parser/save') +parser.add_argument('--dst', type=str, default='/home/yfshao/workdir/dev_fastnlp/reproduction/Biaffine_parser/pipe') +args = parser.parse_args() +build(args.src, args.dst)