@@ -7,16 +7,17 @@ import torch | |||
from torch import nn | |||
from torch.nn import functional as F | |||
from fastNLP.core.losses import LossFunc | |||
from fastNLP.core.metrics import MetricBase | |||
from fastNLP.core.utils import seq_lens_to_masks | |||
from fastNLP.models.base_model import BaseModel | |||
from fastNLP.modules.dropout import TimestepDropout | |||
from fastNLP.modules.encoder.transformer import TransformerEncoder | |||
from fastNLP.modules.encoder.variational_rnn import VarLSTM | |||
from fastNLP.modules.utils import initial_parameter | |||
from fastNLP.modules.utils import seq_mask | |||
from fastNLP.modules.utils import get_embeddings | |||
from ..core.const import Const as C | |||
from ..core.losses import LossFunc | |||
from ..core.metrics import MetricBase | |||
from ..core.utils import seq_lens_to_masks | |||
from ..modules.dropout import TimestepDropout | |||
from ..modules.encoder.transformer import TransformerEncoder | |||
from ..modules.encoder.variational_rnn import VarLSTM | |||
from ..modules.utils import initial_parameter | |||
from ..modules.utils import seq_mask | |||
from ..modules.utils import get_embeddings | |||
from .base_model import BaseModel | |||
def _mst(scores): | |||
""" | |||
@@ -325,21 +326,20 @@ class BiaffineParser(GraphParser): | |||
for p in m.parameters(): | |||
nn.init.normal_(p, 0, 0.1) | |||
def forward(self, words1, words2, seq_len, gold_heads=None): | |||
def forward(self, words1, words2, seq_len, target1=None): | |||
"""模型forward阶段 | |||
:param words1: [batch_size, seq_len] 输入word序列 | |||
:param words2: [batch_size, seq_len] 输入pos序列 | |||
:param seq_len: [batch_size, seq_len] 输入序列长度 | |||
:param gold_heads: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效, | |||
:param target1: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效, | |||
用于训练label分类器. 若为 ``None`` , 使用预测的heads输入到label分类器 | |||
Default: ``None`` | |||
:return dict: parsing结果:: | |||
arc_pred: [batch_size, seq_len, seq_len] 边预测logits | |||
label_pred: [batch_size, seq_len, num_label] label预测logits | |||
mask: [batch_size, seq_len] 预测结果的mask | |||
head_pred: [batch_size, seq_len] heads的预测结果, 在 ``gold_heads=None`` 时预测 | |||
pred1: [batch_size, seq_len, seq_len] 边预测logits | |||
pred2: [batch_size, seq_len, num_label] label预测logits | |||
pred3: [batch_size, seq_len] heads的预测结果, 在 ``target1=None`` 时预测 | |||
""" | |||
# prepare embeddings | |||
batch_size, length = words1.shape | |||
@@ -365,7 +365,7 @@ class BiaffineParser(GraphParser): | |||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||
feat = feat[unsort_idx] | |||
else: | |||
seq_range = torch.arange(seq_len, dtype=torch.long, device=x.device)[None,:] | |||
seq_range = torch.arange(length, dtype=torch.long, device=x.device)[None,:] | |||
x = x + self.position_emb(seq_range) | |||
feat = self.encoder(x, mask.float()) | |||
@@ -380,7 +380,7 @@ class BiaffineParser(GraphParser): | |||
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | |||
# use gold or predicted arc to predict label | |||
if gold_heads is None or not self.training: | |||
if target1 is None or not self.training: | |||
# use greedy decoding in training | |||
if self.training or self.use_greedy_infer: | |||
heads = self.greedy_decoder(arc_pred, mask) | |||
@@ -389,44 +389,45 @@ class BiaffineParser(GraphParser): | |||
head_pred = heads | |||
else: | |||
assert self.training # must be training mode | |||
if gold_heads is None: | |||
if target1 is None: | |||
heads = self.greedy_decoder(arc_pred, mask) | |||
head_pred = heads | |||
else: | |||
head_pred = None | |||
heads = gold_heads | |||
heads = target1 | |||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=words1.device).unsqueeze(1) | |||
label_head = label_head[batch_range, heads].contiguous() | |||
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | |||
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask} | |||
res_dict = {C.OUTPUTS(0): arc_pred, C.OUTPUTS(1): label_pred} | |||
if head_pred is not None: | |||
res_dict['head_pred'] = head_pred | |||
res_dict[C.OUTPUTS(2)] = head_pred | |||
return res_dict | |||
@staticmethod | |||
def loss(arc_pred, label_pred, arc_true, label_true, mask): | |||
def loss(pred1, pred2, target1, target2, seq_len): | |||
""" | |||
Compute loss. | |||
:param arc_pred: [batch_size, seq_len, seq_len] 边预测logits | |||
:param label_pred: [batch_size, seq_len, num_label] label预测logits | |||
:param arc_true: [batch_size, seq_len] 真实边的标注 | |||
:param label_true: [batch_size, seq_len] 真实类别的标注 | |||
:param mask: [batch_size, seq_len] 预测结果的mask | |||
:return: loss value | |||
计算parser的loss | |||
:param pred1: [batch_size, seq_len, seq_len] 边预测logits | |||
:param pred2: [batch_size, seq_len, num_label] label预测logits | |||
:param target1: [batch_size, seq_len] 真实边的标注 | |||
:param target2: [batch_size, seq_len] 真实类别的标注 | |||
:param seq_len: [batch_size, seq_len] 真实目标的长度 | |||
:return loss: scalar | |||
""" | |||
batch_size, seq_len, _ = arc_pred.shape | |||
batch_size, length, _ = pred1.shape | |||
mask = seq_mask(seq_len, length) | |||
flip_mask = (mask == 0) | |||
_arc_pred = arc_pred.clone() | |||
_arc_pred = pred1.clone() | |||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) | |||
arc_logits = F.log_softmax(_arc_pred, dim=2) | |||
label_logits = F.log_softmax(label_pred, dim=2) | |||
label_logits = F.log_softmax(pred2, dim=2) | |||
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) | |||
child_index = torch.arange(seq_len, device=arc_logits.device, dtype=torch.long).unsqueeze(0) | |||
arc_loss = arc_logits[batch_index, child_index, arc_true] | |||
label_loss = label_logits[batch_index, child_index, label_true] | |||
child_index = torch.arange(length, device=arc_logits.device, dtype=torch.long).unsqueeze(0) | |||
arc_loss = arc_logits[batch_index, child_index, target1] | |||
label_loss = label_logits[batch_index, child_index, target2] | |||
byte_mask = flip_mask.byte() | |||
arc_loss.masked_fill_(byte_mask, 0) | |||
@@ -441,21 +442,16 @@ class BiaffineParser(GraphParser): | |||
:param words1: [batch_size, seq_len] 输入word序列 | |||
:param words2: [batch_size, seq_len] 输入pos序列 | |||
:param seq_len: [batch_size, seq_len] 输入序列长度 | |||
:param gold_heads: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效, | |||
用于训练label分类器. 若为 ``None`` , 使用预测的heads输入到label分类器 | |||
Default: ``None`` | |||
:return dict: parsing结果:: | |||
arc_pred: [batch_size, seq_len, seq_len] 边预测logits | |||
label_pred: [batch_size, seq_len, num_label] label预测logits | |||
mask: [batch_size, seq_len] 预测结果的mask | |||
head_pred: [batch_size, seq_len] heads的预测结果, 在 ``gold_heads=None`` 时预测 | |||
pred1: [batch_size, seq_len] heads的预测结果 | |||
pred2: [batch_size, seq_len, num_label] label预测logits | |||
""" | |||
res = self(words1, words2, seq_len) | |||
output = {} | |||
output['arc_pred'] = res.pop('head_pred') | |||
_, label_pred = res.pop('label_pred').max(2) | |||
output['label_pred'] = label_pred | |||
output[C.OUTPUTS(0)] = res.pop(C.OUTPUTS(2)) | |||
_, label_pred = res.pop(C.OUTPUTS(1)).max(2) | |||
output[C.OUTPUTS(1)] = label_pred | |||
return output | |||
@@ -463,41 +459,44 @@ class ParserLoss(LossFunc): | |||
""" | |||
计算parser的loss | |||
:param arc_pred: [batch_size, seq_len, seq_len] 边预测logits | |||
:param label_pred: [batch_size, seq_len, num_label] label预测logits | |||
:param arc_true: [batch_size, seq_len] 真实边的标注 | |||
:param label_true: [batch_size, seq_len] 真实类别的标注 | |||
:param mask: [batch_size, seq_len] 预测结果的mask | |||
:param pred1: [batch_size, seq_len, seq_len] 边预测logits | |||
:param pred2: [batch_size, seq_len, num_label] label预测logits | |||
:param target1: [batch_size, seq_len] 真实边的标注 | |||
:param target2: [batch_size, seq_len] 真实类别的标注 | |||
:param seq_len: [batch_size, seq_len] 真实目标的长度 | |||
:return loss: scalar | |||
""" | |||
def __init__(self, arc_pred=None, label_pred=None, arc_true=None, label_true=None): | |||
def __init__(self, pred1=None, pred2=None, | |||
target1=None, target2=None, | |||
seq_len=None): | |||
super(ParserLoss, self).__init__(BiaffineParser.loss, | |||
arc_pred=arc_pred, | |||
label_pred=label_pred, | |||
arc_true=arc_true, | |||
label_true=label_true) | |||
pred1=pred1, | |||
pred2=pred2, | |||
target1=target1, | |||
target2=target2, | |||
seq_len=seq_len) | |||
class ParserMetric(MetricBase): | |||
""" | |||
评估parser的性能 | |||
:param arc_pred: 边预测logits | |||
:param label_pred: label预测logits | |||
:param arc_true: 真实边的标注 | |||
:param label_true: 真实类别的标注 | |||
:param pred1: 边预测logits | |||
:param pred2: label预测logits | |||
:param target1: 真实边的标注 | |||
:param target2: 真实类别的标注 | |||
:param seq_len: 序列长度 | |||
:return dict: 评估结果:: | |||
UAS: 不带label时, 边预测的准确率 | |||
LAS: 同时预测边和label的准确率 | |||
""" | |||
def __init__(self, arc_pred=None, label_pred=None, | |||
arc_true=None, label_true=None, seq_len=None): | |||
def __init__(self, pred1=None, pred2=None, | |||
target1=None, target2=None, seq_len=None): | |||
super().__init__() | |||
self._init_param_map(arc_pred=arc_pred, label_pred=label_pred, | |||
arc_true=arc_true, label_true=label_true, | |||
self._init_param_map(pred1=pred1, pred2=pred2, | |||
target1=target1, target2=target2, | |||
seq_len=seq_len) | |||
self.num_arc = 0 | |||
self.num_label = 0 | |||
@@ -509,17 +508,17 @@ class ParserMetric(MetricBase): | |||
self.num_sample = self.num_label = self.num_arc = 0 | |||
return res | |||
def evaluate(self, arc_pred, label_pred, arc_true, label_true, seq_len=None): | |||
def evaluate(self, pred1, pred2, target1, target2, seq_len=None): | |||
"""Evaluate the performance of prediction. | |||
""" | |||
if seq_len is None: | |||
seq_mask = arc_pred.new_ones(arc_pred.size(), dtype=torch.long) | |||
seq_mask = pred1.new_ones(pred1.size(), dtype=torch.long) | |||
else: | |||
seq_mask = seq_lens_to_masks(seq_len.long(), float=False).long() | |||
# mask out <root> tag | |||
seq_mask[:,0] = 0 | |||
head_pred_correct = (arc_pred == arc_true).long() * seq_mask | |||
label_pred_correct = (label_pred == label_true).long() * head_pred_correct | |||
head_pred_correct = (pred1 == target1).long() * seq_mask | |||
label_pred_correct = (pred2 == target2).long() * head_pred_correct | |||
self.num_arc += head_pred_correct.sum().item() | |||
self.num_label += label_pred_correct.sum().item() | |||
self.num_sample += seq_mask.sum().item() |
@@ -108,7 +108,7 @@ class STSeqLabel(nn.Module): | |||
:param emb_dropout: 词嵌入的dropout概率. Default: 0.1 | |||
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 | |||
""" | |||
def __init__(self, vocab_size, emb_dim, num_cls, | |||
def __init__(self, init_embed, num_cls, | |||
hidden_size=300, | |||
num_layers=4, | |||
num_head=8, | |||
@@ -118,8 +118,7 @@ class STSeqLabel(nn.Module): | |||
emb_dropout=0.1, | |||
dropout=0.1,): | |||
super(STSeqLabel, self).__init__() | |||
self.enc = StarTransEnc(vocab_size=vocab_size, | |||
emb_dim=emb_dim, | |||
self.enc = StarTransEnc(init_embed=init_embed, | |||
hidden_size=hidden_size, | |||
num_layers=num_layers, | |||
num_head=num_head, | |||
@@ -170,7 +169,7 @@ class STSeqCls(nn.Module): | |||
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 | |||
""" | |||
def __init__(self, vocab_size, emb_dim, num_cls, | |||
def __init__(self, init_embed, num_cls, | |||
hidden_size=300, | |||
num_layers=4, | |||
num_head=8, | |||
@@ -180,8 +179,7 @@ class STSeqCls(nn.Module): | |||
emb_dropout=0.1, | |||
dropout=0.1,): | |||
super(STSeqCls, self).__init__() | |||
self.enc = StarTransEnc(vocab_size=vocab_size, | |||
emb_dim=emb_dim, | |||
self.enc = StarTransEnc(init_embed=init_embed, | |||
hidden_size=hidden_size, | |||
num_layers=num_layers, | |||
num_head=num_head, | |||
@@ -232,7 +230,7 @@ class STNLICls(nn.Module): | |||
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 | |||
""" | |||
def __init__(self, vocab_size, emb_dim, num_cls, | |||
def __init__(self, init_embed, num_cls, | |||
hidden_size=300, | |||
num_layers=4, | |||
num_head=8, | |||
@@ -242,8 +240,7 @@ class STNLICls(nn.Module): | |||
emb_dropout=0.1, | |||
dropout=0.1,): | |||
super(STNLICls, self).__init__() | |||
self.enc = StarTransEnc(vocab_size=vocab_size, | |||
emb_dim=emb_dim, | |||
self.enc = StarTransEnc(init_embed=init_embed, | |||
hidden_size=hidden_size, | |||
num_layers=num_layers, | |||
num_head=num_head, | |||
@@ -0,0 +1,151 @@ | |||
""" | |||
此模块可以非常方便的测试模型。 | |||
若你的模型属于:文本分类,序列标注,自然语言推理(NLI),可以直接使用此模块测试 | |||
若模型不属于上述类别,也可以自己准备假数据,设定loss和metric进行测试 | |||
此模块的测试仅保证模型能使用fastNLP进行训练和测试,不测试模型实际性能 | |||
Example:: | |||
# import 全大写变量... | |||
from model_runner import * | |||
# 测试一个文本分类模型 | |||
init_emb = (VOCAB_SIZE, 50) | |||
model = SomeModel(init_emb, num_cls=NUM_CLS) | |||
RUNNER.run_model_with_task(TEXT_CLS, model) | |||
# 序列标注模型 | |||
RUNNER.run_model_with_task(POS_TAGGING, model) | |||
# NLI模型 | |||
RUNNER.run_model_with_task(NLI, model) | |||
# 自定义模型 | |||
RUNNER.run_model(model, data=get_mydata(), | |||
loss=Myloss(), metrics=Mymetric()) | |||
""" | |||
from fastNLP import Trainer, Tester, DataSet | |||
from fastNLP import AccuracyMetric | |||
from fastNLP import CrossEntropyLoss | |||
from fastNLP.core.const import Const as C | |||
from random import randrange | |||
VOCAB_SIZE = 100 | |||
NUM_CLS = 100 | |||
MAX_LEN = 10 | |||
N_SAMPLES = 100 | |||
N_EPOCHS = 1 | |||
BATCH_SIZE = 5 | |||
TEXT_CLS = 'text_cls' | |||
POS_TAGGING = 'pos_tagging' | |||
NLI = 'nli' | |||
class ModelRunner(): | |||
def gen_seq(self, length, vocab_size): | |||
"""generate fake sequence indexes with given length""" | |||
# reserve 0 for padding | |||
return [randrange(1, vocab_size) for _ in range(length)] | |||
def gen_var_seq(self, max_len, vocab_size): | |||
"""generate fake sequence indexes in variant length""" | |||
length = randrange(3, max_len) # at least 3 words in a seq | |||
return self.gen_seq(length, vocab_size) | |||
def prepare_text_classification_data(self): | |||
index = 'index' | |||
ds = DataSet({index: list(range(N_SAMPLES))}) | |||
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||
field_name=index, new_field_name=C.INPUT, | |||
is_input=True) | |||
ds.apply_field(lambda x: randrange(NUM_CLS), | |||
field_name=index, new_field_name=C.TARGET, | |||
is_target=True) | |||
ds.apply_field(len, C.INPUT, C.INPUT_LEN, | |||
is_input=True) | |||
return ds | |||
def prepare_pos_tagging_data(self): | |||
index = 'index' | |||
ds = DataSet({index: list(range(N_SAMPLES))}) | |||
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||
field_name=index, new_field_name=C.INPUT, | |||
is_input=True) | |||
ds.apply_field(lambda x: self.gen_seq(len(x), NUM_CLS), | |||
field_name=C.INPUT, new_field_name=C.TARGET, | |||
is_target=True) | |||
ds.apply_field(len, C.INPUT, C.INPUT_LEN, | |||
is_input=True, is_target=True) | |||
return ds | |||
def prepare_nli_data(self): | |||
index = 'index' | |||
ds = DataSet({index: list(range(N_SAMPLES))}) | |||
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||
field_name=index, new_field_name=C.INPUTS(0), | |||
is_input=True) | |||
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||
field_name=index, new_field_name=C.INPUTS(1), | |||
is_input=True) | |||
ds.apply_field(lambda x: randrange(NUM_CLS), | |||
field_name=index, new_field_name=C.TARGET, | |||
is_target=True) | |||
ds.apply_field(len, C.INPUTS(0), C.INPUT_LENS(0), | |||
is_input=True, is_target=True) | |||
ds.apply_field(len, C.INPUTS(1), C.INPUT_LENS(1), | |||
is_input = True, is_target = True) | |||
ds.set_input(C.INPUTS(0), C.INPUTS(1)) | |||
ds.set_target(C.TARGET) | |||
return ds | |||
def run_text_classification(self, model, data=None): | |||
if data is None: | |||
data = self.prepare_text_classification_data() | |||
loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) | |||
metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) | |||
self.run_model(model, data, loss, metric) | |||
def run_pos_tagging(self, model, data=None): | |||
if data is None: | |||
data = self.prepare_pos_tagging_data() | |||
loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET, padding_idx=0) | |||
metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET, seq_len=C.INPUT_LEN) | |||
self.run_model(model, data, loss, metric) | |||
def run_nli(self, model, data=None): | |||
if data is None: | |||
data = self.prepare_nli_data() | |||
loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) | |||
metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) | |||
self.run_model(model, data, loss, metric) | |||
def run_model(self, model, data, loss, metrics): | |||
"""run a model, test if it can run with fastNLP""" | |||
print('testing model:', model.__class__.__name__) | |||
tester = Tester(data=data, model=model, metrics=metrics, | |||
batch_size=BATCH_SIZE, verbose=0) | |||
before_train = tester.test() | |||
trainer = Trainer(model=model, train_data=data, dev_data=None, | |||
n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, | |||
loss=loss, | |||
save_path=None, | |||
use_tqdm=False) | |||
trainer.train(load_best_model=False) | |||
after_train = tester.test() | |||
for metric_name, v1 in before_train.items(): | |||
assert metric_name in after_train | |||
# # at least we can sure model params changed, even if we don't know performance | |||
# v2 = after_train[metric_name] | |||
# assert v1 != v2 | |||
def run_model_with_task(self, task, model): | |||
"""run a model with certain task""" | |||
TASKS = { | |||
TEXT_CLS: self.run_text_classification, | |||
POS_TAGGING: self.run_pos_tagging, | |||
NLI: self.run_nli, | |||
} | |||
assert task in TASKS | |||
TASKS[task](model) | |||
RUNNER = ModelRunner() |
@@ -2,90 +2,33 @@ import unittest | |||
import fastNLP | |||
from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric | |||
data_file = """ | |||
1 The _ DET DT _ 3 det _ _ | |||
2 new _ ADJ JJ _ 3 amod _ _ | |||
3 rate _ NOUN NN _ 6 nsubj _ _ | |||
4 will _ AUX MD _ 6 aux _ _ | |||
5 be _ VERB VB _ 6 cop _ _ | |||
6 payable _ ADJ JJ _ 0 root _ _ | |||
7 mask _ ADJ JJ _ 6 punct _ _ | |||
8 mask _ ADJ JJ _ 6 punct _ _ | |||
9 cents _ NOUN NNS _ 4 nmod _ _ | |||
10 from _ ADP IN _ 12 case _ _ | |||
11 seven _ NUM CD _ 12 nummod _ _ | |||
12 cents _ NOUN NNS _ 4 nmod _ _ | |||
13 a _ DET DT _ 14 det _ _ | |||
14 share _ NOUN NN _ 12 nmod:npmod _ _ | |||
15 . _ PUNCT . _ 4 punct _ _ | |||
1 The _ DET DT _ 3 det _ _ | |||
2 new _ ADJ JJ _ 3 amod _ _ | |||
3 rate _ NOUN NN _ 6 nsubj _ _ | |||
4 will _ AUX MD _ 6 aux _ _ | |||
5 be _ VERB VB _ 6 cop _ _ | |||
6 payable _ ADJ JJ _ 0 root _ _ | |||
7 Feb. _ PROPN NNP _ 6 nmod:tmod _ _ | |||
8 15 _ NUM CD _ 7 nummod _ _ | |||
9 . _ PUNCT . _ 6 punct _ _ | |||
1 A _ DET DT _ 3 det _ _ | |||
2 record _ NOUN NN _ 3 compound _ _ | |||
3 date _ NOUN NN _ 7 nsubjpass _ _ | |||
4 has _ AUX VBZ _ 7 aux _ _ | |||
5 n't _ PART RB _ 7 neg _ _ | |||
6 been _ AUX VBN _ 7 auxpass _ _ | |||
7 set _ VERB VBN _ 0 root _ _ | |||
8 . _ PUNCT . _ 7 punct _ _ | |||
""" | |||
def init_data(): | |||
ds = fastNLP.DataSet() | |||
v = {'words1': fastNLP.Vocabulary(), | |||
'words2': fastNLP.Vocabulary(), | |||
'label_true': fastNLP.Vocabulary()} | |||
data = [] | |||
for line in data_file.split('\n'): | |||
line = line.split() | |||
if len(line) == 0 and len(data) > 0: | |||
data = list(zip(*data)) | |||
ds.append(fastNLP.Instance(words1=data[1], | |||
words2=data[4], | |||
arc_true=data[6], | |||
label_true=data[7])) | |||
data = [] | |||
elif len(line) > 0: | |||
data.append(line) | |||
for name in ['words1', 'words2', 'label_true']: | |||
ds.apply(lambda x: ['<st>'] + list(x[name]), new_field_name=name) | |||
ds.apply(lambda x: v[name].add_word_lst(x[name])) | |||
for name in ['words1', 'words2', 'label_true']: | |||
ds.apply(lambda x: [v[name].to_index(w) for w in x[name]], new_field_name=name) | |||
ds.apply(lambda x: [0] + list(map(int, x['arc_true'])), new_field_name='arc_true') | |||
ds.apply(lambda x: len(x['words1']), new_field_name='seq_len') | |||
ds.set_input('words1', 'words2', 'seq_len', flag=True) | |||
ds.set_target('arc_true', 'label_true', 'seq_len', flag=True) | |||
return ds, v['words1'], v['words2'], v['label_true'] | |||
from .model_runner import * | |||
def prepare_parser_data(): | |||
index = 'index' | |||
ds = DataSet({index: list(range(N_SAMPLES))}) | |||
ds.apply_field(lambda x: RUNNER.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||
field_name=index, new_field_name=C.INPUTS(0), | |||
is_input=True) | |||
ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS), | |||
field_name=C.INPUTS(0), new_field_name=C.INPUTS(1), | |||
is_input=True) | |||
# target1 is heads, should in range(0, len(words)) | |||
ds.apply_field(lambda x: RUNNER.gen_seq(len(x), len(x)), | |||
field_name=C.INPUTS(0), new_field_name=C.TARGETS(0), | |||
is_target=True) | |||
ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS), | |||
field_name=C.INPUTS(0), new_field_name=C.TARGETS(1), | |||
is_target=True) | |||
ds.apply_field(len, field_name=C.INPUTS(0), new_field_name=C.INPUT_LEN, | |||
is_input=True, is_target=True) | |||
return ds | |||
class TestBiaffineParser(unittest.TestCase): | |||
def test_train(self): | |||
ds, v1, v2, v3 = init_data() | |||
model = BiaffineParser(init_embed=(len(v1), 30), | |||
pos_vocab_size=len(v2), pos_emb_dim=30, | |||
num_label=len(v3), encoder='var-lstm') | |||
trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds, | |||
loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', | |||
batch_size=1, validate_every=10, | |||
n_epochs=10, use_tqdm=False) | |||
trainer.train(load_best_model=False) | |||
if __name__ == '__main__': | |||
unittest.main() | |||
model = BiaffineParser(init_embed=(VOCAB_SIZE, 30), | |||
pos_vocab_size=VOCAB_SIZE, pos_emb_dim=30, | |||
num_label=NUM_CLS, encoder='var-lstm') | |||
ds = prepare_parser_data() | |||
RUNNER.run_model(model, ds, loss=ParserLoss(), metrics=ParserMetric()) |
@@ -0,0 +1,16 @@ | |||
from .model_runner import * | |||
from fastNLP.models.star_transformer import STNLICls, STSeqCls, STSeqLabel | |||
# add star-transformer tests, for 3 kinds of tasks. | |||
def test_cls(): | |||
model = STSeqCls((VOCAB_SIZE, 100), NUM_CLS, dropout=0) | |||
RUNNER.run_model_with_task(TEXT_CLS, model) | |||
def test_nli(): | |||
model = STNLICls((VOCAB_SIZE, 100), NUM_CLS, dropout=0) | |||
RUNNER.run_model_with_task(NLI, model) | |||
def test_seq_label(): | |||
model = STSeqLabel((VOCAB_SIZE, 100), NUM_CLS, dropout=0) | |||
RUNNER.run_model_with_task(POS_TAGGING, model) |