|
|
@@ -1,5 +1,3 @@ |
|
|
|
import sys, os |
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) |
|
|
|
import copy |
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
@@ -11,6 +9,9 @@ from fastNLP.modules.encoder.variational_rnn import VarLSTM |
|
|
|
from fastNLP.modules.dropout import TimestepDropout |
|
|
|
from fastNLP.models.base_model import BaseModel |
|
|
|
from fastNLP.modules.utils import seq_mask |
|
|
|
from fastNLP.core.losses import LossFunc |
|
|
|
from fastNLP.core.metrics import MetricBase |
|
|
|
from fastNLP.core.utils import seq_lens_to_masks |
|
|
|
|
|
|
|
def mst(scores): |
|
|
|
""" |
|
|
@@ -121,9 +122,6 @@ class GraphParser(BaseModel): |
|
|
|
def __init__(self): |
|
|
|
super(GraphParser, self).__init__() |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def _greedy_decoder(self, arc_matrix, mask=None): |
|
|
|
_, seq_len, _ = arc_matrix.shape |
|
|
|
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) |
|
|
@@ -202,14 +200,14 @@ class BiaffineParser(GraphParser): |
|
|
|
word_emb_dim, |
|
|
|
pos_vocab_size, |
|
|
|
pos_emb_dim, |
|
|
|
word_hid_dim, |
|
|
|
pos_hid_dim, |
|
|
|
rnn_layers, |
|
|
|
rnn_hidden_size, |
|
|
|
arc_mlp_size, |
|
|
|
label_mlp_size, |
|
|
|
num_label, |
|
|
|
dropout, |
|
|
|
word_hid_dim=100, |
|
|
|
pos_hid_dim=100, |
|
|
|
rnn_layers=1, |
|
|
|
rnn_hidden_size=200, |
|
|
|
arc_mlp_size=100, |
|
|
|
label_mlp_size=100, |
|
|
|
dropout=0.3, |
|
|
|
use_var_lstm=False, |
|
|
|
use_greedy_infer=False): |
|
|
|
|
|
|
@@ -267,11 +265,11 @@ class BiaffineParser(GraphParser): |
|
|
|
for p in m.parameters(): |
|
|
|
nn.init.normal_(p, 0, 0.1) |
|
|
|
|
|
|
|
def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_): |
|
|
|
def forward(self, word_seq, pos_seq, seq_lens, gold_heads=None): |
|
|
|
""" |
|
|
|
:param word_seq: [batch_size, seq_len] sequence of word's indices |
|
|
|
:param pos_seq: [batch_size, seq_len] sequence of word's indices |
|
|
|
:param word_seq_origin_len: [batch_size, seq_len] sequence of length masks |
|
|
|
:param seq_lens: [batch_size, seq_len] sequence of length masks |
|
|
|
:param gold_heads: [batch_size, seq_len] sequence of golden heads |
|
|
|
:return dict: parsing results |
|
|
|
arc_pred: [batch_size, seq_len, seq_len] |
|
|
@@ -283,12 +281,12 @@ class BiaffineParser(GraphParser): |
|
|
|
device = self.parameters().__next__().device |
|
|
|
word_seq = word_seq.long().to(device) |
|
|
|
pos_seq = pos_seq.long().to(device) |
|
|
|
word_seq_origin_len = word_seq_origin_len.long().to(device).view(-1) |
|
|
|
seq_lens = seq_lens.long().to(device).view(-1) |
|
|
|
batch_size, seq_len = word_seq.shape |
|
|
|
# print('forward {} {}'.format(batch_size, seq_len)) |
|
|
|
|
|
|
|
# get sequence mask |
|
|
|
mask = seq_mask(word_seq_origin_len, seq_len).long() |
|
|
|
mask = seq_mask(seq_lens, seq_len).long() |
|
|
|
|
|
|
|
word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] |
|
|
|
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] |
|
|
@@ -298,7 +296,7 @@ class BiaffineParser(GraphParser): |
|
|
|
del word, pos |
|
|
|
|
|
|
|
# lstm, extract features |
|
|
|
sort_lens, sort_idx = torch.sort(word_seq_origin_len, dim=0, descending=True) |
|
|
|
sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) |
|
|
|
x = x[sort_idx] |
|
|
|
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) |
|
|
|
feat, _ = self.lstm(x) # -> [N,L,C] |
|
|
@@ -342,14 +340,15 @@ class BiaffineParser(GraphParser): |
|
|
|
res_dict['head_pred'] = head_pred |
|
|
|
return res_dict |
|
|
|
|
|
|
|
def loss(self, arc_pred, label_pred, head_indices, head_labels, mask, **_): |
|
|
|
@staticmethod |
|
|
|
def loss(arc_pred, label_pred, arc_true, label_true, mask): |
|
|
|
""" |
|
|
|
Compute loss. |
|
|
|
|
|
|
|
:param arc_pred: [batch_size, seq_len, seq_len] |
|
|
|
:param label_pred: [batch_size, seq_len, n_tags] |
|
|
|
:param head_indices: [batch_size, seq_len] |
|
|
|
:param head_labels: [batch_size, seq_len] |
|
|
|
:param arc_true: [batch_size, seq_len] |
|
|
|
:param label_true: [batch_size, seq_len] |
|
|
|
:param mask: [batch_size, seq_len] |
|
|
|
:return: loss value |
|
|
|
""" |
|
|
@@ -362,8 +361,8 @@ class BiaffineParser(GraphParser): |
|
|
|
label_logits = F.log_softmax(label_pred, dim=2) |
|
|
|
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) |
|
|
|
child_index = torch.arange(seq_len, device=arc_logits.device, dtype=torch.long).unsqueeze(0) |
|
|
|
arc_loss = arc_logits[batch_index, child_index, head_indices] |
|
|
|
label_loss = label_logits[batch_index, child_index, head_labels] |
|
|
|
arc_loss = arc_logits[batch_index, child_index, arc_true] |
|
|
|
label_loss = label_logits[batch_index, child_index, label_true] |
|
|
|
|
|
|
|
arc_loss = arc_loss[:, 1:] |
|
|
|
label_loss = label_loss[:, 1:] |
|
|
@@ -373,19 +372,58 @@ class BiaffineParser(GraphParser): |
|
|
|
label_nll = -(label_loss*float_mask).mean() |
|
|
|
return arc_nll + label_nll |
|
|
|
|
|
|
|
def predict(self, word_seq, pos_seq, word_seq_origin_len): |
|
|
|
def predict(self, word_seq, pos_seq, seq_lens): |
|
|
|
""" |
|
|
|
|
|
|
|
:param word_seq: |
|
|
|
:param pos_seq: |
|
|
|
:param word_seq_origin_len: |
|
|
|
:return: head_pred: [B, L] |
|
|
|
:param seq_lens: |
|
|
|
:return: arc_pred: [B, L] |
|
|
|
label_pred: [B, L] |
|
|
|
seq_len: [B,] |
|
|
|
""" |
|
|
|
res = self(word_seq, pos_seq, word_seq_origin_len) |
|
|
|
res = self(word_seq, pos_seq, seq_lens) |
|
|
|
output = {} |
|
|
|
output['head_pred'] = res.pop('head_pred') |
|
|
|
output['arc_pred'] = res.pop('head_pred') |
|
|
|
_, label_pred = res.pop('label_pred').max(2) |
|
|
|
output['label_pred'] = label_pred |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
class ParserLoss(LossFunc): |
|
|
|
def __init__(self, arc_pred=None, label_pred=None, arc_true=None, label_true=None): |
|
|
|
super(ParserLoss, self).__init__(BiaffineParser.loss, |
|
|
|
arc_pred=arc_pred, |
|
|
|
label_pred=label_pred, |
|
|
|
arc_true=arc_true, |
|
|
|
label_true=label_true) |
|
|
|
|
|
|
|
|
|
|
|
class ParserMetric(MetricBase): |
|
|
|
def __init__(self, arc_pred=None, label_pred=None, |
|
|
|
arc_true=None, label_true=None, seq_lens=None): |
|
|
|
super().__init__() |
|
|
|
self._init_param_map(arc_pred=arc_pred, label_pred=label_pred, |
|
|
|
arc_true=arc_true, label_true=label_true, |
|
|
|
seq_lens=seq_lens) |
|
|
|
self.num_arc = 0 |
|
|
|
self.num_label = 0 |
|
|
|
self.num_sample = 0 |
|
|
|
|
|
|
|
def get_metric(self, reset=True): |
|
|
|
res = {'UAS': self.num_arc*1.0 / self.num_sample, 'LAS': self.num_label*1.0 / self.num_sample} |
|
|
|
if reset: |
|
|
|
self.num_sample = self.num_label = self.num_arc = 0 |
|
|
|
return res |
|
|
|
|
|
|
|
def evaluate(self, arc_pred, label_pred, arc_true, label_true, seq_lens=None): |
|
|
|
"""Evaluate the performance of prediction. |
|
|
|
""" |
|
|
|
if seq_lens is None: |
|
|
|
seq_mask = arc_pred.new_ones(arc_pred.size(), dtype=torch.long) |
|
|
|
else: |
|
|
|
seq_mask = seq_lens_to_masks(seq_lens, float=False).long() |
|
|
|
head_pred_correct = (arc_pred == arc_true).long() * seq_mask |
|
|
|
label_pred_correct = (label_pred == label_true).long() * head_pred_correct |
|
|
|
self.num_arc += head_pred_correct.sum().item() |
|
|
|
self.num_label += label_pred_correct.sum().item() |
|
|
|
self.num_sample += seq_mask.sum().item() |