Browse Source

update parser, optimize embed_loader

tags/v0.3.0
yunfan 5 years ago
parent
commit
c91696e1ee
4 changed files with 161 additions and 32 deletions
  1. +2
    -2
      fastNLP/core/dataset.py
  2. +11
    -2
      fastNLP/io/embed_loader.py
  3. +66
    -28
      fastNLP/models/biaffine_parser.py
  4. +82
    -0
      test/models/test_biaffine_parser.py

+ 2
- 2
fastNLP/core/dataset.py View File

@@ -254,8 +254,6 @@ class DataSet(object):
:return results: if new_field_name is not passed, returned values of the function over all instances.
"""
results = [func(ins) for ins in self._inner_iter()]
if len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(get_func_signature(func=func)))

extra_param = {}
if 'is_input' in kwargs:
@@ -263,6 +261,8 @@ class DataSet(object):
if 'is_target' in kwargs:
extra_param['is_target'] = kwargs['is_target']
if new_field_name is not None:
if len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(get_func_signature(func=func)))
if new_field_name in self.field_arrays:
# overwrite the field, keep same attributes
old_field = self.field_arrays[new_field_name]


+ 11
- 2
fastNLP/io/embed_loader.py View File

@@ -74,10 +74,18 @@ class EmbedLoader(BaseLoader):

@staticmethod
def parse_glove_line(line):
line = list(filter(lambda w: len(w) > 0, line.strip().split(" ")))
line = line.split()
if len(line) <= 2:
raise RuntimeError("something goes wrong in parsing glove embedding")
return line[0], torch.Tensor(list(map(float, line[1:])))
return line[0], line[1:]

@staticmethod
def str_list_2_vec(line):
try:
return torch.Tensor(list(map(float, line)))
except Exception:
raise RuntimeError("something goes wrong in parsing glove embedding")


@staticmethod
def fast_load_embedding(emb_dim, emb_file, vocab):
@@ -98,6 +106,7 @@ class EmbedLoader(BaseLoader):
for line in f:
word, vector = EmbedLoader.parse_glove_line(line)
if word in vocab:
vector = EmbedLoader.str_list_2_vec(vector)
if len(vector.shape) > 1 or emb_dim != vector.shape[0]:
raise ValueError("Pre-trained embedding dim is {}. Expect {}.".format(vector.shape, (emb_dim,)))
embedding_matrix[vocab[word]] = vector


+ 66
- 28
fastNLP/models/biaffine_parser.py View File

@@ -1,5 +1,3 @@
import sys, os
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
import copy
import numpy as np
import torch
@@ -11,6 +9,9 @@ from fastNLP.modules.encoder.variational_rnn import VarLSTM
from fastNLP.modules.dropout import TimestepDropout
from fastNLP.models.base_model import BaseModel
from fastNLP.modules.utils import seq_mask
from fastNLP.core.losses import LossFunc
from fastNLP.core.metrics import MetricBase
from fastNLP.core.utils import seq_lens_to_masks

def mst(scores):
"""
@@ -121,9 +122,6 @@ class GraphParser(BaseModel):
def __init__(self):
super(GraphParser, self).__init__()

def forward(self, x):
raise NotImplementedError

def _greedy_decoder(self, arc_matrix, mask=None):
_, seq_len, _ = arc_matrix.shape
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf))
@@ -202,14 +200,14 @@ class BiaffineParser(GraphParser):
word_emb_dim,
pos_vocab_size,
pos_emb_dim,
word_hid_dim,
pos_hid_dim,
rnn_layers,
rnn_hidden_size,
arc_mlp_size,
label_mlp_size,
num_label,
dropout,
word_hid_dim=100,
pos_hid_dim=100,
rnn_layers=1,
rnn_hidden_size=200,
arc_mlp_size=100,
label_mlp_size=100,
dropout=0.3,
use_var_lstm=False,
use_greedy_infer=False):

@@ -267,11 +265,11 @@ class BiaffineParser(GraphParser):
for p in m.parameters():
nn.init.normal_(p, 0, 0.1)

def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_):
def forward(self, word_seq, pos_seq, seq_lens, gold_heads=None):
"""
:param word_seq: [batch_size, seq_len] sequence of word's indices
:param pos_seq: [batch_size, seq_len] sequence of word's indices
:param word_seq_origin_len: [batch_size, seq_len] sequence of length masks
:param seq_lens: [batch_size, seq_len] sequence of length masks
:param gold_heads: [batch_size, seq_len] sequence of golden heads
:return dict: parsing results
arc_pred: [batch_size, seq_len, seq_len]
@@ -283,12 +281,12 @@ class BiaffineParser(GraphParser):
device = self.parameters().__next__().device
word_seq = word_seq.long().to(device)
pos_seq = pos_seq.long().to(device)
word_seq_origin_len = word_seq_origin_len.long().to(device).view(-1)
seq_lens = seq_lens.long().to(device).view(-1)
batch_size, seq_len = word_seq.shape
# print('forward {} {}'.format(batch_size, seq_len))

# get sequence mask
mask = seq_mask(word_seq_origin_len, seq_len).long()
mask = seq_mask(seq_lens, seq_len).long()

word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0]
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1]
@@ -298,7 +296,7 @@ class BiaffineParser(GraphParser):
del word, pos

# lstm, extract features
sort_lens, sort_idx = torch.sort(word_seq_origin_len, dim=0, descending=True)
sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True)
x = x[sort_idx]
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True)
feat, _ = self.lstm(x) # -> [N,L,C]
@@ -342,14 +340,15 @@ class BiaffineParser(GraphParser):
res_dict['head_pred'] = head_pred
return res_dict

def loss(self, arc_pred, label_pred, head_indices, head_labels, mask, **_):
@staticmethod
def loss(arc_pred, label_pred, arc_true, label_true, mask):
"""
Compute loss.

:param arc_pred: [batch_size, seq_len, seq_len]
:param label_pred: [batch_size, seq_len, n_tags]
:param head_indices: [batch_size, seq_len]
:param head_labels: [batch_size, seq_len]
:param arc_true: [batch_size, seq_len]
:param label_true: [batch_size, seq_len]
:param mask: [batch_size, seq_len]
:return: loss value
"""
@@ -362,8 +361,8 @@ class BiaffineParser(GraphParser):
label_logits = F.log_softmax(label_pred, dim=2)
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1)
child_index = torch.arange(seq_len, device=arc_logits.device, dtype=torch.long).unsqueeze(0)
arc_loss = arc_logits[batch_index, child_index, head_indices]
label_loss = label_logits[batch_index, child_index, head_labels]
arc_loss = arc_logits[batch_index, child_index, arc_true]
label_loss = label_logits[batch_index, child_index, label_true]

arc_loss = arc_loss[:, 1:]
label_loss = label_loss[:, 1:]
@@ -373,19 +372,58 @@ class BiaffineParser(GraphParser):
label_nll = -(label_loss*float_mask).mean()
return arc_nll + label_nll

def predict(self, word_seq, pos_seq, word_seq_origin_len):
def predict(self, word_seq, pos_seq, seq_lens):
"""

:param word_seq:
:param pos_seq:
:param word_seq_origin_len:
:return: head_pred: [B, L]
:param seq_lens:
:return: arc_pred: [B, L]
label_pred: [B, L]
seq_len: [B,]
"""
res = self(word_seq, pos_seq, word_seq_origin_len)
res = self(word_seq, pos_seq, seq_lens)
output = {}
output['head_pred'] = res.pop('head_pred')
output['arc_pred'] = res.pop('head_pred')
_, label_pred = res.pop('label_pred').max(2)
output['label_pred'] = label_pred
return output


class ParserLoss(LossFunc):
def __init__(self, arc_pred=None, label_pred=None, arc_true=None, label_true=None):
super(ParserLoss, self).__init__(BiaffineParser.loss,
arc_pred=arc_pred,
label_pred=label_pred,
arc_true=arc_true,
label_true=label_true)


class ParserMetric(MetricBase):
def __init__(self, arc_pred=None, label_pred=None,
arc_true=None, label_true=None, seq_lens=None):
super().__init__()
self._init_param_map(arc_pred=arc_pred, label_pred=label_pred,
arc_true=arc_true, label_true=label_true,
seq_lens=seq_lens)
self.num_arc = 0
self.num_label = 0
self.num_sample = 0

def get_metric(self, reset=True):
res = {'UAS': self.num_arc*1.0 / self.num_sample, 'LAS': self.num_label*1.0 / self.num_sample}
if reset:
self.num_sample = self.num_label = self.num_arc = 0
return res

def evaluate(self, arc_pred, label_pred, arc_true, label_true, seq_lens=None):
"""Evaluate the performance of prediction.
"""
if seq_lens is None:
seq_mask = arc_pred.new_ones(arc_pred.size(), dtype=torch.long)
else:
seq_mask = seq_lens_to_masks(seq_lens, float=False).long()
head_pred_correct = (arc_pred == arc_true).long() * seq_mask
label_pred_correct = (label_pred == label_true).long() * head_pred_correct
self.num_arc += head_pred_correct.sum().item()
self.num_label += label_pred_correct.sum().item()
self.num_sample += seq_mask.sum().item()

+ 82
- 0
test/models/test_biaffine_parser.py View File

@@ -0,0 +1,82 @@
from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric
import fastNLP

import unittest

data_file = """
1 The _ DET DT _ 3 det _ _
2 new _ ADJ JJ _ 3 amod _ _
3 rate _ NOUN NN _ 6 nsubj _ _
4 will _ AUX MD _ 6 aux _ _
5 be _ VERB VB _ 6 cop _ _
6 payable _ ADJ JJ _ 0 root _ _
9 cents _ NOUN NNS _ 4 nmod _ _
10 from _ ADP IN _ 12 case _ _
11 seven _ NUM CD _ 12 nummod _ _
12 cents _ NOUN NNS _ 4 nmod _ _
13 a _ DET DT _ 14 det _ _
14 share _ NOUN NN _ 12 nmod:npmod _ _
15 . _ PUNCT . _ 4 punct _ _

1 The _ DET DT _ 3 det _ _
2 new _ ADJ JJ _ 3 amod _ _
3 rate _ NOUN NN _ 6 nsubj _ _
4 will _ AUX MD _ 6 aux _ _
5 be _ VERB VB _ 6 cop _ _
6 payable _ ADJ JJ _ 0 root _ _
7 Feb. _ PROPN NNP _ 6 nmod:tmod _ _
8 15 _ NUM CD _ 7 nummod _ _
9 . _ PUNCT . _ 6 punct _ _

1 A _ DET DT _ 3 det _ _
2 record _ NOUN NN _ 3 compound _ _
3 date _ NOUN NN _ 7 nsubjpass _ _
4 has _ AUX VBZ _ 7 aux _ _
5 n't _ PART RB _ 7 neg _ _
6 been _ AUX VBN _ 7 auxpass _ _
7 set _ VERB VBN _ 0 root _ _
8 . _ PUNCT . _ 7 punct _ _

"""

def init_data():
ds = fastNLP.DataSet()
v = {'word_seq': fastNLP.Vocabulary(),
'pos_seq': fastNLP.Vocabulary(),
'label_true': fastNLP.Vocabulary()}
data = []
for line in data_file.split('\n'):
line = line.split()
if len(line) == 0 and len(data) > 0:
data = list(zip(*data))
ds.append(fastNLP.Instance(word_seq=data[1],
pos_seq=data[4],
arc_true=data[6],
label_true=data[7]))
data = []
elif len(line) > 0:
data.append(line)

for name in ['word_seq', 'pos_seq', 'label_true']:
ds.apply(lambda x: ['<st>']+list(x[name])+['<ed>'], new_field_name=name)
ds.apply(lambda x: v[name].add_word_lst(x[name]))

for name in ['word_seq', 'pos_seq', 'label_true']:
ds.apply(lambda x: [v[name].to_index(w) for w in x[name]], new_field_name=name)

ds.apply(lambda x: [0]+list(map(int, x['arc_true']))+[1], new_field_name='arc_true')
ds.apply(lambda x: len(x['word_seq']), new_field_name='seq_lens')
ds.set_input('word_seq', 'pos_seq', 'seq_lens', flag=True)
ds.set_target('arc_true', 'label_true', 'seq_lens', flag=True)
return ds, v['word_seq'], v['pos_seq'], v['label_true']

class TestBiaffineParser(unittest.TestCase):
def test_train(self):
ds, v1, v2, v3 = init_data()
model = BiaffineParser(word_vocab_size=len(v1), word_emb_dim=30,
pos_vocab_size=len(v2), pos_emb_dim=30,
num_label=len(v3))
trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds,
loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS',
n_epochs=10, use_cuda=False, use_tqdm=False)
trainer.train(load_best_model=False)

Loading…
Cancel
Save