|
|
@@ -368,13 +368,13 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ |
|
|
|
next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True) |
|
|
|
_tokens = _tokens.view(batch_size, num_beams * (num_beams + 1)) |
|
|
|
next_tokens = _tokens.gather(dim=1, index=ids) # (batch_size, 2*num_beams) |
|
|
|
from_which_beam = torch.floor(ids / (num_beams + 1)).long() # (batch_size, 2*num_beams) |
|
|
|
from_which_beam = torch.floor(ids.float() / (num_beams + 1)).long() # (batch_size, 2*num_beams) |
|
|
|
else: |
|
|
|
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) |
|
|
|
_scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size) |
|
|
|
_scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size) |
|
|
|
next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) # (bsz, 2*num_beams) |
|
|
|
from_which_beam = torch.floor(ids / vocab_size).long() # (batch_size, 2*num_beams) |
|
|
|
from_which_beam = torch.floor(ids.float() / vocab_size).long() # (batch_size, 2*num_beams) |
|
|
|
next_tokens = ids % vocab_size # (batch_size, 2*num_beams) |
|
|
|
|
|
|
|
# 接下来需要组装下一个batch的结果。 |
|
|
|