diff --git a/fastNLP/embeddings/static_embedding.py b/fastNLP/embeddings/static_embedding.py index a8139807..cdea7cfb 100644 --- a/fastNLP/embeddings/static_embedding.py +++ b/fastNLP/embeddings/static_embedding.py @@ -177,7 +177,7 @@ class StaticEmbedding(TokenEmbedding): else: unknown_idx = embedding.size(0) - 1 # 否则是最后一个为unknow self.register_buffer('words_to_words', torch.arange(len(vocab)).long()) - words_to_words = torch.full((len(vocab),), fill_value=unknown_idx).long() + words_to_words = torch.full((len(vocab),), fill_value=unknown_idx, dtype=torch.long).long() for word, index in vocab: if word not in lowered_vocab: word = word.lower() @@ -306,7 +306,7 @@ class StaticEmbedding(TokenEmbedding): vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous() else: unknown_idx = vocab.unknown_idx - self.register_buffer('words_to_words', torch.full((len(vocab), ), fill_value=unknown_idx).long()) + self.register_buffer('words_to_words', torch.full((len(vocab), ), fill_value=unknown_idx, dtype=torch.long).long()) index = 0 for word, index_in_vocab in vocab: if index_in_vocab in matrix: diff --git a/fastNLP/modules/generator/seq2seq_generator.py b/fastNLP/modules/generator/seq2seq_generator.py index cf6e0978..60dc5b71 100644 --- a/fastNLP/modules/generator/seq2seq_generator.py +++ b/fastNLP/modules/generator/seq2seq_generator.py @@ -175,11 +175,18 @@ def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_le # tokens = tokens[:, -1:] if max_len_a!=0: - max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length - real_max_length = max_lengths.max() + # (bsz x num_beams, ) + if state.encoder_mask is not None: + max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length + else: + max_lengths = tokens.new_full((tokens.size(0), ), fill_value=max_length, dtype=torch.long) + real_max_length = max_lengths.max().item() else: real_max_length = max_length - max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length + if state.encoder_mask is not None: + max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length + else: + max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long) while cur_len < real_max_length: scores = decoder.decode(tokens=token_ids, state=state) # batch_size x vocab_size @@ -211,7 +218,8 @@ def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_le next_tokens = torch.argmax(scores, dim=-1) # batch_size # 如果已经达到对应的sequence长度了,就直接填为eos了 - next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len+1), _eos_token_id) + if _eos_token_id!=-1: + next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len+1), _eos_token_id) next_tokens = next_tokens.masked_fill(dones, pad_token_id) # 对已经搜索完成的sample做padding tokens = next_tokens.unsqueeze(1) @@ -283,12 +291,17 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ if max_len_a!=0: # (bsz x num_beams, ) - max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length + if state.encoder_mask is not None: + max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length + else: + max_lengths = tokens.new_full((tokens.size(0), ), fill_value=max_length, dtype=torch.long) real_max_length = max_lengths.max().item() else: real_max_length = max_length - max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length - + if state.encoder_mask is not None: + max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length + else: + max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long) hypos = [ BeamHypotheses(num_beams, real_max_length, length_penalty, early_stopping=False) for _ in range(batch_size) ] @@ -371,25 +384,28 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ else: flag = False + # 更改state状态, 重组token_ids + reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 + state.reorder_state(reorder_inds) + # 重新组织token_ids的状态 + tokens = _next_tokens + token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1) + if flag: for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(), eos_beam_idx.tolist()): if not dones[batch_idx]: score = next_scores[batch_idx, beam_ind].item() # 之后需要在结尾新增一个eos - hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score) + if _eos_token_id!=-1: + hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score) + else: + hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx].clone(), score) for batch_idx in range(batch_size): dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) or \ max_lengths[batch_idx*num_beams]==cur_len+1 - # 更改state状态, 重组token_ids - reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 - state.reorder_state(reorder_inds) - # 重新组织token_ids的状态 - tokens = _next_tokens - token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1) - cur_len += 1 if all(dones): diff --git a/test/embeddings/test_gpt2_embedding.py b/test/embeddings/test_gpt2_embedding.py index d31f20bc..e8d0d043 100644 --- a/test/embeddings/test_gpt2_embedding.py +++ b/test/embeddings/test_gpt2_embedding.py @@ -254,6 +254,7 @@ class TestGPT2WordPieceEncoder(unittest.TestCase): self.assertTrue(ds.has_field('word_pieces')) result = embed(torch.LongTensor([[1, 2, 3, 4]])) + @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") def test_generate(self): # weight_path = 'test/data_for_tests/embedding/small_gpt2' weight_path = 'en' diff --git a/test/modules/generator/test_seq2seq_generator.py b/test/modules/generator/test_seq2seq_generator.py index d7c0fbfa..a60e4b4c 100644 --- a/test/modules/generator/test_seq2seq_generator.py +++ b/test/modules/generator/test_seq2seq_generator.py @@ -81,9 +81,9 @@ class TestSequenceGenerator(unittest.TestCase): # greedy for beam_search in [1, 3]: decoder_output = torch.randn(2, 10, 5) - path = decoder_output.argmax(dim=-1) # 2 x 4 + path = decoder_output.argmax(dim=-1) # 2 x 10 decoder = GreedyDummyDecoder(decoder_output) - with self.subTest(beam_search=beam_search): + with self.subTest(msg=beam_search, beam_search=beam_search): generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, do_sample=False, temperature=1, top_k=50, top_p=1, bos_token_id=1, eos_token_id=None, repetition_penalty=1, length_penalty=1, pad_token_id=0)