Browse Source

修复seq2seq_generator的bug

tags/v0.6.0
yh_cc 4 years ago
parent
commit
ee584e71dc
4 changed files with 74 additions and 28 deletions
  1. +1
    -0
      fastNLP/io/pipe/cws.py
  2. +9
    -9
      fastNLP/modules/generator/seq2seq_generator.py
  3. +11
    -0
      test/io/pipe/test_cws.py
  4. +53
    -19
      test/modules/generator/test_seq2seq_generator.py

+ 1
- 0
fastNLP/io/pipe/cws.py View File

@@ -202,6 +202,7 @@ class CWSPipe(Pipe):
subchar.append(c)
char.append(''.join(subchar))
subchar = []
continue
if subchar:
subchar.append(c)
else:


+ 9
- 9
fastNLP/modules/generator/seq2seq_generator.py View File

@@ -384,23 +384,23 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
else:
flag = False

# 更改state状态, 重组token_ids
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维
state.reorder_state(reorder_inds)
# 重新组织token_ids的状态
tokens = _next_tokens
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1)

if flag:
_token_ids = torch.cat([token_ids, _next_tokens], dim=-1)
for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(),
eos_beam_idx.tolist()):
if not dones[batch_idx]:
score = next_scores[batch_idx, beam_ind].item()
# 之后需要在结尾新增一个eos
if _eos_token_id!=-1:
hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score)
hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score)
else:
hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx].clone(), score)
hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx].clone(), score)

# 更改state状态, 重组token_ids
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维
state.reorder_state(reorder_inds)
# 重新组织token_ids的状态
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), _next_tokens], dim=-1)

for batch_idx in range(batch_size):
dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) or \


+ 11
- 0
test/io/pipe/test_cws.py View File

@@ -13,6 +13,17 @@ class TestCWSPipe(unittest.TestCase):
data_bundle = CWSPipe(dataset_name=dataset_name).process_from_file()
print(data_bundle)

def test_demo(self):
# related to issue https://github.com/fastnlp/fastNLP/issues/324#issue-705081091
from fastNLP import DataSet, Instance
from fastNLP.io import DataBundle
data_bundle = DataBundle()
ds = DataSet()
ds.append(Instance(raw_words="截流 进入 最后 冲刺 ( 附 图片 1 张 )"))
data_bundle.set_dataset(ds, name='train')
data_bundle = CWSPipe().process(data_bundle)
self.assertFalse('<' in data_bundle.get_vocab('chars'))


class TestRunCWSPipe(unittest.TestCase):
def test_process_from_file(self):


+ 53
- 19
test/modules/generator/test_seq2seq_generator.py View File

@@ -21,6 +21,27 @@ def prepare_env():
return embed, encoder_output, encoder_mask


class GreedyDummyDecoder(Seq2SeqDecoder):
def __init__(self, decoder_output):
super().__init__()
self.cur_length = 0
self.decoder_output = decoder_output

def decode(self, tokens, state):
self.cur_length += 1
scores = self.decoder_output[:, self.cur_length]
return scores


class DummyState(State):
def __init__(self, decoder):
super().__init__()
self.decoder = decoder

def reorder_state(self, indices: torch.LongTensor):
self.decoder.decoder_output = self._reorder_state(self.decoder.decoder_output, indices, dim=0)


class TestSequenceGenerator(unittest.TestCase):
def test_run(self):
# 测试能否运行 (1) 初始化decoder,(2) decode一发
@@ -59,25 +80,6 @@ class TestSequenceGenerator(unittest.TestCase):

def test_greedy_decode(self):
# 测试能否正确的generate
class GreedyDummyDecoder(Seq2SeqDecoder):
def __init__(self, decoder_output):
super().__init__()
self.cur_length = 0
self.decoder_output = decoder_output

def decode(self, tokens, state):
self.cur_length += 1
scores = self.decoder_output[:, self.cur_length]
return scores

class DummyState(State):
def __init__(self, decoder):
super().__init__()
self.decoder = decoder

def reorder_state(self, indices: torch.LongTensor):
self.decoder.decoder_output = self._reorder_state(self.decoder.decoder_output, indices, dim=0)

# greedy
for beam_search in [1, 3]:
decoder_output = torch.randn(2, 10, 5)
@@ -108,3 +110,35 @@ class TestSequenceGenerator(unittest.TestCase):
self.assertEqual(decode_path.size(1), 8) # 长度为8
self.assertEqual(decode_path[0].eq(path[0, :8]).sum(), 8)
self.assertEqual(decode_path[1, :6].eq(path[1, :6]).sum(), 6)

def test_sample_decoder(self):
# greedy check eos_token_id
for beam_search in [1, 3]:
with self.subTest(beam_search=beam_search):
decode_paths = []
# 因为是随机,所以需要测试100次,如果至少有一次是对的,应该就问题不大
num_tests = 10
for i in range(num_tests):
decoder_output = torch.randn(2, 10, 5) * 10
decoder_output[:, :7, 4].fill_(-100)
decoder_output[0, 7, 4] = 10000 # 在第8个结束
decoder_output[1, 5, 4] = 10000
path = decoder_output.argmax(dim=-1) # 2 x 4
decoder = GreedyDummyDecoder(decoder_output)
generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search,
do_sample=True, temperature=1, top_k=50, top_p=0.5, bos_token_id=1,
eos_token_id=4, repetition_penalty=1, length_penalty=1, pad_token_id=0)
decode_path = generator.generate(DummyState(decoder),
tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True))
decode_paths.append([decode_path, path])
sizes = []
eqs = []
eq2s = []
for i in range(num_tests):
decode_path, path = decode_paths[i]
sizes.append(decode_path.size(1)==8)
eqs.append(decode_path[0].eq(path[0, :8]).sum()==8)
eq2s.append(decode_path[1, :6].eq(path[1, :6]).sum()==6)
self.assertTrue(any(sizes))
self.assertTrue(any(eqs))
self.assertTrue(any(eq2s))

Loading…
Cancel
Save