|
@@ -12,19 +12,19 @@ import torch.nn.functional as F |
|
|
from ...core.utils import _get_model_device |
|
|
from ...core.utils import _get_model_device |
|
|
from functools import partial |
|
|
from functools import partial |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SequenceGenerator: |
|
|
class SequenceGenerator: |
|
|
""" |
|
|
""" |
|
|
给定一个Seq2SeqDecoder,decode出句子 |
|
|
给定一个Seq2SeqDecoder,decode出句子 |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
def __init__(self, decoder: Seq2SeqDecoder, max_length=20, num_beams=1, |
|
|
|
|
|
|
|
|
def __init__(self, decoder: Seq2SeqDecoder, max_length=20, max_len_a=0.0, num_beams=1, |
|
|
do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, |
|
|
do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, |
|
|
repetition_penalty=1, length_penalty=1.0, pad_token_id=0): |
|
|
repetition_penalty=1, length_penalty=1.0, pad_token_id=0): |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
:param Seq2SeqDecoder decoder: Decoder对象 |
|
|
:param Seq2SeqDecoder decoder: Decoder对象 |
|
|
:param int max_length: 句子的最大长度 |
|
|
|
|
|
|
|
|
:param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len |
|
|
|
|
|
:param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask |
|
|
:param int num_beams: beam search的大小 |
|
|
:param int num_beams: beam search的大小 |
|
|
:param bool do_sample: 是否通过采样的方式生成 |
|
|
:param bool do_sample: 是否通过采样的方式生成 |
|
|
:param float temperature: 只有在do_sample为True才有意义 |
|
|
:param float temperature: 只有在do_sample为True才有意义 |
|
@@ -37,12 +37,14 @@ class SequenceGenerator: |
|
|
:param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充 |
|
|
:param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充 |
|
|
""" |
|
|
""" |
|
|
if do_sample: |
|
|
if do_sample: |
|
|
self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, num_beams=num_beams, |
|
|
|
|
|
|
|
|
self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, max_len_a=max_len_a, |
|
|
|
|
|
num_beams=num_beams, |
|
|
temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id, |
|
|
temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id, |
|
|
eos_token_id=eos_token_id, repetition_penalty=repetition_penalty, |
|
|
eos_token_id=eos_token_id, repetition_penalty=repetition_penalty, |
|
|
length_penalty=length_penalty, pad_token_id=pad_token_id) |
|
|
length_penalty=length_penalty, pad_token_id=pad_token_id) |
|
|
else: |
|
|
else: |
|
|
self.generate_func = partial(greedy_generate, decoder=decoder, max_length=max_length, num_beams=num_beams, |
|
|
|
|
|
|
|
|
self.generate_func = partial(greedy_generate, decoder=decoder, max_length=max_length, max_len_a=max_len_a, |
|
|
|
|
|
num_beams=num_beams, |
|
|
bos_token_id=bos_token_id, eos_token_id=eos_token_id, |
|
|
bos_token_id=bos_token_id, eos_token_id=eos_token_id, |
|
|
repetition_penalty=repetition_penalty, |
|
|
repetition_penalty=repetition_penalty, |
|
|
length_penalty=length_penalty, pad_token_id=pad_token_id) |
|
|
length_penalty=length_penalty, pad_token_id=pad_token_id) |
|
@@ -71,7 +73,7 @@ class SequenceGenerator: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
@torch.no_grad() |
|
|
def greedy_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1, |
|
|
|
|
|
|
|
|
def greedy_generate(decoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=1, |
|
|
bos_token_id=None, eos_token_id=None, pad_token_id=0, |
|
|
bos_token_id=None, eos_token_id=None, pad_token_id=0, |
|
|
repetition_penalty=1, length_penalty=1.0): |
|
|
repetition_penalty=1, length_penalty=1.0): |
|
|
""" |
|
|
""" |
|
@@ -80,7 +82,8 @@ def greedy_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1 |
|
|
:param Decoder decoder: Decoder对象 |
|
|
:param Decoder decoder: Decoder对象 |
|
|
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 |
|
|
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 |
|
|
:param State state: 应该包含encoder的一些输出。 |
|
|
:param State state: 应该包含encoder的一些输出。 |
|
|
:param int max_length: 生成句子的最大长度。 |
|
|
|
|
|
|
|
|
:param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len |
|
|
|
|
|
:param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask |
|
|
:param int num_beams: 使用多大的beam进行解码。 |
|
|
:param int num_beams: 使用多大的beam进行解码。 |
|
|
:param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。 |
|
|
:param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。 |
|
|
:param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。 |
|
|
:param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。 |
|
@@ -90,13 +93,14 @@ def greedy_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1 |
|
|
:return: |
|
|
:return: |
|
|
""" |
|
|
""" |
|
|
if num_beams == 1: |
|
|
if num_beams == 1: |
|
|
token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, temperature=1, top_k=50, top_p=1, |
|
|
|
|
|
|
|
|
token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, |
|
|
|
|
|
temperature=1, top_k=50, top_p=1, |
|
|
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, |
|
|
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, |
|
|
repetition_penalty=repetition_penalty, length_penalty=length_penalty, |
|
|
repetition_penalty=repetition_penalty, length_penalty=length_penalty, |
|
|
pad_token_id=pad_token_id) |
|
|
pad_token_id=pad_token_id) |
|
|
else: |
|
|
else: |
|
|
token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, num_beams=num_beams, |
|
|
|
|
|
temperature=1, top_k=50, top_p=1, |
|
|
|
|
|
|
|
|
token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, |
|
|
|
|
|
num_beams=num_beams, temperature=1, top_k=50, top_p=1, |
|
|
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, |
|
|
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, |
|
|
repetition_penalty=repetition_penalty, length_penalty=length_penalty, |
|
|
repetition_penalty=repetition_penalty, length_penalty=length_penalty, |
|
|
pad_token_id=pad_token_id) |
|
|
pad_token_id=pad_token_id) |
|
@@ -105,7 +109,7 @@ def greedy_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
@torch.no_grad() |
|
|
def sample_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1, temperature=1.0, top_k=50, |
|
|
|
|
|
|
|
|
def sample_generate(decoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=1, temperature=1.0, top_k=50, |
|
|
top_p=1.0, bos_token_id=None, eos_token_id=None, pad_token_id=0, repetition_penalty=1.0, |
|
|
top_p=1.0, bos_token_id=None, eos_token_id=None, pad_token_id=0, repetition_penalty=1.0, |
|
|
length_penalty=1.0): |
|
|
length_penalty=1.0): |
|
|
""" |
|
|
""" |
|
@@ -114,7 +118,8 @@ def sample_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1 |
|
|
:param Decoder decoder: Decoder对象 |
|
|
:param Decoder decoder: Decoder对象 |
|
|
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 |
|
|
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 |
|
|
:param State state: 应该包含encoder的一些输出。 |
|
|
:param State state: 应该包含encoder的一些输出。 |
|
|
:param int max_length: 生成句子的最大长度。 |
|
|
|
|
|
|
|
|
:param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len |
|
|
|
|
|
:param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask |
|
|
:param int num_beam: 使用多大的beam进行解码。 |
|
|
:param int num_beam: 使用多大的beam进行解码。 |
|
|
:param float temperature: 采样时的退火大小 |
|
|
:param float temperature: 采样时的退火大小 |
|
|
:param int top_k: 只在top_k的sample里面采样 |
|
|
:param int top_k: 只在top_k的sample里面采样 |
|
@@ -128,21 +133,21 @@ def sample_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1 |
|
|
""" |
|
|
""" |
|
|
# 每个位置在生成的时候会sample生成 |
|
|
# 每个位置在生成的时候会sample生成 |
|
|
if num_beams == 1: |
|
|
if num_beams == 1: |
|
|
token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, temperature=temperature, |
|
|
|
|
|
top_k=top_k, top_p=top_p, |
|
|
|
|
|
|
|
|
token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, |
|
|
|
|
|
temperature=temperature, top_k=top_k, top_p=top_p, |
|
|
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, |
|
|
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, |
|
|
repetition_penalty=repetition_penalty, length_penalty=length_penalty, |
|
|
repetition_penalty=repetition_penalty, length_penalty=length_penalty, |
|
|
pad_token_id=pad_token_id) |
|
|
pad_token_id=pad_token_id) |
|
|
else: |
|
|
else: |
|
|
token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, num_beams=num_beams, |
|
|
|
|
|
temperature=temperature, top_k=top_k, top_p=top_p, |
|
|
|
|
|
|
|
|
token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, |
|
|
|
|
|
num_beams=num_beams, temperature=temperature, top_k=top_k, top_p=top_p, |
|
|
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, |
|
|
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, |
|
|
repetition_penalty=repetition_penalty, length_penalty=length_penalty, |
|
|
repetition_penalty=repetition_penalty, length_penalty=length_penalty, |
|
|
pad_token_id=pad_token_id) |
|
|
pad_token_id=pad_token_id) |
|
|
return token_ids |
|
|
return token_ids |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_length=20, temperature=1.0, top_k=50, |
|
|
|
|
|
|
|
|
def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_length=20, max_len_a=0.0, temperature=1.0, top_k=50, |
|
|
top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, |
|
|
top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, |
|
|
repetition_penalty=1.0, length_penalty=1.0, pad_token_id=0): |
|
|
repetition_penalty=1.0, length_penalty=1.0, pad_token_id=0): |
|
|
device = _get_model_device(decoder) |
|
|
device = _get_model_device(decoder) |
|
@@ -169,7 +174,14 @@ def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_le |
|
|
dones = token_ids.new_zeros(batch_size).eq(1) |
|
|
dones = token_ids.new_zeros(batch_size).eq(1) |
|
|
# tokens = tokens[:, -1:] |
|
|
# tokens = tokens[:, -1:] |
|
|
|
|
|
|
|
|
while cur_len < max_length: |
|
|
|
|
|
|
|
|
if max_len_a!=0: |
|
|
|
|
|
max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length |
|
|
|
|
|
real_max_length = max_lengths.max() |
|
|
|
|
|
else: |
|
|
|
|
|
real_max_length = max_length |
|
|
|
|
|
max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length |
|
|
|
|
|
|
|
|
|
|
|
while cur_len < real_max_length: |
|
|
scores = decoder.decode(tokens=token_ids, state=state) # batch_size x vocab_size |
|
|
scores = decoder.decode(tokens=token_ids, state=state) # batch_size x vocab_size |
|
|
|
|
|
|
|
|
if repetition_penalty != 1.0: |
|
|
if repetition_penalty != 1.0: |
|
@@ -194,11 +206,12 @@ def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_le |
|
|
# 加上1e-12是为了避免https://github.com/pytorch/pytorch/pull/27523 |
|
|
# 加上1e-12是为了避免https://github.com/pytorch/pytorch/pull/27523 |
|
|
probs = F.softmax(scores, dim=-1) + 1e-12 |
|
|
probs = F.softmax(scores, dim=-1) + 1e-12 |
|
|
|
|
|
|
|
|
# 保证至少有一个不是eos的值 |
|
|
|
|
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size |
|
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size |
|
|
else: |
|
|
else: |
|
|
next_tokens = torch.argmax(scores, dim=-1) # batch_size |
|
|
next_tokens = torch.argmax(scores, dim=-1) # batch_size |
|
|
|
|
|
|
|
|
|
|
|
# 如果已经达到对应的sequence长度了,就直接填为eos了 |
|
|
|
|
|
next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len+1), _eos_token_id) |
|
|
next_tokens = next_tokens.masked_fill(dones, pad_token_id) # 对已经搜索完成的sample做padding |
|
|
next_tokens = next_tokens.masked_fill(dones, pad_token_id) # 对已经搜索完成的sample做padding |
|
|
tokens = next_tokens.unsqueeze(1) |
|
|
tokens = next_tokens.unsqueeze(1) |
|
|
|
|
|
|
|
@@ -211,14 +224,14 @@ def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_le |
|
|
if dones.min() == 1: |
|
|
if dones.min() == 1: |
|
|
break |
|
|
break |
|
|
|
|
|
|
|
|
if eos_token_id is not None: |
|
|
|
|
|
if cur_len == max_length: |
|
|
|
|
|
token_ids[:, -1].masked_fill_(~dones, eos_token_id) # 若到最长长度仍未到EOS,则强制将最后一个词替换成eos |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# if eos_token_id is not None: |
|
|
|
|
|
# tokens.scatter(index=max_lengths[:, None], dim=1, value=eos_token_id) # 将最大长度位置设置为eos |
|
|
|
|
|
# if cur_len == max_length: |
|
|
|
|
|
# token_ids[:, -1].masked_fill_(~dones, eos_token_id) # 若到最长长度仍未到EOS,则强制将最后一个词替换成eos |
|
|
return token_ids |
|
|
return token_ids |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_length=20, num_beams=4, temperature=1.0, |
|
|
|
|
|
|
|
|
def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=4, temperature=1.0, |
|
|
top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, |
|
|
top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, |
|
|
repetition_penalty=1.0, length_penalty=None, pad_token_id=0) -> torch.LongTensor: |
|
|
repetition_penalty=1.0, length_penalty=None, pad_token_id=0) -> torch.LongTensor: |
|
|
# 进行beam search |
|
|
# 进行beam search |
|
@@ -268,14 +281,22 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ |
|
|
# 用来记录已经生成好的token的长度 |
|
|
# 用来记录已经生成好的token的长度 |
|
|
cur_len = token_ids.size(1) |
|
|
cur_len = token_ids.size(1) |
|
|
|
|
|
|
|
|
|
|
|
if max_len_a!=0: |
|
|
|
|
|
# (bsz x num_beams, ) |
|
|
|
|
|
max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length |
|
|
|
|
|
real_max_length = max_lengths.max().item() |
|
|
|
|
|
else: |
|
|
|
|
|
real_max_length = max_length |
|
|
|
|
|
max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length |
|
|
|
|
|
|
|
|
hypos = [ |
|
|
hypos = [ |
|
|
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size) |
|
|
|
|
|
|
|
|
BeamHypotheses(num_beams, real_max_length, length_penalty, early_stopping=False) for _ in range(batch_size) |
|
|
] |
|
|
] |
|
|
# 0,num_beams, 2*num_beams, ... |
|
|
|
|
|
|
|
|
# 0, num_beams, 2*num_beams, ... |
|
|
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) |
|
|
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) |
|
|
|
|
|
|
|
|
while cur_len < max_length: |
|
|
|
|
|
scores = decoder.decode(token_ids, state) |
|
|
|
|
|
|
|
|
while cur_len < real_max_length: |
|
|
|
|
|
scores = decoder.decode(token_ids, state) # (bsz x num_beams, vocab_size) |
|
|
if repetition_penalty != 1.0: |
|
|
if repetition_penalty != 1.0: |
|
|
token_scores = scores.gather(dim=1, index=token_ids) |
|
|
token_scores = scores.gather(dim=1, index=token_ids) |
|
|
lt_zero_mask = token_scores.lt(0).float() |
|
|
lt_zero_mask = token_scores.lt(0).float() |
|
@@ -283,6 +304,12 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ |
|
|
token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores |
|
|
token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores |
|
|
scores.scatter_(dim=1, index=token_ids, src=token_scores) |
|
|
scores.scatter_(dim=1, index=token_ids, src=token_scores) |
|
|
|
|
|
|
|
|
|
|
|
if _eos_token_id!=-1: |
|
|
|
|
|
max_len_eos_mask = max_lengths.eq(cur_len+1) |
|
|
|
|
|
eos_scores = scores[:, _eos_token_id] |
|
|
|
|
|
# 如果已经达到最大长度,就把eos的分数加大 |
|
|
|
|
|
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+100, eos_scores) |
|
|
|
|
|
|
|
|
if do_sample: |
|
|
if do_sample: |
|
|
if temperature > 0 and temperature != 1: |
|
|
if temperature > 0 and temperature != 1: |
|
|
scores = scores / temperature |
|
|
scores = scores / temperature |
|
@@ -309,7 +336,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ |
|
|
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) |
|
|
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) |
|
|
_scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size) |
|
|
_scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size) |
|
|
_scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size) |
|
|
_scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size) |
|
|
next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) |
|
|
|
|
|
|
|
|
next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) # (bsz, 2*num_beams) |
|
|
from_which_beam = ids // vocab_size # (batch_size, 2*num_beams) |
|
|
from_which_beam = ids // vocab_size # (batch_size, 2*num_beams) |
|
|
next_tokens = ids % vocab_size # (batch_size, 2*num_beams) |
|
|
next_tokens = ids % vocab_size # (batch_size, 2*num_beams) |
|
|
|
|
|
|
|
@@ -328,12 +355,8 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ |
|
|
_next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams) |
|
|
_next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams) |
|
|
beam_scores = _next_scores.view(-1) |
|
|
beam_scores = _next_scores.view(-1) |
|
|
|
|
|
|
|
|
# 更改state状态, 重组token_ids |
|
|
|
|
|
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 |
|
|
|
|
|
state.reorder_state(reorder_inds) |
|
|
|
|
|
|
|
|
|
|
|
flag = True |
|
|
flag = True |
|
|
if cur_len+1 == max_length: |
|
|
|
|
|
|
|
|
if cur_len+1 == real_max_length: |
|
|
eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0) |
|
|
eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0) |
|
|
eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size) # 表示的是indice |
|
|
eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size) # 表示的是indice |
|
|
eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1) # 表示的是从哪个beam获取得到的 |
|
|
eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1) # 表示的是从哪个beam获取得到的 |
|
@@ -348,19 +371,24 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ |
|
|
else: |
|
|
else: |
|
|
flag = False |
|
|
flag = False |
|
|
|
|
|
|
|
|
# 重新组织token_ids的状态 |
|
|
|
|
|
tokens = _next_tokens |
|
|
|
|
|
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
if flag: |
|
|
if flag: |
|
|
for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(), |
|
|
for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(), |
|
|
eos_beam_idx.tolist()): |
|
|
eos_beam_idx.tolist()): |
|
|
if not dones[batch_idx]: |
|
|
if not dones[batch_idx]: |
|
|
score = next_scores[batch_idx, beam_ind].item() |
|
|
score = next_scores[batch_idx, beam_ind].item() |
|
|
hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len+1].clone(), score) |
|
|
|
|
|
|
|
|
# 之后需要在结尾新增一个eos |
|
|
|
|
|
hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score) |
|
|
|
|
|
|
|
|
for batch_idx in range(batch_size): |
|
|
for batch_idx in range(batch_size): |
|
|
dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) |
|
|
|
|
|
|
|
|
dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) or \ |
|
|
|
|
|
max_lengths[batch_idx*num_beams]==cur_len+1 |
|
|
|
|
|
|
|
|
|
|
|
# 更改state状态, 重组token_ids |
|
|
|
|
|
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 |
|
|
|
|
|
state.reorder_state(reorder_inds) |
|
|
|
|
|
# 重新组织token_ids的状态 |
|
|
|
|
|
tokens = _next_tokens |
|
|
|
|
|
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1) |
|
|
|
|
|
|
|
|
cur_len += 1 |
|
|
cur_len += 1 |
|
|
|
|
|
|
|
@@ -373,15 +401,16 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ |
|
|
|
|
|
|
|
|
for i, hypotheses in enumerate(hypos): |
|
|
for i, hypotheses in enumerate(hypos): |
|
|
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] |
|
|
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] |
|
|
tgt_len[i] = len(best_hyp) # +1 for the <EOS> symbol |
|
|
|
|
|
|
|
|
# 把上面替换为非eos的词替换回eos |
|
|
|
|
|
if _eos_token_id!=-1: |
|
|
|
|
|
best_hyp = torch.cat([best_hyp, best_hyp.new_ones(1)*_eos_token_id]) |
|
|
|
|
|
tgt_len[i] = len(best_hyp) |
|
|
best.append(best_hyp) |
|
|
best.append(best_hyp) |
|
|
|
|
|
|
|
|
# generate target batch |
|
|
# generate target batch |
|
|
decoded = token_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id) |
|
|
decoded = token_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id) |
|
|
for i, hypo in enumerate(best): |
|
|
for i, hypo in enumerate(best): |
|
|
decoded[i, :tgt_len[i]] = hypo |
|
|
decoded[i, :tgt_len[i]] = hypo |
|
|
if eos_token_id is not None: |
|
|
|
|
|
decoded[i, tgt_len[i] - 1] = _eos_token_id |
|
|
|
|
|
|
|
|
|
|
|
return decoded |
|
|
return decoded |
|
|
|
|
|
|
|
|