Browse Source

update for inf

tags/v1.0.0alpha
yh_cc 4 years ago
parent
commit
33a94a3d1b
1 changed files with 1 additions and 2 deletions
  1. +1
    -2
      fastNLP/modules/generator/seq2seq_generator.py

+ 1
- 2
fastNLP/modules/generator/seq2seq_generator.py View File

@@ -280,7 +280,6 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size)
# 得到(batch_size, num_beams), (batch_size, num_beams)
next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True)
# TODO 这里需要考虑如果在第一个位置就结束的情况

# 根据index来做顺序的调转
indices = torch.arange(batch_size, dtype=torch.long).to(device)
@@ -329,7 +328,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
max_len_eos_mask = max_lengths.eq(cur_len+1)
eos_scores = scores[:, _eos_token_id]
# 如果已经达到最大长度,就把eos的分数加大
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+float('inf'), eos_scores)
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e12, eos_scores)

if do_sample:
if temperature > 0 and temperature != 1:


Loading…
Cancel
Save