|
|
@@ -177,7 +177,8 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt |
|
|
|
scores = scores / temperature |
|
|
|
|
|
|
|
scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2) |
|
|
|
probs = F.softmax(scores, dim=-1) |
|
|
|
# 加上1e-12是为了避免https://github.com/pytorch/pytorch/pull/27523 |
|
|
|
probs = F.softmax(scores, dim=-1) + 1e-12 |
|
|
|
|
|
|
|
# 保证至少有一个不是eos的值 |
|
|
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size |
|
|
@@ -230,7 +231,7 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 |
|
|
|
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." |
|
|
|
|
|
|
|
if do_sample: |
|
|
|
probs = F.softmax(scores, dim=-1) |
|
|
|
probs = F.softmax(scores, dim=-1) + 1e-12 |
|
|
|
next_tokens = torch.multinomial(probs, num_samples=num_beams) # (batch_size, num_beams) |
|
|
|
logits = probs.log() |
|
|
|
next_scores = logits.gather(dim=1, index=next_tokens) # (batch_size, num_beams) |
|
|
@@ -276,7 +277,8 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 |
|
|
|
|
|
|
|
# 多召回一个防止eos |
|
|
|
scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=num_beams + 1) |
|
|
|
probs = F.softmax(scores, dim=-1) |
|
|
|
# 加上1e-12是为了避免https://github.com/pytorch/pytorch/pull/27523 |
|
|
|
probs = F.softmax(scores, dim=-1) + 1e-12 |
|
|
|
|
|
|
|
# 保证至少有一个不是eos的值 |
|
|
|
_tokens = torch.multinomial(probs, num_samples=num_beams + 1) # batch_size' x (num_beams+1) |
|
|
|