Browse Source

- add model runner for easier test models

- add model tests
tags/v0.4.10
yunfan 5 years ago
parent
commit
4f65e17d1a
6 changed files with 268 additions and 162 deletions
  1. +68
    -69
      fastNLP/models/biaffine_parser.py
  2. +6
    -9
      fastNLP/models/star_transformer.py
  3. +0
    -0
      test/models/__init__.py
  4. +151
    -0
      test/models/model_runner.py
  5. +27
    -84
      test/models/test_biaffine_parser.py
  6. +16
    -0
      test/models/test_star_trans.py

+ 68
- 69
fastNLP/models/biaffine_parser.py View File

@@ -7,16 +7,17 @@ import torch
from torch import nn
from torch.nn import functional as F

from fastNLP.core.losses import LossFunc
from fastNLP.core.metrics import MetricBase
from fastNLP.core.utils import seq_lens_to_masks
from fastNLP.models.base_model import BaseModel
from fastNLP.modules.dropout import TimestepDropout
from fastNLP.modules.encoder.transformer import TransformerEncoder
from fastNLP.modules.encoder.variational_rnn import VarLSTM
from fastNLP.modules.utils import initial_parameter
from fastNLP.modules.utils import seq_mask
from fastNLP.modules.utils import get_embeddings
from ..core.const import Const as C
from ..core.losses import LossFunc
from ..core.metrics import MetricBase
from ..core.utils import seq_lens_to_masks
from ..modules.dropout import TimestepDropout
from ..modules.encoder.transformer import TransformerEncoder
from ..modules.encoder.variational_rnn import VarLSTM
from ..modules.utils import initial_parameter
from ..modules.utils import seq_mask
from ..modules.utils import get_embeddings
from .base_model import BaseModel

def _mst(scores):
"""
@@ -325,21 +326,20 @@ class BiaffineParser(GraphParser):
for p in m.parameters():
nn.init.normal_(p, 0, 0.1)

def forward(self, words1, words2, seq_len, gold_heads=None):
def forward(self, words1, words2, seq_len, target1=None):
"""模型forward阶段

:param words1: [batch_size, seq_len] 输入word序列
:param words2: [batch_size, seq_len] 输入pos序列
:param seq_len: [batch_size, seq_len] 输入序列长度
:param gold_heads: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效,
:param target1: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效,
用于训练label分类器. 若为 ``None`` , 使用预测的heads输入到label分类器
Default: ``None``
:return dict: parsing结果::

arc_pred: [batch_size, seq_len, seq_len] 边预测logits
label_pred: [batch_size, seq_len, num_label] label预测logits
mask: [batch_size, seq_len] 预测结果的mask
head_pred: [batch_size, seq_len] heads的预测结果, 在 ``gold_heads=None`` 时预测
pred1: [batch_size, seq_len, seq_len] 边预测logits
pred2: [batch_size, seq_len, num_label] label预测logits
pred3: [batch_size, seq_len] heads的预测结果, 在 ``target1=None`` 时预测
"""
# prepare embeddings
batch_size, length = words1.shape
@@ -365,7 +365,7 @@ class BiaffineParser(GraphParser):
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
feat = feat[unsort_idx]
else:
seq_range = torch.arange(seq_len, dtype=torch.long, device=x.device)[None,:]
seq_range = torch.arange(length, dtype=torch.long, device=x.device)[None,:]
x = x + self.position_emb(seq_range)
feat = self.encoder(x, mask.float())

@@ -380,7 +380,7 @@ class BiaffineParser(GraphParser):
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L]

# use gold or predicted arc to predict label
if gold_heads is None or not self.training:
if target1 is None or not self.training:
# use greedy decoding in training
if self.training or self.use_greedy_infer:
heads = self.greedy_decoder(arc_pred, mask)
@@ -389,44 +389,45 @@ class BiaffineParser(GraphParser):
head_pred = heads
else:
assert self.training # must be training mode
if gold_heads is None:
if target1 is None:
heads = self.greedy_decoder(arc_pred, mask)
head_pred = heads
else:
head_pred = None
heads = gold_heads
heads = target1

batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=words1.device).unsqueeze(1)
label_head = label_head[batch_range, heads].contiguous()
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label]
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask}
res_dict = {C.OUTPUTS(0): arc_pred, C.OUTPUTS(1): label_pred}
if head_pred is not None:
res_dict['head_pred'] = head_pred
res_dict[C.OUTPUTS(2)] = head_pred
return res_dict

@staticmethod
def loss(arc_pred, label_pred, arc_true, label_true, mask):
def loss(pred1, pred2, target1, target2, seq_len):
"""
Compute loss.
:param arc_pred: [batch_size, seq_len, seq_len] 边预测logits
:param label_pred: [batch_size, seq_len, num_label] label预测logits
:param arc_true: [batch_size, seq_len] 真实边的标注
:param label_true: [batch_size, seq_len] 真实类别的标注
:param mask: [batch_size, seq_len] 预测结果的mask
:return: loss value
计算parser的loss
:param pred1: [batch_size, seq_len, seq_len] 边预测logits
:param pred2: [batch_size, seq_len, num_label] label预测logits
:param target1: [batch_size, seq_len] 真实边的标注
:param target2: [batch_size, seq_len] 真实类别的标注
:param seq_len: [batch_size, seq_len] 真实目标的长度
:return loss: scalar
"""

batch_size, seq_len, _ = arc_pred.shape
batch_size, length, _ = pred1.shape
mask = seq_mask(seq_len, length)
flip_mask = (mask == 0)
_arc_pred = arc_pred.clone()
_arc_pred = pred1.clone()
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf'))
arc_logits = F.log_softmax(_arc_pred, dim=2)
label_logits = F.log_softmax(label_pred, dim=2)
label_logits = F.log_softmax(pred2, dim=2)
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1)
child_index = torch.arange(seq_len, device=arc_logits.device, dtype=torch.long).unsqueeze(0)
arc_loss = arc_logits[batch_index, child_index, arc_true]
label_loss = label_logits[batch_index, child_index, label_true]
child_index = torch.arange(length, device=arc_logits.device, dtype=torch.long).unsqueeze(0)
arc_loss = arc_logits[batch_index, child_index, target1]
label_loss = label_logits[batch_index, child_index, target2]

byte_mask = flip_mask.byte()
arc_loss.masked_fill_(byte_mask, 0)
@@ -441,21 +442,16 @@ class BiaffineParser(GraphParser):
:param words1: [batch_size, seq_len] 输入word序列
:param words2: [batch_size, seq_len] 输入pos序列
:param seq_len: [batch_size, seq_len] 输入序列长度
:param gold_heads: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效,
用于训练label分类器. 若为 ``None`` , 使用预测的heads输入到label分类器
Default: ``None``
:return dict: parsing结果::

arc_pred: [batch_size, seq_len, seq_len] 边预测logits
label_pred: [batch_size, seq_len, num_label] label预测logits
mask: [batch_size, seq_len] 预测结果的mask
head_pred: [batch_size, seq_len] heads的预测结果, 在 ``gold_heads=None`` 时预测
pred1: [batch_size, seq_len] heads的预测结果
pred2: [batch_size, seq_len, num_label] label预测logits
"""
res = self(words1, words2, seq_len)
output = {}
output['arc_pred'] = res.pop('head_pred')
_, label_pred = res.pop('label_pred').max(2)
output['label_pred'] = label_pred
output[C.OUTPUTS(0)] = res.pop(C.OUTPUTS(2))
_, label_pred = res.pop(C.OUTPUTS(1)).max(2)
output[C.OUTPUTS(1)] = label_pred
return output


@@ -463,41 +459,44 @@ class ParserLoss(LossFunc):
"""
计算parser的loss

:param arc_pred: [batch_size, seq_len, seq_len] 边预测logits
:param label_pred: [batch_size, seq_len, num_label] label预测logits
:param arc_true: [batch_size, seq_len] 真实边的标注
:param label_true: [batch_size, seq_len] 真实类别的标注
:param mask: [batch_size, seq_len] 预测结果的mask
:param pred1: [batch_size, seq_len, seq_len] 边预测logits
:param pred2: [batch_size, seq_len, num_label] label预测logits
:param target1: [batch_size, seq_len] 真实边的标注
:param target2: [batch_size, seq_len] 真实类别的标注
:param seq_len: [batch_size, seq_len] 真实目标的长度
:return loss: scalar
"""
def __init__(self, arc_pred=None, label_pred=None, arc_true=None, label_true=None):
def __init__(self, pred1=None, pred2=None,
target1=None, target2=None,
seq_len=None):
super(ParserLoss, self).__init__(BiaffineParser.loss,
arc_pred=arc_pred,
label_pred=label_pred,
arc_true=arc_true,
label_true=label_true)
pred1=pred1,
pred2=pred2,
target1=target1,
target2=target2,
seq_len=seq_len)


class ParserMetric(MetricBase):
"""
评估parser的性能

:param arc_pred: 边预测logits
:param label_pred: label预测logits
:param arc_true: 真实边的标注
:param label_true: 真实类别的标注
:param pred1: 边预测logits
:param pred2: label预测logits
:param target1: 真实边的标注
:param target2: 真实类别的标注
:param seq_len: 序列长度
:return dict: 评估结果::

UAS: 不带label时, 边预测的准确率
LAS: 同时预测边和label的准确率
"""
def __init__(self, arc_pred=None, label_pred=None,
arc_true=None, label_true=None, seq_len=None):
def __init__(self, pred1=None, pred2=None,
target1=None, target2=None, seq_len=None):

super().__init__()
self._init_param_map(arc_pred=arc_pred, label_pred=label_pred,
arc_true=arc_true, label_true=label_true,
self._init_param_map(pred1=pred1, pred2=pred2,
target1=target1, target2=target2,
seq_len=seq_len)
self.num_arc = 0
self.num_label = 0
@@ -509,17 +508,17 @@ class ParserMetric(MetricBase):
self.num_sample = self.num_label = self.num_arc = 0
return res

def evaluate(self, arc_pred, label_pred, arc_true, label_true, seq_len=None):
def evaluate(self, pred1, pred2, target1, target2, seq_len=None):
"""Evaluate the performance of prediction.
"""
if seq_len is None:
seq_mask = arc_pred.new_ones(arc_pred.size(), dtype=torch.long)
seq_mask = pred1.new_ones(pred1.size(), dtype=torch.long)
else:
seq_mask = seq_lens_to_masks(seq_len.long(), float=False).long()
# mask out <root> tag
seq_mask[:,0] = 0
head_pred_correct = (arc_pred == arc_true).long() * seq_mask
label_pred_correct = (label_pred == label_true).long() * head_pred_correct
head_pred_correct = (pred1 == target1).long() * seq_mask
label_pred_correct = (pred2 == target2).long() * head_pred_correct
self.num_arc += head_pred_correct.sum().item()
self.num_label += label_pred_correct.sum().item()
self.num_sample += seq_mask.sum().item()

+ 6
- 9
fastNLP/models/star_transformer.py View File

@@ -108,7 +108,7 @@ class STSeqLabel(nn.Module):
:param emb_dropout: 词嵌入的dropout概率. Default: 0.1
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1
"""
def __init__(self, vocab_size, emb_dim, num_cls,
def __init__(self, init_embed, num_cls,
hidden_size=300,
num_layers=4,
num_head=8,
@@ -118,8 +118,7 @@ class STSeqLabel(nn.Module):
emb_dropout=0.1,
dropout=0.1,):
super(STSeqLabel, self).__init__()
self.enc = StarTransEnc(vocab_size=vocab_size,
emb_dim=emb_dim,
self.enc = StarTransEnc(init_embed=init_embed,
hidden_size=hidden_size,
num_layers=num_layers,
num_head=num_head,
@@ -170,7 +169,7 @@ class STSeqCls(nn.Module):
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1
"""

def __init__(self, vocab_size, emb_dim, num_cls,
def __init__(self, init_embed, num_cls,
hidden_size=300,
num_layers=4,
num_head=8,
@@ -180,8 +179,7 @@ class STSeqCls(nn.Module):
emb_dropout=0.1,
dropout=0.1,):
super(STSeqCls, self).__init__()
self.enc = StarTransEnc(vocab_size=vocab_size,
emb_dim=emb_dim,
self.enc = StarTransEnc(init_embed=init_embed,
hidden_size=hidden_size,
num_layers=num_layers,
num_head=num_head,
@@ -232,7 +230,7 @@ class STNLICls(nn.Module):
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1
"""

def __init__(self, vocab_size, emb_dim, num_cls,
def __init__(self, init_embed, num_cls,
hidden_size=300,
num_layers=4,
num_head=8,
@@ -242,8 +240,7 @@ class STNLICls(nn.Module):
emb_dropout=0.1,
dropout=0.1,):
super(STNLICls, self).__init__()
self.enc = StarTransEnc(vocab_size=vocab_size,
emb_dim=emb_dim,
self.enc = StarTransEnc(init_embed=init_embed,
hidden_size=hidden_size,
num_layers=num_layers,
num_head=num_head,


+ 0
- 0
test/models/__init__.py View File


+ 151
- 0
test/models/model_runner.py View File

@@ -0,0 +1,151 @@
"""
此模块可以非常方便的测试模型。
若你的模型属于:文本分类,序列标注,自然语言推理(NLI),可以直接使用此模块测试
若模型不属于上述类别,也可以自己准备假数据,设定loss和metric进行测试

此模块的测试仅保证模型能使用fastNLP进行训练和测试,不测试模型实际性能

Example::
# import 全大写变量...
from model_runner import *

# 测试一个文本分类模型
init_emb = (VOCAB_SIZE, 50)
model = SomeModel(init_emb, num_cls=NUM_CLS)
RUNNER.run_model_with_task(TEXT_CLS, model)

# 序列标注模型
RUNNER.run_model_with_task(POS_TAGGING, model)

# NLI模型
RUNNER.run_model_with_task(NLI, model)

# 自定义模型
RUNNER.run_model(model, data=get_mydata(),
loss=Myloss(), metrics=Mymetric())
"""
from fastNLP import Trainer, Tester, DataSet
from fastNLP import AccuracyMetric
from fastNLP import CrossEntropyLoss
from fastNLP.core.const import Const as C
from random import randrange

VOCAB_SIZE = 100
NUM_CLS = 100
MAX_LEN = 10
N_SAMPLES = 100
N_EPOCHS = 1
BATCH_SIZE = 5

TEXT_CLS = 'text_cls'
POS_TAGGING = 'pos_tagging'
NLI = 'nli'

class ModelRunner():
def gen_seq(self, length, vocab_size):
"""generate fake sequence indexes with given length"""
# reserve 0 for padding
return [randrange(1, vocab_size) for _ in range(length)]

def gen_var_seq(self, max_len, vocab_size):
"""generate fake sequence indexes in variant length"""
length = randrange(3, max_len) # at least 3 words in a seq
return self.gen_seq(length, vocab_size)

def prepare_text_classification_data(self):
index = 'index'
ds = DataSet({index: list(range(N_SAMPLES))})
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE),
field_name=index, new_field_name=C.INPUT,
is_input=True)
ds.apply_field(lambda x: randrange(NUM_CLS),
field_name=index, new_field_name=C.TARGET,
is_target=True)
ds.apply_field(len, C.INPUT, C.INPUT_LEN,
is_input=True)
return ds

def prepare_pos_tagging_data(self):
index = 'index'
ds = DataSet({index: list(range(N_SAMPLES))})
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE),
field_name=index, new_field_name=C.INPUT,
is_input=True)
ds.apply_field(lambda x: self.gen_seq(len(x), NUM_CLS),
field_name=C.INPUT, new_field_name=C.TARGET,
is_target=True)
ds.apply_field(len, C.INPUT, C.INPUT_LEN,
is_input=True, is_target=True)
return ds

def prepare_nli_data(self):
index = 'index'
ds = DataSet({index: list(range(N_SAMPLES))})
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE),
field_name=index, new_field_name=C.INPUTS(0),
is_input=True)
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE),
field_name=index, new_field_name=C.INPUTS(1),
is_input=True)
ds.apply_field(lambda x: randrange(NUM_CLS),
field_name=index, new_field_name=C.TARGET,
is_target=True)
ds.apply_field(len, C.INPUTS(0), C.INPUT_LENS(0),
is_input=True, is_target=True)
ds.apply_field(len, C.INPUTS(1), C.INPUT_LENS(1),
is_input = True, is_target = True)
ds.set_input(C.INPUTS(0), C.INPUTS(1))
ds.set_target(C.TARGET)
return ds

def run_text_classification(self, model, data=None):
if data is None:
data = self.prepare_text_classification_data()
loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET)
metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET)
self.run_model(model, data, loss, metric)

def run_pos_tagging(self, model, data=None):
if data is None:
data = self.prepare_pos_tagging_data()
loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET, padding_idx=0)
metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET, seq_len=C.INPUT_LEN)
self.run_model(model, data, loss, metric)

def run_nli(self, model, data=None):
if data is None:
data = self.prepare_nli_data()
loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET)
metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET)
self.run_model(model, data, loss, metric)

def run_model(self, model, data, loss, metrics):
"""run a model, test if it can run with fastNLP"""
print('testing model:', model.__class__.__name__)
tester = Tester(data=data, model=model, metrics=metrics,
batch_size=BATCH_SIZE, verbose=0)
before_train = tester.test()
trainer = Trainer(model=model, train_data=data, dev_data=None,
n_epochs=N_EPOCHS, batch_size=BATCH_SIZE,
loss=loss,
save_path=None,
use_tqdm=False)
trainer.train(load_best_model=False)
after_train = tester.test()
for metric_name, v1 in before_train.items():
assert metric_name in after_train
# # at least we can sure model params changed, even if we don't know performance
# v2 = after_train[metric_name]
# assert v1 != v2

def run_model_with_task(self, task, model):
"""run a model with certain task"""
TASKS = {
TEXT_CLS: self.run_text_classification,
POS_TAGGING: self.run_pos_tagging,
NLI: self.run_nli,
}
assert task in TASKS
TASKS[task](model)

RUNNER = ModelRunner()

+ 27
- 84
test/models/test_biaffine_parser.py View File

@@ -2,90 +2,33 @@ import unittest

import fastNLP
from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric

data_file = """
1 The _ DET DT _ 3 det _ _
2 new _ ADJ JJ _ 3 amod _ _
3 rate _ NOUN NN _ 6 nsubj _ _
4 will _ AUX MD _ 6 aux _ _
5 be _ VERB VB _ 6 cop _ _
6 payable _ ADJ JJ _ 0 root _ _
7 mask _ ADJ JJ _ 6 punct _ _
8 mask _ ADJ JJ _ 6 punct _ _
9 cents _ NOUN NNS _ 4 nmod _ _
10 from _ ADP IN _ 12 case _ _
11 seven _ NUM CD _ 12 nummod _ _
12 cents _ NOUN NNS _ 4 nmod _ _
13 a _ DET DT _ 14 det _ _
14 share _ NOUN NN _ 12 nmod:npmod _ _
15 . _ PUNCT . _ 4 punct _ _

1 The _ DET DT _ 3 det _ _
2 new _ ADJ JJ _ 3 amod _ _
3 rate _ NOUN NN _ 6 nsubj _ _
4 will _ AUX MD _ 6 aux _ _
5 be _ VERB VB _ 6 cop _ _
6 payable _ ADJ JJ _ 0 root _ _
7 Feb. _ PROPN NNP _ 6 nmod:tmod _ _
8 15 _ NUM CD _ 7 nummod _ _
9 . _ PUNCT . _ 6 punct _ _

1 A _ DET DT _ 3 det _ _
2 record _ NOUN NN _ 3 compound _ _
3 date _ NOUN NN _ 7 nsubjpass _ _
4 has _ AUX VBZ _ 7 aux _ _
5 n't _ PART RB _ 7 neg _ _
6 been _ AUX VBN _ 7 auxpass _ _
7 set _ VERB VBN _ 0 root _ _
8 . _ PUNCT . _ 7 punct _ _

"""


def init_data():
ds = fastNLP.DataSet()
v = {'words1': fastNLP.Vocabulary(),
'words2': fastNLP.Vocabulary(),
'label_true': fastNLP.Vocabulary()}
data = []
for line in data_file.split('\n'):
line = line.split()
if len(line) == 0 and len(data) > 0:
data = list(zip(*data))
ds.append(fastNLP.Instance(words1=data[1],
words2=data[4],
arc_true=data[6],
label_true=data[7]))
data = []
elif len(line) > 0:
data.append(line)

for name in ['words1', 'words2', 'label_true']:
ds.apply(lambda x: ['<st>'] + list(x[name]), new_field_name=name)
ds.apply(lambda x: v[name].add_word_lst(x[name]))

for name in ['words1', 'words2', 'label_true']:
ds.apply(lambda x: [v[name].to_index(w) for w in x[name]], new_field_name=name)

ds.apply(lambda x: [0] + list(map(int, x['arc_true'])), new_field_name='arc_true')
ds.apply(lambda x: len(x['words1']), new_field_name='seq_len')
ds.set_input('words1', 'words2', 'seq_len', flag=True)
ds.set_target('arc_true', 'label_true', 'seq_len', flag=True)
return ds, v['words1'], v['words2'], v['label_true']

from .model_runner import *


def prepare_parser_data():
index = 'index'
ds = DataSet({index: list(range(N_SAMPLES))})
ds.apply_field(lambda x: RUNNER.gen_var_seq(MAX_LEN, VOCAB_SIZE),
field_name=index, new_field_name=C.INPUTS(0),
is_input=True)
ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS),
field_name=C.INPUTS(0), new_field_name=C.INPUTS(1),
is_input=True)
# target1 is heads, should in range(0, len(words))
ds.apply_field(lambda x: RUNNER.gen_seq(len(x), len(x)),
field_name=C.INPUTS(0), new_field_name=C.TARGETS(0),
is_target=True)
ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS),
field_name=C.INPUTS(0), new_field_name=C.TARGETS(1),
is_target=True)
ds.apply_field(len, field_name=C.INPUTS(0), new_field_name=C.INPUT_LEN,
is_input=True, is_target=True)
return ds

class TestBiaffineParser(unittest.TestCase):
def test_train(self):
ds, v1, v2, v3 = init_data()
model = BiaffineParser(init_embed=(len(v1), 30),
pos_vocab_size=len(v2), pos_emb_dim=30,
num_label=len(v3), encoder='var-lstm')
trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds,
loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS',
batch_size=1, validate_every=10,
n_epochs=10, use_tqdm=False)
trainer.train(load_best_model=False)


if __name__ == '__main__':
unittest.main()
model = BiaffineParser(init_embed=(VOCAB_SIZE, 30),
pos_vocab_size=VOCAB_SIZE, pos_emb_dim=30,
num_label=NUM_CLS, encoder='var-lstm')
ds = prepare_parser_data()
RUNNER.run_model(model, ds, loss=ParserLoss(), metrics=ParserMetric())

+ 16
- 0
test/models/test_star_trans.py View File

@@ -0,0 +1,16 @@
from .model_runner import *
from fastNLP.models.star_transformer import STNLICls, STSeqCls, STSeqLabel


# add star-transformer tests, for 3 kinds of tasks.
def test_cls():
model = STSeqCls((VOCAB_SIZE, 100), NUM_CLS, dropout=0)
RUNNER.run_model_with_task(TEXT_CLS, model)

def test_nli():
model = STNLICls((VOCAB_SIZE, 100), NUM_CLS, dropout=0)
RUNNER.run_model_with_task(NLI, model)

def test_seq_label():
model = STSeqLabel((VOCAB_SIZE, 100), NUM_CLS, dropout=0)
RUNNER.run_model_with_task(POS_TAGGING, model)

Loading…
Cancel
Save