|
-
- import unittest
- from fastNLP.models import SequenceGeneratorModel
- from fastNLP.models import LSTMSeq2SeqModel, TransformerSeq2SeqModel
- from fastNLP import Vocabulary, DataSet
- import torch
- from fastNLP.embeddings import StaticEmbedding
- from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric
- from fastNLP import Callback
-
-
- def prepare_env():
- vocab = Vocabulary().add_word_lst("This is a test .".split())
- vocab.add_word_lst("Another test !".split())
- embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5)
-
- src_words_idx = [[3, 1, 2], [1, 2]]
- # tgt_words_idx = [[1, 2, 3, 4], [2, 3]]
- src_seq_len = [3, 2]
- # tgt_seq_len = [4, 2]
-
- ds = DataSet({'src_tokens': src_words_idx, 'src_seq_len': src_seq_len, 'tgt_tokens': src_words_idx,
- 'tgt_seq_len':src_seq_len})
-
- ds.set_input('src_tokens', 'tgt_tokens', 'src_seq_len')
- ds.set_target('tgt_seq_len', 'tgt_tokens')
-
- return embed, ds
-
-
- class ExitCallback(Callback):
- def __init__(self):
- super().__init__()
-
- def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
- if eval_result['AccuracyMetric']['acc']==1:
- raise KeyboardInterrupt()
-
-
- class TestSeq2SeqGeneratorModel(unittest.TestCase):
- def test_run(self):
- # 检测是否能够使用SequenceGeneratorModel训练, 透传预测
- embed, ds = prepare_env()
- model1 = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None,
- pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6,
- dim_ff=20, dropout=0.1,
- bind_encoder_decoder_embed=True,
- bind_decoder_input_output_embed=True)
- trainer = Trainer(ds, model1, optimizer=None, loss=CrossEntropyLoss(target='tgt_tokens', seq_len='tgt_seq_len'),
- batch_size=32, sampler=None, drop_last=False, update_every=1,
- num_workers=0, n_epochs=100, print_every=5,
- dev_data=ds, metrics=AccuracyMetric(target='tgt_tokens', seq_len='tgt_seq_len'), metric_key=None,
- validate_every=-1, save_path=None, use_tqdm=False, device=None,
- callbacks=ExitCallback(), check_code_level=0)
- res = trainer.train()
- self.assertEqual(res['best_eval']['AccuracyMetric']['acc'], 1)
-
- embed, ds = prepare_env()
- model2 = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None,
- num_layers=1, hidden_size=20, dropout=0.1,
- bind_encoder_decoder_embed=True,
- bind_decoder_input_output_embed=True, attention=True)
- optimizer = torch.optim.Adam(model2.parameters(), lr=0.01)
- trainer = Trainer(ds, model2, optimizer=optimizer, loss=CrossEntropyLoss(target='tgt_tokens', seq_len='tgt_seq_len'),
- batch_size=32, sampler=None, drop_last=False, update_every=1,
- num_workers=0, n_epochs=200, print_every=1,
- dev_data=ds, metrics=AccuracyMetric(target='tgt_tokens', seq_len='tgt_seq_len'),
- metric_key=None,
- validate_every=-1, save_path=None, use_tqdm=False, device=None,
- callbacks=ExitCallback(), check_code_level=0)
- res = trainer.train()
- self.assertEqual(res['best_eval']['AccuracyMetric']['acc'], 1)
-
-
-
|