diff --git a/fastNLP/io/pipe/cws.py b/fastNLP/io/pipe/cws.py index 74f05a42..c3aab4e6 100644 --- a/fastNLP/io/pipe/cws.py +++ b/fastNLP/io/pipe/cws.py @@ -202,6 +202,7 @@ class CWSPipe(Pipe): subchar.append(c) char.append(''.join(subchar)) subchar = [] + continue if subchar: subchar.append(c) else: diff --git a/fastNLP/modules/generator/seq2seq_generator.py b/fastNLP/modules/generator/seq2seq_generator.py index 60dc5b71..faa9a93a 100644 --- a/fastNLP/modules/generator/seq2seq_generator.py +++ b/fastNLP/modules/generator/seq2seq_generator.py @@ -384,23 +384,23 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ else: flag = False - # 更改state状态, 重组token_ids - reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 - state.reorder_state(reorder_inds) - # 重新组织token_ids的状态 - tokens = _next_tokens - token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1) - if flag: + _token_ids = torch.cat([token_ids, _next_tokens], dim=-1) for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(), eos_beam_idx.tolist()): if not dones[batch_idx]: score = next_scores[batch_idx, beam_ind].item() # 之后需要在结尾新增一个eos if _eos_token_id!=-1: - hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score) + hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score) else: - hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx].clone(), score) + hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx].clone(), score) + + # 更改state状态, 重组token_ids + reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 + state.reorder_state(reorder_inds) + # 重新组织token_ids的状态 + token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), _next_tokens], dim=-1) for batch_idx in range(batch_size): dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) or \ diff --git a/test/io/pipe/test_cws.py b/test/io/pipe/test_cws.py index 5ca0f164..f3a95596 100644 --- a/test/io/pipe/test_cws.py +++ b/test/io/pipe/test_cws.py @@ -13,6 +13,17 @@ class TestCWSPipe(unittest.TestCase): data_bundle = CWSPipe(dataset_name=dataset_name).process_from_file() print(data_bundle) + def test_demo(self): + # related to issue https://github.com/fastnlp/fastNLP/issues/324#issue-705081091 + from fastNLP import DataSet, Instance + from fastNLP.io import DataBundle + data_bundle = DataBundle() + ds = DataSet() + ds.append(Instance(raw_words="截流 进入 最后 冲刺 ( 附 图片 1 张 )")) + data_bundle.set_dataset(ds, name='train') + data_bundle = CWSPipe().process(data_bundle) + self.assertFalse('<' in data_bundle.get_vocab('chars')) + class TestRunCWSPipe(unittest.TestCase): def test_process_from_file(self): diff --git a/test/modules/generator/test_seq2seq_generator.py b/test/modules/generator/test_seq2seq_generator.py index a60e4b4c..2a2f9d78 100644 --- a/test/modules/generator/test_seq2seq_generator.py +++ b/test/modules/generator/test_seq2seq_generator.py @@ -21,6 +21,27 @@ def prepare_env(): return embed, encoder_output, encoder_mask +class GreedyDummyDecoder(Seq2SeqDecoder): + def __init__(self, decoder_output): + super().__init__() + self.cur_length = 0 + self.decoder_output = decoder_output + + def decode(self, tokens, state): + self.cur_length += 1 + scores = self.decoder_output[:, self.cur_length] + return scores + + +class DummyState(State): + def __init__(self, decoder): + super().__init__() + self.decoder = decoder + + def reorder_state(self, indices: torch.LongTensor): + self.decoder.decoder_output = self._reorder_state(self.decoder.decoder_output, indices, dim=0) + + class TestSequenceGenerator(unittest.TestCase): def test_run(self): # 测试能否运行 (1) 初始化decoder,(2) decode一发 @@ -59,25 +80,6 @@ class TestSequenceGenerator(unittest.TestCase): def test_greedy_decode(self): # 测试能否正确的generate - class GreedyDummyDecoder(Seq2SeqDecoder): - def __init__(self, decoder_output): - super().__init__() - self.cur_length = 0 - self.decoder_output = decoder_output - - def decode(self, tokens, state): - self.cur_length += 1 - scores = self.decoder_output[:, self.cur_length] - return scores - - class DummyState(State): - def __init__(self, decoder): - super().__init__() - self.decoder = decoder - - def reorder_state(self, indices: torch.LongTensor): - self.decoder.decoder_output = self._reorder_state(self.decoder.decoder_output, indices, dim=0) - # greedy for beam_search in [1, 3]: decoder_output = torch.randn(2, 10, 5) @@ -108,3 +110,35 @@ class TestSequenceGenerator(unittest.TestCase): self.assertEqual(decode_path.size(1), 8) # 长度为8 self.assertEqual(decode_path[0].eq(path[0, :8]).sum(), 8) self.assertEqual(decode_path[1, :6].eq(path[1, :6]).sum(), 6) + + def test_sample_decoder(self): + # greedy check eos_token_id + for beam_search in [1, 3]: + with self.subTest(beam_search=beam_search): + decode_paths = [] + # 因为是随机,所以需要测试100次,如果至少有一次是对的,应该就问题不大 + num_tests = 10 + for i in range(num_tests): + decoder_output = torch.randn(2, 10, 5) * 10 + decoder_output[:, :7, 4].fill_(-100) + decoder_output[0, 7, 4] = 10000 # 在第8个结束 + decoder_output[1, 5, 4] = 10000 + path = decoder_output.argmax(dim=-1) # 2 x 4 + decoder = GreedyDummyDecoder(decoder_output) + generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, + do_sample=True, temperature=1, top_k=50, top_p=0.5, bos_token_id=1, + eos_token_id=4, repetition_penalty=1, length_penalty=1, pad_token_id=0) + decode_path = generator.generate(DummyState(decoder), + tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) + decode_paths.append([decode_path, path]) + sizes = [] + eqs = [] + eq2s = [] + for i in range(num_tests): + decode_path, path = decode_paths[i] + sizes.append(decode_path.size(1)==8) + eqs.append(decode_path[0].eq(path[0, :8]).sum()==8) + eq2s.append(decode_path[1, :6].eq(path[1, :6]).sum()==6) + self.assertTrue(any(sizes)) + self.assertTrue(any(eqs)) + self.assertTrue(any(eq2s)) \ No newline at end of file