@@ -7,16 +7,17 @@ import torch | |||||
from torch import nn | from torch import nn | ||||
from torch.nn import functional as F | from torch.nn import functional as F | ||||
from fastNLP.core.losses import LossFunc | |||||
from fastNLP.core.metrics import MetricBase | |||||
from fastNLP.core.utils import seq_lens_to_masks | |||||
from fastNLP.models.base_model import BaseModel | |||||
from fastNLP.modules.dropout import TimestepDropout | |||||
from fastNLP.modules.encoder.transformer import TransformerEncoder | |||||
from fastNLP.modules.encoder.variational_rnn import VarLSTM | |||||
from fastNLP.modules.utils import initial_parameter | |||||
from fastNLP.modules.utils import seq_mask | |||||
from fastNLP.modules.utils import get_embeddings | |||||
from ..core.const import Const as C | |||||
from ..core.losses import LossFunc | |||||
from ..core.metrics import MetricBase | |||||
from ..core.utils import seq_lens_to_masks | |||||
from ..modules.dropout import TimestepDropout | |||||
from ..modules.encoder.transformer import TransformerEncoder | |||||
from ..modules.encoder.variational_rnn import VarLSTM | |||||
from ..modules.utils import initial_parameter | |||||
from ..modules.utils import seq_mask | |||||
from ..modules.utils import get_embeddings | |||||
from .base_model import BaseModel | |||||
def _mst(scores): | def _mst(scores): | ||||
""" | """ | ||||
@@ -325,21 +326,20 @@ class BiaffineParser(GraphParser): | |||||
for p in m.parameters(): | for p in m.parameters(): | ||||
nn.init.normal_(p, 0, 0.1) | nn.init.normal_(p, 0, 0.1) | ||||
def forward(self, words1, words2, seq_len, gold_heads=None): | |||||
def forward(self, words1, words2, seq_len, target1=None): | |||||
"""模型forward阶段 | """模型forward阶段 | ||||
:param words1: [batch_size, seq_len] 输入word序列 | :param words1: [batch_size, seq_len] 输入word序列 | ||||
:param words2: [batch_size, seq_len] 输入pos序列 | :param words2: [batch_size, seq_len] 输入pos序列 | ||||
:param seq_len: [batch_size, seq_len] 输入序列长度 | :param seq_len: [batch_size, seq_len] 输入序列长度 | ||||
:param gold_heads: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效, | |||||
:param target1: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效, | |||||
用于训练label分类器. 若为 ``None`` , 使用预测的heads输入到label分类器 | 用于训练label分类器. 若为 ``None`` , 使用预测的heads输入到label分类器 | ||||
Default: ``None`` | Default: ``None`` | ||||
:return dict: parsing结果:: | :return dict: parsing结果:: | ||||
arc_pred: [batch_size, seq_len, seq_len] 边预测logits | |||||
label_pred: [batch_size, seq_len, num_label] label预测logits | |||||
mask: [batch_size, seq_len] 预测结果的mask | |||||
head_pred: [batch_size, seq_len] heads的预测结果, 在 ``gold_heads=None`` 时预测 | |||||
pred1: [batch_size, seq_len, seq_len] 边预测logits | |||||
pred2: [batch_size, seq_len, num_label] label预测logits | |||||
pred3: [batch_size, seq_len] heads的预测结果, 在 ``target1=None`` 时预测 | |||||
""" | """ | ||||
# prepare embeddings | # prepare embeddings | ||||
batch_size, length = words1.shape | batch_size, length = words1.shape | ||||
@@ -365,7 +365,7 @@ class BiaffineParser(GraphParser): | |||||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | ||||
feat = feat[unsort_idx] | feat = feat[unsort_idx] | ||||
else: | else: | ||||
seq_range = torch.arange(seq_len, dtype=torch.long, device=x.device)[None,:] | |||||
seq_range = torch.arange(length, dtype=torch.long, device=x.device)[None,:] | |||||
x = x + self.position_emb(seq_range) | x = x + self.position_emb(seq_range) | ||||
feat = self.encoder(x, mask.float()) | feat = self.encoder(x, mask.float()) | ||||
@@ -380,7 +380,7 @@ class BiaffineParser(GraphParser): | |||||
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | ||||
# use gold or predicted arc to predict label | # use gold or predicted arc to predict label | ||||
if gold_heads is None or not self.training: | |||||
if target1 is None or not self.training: | |||||
# use greedy decoding in training | # use greedy decoding in training | ||||
if self.training or self.use_greedy_infer: | if self.training or self.use_greedy_infer: | ||||
heads = self.greedy_decoder(arc_pred, mask) | heads = self.greedy_decoder(arc_pred, mask) | ||||
@@ -389,44 +389,45 @@ class BiaffineParser(GraphParser): | |||||
head_pred = heads | head_pred = heads | ||||
else: | else: | ||||
assert self.training # must be training mode | assert self.training # must be training mode | ||||
if gold_heads is None: | |||||
if target1 is None: | |||||
heads = self.greedy_decoder(arc_pred, mask) | heads = self.greedy_decoder(arc_pred, mask) | ||||
head_pred = heads | head_pred = heads | ||||
else: | else: | ||||
head_pred = None | head_pred = None | ||||
heads = gold_heads | |||||
heads = target1 | |||||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=words1.device).unsqueeze(1) | batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=words1.device).unsqueeze(1) | ||||
label_head = label_head[batch_range, heads].contiguous() | label_head = label_head[batch_range, heads].contiguous() | ||||
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | ||||
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask} | |||||
res_dict = {C.OUTPUTS(0): arc_pred, C.OUTPUTS(1): label_pred} | |||||
if head_pred is not None: | if head_pred is not None: | ||||
res_dict['head_pred'] = head_pred | |||||
res_dict[C.OUTPUTS(2)] = head_pred | |||||
return res_dict | return res_dict | ||||
@staticmethod | @staticmethod | ||||
def loss(arc_pred, label_pred, arc_true, label_true, mask): | |||||
def loss(pred1, pred2, target1, target2, seq_len): | |||||
""" | """ | ||||
Compute loss. | |||||
:param arc_pred: [batch_size, seq_len, seq_len] 边预测logits | |||||
:param label_pred: [batch_size, seq_len, num_label] label预测logits | |||||
:param arc_true: [batch_size, seq_len] 真实边的标注 | |||||
:param label_true: [batch_size, seq_len] 真实类别的标注 | |||||
:param mask: [batch_size, seq_len] 预测结果的mask | |||||
:return: loss value | |||||
计算parser的loss | |||||
:param pred1: [batch_size, seq_len, seq_len] 边预测logits | |||||
:param pred2: [batch_size, seq_len, num_label] label预测logits | |||||
:param target1: [batch_size, seq_len] 真实边的标注 | |||||
:param target2: [batch_size, seq_len] 真实类别的标注 | |||||
:param seq_len: [batch_size, seq_len] 真实目标的长度 | |||||
:return loss: scalar | |||||
""" | """ | ||||
batch_size, seq_len, _ = arc_pred.shape | |||||
batch_size, length, _ = pred1.shape | |||||
mask = seq_mask(seq_len, length) | |||||
flip_mask = (mask == 0) | flip_mask = (mask == 0) | ||||
_arc_pred = arc_pred.clone() | |||||
_arc_pred = pred1.clone() | |||||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) | _arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) | ||||
arc_logits = F.log_softmax(_arc_pred, dim=2) | arc_logits = F.log_softmax(_arc_pred, dim=2) | ||||
label_logits = F.log_softmax(label_pred, dim=2) | |||||
label_logits = F.log_softmax(pred2, dim=2) | |||||
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) | batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) | ||||
child_index = torch.arange(seq_len, device=arc_logits.device, dtype=torch.long).unsqueeze(0) | |||||
arc_loss = arc_logits[batch_index, child_index, arc_true] | |||||
label_loss = label_logits[batch_index, child_index, label_true] | |||||
child_index = torch.arange(length, device=arc_logits.device, dtype=torch.long).unsqueeze(0) | |||||
arc_loss = arc_logits[batch_index, child_index, target1] | |||||
label_loss = label_logits[batch_index, child_index, target2] | |||||
byte_mask = flip_mask.byte() | byte_mask = flip_mask.byte() | ||||
arc_loss.masked_fill_(byte_mask, 0) | arc_loss.masked_fill_(byte_mask, 0) | ||||
@@ -441,21 +442,16 @@ class BiaffineParser(GraphParser): | |||||
:param words1: [batch_size, seq_len] 输入word序列 | :param words1: [batch_size, seq_len] 输入word序列 | ||||
:param words2: [batch_size, seq_len] 输入pos序列 | :param words2: [batch_size, seq_len] 输入pos序列 | ||||
:param seq_len: [batch_size, seq_len] 输入序列长度 | :param seq_len: [batch_size, seq_len] 输入序列长度 | ||||
:param gold_heads: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效, | |||||
用于训练label分类器. 若为 ``None`` , 使用预测的heads输入到label分类器 | |||||
Default: ``None`` | |||||
:return dict: parsing结果:: | :return dict: parsing结果:: | ||||
arc_pred: [batch_size, seq_len, seq_len] 边预测logits | |||||
label_pred: [batch_size, seq_len, num_label] label预测logits | |||||
mask: [batch_size, seq_len] 预测结果的mask | |||||
head_pred: [batch_size, seq_len] heads的预测结果, 在 ``gold_heads=None`` 时预测 | |||||
pred1: [batch_size, seq_len] heads的预测结果 | |||||
pred2: [batch_size, seq_len, num_label] label预测logits | |||||
""" | """ | ||||
res = self(words1, words2, seq_len) | res = self(words1, words2, seq_len) | ||||
output = {} | output = {} | ||||
output['arc_pred'] = res.pop('head_pred') | |||||
_, label_pred = res.pop('label_pred').max(2) | |||||
output['label_pred'] = label_pred | |||||
output[C.OUTPUTS(0)] = res.pop(C.OUTPUTS(2)) | |||||
_, label_pred = res.pop(C.OUTPUTS(1)).max(2) | |||||
output[C.OUTPUTS(1)] = label_pred | |||||
return output | return output | ||||
@@ -463,41 +459,44 @@ class ParserLoss(LossFunc): | |||||
""" | """ | ||||
计算parser的loss | 计算parser的loss | ||||
:param arc_pred: [batch_size, seq_len, seq_len] 边预测logits | |||||
:param label_pred: [batch_size, seq_len, num_label] label预测logits | |||||
:param arc_true: [batch_size, seq_len] 真实边的标注 | |||||
:param label_true: [batch_size, seq_len] 真实类别的标注 | |||||
:param mask: [batch_size, seq_len] 预测结果的mask | |||||
:param pred1: [batch_size, seq_len, seq_len] 边预测logits | |||||
:param pred2: [batch_size, seq_len, num_label] label预测logits | |||||
:param target1: [batch_size, seq_len] 真实边的标注 | |||||
:param target2: [batch_size, seq_len] 真实类别的标注 | |||||
:param seq_len: [batch_size, seq_len] 真实目标的长度 | |||||
:return loss: scalar | :return loss: scalar | ||||
""" | """ | ||||
def __init__(self, arc_pred=None, label_pred=None, arc_true=None, label_true=None): | |||||
def __init__(self, pred1=None, pred2=None, | |||||
target1=None, target2=None, | |||||
seq_len=None): | |||||
super(ParserLoss, self).__init__(BiaffineParser.loss, | super(ParserLoss, self).__init__(BiaffineParser.loss, | ||||
arc_pred=arc_pred, | |||||
label_pred=label_pred, | |||||
arc_true=arc_true, | |||||
label_true=label_true) | |||||
pred1=pred1, | |||||
pred2=pred2, | |||||
target1=target1, | |||||
target2=target2, | |||||
seq_len=seq_len) | |||||
class ParserMetric(MetricBase): | class ParserMetric(MetricBase): | ||||
""" | """ | ||||
评估parser的性能 | 评估parser的性能 | ||||
:param arc_pred: 边预测logits | |||||
:param label_pred: label预测logits | |||||
:param arc_true: 真实边的标注 | |||||
:param label_true: 真实类别的标注 | |||||
:param pred1: 边预测logits | |||||
:param pred2: label预测logits | |||||
:param target1: 真实边的标注 | |||||
:param target2: 真实类别的标注 | |||||
:param seq_len: 序列长度 | :param seq_len: 序列长度 | ||||
:return dict: 评估结果:: | :return dict: 评估结果:: | ||||
UAS: 不带label时, 边预测的准确率 | UAS: 不带label时, 边预测的准确率 | ||||
LAS: 同时预测边和label的准确率 | LAS: 同时预测边和label的准确率 | ||||
""" | """ | ||||
def __init__(self, arc_pred=None, label_pred=None, | |||||
arc_true=None, label_true=None, seq_len=None): | |||||
def __init__(self, pred1=None, pred2=None, | |||||
target1=None, target2=None, seq_len=None): | |||||
super().__init__() | super().__init__() | ||||
self._init_param_map(arc_pred=arc_pred, label_pred=label_pred, | |||||
arc_true=arc_true, label_true=label_true, | |||||
self._init_param_map(pred1=pred1, pred2=pred2, | |||||
target1=target1, target2=target2, | |||||
seq_len=seq_len) | seq_len=seq_len) | ||||
self.num_arc = 0 | self.num_arc = 0 | ||||
self.num_label = 0 | self.num_label = 0 | ||||
@@ -509,17 +508,17 @@ class ParserMetric(MetricBase): | |||||
self.num_sample = self.num_label = self.num_arc = 0 | self.num_sample = self.num_label = self.num_arc = 0 | ||||
return res | return res | ||||
def evaluate(self, arc_pred, label_pred, arc_true, label_true, seq_len=None): | |||||
def evaluate(self, pred1, pred2, target1, target2, seq_len=None): | |||||
"""Evaluate the performance of prediction. | """Evaluate the performance of prediction. | ||||
""" | """ | ||||
if seq_len is None: | if seq_len is None: | ||||
seq_mask = arc_pred.new_ones(arc_pred.size(), dtype=torch.long) | |||||
seq_mask = pred1.new_ones(pred1.size(), dtype=torch.long) | |||||
else: | else: | ||||
seq_mask = seq_lens_to_masks(seq_len.long(), float=False).long() | seq_mask = seq_lens_to_masks(seq_len.long(), float=False).long() | ||||
# mask out <root> tag | # mask out <root> tag | ||||
seq_mask[:,0] = 0 | seq_mask[:,0] = 0 | ||||
head_pred_correct = (arc_pred == arc_true).long() * seq_mask | |||||
label_pred_correct = (label_pred == label_true).long() * head_pred_correct | |||||
head_pred_correct = (pred1 == target1).long() * seq_mask | |||||
label_pred_correct = (pred2 == target2).long() * head_pred_correct | |||||
self.num_arc += head_pred_correct.sum().item() | self.num_arc += head_pred_correct.sum().item() | ||||
self.num_label += label_pred_correct.sum().item() | self.num_label += label_pred_correct.sum().item() | ||||
self.num_sample += seq_mask.sum().item() | self.num_sample += seq_mask.sum().item() |
@@ -108,7 +108,7 @@ class STSeqLabel(nn.Module): | |||||
:param emb_dropout: 词嵌入的dropout概率. Default: 0.1 | :param emb_dropout: 词嵌入的dropout概率. Default: 0.1 | ||||
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 | :param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 | ||||
""" | """ | ||||
def __init__(self, vocab_size, emb_dim, num_cls, | |||||
def __init__(self, init_embed, num_cls, | |||||
hidden_size=300, | hidden_size=300, | ||||
num_layers=4, | num_layers=4, | ||||
num_head=8, | num_head=8, | ||||
@@ -118,8 +118,7 @@ class STSeqLabel(nn.Module): | |||||
emb_dropout=0.1, | emb_dropout=0.1, | ||||
dropout=0.1,): | dropout=0.1,): | ||||
super(STSeqLabel, self).__init__() | super(STSeqLabel, self).__init__() | ||||
self.enc = StarTransEnc(vocab_size=vocab_size, | |||||
emb_dim=emb_dim, | |||||
self.enc = StarTransEnc(init_embed=init_embed, | |||||
hidden_size=hidden_size, | hidden_size=hidden_size, | ||||
num_layers=num_layers, | num_layers=num_layers, | ||||
num_head=num_head, | num_head=num_head, | ||||
@@ -170,7 +169,7 @@ class STSeqCls(nn.Module): | |||||
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 | :param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 | ||||
""" | """ | ||||
def __init__(self, vocab_size, emb_dim, num_cls, | |||||
def __init__(self, init_embed, num_cls, | |||||
hidden_size=300, | hidden_size=300, | ||||
num_layers=4, | num_layers=4, | ||||
num_head=8, | num_head=8, | ||||
@@ -180,8 +179,7 @@ class STSeqCls(nn.Module): | |||||
emb_dropout=0.1, | emb_dropout=0.1, | ||||
dropout=0.1,): | dropout=0.1,): | ||||
super(STSeqCls, self).__init__() | super(STSeqCls, self).__init__() | ||||
self.enc = StarTransEnc(vocab_size=vocab_size, | |||||
emb_dim=emb_dim, | |||||
self.enc = StarTransEnc(init_embed=init_embed, | |||||
hidden_size=hidden_size, | hidden_size=hidden_size, | ||||
num_layers=num_layers, | num_layers=num_layers, | ||||
num_head=num_head, | num_head=num_head, | ||||
@@ -232,7 +230,7 @@ class STNLICls(nn.Module): | |||||
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 | :param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 | ||||
""" | """ | ||||
def __init__(self, vocab_size, emb_dim, num_cls, | |||||
def __init__(self, init_embed, num_cls, | |||||
hidden_size=300, | hidden_size=300, | ||||
num_layers=4, | num_layers=4, | ||||
num_head=8, | num_head=8, | ||||
@@ -242,8 +240,7 @@ class STNLICls(nn.Module): | |||||
emb_dropout=0.1, | emb_dropout=0.1, | ||||
dropout=0.1,): | dropout=0.1,): | ||||
super(STNLICls, self).__init__() | super(STNLICls, self).__init__() | ||||
self.enc = StarTransEnc(vocab_size=vocab_size, | |||||
emb_dim=emb_dim, | |||||
self.enc = StarTransEnc(init_embed=init_embed, | |||||
hidden_size=hidden_size, | hidden_size=hidden_size, | ||||
num_layers=num_layers, | num_layers=num_layers, | ||||
num_head=num_head, | num_head=num_head, | ||||
@@ -0,0 +1,151 @@ | |||||
""" | |||||
此模块可以非常方便的测试模型。 | |||||
若你的模型属于:文本分类,序列标注,自然语言推理(NLI),可以直接使用此模块测试 | |||||
若模型不属于上述类别,也可以自己准备假数据,设定loss和metric进行测试 | |||||
此模块的测试仅保证模型能使用fastNLP进行训练和测试,不测试模型实际性能 | |||||
Example:: | |||||
# import 全大写变量... | |||||
from model_runner import * | |||||
# 测试一个文本分类模型 | |||||
init_emb = (VOCAB_SIZE, 50) | |||||
model = SomeModel(init_emb, num_cls=NUM_CLS) | |||||
RUNNER.run_model_with_task(TEXT_CLS, model) | |||||
# 序列标注模型 | |||||
RUNNER.run_model_with_task(POS_TAGGING, model) | |||||
# NLI模型 | |||||
RUNNER.run_model_with_task(NLI, model) | |||||
# 自定义模型 | |||||
RUNNER.run_model(model, data=get_mydata(), | |||||
loss=Myloss(), metrics=Mymetric()) | |||||
""" | |||||
from fastNLP import Trainer, Tester, DataSet | |||||
from fastNLP import AccuracyMetric | |||||
from fastNLP import CrossEntropyLoss | |||||
from fastNLP.core.const import Const as C | |||||
from random import randrange | |||||
VOCAB_SIZE = 100 | |||||
NUM_CLS = 100 | |||||
MAX_LEN = 10 | |||||
N_SAMPLES = 100 | |||||
N_EPOCHS = 1 | |||||
BATCH_SIZE = 5 | |||||
TEXT_CLS = 'text_cls' | |||||
POS_TAGGING = 'pos_tagging' | |||||
NLI = 'nli' | |||||
class ModelRunner(): | |||||
def gen_seq(self, length, vocab_size): | |||||
"""generate fake sequence indexes with given length""" | |||||
# reserve 0 for padding | |||||
return [randrange(1, vocab_size) for _ in range(length)] | |||||
def gen_var_seq(self, max_len, vocab_size): | |||||
"""generate fake sequence indexes in variant length""" | |||||
length = randrange(3, max_len) # at least 3 words in a seq | |||||
return self.gen_seq(length, vocab_size) | |||||
def prepare_text_classification_data(self): | |||||
index = 'index' | |||||
ds = DataSet({index: list(range(N_SAMPLES))}) | |||||
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||||
field_name=index, new_field_name=C.INPUT, | |||||
is_input=True) | |||||
ds.apply_field(lambda x: randrange(NUM_CLS), | |||||
field_name=index, new_field_name=C.TARGET, | |||||
is_target=True) | |||||
ds.apply_field(len, C.INPUT, C.INPUT_LEN, | |||||
is_input=True) | |||||
return ds | |||||
def prepare_pos_tagging_data(self): | |||||
index = 'index' | |||||
ds = DataSet({index: list(range(N_SAMPLES))}) | |||||
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||||
field_name=index, new_field_name=C.INPUT, | |||||
is_input=True) | |||||
ds.apply_field(lambda x: self.gen_seq(len(x), NUM_CLS), | |||||
field_name=C.INPUT, new_field_name=C.TARGET, | |||||
is_target=True) | |||||
ds.apply_field(len, C.INPUT, C.INPUT_LEN, | |||||
is_input=True, is_target=True) | |||||
return ds | |||||
def prepare_nli_data(self): | |||||
index = 'index' | |||||
ds = DataSet({index: list(range(N_SAMPLES))}) | |||||
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||||
field_name=index, new_field_name=C.INPUTS(0), | |||||
is_input=True) | |||||
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||||
field_name=index, new_field_name=C.INPUTS(1), | |||||
is_input=True) | |||||
ds.apply_field(lambda x: randrange(NUM_CLS), | |||||
field_name=index, new_field_name=C.TARGET, | |||||
is_target=True) | |||||
ds.apply_field(len, C.INPUTS(0), C.INPUT_LENS(0), | |||||
is_input=True, is_target=True) | |||||
ds.apply_field(len, C.INPUTS(1), C.INPUT_LENS(1), | |||||
is_input = True, is_target = True) | |||||
ds.set_input(C.INPUTS(0), C.INPUTS(1)) | |||||
ds.set_target(C.TARGET) | |||||
return ds | |||||
def run_text_classification(self, model, data=None): | |||||
if data is None: | |||||
data = self.prepare_text_classification_data() | |||||
loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) | |||||
metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) | |||||
self.run_model(model, data, loss, metric) | |||||
def run_pos_tagging(self, model, data=None): | |||||
if data is None: | |||||
data = self.prepare_pos_tagging_data() | |||||
loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET, padding_idx=0) | |||||
metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET, seq_len=C.INPUT_LEN) | |||||
self.run_model(model, data, loss, metric) | |||||
def run_nli(self, model, data=None): | |||||
if data is None: | |||||
data = self.prepare_nli_data() | |||||
loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) | |||||
metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) | |||||
self.run_model(model, data, loss, metric) | |||||
def run_model(self, model, data, loss, metrics): | |||||
"""run a model, test if it can run with fastNLP""" | |||||
print('testing model:', model.__class__.__name__) | |||||
tester = Tester(data=data, model=model, metrics=metrics, | |||||
batch_size=BATCH_SIZE, verbose=0) | |||||
before_train = tester.test() | |||||
trainer = Trainer(model=model, train_data=data, dev_data=None, | |||||
n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, | |||||
loss=loss, | |||||
save_path=None, | |||||
use_tqdm=False) | |||||
trainer.train(load_best_model=False) | |||||
after_train = tester.test() | |||||
for metric_name, v1 in before_train.items(): | |||||
assert metric_name in after_train | |||||
# # at least we can sure model params changed, even if we don't know performance | |||||
# v2 = after_train[metric_name] | |||||
# assert v1 != v2 | |||||
def run_model_with_task(self, task, model): | |||||
"""run a model with certain task""" | |||||
TASKS = { | |||||
TEXT_CLS: self.run_text_classification, | |||||
POS_TAGGING: self.run_pos_tagging, | |||||
NLI: self.run_nli, | |||||
} | |||||
assert task in TASKS | |||||
TASKS[task](model) | |||||
RUNNER = ModelRunner() |
@@ -2,90 +2,33 @@ import unittest | |||||
import fastNLP | import fastNLP | ||||
from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric | from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric | ||||
data_file = """ | |||||
1 The _ DET DT _ 3 det _ _ | |||||
2 new _ ADJ JJ _ 3 amod _ _ | |||||
3 rate _ NOUN NN _ 6 nsubj _ _ | |||||
4 will _ AUX MD _ 6 aux _ _ | |||||
5 be _ VERB VB _ 6 cop _ _ | |||||
6 payable _ ADJ JJ _ 0 root _ _ | |||||
7 mask _ ADJ JJ _ 6 punct _ _ | |||||
8 mask _ ADJ JJ _ 6 punct _ _ | |||||
9 cents _ NOUN NNS _ 4 nmod _ _ | |||||
10 from _ ADP IN _ 12 case _ _ | |||||
11 seven _ NUM CD _ 12 nummod _ _ | |||||
12 cents _ NOUN NNS _ 4 nmod _ _ | |||||
13 a _ DET DT _ 14 det _ _ | |||||
14 share _ NOUN NN _ 12 nmod:npmod _ _ | |||||
15 . _ PUNCT . _ 4 punct _ _ | |||||
1 The _ DET DT _ 3 det _ _ | |||||
2 new _ ADJ JJ _ 3 amod _ _ | |||||
3 rate _ NOUN NN _ 6 nsubj _ _ | |||||
4 will _ AUX MD _ 6 aux _ _ | |||||
5 be _ VERB VB _ 6 cop _ _ | |||||
6 payable _ ADJ JJ _ 0 root _ _ | |||||
7 Feb. _ PROPN NNP _ 6 nmod:tmod _ _ | |||||
8 15 _ NUM CD _ 7 nummod _ _ | |||||
9 . _ PUNCT . _ 6 punct _ _ | |||||
1 A _ DET DT _ 3 det _ _ | |||||
2 record _ NOUN NN _ 3 compound _ _ | |||||
3 date _ NOUN NN _ 7 nsubjpass _ _ | |||||
4 has _ AUX VBZ _ 7 aux _ _ | |||||
5 n't _ PART RB _ 7 neg _ _ | |||||
6 been _ AUX VBN _ 7 auxpass _ _ | |||||
7 set _ VERB VBN _ 0 root _ _ | |||||
8 . _ PUNCT . _ 7 punct _ _ | |||||
""" | |||||
def init_data(): | |||||
ds = fastNLP.DataSet() | |||||
v = {'words1': fastNLP.Vocabulary(), | |||||
'words2': fastNLP.Vocabulary(), | |||||
'label_true': fastNLP.Vocabulary()} | |||||
data = [] | |||||
for line in data_file.split('\n'): | |||||
line = line.split() | |||||
if len(line) == 0 and len(data) > 0: | |||||
data = list(zip(*data)) | |||||
ds.append(fastNLP.Instance(words1=data[1], | |||||
words2=data[4], | |||||
arc_true=data[6], | |||||
label_true=data[7])) | |||||
data = [] | |||||
elif len(line) > 0: | |||||
data.append(line) | |||||
for name in ['words1', 'words2', 'label_true']: | |||||
ds.apply(lambda x: ['<st>'] + list(x[name]), new_field_name=name) | |||||
ds.apply(lambda x: v[name].add_word_lst(x[name])) | |||||
for name in ['words1', 'words2', 'label_true']: | |||||
ds.apply(lambda x: [v[name].to_index(w) for w in x[name]], new_field_name=name) | |||||
ds.apply(lambda x: [0] + list(map(int, x['arc_true'])), new_field_name='arc_true') | |||||
ds.apply(lambda x: len(x['words1']), new_field_name='seq_len') | |||||
ds.set_input('words1', 'words2', 'seq_len', flag=True) | |||||
ds.set_target('arc_true', 'label_true', 'seq_len', flag=True) | |||||
return ds, v['words1'], v['words2'], v['label_true'] | |||||
from .model_runner import * | |||||
def prepare_parser_data(): | |||||
index = 'index' | |||||
ds = DataSet({index: list(range(N_SAMPLES))}) | |||||
ds.apply_field(lambda x: RUNNER.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||||
field_name=index, new_field_name=C.INPUTS(0), | |||||
is_input=True) | |||||
ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS), | |||||
field_name=C.INPUTS(0), new_field_name=C.INPUTS(1), | |||||
is_input=True) | |||||
# target1 is heads, should in range(0, len(words)) | |||||
ds.apply_field(lambda x: RUNNER.gen_seq(len(x), len(x)), | |||||
field_name=C.INPUTS(0), new_field_name=C.TARGETS(0), | |||||
is_target=True) | |||||
ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS), | |||||
field_name=C.INPUTS(0), new_field_name=C.TARGETS(1), | |||||
is_target=True) | |||||
ds.apply_field(len, field_name=C.INPUTS(0), new_field_name=C.INPUT_LEN, | |||||
is_input=True, is_target=True) | |||||
return ds | |||||
class TestBiaffineParser(unittest.TestCase): | class TestBiaffineParser(unittest.TestCase): | ||||
def test_train(self): | def test_train(self): | ||||
ds, v1, v2, v3 = init_data() | |||||
model = BiaffineParser(init_embed=(len(v1), 30), | |||||
pos_vocab_size=len(v2), pos_emb_dim=30, | |||||
num_label=len(v3), encoder='var-lstm') | |||||
trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds, | |||||
loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', | |||||
batch_size=1, validate_every=10, | |||||
n_epochs=10, use_tqdm=False) | |||||
trainer.train(load_best_model=False) | |||||
if __name__ == '__main__': | |||||
unittest.main() | |||||
model = BiaffineParser(init_embed=(VOCAB_SIZE, 30), | |||||
pos_vocab_size=VOCAB_SIZE, pos_emb_dim=30, | |||||
num_label=NUM_CLS, encoder='var-lstm') | |||||
ds = prepare_parser_data() | |||||
RUNNER.run_model(model, ds, loss=ParserLoss(), metrics=ParserMetric()) |
@@ -0,0 +1,16 @@ | |||||
from .model_runner import * | |||||
from fastNLP.models.star_transformer import STNLICls, STSeqCls, STSeqLabel | |||||
# add star-transformer tests, for 3 kinds of tasks. | |||||
def test_cls(): | |||||
model = STSeqCls((VOCAB_SIZE, 100), NUM_CLS, dropout=0) | |||||
RUNNER.run_model_with_task(TEXT_CLS, model) | |||||
def test_nli(): | |||||
model = STNLICls((VOCAB_SIZE, 100), NUM_CLS, dropout=0) | |||||
RUNNER.run_model_with_task(NLI, model) | |||||
def test_seq_label(): | |||||
model = STSeqLabel((VOCAB_SIZE, 100), NUM_CLS, dropout=0) | |||||
RUNNER.run_model_with_task(POS_TAGGING, model) |