| @@ -280,7 +280,6 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ | |||||
| scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size) | scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size) | ||||
| # 得到(batch_size, num_beams), (batch_size, num_beams) | # 得到(batch_size, num_beams), (batch_size, num_beams) | ||||
| next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True) | next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True) | ||||
| # TODO 这里需要考虑如果在第一个位置就结束的情况 | |||||
| # 根据index来做顺序的调转 | # 根据index来做顺序的调转 | ||||
| indices = torch.arange(batch_size, dtype=torch.long).to(device) | indices = torch.arange(batch_size, dtype=torch.long).to(device) | ||||
| @@ -329,7 +328,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ | |||||
| max_len_eos_mask = max_lengths.eq(cur_len+1) | max_len_eos_mask = max_lengths.eq(cur_len+1) | ||||
| eos_scores = scores[:, _eos_token_id] | eos_scores = scores[:, _eos_token_id] | ||||
| # 如果已经达到最大长度,就把eos的分数加大 | # 如果已经达到最大长度,就把eos的分数加大 | ||||
| scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+float('inf'), eos_scores) | |||||
| scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e12, eos_scores) | |||||
| if do_sample: | if do_sample: | ||||
| if temperature > 0 and temperature != 1: | if temperature > 0 and temperature != 1: | ||||