|
|
@@ -175,11 +175,18 @@ def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_le |
|
|
|
# tokens = tokens[:, -1:] |
|
|
|
|
|
|
|
if max_len_a!=0: |
|
|
|
max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length |
|
|
|
real_max_length = max_lengths.max() |
|
|
|
# (bsz x num_beams, ) |
|
|
|
if state.encoder_mask is not None: |
|
|
|
max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length |
|
|
|
else: |
|
|
|
max_lengths = tokens.new_full((tokens.size(0), ), fill_value=max_length, dtype=torch.long) |
|
|
|
real_max_length = max_lengths.max().item() |
|
|
|
else: |
|
|
|
real_max_length = max_length |
|
|
|
max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length |
|
|
|
if state.encoder_mask is not None: |
|
|
|
max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length |
|
|
|
else: |
|
|
|
max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long) |
|
|
|
|
|
|
|
while cur_len < real_max_length: |
|
|
|
scores = decoder.decode(tokens=token_ids, state=state) # batch_size x vocab_size |
|
|
@@ -211,7 +218,8 @@ def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_le |
|
|
|
next_tokens = torch.argmax(scores, dim=-1) # batch_size |
|
|
|
|
|
|
|
# 如果已经达到对应的sequence长度了,就直接填为eos了 |
|
|
|
next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len+1), _eos_token_id) |
|
|
|
if _eos_token_id!=-1: |
|
|
|
next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len+1), _eos_token_id) |
|
|
|
next_tokens = next_tokens.masked_fill(dones, pad_token_id) # 对已经搜索完成的sample做padding |
|
|
|
tokens = next_tokens.unsqueeze(1) |
|
|
|
|
|
|
@@ -283,12 +291,17 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ |
|
|
|
|
|
|
|
if max_len_a!=0: |
|
|
|
# (bsz x num_beams, ) |
|
|
|
max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length |
|
|
|
if state.encoder_mask is not None: |
|
|
|
max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length |
|
|
|
else: |
|
|
|
max_lengths = tokens.new_full((tokens.size(0), ), fill_value=max_length, dtype=torch.long) |
|
|
|
real_max_length = max_lengths.max().item() |
|
|
|
else: |
|
|
|
real_max_length = max_length |
|
|
|
max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length |
|
|
|
|
|
|
|
if state.encoder_mask is not None: |
|
|
|
max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length |
|
|
|
else: |
|
|
|
max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long) |
|
|
|
hypos = [ |
|
|
|
BeamHypotheses(num_beams, real_max_length, length_penalty, early_stopping=False) for _ in range(batch_size) |
|
|
|
] |
|
|
@@ -371,25 +384,28 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ |
|
|
|
else: |
|
|
|
flag = False |
|
|
|
|
|
|
|
# 更改state状态, 重组token_ids |
|
|
|
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 |
|
|
|
state.reorder_state(reorder_inds) |
|
|
|
# 重新组织token_ids的状态 |
|
|
|
tokens = _next_tokens |
|
|
|
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1) |
|
|
|
|
|
|
|
if flag: |
|
|
|
for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(), |
|
|
|
eos_beam_idx.tolist()): |
|
|
|
if not dones[batch_idx]: |
|
|
|
score = next_scores[batch_idx, beam_ind].item() |
|
|
|
# 之后需要在结尾新增一个eos |
|
|
|
hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score) |
|
|
|
if _eos_token_id!=-1: |
|
|
|
hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score) |
|
|
|
else: |
|
|
|
hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx].clone(), score) |
|
|
|
|
|
|
|
for batch_idx in range(batch_size): |
|
|
|
dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) or \ |
|
|
|
max_lengths[batch_idx*num_beams]==cur_len+1 |
|
|
|
|
|
|
|
# 更改state状态, 重组token_ids |
|
|
|
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 |
|
|
|
state.reorder_state(reorder_inds) |
|
|
|
# 重新组织token_ids的状态 |
|
|
|
tokens = _next_tokens |
|
|
|
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1) |
|
|
|
|
|
|
|
cur_len += 1 |
|
|
|
|
|
|
|
if all(dones): |
|
|
|