From 4f65e17d1a9f2d4fda0af836f3d4d36811b54599 Mon Sep 17 00:00:00 2001 From: yunfan Date: Mon, 29 Apr 2019 18:53:05 +0800 Subject: [PATCH] - add model runner for easier test models - add model tests --- fastNLP/models/biaffine_parser.py | 137 +++++++++++++------------ fastNLP/models/star_transformer.py | 15 ++- test/models/__init__.py | 0 test/models/model_runner.py | 151 ++++++++++++++++++++++++++++ test/models/test_biaffine_parser.py | 111 +++++--------------- test/models/test_star_trans.py | 16 +++ 6 files changed, 268 insertions(+), 162 deletions(-) create mode 100644 test/models/__init__.py create mode 100644 test/models/model_runner.py create mode 100644 test/models/test_star_trans.py diff --git a/fastNLP/models/biaffine_parser.py b/fastNLP/models/biaffine_parser.py index f2329dca..b8d1e7a9 100644 --- a/fastNLP/models/biaffine_parser.py +++ b/fastNLP/models/biaffine_parser.py @@ -7,16 +7,17 @@ import torch from torch import nn from torch.nn import functional as F -from fastNLP.core.losses import LossFunc -from fastNLP.core.metrics import MetricBase -from fastNLP.core.utils import seq_lens_to_masks -from fastNLP.models.base_model import BaseModel -from fastNLP.modules.dropout import TimestepDropout -from fastNLP.modules.encoder.transformer import TransformerEncoder -from fastNLP.modules.encoder.variational_rnn import VarLSTM -from fastNLP.modules.utils import initial_parameter -from fastNLP.modules.utils import seq_mask -from fastNLP.modules.utils import get_embeddings +from ..core.const import Const as C +from ..core.losses import LossFunc +from ..core.metrics import MetricBase +from ..core.utils import seq_lens_to_masks +from ..modules.dropout import TimestepDropout +from ..modules.encoder.transformer import TransformerEncoder +from ..modules.encoder.variational_rnn import VarLSTM +from ..modules.utils import initial_parameter +from ..modules.utils import seq_mask +from ..modules.utils import get_embeddings +from .base_model import BaseModel def _mst(scores): """ @@ -325,21 +326,20 @@ class BiaffineParser(GraphParser): for p in m.parameters(): nn.init.normal_(p, 0, 0.1) - def forward(self, words1, words2, seq_len, gold_heads=None): + def forward(self, words1, words2, seq_len, target1=None): """模型forward阶段 :param words1: [batch_size, seq_len] 输入word序列 :param words2: [batch_size, seq_len] 输入pos序列 :param seq_len: [batch_size, seq_len] 输入序列长度 - :param gold_heads: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效, + :param target1: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效, 用于训练label分类器. 若为 ``None`` , 使用预测的heads输入到label分类器 Default: ``None`` :return dict: parsing结果:: - arc_pred: [batch_size, seq_len, seq_len] 边预测logits - label_pred: [batch_size, seq_len, num_label] label预测logits - mask: [batch_size, seq_len] 预测结果的mask - head_pred: [batch_size, seq_len] heads的预测结果, 在 ``gold_heads=None`` 时预测 + pred1: [batch_size, seq_len, seq_len] 边预测logits + pred2: [batch_size, seq_len, num_label] label预测logits + pred3: [batch_size, seq_len] heads的预测结果, 在 ``target1=None`` 时预测 """ # prepare embeddings batch_size, length = words1.shape @@ -365,7 +365,7 @@ class BiaffineParser(GraphParser): _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) feat = feat[unsort_idx] else: - seq_range = torch.arange(seq_len, dtype=torch.long, device=x.device)[None,:] + seq_range = torch.arange(length, dtype=torch.long, device=x.device)[None,:] x = x + self.position_emb(seq_range) feat = self.encoder(x, mask.float()) @@ -380,7 +380,7 @@ class BiaffineParser(GraphParser): arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] # use gold or predicted arc to predict label - if gold_heads is None or not self.training: + if target1 is None or not self.training: # use greedy decoding in training if self.training or self.use_greedy_infer: heads = self.greedy_decoder(arc_pred, mask) @@ -389,44 +389,45 @@ class BiaffineParser(GraphParser): head_pred = heads else: assert self.training # must be training mode - if gold_heads is None: + if target1 is None: heads = self.greedy_decoder(arc_pred, mask) head_pred = heads else: head_pred = None - heads = gold_heads + heads = target1 batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=words1.device).unsqueeze(1) label_head = label_head[batch_range, heads].contiguous() label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] - res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask} + res_dict = {C.OUTPUTS(0): arc_pred, C.OUTPUTS(1): label_pred} if head_pred is not None: - res_dict['head_pred'] = head_pred + res_dict[C.OUTPUTS(2)] = head_pred return res_dict @staticmethod - def loss(arc_pred, label_pred, arc_true, label_true, mask): + def loss(pred1, pred2, target1, target2, seq_len): """ - Compute loss. - - :param arc_pred: [batch_size, seq_len, seq_len] 边预测logits - :param label_pred: [batch_size, seq_len, num_label] label预测logits - :param arc_true: [batch_size, seq_len] 真实边的标注 - :param label_true: [batch_size, seq_len] 真实类别的标注 - :param mask: [batch_size, seq_len] 预测结果的mask - :return: loss value + 计算parser的loss + + :param pred1: [batch_size, seq_len, seq_len] 边预测logits + :param pred2: [batch_size, seq_len, num_label] label预测logits + :param target1: [batch_size, seq_len] 真实边的标注 + :param target2: [batch_size, seq_len] 真实类别的标注 + :param seq_len: [batch_size, seq_len] 真实目标的长度 + :return loss: scalar """ - batch_size, seq_len, _ = arc_pred.shape + batch_size, length, _ = pred1.shape + mask = seq_mask(seq_len, length) flip_mask = (mask == 0) - _arc_pred = arc_pred.clone() + _arc_pred = pred1.clone() _arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) arc_logits = F.log_softmax(_arc_pred, dim=2) - label_logits = F.log_softmax(label_pred, dim=2) + label_logits = F.log_softmax(pred2, dim=2) batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) - child_index = torch.arange(seq_len, device=arc_logits.device, dtype=torch.long).unsqueeze(0) - arc_loss = arc_logits[batch_index, child_index, arc_true] - label_loss = label_logits[batch_index, child_index, label_true] + child_index = torch.arange(length, device=arc_logits.device, dtype=torch.long).unsqueeze(0) + arc_loss = arc_logits[batch_index, child_index, target1] + label_loss = label_logits[batch_index, child_index, target2] byte_mask = flip_mask.byte() arc_loss.masked_fill_(byte_mask, 0) @@ -441,21 +442,16 @@ class BiaffineParser(GraphParser): :param words1: [batch_size, seq_len] 输入word序列 :param words2: [batch_size, seq_len] 输入pos序列 :param seq_len: [batch_size, seq_len] 输入序列长度 - :param gold_heads: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效, - 用于训练label分类器. 若为 ``None`` , 使用预测的heads输入到label分类器 - Default: ``None`` :return dict: parsing结果:: - arc_pred: [batch_size, seq_len, seq_len] 边预测logits - label_pred: [batch_size, seq_len, num_label] label预测logits - mask: [batch_size, seq_len] 预测结果的mask - head_pred: [batch_size, seq_len] heads的预测结果, 在 ``gold_heads=None`` 时预测 + pred1: [batch_size, seq_len] heads的预测结果 + pred2: [batch_size, seq_len, num_label] label预测logits """ res = self(words1, words2, seq_len) output = {} - output['arc_pred'] = res.pop('head_pred') - _, label_pred = res.pop('label_pred').max(2) - output['label_pred'] = label_pred + output[C.OUTPUTS(0)] = res.pop(C.OUTPUTS(2)) + _, label_pred = res.pop(C.OUTPUTS(1)).max(2) + output[C.OUTPUTS(1)] = label_pred return output @@ -463,41 +459,44 @@ class ParserLoss(LossFunc): """ 计算parser的loss - :param arc_pred: [batch_size, seq_len, seq_len] 边预测logits - :param label_pred: [batch_size, seq_len, num_label] label预测logits - :param arc_true: [batch_size, seq_len] 真实边的标注 - :param label_true: [batch_size, seq_len] 真实类别的标注 - :param mask: [batch_size, seq_len] 预测结果的mask + :param pred1: [batch_size, seq_len, seq_len] 边预测logits + :param pred2: [batch_size, seq_len, num_label] label预测logits + :param target1: [batch_size, seq_len] 真实边的标注 + :param target2: [batch_size, seq_len] 真实类别的标注 + :param seq_len: [batch_size, seq_len] 真实目标的长度 :return loss: scalar """ - def __init__(self, arc_pred=None, label_pred=None, arc_true=None, label_true=None): + def __init__(self, pred1=None, pred2=None, + target1=None, target2=None, + seq_len=None): super(ParserLoss, self).__init__(BiaffineParser.loss, - arc_pred=arc_pred, - label_pred=label_pred, - arc_true=arc_true, - label_true=label_true) + pred1=pred1, + pred2=pred2, + target1=target1, + target2=target2, + seq_len=seq_len) class ParserMetric(MetricBase): """ 评估parser的性能 - :param arc_pred: 边预测logits - :param label_pred: label预测logits - :param arc_true: 真实边的标注 - :param label_true: 真实类别的标注 + :param pred1: 边预测logits + :param pred2: label预测logits + :param target1: 真实边的标注 + :param target2: 真实类别的标注 :param seq_len: 序列长度 :return dict: 评估结果:: UAS: 不带label时, 边预测的准确率 LAS: 同时预测边和label的准确率 """ - def __init__(self, arc_pred=None, label_pred=None, - arc_true=None, label_true=None, seq_len=None): + def __init__(self, pred1=None, pred2=None, + target1=None, target2=None, seq_len=None): super().__init__() - self._init_param_map(arc_pred=arc_pred, label_pred=label_pred, - arc_true=arc_true, label_true=label_true, + self._init_param_map(pred1=pred1, pred2=pred2, + target1=target1, target2=target2, seq_len=seq_len) self.num_arc = 0 self.num_label = 0 @@ -509,17 +508,17 @@ class ParserMetric(MetricBase): self.num_sample = self.num_label = self.num_arc = 0 return res - def evaluate(self, arc_pred, label_pred, arc_true, label_true, seq_len=None): + def evaluate(self, pred1, pred2, target1, target2, seq_len=None): """Evaluate the performance of prediction. """ if seq_len is None: - seq_mask = arc_pred.new_ones(arc_pred.size(), dtype=torch.long) + seq_mask = pred1.new_ones(pred1.size(), dtype=torch.long) else: seq_mask = seq_lens_to_masks(seq_len.long(), float=False).long() # mask out tag seq_mask[:,0] = 0 - head_pred_correct = (arc_pred == arc_true).long() * seq_mask - label_pred_correct = (label_pred == label_true).long() * head_pred_correct + head_pred_correct = (pred1 == target1).long() * seq_mask + label_pred_correct = (pred2 == target2).long() * head_pred_correct self.num_arc += head_pred_correct.sum().item() self.num_label += label_pred_correct.sum().item() self.num_sample += seq_mask.sum().item() diff --git a/fastNLP/models/star_transformer.py b/fastNLP/models/star_transformer.py index f0c6f33f..c3247333 100644 --- a/fastNLP/models/star_transformer.py +++ b/fastNLP/models/star_transformer.py @@ -108,7 +108,7 @@ class STSeqLabel(nn.Module): :param emb_dropout: 词嵌入的dropout概率. Default: 0.1 :param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 """ - def __init__(self, vocab_size, emb_dim, num_cls, + def __init__(self, init_embed, num_cls, hidden_size=300, num_layers=4, num_head=8, @@ -118,8 +118,7 @@ class STSeqLabel(nn.Module): emb_dropout=0.1, dropout=0.1,): super(STSeqLabel, self).__init__() - self.enc = StarTransEnc(vocab_size=vocab_size, - emb_dim=emb_dim, + self.enc = StarTransEnc(init_embed=init_embed, hidden_size=hidden_size, num_layers=num_layers, num_head=num_head, @@ -170,7 +169,7 @@ class STSeqCls(nn.Module): :param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 """ - def __init__(self, vocab_size, emb_dim, num_cls, + def __init__(self, init_embed, num_cls, hidden_size=300, num_layers=4, num_head=8, @@ -180,8 +179,7 @@ class STSeqCls(nn.Module): emb_dropout=0.1, dropout=0.1,): super(STSeqCls, self).__init__() - self.enc = StarTransEnc(vocab_size=vocab_size, - emb_dim=emb_dim, + self.enc = StarTransEnc(init_embed=init_embed, hidden_size=hidden_size, num_layers=num_layers, num_head=num_head, @@ -232,7 +230,7 @@ class STNLICls(nn.Module): :param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 """ - def __init__(self, vocab_size, emb_dim, num_cls, + def __init__(self, init_embed, num_cls, hidden_size=300, num_layers=4, num_head=8, @@ -242,8 +240,7 @@ class STNLICls(nn.Module): emb_dropout=0.1, dropout=0.1,): super(STNLICls, self).__init__() - self.enc = StarTransEnc(vocab_size=vocab_size, - emb_dim=emb_dim, + self.enc = StarTransEnc(init_embed=init_embed, hidden_size=hidden_size, num_layers=num_layers, num_head=num_head, diff --git a/test/models/__init__.py b/test/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/models/model_runner.py b/test/models/model_runner.py new file mode 100644 index 00000000..7a8d0593 --- /dev/null +++ b/test/models/model_runner.py @@ -0,0 +1,151 @@ +""" +此模块可以非常方便的测试模型。 +若你的模型属于:文本分类,序列标注,自然语言推理(NLI),可以直接使用此模块测试 +若模型不属于上述类别,也可以自己准备假数据,设定loss和metric进行测试 + +此模块的测试仅保证模型能使用fastNLP进行训练和测试,不测试模型实际性能 + +Example:: + # import 全大写变量... + from model_runner import * + + # 测试一个文本分类模型 + init_emb = (VOCAB_SIZE, 50) + model = SomeModel(init_emb, num_cls=NUM_CLS) + RUNNER.run_model_with_task(TEXT_CLS, model) + + # 序列标注模型 + RUNNER.run_model_with_task(POS_TAGGING, model) + + # NLI模型 + RUNNER.run_model_with_task(NLI, model) + + # 自定义模型 + RUNNER.run_model(model, data=get_mydata(), + loss=Myloss(), metrics=Mymetric()) +""" +from fastNLP import Trainer, Tester, DataSet +from fastNLP import AccuracyMetric +from fastNLP import CrossEntropyLoss +from fastNLP.core.const import Const as C +from random import randrange + +VOCAB_SIZE = 100 +NUM_CLS = 100 +MAX_LEN = 10 +N_SAMPLES = 100 +N_EPOCHS = 1 +BATCH_SIZE = 5 + +TEXT_CLS = 'text_cls' +POS_TAGGING = 'pos_tagging' +NLI = 'nli' + +class ModelRunner(): + def gen_seq(self, length, vocab_size): + """generate fake sequence indexes with given length""" + # reserve 0 for padding + return [randrange(1, vocab_size) for _ in range(length)] + + def gen_var_seq(self, max_len, vocab_size): + """generate fake sequence indexes in variant length""" + length = randrange(3, max_len) # at least 3 words in a seq + return self.gen_seq(length, vocab_size) + + def prepare_text_classification_data(self): + index = 'index' + ds = DataSet({index: list(range(N_SAMPLES))}) + ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), + field_name=index, new_field_name=C.INPUT, + is_input=True) + ds.apply_field(lambda x: randrange(NUM_CLS), + field_name=index, new_field_name=C.TARGET, + is_target=True) + ds.apply_field(len, C.INPUT, C.INPUT_LEN, + is_input=True) + return ds + + def prepare_pos_tagging_data(self): + index = 'index' + ds = DataSet({index: list(range(N_SAMPLES))}) + ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), + field_name=index, new_field_name=C.INPUT, + is_input=True) + ds.apply_field(lambda x: self.gen_seq(len(x), NUM_CLS), + field_name=C.INPUT, new_field_name=C.TARGET, + is_target=True) + ds.apply_field(len, C.INPUT, C.INPUT_LEN, + is_input=True, is_target=True) + return ds + + def prepare_nli_data(self): + index = 'index' + ds = DataSet({index: list(range(N_SAMPLES))}) + ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), + field_name=index, new_field_name=C.INPUTS(0), + is_input=True) + ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), + field_name=index, new_field_name=C.INPUTS(1), + is_input=True) + ds.apply_field(lambda x: randrange(NUM_CLS), + field_name=index, new_field_name=C.TARGET, + is_target=True) + ds.apply_field(len, C.INPUTS(0), C.INPUT_LENS(0), + is_input=True, is_target=True) + ds.apply_field(len, C.INPUTS(1), C.INPUT_LENS(1), + is_input = True, is_target = True) + ds.set_input(C.INPUTS(0), C.INPUTS(1)) + ds.set_target(C.TARGET) + return ds + + def run_text_classification(self, model, data=None): + if data is None: + data = self.prepare_text_classification_data() + loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) + metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) + self.run_model(model, data, loss, metric) + + def run_pos_tagging(self, model, data=None): + if data is None: + data = self.prepare_pos_tagging_data() + loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET, padding_idx=0) + metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET, seq_len=C.INPUT_LEN) + self.run_model(model, data, loss, metric) + + def run_nli(self, model, data=None): + if data is None: + data = self.prepare_nli_data() + loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) + metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) + self.run_model(model, data, loss, metric) + + def run_model(self, model, data, loss, metrics): + """run a model, test if it can run with fastNLP""" + print('testing model:', model.__class__.__name__) + tester = Tester(data=data, model=model, metrics=metrics, + batch_size=BATCH_SIZE, verbose=0) + before_train = tester.test() + trainer = Trainer(model=model, train_data=data, dev_data=None, + n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, + loss=loss, + save_path=None, + use_tqdm=False) + trainer.train(load_best_model=False) + after_train = tester.test() + for metric_name, v1 in before_train.items(): + assert metric_name in after_train + # # at least we can sure model params changed, even if we don't know performance + # v2 = after_train[metric_name] + # assert v1 != v2 + + def run_model_with_task(self, task, model): + """run a model with certain task""" + TASKS = { + TEXT_CLS: self.run_text_classification, + POS_TAGGING: self.run_pos_tagging, + NLI: self.run_nli, + } + assert task in TASKS + TASKS[task](model) + +RUNNER = ModelRunner() diff --git a/test/models/test_biaffine_parser.py b/test/models/test_biaffine_parser.py index 918f0fd9..e4746391 100644 --- a/test/models/test_biaffine_parser.py +++ b/test/models/test_biaffine_parser.py @@ -2,90 +2,33 @@ import unittest import fastNLP from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric - -data_file = """ -1 The _ DET DT _ 3 det _ _ -2 new _ ADJ JJ _ 3 amod _ _ -3 rate _ NOUN NN _ 6 nsubj _ _ -4 will _ AUX MD _ 6 aux _ _ -5 be _ VERB VB _ 6 cop _ _ -6 payable _ ADJ JJ _ 0 root _ _ -7 mask _ ADJ JJ _ 6 punct _ _ -8 mask _ ADJ JJ _ 6 punct _ _ -9 cents _ NOUN NNS _ 4 nmod _ _ -10 from _ ADP IN _ 12 case _ _ -11 seven _ NUM CD _ 12 nummod _ _ -12 cents _ NOUN NNS _ 4 nmod _ _ -13 a _ DET DT _ 14 det _ _ -14 share _ NOUN NN _ 12 nmod:npmod _ _ -15 . _ PUNCT . _ 4 punct _ _ - -1 The _ DET DT _ 3 det _ _ -2 new _ ADJ JJ _ 3 amod _ _ -3 rate _ NOUN NN _ 6 nsubj _ _ -4 will _ AUX MD _ 6 aux _ _ -5 be _ VERB VB _ 6 cop _ _ -6 payable _ ADJ JJ _ 0 root _ _ -7 Feb. _ PROPN NNP _ 6 nmod:tmod _ _ -8 15 _ NUM CD _ 7 nummod _ _ -9 . _ PUNCT . _ 6 punct _ _ - -1 A _ DET DT _ 3 det _ _ -2 record _ NOUN NN _ 3 compound _ _ -3 date _ NOUN NN _ 7 nsubjpass _ _ -4 has _ AUX VBZ _ 7 aux _ _ -5 n't _ PART RB _ 7 neg _ _ -6 been _ AUX VBN _ 7 auxpass _ _ -7 set _ VERB VBN _ 0 root _ _ -8 . _ PUNCT . _ 7 punct _ _ - -""" - - -def init_data(): - ds = fastNLP.DataSet() - v = {'words1': fastNLP.Vocabulary(), - 'words2': fastNLP.Vocabulary(), - 'label_true': fastNLP.Vocabulary()} - data = [] - for line in data_file.split('\n'): - line = line.split() - if len(line) == 0 and len(data) > 0: - data = list(zip(*data)) - ds.append(fastNLP.Instance(words1=data[1], - words2=data[4], - arc_true=data[6], - label_true=data[7])) - data = [] - elif len(line) > 0: - data.append(line) - - for name in ['words1', 'words2', 'label_true']: - ds.apply(lambda x: [''] + list(x[name]), new_field_name=name) - ds.apply(lambda x: v[name].add_word_lst(x[name])) - - for name in ['words1', 'words2', 'label_true']: - ds.apply(lambda x: [v[name].to_index(w) for w in x[name]], new_field_name=name) - - ds.apply(lambda x: [0] + list(map(int, x['arc_true'])), new_field_name='arc_true') - ds.apply(lambda x: len(x['words1']), new_field_name='seq_len') - ds.set_input('words1', 'words2', 'seq_len', flag=True) - ds.set_target('arc_true', 'label_true', 'seq_len', flag=True) - return ds, v['words1'], v['words2'], v['label_true'] - +from .model_runner import * + + +def prepare_parser_data(): + index = 'index' + ds = DataSet({index: list(range(N_SAMPLES))}) + ds.apply_field(lambda x: RUNNER.gen_var_seq(MAX_LEN, VOCAB_SIZE), + field_name=index, new_field_name=C.INPUTS(0), + is_input=True) + ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS), + field_name=C.INPUTS(0), new_field_name=C.INPUTS(1), + is_input=True) + # target1 is heads, should in range(0, len(words)) + ds.apply_field(lambda x: RUNNER.gen_seq(len(x), len(x)), + field_name=C.INPUTS(0), new_field_name=C.TARGETS(0), + is_target=True) + ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS), + field_name=C.INPUTS(0), new_field_name=C.TARGETS(1), + is_target=True) + ds.apply_field(len, field_name=C.INPUTS(0), new_field_name=C.INPUT_LEN, + is_input=True, is_target=True) + return ds class TestBiaffineParser(unittest.TestCase): def test_train(self): - ds, v1, v2, v3 = init_data() - model = BiaffineParser(init_embed=(len(v1), 30), - pos_vocab_size=len(v2), pos_emb_dim=30, - num_label=len(v3), encoder='var-lstm') - trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds, - loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', - batch_size=1, validate_every=10, - n_epochs=10, use_tqdm=False) - trainer.train(load_best_model=False) - - -if __name__ == '__main__': - unittest.main() + model = BiaffineParser(init_embed=(VOCAB_SIZE, 30), + pos_vocab_size=VOCAB_SIZE, pos_emb_dim=30, + num_label=NUM_CLS, encoder='var-lstm') + ds = prepare_parser_data() + RUNNER.run_model(model, ds, loss=ParserLoss(), metrics=ParserMetric()) diff --git a/test/models/test_star_trans.py b/test/models/test_star_trans.py new file mode 100644 index 00000000..b08e2efe --- /dev/null +++ b/test/models/test_star_trans.py @@ -0,0 +1,16 @@ +from .model_runner import * +from fastNLP.models.star_transformer import STNLICls, STSeqCls, STSeqLabel + + +# add star-transformer tests, for 3 kinds of tasks. +def test_cls(): + model = STSeqCls((VOCAB_SIZE, 100), NUM_CLS, dropout=0) + RUNNER.run_model_with_task(TEXT_CLS, model) + +def test_nli(): + model = STNLICls((VOCAB_SIZE, 100), NUM_CLS, dropout=0) + RUNNER.run_model_with_task(NLI, model) + +def test_seq_label(): + model = STSeqLabel((VOCAB_SIZE, 100), NUM_CLS, dropout=0) + RUNNER.run_model_with_task(POS_TAGGING, model)