|
- import unittest
-
- import torch
- from fastNLP.modules.generator import SequenceGenerator
- from fastNLP.modules import TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder, Seq2SeqDecoder, State
- from fastNLP import Vocabulary
- from fastNLP.embeddings import StaticEmbedding
- from torch import nn
- from fastNLP import seq_len_to_mask
-
-
- def prepare_env():
- vocab = Vocabulary().add_word_lst("This is a test .".split())
- vocab.add_word_lst("Another test !".split())
- embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5)
-
- encoder_output = torch.randn(2, 3, 10)
- src_seq_len = torch.LongTensor([3, 2])
- encoder_mask = seq_len_to_mask(src_seq_len)
-
- return embed, encoder_output, encoder_mask
-
-
- class TestSequenceGenerator(unittest.TestCase):
- def test_run(self):
- # 测试能否运行 (1) 初始化decoder,(2) decode一发
- embed, encoder_output, encoder_mask = prepare_env()
-
- for do_sample in [True, False]:
- for num_beams in [1, 3, 5]:
- with self.subTest(do_sample=do_sample, num_beams=num_beams):
- decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers=1, hidden_size=10,
- dropout=0.3, bind_decoder_input_output_embed=True, attention=True)
- state = decoder.init_state(encoder_output, encoder_mask)
- generator = SequenceGenerator(decoder=decoder, max_length=20, num_beams=num_beams,
- do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None,
- repetition_penalty=1, length_penalty=1.0, pad_token_id=0)
- generator.generate(state=state, tokens=None)
-
- decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim),
- d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10, dropout=0.1,
- bind_decoder_input_output_embed=True)
- state = decoder.init_state(encoder_output, encoder_mask)
- generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams,
- do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None,
- repetition_penalty=1, length_penalty=1.0, pad_token_id=0)
- generator.generate(state=state, tokens=None)
-
- # 测试一下其它值
- decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim),
- d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10,
- dropout=0.1,
- bind_decoder_input_output_embed=True)
- state = decoder.init_state(encoder_output, encoder_mask)
- generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams,
- do_sample=do_sample, temperature=0.9, top_k=50, top_p=0.5, bos_token_id=1,
- eos_token_id=3, repetition_penalty=2, length_penalty=1.5, pad_token_id=0)
- generator.generate(state=state, tokens=None)
-
- def test_greedy_decode(self):
- # 测试能否正确的generate
- class GreedyDummyDecoder(Seq2SeqDecoder):
- def __init__(self, decoder_output):
- super().__init__()
- self.cur_length = 0
- self.decoder_output = decoder_output
-
- def decode(self, tokens, state):
- self.cur_length += 1
- scores = self.decoder_output[:, self.cur_length]
- return scores
-
- class DummyState(State):
- def __init__(self, decoder):
- super().__init__()
- self.decoder = decoder
-
- def reorder_state(self, indices: torch.LongTensor):
- self.decoder.decoder_output = self._reorder_state(self.decoder.decoder_output, indices, dim=0)
-
- # greedy
- for beam_search in [1, 3]:
- decoder_output = torch.randn(2, 10, 5)
- path = decoder_output.argmax(dim=-1) # 2 x 4
- decoder = GreedyDummyDecoder(decoder_output)
- with self.subTest(beam_search=beam_search):
- generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search,
- do_sample=False, temperature=1, top_k=50, top_p=1, bos_token_id=1,
- eos_token_id=None, repetition_penalty=1, length_penalty=1, pad_token_id=0)
- decode_path = generator.generate(DummyState(decoder), tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True))
-
- self.assertEqual(decode_path.eq(path).sum(), path.numel())
-
- # greedy check eos_token_id
- for beam_search in [1, 3]:
- decoder_output = torch.randn(2, 10, 5)
- decoder_output[:, :7, 4].fill_(-100)
- decoder_output[0, 7, 4] = 1000 # 在第8个结束
- decoder_output[1, 5, 4] = 1000
- path = decoder_output.argmax(dim=-1) # 2 x 4
- decoder = GreedyDummyDecoder(decoder_output)
- with self.subTest(beam_search=beam_search):
- generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search,
- do_sample=False, temperature=1, top_k=50, top_p=0.5, bos_token_id=1,
- eos_token_id=4, repetition_penalty=1, length_penalty=1, pad_token_id=0)
- decode_path = generator.generate(DummyState(decoder),
- tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True))
- self.assertEqual(decode_path.size(1), 8) # 长度为8
- self.assertEqual(decode_path[0].eq(path[0, :8]).sum(), 8)
- self.assertEqual(decode_path[1, :6].eq(path[1, :6]).sum(), 6)
|