|
- import unittest
-
- import torch
-
- from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder
- from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast, LSTMPast, \
- LSTMSeq2SeqDecoder
- from fastNLP.models.seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel
- from fastNLP import Vocabulary
- from fastNLP.embeddings import StaticEmbedding
- from fastNLP.core.utils import seq_len_to_mask
-
-
- class TestTransformerSeq2SeqDecoder(unittest.TestCase):
- def test_case(self):
- vocab = Vocabulary().add_word_lst("This is a test .".split())
- vocab.add_word_lst("Another test !".split())
- embed = StaticEmbedding(vocab, embedding_dim=512)
-
- args = TransformerSeq2SeqModel.add_args()
- model = TransformerSeq2SeqModel.build_model(args, vocab, vocab)
-
- src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]])
- tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]])
- src_seq_len = torch.LongTensor([3, 2])
-
- output = model(src_words_idx, src_seq_len, tgt_words_idx)
- print(output)
-
- # self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab)))
-
- def test_decode(self):
- pass # todo
-
-
- class TestLSTMDecoder(unittest.TestCase):
- def test_case(self):
- vocab = Vocabulary().add_word_lst("This is a test .".split())
- vocab.add_word_lst("Another test !".split())
- embed = StaticEmbedding(vocab, embedding_dim=512)
-
- encoder = BiLSTMEncoder(embed)
- decoder = LSTMDecoder(embed, bind_input_output_embed=True)
-
- src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]])
- tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]])
- src_seq_len = torch.LongTensor([3, 2])
-
- words, hx = encoder(src_words_idx, src_seq_len)
- encode_mask = seq_len_to_mask(src_seq_len)
- hidden = torch.cat([hx[0][-2:-1], hx[0][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1)
- cell = torch.cat([hx[1][-2:-1], hx[1][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1)
- past = LSTMPast(encode_outputs=words, encode_mask=encode_mask, hx=(hidden, cell))
- decoder_outputs = decoder(tgt_words_idx, past)
-
- print(decoder_outputs)
- print(encode_mask)
-
- self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab)))
-
- def test_decode(self):
- pass # todo
|