From c91696e1ee76b4164ba3170f6312367eac82c098 Mon Sep 17 00:00:00 2001 From: yunfan Date: Sat, 5 Jan 2019 21:02:24 +0800 Subject: [PATCH] update parser, optimize embed_loader --- fastNLP/core/dataset.py | 4 +- fastNLP/io/embed_loader.py | 13 +++- fastNLP/models/biaffine_parser.py | 94 ++++++++++++++++++++--------- test/models/test_biaffine_parser.py | 82 +++++++++++++++++++++++++ 4 files changed, 161 insertions(+), 32 deletions(-) create mode 100644 test/models/test_biaffine_parser.py diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index c798a422..844492dd 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -254,8 +254,6 @@ class DataSet(object): :return results: if new_field_name is not passed, returned values of the function over all instances. """ results = [func(ins) for ins in self._inner_iter()] - if len(list(filter(lambda x: x is not None, results))) == 0: # all None - raise ValueError("{} always return None.".format(get_func_signature(func=func))) extra_param = {} if 'is_input' in kwargs: @@ -263,6 +261,8 @@ class DataSet(object): if 'is_target' in kwargs: extra_param['is_target'] = kwargs['is_target'] if new_field_name is not None: + if len(list(filter(lambda x: x is not None, results))) == 0: # all None + raise ValueError("{} always return None.".format(get_func_signature(func=func))) if new_field_name in self.field_arrays: # overwrite the field, keep same attributes old_field = self.field_arrays[new_field_name] diff --git a/fastNLP/io/embed_loader.py b/fastNLP/io/embed_loader.py index 779b7fd0..2eb48f93 100644 --- a/fastNLP/io/embed_loader.py +++ b/fastNLP/io/embed_loader.py @@ -74,10 +74,18 @@ class EmbedLoader(BaseLoader): @staticmethod def parse_glove_line(line): - line = list(filter(lambda w: len(w) > 0, line.strip().split(" "))) + line = line.split() if len(line) <= 2: raise RuntimeError("something goes wrong in parsing glove embedding") - return line[0], torch.Tensor(list(map(float, line[1:]))) + return line[0], line[1:] + + @staticmethod + def str_list_2_vec(line): + try: + return torch.Tensor(list(map(float, line))) + except Exception: + raise RuntimeError("something goes wrong in parsing glove embedding") + @staticmethod def fast_load_embedding(emb_dim, emb_file, vocab): @@ -98,6 +106,7 @@ class EmbedLoader(BaseLoader): for line in f: word, vector = EmbedLoader.parse_glove_line(line) if word in vocab: + vector = EmbedLoader.str_list_2_vec(vector) if len(vector.shape) > 1 or emb_dim != vector.shape[0]: raise ValueError("Pre-trained embedding dim is {}. Expect {}.".format(vector.shape, (emb_dim,))) embedding_matrix[vocab[word]] = vector diff --git a/fastNLP/models/biaffine_parser.py b/fastNLP/models/biaffine_parser.py index 2a42116c..efb07f34 100644 --- a/fastNLP/models/biaffine_parser.py +++ b/fastNLP/models/biaffine_parser.py @@ -1,5 +1,3 @@ -import sys, os -sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) import copy import numpy as np import torch @@ -11,6 +9,9 @@ from fastNLP.modules.encoder.variational_rnn import VarLSTM from fastNLP.modules.dropout import TimestepDropout from fastNLP.models.base_model import BaseModel from fastNLP.modules.utils import seq_mask +from fastNLP.core.losses import LossFunc +from fastNLP.core.metrics import MetricBase +from fastNLP.core.utils import seq_lens_to_masks def mst(scores): """ @@ -121,9 +122,6 @@ class GraphParser(BaseModel): def __init__(self): super(GraphParser, self).__init__() - def forward(self, x): - raise NotImplementedError - def _greedy_decoder(self, arc_matrix, mask=None): _, seq_len, _ = arc_matrix.shape matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) @@ -202,14 +200,14 @@ class BiaffineParser(GraphParser): word_emb_dim, pos_vocab_size, pos_emb_dim, - word_hid_dim, - pos_hid_dim, - rnn_layers, - rnn_hidden_size, - arc_mlp_size, - label_mlp_size, num_label, - dropout, + word_hid_dim=100, + pos_hid_dim=100, + rnn_layers=1, + rnn_hidden_size=200, + arc_mlp_size=100, + label_mlp_size=100, + dropout=0.3, use_var_lstm=False, use_greedy_infer=False): @@ -267,11 +265,11 @@ class BiaffineParser(GraphParser): for p in m.parameters(): nn.init.normal_(p, 0, 0.1) - def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_): + def forward(self, word_seq, pos_seq, seq_lens, gold_heads=None): """ :param word_seq: [batch_size, seq_len] sequence of word's indices :param pos_seq: [batch_size, seq_len] sequence of word's indices - :param word_seq_origin_len: [batch_size, seq_len] sequence of length masks + :param seq_lens: [batch_size, seq_len] sequence of length masks :param gold_heads: [batch_size, seq_len] sequence of golden heads :return dict: parsing results arc_pred: [batch_size, seq_len, seq_len] @@ -283,12 +281,12 @@ class BiaffineParser(GraphParser): device = self.parameters().__next__().device word_seq = word_seq.long().to(device) pos_seq = pos_seq.long().to(device) - word_seq_origin_len = word_seq_origin_len.long().to(device).view(-1) + seq_lens = seq_lens.long().to(device).view(-1) batch_size, seq_len = word_seq.shape # print('forward {} {}'.format(batch_size, seq_len)) # get sequence mask - mask = seq_mask(word_seq_origin_len, seq_len).long() + mask = seq_mask(seq_lens, seq_len).long() word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] @@ -298,7 +296,7 @@ class BiaffineParser(GraphParser): del word, pos # lstm, extract features - sort_lens, sort_idx = torch.sort(word_seq_origin_len, dim=0, descending=True) + sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) x = x[sort_idx] x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) feat, _ = self.lstm(x) # -> [N,L,C] @@ -342,14 +340,15 @@ class BiaffineParser(GraphParser): res_dict['head_pred'] = head_pred return res_dict - def loss(self, arc_pred, label_pred, head_indices, head_labels, mask, **_): + @staticmethod + def loss(arc_pred, label_pred, arc_true, label_true, mask): """ Compute loss. :param arc_pred: [batch_size, seq_len, seq_len] :param label_pred: [batch_size, seq_len, n_tags] - :param head_indices: [batch_size, seq_len] - :param head_labels: [batch_size, seq_len] + :param arc_true: [batch_size, seq_len] + :param label_true: [batch_size, seq_len] :param mask: [batch_size, seq_len] :return: loss value """ @@ -362,8 +361,8 @@ class BiaffineParser(GraphParser): label_logits = F.log_softmax(label_pred, dim=2) batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) child_index = torch.arange(seq_len, device=arc_logits.device, dtype=torch.long).unsqueeze(0) - arc_loss = arc_logits[batch_index, child_index, head_indices] - label_loss = label_logits[batch_index, child_index, head_labels] + arc_loss = arc_logits[batch_index, child_index, arc_true] + label_loss = label_logits[batch_index, child_index, label_true] arc_loss = arc_loss[:, 1:] label_loss = label_loss[:, 1:] @@ -373,19 +372,58 @@ class BiaffineParser(GraphParser): label_nll = -(label_loss*float_mask).mean() return arc_nll + label_nll - def predict(self, word_seq, pos_seq, word_seq_origin_len): + def predict(self, word_seq, pos_seq, seq_lens): """ :param word_seq: :param pos_seq: - :param word_seq_origin_len: - :return: head_pred: [B, L] + :param seq_lens: + :return: arc_pred: [B, L] label_pred: [B, L] - seq_len: [B,] """ - res = self(word_seq, pos_seq, word_seq_origin_len) + res = self(word_seq, pos_seq, seq_lens) output = {} - output['head_pred'] = res.pop('head_pred') + output['arc_pred'] = res.pop('head_pred') _, label_pred = res.pop('label_pred').max(2) output['label_pred'] = label_pred return output + + +class ParserLoss(LossFunc): + def __init__(self, arc_pred=None, label_pred=None, arc_true=None, label_true=None): + super(ParserLoss, self).__init__(BiaffineParser.loss, + arc_pred=arc_pred, + label_pred=label_pred, + arc_true=arc_true, + label_true=label_true) + + +class ParserMetric(MetricBase): + def __init__(self, arc_pred=None, label_pred=None, + arc_true=None, label_true=None, seq_lens=None): + super().__init__() + self._init_param_map(arc_pred=arc_pred, label_pred=label_pred, + arc_true=arc_true, label_true=label_true, + seq_lens=seq_lens) + self.num_arc = 0 + self.num_label = 0 + self.num_sample = 0 + + def get_metric(self, reset=True): + res = {'UAS': self.num_arc*1.0 / self.num_sample, 'LAS': self.num_label*1.0 / self.num_sample} + if reset: + self.num_sample = self.num_label = self.num_arc = 0 + return res + + def evaluate(self, arc_pred, label_pred, arc_true, label_true, seq_lens=None): + """Evaluate the performance of prediction. + """ + if seq_lens is None: + seq_mask = arc_pred.new_ones(arc_pred.size(), dtype=torch.long) + else: + seq_mask = seq_lens_to_masks(seq_lens, float=False).long() + head_pred_correct = (arc_pred == arc_true).long() * seq_mask + label_pred_correct = (label_pred == label_true).long() * head_pred_correct + self.num_arc += head_pred_correct.sum().item() + self.num_label += label_pred_correct.sum().item() + self.num_sample += seq_mask.sum().item() diff --git a/test/models/test_biaffine_parser.py b/test/models/test_biaffine_parser.py new file mode 100644 index 00000000..8fafd00b --- /dev/null +++ b/test/models/test_biaffine_parser.py @@ -0,0 +1,82 @@ +from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric +import fastNLP + +import unittest + +data_file = """ +1 The _ DET DT _ 3 det _ _ +2 new _ ADJ JJ _ 3 amod _ _ +3 rate _ NOUN NN _ 6 nsubj _ _ +4 will _ AUX MD _ 6 aux _ _ +5 be _ VERB VB _ 6 cop _ _ +6 payable _ ADJ JJ _ 0 root _ _ +9 cents _ NOUN NNS _ 4 nmod _ _ +10 from _ ADP IN _ 12 case _ _ +11 seven _ NUM CD _ 12 nummod _ _ +12 cents _ NOUN NNS _ 4 nmod _ _ +13 a _ DET DT _ 14 det _ _ +14 share _ NOUN NN _ 12 nmod:npmod _ _ +15 . _ PUNCT . _ 4 punct _ _ + +1 The _ DET DT _ 3 det _ _ +2 new _ ADJ JJ _ 3 amod _ _ +3 rate _ NOUN NN _ 6 nsubj _ _ +4 will _ AUX MD _ 6 aux _ _ +5 be _ VERB VB _ 6 cop _ _ +6 payable _ ADJ JJ _ 0 root _ _ +7 Feb. _ PROPN NNP _ 6 nmod:tmod _ _ +8 15 _ NUM CD _ 7 nummod _ _ +9 . _ PUNCT . _ 6 punct _ _ + +1 A _ DET DT _ 3 det _ _ +2 record _ NOUN NN _ 3 compound _ _ +3 date _ NOUN NN _ 7 nsubjpass _ _ +4 has _ AUX VBZ _ 7 aux _ _ +5 n't _ PART RB _ 7 neg _ _ +6 been _ AUX VBN _ 7 auxpass _ _ +7 set _ VERB VBN _ 0 root _ _ +8 . _ PUNCT . _ 7 punct _ _ + +""" + +def init_data(): + ds = fastNLP.DataSet() + v = {'word_seq': fastNLP.Vocabulary(), + 'pos_seq': fastNLP.Vocabulary(), + 'label_true': fastNLP.Vocabulary()} + data = [] + for line in data_file.split('\n'): + line = line.split() + if len(line) == 0 and len(data) > 0: + data = list(zip(*data)) + ds.append(fastNLP.Instance(word_seq=data[1], + pos_seq=data[4], + arc_true=data[6], + label_true=data[7])) + data = [] + elif len(line) > 0: + data.append(line) + + for name in ['word_seq', 'pos_seq', 'label_true']: + ds.apply(lambda x: ['']+list(x[name])+[''], new_field_name=name) + ds.apply(lambda x: v[name].add_word_lst(x[name])) + + for name in ['word_seq', 'pos_seq', 'label_true']: + ds.apply(lambda x: [v[name].to_index(w) for w in x[name]], new_field_name=name) + + ds.apply(lambda x: [0]+list(map(int, x['arc_true']))+[1], new_field_name='arc_true') + ds.apply(lambda x: len(x['word_seq']), new_field_name='seq_lens') + ds.set_input('word_seq', 'pos_seq', 'seq_lens', flag=True) + ds.set_target('arc_true', 'label_true', 'seq_lens', flag=True) + return ds, v['word_seq'], v['pos_seq'], v['label_true'] + +class TestBiaffineParser(unittest.TestCase): + def test_train(self): + ds, v1, v2, v3 = init_data() + model = BiaffineParser(word_vocab_size=len(v1), word_emb_dim=30, + pos_vocab_size=len(v2), pos_emb_dim=30, + num_label=len(v3)) + trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds, + loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', + n_epochs=10, use_cuda=False, use_tqdm=False) + trainer.train(load_best_model=False)