@@ -73,7 +73,7 @@ class Pipelines(object): | |||
asr_inference = 'asr-inference' | |||
# multi-modal tasks | |||
image_caption = 'image-captioning' | |||
image_captioning = 'image-captioning' | |||
multi_modal_embedding = 'multi-modal-embedding' | |||
visual_question_answering = 'visual-question-answering' | |||
text_to_image_synthesis = 'text-to-image-synthesis' | |||
@@ -1,5 +1,5 @@ | |||
from .clip.clip_model import CLIPForMultiModalEmbedding | |||
from .image_captioning_model import OfaForImageCaptioning | |||
from .imagen.imagen_model import ImagenForTextToImageSynthesis | |||
from .mplug_for_visual_question_answering import \ | |||
MPlugForVisualQuestionAnswering | |||
from .ofa_for_image_captioning_model import OfaForImageCaptioning |
@@ -0,0 +1,2 @@ | |||
from .modeling_ofa import OFADecoder, OFAEncoder, OFAModel, OFAPreTrainedModel | |||
from .tokenization_ofa import OFATokenizer |
@@ -0,0 +1,194 @@ | |||
# Copyright 2022 Alibaba Group and The HuggingFace Inc. team. All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" OFA model configuration""" | |||
import warnings | |||
from transformers import PretrainedConfig | |||
from transformers.utils import logging | |||
logger = logging.get_logger(__name__) | |||
OFA_PRETRAINED_CONFIG_ARCHIVE_MAP = { | |||
'ofa-medium': 'https://huggingface.co/ofa-base/resolve/main/config.json', | |||
# OFA models are implemeted to be compatible with both huggingface | |||
# and modelscope frameworks. For all OFA models available on huggingface, | |||
# please refer to https://huggingface.co/models?filter=ofa | |||
} | |||
class OFAConfig(PretrainedConfig): | |||
r""" | |||
This is the configuration class to store the configuration of a [`~OFAModel`]. It is used to instantiate an OFA | |||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the | |||
defaults will yield a similar configuration to that of the OFA [ofa-base](https://huggingface.co/ofa-base) | |||
architecture. | |||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the | |||
documentation from [`PretrainedConfig`] for more information. | |||
Args: | |||
vocab_size (`int`, *optional*, defaults to 50265): | |||
Vocabulary size of the OFA model. Defines the number of different tokens that can be represented by the | |||
`inputs_ids` passed when calling [`~OFAModel`] or [`~TFOFAModel`]. | |||
d_model (`int`, *optional*, defaults to 1024): | |||
Dimension of the layers and the pooler layer. | |||
encoder_layers (`int`, *optional*, defaults to 12): | |||
Number of encoder layers. | |||
decoder_layers (`int`, *optional*, defaults to 12): | |||
Number of decoder layers. | |||
encoder_attention_heads (`int`, *optional*, defaults to 16): | |||
Number of attention heads for each attention layer in the Transformer encoder. | |||
decoder_attention_heads (`int`, *optional*, defaults to 16): | |||
Number of attention heads for each attention layer in the Transformer decoder. | |||
decoder_ffn_dim (`int`, *optional*, defaults to 4096): | |||
Dimension of the "intermediate" (often named feed-forward) layer in decoder. | |||
encoder_ffn_dim (`int`, *optional*, defaults to 4096): | |||
Dimension of the "intermediate" (often named feed-forward) layer in decoder. | |||
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): | |||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, | |||
`"relu"`, `"silu"` and `"gelu_new"` are supported. | |||
dropout (`float`, *optional*, defaults to 0.1): | |||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. | |||
attention_dropout (`float`, *optional*, defaults to 0.0): | |||
The dropout ratio for the attention probabilities. | |||
activation_dropout (`float`, *optional*, defaults to 0.0): | |||
The dropout ratio for activations inside the fully connected layer. | |||
classifier_dropout (`float`, *optional*, defaults to 0.0): | |||
The dropout ratio for classifier. | |||
max_position_embeddings (`int`, *optional*, defaults to 1024): | |||
The maximum sequence length that this model might ever be used with. Typically set this to something large | |||
just in case (e.g., 512 or 1024 or 2048). | |||
init_std (`float`, *optional*, defaults to 0.02): | |||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. | |||
encoder_layerdrop: (`float`, *optional*, defaults to 0.0): | |||
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) | |||
for more details. | |||
decoder_layerdrop: (`float`, *optional*, defaults to 0.0): | |||
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) | |||
for more details. | |||
use_cache (`bool`, *optional*, defaults to `True`): | |||
Whether or not the model should return the last key/values attentions (not used by all models). | |||
""" | |||
model_type = 'ofa' | |||
keys_to_ignore_at_inference = ['past_key_values'] | |||
attribute_map = { | |||
'num_attention_heads': 'encoder_attention_heads', | |||
'hidden_size': 'd_model' | |||
} | |||
def __init__(self, | |||
vocab_size=59457, | |||
max_position_embeddings=1024, | |||
encoder_layers=4, | |||
encoder_ffn_dim=512 * 4, | |||
encoder_attention_heads=8, | |||
decoder_layers=4, | |||
decoder_ffn_dim=512 * 4, | |||
decoder_attention_heads=8, | |||
encoder_layerdrop=0.0, | |||
decoder_layerdrop=0.0, | |||
use_cache=True, | |||
is_encoder_decoder=True, | |||
activation_function='gelu', | |||
d_model=512, | |||
dropout=0.1, | |||
attention_dropout=0.0, | |||
activation_dropout=0.0, | |||
init_std=0.02, | |||
classifier_dropout=0.0, | |||
scale_embedding=False, | |||
pad_token_id=1, | |||
bos_token_id=0, | |||
decoder_start_token_id=0, | |||
eos_token_id=2, | |||
forced_eos_token_id=2, | |||
encoder_normalize_before=True, | |||
decoder_normalize_before=True, | |||
normformer=True, | |||
encoder_drop_path_rate=0.0, | |||
decoder_drop_path_rate=0.0, | |||
layernorm_embedding=True, | |||
patch_layernorm_embedding=True, | |||
resnet_type='resnet101', | |||
resnet_model_path=None, | |||
resnet_drop_path_rate=0.0, | |||
token_bucket_size=256, | |||
image_bucket_size=42, | |||
add_type_embedding=True, | |||
share_decoder_input_output_embed=True, | |||
attn_scale_factor=2., | |||
code_layernorm_embedding=True, | |||
code_image_size=128, | |||
entangle_position_embedding=False, | |||
**kwargs): | |||
self.vocab_size = vocab_size | |||
self.max_position_embeddings = max_position_embeddings | |||
self.d_model = d_model | |||
self.encoder_ffn_dim = encoder_ffn_dim | |||
self.encoder_layers = encoder_layers | |||
self.encoder_attention_heads = encoder_attention_heads | |||
self.decoder_ffn_dim = decoder_ffn_dim | |||
self.decoder_layers = decoder_layers | |||
self.decoder_attention_heads = decoder_attention_heads | |||
self.dropout = dropout | |||
self.attention_dropout = attention_dropout | |||
self.activation_dropout = activation_dropout | |||
self.activation_function = activation_function | |||
self.init_std = init_std | |||
self.encoder_layerdrop = encoder_layerdrop | |||
self.decoder_layerdrop = decoder_layerdrop | |||
self.classifier_dropout = classifier_dropout | |||
self.use_cache = use_cache | |||
self.num_hidden_layers = encoder_layers | |||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True | |||
self.encoder_normalize_before = encoder_normalize_before | |||
self.decoder_normalize_before = decoder_normalize_before | |||
self.normformer = normformer | |||
self.encoder_drop_path_rate = encoder_drop_path_rate | |||
self.decoder_drop_path_rate = decoder_drop_path_rate | |||
self.layernorm_embedding = layernorm_embedding | |||
self.patch_layernorm_embedding = patch_layernorm_embedding | |||
self.resnet_type = resnet_type | |||
self.resnet_model_path = resnet_model_path | |||
self.resnet_drop_path_rate = resnet_drop_path_rate | |||
self.token_bucket_size = token_bucket_size | |||
self.image_bucket_size = image_bucket_size | |||
self.add_type_embedding = add_type_embedding | |||
self.share_decoder_input_output_embed = share_decoder_input_output_embed | |||
self.attn_scale_factor = attn_scale_factor | |||
self.code_layernorm_embedding = code_layernorm_embedding | |||
self.code_image_size = code_image_size | |||
self.entangle_position_embedding = entangle_position_embedding | |||
super().__init__( | |||
pad_token_id=pad_token_id, | |||
bos_token_id=bos_token_id, | |||
eos_token_id=eos_token_id, | |||
is_encoder_decoder=is_encoder_decoder, | |||
decoder_start_token_id=decoder_start_token_id, | |||
forced_eos_token_id=forced_eos_token_id, | |||
**kwargs, | |||
) | |||
# ensure backward compatibility for BART CNN models | |||
if self.forced_bos_token_id is None and kwargs.get( | |||
'force_bos_token_to_be_generated', False): | |||
self.forced_bos_token_id = self.bos_token_id | |||
warnings.warn( | |||
f'Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. ' | |||
'The config can simply be saved and uploaded again to be fixed.' | |||
) |
@@ -0,0 +1,51 @@ | |||
# Copyright (c) Facebook, Inc. and its affiliates. | |||
# | |||
# This source code is licensed under the MIT license which can be found at | |||
# https://github.com/facebookresearch/fairseq/blob/main/LICENSE | |||
import uuid | |||
from typing import Dict, Optional | |||
from torch import Tensor | |||
class FairseqIncrementalState(object): | |||
def __init__(self, *args, **kwargs): | |||
super().__init__(*args, **kwargs) | |||
self.init_incremental_state() | |||
def init_incremental_state(self): | |||
self._incremental_state_id = str(uuid.uuid4()) | |||
def _get_full_incremental_state_key(self, key: str) -> str: | |||
return '{}.{}'.format(self._incremental_state_id, key) | |||
def get_incremental_state( | |||
self, | |||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], | |||
key: str, | |||
) -> Optional[Dict[str, Optional[Tensor]]]: | |||
"""Helper for getting incremental state for an nn.Module.""" | |||
full_key = self._get_full_incremental_state_key(key) | |||
if incremental_state is None or full_key not in incremental_state: | |||
return None | |||
return incremental_state[full_key] | |||
def set_incremental_state( | |||
self, | |||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], | |||
key: str, | |||
value: Dict[str, Optional[Tensor]], | |||
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: | |||
"""Helper for setting incremental state for an nn.Module.""" | |||
if incremental_state is not None: | |||
full_key = self._get_full_incremental_state_key(key) | |||
incremental_state[full_key] = value | |||
return incremental_state | |||
def with_incremental_state(cls): | |||
cls.__bases__ = (FairseqIncrementalState, ) + tuple( | |||
b for b in cls.__bases__ if b != FairseqIncrementalState) | |||
return cls |
@@ -0,0 +1,510 @@ | |||
# Copyright (c) Facebook, Inc. and its affiliates. | |||
# | |||
# This source code is licensed under the MIT license which can be found at | |||
# https://github.com/facebookresearch/fairseq/blob/main/LICENSE | |||
import math | |||
from typing import Dict, Optional, Tuple | |||
import torch | |||
import torch.nn.functional as F | |||
from fairseq import utils | |||
from fairseq.incremental_decoding_utils import with_incremental_state | |||
from fairseq.modules.fairseq_dropout import FairseqDropout | |||
from fairseq.modules.quant_noise import quant_noise | |||
from torch import Tensor, nn | |||
from torch.nn import Parameter | |||
@with_incremental_state | |||
class MultiheadAttention(nn.Module): | |||
"""Multi-headed attention. | |||
See "Attention Is All You Need" for more details. | |||
""" | |||
def __init__( | |||
self, | |||
embed_dim, | |||
num_heads, | |||
kdim=None, | |||
vdim=None, | |||
dropout=0.0, | |||
bias=True, | |||
add_bias_kv=False, | |||
add_zero_attn=False, | |||
self_attention=False, | |||
encoder_decoder_attention=False, | |||
q_noise=0.0, | |||
qn_block_size=8, | |||
): | |||
super().__init__() | |||
self.embed_dim = embed_dim | |||
self.kdim = kdim if kdim is not None else embed_dim | |||
self.vdim = vdim if vdim is not None else embed_dim | |||
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim | |||
self.num_heads = num_heads | |||
self.dropout_module = FairseqDropout( | |||
dropout, module_name=self.__class__.__name__) | |||
self.head_dim = embed_dim // num_heads | |||
assert (self.head_dim * num_heads == self.embed_dim | |||
), 'embed_dim must be divisible by num_heads' | |||
self.scaling = self.head_dim**-0.5 | |||
self.self_attention = self_attention | |||
self.encoder_decoder_attention = encoder_decoder_attention | |||
assert not self.self_attention or self.qkv_same_dim, ( | |||
'Self-attention requires query, key and ' | |||
'value to be of the same size') | |||
self.k_proj = quant_noise( | |||
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size) | |||
self.v_proj = quant_noise( | |||
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size) | |||
self.q_proj = quant_noise( | |||
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) | |||
self.out_proj = quant_noise( | |||
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) | |||
if add_bias_kv: | |||
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) | |||
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) | |||
else: | |||
self.bias_k = self.bias_v = None | |||
self.add_zero_attn = add_zero_attn | |||
self.reset_parameters() | |||
self.onnx_trace = False | |||
def prepare_for_onnx_export_(self): | |||
self.onnx_trace = True | |||
def reset_parameters(self): | |||
if self.qkv_same_dim: | |||
# Empirically observed the convergence to be much better with | |||
# the scaled initialization | |||
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) | |||
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) | |||
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) | |||
else: | |||
nn.init.xavier_uniform_(self.k_proj.weight) | |||
nn.init.xavier_uniform_(self.v_proj.weight) | |||
nn.init.xavier_uniform_(self.q_proj.weight) | |||
nn.init.xavier_uniform_(self.out_proj.weight) | |||
if self.out_proj.bias is not None: | |||
nn.init.constant_(self.out_proj.bias, 0.0) | |||
if self.bias_k is not None: | |||
nn.init.xavier_normal_(self.bias_k) | |||
if self.bias_v is not None: | |||
nn.init.xavier_normal_(self.bias_v) | |||
def forward( | |||
self, | |||
query, | |||
key: Optional[Tensor], | |||
value: Optional[Tensor], | |||
key_padding_mask: Optional[Tensor] = None, | |||
incremental_state: Optional[Dict[str, Dict[str, | |||
Optional[Tensor]]]] = None, | |||
need_weights: bool = True, | |||
static_kv: bool = False, | |||
attn_mask: Optional[Tensor] = None, | |||
before_softmax: bool = False, | |||
need_head_weights: bool = False, | |||
) -> Tuple[Tensor, Optional[Tensor]]: | |||
"""Input shape: Time x Batch x Channel | |||
Args: | |||
key_padding_mask (ByteTensor, optional): mask to exclude | |||
keys that are pads, of shape `(batch, src_len)`, where | |||
padding elements are indicated by 1s. | |||
need_weights (bool, optional): return the attention weights, | |||
averaged over heads (default: False). | |||
attn_mask (ByteTensor, optional): typically used to | |||
implement causal attention, where the mask prevents the | |||
attention from looking forward in time (default: None). | |||
before_softmax (bool, optional): return the raw attention | |||
weights and values before the attention softmax. | |||
need_head_weights (bool, optional): return the attention | |||
weights for each head. Implies *need_weights*. Default: | |||
return the average attention weights over all heads. | |||
""" | |||
if need_head_weights: | |||
need_weights = True | |||
is_tpu = query.device.type == 'xla' | |||
tgt_len, bsz, embed_dim = query.size() | |||
src_len = tgt_len | |||
assert embed_dim == self.embed_dim, f'query dim {embed_dim} != {self.embed_dim}' | |||
assert list(query.size()) == [tgt_len, bsz, embed_dim] | |||
if key is not None: | |||
src_len, key_bsz, _ = key.size() | |||
if not torch.jit.is_scripting(): | |||
assert key_bsz == bsz | |||
assert value is not None | |||
assert src_len, bsz == value.shape[:2] | |||
if (not self.onnx_trace | |||
and not is_tpu # don't use PyTorch version on TPUs | |||
and incremental_state is None and not static_kv | |||
# A workaround for quantization to work. Otherwise JIT compilation | |||
# treats bias in linear module as method. | |||
and not torch.jit.is_scripting()): | |||
assert key is not None and value is not None | |||
return F.multi_head_attention_forward( | |||
query, | |||
key, | |||
value, | |||
self.embed_dim, | |||
self.num_heads, | |||
torch.empty([0]), | |||
torch.cat( | |||
(self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), | |||
self.bias_k, | |||
self.bias_v, | |||
self.add_zero_attn, | |||
self.dropout_module.p, | |||
self.out_proj.weight, | |||
self.out_proj.bias, | |||
self.training or self.dropout_module.apply_during_inference, | |||
key_padding_mask, | |||
need_weights, | |||
attn_mask, | |||
use_separate_proj_weight=True, | |||
q_proj_weight=self.q_proj.weight, | |||
k_proj_weight=self.k_proj.weight, | |||
v_proj_weight=self.v_proj.weight, | |||
) | |||
if incremental_state is not None: | |||
saved_state = self._get_input_buffer(incremental_state) | |||
if saved_state is not None and 'prev_key' in saved_state: | |||
# previous time steps are cached - no need to recompute | |||
# key and value if they are static | |||
if static_kv: | |||
assert self.encoder_decoder_attention and not self.self_attention | |||
key = value = None | |||
else: | |||
saved_state = None | |||
if self.self_attention: | |||
q = self.q_proj(query) | |||
k = self.k_proj(query) | |||
v = self.v_proj(query) | |||
elif self.encoder_decoder_attention: | |||
# encoder-decoder attention | |||
q = self.q_proj(query) | |||
if key is None: | |||
assert value is None | |||
k = v = None | |||
else: | |||
k = self.k_proj(key) | |||
v = self.v_proj(key) | |||
else: | |||
assert key is not None and value is not None | |||
q = self.q_proj(query) | |||
k = self.k_proj(key) | |||
v = self.v_proj(value) | |||
q *= self.scaling | |||
if self.bias_k is not None: | |||
assert self.bias_v is not None | |||
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) | |||
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) | |||
if attn_mask is not None: | |||
attn_mask = torch.cat( | |||
[attn_mask, | |||
attn_mask.new_zeros(attn_mask.size(0), 1)], | |||
dim=1) | |||
if key_padding_mask is not None: | |||
key_padding_mask = torch.cat( | |||
[ | |||
key_padding_mask, | |||
key_padding_mask.new_zeros( | |||
key_padding_mask.size(0), 1), | |||
], | |||
dim=1, | |||
) | |||
q = ( | |||
q.contiguous().view(tgt_len, bsz * self.num_heads, | |||
self.head_dim).transpose(0, 1)) | |||
if k is not None: | |||
k = ( | |||
k.contiguous().view(-1, bsz * self.num_heads, | |||
self.head_dim).transpose(0, 1)) | |||
if v is not None: | |||
v = ( | |||
v.contiguous().view(-1, bsz * self.num_heads, | |||
self.head_dim).transpose(0, 1)) | |||
if saved_state is not None: | |||
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim) | |||
if 'prev_key' in saved_state: | |||
_prev_key = saved_state['prev_key'] | |||
assert _prev_key is not None | |||
prev_key = _prev_key.view(bsz * self.num_heads, -1, | |||
self.head_dim) | |||
if static_kv: | |||
k = prev_key | |||
else: | |||
assert k is not None | |||
k = torch.cat([prev_key, k], dim=1) | |||
src_len = k.size(1) | |||
if 'prev_value' in saved_state: | |||
_prev_value = saved_state['prev_value'] | |||
assert _prev_value is not None | |||
prev_value = _prev_value.view(bsz * self.num_heads, -1, | |||
self.head_dim) | |||
if static_kv: | |||
v = prev_value | |||
else: | |||
assert v is not None | |||
v = torch.cat([prev_value, v], dim=1) | |||
prev_key_padding_mask: Optional[Tensor] = None | |||
if 'prev_key_padding_mask' in saved_state: | |||
prev_key_padding_mask = saved_state['prev_key_padding_mask'] | |||
assert k is not None and v is not None | |||
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( | |||
key_padding_mask=key_padding_mask, | |||
prev_key_padding_mask=prev_key_padding_mask, | |||
batch_size=bsz, | |||
src_len=k.size(1), | |||
static_kv=static_kv, | |||
) | |||
saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, | |||
self.head_dim) | |||
saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, | |||
self.head_dim) | |||
saved_state['prev_key_padding_mask'] = key_padding_mask | |||
# In this branch incremental_state is never None | |||
assert incremental_state is not None | |||
incremental_state = self._set_input_buffer(incremental_state, | |||
saved_state) | |||
assert k is not None | |||
assert k.size(1) == src_len | |||
# This is part of a workaround to get around fork/join parallelism | |||
# not supporting Optional types. | |||
if key_padding_mask is not None and key_padding_mask.dim() == 0: | |||
key_padding_mask = None | |||
if key_padding_mask is not None: | |||
assert key_padding_mask.size(0) == bsz | |||
assert key_padding_mask.size(1) == src_len | |||
if self.add_zero_attn: | |||
assert v is not None | |||
src_len += 1 | |||
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], | |||
dim=1) | |||
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], | |||
dim=1) | |||
if attn_mask is not None: | |||
attn_mask = torch.cat( | |||
[attn_mask, | |||
attn_mask.new_zeros(attn_mask.size(0), 1)], | |||
dim=1) | |||
if key_padding_mask is not None: | |||
key_padding_mask = torch.cat( | |||
[ | |||
key_padding_mask, | |||
torch.zeros(key_padding_mask.size(0), | |||
1).type_as(key_padding_mask), | |||
], | |||
dim=1, | |||
) | |||
attn_weights = torch.bmm(q, k.transpose(1, 2)) | |||
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, | |||
bsz) | |||
assert list( | |||
attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] | |||
if attn_mask is not None: | |||
attn_mask = attn_mask.unsqueeze(0) | |||
if self.onnx_trace: | |||
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) | |||
attn_weights += attn_mask | |||
if key_padding_mask is not None: | |||
# don't attend to padding symbols | |||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, | |||
src_len) | |||
if not is_tpu: | |||
attn_weights = attn_weights.masked_fill( | |||
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), | |||
float('-inf'), | |||
) | |||
else: | |||
attn_weights = attn_weights.transpose(0, 2) | |||
attn_weights = attn_weights.masked_fill( | |||
key_padding_mask, float('-inf')) | |||
attn_weights = attn_weights.transpose(0, 2) | |||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, | |||
src_len) | |||
if before_softmax: | |||
return attn_weights, v | |||
attn_weights_float = utils.softmax( | |||
attn_weights, dim=-1, onnx_trace=self.onnx_trace) | |||
attn_weights = attn_weights_float.type_as(attn_weights) | |||
attn_probs = self.dropout_module(attn_weights) | |||
assert v is not None | |||
attn = torch.bmm(attn_probs, v) | |||
assert list( | |||
attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] | |||
if self.onnx_trace and attn.size(1) == 1: | |||
# when ONNX tracing a single decoder step (sequence length == 1) | |||
# the transpose is a no-op copy before view, thus unnecessary | |||
attn = attn.contiguous().view(tgt_len, bsz, embed_dim) | |||
else: | |||
attn = attn.transpose(0, | |||
1).contiguous().view(tgt_len, bsz, embed_dim) | |||
attn = self.out_proj(attn) | |||
attn_weights: Optional[Tensor] = None | |||
if need_weights: | |||
attn_weights = attn_weights_float.view(bsz, self.num_heads, | |||
tgt_len, | |||
src_len).transpose(1, 0) | |||
if not need_head_weights: | |||
# average attention weights over heads | |||
attn_weights = attn_weights.mean(dim=0) | |||
return attn, attn_weights | |||
@staticmethod | |||
def _append_prev_key_padding_mask( | |||
key_padding_mask: Optional[Tensor], | |||
prev_key_padding_mask: Optional[Tensor], | |||
batch_size: int, | |||
src_len: int, | |||
static_kv: bool, | |||
) -> Optional[Tensor]: | |||
# saved key padding masks have shape (bsz, seq_len) | |||
if prev_key_padding_mask is not None and static_kv: | |||
new_key_padding_mask = prev_key_padding_mask | |||
elif prev_key_padding_mask is not None and key_padding_mask is not None: | |||
new_key_padding_mask = torch.cat( | |||
[prev_key_padding_mask.float(), | |||
key_padding_mask.float()], | |||
dim=1) | |||
# During incremental decoding, as the padding token enters and | |||
# leaves the frame, there will be a time when prev or current | |||
# is None | |||
elif prev_key_padding_mask is not None: | |||
if src_len > prev_key_padding_mask.size(1): | |||
filler = torch.zeros( | |||
(batch_size, src_len - prev_key_padding_mask.size(1)), | |||
device=prev_key_padding_mask.device, | |||
) | |||
new_key_padding_mask = torch.cat( | |||
[prev_key_padding_mask.float(), | |||
filler.float()], dim=1) | |||
else: | |||
new_key_padding_mask = prev_key_padding_mask.float() | |||
elif key_padding_mask is not None: | |||
if src_len > key_padding_mask.size(1): | |||
filler = torch.zeros( | |||
(batch_size, src_len - key_padding_mask.size(1)), | |||
device=key_padding_mask.device, | |||
) | |||
new_key_padding_mask = torch.cat( | |||
[filler.float(), key_padding_mask.float()], dim=1) | |||
else: | |||
new_key_padding_mask = key_padding_mask.float() | |||
else: | |||
new_key_padding_mask = prev_key_padding_mask | |||
return new_key_padding_mask | |||
@torch.jit.export | |||
def reorder_incremental_state( | |||
self, | |||
incremental_state: Dict[str, Dict[str, Optional[Tensor]]], | |||
new_order: Tensor, | |||
): | |||
"""Reorder buffered internal state (for incremental generation).""" | |||
input_buffer = self._get_input_buffer(incremental_state) | |||
if input_buffer is not None: | |||
for k in input_buffer.keys(): | |||
input_buffer_k = input_buffer[k] | |||
if input_buffer_k is not None: | |||
if self.encoder_decoder_attention and input_buffer_k.size( | |||
0) == new_order.size(0): | |||
break | |||
input_buffer[k] = input_buffer_k.index_select(0, new_order) | |||
incremental_state = self._set_input_buffer(incremental_state, | |||
input_buffer) | |||
return incremental_state | |||
def _get_input_buffer( | |||
self, incremental_state: Optional[Dict[str, Dict[str, | |||
Optional[Tensor]]]] | |||
) -> Dict[str, Optional[Tensor]]: | |||
result = self.get_incremental_state(incremental_state, 'attn_state') | |||
if result is not None: | |||
return result | |||
else: | |||
empty_result: Dict[str, Optional[Tensor]] = {} | |||
return empty_result | |||
def _set_input_buffer( | |||
self, | |||
incremental_state: Dict[str, Dict[str, Optional[Tensor]]], | |||
buffer: Dict[str, Optional[Tensor]], | |||
): | |||
return self.set_incremental_state(incremental_state, 'attn_state', | |||
buffer) | |||
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, | |||
bsz: int): | |||
return attn_weights | |||
def upgrade_state_dict_named(self, state_dict, name): | |||
prefix = name + '.' if name != '' else '' | |||
items_to_add = {} | |||
keys_to_remove = [] | |||
for k in state_dict.keys(): | |||
if k.endswith(prefix + 'in_proj_weight'): | |||
# in_proj_weight used to be q + k + v with same dimensions | |||
dim = int(state_dict[k].shape[0] / 3) | |||
items_to_add[prefix + 'q_proj.weight'] = state_dict[k][:dim] | |||
items_to_add[prefix + 'k_proj.weight'] = state_dict[k][dim:2 | |||
* dim] | |||
items_to_add[prefix + 'v_proj.weight'] = state_dict[k][2 | |||
* dim:] | |||
keys_to_remove.append(k) | |||
k_bias = prefix + 'in_proj_bias' | |||
if k_bias in state_dict.keys(): | |||
dim = int(state_dict[k].shape[0] / 3) | |||
items_to_add[prefix | |||
+ 'q_proj.bias'] = state_dict[k_bias][:dim] | |||
items_to_add[prefix | |||
+ 'k_proj.bias'] = state_dict[k_bias][dim:2 | |||
* dim] | |||
items_to_add[prefix | |||
+ 'v_proj.bias'] = state_dict[k_bias][2 | |||
* dim:] | |||
keys_to_remove.append(prefix + 'in_proj_bias') | |||
for k in keys_to_remove: | |||
del state_dict[k] | |||
for key, value in items_to_add.items(): | |||
state_dict[key] = value |
@@ -0,0 +1,155 @@ | |||
# Originally from Microsoft Corporation. | |||
# Licensed under the MIT License. | |||
""" Wrapper for ngram_repeat_block cuda extension """ | |||
import math | |||
import warnings | |||
from typing import Dict, List | |||
import torch | |||
from torch import nn | |||
try: | |||
from fairseq import ngram_repeat_block_cuda | |||
EXTENSION_BUILT = True | |||
except ImportError: | |||
EXTENSION_BUILT = False | |||
def is_cuda_extension_usable() -> bool: | |||
"""Check whether ngram_repeat_block_cuda is built properly""" | |||
if not EXTENSION_BUILT or not torch.cuda.is_available(): | |||
return False | |||
bsz = 2 | |||
tokens = torch.tensor([[4, 4, 3, 2], [1, 2, 3, 4]], | |||
dtype=torch.long, | |||
device='cuda') | |||
lprobs = torch.rand((8, 12), device='cuda') | |||
try: | |||
outputs = ngram_repeat_block_cuda.forward(tokens, lprobs, bsz, 3, 4, 3) | |||
outputs = outputs + 4 # This line breaks if the extension is built incorrectly. | |||
return True | |||
except RuntimeError: | |||
warnings.warn( | |||
'NGramRepeatBlock extension must be rebuilt.' | |||
'Run TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0" python setup.py build_ext --inplace' | |||
) | |||
return False | |||
class NGramRepeatBlock(nn.Module): | |||
""" Wrapper class for calling ngram_repeat_block cuda extension """ | |||
def __init__(self, no_repeat_ngram_size: int, use_extension: bool = True): | |||
super().__init__() | |||
self.use_extension = is_cuda_extension_usable( | |||
) if use_extension else False | |||
self.no_repeat_ngram_size = no_repeat_ngram_size | |||
def reset_parameters(self): | |||
pass | |||
@torch.jit.unused | |||
def call_cuda_extension( | |||
self, | |||
tokens, | |||
lprobs, | |||
bsz: int, | |||
beam_size: int, | |||
step: int, | |||
): | |||
return ngram_repeat_block_cuda.forward(tokens, lprobs, bsz, step, | |||
beam_size, | |||
self.no_repeat_ngram_size) | |||
def forward( | |||
self, | |||
tokens, | |||
lprobs, | |||
bsz: int, | |||
beam_size: int, | |||
step: int, | |||
): | |||
""" | |||
Args: | |||
tokens(Tensor): Input tokens(Bsz*beam, seq_len) | |||
lprobs(Tensor): likelihood probability, | |||
Expected to be updated in place.(Bsz*beam, vocab_size) | |||
bsz(int): batch size | |||
step(int): current step | |||
beam_size(int): beam size | |||
no_repeat_ngram_size(int): Ngram size | |||
""" | |||
msg = f'expected {bsz * beam_size} got' | |||
assert tokens.size(0) == bsz * beam_size, f'{msg} {tokens.size(0)}' | |||
assert lprobs.size(0) == bsz * beam_size, f'{msg} {lprobs.size(0)}' | |||
if self.use_extension: | |||
return self.call_cuda_extension(tokens, lprobs, bsz, beam_size, | |||
step) | |||
else: | |||
return self._no_repeat_ngram( | |||
tokens, | |||
lprobs, | |||
bsz, | |||
beam_size, | |||
step, | |||
) | |||
def _no_repeat_ngram(self, tokens, lprobs, bsz: int, beam_size: int, | |||
step: int): | |||
"""For each hypothesis generate a list of previous ngrams and set associated lprobs to -inf""" | |||
gen_ngrams: List[Dict[str, List[int]]] = [ | |||
torch.jit.annotate(Dict[str, List[int]], {}) | |||
for bbsz_idx in range(bsz * beam_size) | |||
] | |||
cpu_tokens = tokens.cpu() | |||
for bbsz_idx in range(bsz * beam_size): | |||
gen_tokens: List[int] = cpu_tokens[bbsz_idx].tolist() | |||
for ngram in self.transpose_list([ | |||
gen_tokens[i:] for i in range(self.no_repeat_ngram_size) | |||
]): # noqa | |||
key = ','.join([str(x) for x in ngram[:-1]]) | |||
gen_ngrams[bbsz_idx][key] = gen_ngrams[bbsz_idx].get( | |||
key, torch.jit.annotate(List[int], [])) + [ngram[-1]] | |||
if step + 2 - self.no_repeat_ngram_size >= 0: | |||
# no banned tokens if we haven't generated no_repeat_ngram_size tokens yet | |||
banned_tokens = [ | |||
self.calculate_banned_tokens(tokens, step, gen_ngrams, | |||
self.no_repeat_ngram_size, | |||
bbsz_idx) | |||
for bbsz_idx in range(bsz * beam_size) | |||
] | |||
else: | |||
banned_tokens = [ | |||
torch.jit.annotate(List[int], []) | |||
for bbsz_idx in range(bsz * beam_size) | |||
] | |||
for bbsz_idx in range(bsz * beam_size): | |||
lprobs[bbsz_idx][torch.tensor( | |||
banned_tokens[bbsz_idx], | |||
dtype=torch.int64)] = torch.tensor(-math.inf).to(lprobs) | |||
return lprobs | |||
@staticmethod | |||
def calculate_banned_tokens( | |||
tokens, | |||
step: int, | |||
gen_ngrams: List[Dict[str, List[int]]], | |||
no_repeat_ngram_size: int, | |||
bbsz_idx: int, | |||
): | |||
tokens_list: List[int] = tokens[bbsz_idx, | |||
step + 2 - no_repeat_ngram_size:step | |||
+ 1].tolist() # noqa | |||
# before decoding the next token, prevent decoding of ngrams that have already appeared | |||
ngram_index = ','.join([str(x) for x in tokens_list]) | |||
return gen_ngrams[bbsz_idx].get(ngram_index, | |||
torch.jit.annotate(List[int], [])) | |||
@staticmethod | |||
def transpose_list(l: List[List[int]]): # noqa | |||
# GeneratorExp aren't supported in TS so ignoring the lint | |||
min_len = min([len(x) for x in l]) # noqa | |||
l2 = [[row[i] for row in l] for i in range(min_len)] | |||
return l2 |
@@ -0,0 +1,848 @@ | |||
# Copyright (c) Facebook, Inc. and its affiliates. | |||
# | |||
# This source code is licensed under the MIT license which can be found at | |||
# https://github.com/facebookresearch/fairseq/blob/main/LICENSE | |||
import math | |||
from typing import List, Optional | |||
import torch | |||
import torch.nn as nn | |||
from torch import Tensor | |||
from .token_generation_constraints import (ConstraintState, | |||
OrderedConstraintState, | |||
UnorderedConstraintState) | |||
class Search(nn.Module): | |||
def __init__(self, tokenizer): | |||
super().__init__() | |||
self.pad = tokenizer.pad_token_id | |||
self.unk = tokenizer.unk_token_id | |||
self.eos = tokenizer.eos_token_id | |||
tgt_dict = {value: key for key, value in tokenizer.get_vocab().items()} | |||
added = { | |||
value: key | |||
for key, value in tokenizer.get_added_vocab().items() | |||
} | |||
tgt_dict.update(added) | |||
self.vocab_size = len(tgt_dict) | |||
self.src_lengths = torch.tensor(-1) | |||
self.supports_constraints = False | |||
self.stop_on_max_len = False | |||
def step(self, | |||
step, | |||
lprobs, | |||
scores, | |||
prev_output_tokens=None, | |||
original_batch_idxs=None): | |||
"""Take a single search step. | |||
Args: | |||
step: the current search step, starting at 0 | |||
lprobs: (bsz x input_beam_size x vocab_size) | |||
the model's log-probabilities over the vocabulary at the current step | |||
scores: (bsz x input_beam_size x step) | |||
the historical model scores of each hypothesis up to this point | |||
prev_output_tokens: (bsz x step) | |||
the previously generated oputput tokens | |||
original_batch_idxs: (bsz) | |||
the tensor with the batch indices, in the range [0, bsz) | |||
this is useful in case there has been applied a re-ordering | |||
and we need to know the orignal indices | |||
Return: A tuple of (scores, indices, beams) where: | |||
scores: (bsz x output_beam_size) | |||
the scores of the chosen elements; output_beam_size can be | |||
larger than input_beam_size, e.g., we may return | |||
2*input_beam_size to account for EOS | |||
indices: (bsz x output_beam_size) | |||
the indices of the chosen elements | |||
beams: (bsz x output_beam_size) | |||
the hypothesis ids of the chosen elements, in the range [0, input_beam_size) | |||
""" | |||
raise NotImplementedError | |||
@torch.jit.export | |||
def set_src_lengths(self, src_lengths): | |||
self.src_lengths = src_lengths | |||
@torch.jit.export | |||
def init_constraints(self, batch_constraints: Optional[Tensor], | |||
beam_size: int): | |||
"""Initialize constraint states for constrained decoding (if supported). | |||
Args: | |||
batch_constraints: (torch.Tensor, optional) | |||
the list of constraints, in packed form | |||
beam_size: (int) | |||
the beam size | |||
Returns: | |||
*encoder_out* rearranged according to *new_order* | |||
""" | |||
pass | |||
def prune_sentences(self, batch_idxs: Tensor): | |||
""" | |||
Removes constraint states for completed sentences (if supported). | |||
This is called from sequence_generator._generate() when sentences are | |||
deleted from the batch. | |||
Args: | |||
batch_idxs: Indices of *sentences* whose constraint state should be *kept*. | |||
""" | |||
pass | |||
def update_constraints(self, active_hypos: Tensor): | |||
""" | |||
Updates the constraint states by selecting the beam items that are retained. | |||
This is called at each time step of sequence_generator._generate() when | |||
the set of 2 * {beam_size} candidate hypotheses are reduced to the beam size. | |||
Args: | |||
active_hypos: (batch size, beam size) | |||
list of integers denoting, for each sentence, which beam candidate items | |||
should be kept. | |||
""" | |||
pass | |||
class BeamSearch(Search): | |||
def __init__(self, tgt_dict): | |||
super().__init__(tgt_dict) | |||
self.constraint_states = None | |||
@torch.jit.export | |||
def step( | |||
self, | |||
step: int, | |||
lprobs, | |||
scores: Optional[Tensor], | |||
prev_output_tokens: Optional[Tensor] = None, | |||
original_batch_idxs: Optional[Tensor] = None, | |||
): | |||
bsz, beam_size, vocab_size = lprobs.size() | |||
if step == 0: | |||
# at the first step all hypotheses are equally likely, so use | |||
# only the first beam | |||
lprobs = lprobs[:, ::beam_size, :].contiguous() | |||
else: | |||
# make probs contain cumulative scores for each hypothesis | |||
assert scores is not None | |||
lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) | |||
top_prediction = torch.topk( | |||
lprobs.view(bsz, -1), | |||
k=min( | |||
# Take the best 2 x beam_size predictions. We'll choose the first | |||
# beam_size of these which don't predict eos to continue with. | |||
beam_size * 2, | |||
lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad | |||
), | |||
) | |||
scores_buf = top_prediction[0] | |||
indices_buf = top_prediction[1] | |||
# Project back into relative indices and beams | |||
beams_buf = indices_buf // vocab_size | |||
indices_buf = indices_buf.fmod(vocab_size) | |||
# At this point, beams_buf and indices_buf are single-dim and contain relative indices | |||
return scores_buf, indices_buf, beams_buf | |||
class PrefixConstrainedBeamSearch(Search): | |||
def __init__(self, tgt_dict, prefix_allowed_tokens_fn): | |||
super().__init__(tgt_dict) | |||
self.prefix_allowed_tokens_fn = prefix_allowed_tokens_fn | |||
self.stop_on_max_len = True | |||
@torch.jit.export | |||
def apply_mask(self, x, prev_output_tokens, original_batch_idxs): | |||
beam_size = x.shape[0] // original_batch_idxs.shape[0] | |||
original_batch_idxs = ( | |||
original_batch_idxs.unsqueeze(-1).repeat( | |||
(1, beam_size)).flatten().tolist()) | |||
mask = torch.full_like(x, -math.inf) | |||
for sent_i, (sent, batch_i) in enumerate( | |||
zip(prev_output_tokens, original_batch_idxs)): | |||
mask[sent_i, :, self.prefix_allowed_tokens_fn(batch_i, sent)] = 0 | |||
return mask | |||
@torch.jit.export | |||
def step( | |||
self, | |||
step: int, | |||
lprobs: Tensor, | |||
scores: Tensor, | |||
prev_output_tokens: Tensor, | |||
original_batch_idxs: Tensor, | |||
): | |||
bsz, beam_size, vocab_size = lprobs.size() | |||
lprobs += self.apply_mask( | |||
lprobs.view(bsz * beam_size, 1, vocab_size), | |||
prev_output_tokens, | |||
original_batch_idxs, | |||
).view(bsz, beam_size, vocab_size) | |||
if step == 0: | |||
# at the first step all hypotheses are equally likely, so use | |||
# only the first beam | |||
lprobs = lprobs[:, ::beam_size, :].contiguous() | |||
else: | |||
# make probs contain cumulative scores for each hypothesis | |||
assert scores is not None | |||
lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) | |||
top_prediction = torch.topk( | |||
lprobs.view(bsz, -1), | |||
k=min( | |||
# Take the best beam_size predictions. We'll choose the first | |||
# beam_size of these which don't predict eos to continue with. | |||
beam_size, | |||
lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad | |||
), | |||
) | |||
scores_buf = top_prediction[0] | |||
indices_buf = top_prediction[1] | |||
beams_buf = indices_buf // vocab_size | |||
indices_buf = indices_buf.fmod(vocab_size) | |||
return scores_buf, indices_buf, beams_buf | |||
class LexicallyConstrainedBeamSearch(Search): | |||
"""Implements lexically constrained beam search as described in | |||
Fast Lexically Constrained Decoding with Dynamic Beam | |||
Allocation for Neural Machine Translation. Post & Vilar, | |||
NAACL 2018. https://www.aclweb.org/anthology/N18-1119/ | |||
and | |||
Improved Lexically Constrained Decoding for Translation and | |||
Monolingual Rewriting. Hu et al, NAACL | |||
2019. https://www.aclweb.org/anthology/N19-1090/ | |||
This is accomplished by maintaining, for each beam hypothesis, a | |||
ConstraintState object (see constraints.py) that tracks which | |||
constraints have been generated and using this information to | |||
shape the beam for each input sentence. | |||
""" | |||
def __init__(self, tokenizer, representation): | |||
super().__init__(tokenizer) | |||
self.representation = representation | |||
tgt_dict = {value: key for key, value in tokenizer.get_vocab().items()} | |||
added = { | |||
value: key | |||
for key, value in tokenizer.get_added_vocab().items() | |||
} | |||
tgt_dict.update(added) | |||
self.vocab_size = len(tgt_dict) | |||
self.num_cands = 0 | |||
self.supports_constraints = True | |||
@torch.jit.export | |||
def init_constraints(self, batch_constraints: Optional[Tensor], | |||
beam_size: int): | |||
self.constraint_states = [] | |||
for constraint_tensor in batch_constraints: | |||
if self.representation == 'ordered': | |||
constraint_state = OrderedConstraintState.create( | |||
constraint_tensor) | |||
elif self.representation == 'unordered': | |||
constraint_state = UnorderedConstraintState.create( | |||
constraint_tensor) | |||
self.constraint_states.append( | |||
[constraint_state for i in range(beam_size)]) | |||
@torch.jit.export | |||
def prune_sentences(self, batch_idxs: Tensor): | |||
self.constraint_states = [ | |||
self.constraint_states[i] for i in batch_idxs.tolist() | |||
] | |||
@torch.jit.export | |||
def update_constraints(self, active_hypos: Tensor): | |||
if self.constraint_states: | |||
batch_size = active_hypos.size(0) | |||
for sentid in range(batch_size): | |||
self.constraint_states[sentid] = [ | |||
self.constraint_states[sentid][i] | |||
for i in active_hypos[sentid] | |||
] | |||
@torch.jit.export | |||
def step( | |||
self, | |||
step: int, | |||
lprobs: Tensor, | |||
scores: Optional[Tensor], | |||
prev_output_tokens: Optional[Tensor] = None, | |||
original_batch_idxs: Optional[Tensor] = None, | |||
): | |||
""" | |||
A constrained step builds a large candidates list from the following: | |||
- the top 2 * {beam_size} items over the whole beam | |||
- for each item in the beam | |||
- the top {each_k} (default 1) | |||
- all next constraints | |||
We then compute the constrained state of each beam item, and assign | |||
stripe codes: 0 to the best in each bank, 1 to the 2nd-best, and so | |||
on. We then sort by (stripe, score), and truncate the list at | |||
2 * beam size. | |||
Args: | |||
step: the decoder step | |||
lprobs: (batch size, beam size, target vocab) | |||
the target-vocab distributions for each item in the beam. | |||
Retrun: A tuple of (scores, indices, beams, constraints) where: | |||
scores: (batch, output beam size) | |||
the scores of the chosen elements | |||
indices: (batch, output beam size) | |||
the target vocab indices of the chosen elements | |||
beams: (batch, output beam size) | |||
the 0-indexed hypothesis ids of the chosen elements | |||
constraints: (batch, output beam size) | |||
the new constraint states | |||
""" | |||
each_k = 1 | |||
device = lprobs.device | |||
batch_size, beam_size, vocab_size = lprobs.size() | |||
self.num_cands = min( | |||
# Just take the k-best. We'll get another k from the 1-best from each | |||
# row, plus more from the constraints | |||
beam_size * 2, | |||
lprobs.view(batch_size, -1).size(1) | |||
- 1, # -1 so we never select pad | |||
) | |||
# STEP 0: Preliminary. Prevent EOS for unfinished hyps across all batch items | |||
constraint_states = self.constraint_states | |||
if constraint_states and step > 0: | |||
not_finished_indices = [] | |||
for sentno, sent_constraints in enumerate(constraint_states): | |||
for beamno, state in enumerate(sent_constraints): | |||
index = sentno * beam_size + beamno | |||
if not state.finished: | |||
not_finished_indices.append(index) | |||
not_finished_indices = torch.tensor(not_finished_indices) | |||
if not_finished_indices.numel() > 0: | |||
lprobs.view(batch_size * beam_size, -1)[not_finished_indices, | |||
self.eos] = -math.inf | |||
if step == 0: | |||
# at the first step all hypotheses are equally likely, so use | |||
# only the first beam entry for each batch item | |||
lprobs = lprobs[:, ::beam_size, :].contiguous() | |||
else: | |||
# make probs contain cumulative scores for each hypothesis | |||
assert scores is not None | |||
lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) | |||
top_prediction = torch.topk( | |||
lprobs.view(batch_size, -1), | |||
self.num_cands, | |||
) | |||
scores_buf, indices_buf = top_prediction | |||
# Project back into relative indices and beams | |||
beams_buf = indices_buf // vocab_size | |||
indices_buf = indices_buf.fmod(vocab_size) | |||
# Short circuit if there are no constraints in this batch | |||
if not constraint_states: | |||
return scores_buf, indices_buf, beams_buf | |||
# STEP 1: get top-1 from each hypothesis across all sentences in the batch | |||
if step > 0: | |||
top_scores, top_indices = torch.topk( | |||
lprobs.view(batch_size * beam_size, -1), | |||
k=each_k, | |||
dim=1, | |||
) | |||
top_scores = top_scores.view(batch_size, -1) | |||
top_indices = top_indices.view(batch_size, -1) | |||
scores_buf = torch.cat((scores_buf, top_scores), dim=1) | |||
indices_buf = torch.cat((indices_buf, top_indices), dim=1) | |||
new_beams = torch.arange( | |||
0, beam_size, device=device).repeat(batch_size, 1) | |||
beams_buf = torch.cat((beams_buf, new_beams), dim=1) | |||
# Now, process sentences in the batch one by one. | |||
new_scores_buf = torch.zeros((batch_size, 2 * beam_size), | |||
device=device) | |||
new_indices_buf = torch.zeros((batch_size, 2 * beam_size), | |||
device=device).long() | |||
new_beams_buf = torch.zeros((batch_size, 2 * beam_size), | |||
device=device).long() | |||
for sentno, states in enumerate(constraint_states): | |||
scores, indices, beams, new_states = self.step_sentence( | |||
step, | |||
sentno, | |||
lprobs[sentno], | |||
constraint_states[sentno], | |||
beams_buf[sentno].clone(), | |||
indices_buf[sentno].clone(), | |||
scores_buf[sentno].clone(), | |||
) | |||
new_scores_buf[sentno] = scores | |||
new_indices_buf[sentno] = indices | |||
new_beams_buf[sentno] = beams | |||
self.constraint_states[sentno] = new_states | |||
return new_scores_buf, new_indices_buf, new_beams_buf | |||
@torch.jit.export | |||
def step_sentence( | |||
self, | |||
step: int, | |||
sentno: int, | |||
lprobs: Tensor, | |||
constraint_states: List[List[ConstraintState]], | |||
beams_buf: Tensor, | |||
indices_buf: Tensor, | |||
scores_buf: Tensor, | |||
): | |||
"""Does per-sentence processing. Adds all constraints for each | |||
hypothesis to the list of candidates; then removes duplicates, | |||
sorts, and dynamically stripes across the banks. All tensor inputs | |||
are collapsed to those pertaining to a single input sentence. | |||
""" | |||
device = lprobs.device | |||
# STEP 2: Add all constraints for each beam item | |||
for beamno, state in enumerate(constraint_states): | |||
next_tokens = torch.tensor( | |||
list(state.next_tokens()), device=device).long() | |||
if next_tokens.numel() != 0: | |||
indices_buf = torch.cat((indices_buf, next_tokens)) | |||
next_beams = ( | |||
torch.tensor(beamno, device=device).repeat( | |||
next_tokens.size(0)).long()) | |||
beams_buf = torch.cat((beams_buf, next_beams)) | |||
next_values = lprobs[beamno].take(next_tokens.view(-1)) | |||
scores_buf = torch.cat((scores_buf, next_values)) | |||
# At the 0th time step, there is just one beam item | |||
if step == 0: | |||
break | |||
# STEP 3: Compute the "bank" for each candidate. This is the | |||
# number of constraints it's generated. We need this so that | |||
# we can do round-robin allocation of the beam across these | |||
# banks. If C is the number of constraints, we select the best | |||
# item in bank C, then the best in bank C-1, etc, followed by | |||
# the 2nd-best in bank C, the 2nd-best in bank C-1, etc, and so | |||
# on, until the maximum beam size. We accomplish this by | |||
# creating a sort key and striping across the banks. | |||
# Compute the new states for all candidates | |||
cands_size = indices_buf.size(0) | |||
constraint_states = [ | |||
constraint_states[beams_buf[i]].advance(indices_buf[i]) | |||
for i in range(cands_size) | |||
] | |||
banks = torch.tensor([state.bank for state in constraint_states], | |||
device=device) | |||
# STEP 4: Sort | |||
num_constraint_tokens = len(state.tokens) | |||
# Sort by keys (bank, score) (i.e., sort banks together, and scores | |||
# within banks). AFAIK pytorch doesn't support either stable sort or | |||
# multi-key sorting, so we have to hack this. | |||
MAX_SCORE = -100 | |||
sort_key = (num_constraint_tokens - banks) * MAX_SCORE + scores_buf | |||
sort_values, sort_indices = sort_key.sort(dim=0, descending=True) | |||
scores_buf = scores_buf[sort_indices] | |||
indices_buf = indices_buf[sort_indices] | |||
beams_buf = beams_buf[sort_indices] | |||
banks = banks[sort_indices] | |||
# Sort the constraints to follow suit | |||
constraint_states = [constraint_states[i] for i in sort_indices] | |||
# STEP 5: Remove duplicates. The topk calls (overall and | |||
# per-row) plus the per-row generation of constraints will | |||
# produce duplicates. Here we remove them. | |||
def roll(t): | |||
"""Rolls a 1d tensor left by 1. | |||
[0, 1, 2, 3, 4] becomes [4, 0, 1, 2, 3] | |||
""" | |||
return torch.cat((t[-1].unsqueeze(0), t[0:-1]), dim=0) | |||
# We map candidates (beam, token_id) to a single dimension. | |||
# This is then shifted by 1. We can then easily identify | |||
# duplicates and create a mask that identifies unique | |||
# extensions. | |||
uniques_mask = beams_buf * (self.vocab_size + 1) + indices_buf | |||
uniques_mask = roll(uniques_mask) != uniques_mask | |||
# Use the mask to pare down the data structures | |||
scores_buf = torch.masked_select(scores_buf, uniques_mask) | |||
indices_buf = torch.masked_select(indices_buf, uniques_mask) | |||
beams_buf = torch.masked_select(beams_buf, uniques_mask) | |||
banks = torch.masked_select(banks, uniques_mask) | |||
i = 1 | |||
for mask in uniques_mask[1:]: | |||
if not mask: | |||
constraint_states.pop(i) | |||
i += mask | |||
# STEP 6: Assign IDs round-robin across banks, sort, and | |||
# truncate. Now that the candidates are sorted by (bank, | |||
# score) and uniqed, we dynamically allocate the {beam_size} | |||
# beam by striping across the candidates. These stripes will | |||
# be used as sort keys to do round-robin selection. This is | |||
# accomplished in a single pass with offsets. Sorting by | |||
# highest-banks (furthest-along hypotheses) first ensures | |||
# progress through the constraints. | |||
# | |||
# e.g., BANKS: 3 3 3 2 2 2 2 1 1 1 0 0 | |||
# OLD STRIPES: 0 1 2 0 1 2 3 0 1 2 0 1 | |||
# NEW STRIPES: 0 1+4 2+8 0+1 1+5 2+9 3+11 0+2 1+6 2+10 0+3 1+7 | |||
# = 0 5 10 1 6 11 13 2 7 12 3 8 | |||
# | |||
# Sorting by this then gives the following banks: | |||
# | |||
# 3 2 1 0 3 2 1 0 3 2 1 2 | |||
# | |||
# We'll take the top {beam_size} of these. | |||
stripe_offsets = [ | |||
offset * (len(banks) + 1) for offset in range(len(banks) + 1) | |||
] | |||
stripes = torch.zeros_like(banks) | |||
cur_bank_count = -1 | |||
cur_bank = banks[0] | |||
for i, bank in enumerate(banks): | |||
if bank != cur_bank: | |||
cur_bank_count = 0 | |||
cur_bank = bank | |||
else: | |||
cur_bank_count += 1 | |||
stripes[i] = num_constraint_tokens - bank + stripe_offsets[ | |||
cur_bank_count] | |||
# STEP 7: Sort by the stripes values | |||
sort_values, sort_indices = stripes.sort(dim=0) | |||
scores_buf = scores_buf[sort_indices] | |||
indices_buf = indices_buf[sort_indices] | |||
beams_buf = beams_buf[sort_indices] | |||
constraint_states = [constraint_states[i] for i in sort_indices] | |||
# STEP 8: Truncate to the candidates size! | |||
scores_buf = scores_buf[:self.num_cands] | |||
indices_buf = indices_buf[:self.num_cands] | |||
beams_buf = beams_buf[:self.num_cands] | |||
return scores_buf, indices_buf, beams_buf, constraint_states | |||
class LengthConstrainedBeamSearch(Search): | |||
def __init__(self, tgt_dict, min_len_a, min_len_b, max_len_a, max_len_b): | |||
super().__init__(tgt_dict) | |||
self.min_len_a = min_len_a | |||
self.min_len_b = min_len_b | |||
self.max_len_a = max_len_a | |||
self.max_len_b = max_len_b | |||
self.beam = BeamSearch(tgt_dict) | |||
self.needs_src_lengths = True | |||
def step( | |||
self, | |||
step: int, | |||
lprobs, | |||
scores, | |||
prev_output_tokens: Optional[Tensor] = None, | |||
original_batch_idxs: Optional[Tensor] = None, | |||
): | |||
min_lens = self.min_len_a * self.src_lengths + self.min_len_b | |||
max_lens = self.max_len_a * self.src_lengths + self.max_len_b | |||
lprobs[step < min_lens, :, self.eos] = -math.inf | |||
lprobs[step >= max_lens, :, self.eos] = 0 | |||
return self.beam.step(step, lprobs, scores) | |||
class DiverseBeamSearch(Search): | |||
"""Diverse Beam Search. | |||
See "Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence | |||
Models" for details. | |||
We only implement the Hamming Diversity penalty here, which performed best | |||
in the original paper. | |||
""" | |||
def __init__(self, tgt_dict, num_groups, diversity_strength): | |||
super().__init__(tgt_dict) | |||
self.num_groups = num_groups | |||
self.diversity_strength = -diversity_strength | |||
self.beam = BeamSearch(tgt_dict) | |||
@torch.jit.export | |||
def step( | |||
self, | |||
step: int, | |||
lprobs, | |||
scores, | |||
prev_output_tokens: Optional[Tensor] = None, | |||
original_batch_idxs: Optional[Tensor] = None, | |||
): | |||
bsz, beam_size, vocab_size = lprobs.size() | |||
if beam_size % self.num_groups != 0: | |||
raise ValueError( | |||
'DiverseBeamSearch requires --beam to be divisible by the number of groups' | |||
) | |||
# initialize diversity penalty | |||
diversity_buf = torch.zeros(lprobs[:, 0, :].size()).to(lprobs) | |||
scores_G, indices_G, beams_G = [], [], [] | |||
for g in range(self.num_groups): | |||
lprobs_g = lprobs[:, g::self.num_groups, :] | |||
scores_g = scores[:, g::self.num_groups, :] if step > 0 else None | |||
# apply diversity penalty | |||
if g > 0: | |||
lprobs_g = torch.add( | |||
lprobs_g, | |||
other=diversity_buf.unsqueeze(1), | |||
alpha=self.diversity_strength, | |||
) | |||
else: | |||
lprobs_g = lprobs_g.contiguous() | |||
scores_buf, indices_buf, beams_buf = self.beam.step( | |||
step, lprobs_g, scores_g) | |||
beams_buf.mul_(self.num_groups).add_(g) | |||
scores_G.append(scores_buf.clone()) | |||
indices_G.append(indices_buf.clone()) | |||
beams_G.append(beams_buf.clone()) | |||
# update diversity penalty | |||
diversity_buf.scatter_add_( | |||
1, indices_buf, | |||
torch.ones(indices_buf.size()).to(diversity_buf)) | |||
# interleave results from different groups | |||
scores_buf = torch.stack(scores_G, dim=2).view(bsz, -1) | |||
indices_buf = torch.stack(indices_G, dim=2).view(bsz, -1) | |||
beams_buf = torch.stack(beams_G, dim=2).view(bsz, -1) | |||
return scores_buf, indices_buf, beams_buf | |||
class Sampling(Search): | |||
sampling_topk: int | |||
sampling_topp: float | |||
def __init__(self, tgt_dict, sampling_topk=-1, sampling_topp=-1.0): | |||
super().__init__(tgt_dict) | |||
self.sampling_topk = sampling_topk | |||
self.sampling_topp = sampling_topp | |||
def _sample_topp(self, lprobs): | |||
"""Sample among the smallest set of elements whose cumulative probability mass exceeds p. | |||
See `"The Curious Case of Neural Text Degeneration" | |||
(Holtzman et al., 2019) <https://arxiv.org/abs/1904.09751>`_. | |||
Args: | |||
lprobs: (bsz x input_beam_size x vocab_size) | |||
the model's log-probabilities over the vocabulary at the current step | |||
Return: A tuple of (trimed_probs, truncated_indices) where: | |||
trimed_probs: (bsz x input_beam_size x ?) | |||
the model's probabilities over the elements selected to sample from. The | |||
width of the third dimension is determined by top-P. | |||
truncated_indices: (bsz x input_beam_size x ?) | |||
the indices of the chosen elements. | |||
""" | |||
probs = lprobs.exp_() | |||
# sort the last dimension (vocab dimension) in descending order | |||
sorted_probs, sorted_indices = probs.sort(descending=True) | |||
# compute a mask to indicate the words to be included in the top-P set. | |||
cumsum_probs = sorted_probs.cumsum(dim=2) | |||
mask = cumsum_probs.lt(self.sampling_topp) | |||
# note that mask was computed by 'lt'. One more word needs to be included | |||
# so that the cumulative probability mass can exceed p. | |||
cumsum_mask = mask.cumsum(dim=2) | |||
last_included = cumsum_mask[:, :, -1:] | |||
last_included.clamp_(0, mask.size()[2] - 1) | |||
mask = mask.scatter_(2, last_included, 1) | |||
# truncate unnecessary dims. | |||
max_dim = last_included.max() | |||
truncated_mask = mask[:, :, :max_dim + 1] | |||
truncated_probs = sorted_probs[:, :, :max_dim + 1] | |||
truncated_indices = sorted_indices[:, :, :max_dim + 1] | |||
# trim the words that are not in top-P by setting their probabilities | |||
# to 0, so that they would not be sampled later. | |||
trim_mask = ~truncated_mask | |||
trimed_probs = truncated_probs.masked_fill_(trim_mask, 0) | |||
return trimed_probs, truncated_indices | |||
@torch.jit.export | |||
def step( | |||
self, | |||
step: int, | |||
lprobs, | |||
scores, | |||
prev_output_tokens: Optional[Tensor] = None, | |||
original_batch_idxs: Optional[Tensor] = None, | |||
): | |||
bsz, beam_size, vocab_size = lprobs.size() | |||
if step == 0: | |||
# at the first step all hypotheses are equally likely, so use | |||
# only the first beam | |||
lprobs = lprobs[:, ::beam_size, :].contiguous() | |||
if self.sampling_topp > 0: | |||
# only sample from the smallest set of words whose cumulative probability mass exceeds p | |||
probs, top_indices = self._sample_topp(lprobs) | |||
elif self.sampling_topk > 0: | |||
# only sample from top-k candidates | |||
lprobs, top_indices = lprobs.topk(self.sampling_topk) | |||
probs = lprobs.exp_() | |||
else: | |||
probs = lprobs.exp_() | |||
# dummy data to be consistent with true branch for type check | |||
top_indices = torch.empty(0).to(probs) | |||
# sample | |||
if step == 0: | |||
indices_buf = torch.multinomial( | |||
probs.view(bsz, -1), | |||
beam_size, | |||
replacement=True, | |||
).view(bsz, beam_size) | |||
else: | |||
indices_buf = torch.multinomial( | |||
probs.view(bsz * beam_size, -1), | |||
1, | |||
replacement=True, | |||
).view(bsz, beam_size) | |||
if step == 0: | |||
# expand to beam size | |||
probs = probs.expand(bsz, beam_size, -1) | |||
# gather scores | |||
scores_buf = torch.gather( | |||
probs, dim=2, index=indices_buf.unsqueeze(-1)) | |||
scores_buf = scores_buf.log_().view(bsz, -1) | |||
# remap indices if using top-k or top-P sampling | |||
if self.sampling_topk > 0 or self.sampling_topp > 0: | |||
indices_buf = torch.gather( | |||
top_indices.expand(bsz, beam_size, -1), | |||
dim=2, | |||
index=indices_buf.unsqueeze(-1), | |||
).squeeze(2) | |||
if step == 0: | |||
beams_buf = indices_buf.new_zeros(bsz, beam_size) | |||
else: | |||
beams_buf = torch.arange(0, | |||
beam_size).to(indices_buf).repeat(bsz, 1) | |||
# make scores cumulative | |||
scores_buf.add_( | |||
torch.gather(scores[:, :, step - 1], dim=1, index=beams_buf)) | |||
return scores_buf, indices_buf, beams_buf | |||
class DiverseSiblingsSearch(Search): | |||
""" | |||
Beam search with diverse siblings. | |||
See "A Simple, Fast Diverse Decoding Algorithm for Neural Generation" for details. | |||
https://arxiv.org/abs/1611.08562 | |||
1/ Calculate hypotheses for each beam | |||
2/ Intra-sibling ordering | |||
3/ Rewrite scores | |||
4/ Choose top K hypotheses | |||
if diversity_rate == 0 is equivalent to BeamSearch | |||
""" | |||
def __init__(self, tgt_dict, diversity_rate): | |||
super().__init__(tgt_dict) | |||
self.diversity_rate = diversity_rate | |||
self.beam = BeamSearch(tgt_dict) | |||
def step( | |||
self, | |||
step: int, | |||
lprobs, | |||
scores, | |||
prev_output_tokens: Optional[Tensor] = None, | |||
original_batch_idxs: Optional[Tensor] = None, | |||
): | |||
bsz, beam_size, vocab_size = lprobs.size() | |||
k = min( | |||
# Take the best 2 x beam_size predictions. We'll choose the first | |||
# beam_size of these which don't predict eos to continue with. | |||
beam_size * 2, | |||
lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad | |||
) | |||
s_list: List[Tensor] | |||
i_list: List[Tensor] | |||
s_list = [torch.empty(0).to(lprobs) for i in range(beam_size)] | |||
i_list = [ | |||
torch.LongTensor().to(device=lprobs.device) | |||
for i in range(beam_size) | |||
] | |||
sibling_score = torch.arange(1, k + 1).to(lprobs) * self.diversity_rate | |||
if step == 0: | |||
return self.beam.step(step, lprobs, scores) | |||
lprobs.add_(scores[:, :, step - 1].unsqueeze(-1)) | |||
# 1/ Calculate hypotheses for each beam | |||
for i in range(beam_size): | |||
torch.topk( | |||
lprobs[:, i, :].view(bsz, -1), k, out=(s_list[i], i_list[i])) | |||
i_list[i].fmod_(vocab_size) | |||
# 2/ Intra-sibling ordering by default from topk + 3/ Rewrite scores | |||
s_list[i].sub_(sibling_score) | |||
# 4/ Choose top K hypotheses | |||
indices = torch.stack(i_list, dim=1).view(bsz, -1) | |||
final_scores = torch.empty(0).to(lprobs) | |||
final_indices = torch.LongTensor().to(device=lprobs.device) | |||
final_beams = torch.LongTensor().to(device=lprobs.device) | |||
(final_scores, final_indices) = torch.topk( | |||
torch.stack(s_list, dim=1).view(bsz, -1), | |||
k, | |||
) | |||
final_beams = final_indices // k | |||
for i in range(bsz): | |||
final_indices[i] = indices[i][final_indices[i]] | |||
return final_scores, final_indices, final_beams |
@@ -0,0 +1,996 @@ | |||
# Copyright 2022 The OFA-Sys Team. | |||
# All rights reserved. | |||
# This source code is licensed under the Apache 2.0 license | |||
# You may obtain a copy of the License at | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
import math | |||
import sys | |||
from typing import Dict, List, Optional, Tuple | |||
import torch | |||
import torch.nn as nn | |||
from torch import Tensor | |||
from ..generate import search | |||
from .ngram_repeat_block import NGramRepeatBlock | |||
def _expand_mask(mask: torch.Tensor, | |||
dtype: torch.dtype, | |||
tgt_len: Optional[int] = None): | |||
r""" | |||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. | |||
""" | |||
bsz, src_len = mask.size() | |||
tgt_len = tgt_len if tgt_len is not None else src_len | |||
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, | |||
src_len).to(dtype) | |||
return expanded_mask.masked_fill(expanded_mask.bool(), | |||
torch.finfo(dtype).min) | |||
class SequenceGenerator(nn.Module): | |||
def __init__(self, | |||
tokenizer, | |||
beam_size=1, | |||
max_len_a=0, | |||
max_len_b=200, | |||
max_len=0, | |||
min_len=1, | |||
normalize_scores=True, | |||
len_penalty=1.0, | |||
unk_penalty=0.0, | |||
temperature=1.0, | |||
match_source_len=False, | |||
no_repeat_ngram_size=0, | |||
search_strategy=None, | |||
eos=None, | |||
symbols_to_strip_from_output=None, | |||
lm_model=None, | |||
lm_weight=1.0, | |||
constraint_trie=None, | |||
constraint_range=None, | |||
gen_code=False, | |||
gen_box=False, | |||
ignore_eos=False, | |||
zero_shot=False): | |||
"""Generates translations of a given source sentence. | |||
Args: | |||
models (List[~fairseq.models.FairseqModel]): ensemble of models, | |||
currently support fairseq.models.TransformerModel for scripting | |||
beam_size (int, optional): beam width (default: 1) | |||
max_len_a/b (int, optional): generate sequences of maximum length | |||
ax + b, where x is the source length | |||
max_len (int, optional): the maximum length of the generated output | |||
(not including end-of-sentence) | |||
min_len (int, optional): the minimum length of the generated output | |||
(not including end-of-sentence) | |||
normalize_scores (bool, optional): normalize scores by the length | |||
of the output (default: True) | |||
len_penalty (float, optional): length penalty, where <1.0 favors | |||
shorter, >1.0 favors longer sentences (default: 1.0) | |||
unk_penalty (float, optional): unknown word penalty, where <0 | |||
produces more unks, >0 produces fewer (default: 0.0) | |||
temperature (float, optional): temperature, where values | |||
>1.0 produce more uniform samples and values <1.0 produce | |||
sharper samples (default: 1.0) | |||
match_source_len (bool, optional): outputs should match the source | |||
length (default: False) | |||
""" | |||
super().__init__() | |||
self.gen_code = gen_code | |||
self.gen_box = gen_box | |||
self.ignore_eos = ignore_eos | |||
self.tokenizer = tokenizer | |||
self.tgt_dict = { | |||
value: key | |||
for key, value in tokenizer.get_vocab().items() | |||
} | |||
added = { | |||
value: key | |||
for key, value in tokenizer.get_added_vocab().items() | |||
} | |||
self.tgt_dict.update(added) | |||
self.pad = tokenizer.pad_token_id | |||
self.unk = tokenizer.unk_token_id | |||
self.bos = tokenizer.bos_token_id | |||
self.eos = tokenizer.eos_token_id | |||
self.symbols_to_strip_from_output = ( | |||
symbols_to_strip_from_output.union({self.eos}) if | |||
symbols_to_strip_from_output is not None else {self.bos, self.eos}) | |||
self.vocab_size = len(self.tgt_dict) | |||
self.beam_size = beam_size | |||
# the max beam size is the dictionary size - 1, since we never select pad | |||
self.beam_size = min(beam_size, self.vocab_size - 1) | |||
self.max_len_a = max_len_a | |||
self.max_len_b = max_len_b | |||
self.min_len = min_len | |||
self.max_len = max_len | |||
self.normalize_scores = normalize_scores | |||
self.len_penalty = len_penalty | |||
self.unk_penalty = unk_penalty | |||
self.temperature = temperature | |||
self.match_source_len = match_source_len | |||
self.zero_shot = zero_shot | |||
if no_repeat_ngram_size > 0: | |||
self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size) | |||
else: | |||
self.repeat_ngram_blocker = None | |||
assert temperature > 0, '--temperature must be greater than 0' | |||
self.search = ( | |||
search.BeamSearch(self.tokenizer) | |||
if search_strategy is None else search_strategy) | |||
# We only need to set src_lengths in LengthConstrainedBeamSearch. | |||
# As a module attribute, setting it would break in multithread | |||
# settings when the model is shared. | |||
self.should_set_src_lengths = ( | |||
hasattr(self.search, 'needs_src_lengths') | |||
and self.search.needs_src_lengths) | |||
self.lm_model = lm_model | |||
self.lm_weight = lm_weight | |||
if self.lm_model is not None: | |||
self.lm_model.eval() | |||
self.constraint_trie = constraint_trie | |||
self.constraint_start = None | |||
self.constraint_end = None | |||
if constraint_range is not None: | |||
constraint_start, constraint_end = constraint_range.split(',') | |||
self.constraint_start = int(constraint_start) | |||
self.constraint_end = int(constraint_end) | |||
@torch.no_grad() | |||
def forward( | |||
self, | |||
sample: Dict[str, Dict[str, Tensor]], | |||
prefix_tokens: Optional[Tensor] = None, | |||
bos_token: Optional[int] = None, | |||
): | |||
"""Generate a batch of translations. | |||
Args: | |||
sample (dict): batch | |||
prefix_tokens (torch.LongTensor, optional): force decoder to begin | |||
with these tokens | |||
bos_token (int, optional): beginning of sentence token | |||
(default: self.eos) | |||
""" | |||
return self._generate(sample, prefix_tokens, bos_token=bos_token) | |||
@torch.no_grad() | |||
def generate(self, models, sample: Dict[str, Dict[str, Tensor]], | |||
**kwargs) -> List[List[Dict[str, Tensor]]]: | |||
"""Generate translations. Match the api of other fairseq generators. | |||
Args: | |||
models (List[~fairseq.models.FairseqModel]): ensemble of models | |||
sample (dict): batch | |||
prefix_tokens (torch.LongTensor, optional): force decoder to begin | |||
with these tokens | |||
constraints (torch.LongTensor, optional): force decoder to include | |||
the list of constraints | |||
bos_token (int, optional): beginning of sentence token | |||
(default: self.eos) | |||
""" | |||
return self._generate(models, sample, **kwargs) | |||
def _generate( | |||
self, | |||
models, | |||
sample: Dict[str, Dict[str, Tensor]], | |||
prefix_tokens: Optional[Tensor] = None, | |||
constraints: Optional[Tensor] = None, | |||
bos_token: Optional[int] = None, | |||
): | |||
model = EnsembleModel(models) | |||
# incremental_states = torch.jit.annotate( | |||
# List[Dict[str, Dict[str, Optional[Tensor]]]], | |||
# [ | |||
# torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) | |||
# for i in range(model.models_size) | |||
# ], | |||
# ) | |||
incremental_states = torch.jit.annotate( | |||
List[Tuple[Tuple[torch.Tensor]]], | |||
[ | |||
torch.jit.annotate(Tuple[Tuple[torch.Tensor]], {}) | |||
for i in range(model.models_size) | |||
], | |||
) | |||
# print("incremental_states",incremental_states) | |||
# print("incremental_states[0]",incremental_states[0]) | |||
net_input = sample['net_input'] | |||
if 'src_tokens' in net_input: | |||
src_tokens = net_input['src_tokens'] | |||
# length of the source text being the character length except EndOfSentence and pad | |||
src_lengths = ((src_tokens.ne(self.eos) | |||
& src_tokens.ne(self.pad)).long().sum(dim=1)) | |||
elif 'input_ids' in net_input: | |||
src_tokens = net_input['input_ids'] | |||
# length of the source text being the character length except EndOfSentence and pad | |||
src_lengths = ((src_tokens.ne(self.eos) | |||
& src_tokens.ne(self.pad)).long().sum(dim=1)) | |||
elif 'source' in net_input: | |||
src_tokens = net_input['source'] | |||
src_lengths = ( | |||
net_input['padding_mask'].size(-1) | |||
- net_input['padding_mask'].sum(-1) | |||
if net_input['padding_mask'] is not None else torch.tensor( | |||
src_tokens.size(-1)).to(src_tokens)) | |||
elif 'features' in net_input: | |||
src_tokens = net_input['features'] | |||
src_lengths = ( | |||
net_input['padding_mask'].size(-1) | |||
- net_input['padding_mask'].sum(-1) | |||
if net_input['padding_mask'] is not None else torch.tensor( | |||
src_tokens.size(-1)).to(src_tokens)) | |||
else: | |||
raise Exception( | |||
'expected src_tokens or source in net input. input keys: ' | |||
+ str(net_input.keys())) | |||
# bsz: total number of sentences in beam | |||
# Note that src_tokens may have more than 2 dimensions (i.e. audio features) | |||
bsz, src_len = src_tokens.size()[:2] | |||
beam_size = self.beam_size | |||
if constraints is not None and not self.search.supports_constraints: | |||
raise NotImplementedError( | |||
"Target-side constraints were provided, but search method doesn't support them" | |||
) | |||
# Initialize constraints, when active | |||
self.search.init_constraints(constraints, beam_size) | |||
max_len: int = -1 | |||
if self.match_source_len: | |||
max_len = src_lengths.max().item() | |||
else: | |||
max_len = int(self.max_len_a * src_len + self.max_len_b) | |||
assert ( | |||
self.min_len <= max_len | |||
), 'min_len cannot be larger than max_len, please adjust these!' | |||
# compute the encoder output for each beam | |||
with torch.autograd.profiler.record_function( | |||
'EnsembleModel: forward_encoder'): | |||
encoder_outs = model.forward_encoder(net_input) | |||
# placeholder of indices for bsz * beam_size to hold tokens and accumulative scores | |||
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) | |||
new_order = new_order.to(src_tokens.device).long() | |||
encoder_outs = model.reorder_encoder_out(encoder_outs, new_order) | |||
# ensure encoder_outs is a List. | |||
assert encoder_outs is not None | |||
# initialize buffers | |||
scores = (torch.zeros(bsz * beam_size, | |||
max_len + 1).to(src_tokens).float() | |||
) # +1 for eos; pad is never chosen for scoring | |||
tokens = (torch.zeros(bsz * beam_size, | |||
max_len + 2).to(src_tokens).long().fill_( | |||
self.pad)) # +2 for eos and pad | |||
# tokens[:, 0] = self.eos if bos_token is None else bos_token | |||
tokens[:, 0] = self.bos | |||
attn: Optional[Tensor] = None | |||
# A list that indicates candidates that should be ignored. | |||
# For example, suppose we're sampling and have already finalized 2/5 | |||
# samples. Then cands_to_ignore would mark 2 positions as being ignored, | |||
# so that we only finalize the remaining 3 samples. | |||
cands_to_ignore = (torch.zeros(bsz, beam_size).to(src_tokens).eq(-1) | |||
) # forward and backward-compatible False mask | |||
# list of completed sentences | |||
finalized = torch.jit.annotate( | |||
List[List[Dict[str, Tensor]]], | |||
[ | |||
torch.jit.annotate(List[Dict[str, Tensor]], []) | |||
for i in range(bsz) | |||
], | |||
) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step | |||
# a boolean array indicating if the sentence at the index is finished or not | |||
finished = [False for i in range(bsz)] | |||
num_remaining_sent = bsz # number of sentences remaining | |||
# number of candidate hypos per step | |||
cand_size = 2 * beam_size # 2 x beam size in case half are EOS | |||
# offset arrays for converting between different indexing schemes | |||
bbsz_offsets = ((torch.arange(0, bsz) | |||
* beam_size).unsqueeze(1).type_as(tokens).to( | |||
src_tokens.device)) | |||
cand_offsets = torch.arange(0, cand_size).type_as(tokens).to( | |||
src_tokens.device) | |||
reorder_state: Optional[Tensor] = None | |||
batch_idxs: Optional[Tensor] = None | |||
original_batch_idxs: Optional[Tensor] = None | |||
if 'id' in sample and isinstance(sample['id'], Tensor): | |||
original_batch_idxs = sample['id'] | |||
else: | |||
original_batch_idxs = torch.arange(0, bsz).type_as(tokens) | |||
for step in range(max_len + 1): # one extra step for EOS marker | |||
# reorder decoder internal states based on the prev choice of beams | |||
if reorder_state is not None: | |||
if batch_idxs is not None: | |||
# update beam indices to take into account removed sentences | |||
corr = batch_idxs - torch.arange( | |||
batch_idxs.numel()).type_as(batch_idxs) | |||
reorder_state.view(-1, beam_size).add_( | |||
corr.unsqueeze(-1) * beam_size) | |||
original_batch_idxs = original_batch_idxs[batch_idxs] | |||
model.reorder_incremental_state(incremental_states, | |||
reorder_state) # todo | |||
encoder_outs = model.reorder_encoder_out( | |||
encoder_outs, reorder_state) | |||
with torch.autograd.profiler.record_function( | |||
'EnsembleModel: forward_decoder'): | |||
lprobs, avg_attn_scores = model.forward_decoder( | |||
tokens[:, :step + 1], | |||
encoder_outs, | |||
incremental_states, | |||
self.temperature, | |||
constraint_trie=self.constraint_trie, | |||
constraint_start=self.constraint_start, | |||
constraint_end=self.constraint_end, | |||
gen_code=self.gen_code, | |||
zero_shot=self.zero_shot, | |||
prefix_tokens=prefix_tokens) | |||
if self.lm_model is not None: | |||
lm_out = self.lm_model(tokens[:, :step + 1]) | |||
probs = self.lm_model.get_normalized_probs( | |||
lm_out, log_probs=True, sample=None) | |||
probs = probs[:, -1, :] * self.lm_weight | |||
lprobs += probs | |||
# handle prefix tokens (possibly with different lengths) | |||
if (prefix_tokens is not None and step < prefix_tokens.size(1) | |||
and step < max_len): | |||
lprobs, tokens, scores = self._prefix_tokens( | |||
step, lprobs, scores, tokens, prefix_tokens, beam_size) | |||
elif step < self.min_len: | |||
# minimum length constraint (does not apply if using prefix_tokens) | |||
lprobs[:, self.eos] = -math.inf | |||
lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) | |||
lprobs[:, self.pad] = -math.inf # never select pad | |||
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty | |||
if (self.gen_code or self.gen_box) and step < max_len: | |||
lprobs[:, :4] = -math.inf | |||
if self.gen_box: | |||
lprobs[:, -1] = -math.inf | |||
if (step + 1) % 5 == 0: | |||
lprobs[:, self.constraint_start:59457] = -math.inf | |||
else: | |||
lprobs[:, 59457:] = -math.inf | |||
# handle max length constraint | |||
if step >= max_len: | |||
lprobs[:, :self.eos] = -math.inf | |||
lprobs[:, self.eos + 1:] = -math.inf | |||
if self.ignore_eos: | |||
lprobs[:, self.eos] = 1 | |||
# Record attention scores, only support avg_attn_scores is a Tensor | |||
if avg_attn_scores is not None: | |||
if attn is None: | |||
attn = torch.empty(bsz * beam_size, | |||
avg_attn_scores.size(1), | |||
max_len + 2).to(scores) | |||
# print("+++++++ debug attention shape +++++++") | |||
# print("attn", attn.shape) | |||
# print("avg_attn_scores", avg_attn_scores.shape) | |||
attn[:, :, step + 1].copy_(avg_attn_scores) | |||
# print("attn[:, :, step + 1]", attn[:, :, step + 1].shape) | |||
# print("attn", attn.shape) | |||
scores = scores.type_as(lprobs) | |||
eos_bbsz_idx = torch.empty(0).to( | |||
tokens | |||
) # indices of hypothesis ending with eos (finished sentences) | |||
eos_scores = torch.empty(0).to( | |||
scores | |||
) # scores of hypothesis ending with eos (finished sentences) | |||
if self.should_set_src_lengths: | |||
self.search.set_src_lengths(src_lengths) | |||
if self.repeat_ngram_blocker is not None: | |||
lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, | |||
beam_size, step) | |||
# Shape: (batch, cand_size) | |||
cand_scores, cand_indices, cand_beams = self.search.step( | |||
step, | |||
lprobs.view(bsz, -1, self.vocab_size), | |||
scores.view(bsz, beam_size, -1)[:, :, :step], | |||
tokens[:, :step + 1], | |||
original_batch_idxs, | |||
) | |||
# cand_bbsz_idx contains beam indices for the top candidate | |||
# hypotheses, with a range of values: [0, bsz*beam_size), | |||
# and dimensions: [bsz, cand_size] | |||
cand_bbsz_idx = cand_beams.add(bbsz_offsets) | |||
# finalize hypotheses that end in eos | |||
# Shape of eos_mask: (batch size, beam size) | |||
eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf) | |||
eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to( | |||
eos_mask) | |||
# only consider eos when it's among the top beam_size indices | |||
# Now we know what beam item(s) to finish | |||
# Shape: 1d list of absolute-numbered | |||
eos_bbsz_idx = torch.masked_select( | |||
cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]) | |||
finalized_sents: List[int] = [] | |||
if eos_bbsz_idx.numel() > 0: | |||
eos_scores = torch.masked_select( | |||
cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]) | |||
finalized_sents = self.finalize_hypos( | |||
step, | |||
eos_bbsz_idx, | |||
eos_scores, | |||
tokens, | |||
scores, | |||
finalized, | |||
finished, | |||
beam_size, | |||
attn, | |||
src_lengths, | |||
max_len, | |||
) | |||
num_remaining_sent -= len(finalized_sents) | |||
assert num_remaining_sent >= 0 | |||
if num_remaining_sent == 0: | |||
break | |||
if self.search.stop_on_max_len and step >= max_len: | |||
break | |||
assert step < max_len, f'{step} < {max_len}' | |||
# Remove finalized sentences (ones for which {beam_size} | |||
# finished hypotheses have been generated) from the batch. | |||
if len(finalized_sents) > 0: | |||
new_bsz = bsz - len(finalized_sents) | |||
# construct batch_idxs which holds indices of batches to keep for the next pass | |||
batch_mask = torch.ones( | |||
bsz, dtype=torch.bool, device=cand_indices.device) | |||
batch_mask[finalized_sents] = False | |||
# TODO replace `nonzero(as_tuple=False)` after TorchScript supports it | |||
batch_idxs = torch.arange( | |||
bsz, device=cand_indices.device).masked_select(batch_mask) | |||
# Choose the subset of the hypothesized constraints that will continue | |||
self.search.prune_sentences(batch_idxs) | |||
eos_mask = eos_mask[batch_idxs] | |||
cand_beams = cand_beams[batch_idxs] | |||
bbsz_offsets.resize_(new_bsz, 1) | |||
cand_bbsz_idx = cand_beams.add(bbsz_offsets) | |||
cand_scores = cand_scores[batch_idxs] | |||
cand_indices = cand_indices[batch_idxs] | |||
if prefix_tokens is not None: | |||
prefix_tokens = prefix_tokens[batch_idxs] | |||
src_lengths = src_lengths[batch_idxs] | |||
cands_to_ignore = cands_to_ignore[batch_idxs] | |||
scores = scores.view(bsz, -1)[batch_idxs].view( | |||
new_bsz * beam_size, -1) | |||
tokens = tokens.view(bsz, -1)[batch_idxs].view( | |||
new_bsz * beam_size, -1) | |||
if attn is not None: | |||
attn = attn.view(bsz, -1)[batch_idxs].view( | |||
new_bsz * beam_size, attn.size(1), -1) | |||
bsz = new_bsz | |||
else: | |||
batch_idxs = None | |||
# Set active_mask so that values > cand_size indicate eos hypos | |||
# and values < cand_size indicate candidate active hypos. | |||
# After, the min values per row are the top candidate active hypos | |||
# Rewrite the operator since the element wise or is not supported in torchscript. | |||
eos_mask[:, :beam_size] = ~( # noqa | |||
(~cands_to_ignore) & (~eos_mask[:, :beam_size])) # noqa | |||
active_mask = torch.add( | |||
eos_mask.type_as(cand_offsets) * cand_size, | |||
cand_offsets[:eos_mask.size(1)], | |||
) | |||
# get the top beam_size active hypotheses, which are just | |||
# the hypos with the smallest values in active_mask. | |||
# {active_hypos} indicates which {beam_size} hypotheses | |||
# from the list of {2 * beam_size} candidates were | |||
# selected. Shapes: (batch size, beam size) | |||
new_cands_to_ignore, active_hypos = torch.topk( | |||
active_mask, k=beam_size, dim=1, largest=False) | |||
# update cands_to_ignore to ignore any finalized hypos. | |||
cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size] | |||
# Make sure there is at least one active item for each sentence in the batch. | |||
assert (~cands_to_ignore).any(dim=1).all() | |||
# update cands_to_ignore to ignore any finalized hypos | |||
# {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam | |||
# can be selected more than once). | |||
active_bbsz_idx = torch.gather( | |||
cand_bbsz_idx, dim=1, index=active_hypos) | |||
active_scores = torch.gather( | |||
cand_scores, dim=1, index=active_hypos) | |||
active_bbsz_idx = active_bbsz_idx.view(-1) | |||
active_scores = active_scores.view(-1) | |||
# copy tokens and scores for active hypotheses | |||
# Set the tokens for each beam (can select the same row more than once) | |||
tokens[:, :step + 1] = torch.index_select( | |||
tokens[:, :step + 1], dim=0, index=active_bbsz_idx) | |||
# Select the next token for each of them | |||
tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather( | |||
cand_indices, dim=1, index=active_hypos) | |||
if step > 0: | |||
scores[:, :step] = torch.index_select( | |||
scores[:, :step], dim=0, index=active_bbsz_idx) | |||
scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather( | |||
cand_scores, dim=1, index=active_hypos) | |||
# Update constraints based on which candidates were selected for the next beam | |||
self.search.update_constraints(active_hypos) | |||
# copy attention for active hypotheses | |||
if attn is not None: | |||
attn[:, :, :step + 2] = torch.index_select( | |||
attn[:, :, :step + 2], dim=0, index=active_bbsz_idx) | |||
# reorder incremental state in decoder | |||
reorder_state = active_bbsz_idx | |||
# sort by score descending | |||
for sent in range(len(finalized)): | |||
scores = torch.tensor( | |||
[float(elem['score'].item()) for elem in finalized[sent]]) | |||
_, sorted_scores_indices = torch.sort(scores, descending=True) | |||
finalized[sent] = [ | |||
finalized[sent][ssi] for ssi in sorted_scores_indices | |||
] | |||
finalized[sent] = torch.jit.annotate(List[Dict[str, Tensor]], | |||
finalized[sent]) | |||
return finalized | |||
def _prefix_tokens(self, step: int, lprobs, scores, tokens, prefix_tokens, | |||
beam_size: int): | |||
"""Handle prefix tokens""" | |||
prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat( | |||
1, beam_size).view(-1) | |||
prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) | |||
prefix_mask = prefix_toks.ne(self.pad) | |||
if self.constraint_trie is None: | |||
lprobs[prefix_mask] = torch.min(prefix_lprobs) - 1 | |||
else: | |||
lprobs[prefix_mask] = -math.inf | |||
lprobs[prefix_mask] = lprobs[prefix_mask].scatter( | |||
-1, prefix_toks[prefix_mask].unsqueeze(-1), | |||
prefix_lprobs[prefix_mask]) | |||
# if prefix includes eos, then we should make sure tokens and | |||
# scores are the same across all beams | |||
eos_mask = prefix_toks.eq(self.eos) | |||
if eos_mask.any(): | |||
# validate that the first beam matches the prefix | |||
first_beam = tokens[eos_mask].view(-1, beam_size, | |||
tokens.size(-1))[:, 0, | |||
1:step + 1] | |||
eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] | |||
target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] | |||
assert (first_beam == target_prefix).all() | |||
# copy tokens, scores and lprobs from the first beam to all beams | |||
tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, | |||
beam_size) | |||
scores = self.replicate_first_beam(scores, eos_mask_batch_dim, | |||
beam_size) | |||
lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, | |||
beam_size) | |||
return lprobs, tokens, scores | |||
def replicate_first_beam(self, tensor, mask, beam_size: int): | |||
tensor = tensor.view(-1, beam_size, tensor.size(-1)) | |||
tensor[mask] = tensor[mask][:, :1, :] | |||
return tensor.view(-1, tensor.size(-1)) | |||
def finalize_hypos( | |||
self, | |||
step: int, | |||
bbsz_idx, | |||
eos_scores, | |||
tokens, | |||
scores, | |||
finalized: List[List[Dict[str, Tensor]]], | |||
finished: List[bool], | |||
beam_size: int, | |||
attn: Optional[Tensor], | |||
src_lengths, | |||
max_len: int, | |||
): | |||
"""Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly. | |||
A sentence is finalized when {beam_size} finished items have been collected for it. | |||
Returns number of sentences (not beam items) being finalized. | |||
These will be removed from the batch and not processed further. | |||
Args: | |||
bbsz_idx (Tensor): | |||
""" | |||
assert bbsz_idx.numel() == eos_scores.numel() | |||
# clone relevant token and attention tensors. | |||
# tokens is (batch * beam, max_len). So the index_select | |||
# gets the newly EOS rows, then selects cols 1..{step + 2} | |||
tokens_clone = tokens.index_select( | |||
0, bbsz_idx)[:, 1:step + 2] # skip the first index, which is EOS | |||
tokens_clone[:, step] = self.eos | |||
attn_clone = ( | |||
attn.index_select(0, bbsz_idx)[:, :, 1:step | |||
+ 2] if attn is not None else None) | |||
# compute scores per token position | |||
pos_scores = scores.index_select(0, bbsz_idx)[:, :step + 1] | |||
pos_scores[:, step] = eos_scores | |||
# convert from cumulative to per-position scores | |||
pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] | |||
# normalize sentence-level scores | |||
if self.normalize_scores: | |||
eos_scores /= (step + 1)**self.len_penalty | |||
# cum_unfin records which sentences in the batch are finished. | |||
# It helps match indexing between (a) the original sentences | |||
# in the batch and (b) the current, possibly-reduced set of | |||
# sentences. | |||
cum_unfin: List[int] = [] | |||
prev = 0 | |||
for f in finished: | |||
if f: | |||
prev += 1 | |||
else: | |||
cum_unfin.append(prev) | |||
cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx) | |||
unfin_idx = bbsz_idx // beam_size | |||
sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx) | |||
# Create a set of "{sent}{unfin_idx}", where | |||
# "unfin_idx" is the index in the current (possibly reduced) | |||
# list of sentences, and "sent" is the index in the original, | |||
# unreduced batch | |||
# For every finished beam item | |||
# sentence index in the current (possibly reduced) batch | |||
seen = (sent << 32) + unfin_idx | |||
unique_seen: List[int] = torch.unique(seen).tolist() | |||
if self.match_source_len: | |||
condition = step > torch.index_select(src_lengths, 0, unfin_idx) | |||
eos_scores = torch.where(condition, torch.tensor(-math.inf), | |||
eos_scores) | |||
sent_list: List[int] = sent.tolist() | |||
for i in range(bbsz_idx.size()[0]): | |||
# An input sentence (among those in a batch) is finished when | |||
# beam_size hypotheses have been collected for it | |||
if len(finalized[sent_list[i]]) < beam_size: | |||
if attn_clone is not None: | |||
# remove padding tokens from attn scores | |||
hypo_attn = attn_clone[i] | |||
else: | |||
hypo_attn = torch.empty(0) | |||
finalized[sent_list[i]].append({ | |||
'tokens': | |||
tokens_clone[i], | |||
'score': | |||
eos_scores[i], | |||
'attention': | |||
hypo_attn, # src_len x tgt_len | |||
'alignment': | |||
torch.empty(0), | |||
'positional_scores': | |||
pos_scores[i], | |||
}) | |||
newly_finished: List[int] = [] | |||
for unique_s in unique_seen: | |||
# check termination conditions for this sentence | |||
unique_sent: int = unique_s >> 32 | |||
unique_unfin_idx: int = unique_s - (unique_sent << 32) | |||
if not finished[unique_sent] and self.is_finished( | |||
step, unique_unfin_idx, max_len, len( | |||
finalized[unique_sent]), beam_size): | |||
finished[unique_sent] = True | |||
newly_finished.append(unique_unfin_idx) | |||
return newly_finished | |||
def is_finished( | |||
self, | |||
step: int, | |||
unfin_idx: int, | |||
max_len: int, | |||
finalized_sent_len: int, | |||
beam_size: int, | |||
): | |||
""" | |||
Check whether decoding for a sentence is finished, which | |||
occurs when the list of finalized sentences has reached the | |||
beam size, or when we reach the maximum length. | |||
""" | |||
assert finalized_sent_len <= beam_size | |||
if finalized_sent_len == beam_size or step == max_len: | |||
return True | |||
return False | |||
class EnsembleModel(nn.Module): | |||
"""A wrapper around an ensemble of models.""" | |||
def __init__(self, models): | |||
super().__init__() | |||
self.models_size = len(models) | |||
# method '__len__' is not supported in ModuleList for torch script | |||
self.single_model = models[0] | |||
self.models = nn.ModuleList(models) | |||
# self.has_incremental: bool = False | |||
# if all( | |||
# hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder) | |||
# for m in models | |||
# ): | |||
# self.has_incremental = True | |||
self.has_incremental = True | |||
def forward(self): | |||
pass | |||
def has_encoder(self): | |||
return hasattr(self.single_model, 'encoder') | |||
def has_incremental_states(self): | |||
return self.has_incremental | |||
def max_decoder_positions(self): | |||
return min([ | |||
m.max_decoder_positions() | |||
for m in self.models if hasattr(m, 'max_decoder_positions') | |||
] + [sys.maxsize]) # | |||
@torch.jit.export | |||
def forward_encoder(self, net_input: Dict[str, Tensor]): | |||
if not self.has_encoder(): | |||
return None | |||
encoder_input = { | |||
k: v | |||
for k, v in net_input.items() if k != 'decoder_input_ids' | |||
} | |||
encoder_input['output_hidden_states'] = True | |||
return [ | |||
model.encoder.forward(**encoder_input) for model in self.models | |||
] | |||
@torch.jit.export | |||
def forward_decoder(self, | |||
tokens, | |||
encoder_outs: List[Dict[str, List[Tensor]]], | |||
incremental_states: List[Optional[torch.Tensor]], | |||
temperature: float = 1.0, | |||
constraint_trie=None, | |||
constraint_start=None, | |||
constraint_end=None, | |||
gen_code=False, | |||
zero_shot=False, | |||
prefix_tokens=None): | |||
log_probs = [] | |||
avg_attn: Optional[Tensor] = None | |||
encoder_out: Optional[Dict[str, List[Tensor]]] = None | |||
code_mask = (tokens.new_ones(tokens.size(0)) * gen_code).bool() | |||
for i, model in enumerate(self.models): | |||
if self.has_encoder(): | |||
encoder_out = encoder_outs[i] | |||
encoder_hidden_states = encoder_out.last_hidden_state | |||
encoder_attention_mask = _expand_mask( | |||
encoder_out.padding_mask, encoder_hidden_states.dtype, | |||
tokens.shape[-1]) | |||
src_pos_embed = encoder_out.position_embedding | |||
# if tokens.eq(self.single_model.config.pad_token_id).any(): | |||
attention_mask = tokens.eq(self.single_model.padding_idx) | |||
# decode each model | |||
if self.has_incremental_states(): | |||
decoder_out = model.decoder.forward( # todo 模型输入不同 | |||
input_ids=tokens, | |||
attention_mask=attention_mask, | |||
encoder_hidden_states=encoder_hidden_states, | |||
encoder_attention_mask=encoder_attention_mask, | |||
code_masks=code_mask, | |||
src_pos_embed=src_pos_embed, | |||
past_key_values=incremental_states[i], | |||
use_cache=True, | |||
output_attentions=True) | |||
else: | |||
if hasattr(model, 'decoder'): | |||
# decoder_out = model.decoder.forward(tokens, code_masks=code_mask, encoder_out=encoder_out) | |||
decoder_out = model.decoder.forward( # todo 模型输入不同 | |||
input_ids=tokens, | |||
attention_mask=attention_mask, | |||
encoder_hidden_states=encoder_hidden_states, | |||
encoder_attention_mask=encoder_attention_mask, | |||
code_masks=code_mask, | |||
src_pos_embed=src_pos_embed) | |||
else: | |||
decoder_out = model.forward(tokens) | |||
# print('#### decoder_out ####', decoder_out) | |||
# print('#### decoder_out ####', decoder_out.keys()) | |||
# for k,v in decoder_out.items(): | |||
# print(k) | |||
# if isinstance(v, Tensor): | |||
# print(v.shape) | |||
# elif k == "past_key_values": | |||
# print(len(v)) | |||
# print([v[0][i].shape for i in range(len(v[0]))]) | |||
# else: | |||
# print(len(v)) | |||
# print([v[i].shape for i in range(len(v))]) | |||
attn: Optional[Tensor] = None | |||
decoder_len = len(decoder_out) | |||
# if decoder_len > 1 and decoder_out[1] is not None: | |||
# if isinstance(decoder_out[1], Tensor): | |||
# attn = decoder_out[1] | |||
# else: | |||
# attn_holder = decoder_out[1]["attn"] | |||
# if isinstance(attn_holder, Tensor): | |||
# attn = attn_holder | |||
# elif attn_holder is not None: | |||
# attn = attn_holder[0] | |||
# if attn is not None: | |||
# attn = attn[:, -1, :] | |||
if 'cross_attentions' in decoder_out: | |||
attn = decoder_out['cross_attentions'][-1].transpose(1, 0) | |||
attn = attn.mean(dim=0) # (B, tgt_len, src_len) | |||
if attn is not None: | |||
attn = attn[:, -1, :] | |||
# decoder_out_tuple = ( | |||
# decoder_out[0][:, -1:, :].div_(temperature), | |||
# None if decoder_len <= 1 else decoder_out[1], | |||
# ) | |||
decoder_out_tuple = ( | |||
decoder_out[0][:, -1:, :].div_(temperature), | |||
None if decoder_len <= 1 else attn, | |||
) | |||
beam_size = decoder_out_tuple[0].size(0) // prefix_tokens.size( | |||
0) if prefix_tokens is not None else 0 | |||
if constraint_trie is not None and not zero_shot: | |||
assert constraint_start is None and constraint_end is None | |||
constraint_masks = decoder_out_tuple[0].new_zeros( | |||
decoder_out_tuple[0].size()).bool() | |||
constraint_prefix_tokens = tokens.tolist() | |||
for token_index, constraint_prefix_token in enumerate( | |||
constraint_prefix_tokens): | |||
prefix_len = prefix_tokens[token_index // beam_size].ne( | |||
1).sum().item() if prefix_tokens is not None else 0 | |||
if len(constraint_prefix_token) > prefix_len: | |||
constraint_prefix_token = [ | |||
0 | |||
] + constraint_prefix_token[prefix_len + 1:] | |||
constraint_nodes = constraint_trie.get_next_layer( | |||
constraint_prefix_token) | |||
constraint_masks[token_index][:, | |||
constraint_nodes] = True | |||
else: | |||
constraint_masks[token_index] = True | |||
decoder_out_tuple[0].masked_fill_(~constraint_masks, -math.inf) | |||
if constraint_start is not None and constraint_end is not None and not zero_shot: | |||
assert constraint_trie is None | |||
decoder_out_tuple[0][:, :, 4:constraint_start] = -math.inf | |||
decoder_out_tuple[0][:, :, constraint_end:] = -math.inf | |||
probs = model.get_normalized_probs( | |||
decoder_out_tuple, log_probs=True, sample=None) | |||
if constraint_trie is not None and zero_shot: | |||
assert constraint_start is None and constraint_end is None | |||
constraint_masks = decoder_out_tuple[0].new_zeros( | |||
decoder_out_tuple[0].size()).bool() | |||
constraint_prefix_tokens = tokens.tolist() | |||
for token_index, constraint_prefix_token in enumerate( | |||
constraint_prefix_tokens): | |||
constraint_nodes = constraint_trie.get_next_layer( | |||
constraint_prefix_token) | |||
constraint_masks[token_index][:, constraint_nodes] = True | |||
probs.masked_fill_(~constraint_masks, -math.inf) | |||
if constraint_start is not None and constraint_end is not None and zero_shot: | |||
assert constraint_trie is None | |||
probs[:, :, 4:constraint_start] = -math.inf | |||
probs[:, :, constraint_end:] = -math.inf | |||
probs = probs[:, -1, :] | |||
if self.models_size == 1: | |||
return probs, attn | |||
log_probs.append(probs) | |||
if attn is not None: | |||
if avg_attn is None: | |||
avg_attn = attn | |||
else: | |||
avg_attn.add_(attn) | |||
avg_probs = torch.logsumexp( | |||
torch.stack(log_probs, dim=0), dim=0) - math.log(self.models_size) | |||
if avg_attn is not None: | |||
avg_attn.div_(self.models_size) | |||
return avg_probs, avg_attn | |||
@torch.jit.export | |||
def reorder_encoder_out(self, | |||
encoder_outs: Optional[List[Dict[str, | |||
List[Tensor]]]], | |||
new_order): | |||
""" | |||
Reorder encoder output according to *new_order*. | |||
Args: | |||
encoder_out: output from the ``forward()`` method | |||
new_order (LongTensor): desired order | |||
Returns: | |||
*encoder_out* rearranged according to *new_order* | |||
""" | |||
new_outs: List[Dict[str, List[Tensor]]] = [] | |||
if not self.has_encoder(): | |||
return new_outs | |||
for i, model in enumerate(self.models): | |||
assert encoder_outs is not None | |||
new_outs.append( | |||
model.encoder.reorder_encoder_out(encoder_outs[i], new_order)) | |||
return new_outs | |||
@torch.jit.export | |||
def reorder_incremental_state( | |||
self, | |||
incremental_states: List[Optional[torch.Tensor]], | |||
new_order, | |||
): | |||
if not self.has_incremental_states(): | |||
return | |||
for i, model in enumerate(self.models): | |||
model.decoder.reorder_incremental_state_scripting( # todo | |||
incremental_states[i], new_order) |
@@ -0,0 +1,512 @@ | |||
# Copyright (c) Facebook, Inc. and its affiliates. | |||
# | |||
# This source code is licensed under the MIT license which can be found at | |||
# https://github.com/facebookresearch/fairseq/blob/main/LICENSE | |||
"""Implements tracking of constraints for a beam item. | |||
A list of constraints is given as a list of one or more token | |||
sequences, each of length at least one token. For example, for an input sentence | |||
> Die maschinelle Übersetzung ist schwer zu kontrollieren. | |||
We could have the constraints: | |||
* to influence | |||
* hard | |||
There are two implementations: | |||
* OrderedConstraintState: Tracks progress through an ordered list of multitoken constraints. | |||
* UnorderedConstraintState: Tracks progress through an unordered list of multitoken constraints. | |||
The difference is that in the first, the constraints are assumed to be | |||
in order; the algorithm will permit zero or more tokens between them. | |||
In the second, the constraints are not ordered, so many orderings will | |||
be explored. | |||
The same sequence can be present any number of times, and will appear | |||
that many times in the output. | |||
""" | |||
from collections import Counter | |||
from typing import List, Set | |||
import torch | |||
class ConstraintState: | |||
def __init__(self): | |||
pass | |||
def pack_constraints( | |||
batch_constraints: List[List[torch.Tensor]]) -> torch.Tensor: | |||
"""Takes a list of list of constraints in tensor form (a list of | |||
tensor constraints for each sentence) and transforms it into a | |||
packed Tensor. For example, here is a batch of size 3 with 3, 0, | |||
and 1 constraints: | |||
[ [ [3 1 2], [3], [4 5 6 7], ] | |||
[], | |||
[ [1 8 9 10 1 4 11 12], ] | |||
] | |||
Its corresponding packed structure is: | |||
[ [ 3 3 1 2 0 3 0 4 5 6 7 0], | |||
[ 0 0 0 0 0 0 0 0 0 0 0 0], | |||
[ 1 1 8 9 10 1 4 11 12 0 0 0] ] | |||
The packed tensor has shape (batch size, maxlen), where | |||
maxlen is defined below. Each row contains concatenated | |||
constraint tokens for that sentence, with 0 appended after | |||
each constraint. The first item in each row is the number | |||
of constraints for that sentence. So maxlen is the maximum | |||
of | |||
(number of constraints) + (sum length of constraints) + 1. | |||
across all sentences in the batch. | |||
""" | |||
# The maximum word length of concatenated constraints for any sentence | |||
max_constraints_len = 1 | |||
for sentence_constraints in batch_constraints: | |||
if len(sentence_constraints): | |||
# number of constraints, plus sum of constrain lens, plus a zero after each | |||
constraints_len = (1 | |||
+ sum([c.size(0) for c in sentence_constraints]) | |||
+ len(sentence_constraints)) | |||
max_constraints_len = max(max_constraints_len, constraints_len) | |||
batch_size = len(batch_constraints) | |||
constraints_tensor = torch.zeros((batch_size, max_constraints_len)).long() | |||
for i, sentence_constraints in enumerate(batch_constraints): | |||
constraints_tensor[i, 0] = len(sentence_constraints) | |||
offset = 1 | |||
for j, constraint in enumerate(sentence_constraints): | |||
this_len = constraint.size(0) | |||
constraints_tensor[i, offset:offset + this_len] = constraint | |||
offset += this_len + 1 | |||
return constraints_tensor.long() | |||
def unpack_constraints(constraint_tensor: torch.Tensor) -> List[torch.Tensor]: | |||
""" | |||
Transforms *one row* of a packed constraint tensor (e.g., for one | |||
sentence in the batch) into a list of constraint tensors. | |||
""" | |||
constraint_list = [] | |||
num_constraints = constraint_tensor[0] | |||
constraints = constraint_tensor.tolist() | |||
offset = 1 | |||
for i in range(num_constraints): | |||
where = constraints.index(0, offset) | |||
constraint_list.append(constraint_tensor[offset:where]) | |||
offset = where + 1 | |||
return constraint_list | |||
class ConstraintNode: | |||
""" | |||
Represents a node in a trie managing unordered constraints. | |||
""" | |||
def __init__(self, token: int = None, parent=None): | |||
# The token associate with this node (None for the root) | |||
self.token = int(token) if token is not None else None | |||
# The parent (None at the root) | |||
self.parent = parent | |||
# Whether this node is a completed constraint | |||
self.terminal = 0 | |||
# List of child nodes | |||
self.children = {} | |||
# The cumulative number of constraints from this point in the | |||
# trie forward | |||
self.num_constraints = 0 | |||
@property | |||
def id(self): | |||
return self.token | |||
def __str__(self): | |||
term = self.terminal != 0 | |||
return f'[{self.token}].{term}#{self.num_constraints}' | |||
def __getitem__(self, key: int): | |||
return self.children.get(key, None) | |||
def next_tokens(self) -> Set[int]: | |||
"""The set of child labels.""" | |||
return set(self.children.keys()) | |||
@staticmethod | |||
def create(constraints: List[List[int]]): | |||
root = ConstraintNode() | |||
for sequence in constraints: | |||
root.add_sequence(sequence) | |||
return root | |||
@staticmethod | |||
def print_graph(node: 'ConstraintNode'): | |||
if len(node.children) == 0: | |||
return str(node) | |||
else: | |||
s = f'({node}' | |||
for child in node.children.values(): | |||
s += ' ' + ConstraintNode.print_graph(child) | |||
s += ')' | |||
return s | |||
def token_counts(self) -> Counter: | |||
"""Returns a counter of the number of times each token is used | |||
in a constraint. | |||
""" | |||
token_counts = Counter() | |||
kids = list(self.children.values()) | |||
while len(kids) > 0: | |||
kid = kids.pop() | |||
token_counts[kid.id] += kid.num_constraints | |||
kids += list(kid.children.values()) | |||
return token_counts | |||
def tokens(self) -> Set[int]: | |||
"""Returns the set of tokens in constraints.""" | |||
return set(self.token_counts().keys()) | |||
def add_sequence(self, sequence: List[int]): | |||
"""Adds a constraint, represented as a list of integers, to | |||
the trie.""" | |||
assert len(sequence) > 0 | |||
token = int(sequence[0]) | |||
if token not in self.children: | |||
self.children[token] = ConstraintNode(token, parent=self) | |||
node = self.children[token] | |||
if len(sequence) == 1: | |||
node.terminal += 1 | |||
node.num_constraints += 1 | |||
parent = node.parent | |||
while parent is not None: | |||
parent.num_constraints += 1 | |||
parent = parent.parent | |||
else: | |||
node.add_sequence(sequence[1:]) | |||
class UnorderedConstraintState(ConstraintState): | |||
""" | |||
Records progress through the set of constraints for each item in the beam | |||
using a trie. | |||
""" | |||
def __init__(self, | |||
node: ConstraintNode, | |||
copy_from: 'ConstraintState' = None): | |||
self.node = node | |||
if copy_from is None: | |||
# The root node | |||
self.root = node | |||
# The set of states in the graph that have been completed | |||
self.completed = Counter() | |||
# The... | |||
self.generated = Counter() | |||
# The list of tokens we need to generate | |||
self.needed_tokens = self.root.tokens() | |||
else: | |||
self.completed = Counter(copy_from.completed) | |||
self.generated = Counter(copy_from.generated) | |||
self.root = copy_from.root | |||
# Mark the node as generated | |||
if self.node != self.root: | |||
self.generated[node] += 1 | |||
@staticmethod | |||
def create(constraint_tensor: torch.Tensor): | |||
constraint_list = unpack_constraints(constraint_tensor) | |||
constraint_trie_root = ConstraintNode.create(constraint_list) | |||
return UnorderedConstraintState(constraint_trie_root) | |||
def __str__(self): | |||
gen_str = ','.join([str(node) for node in self.generated]) | |||
return f'{self.name}/{self.bank}({gen_str})x{self.num_completed}' | |||
def __copy__(self): | |||
copied_state = UnorderedConstraintState(self.node, copy_from=self) | |||
return copied_state | |||
def copy(self): | |||
return self.__copy__() | |||
@property | |||
def name(self): | |||
if self.node.id is None: | |||
return 'ROOT' | |||
else: | |||
return str(self.node.id) | |||
@property | |||
def is_root(self): | |||
return self.node == self.root | |||
@property | |||
def bank(self): | |||
return sum(self.generated.values()) | |||
@property | |||
def num_completed(self): | |||
"""The number of constraints (not constraint tokens) that are completed. | |||
In addition to the already-completed states, we need to account for the | |||
current state, which might get marked as completed when another token | |||
is generated. | |||
""" | |||
in_final = self.node.terminal and self.completed[ | |||
self.node] < self.node.terminal | |||
return sum(self.completed.values()) + in_final | |||
@property | |||
def finished(self): | |||
return self.root.num_constraints - self.num_completed == 0 | |||
@property | |||
def token_counts(self): | |||
return self.root.token_counts() | |||
@property | |||
def tokens(self): | |||
return self.root.tokens() | |||
@property | |||
def num_constraint_tokens(self): | |||
return sum(self.token_counts.values()) | |||
def next_tokens(self) -> Set[int]: | |||
"""Returns the list of tokens that could come next. | |||
These are (a) all tokens extending the root state and, for | |||
non-root states, additionally all tokens extending the current | |||
state.""" | |||
if self.node != self.root: | |||
return self.root.next_tokens().union(self.node.next_tokens()) | |||
else: | |||
return self.root.next_tokens() | |||
def advance(self, token: int): | |||
"""Reads in a token and advances the state. Here's how it works. | |||
We can advance to the next state if: | |||
- there is a matching child | |||
- its path isn't blocked | |||
A path is blocked when all constraints that are descendants of | |||
that node have already been generated, in the current state. | |||
If we are not able to advance from the current state, we "fall | |||
off the graph" and return to the root state. There, we again | |||
try to advance, checking the same criteria. | |||
In any case, when falling off the graph, we need to do some | |||
bookkeeping. We: | |||
- check whether any constraints were met (all prefixes of | |||
current state) | |||
- if one is found, mark it as completed | |||
- adjust visited nodes accordingly | |||
""" | |||
token = int(token) | |||
next_state = None | |||
child = self.node[token] | |||
if child is not None and self.generated[child] < child.num_constraints: | |||
next_state = UnorderedConstraintState(child, copy_from=self) | |||
def rewind(): | |||
"""If we're mid-trie and an "illegal" token is chosen next, we need | |||
to reset our state to the root state. However, along the way, we need | |||
to check whether a prefix of the current trie state represents a state | |||
we could mark as completed. | |||
""" | |||
node = self.node | |||
while node != self.root: | |||
if node.terminal and self.completed[node] < node.terminal: | |||
next_state.completed[node] += 1 | |||
return | |||
next_state.generated[node] -= 1 | |||
node = node.parent | |||
# Fall off the graph, check the root | |||
if next_state is None and token in self.root.next_tokens(): | |||
child = self.root[token] | |||
# We can only traverse this edge if it's not saturated | |||
if self.generated[child] < child.num_constraints: | |||
next_state = UnorderedConstraintState(child, copy_from=self) | |||
else: | |||
next_state = UnorderedConstraintState( | |||
self.root, copy_from=self) | |||
# Rewind | |||
rewind() | |||
elif next_state is None: | |||
next_state = UnorderedConstraintState(self.root, copy_from=self) | |||
# Rewind | |||
rewind() | |||
return next_state | |||
class ConstraintSequence: | |||
def __init__(self, sequences: List[List[int]]): | |||
"""Represents a set of possibly multitoken constraints by | |||
concatenating them and internally recording the end points. | |||
""" | |||
self.sequences = [] | |||
self.endpoints = [] | |||
self.num_tokens = 0 | |||
self.tokens = set() | |||
for sequence in sequences: | |||
for token in sequence: | |||
self.tokens.add(token) | |||
self.num_tokens += len(sequence) | |||
self.endpoints += [False | |||
for x in range(len(sequence) - 1)] + [True] | |||
self.sequences += sequence | |||
def __getitem__(self, key: int): | |||
return self.sequences[key] | |||
def __len__(self): | |||
return len(self.sequences) | |||
def __str__(self): | |||
return str(self.sequences) | |||
class OrderedConstraintState(ConstraintState): | |||
""" | |||
Records progress through the set of linear nonbranching constraints with gaps. | |||
""" | |||
def __init__(self, sequence: ConstraintSequence, state: int = -1): | |||
self.sequence = sequence | |||
self.state = state | |||
@staticmethod | |||
def create(constraint_tensor: torch.Tensor): | |||
constraint_list = unpack_constraints(constraint_tensor) | |||
return OrderedConstraintState(ConstraintSequence(constraint_list), -1) | |||
def __str__(self): | |||
return f'{self.state}/{self.bank}x{self.num_completed}' | |||
def __copy__(self): | |||
return OrderedConstraintState(self.sequence, self.state) | |||
def copy(self): | |||
return self.__copy__() | |||
@property | |||
def num_completed(self): | |||
if self.state == -1: | |||
return 0 | |||
count = len( | |||
list( | |||
filter(lambda x: x, | |||
self.sequence.endpoints[0:self.state + 1]))) | |||
return count | |||
@property | |||
def is_root(self): | |||
return self.state == -1 | |||
@property | |||
def name(self): | |||
if self.state == -1: | |||
return 'ROOT' | |||
else: | |||
return str(self.sequence[self.state]) | |||
@property | |||
def bank(self) -> int: | |||
return self.state + 1 | |||
@property | |||
def finished(self): | |||
return self.state + 1 == len(self.sequence) | |||
@property | |||
def token_counts(self): | |||
return self.sequence.token_counts() | |||
@property | |||
def tokens(self): | |||
return self.sequence.tokens | |||
@property | |||
def num_constraint_tokens(self): | |||
return sum(self.token_counts.values()) | |||
def next_tokens(self) -> Set[int]: | |||
"""Returns the list of tokens that could come next. | |||
These are (a) all tokens extending the root state and, for | |||
non-root states, additionally all tokens extending the current | |||
state.""" | |||
tokens = set() | |||
if self.state > 0: | |||
tokens.add(self.sequence[0]) | |||
if not self.finished: | |||
tokens.add(self.sequence[self.state + 1]) | |||
return tokens | |||
def advance(self, token: int): | |||
"""Reads in a token and advances the state. Here's how it works. | |||
We can advance to the next state if: | |||
- there is a matching child | |||
- its path isn't blocked | |||
A path is blocked when all constraints that are descendants of | |||
that node have already been generated, in the current state. | |||
If we are not able to advance from the current state, we "fall | |||
off the graph" and return to the root state. There, we again | |||
try to advance, checking the same criteria. | |||
In any case, when falling off the graph, we need to do some | |||
bookkeeping. We: | |||
- check whether any constraints were met (all prefixes of | |||
current state) | |||
- if one is found, mark it as completed | |||
- adjust visited nodes accordingly | |||
""" | |||
token = int(token) | |||
# print(f"{self} ADVANCE({token}) {self.sequence} -> ", end="") | |||
if self.finished: | |||
# Accept anything | |||
next_state = self.copy() | |||
elif self.sequence[self.state + 1] == token: | |||
# Advance to the next token | |||
next_state = OrderedConstraintState(self.sequence, self.state + 1) | |||
elif self.sequence.endpoints[self.state]: | |||
# Accept anything between constraints (*) | |||
next_state = self.copy() | |||
elif token == self.sequence[0]: | |||
# Start over having generated the first token | |||
next_state = OrderedConstraintState(self.sequence, 0) | |||
else: | |||
# Start over from the root | |||
next_state = OrderedConstraintState(self.sequence, -1) | |||
return next_state |
@@ -0,0 +1,124 @@ | |||
# Copyright (c) Facebook, Inc. and its affiliates. | |||
# | |||
# This source code is licensed under the MIT license which can be found at | |||
# https://github.com/facebookresearch/fairseq/blob/main/LICENSE | |||
import collections | |||
from collections import abc | |||
from itertools import accumulate | |||
import torch | |||
import torch.nn.functional as F | |||
try: | |||
from amp_C import multi_tensor_l2norm | |||
multi_tensor_l2norm_available = True | |||
except ImportError: | |||
multi_tensor_l2norm_available = False | |||
try: | |||
import torch_xla.core.xla_model as xm | |||
except ImportError: | |||
xm = None | |||
MANIFOLD_PATH_SEP = '|' | |||
def apply_to_sample(f, sample): | |||
if hasattr(sample, '__len__') and len(sample) == 0: | |||
return {} | |||
def _apply(x): | |||
if torch.is_tensor(x): | |||
return f(x) | |||
elif isinstance(x, collections.OrderedDict): | |||
# OrderedDict has attributes that needs to be preserved | |||
od = collections.OrderedDict( | |||
(key, _apply(value)) for key, value in x.items()) | |||
od.__dict__ = x.__dict__ | |||
return od | |||
elif isinstance(x, dict): | |||
return {key: _apply(value) for key, value in x.items()} | |||
elif isinstance(x, list): | |||
return [_apply(x) for x in x] | |||
elif isinstance(x, tuple): | |||
return tuple(_apply(x) for x in x) | |||
elif isinstance(x, set): | |||
return {_apply(x) for x in x} | |||
else: | |||
return x | |||
return _apply(sample) | |||
def move_to_device(batch, device): | |||
r"""Puts each data field to the device""" | |||
if isinstance(batch, torch.Tensor): | |||
return batch.to(device) | |||
elif isinstance(batch, (list, tuple)): | |||
return tuple(move_to_device(item, device) for item in batch) | |||
elif isinstance(batch, abc.Mapping): | |||
return { | |||
key: move_to_device(value, device) | |||
for key, value in batch.items() | |||
} | |||
else: | |||
return batch | |||
def strip_pad(tensor, pad): | |||
return tensor[tensor.ne(pad)] | |||
def get_token_to_word_mapping(tokens, exclude_list): | |||
n = len(tokens) | |||
word_start = [int(token not in exclude_list) for token in tokens] | |||
word_idx = list(accumulate(word_start)) | |||
token_to_word = {i: word_idx[i] for i in range(n)} | |||
return token_to_word | |||
def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos): | |||
tgt_valid = (((tgt_sent != pad) & # noqa | |||
(tgt_sent != eos)).nonzero(as_tuple=False).squeeze(dim=-1)) | |||
src_invalid = (((src_sent == pad) | # noqa | |||
(src_sent == eos)).nonzero(as_tuple=False).squeeze(dim=-1)) | |||
src_token_to_word = get_token_to_word_mapping(src_sent, [eos, pad]) | |||
tgt_token_to_word = get_token_to_word_mapping(tgt_sent, [eos, pad]) | |||
alignment = [] | |||
if len(tgt_valid) != 0 and len(src_invalid) < len(src_sent): | |||
attn_valid = attn[tgt_valid] | |||
attn_valid[:, src_invalid] = float('-inf') | |||
_, src_indices = attn_valid.max(dim=1) | |||
for tgt_idx, src_idx in zip(tgt_valid, src_indices): | |||
alignment.append(( | |||
src_token_to_word[src_idx.item()] - 1, | |||
tgt_token_to_word[tgt_idx.item()] - 1, | |||
)) | |||
return alignment | |||
def softmax(x, dim: int, onnx_trace: bool = False): | |||
if onnx_trace: | |||
return F.softmax(x.float(), dim=dim) | |||
else: | |||
return F.softmax(x, dim=dim, dtype=torch.float32) | |||
def log_softmax(x, dim: int, onnx_trace: bool = False): | |||
if onnx_trace: | |||
return F.log_softmax(x.float(), dim=dim) | |||
else: | |||
return F.log_softmax(x, dim=dim, dtype=torch.float32) | |||
def extract_soft_alignment(attn, src_sent, tgt_sent, pad, eos): | |||
tgt_valid = (tgt_sent != pad).nonzero(as_tuple=False) | |||
src_valid = (src_sent != pad).nonzero(as_tuple=False).squeeze(dim=-1) | |||
alignment = [] | |||
if len(tgt_valid) != 0 and len(src_valid) != 0: | |||
attn_valid = attn[tgt_valid, src_valid] | |||
alignment = [['{:.6f}'.format(p) for p in src_probs.tolist()] | |||
for src_probs in attn_valid] | |||
return alignment |
@@ -0,0 +1,283 @@ | |||
import torch | |||
import torch.nn as nn | |||
def drop_path(x, drop_prob: float = 0., training: bool = False): | |||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, | |||
the original name is misleading as 'Drop Connect' is a.sh different form of dropout in a.sh separate paper... | |||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for | |||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a.sh layer name and use | |||
'survival rate' as the argument. | |||
""" | |||
if drop_prob == 0. or not training: | |||
return x | |||
keep_prob = 1 - drop_prob | |||
shape = (x.shape[0], ) + (1, ) * ( | |||
x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets | |||
random_tensor = keep_prob + torch.rand( | |||
shape, dtype=x.dtype, device=x.device) | |||
random_tensor.floor_() # binarize | |||
output = x.div(keep_prob) * random_tensor | |||
return output | |||
class DropPath(nn.Module): | |||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |||
""" | |||
def __init__(self, drop_prob=None): | |||
super(DropPath, self).__init__() | |||
self.drop_prob = drop_prob | |||
def forward(self, x): | |||
return drop_path(x, self.drop_prob, self.training) | |||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | |||
"""3x3 convolution with padding""" | |||
return nn.Conv2d( | |||
in_planes, | |||
out_planes, | |||
kernel_size=3, | |||
stride=stride, | |||
padding=dilation, | |||
groups=groups, | |||
bias=False, | |||
dilation=dilation) | |||
def conv1x1(in_planes, out_planes, stride=1): | |||
"""1x1 convolution""" | |||
return nn.Conv2d( | |||
in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |||
class BasicBlock(nn.Module): | |||
expansion = 1 | |||
def __init__(self, | |||
inplanes, | |||
planes, | |||
stride=1, | |||
downsample=None, | |||
groups=1, | |||
base_width=64, | |||
dilation=1, | |||
norm_layer=None): | |||
super(BasicBlock, self).__init__() | |||
if norm_layer is None: | |||
norm_layer = nn.BatchNorm2d | |||
if groups != 1 or base_width != 64: | |||
raise ValueError( | |||
'BasicBlock only supports groups=1 and base_width=64') | |||
if dilation > 1: | |||
raise NotImplementedError( | |||
'Dilation > 1 not supported in BasicBlock') | |||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1 | |||
self.conv1 = conv3x3(inplanes, planes, stride) | |||
self.bn1 = norm_layer(planes) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.conv2 = conv3x3(planes, planes) | |||
self.bn2 = norm_layer(planes) | |||
self.downsample = downsample | |||
self.stride = stride | |||
def forward(self, x): | |||
assert False | |||
identity = x | |||
out = self.conv1(x) | |||
out = self.bn1(out) | |||
out = self.relu(out) | |||
out = self.conv2(out) | |||
out = self.bn2(out) | |||
if self.downsample is not None: | |||
identity = self.downsample(x) | |||
out += identity | |||
out = self.relu(out) | |||
return out | |||
class Bottleneck(nn.Module): | |||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) | |||
# while original implementation places the stride at the first 1x1 convolution(self.conv1) | |||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. | |||
# This variant is also known as ResNet V1.5 and improves accuracy according to | |||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. | |||
expansion = 4 | |||
def __init__(self, | |||
inplanes, | |||
planes, | |||
stride=1, | |||
downsample=None, | |||
groups=1, | |||
base_width=64, | |||
dilation=1, | |||
norm_layer=None, | |||
drop_path_rate=0.0): | |||
super(Bottleneck, self).__init__() | |||
if norm_layer is None: | |||
norm_layer = nn.BatchNorm2d | |||
width = int(planes * (base_width / 64.)) * groups | |||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1 | |||
self.conv1 = conv1x1(inplanes, width) | |||
self.bn1 = norm_layer(width) | |||
self.conv2 = conv3x3(width, width, stride, groups, dilation) | |||
self.bn2 = norm_layer(width) | |||
self.conv3 = conv1x1(width, planes * self.expansion) | |||
self.bn3 = norm_layer(planes * self.expansion) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.downsample = downsample | |||
self.stride = stride | |||
self.drop_path = DropPath( | |||
drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() | |||
def forward(self, x): | |||
identity = x | |||
out = self.conv1(x) | |||
out = self.bn1(out) | |||
out = self.relu(out) | |||
out = self.conv2(out) | |||
out = self.bn2(out) | |||
out = self.relu(out) | |||
out = self.conv3(out) | |||
out = self.bn3(out) | |||
if self.downsample is not None: | |||
identity = self.downsample(x) | |||
out = identity + self.drop_path(out) | |||
out = self.relu(out) | |||
return out | |||
class ResNet(nn.Module): | |||
def __init__(self, | |||
layers, | |||
zero_init_residual=False, | |||
groups=1, | |||
width_per_group=64, | |||
replace_stride_with_dilation=None, | |||
norm_layer=None, | |||
drop_path_rate=0.0): | |||
super(ResNet, self).__init__() | |||
if norm_layer is None: | |||
norm_layer = nn.BatchNorm2d | |||
self._norm_layer = norm_layer | |||
self.inplanes = 64 | |||
self.dilation = 1 | |||
if replace_stride_with_dilation is None: | |||
# each element in the tuple indicates if we should replace | |||
# the 2x2 stride with a dilated convolution instead | |||
replace_stride_with_dilation = [False, False, False] | |||
if len(replace_stride_with_dilation) != 3: | |||
raise ValueError('replace_stride_with_dilation should be None ' | |||
'or a 3-element tuple, got {}'.format( | |||
replace_stride_with_dilation)) | |||
self.groups = groups | |||
self.base_width = width_per_group | |||
self.conv1 = nn.Conv2d( | |||
3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) | |||
self.bn1 = norm_layer(self.inplanes) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||
self.layer1 = self._make_layer( | |||
Bottleneck, 64, layers[0], drop_path_rate=drop_path_rate) | |||
self.layer2 = self._make_layer( | |||
Bottleneck, | |||
128, | |||
layers[1], | |||
stride=2, | |||
dilate=replace_stride_with_dilation[0], | |||
drop_path_rate=drop_path_rate) | |||
self.layer3 = self._make_layer( | |||
Bottleneck, | |||
256, | |||
layers[2], | |||
stride=2, | |||
dilate=replace_stride_with_dilation[1], | |||
drop_path_rate=drop_path_rate) | |||
for m in self.modules(): | |||
if isinstance(m, nn.Conv2d): | |||
nn.init.kaiming_normal_( | |||
m.weight, mode='fan_out', nonlinearity='relu') | |||
elif isinstance(m, | |||
(nn.SyncBatchNorm, nn.BatchNorm2d, nn.GroupNorm)): | |||
nn.init.constant_(m.weight, 1) | |||
nn.init.constant_(m.bias, 0) | |||
# Zero-initialize the last BN in each residual branch, | |||
# so that the residual branch starts with zeros, and each residual block behaves like an identity. | |||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | |||
if zero_init_residual: | |||
for m in self.modules(): | |||
if isinstance(m, Bottleneck): | |||
nn.init.constant_(m.bn3.weight, 0) | |||
elif isinstance(m, BasicBlock): | |||
nn.init.constant_(m.bn2.weight, 0) | |||
def _make_layer(self, | |||
block, | |||
planes, | |||
blocks, | |||
stride=1, | |||
dilate=False, | |||
drop_path_rate=0.0): | |||
norm_layer = self._norm_layer | |||
downsample = None | |||
previous_dilation = self.dilation | |||
if dilate: | |||
self.dilation *= stride | |||
stride = 1 | |||
if stride != 1 or self.inplanes != planes * block.expansion: | |||
downsample = nn.Sequential( | |||
conv1x1(self.inplanes, planes * block.expansion, stride), | |||
norm_layer(planes * block.expansion), | |||
) | |||
layers = [] | |||
layers.append( | |||
block(self.inplanes, planes, stride, downsample, self.groups, | |||
self.base_width, previous_dilation, norm_layer)) | |||
self.inplanes = planes * block.expansion | |||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, blocks)] | |||
for i in range(1, blocks): | |||
layers.append( | |||
block( | |||
self.inplanes, | |||
planes, | |||
groups=self.groups, | |||
base_width=self.base_width, | |||
dilation=self.dilation, | |||
norm_layer=norm_layer, | |||
drop_path_rate=dpr[i])) | |||
return nn.Sequential(*layers) | |||
def _forward_impl(self, x): | |||
x = self.conv1(x) | |||
x = self.bn1(x) | |||
x = self.relu(x) | |||
x = self.maxpool(x) | |||
x = self.layer1(x) | |||
x = self.layer2(x) | |||
x = self.layer3(x) | |||
return x | |||
def forward(self, x): | |||
return self._forward_impl(x) |
@@ -0,0 +1,48 @@ | |||
# Copyright 2022 OFA-Sys Team. All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""Tokenization classes for OFA.""" | |||
from transformers.models.bart.tokenization_bart import BartTokenizer | |||
from transformers.utils import logging | |||
logger = logging.get_logger(__name__) | |||
VOCAB_FILES_NAMES = {'vocab_file': 'vocab.json', 'merges_file': 'merges.txt'} | |||
PRETRAINED_VOCAB_FILES_MAP = { | |||
'vocab_file': { | |||
'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/vocab.json', | |||
}, | |||
'merges_file': { | |||
'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/merges.txt', | |||
}, | |||
} | |||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { | |||
'ofa-base': 1024, | |||
} | |||
class OFATokenizer(BartTokenizer): | |||
""" | |||
Construct a OFA tokenizer. | |||
[`~OFATokenizer`] is identical to [`BartTokenizer`] and runs end-to-end tokenization: punctuation splitting and | |||
wordpiece. | |||
Refer to superclass [`BartTokenizer`] for usage examples and documentation concerning parameters. | |||
""" | |||
vocab_files_names = VOCAB_FILES_NAMES | |||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP | |||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES |
@@ -0,0 +1,59 @@ | |||
# Copyright 2022 OFA-Sys Team. All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""Tokenization classes for OFA.""" | |||
from transformers.models.bart.tokenization_bart_fast import BartTokenizerFast | |||
from transformers.utils import logging | |||
from .tokenization_ofa import OFATokenizer | |||
logger = logging.get_logger(__name__) | |||
VOCAB_FILES_NAMES = { | |||
'vocab_file': 'vocab.json', | |||
'merges_file': 'merges.txt', | |||
'tokenizer_file': 'tokenizer.json' | |||
} | |||
PRETRAINED_VOCAB_FILES_MAP = { | |||
'vocab_file': { | |||
'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/vocab.json', | |||
}, | |||
'merges_file': { | |||
'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/merges.txt', | |||
}, | |||
'tokenizer_file': { | |||
'ofa-base': | |||
'https://huggingface.co/ofa-base/resolve/main/tokenizer.json', | |||
}, | |||
} | |||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { | |||
'ofa-base': 1024, | |||
} | |||
class OFATokenizerFast(BartTokenizerFast): | |||
r""" | |||
Construct a "fast" OFA tokenizer (backed by HuggingFace's *tokenizers* library). | |||
[`~OFATokenizerFast`] is identical to [`BartTokenizerFast`] and runs end-to-end tokenization: punctuation splitting | |||
and wordpiece. | |||
Refer to superclass [`BartTokenizerFast`] for usage examples and documentation concerning parameters. | |||
""" | |||
vocab_files_names = VOCAB_FILES_NAMES | |||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP | |||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES | |||
slow_tokenizer_class = OFATokenizer |
@@ -0,0 +1,53 @@ | |||
from typing import Any, Dict | |||
import torch.cuda | |||
from modelscope.metainfo import Models | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.utils.constant import Tasks | |||
from ..base import Model | |||
from ..builder import MODELS | |||
from .ofa import OFAModel, OFATokenizer | |||
from .ofa.generate import sequence_generator as sg | |||
from .ofa.generate.utils import move_to_device | |||
__all__ = ['OfaForImageCaptioning'] | |||
@MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa) | |||
class OfaForImageCaptioning(Model): | |||
def __init__(self, model_dir, *args, **kwargs): | |||
super().__init__(model_dir=model_dir, *args, **kwargs) | |||
model = OFAModel.from_pretrained(model_dir) | |||
self.model = model.module if hasattr(model, 'module') else model | |||
self.tokenizer = OFATokenizer.from_pretrained(model_dir) | |||
self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)]) | |||
self.tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)]) | |||
self._device = torch.device('cuda') if torch.cuda.is_available() \ | |||
else torch.device('cpu') | |||
self.model.to(self._device) | |||
# Initialize generator | |||
sg_args = { | |||
'tokenizer': self.tokenizer, | |||
'beam_size': 5, | |||
'max_len_b': 16, | |||
'min_len': 1, | |||
'no_repeat_ngram_size': 3, | |||
'constraint_range': None | |||
} | |||
if hasattr(kwargs, 'beam_search'): | |||
sg_args.update(kwargs['beam_search']) | |||
self.generator = sg.SequenceGenerator(**sg_args) | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
input = move_to_device(input, self._device) | |||
gen_output = self.generator.generate([self.model], input) | |||
gen = [gen_output[i][0]['tokens'] for i in range(len(gen_output))] | |||
result = self.tokenizer.batch_decode(gen, skip_special_tokens=True) | |||
return {'image_id': '42', OutputKeys.CAPTION: result[0]} | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
# What should we do here ? | |||
return inputs |
@@ -44,7 +44,7 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
'damo/nlp_space_dialog-modeling'), | |||
Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, | |||
'damo/nlp_space_dialog-state-tracking'), | |||
Tasks.image_captioning: (Pipelines.image_caption, | |||
Tasks.image_captioning: (Pipelines.image_captioning, | |||
'damo/ofa_image-caption_coco_large_en'), | |||
Tasks.image_generation: | |||
(Pipelines.person_image_cartoon, | |||
@@ -11,7 +11,7 @@ logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.image_captioning, module_name=Pipelines.image_caption) | |||
Tasks.image_captioning, module_name=Pipelines.image_captioning) | |||
class ImageCaptionPipeline(Pipeline): | |||
def __init__(self, | |||
@@ -2,13 +2,14 @@ | |||
import os.path as osp | |||
from typing import Any, Dict, Union | |||
import numpy as np | |||
import torch | |||
from PIL import Image | |||
from torchvision import transforms | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.metainfo import Preprocessors | |||
from modelscope.utils.constant import Fields, ModelFile | |||
from modelscope.models.multi_modal.ofa import OFATokenizer | |||
from modelscope.utils.constant import Fields | |||
from modelscope.utils.type_assert import type_assert | |||
from .base import Preprocessor | |||
from .builder import PREPROCESSORS | |||
@@ -31,84 +32,39 @@ class OfaImageCaptionPreprocessor(Preprocessor): | |||
model_dir (str): model path | |||
""" | |||
super().__init__(*args, **kwargs) | |||
model_dir = model_dir if osp.exists(model_dir) else snapshot_download( | |||
model_dir) | |||
self.tokenizer = OFATokenizer.from_pretrained(model_dir) | |||
self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)]) | |||
self.tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)]) | |||
if osp.exists(model_dir): | |||
local_model_dir = model_dir | |||
else: | |||
local_model_dir = snapshot_download(model_dir) | |||
local_model = osp.join(local_model_dir, ModelFile.TORCH_MODEL_FILE) | |||
bpe_dir = local_model_dir | |||
from fairseq import checkpoint_utils, tasks, utils | |||
from ofa.tasks.mm_tasks import CaptionTask | |||
tasks.register_task('caption', CaptionTask) | |||
overrides = { | |||
'bpe_dir': bpe_dir, | |||
'eval_cider': False, | |||
'beam': 5, | |||
'max_len_b': 16, | |||
'no_repeat_ngram_size': 3, | |||
'seed': 7 | |||
} | |||
model, cfg, task = checkpoint_utils.load_model_ensemble_and_task( | |||
utils.split_paths(local_model), arg_overrides=overrides) | |||
del model | |||
# Initialize transform | |||
from torchvision import transforms | |||
mean = [0.5, 0.5, 0.5] | |||
std = [0.5, 0.5, 0.5] | |||
patch_image_size = 480 | |||
self.patch_resize_transform = transforms.Compose([ | |||
lambda image: image.convert('RGB'), | |||
transforms.Resize( | |||
(cfg.task.patch_image_size, cfg.task.patch_image_size), | |||
interpolation=Image.BICUBIC), | |||
transforms.Resize((patch_image_size, patch_image_size), | |||
interpolation=Image.BICUBIC), | |||
transforms.ToTensor(), | |||
transforms.Normalize(mean=mean, std=std), | |||
]) | |||
self.task = task | |||
self.bos_item = torch.LongTensor([task.src_dict.bos()]) | |||
self.eos_item = torch.LongTensor([task.src_dict.eos()]) | |||
self.pad_idx = task.src_dict.pad() | |||
@type_assert(object, (str, tuple, Image.Image)) | |||
def __call__(self, data: Union[str, tuple]) -> Dict[str, Any]: | |||
def encode_text(text, length=None, append_bos=False, append_eos=False): | |||
s = self.task.tgt_dict.encode_line( | |||
line=self.task.bpe.encode(text), | |||
add_if_not_exist=False, | |||
append_eos=False).long() | |||
if length is not None: | |||
s = s[:length] | |||
if append_bos: | |||
s = torch.cat([self.bos_item, s]) | |||
if append_eos: | |||
s = torch.cat([s, self.eos_item]) | |||
return s | |||
if isinstance(data, Image.Image): | |||
patch_image = self.patch_resize_transform(data).unsqueeze(0) | |||
else: | |||
patch_image = self.patch_resize_transform( | |||
load_image(data)).unsqueeze(0) | |||
patch_mask = torch.tensor([True]) | |||
text = 'what does the image describe?' | |||
src_text = encode_text( | |||
text, append_bos=True, append_eos=True).unsqueeze(0) | |||
src_length = torch.LongTensor( | |||
[s.ne(self.pad_idx).long().sum() for s in src_text]) | |||
sample = { | |||
'id': np.array(['42']), | |||
'net_input': { | |||
'src_tokens': src_text, | |||
'src_lengths': src_length, | |||
'patch_images': patch_image, | |||
'patch_masks': patch_mask, | |||
} | |||
text = ' what does the image describe?' | |||
inputs = self.tokenizer([text], max_length=1024, | |||
return_tensors='pt')['input_ids'] | |||
sample = dict() | |||
sample['net_input'] = { | |||
'input_ids': inputs, | |||
'patch_images': patch_image, | |||
'patch_masks': torch.tensor([True]) | |||
} | |||
return sample | |||
@@ -14,7 +14,7 @@ class ImageCaptionTest(unittest.TestCase): | |||
def test_run(self): | |||
img_captioning = pipeline( | |||
Tasks.image_captioning, | |||
model='damo/ofa_image-caption_coco_large_en') | |||
model='damo/ofa_image-caption_coco_distilled_en') | |||
result = img_captioning('data/test/images/image_captioning.png') | |||
print(result[OutputKeys.CAPTION]) | |||