From 33a94a3d1bad64705f3055b6df1ed2d08c63e2ef Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sun, 27 Dec 2020 20:27:42 +0800 Subject: [PATCH] update for inf --- fastNLP/modules/generator/seq2seq_generator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fastNLP/modules/generator/seq2seq_generator.py b/fastNLP/modules/generator/seq2seq_generator.py index 5b34df24..0ba9c02a 100644 --- a/fastNLP/modules/generator/seq2seq_generator.py +++ b/fastNLP/modules/generator/seq2seq_generator.py @@ -280,7 +280,6 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size) # 得到(batch_size, num_beams), (batch_size, num_beams) next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True) - # TODO 这里需要考虑如果在第一个位置就结束的情况 # 根据index来做顺序的调转 indices = torch.arange(batch_size, dtype=torch.long).to(device) @@ -329,7 +328,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ max_len_eos_mask = max_lengths.eq(cur_len+1) eos_scores = scores[:, _eos_token_id] # 如果已经达到最大长度,就把eos的分数加大 - scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+float('inf'), eos_scores) + scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e12, eos_scores) if do_sample: if temperature > 0 and temperature != 1: