|
@@ -21,6 +21,27 @@ def prepare_env(): |
|
|
return embed, encoder_output, encoder_mask |
|
|
return embed, encoder_output, encoder_mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GreedyDummyDecoder(Seq2SeqDecoder): |
|
|
|
|
|
def __init__(self, decoder_output): |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
self.cur_length = 0 |
|
|
|
|
|
self.decoder_output = decoder_output |
|
|
|
|
|
|
|
|
|
|
|
def decode(self, tokens, state): |
|
|
|
|
|
self.cur_length += 1 |
|
|
|
|
|
scores = self.decoder_output[:, self.cur_length] |
|
|
|
|
|
return scores |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DummyState(State): |
|
|
|
|
|
def __init__(self, decoder): |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
self.decoder = decoder |
|
|
|
|
|
|
|
|
|
|
|
def reorder_state(self, indices: torch.LongTensor): |
|
|
|
|
|
self.decoder.decoder_output = self._reorder_state(self.decoder.decoder_output, indices, dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSequenceGenerator(unittest.TestCase): |
|
|
class TestSequenceGenerator(unittest.TestCase): |
|
|
def test_run(self): |
|
|
def test_run(self): |
|
|
# 测试能否运行 (1) 初始化decoder,(2) decode一发 |
|
|
# 测试能否运行 (1) 初始化decoder,(2) decode一发 |
|
@@ -59,25 +80,6 @@ class TestSequenceGenerator(unittest.TestCase): |
|
|
|
|
|
|
|
|
def test_greedy_decode(self): |
|
|
def test_greedy_decode(self): |
|
|
# 测试能否正确的generate |
|
|
# 测试能否正确的generate |
|
|
class GreedyDummyDecoder(Seq2SeqDecoder): |
|
|
|
|
|
def __init__(self, decoder_output): |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
self.cur_length = 0 |
|
|
|
|
|
self.decoder_output = decoder_output |
|
|
|
|
|
|
|
|
|
|
|
def decode(self, tokens, state): |
|
|
|
|
|
self.cur_length += 1 |
|
|
|
|
|
scores = self.decoder_output[:, self.cur_length] |
|
|
|
|
|
return scores |
|
|
|
|
|
|
|
|
|
|
|
class DummyState(State): |
|
|
|
|
|
def __init__(self, decoder): |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
self.decoder = decoder |
|
|
|
|
|
|
|
|
|
|
|
def reorder_state(self, indices: torch.LongTensor): |
|
|
|
|
|
self.decoder.decoder_output = self._reorder_state(self.decoder.decoder_output, indices, dim=0) |
|
|
|
|
|
|
|
|
|
|
|
# greedy |
|
|
# greedy |
|
|
for beam_search in [1, 3]: |
|
|
for beam_search in [1, 3]: |
|
|
decoder_output = torch.randn(2, 10, 5) |
|
|
decoder_output = torch.randn(2, 10, 5) |
|
@@ -108,3 +110,35 @@ class TestSequenceGenerator(unittest.TestCase): |
|
|
self.assertEqual(decode_path.size(1), 8) # 长度为8 |
|
|
self.assertEqual(decode_path.size(1), 8) # 长度为8 |
|
|
self.assertEqual(decode_path[0].eq(path[0, :8]).sum(), 8) |
|
|
self.assertEqual(decode_path[0].eq(path[0, :8]).sum(), 8) |
|
|
self.assertEqual(decode_path[1, :6].eq(path[1, :6]).sum(), 6) |
|
|
self.assertEqual(decode_path[1, :6].eq(path[1, :6]).sum(), 6) |
|
|
|
|
|
|
|
|
|
|
|
def test_sample_decoder(self): |
|
|
|
|
|
# greedy check eos_token_id |
|
|
|
|
|
for beam_search in [1, 3]: |
|
|
|
|
|
with self.subTest(beam_search=beam_search): |
|
|
|
|
|
decode_paths = [] |
|
|
|
|
|
# 因为是随机,所以需要测试100次,如果至少有一次是对的,应该就问题不大 |
|
|
|
|
|
num_tests = 10 |
|
|
|
|
|
for i in range(num_tests): |
|
|
|
|
|
decoder_output = torch.randn(2, 10, 5) * 10 |
|
|
|
|
|
decoder_output[:, :7, 4].fill_(-100) |
|
|
|
|
|
decoder_output[0, 7, 4] = 10000 # 在第8个结束 |
|
|
|
|
|
decoder_output[1, 5, 4] = 10000 |
|
|
|
|
|
path = decoder_output.argmax(dim=-1) # 2 x 4 |
|
|
|
|
|
decoder = GreedyDummyDecoder(decoder_output) |
|
|
|
|
|
generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, |
|
|
|
|
|
do_sample=True, temperature=1, top_k=50, top_p=0.5, bos_token_id=1, |
|
|
|
|
|
eos_token_id=4, repetition_penalty=1, length_penalty=1, pad_token_id=0) |
|
|
|
|
|
decode_path = generator.generate(DummyState(decoder), |
|
|
|
|
|
tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) |
|
|
|
|
|
decode_paths.append([decode_path, path]) |
|
|
|
|
|
sizes = [] |
|
|
|
|
|
eqs = [] |
|
|
|
|
|
eq2s = [] |
|
|
|
|
|
for i in range(num_tests): |
|
|
|
|
|
decode_path, path = decode_paths[i] |
|
|
|
|
|
sizes.append(decode_path.size(1)==8) |
|
|
|
|
|
eqs.append(decode_path[0].eq(path[0, :8]).sum()==8) |
|
|
|
|
|
eq2s.append(decode_path[1, :6].eq(path[1, :6]).sum()==6) |
|
|
|
|
|
self.assertTrue(any(sizes)) |
|
|
|
|
|
self.assertTrue(any(eqs)) |
|
|
|
|
|
self.assertTrue(any(eq2s)) |