|
@@ -12,9 +12,11 @@ import torch.nn.functional as F |
|
|
from ...core.utils import _get_model_device |
|
|
from ...core.utils import _get_model_device |
|
|
from functools import partial |
|
|
from functools import partial |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SequenceGenerator: |
|
|
class SequenceGenerator: |
|
|
""" |
|
|
""" |
|
|
给定一个Seq2SeqDecoder,decode出句子 |
|
|
|
|
|
|
|
|
给定一个Seq2SeqDecoder,decode出句子。输入的decoder对象需要有decode()函数, 接受的第一个参数为decode的到目前位置的所有输出, |
|
|
|
|
|
第二个参数为state。SequenceGenerator不会对state进行任何操作。 |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
def __init__(self, decoder: Seq2SeqDecoder, max_length=20, max_len_a=0.0, num_beams=1, |
|
|
def __init__(self, decoder: Seq2SeqDecoder, max_length=20, max_len_a=0.0, num_beams=1, |
|
@@ -65,7 +67,8 @@ class SequenceGenerator: |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
:param State state: encoder结果的State, 是与Decoder配套是用的 |
|
|
:param State state: encoder结果的State, 是与Decoder配套是用的 |
|
|
:param torch.LongTensor,None tokens: batch_size x length, 开始的token |
|
|
|
|
|
|
|
|
:param torch.LongTensor,None tokens: batch_size x length, 开始的token。如果为None,则默认添加bos_token作为开头的token |
|
|
|
|
|
进行生成。 |
|
|
:return: bsz x max_length' 生成的token序列。如果eos_token_id不为None, 每个sequence的结尾一定是eos_token_id |
|
|
:return: bsz x max_length' 生成的token序列。如果eos_token_id不为None, 每个sequence的结尾一定是eos_token_id |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
@@ -168,6 +171,8 @@ def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_le |
|
|
_eos_token_id = eos_token_id |
|
|
_eos_token_id = eos_token_id |
|
|
|
|
|
|
|
|
scores = decoder.decode(tokens=tokens, state=state) # 主要是为了update state |
|
|
scores = decoder.decode(tokens=tokens, state=state) # 主要是为了update state |
|
|
|
|
|
if _eos_token_id!=-1: # 防止第一个位置为结束 |
|
|
|
|
|
scores[:, _eos_token_id] = -1e12 |
|
|
next_tokens = scores.argmax(dim=-1, keepdim=True) |
|
|
next_tokens = scores.argmax(dim=-1, keepdim=True) |
|
|
token_ids = torch.cat([tokens, next_tokens], dim=1) |
|
|
token_ids = torch.cat([tokens, next_tokens], dim=1) |
|
|
cur_len = token_ids.size(1) |
|
|
cur_len = token_ids.size(1) |
|
@@ -261,6 +266,8 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ |
|
|
_eos_token_id = eos_token_id |
|
|
_eos_token_id = eos_token_id |
|
|
|
|
|
|
|
|
scores = decoder.decode(tokens=tokens, state=state) # 这里要传入的是整个句子的长度 |
|
|
scores = decoder.decode(tokens=tokens, state=state) # 这里要传入的是整个句子的长度 |
|
|
|
|
|
if _eos_token_id!=-1: # 防止第一个位置为结束 |
|
|
|
|
|
scores[:, _eos_token_id] = -1e12 |
|
|
vocab_size = scores.size(1) |
|
|
vocab_size = scores.size(1) |
|
|
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." |
|
|
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." |
|
|
|
|
|
|
|
@@ -322,7 +329,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ |
|
|
max_len_eos_mask = max_lengths.eq(cur_len+1) |
|
|
max_len_eos_mask = max_lengths.eq(cur_len+1) |
|
|
eos_scores = scores[:, _eos_token_id] |
|
|
eos_scores = scores[:, _eos_token_id] |
|
|
# 如果已经达到最大长度,就把eos的分数加大 |
|
|
# 如果已经达到最大长度,就把eos的分数加大 |
|
|
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e12, eos_scores) |
|
|
|
|
|
|
|
|
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+float('inf'), eos_scores) |
|
|
|
|
|
|
|
|
if do_sample: |
|
|
if do_sample: |
|
|
if temperature > 0 and temperature != 1: |
|
|
if temperature > 0 and temperature != 1: |
|
@@ -356,9 +363,9 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ |
|
|
|
|
|
|
|
|
# 接下来需要组装下一个batch的结果。 |
|
|
# 接下来需要组装下一个batch的结果。 |
|
|
# 需要选定哪些留下来 |
|
|
# 需要选定哪些留下来 |
|
|
next_scores, sorted_inds = next_scores.sort(dim=-1, descending=True) |
|
|
|
|
|
next_tokens = next_tokens.gather(dim=1, index=sorted_inds) |
|
|
|
|
|
from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds) |
|
|
|
|
|
|
|
|
# next_scores, sorted_inds = next_scores.sort(dim=-1, descending=True) |
|
|
|
|
|
# next_tokens = next_tokens.gather(dim=1, index=sorted_inds) |
|
|
|
|
|
# from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds) |
|
|
|
|
|
|
|
|
not_eos_mask = next_tokens.ne(_eos_token_id) # 为1的地方不是eos |
|
|
not_eos_mask = next_tokens.ne(_eos_token_id) # 为1的地方不是eos |
|
|
keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) # 为1的地方需要保留 |
|
|
keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) # 为1的地方需要保留 |
|
@@ -413,7 +420,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ |
|
|
break |
|
|
break |
|
|
|
|
|
|
|
|
# select the best hypotheses |
|
|
# select the best hypotheses |
|
|
tgt_len = token_ids.new(batch_size) |
|
|
|
|
|
|
|
|
tgt_len = token_ids.new_zeros(batch_size) |
|
|
best = [] |
|
|
best = [] |
|
|
|
|
|
|
|
|
for i, hypotheses in enumerate(hypos): |
|
|
for i, hypotheses in enumerate(hypos): |
|
@@ -425,7 +432,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ |
|
|
best.append(best_hyp) |
|
|
best.append(best_hyp) |
|
|
|
|
|
|
|
|
# generate target batch |
|
|
# generate target batch |
|
|
decoded = token_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id) |
|
|
|
|
|
|
|
|
decoded = token_ids.new_zeros(batch_size, tgt_len.max().item()).fill_(pad_token_id) |
|
|
for i, hypo in enumerate(best): |
|
|
for i, hypo in enumerate(best): |
|
|
decoded[i, :tgt_len[i]] = hypo |
|
|
decoded[i, :tgt_len[i]] = hypo |
|
|
|
|
|
|
|
|