|
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license which can be found at
- # https://github.com/facebookresearch/fairseq/blob/main/LICENSE
-
- import math
- from typing import List, Optional
-
- import torch
- import torch.nn as nn
- from torch import Tensor
-
- from .token_generation_constraints import (ConstraintState,
- OrderedConstraintState,
- UnorderedConstraintState)
-
-
- class Search(nn.Module):
-
- def __init__(self, tokenizer):
- super().__init__()
- self.pad = tokenizer.pad_token_id
- self.unk = tokenizer.unk_token_id
- self.eos = tokenizer.eos_token_id
- tgt_dict = {value: key for key, value in tokenizer.get_vocab().items()}
- added = {
- value: key
- for key, value in tokenizer.get_added_vocab().items()
- }
- tgt_dict.update(added)
- self.vocab_size = len(tgt_dict)
- self.src_lengths = torch.tensor(-1)
- self.supports_constraints = False
- self.stop_on_max_len = False
-
- def step(self,
- step,
- lprobs,
- scores,
- prev_output_tokens=None,
- original_batch_idxs=None):
- """Take a single search step.
-
- Args:
- step: the current search step, starting at 0
- lprobs: (bsz x input_beam_size x vocab_size)
- the model's log-probabilities over the vocabulary at the current step
- scores: (bsz x input_beam_size x step)
- the historical model scores of each hypothesis up to this point
- prev_output_tokens: (bsz x step)
- the previously generated oputput tokens
- original_batch_idxs: (bsz)
- the tensor with the batch indices, in the range [0, bsz)
- this is useful in case there has been applied a re-ordering
- and we need to know the orignal indices
-
- Return: A tuple of (scores, indices, beams) where:
- scores: (bsz x output_beam_size)
- the scores of the chosen elements; output_beam_size can be
- larger than input_beam_size, e.g., we may return
- 2*input_beam_size to account for EOS
- indices: (bsz x output_beam_size)
- the indices of the chosen elements
- beams: (bsz x output_beam_size)
- the hypothesis ids of the chosen elements, in the range [0, input_beam_size)
- """
- raise NotImplementedError
-
- @torch.jit.export
- def set_src_lengths(self, src_lengths):
- self.src_lengths = src_lengths
-
- @torch.jit.export
- def init_constraints(self, batch_constraints: Optional[Tensor],
- beam_size: int):
- """Initialize constraint states for constrained decoding (if supported).
-
- Args:
- batch_constraints: (torch.Tensor, optional)
- the list of constraints, in packed form
- beam_size: (int)
- the beam size
- Returns:
- *encoder_out* rearranged according to *new_order*
- """
- pass
-
- def prune_sentences(self, batch_idxs: Tensor):
- """
- Removes constraint states for completed sentences (if supported).
- This is called from sequence_generator._generate() when sentences are
- deleted from the batch.
-
- Args:
- batch_idxs: Indices of *sentences* whose constraint state should be *kept*.
- """
- pass
-
- def update_constraints(self, active_hypos: Tensor):
- """
- Updates the constraint states by selecting the beam items that are retained.
- This is called at each time step of sequence_generator._generate() when
- the set of 2 * {beam_size} candidate hypotheses are reduced to the beam size.
-
- Args:
- active_hypos: (batch size, beam size)
- list of integers denoting, for each sentence, which beam candidate items
- should be kept.
- """
- pass
-
-
- class BeamSearch(Search):
-
- def __init__(self, tgt_dict):
- super().__init__(tgt_dict)
- self.constraint_states = None
-
- @torch.jit.export
- def step(
- self,
- step: int,
- lprobs,
- scores: Optional[Tensor],
- prev_output_tokens: Optional[Tensor] = None,
- original_batch_idxs: Optional[Tensor] = None,
- ):
- bsz, beam_size, vocab_size = lprobs.size()
-
- if step == 0:
- # at the first step all hypotheses are equally likely, so use
- # only the first beam
- lprobs = lprobs[:, ::beam_size, :].contiguous()
- else:
- # make probs contain cumulative scores for each hypothesis
- assert scores is not None
- lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1)
-
- top_prediction = torch.topk(
- lprobs.view(bsz, -1),
- k=min(
- # Take the best 2 x beam_size predictions. We'll choose the first
- # beam_size of these which don't predict eos to continue with.
- beam_size * 2,
- lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad
- ),
- )
- scores_buf = top_prediction[0]
- indices_buf = top_prediction[1]
- # Project back into relative indices and beams
- beams_buf = torch.div(indices_buf, vocab_size, rounding_mode='floor')
- indices_buf = indices_buf.fmod(vocab_size)
-
- # At this point, beams_buf and indices_buf are single-dim and contain relative indices
- return scores_buf, indices_buf, beams_buf
-
-
- class PrefixConstrainedBeamSearch(Search):
-
- def __init__(self, tgt_dict, prefix_allowed_tokens_fn):
- super().__init__(tgt_dict)
- self.prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
- self.stop_on_max_len = True
-
- @torch.jit.export
- def apply_mask(self, x, prev_output_tokens, original_batch_idxs):
- beam_size = x.shape[0] // original_batch_idxs.shape[0]
- original_batch_idxs = (
- original_batch_idxs.unsqueeze(-1).repeat(
- (1, beam_size)).flatten().tolist())
-
- mask = torch.full_like(x, -math.inf)
- for sent_i, (sent, batch_i) in enumerate(
- zip(prev_output_tokens, original_batch_idxs)):
- mask[sent_i, :, self.prefix_allowed_tokens_fn(batch_i, sent)] = 0
-
- return mask
-
- @torch.jit.export
- def step(
- self,
- step: int,
- lprobs: Tensor,
- scores: Tensor,
- prev_output_tokens: Tensor,
- original_batch_idxs: Tensor,
- ):
- bsz, beam_size, vocab_size = lprobs.size()
-
- lprobs += self.apply_mask(
- lprobs.view(bsz * beam_size, 1, vocab_size),
- prev_output_tokens,
- original_batch_idxs,
- ).view(bsz, beam_size, vocab_size)
-
- if step == 0:
- # at the first step all hypotheses are equally likely, so use
- # only the first beam
- lprobs = lprobs[:, ::beam_size, :].contiguous()
- else:
- # make probs contain cumulative scores for each hypothesis
- assert scores is not None
- lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1)
-
- top_prediction = torch.topk(
- lprobs.view(bsz, -1),
- k=min(
- # Take the best beam_size predictions. We'll choose the first
- # beam_size of these which don't predict eos to continue with.
- beam_size,
- lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad
- ),
- )
- scores_buf = top_prediction[0]
- indices_buf = top_prediction[1]
- beams_buf = indices_buf // vocab_size
- indices_buf = indices_buf.fmod(vocab_size)
- return scores_buf, indices_buf, beams_buf
-
-
- class LexicallyConstrainedBeamSearch(Search):
- """Implements lexically constrained beam search as described in
-
- Fast Lexically Constrained Decoding with Dynamic Beam
- Allocation for Neural Machine Translation. Post & Vilar,
- NAACL 2018. https://www.aclweb.org/anthology/N18-1119/
-
- and
-
- Improved Lexically Constrained Decoding for Translation and
- Monolingual Rewriting. Hu et al, NAACL
- 2019. https://www.aclweb.org/anthology/N19-1090/
-
- This is accomplished by maintaining, for each beam hypothesis, a
- ConstraintState object (see constraints.py) that tracks which
- constraints have been generated and using this information to
- shape the beam for each input sentence.
- """
-
- def __init__(self, tokenizer, representation):
- super().__init__(tokenizer)
- self.representation = representation
- tgt_dict = {value: key for key, value in tokenizer.get_vocab().items()}
- added = {
- value: key
- for key, value in tokenizer.get_added_vocab().items()
- }
- tgt_dict.update(added)
- self.vocab_size = len(tgt_dict)
- self.num_cands = 0
- self.supports_constraints = True
-
- @torch.jit.export
- def init_constraints(self, batch_constraints: Optional[Tensor],
- beam_size: int):
- self.constraint_states = []
- for constraint_tensor in batch_constraints:
- if self.representation == 'ordered':
- constraint_state = OrderedConstraintState.create(
- constraint_tensor)
- elif self.representation == 'unordered':
- constraint_state = UnorderedConstraintState.create(
- constraint_tensor)
-
- self.constraint_states.append(
- [constraint_state for i in range(beam_size)])
-
- @torch.jit.export
- def prune_sentences(self, batch_idxs: Tensor):
- self.constraint_states = [
- self.constraint_states[i] for i in batch_idxs.tolist()
- ]
-
- @torch.jit.export
- def update_constraints(self, active_hypos: Tensor):
- if self.constraint_states:
- batch_size = active_hypos.size(0)
- for sentid in range(batch_size):
- self.constraint_states[sentid] = [
- self.constraint_states[sentid][i]
- for i in active_hypos[sentid]
- ]
-
- @torch.jit.export
- def step(
- self,
- step: int,
- lprobs: Tensor,
- scores: Optional[Tensor],
- prev_output_tokens: Optional[Tensor] = None,
- original_batch_idxs: Optional[Tensor] = None,
- ):
- """
- A constrained step builds a large candidates list from the following:
- - the top 2 * {beam_size} items over the whole beam
- - for each item in the beam
- - the top {each_k} (default 1)
- - all next constraints
- We then compute the constrained state of each beam item, and assign
- stripe codes: 0 to the best in each bank, 1 to the 2nd-best, and so
- on. We then sort by (stripe, score), and truncate the list at
- 2 * beam size.
-
- Args:
- step: the decoder step
- lprobs: (batch size, beam size, target vocab)
- the target-vocab distributions for each item in the beam.
- Retrun: A tuple of (scores, indices, beams, constraints) where:
- scores: (batch, output beam size)
- the scores of the chosen elements
- indices: (batch, output beam size)
- the target vocab indices of the chosen elements
- beams: (batch, output beam size)
- the 0-indexed hypothesis ids of the chosen elements
- constraints: (batch, output beam size)
- the new constraint states
- """
- each_k = 1
- device = lprobs.device
-
- batch_size, beam_size, vocab_size = lprobs.size()
-
- self.num_cands = min(
- # Just take the k-best. We'll get another k from the 1-best from each
- # row, plus more from the constraints
- beam_size * 2,
- lprobs.view(batch_size, -1).size(1)
- - 1, # -1 so we never select pad
- )
-
- # STEP 0: Preliminary. Prevent EOS for unfinished hyps across all batch items
- constraint_states = self.constraint_states
- if constraint_states and step > 0:
- not_finished_indices = []
- for sentno, sent_constraints in enumerate(constraint_states):
- for beamno, state in enumerate(sent_constraints):
- index = sentno * beam_size + beamno
- if not state.finished:
- not_finished_indices.append(index)
- not_finished_indices = torch.tensor(not_finished_indices)
- if not_finished_indices.numel() > 0:
- lprobs.view(batch_size * beam_size, -1)[not_finished_indices,
- self.eos] = -math.inf
-
- if step == 0:
- # at the first step all hypotheses are equally likely, so use
- # only the first beam entry for each batch item
- lprobs = lprobs[:, ::beam_size, :].contiguous()
- else:
- # make probs contain cumulative scores for each hypothesis
- assert scores is not None
- lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1)
-
- top_prediction = torch.topk(
- lprobs.view(batch_size, -1),
- self.num_cands,
- )
- scores_buf, indices_buf = top_prediction
- # Project back into relative indices and beams
- beams_buf = indices_buf // vocab_size
- indices_buf = indices_buf.fmod(vocab_size)
-
- # Short circuit if there are no constraints in this batch
- if not constraint_states:
- return scores_buf, indices_buf, beams_buf
-
- # STEP 1: get top-1 from each hypothesis across all sentences in the batch
- if step > 0:
- top_scores, top_indices = torch.topk(
- lprobs.view(batch_size * beam_size, -1),
- k=each_k,
- dim=1,
- )
- top_scores = top_scores.view(batch_size, -1)
- top_indices = top_indices.view(batch_size, -1)
- scores_buf = torch.cat((scores_buf, top_scores), dim=1)
- indices_buf = torch.cat((indices_buf, top_indices), dim=1)
- new_beams = torch.arange(
- 0, beam_size, device=device).repeat(batch_size, 1)
- beams_buf = torch.cat((beams_buf, new_beams), dim=1)
-
- # Now, process sentences in the batch one by one.
- new_scores_buf = torch.zeros((batch_size, 2 * beam_size),
- device=device)
- new_indices_buf = torch.zeros((batch_size, 2 * beam_size),
- device=device).long()
- new_beams_buf = torch.zeros((batch_size, 2 * beam_size),
- device=device).long()
- for sentno, states in enumerate(constraint_states):
- scores, indices, beams, new_states = self.step_sentence(
- step,
- sentno,
- lprobs[sentno],
- constraint_states[sentno],
- beams_buf[sentno].clone(),
- indices_buf[sentno].clone(),
- scores_buf[sentno].clone(),
- )
- new_scores_buf[sentno] = scores
- new_indices_buf[sentno] = indices
- new_beams_buf[sentno] = beams
- self.constraint_states[sentno] = new_states
-
- return new_scores_buf, new_indices_buf, new_beams_buf
-
- @torch.jit.export
- def step_sentence(
- self,
- step: int,
- sentno: int,
- lprobs: Tensor,
- constraint_states: List[List[ConstraintState]],
- beams_buf: Tensor,
- indices_buf: Tensor,
- scores_buf: Tensor,
- ):
- """Does per-sentence processing. Adds all constraints for each
- hypothesis to the list of candidates; then removes duplicates,
- sorts, and dynamically stripes across the banks. All tensor inputs
- are collapsed to those pertaining to a single input sentence.
- """
- device = lprobs.device
-
- # STEP 2: Add all constraints for each beam item
- for beamno, state in enumerate(constraint_states):
- next_tokens = torch.tensor(
- list(state.next_tokens()), device=device).long()
- if next_tokens.numel() != 0:
- indices_buf = torch.cat((indices_buf, next_tokens))
- next_beams = (
- torch.tensor(beamno, device=device).repeat(
- next_tokens.size(0)).long())
- beams_buf = torch.cat((beams_buf, next_beams))
- next_values = lprobs[beamno].take(next_tokens.view(-1))
- scores_buf = torch.cat((scores_buf, next_values))
-
- # At the 0th time step, there is just one beam item
- if step == 0:
- break
-
- # STEP 3: Compute the "bank" for each candidate. This is the
- # number of constraints it's generated. We need this so that
- # we can do round-robin allocation of the beam across these
- # banks. If C is the number of constraints, we select the best
- # item in bank C, then the best in bank C-1, etc, followed by
- # the 2nd-best in bank C, the 2nd-best in bank C-1, etc, and so
- # on, until the maximum beam size. We accomplish this by
- # creating a sort key and striping across the banks.
-
- # Compute the new states for all candidates
- cands_size = indices_buf.size(0)
- constraint_states = [
- constraint_states[beams_buf[i]].advance(indices_buf[i])
- for i in range(cands_size)
- ]
-
- banks = torch.tensor([state.bank for state in constraint_states],
- device=device)
-
- # STEP 4: Sort
- num_constraint_tokens = len(state.tokens)
-
- # Sort by keys (bank, score) (i.e., sort banks together, and scores
- # within banks). AFAIK pytorch doesn't support either stable sort or
- # multi-key sorting, so we have to hack this.
- MAX_SCORE = -100
- sort_key = (num_constraint_tokens - banks) * MAX_SCORE + scores_buf
- sort_values, sort_indices = sort_key.sort(dim=0, descending=True)
- scores_buf = scores_buf[sort_indices]
- indices_buf = indices_buf[sort_indices]
- beams_buf = beams_buf[sort_indices]
- banks = banks[sort_indices]
-
- # Sort the constraints to follow suit
- constraint_states = [constraint_states[i] for i in sort_indices]
-
- # STEP 5: Remove duplicates. The topk calls (overall and
- # per-row) plus the per-row generation of constraints will
- # produce duplicates. Here we remove them.
-
- def roll(t):
- """Rolls a 1d tensor left by 1.
-
- [0, 1, 2, 3, 4] becomes [4, 0, 1, 2, 3]
- """
- return torch.cat((t[-1].unsqueeze(0), t[0:-1]), dim=0)
-
- # We map candidates (beam, token_id) to a single dimension.
- # This is then shifted by 1. We can then easily identify
- # duplicates and create a mask that identifies unique
- # extensions.
- uniques_mask = beams_buf * (self.vocab_size + 1) + indices_buf
- uniques_mask = roll(uniques_mask) != uniques_mask
-
- # Use the mask to pare down the data structures
- scores_buf = torch.masked_select(scores_buf, uniques_mask)
- indices_buf = torch.masked_select(indices_buf, uniques_mask)
- beams_buf = torch.masked_select(beams_buf, uniques_mask)
- banks = torch.masked_select(banks, uniques_mask)
- i = 1
- for mask in uniques_mask[1:]:
- if not mask:
- constraint_states.pop(i)
- i += mask
-
- # STEP 6: Assign IDs round-robin across banks, sort, and
- # truncate. Now that the candidates are sorted by (bank,
- # score) and uniqed, we dynamically allocate the {beam_size}
- # beam by striping across the candidates. These stripes will
- # be used as sort keys to do round-robin selection. This is
- # accomplished in a single pass with offsets. Sorting by
- # highest-banks (furthest-along hypotheses) first ensures
- # progress through the constraints.
- #
- # e.g., BANKS: 3 3 3 2 2 2 2 1 1 1 0 0
- # OLD STRIPES: 0 1 2 0 1 2 3 0 1 2 0 1
- # NEW STRIPES: 0 1+4 2+8 0+1 1+5 2+9 3+11 0+2 1+6 2+10 0+3 1+7
- # = 0 5 10 1 6 11 13 2 7 12 3 8
- #
- # Sorting by this then gives the following banks:
- #
- # 3 2 1 0 3 2 1 0 3 2 1 2
- #
- # We'll take the top {beam_size} of these.
- stripe_offsets = [
- offset * (len(banks) + 1) for offset in range(len(banks) + 1)
- ]
- stripes = torch.zeros_like(banks)
- cur_bank_count = -1
- cur_bank = banks[0]
- for i, bank in enumerate(banks):
- if bank != cur_bank:
- cur_bank_count = 0
- cur_bank = bank
- else:
- cur_bank_count += 1
- stripes[i] = num_constraint_tokens - bank + stripe_offsets[
- cur_bank_count]
-
- # STEP 7: Sort by the stripes values
- sort_values, sort_indices = stripes.sort(dim=0)
- scores_buf = scores_buf[sort_indices]
- indices_buf = indices_buf[sort_indices]
- beams_buf = beams_buf[sort_indices]
- constraint_states = [constraint_states[i] for i in sort_indices]
-
- # STEP 8: Truncate to the candidates size!
- scores_buf = scores_buf[:self.num_cands]
- indices_buf = indices_buf[:self.num_cands]
- beams_buf = beams_buf[:self.num_cands]
-
- return scores_buf, indices_buf, beams_buf, constraint_states
-
-
- class LengthConstrainedBeamSearch(Search):
-
- def __init__(self, tgt_dict, min_len_a, min_len_b, max_len_a, max_len_b):
- super().__init__(tgt_dict)
- self.min_len_a = min_len_a
- self.min_len_b = min_len_b
- self.max_len_a = max_len_a
- self.max_len_b = max_len_b
- self.beam = BeamSearch(tgt_dict)
- self.needs_src_lengths = True
-
- def step(
- self,
- step: int,
- lprobs,
- scores,
- prev_output_tokens: Optional[Tensor] = None,
- original_batch_idxs: Optional[Tensor] = None,
- ):
- min_lens = self.min_len_a * self.src_lengths + self.min_len_b
- max_lens = self.max_len_a * self.src_lengths + self.max_len_b
- lprobs[step < min_lens, :, self.eos] = -math.inf
- lprobs[step >= max_lens, :, self.eos] = 0
- return self.beam.step(step, lprobs, scores)
-
-
- class DiverseBeamSearch(Search):
- """Diverse Beam Search.
-
- See "Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence
- Models" for details.
-
- We only implement the Hamming Diversity penalty here, which performed best
- in the original paper.
- """
-
- def __init__(self, tgt_dict, num_groups, diversity_strength):
- super().__init__(tgt_dict)
- self.num_groups = num_groups
- self.diversity_strength = -diversity_strength
- self.beam = BeamSearch(tgt_dict)
-
- @torch.jit.export
- def step(
- self,
- step: int,
- lprobs,
- scores,
- prev_output_tokens: Optional[Tensor] = None,
- original_batch_idxs: Optional[Tensor] = None,
- ):
- bsz, beam_size, vocab_size = lprobs.size()
- if beam_size % self.num_groups != 0:
- raise ValueError(
- 'DiverseBeamSearch requires --beam to be divisible by the number of groups'
- )
-
- # initialize diversity penalty
- diversity_buf = torch.zeros(lprobs[:, 0, :].size()).to(lprobs)
-
- scores_G, indices_G, beams_G = [], [], []
- for g in range(self.num_groups):
- lprobs_g = lprobs[:, g::self.num_groups, :]
- scores_g = scores[:, g::self.num_groups, :] if step > 0 else None
-
- # apply diversity penalty
- if g > 0:
- lprobs_g = torch.add(
- lprobs_g,
- other=diversity_buf.unsqueeze(1),
- alpha=self.diversity_strength,
- )
- else:
- lprobs_g = lprobs_g.contiguous()
-
- scores_buf, indices_buf, beams_buf = self.beam.step(
- step, lprobs_g, scores_g)
- beams_buf.mul_(self.num_groups).add_(g)
-
- scores_G.append(scores_buf.clone())
- indices_G.append(indices_buf.clone())
- beams_G.append(beams_buf.clone())
-
- # update diversity penalty
- diversity_buf.scatter_add_(
- 1, indices_buf,
- torch.ones(indices_buf.size()).to(diversity_buf))
-
- # interleave results from different groups
- scores_buf = torch.stack(scores_G, dim=2).view(bsz, -1)
- indices_buf = torch.stack(indices_G, dim=2).view(bsz, -1)
- beams_buf = torch.stack(beams_G, dim=2).view(bsz, -1)
- return scores_buf, indices_buf, beams_buf
-
-
- class Sampling(Search):
- sampling_topk: int
- sampling_topp: float
-
- def __init__(self, tgt_dict, sampling_topk=-1, sampling_topp=-1.0):
- super().__init__(tgt_dict)
- self.sampling_topk = sampling_topk
- self.sampling_topp = sampling_topp
-
- def _sample_topp(self, lprobs):
- """Sample among the smallest set of elements whose cumulative probability mass exceeds p.
-
- See `"The Curious Case of Neural Text Degeneration"
- (Holtzman et al., 2019) <https://arxiv.org/abs/1904.09751>`_.
-
- Args:
- lprobs: (bsz x input_beam_size x vocab_size)
- the model's log-probabilities over the vocabulary at the current step
-
- Return: A tuple of (trimed_probs, truncated_indices) where:
- trimed_probs: (bsz x input_beam_size x ?)
- the model's probabilities over the elements selected to sample from. The
- width of the third dimension is determined by top-P.
- truncated_indices: (bsz x input_beam_size x ?)
- the indices of the chosen elements.
- """
- probs = lprobs.exp_()
-
- # sort the last dimension (vocab dimension) in descending order
- sorted_probs, sorted_indices = probs.sort(descending=True)
-
- # compute a mask to indicate the words to be included in the top-P set.
- cumsum_probs = sorted_probs.cumsum(dim=2)
- mask = cumsum_probs.lt(self.sampling_topp)
-
- # note that mask was computed by 'lt'. One more word needs to be included
- # so that the cumulative probability mass can exceed p.
- cumsum_mask = mask.cumsum(dim=2)
- last_included = cumsum_mask[:, :, -1:]
- last_included.clamp_(0, mask.size()[2] - 1)
- mask = mask.scatter_(2, last_included, 1)
-
- # truncate unnecessary dims.
- max_dim = last_included.max()
- truncated_mask = mask[:, :, :max_dim + 1]
- truncated_probs = sorted_probs[:, :, :max_dim + 1]
- truncated_indices = sorted_indices[:, :, :max_dim + 1]
-
- # trim the words that are not in top-P by setting their probabilities
- # to 0, so that they would not be sampled later.
- trim_mask = ~truncated_mask
- trimed_probs = truncated_probs.masked_fill_(trim_mask, 0)
- return trimed_probs, truncated_indices
-
- @torch.jit.export
- def step(
- self,
- step: int,
- lprobs,
- scores,
- prev_output_tokens: Optional[Tensor] = None,
- original_batch_idxs: Optional[Tensor] = None,
- ):
- bsz, beam_size, vocab_size = lprobs.size()
-
- if step == 0:
- # at the first step all hypotheses are equally likely, so use
- # only the first beam
- lprobs = lprobs[:, ::beam_size, :].contiguous()
-
- if self.sampling_topp > 0:
- # only sample from the smallest set of words whose cumulative probability mass exceeds p
- probs, top_indices = self._sample_topp(lprobs)
- elif self.sampling_topk > 0:
- # only sample from top-k candidates
- lprobs, top_indices = lprobs.topk(self.sampling_topk)
- probs = lprobs.exp_()
- else:
- probs = lprobs.exp_()
-
- # dummy data to be consistent with true branch for type check
- top_indices = torch.empty(0).to(probs)
- # sample
- if step == 0:
- indices_buf = torch.multinomial(
- probs.view(bsz, -1),
- beam_size,
- replacement=True,
- ).view(bsz, beam_size)
- else:
- indices_buf = torch.multinomial(
- probs.view(bsz * beam_size, -1),
- 1,
- replacement=True,
- ).view(bsz, beam_size)
-
- if step == 0:
- # expand to beam size
- probs = probs.expand(bsz, beam_size, -1)
-
- # gather scores
- scores_buf = torch.gather(
- probs, dim=2, index=indices_buf.unsqueeze(-1))
- scores_buf = scores_buf.log_().view(bsz, -1)
-
- # remap indices if using top-k or top-P sampling
- if self.sampling_topk > 0 or self.sampling_topp > 0:
- indices_buf = torch.gather(
- top_indices.expand(bsz, beam_size, -1),
- dim=2,
- index=indices_buf.unsqueeze(-1),
- ).squeeze(2)
-
- if step == 0:
- beams_buf = indices_buf.new_zeros(bsz, beam_size)
- else:
- beams_buf = torch.arange(0,
- beam_size).to(indices_buf).repeat(bsz, 1)
- # make scores cumulative
- scores_buf.add_(
- torch.gather(scores[:, :, step - 1], dim=1, index=beams_buf))
-
- return scores_buf, indices_buf, beams_buf
-
-
- class DiverseSiblingsSearch(Search):
- """
- Beam search with diverse siblings.
-
- See "A Simple, Fast Diverse Decoding Algorithm for Neural Generation" for details.
- https://arxiv.org/abs/1611.08562
-
- 1/ Calculate hypotheses for each beam
- 2/ Intra-sibling ordering
- 3/ Rewrite scores
- 4/ Choose top K hypotheses
-
- if diversity_rate == 0 is equivalent to BeamSearch
- """
-
- def __init__(self, tgt_dict, diversity_rate):
- super().__init__(tgt_dict)
- self.diversity_rate = diversity_rate
- self.beam = BeamSearch(tgt_dict)
-
- def step(
- self,
- step: int,
- lprobs,
- scores,
- prev_output_tokens: Optional[Tensor] = None,
- original_batch_idxs: Optional[Tensor] = None,
- ):
- bsz, beam_size, vocab_size = lprobs.size()
- k = min(
- # Take the best 2 x beam_size predictions. We'll choose the first
- # beam_size of these which don't predict eos to continue with.
- beam_size * 2,
- lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad
- )
- s_list: List[Tensor]
- i_list: List[Tensor]
- s_list = [torch.empty(0).to(lprobs) for i in range(beam_size)]
- i_list = [
- torch.LongTensor().to(device=lprobs.device)
- for i in range(beam_size)
- ]
- sibling_score = torch.arange(1, k + 1).to(lprobs) * self.diversity_rate
-
- if step == 0:
- return self.beam.step(step, lprobs, scores)
- lprobs.add_(scores[:, :, step - 1].unsqueeze(-1))
-
- # 1/ Calculate hypotheses for each beam
- for i in range(beam_size):
- torch.topk(
- lprobs[:, i, :].view(bsz, -1), k, out=(s_list[i], i_list[i]))
- i_list[i].fmod_(vocab_size)
-
- # 2/ Intra-sibling ordering by default from topk + 3/ Rewrite scores
- s_list[i].sub_(sibling_score)
-
- # 4/ Choose top K hypotheses
- indices = torch.stack(i_list, dim=1).view(bsz, -1)
-
- final_scores = torch.empty(0).to(lprobs)
- final_indices = torch.LongTensor().to(device=lprobs.device)
- final_beams = torch.LongTensor().to(device=lprobs.device)
- (final_scores, final_indices) = torch.topk(
- torch.stack(s_list, dim=1).view(bsz, -1),
- k,
- )
-
- final_beams = final_indices // k
-
- for i in range(bsz):
- final_indices[i] = indices[i][final_indices[i]]
-
- return final_scores, final_indices, final_beams
|