diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 33d62084..063b4d4f 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -73,7 +73,7 @@ class Pipelines(object): asr_inference = 'asr-inference' # multi-modal tasks - image_caption = 'image-captioning' + image_captioning = 'image-captioning' multi_modal_embedding = 'multi-modal-embedding' visual_question_answering = 'visual-question-answering' text_to_image_synthesis = 'text-to-image-synthesis' diff --git a/modelscope/models/multi_modal/__init__.py b/modelscope/models/multi_modal/__init__.py index a69491af..1f60878b 100644 --- a/modelscope/models/multi_modal/__init__.py +++ b/modelscope/models/multi_modal/__init__.py @@ -1,5 +1,5 @@ from .clip.clip_model import CLIPForMultiModalEmbedding -from .image_captioning_model import OfaForImageCaptioning from .imagen.imagen_model import ImagenForTextToImageSynthesis from .mplug_for_visual_question_answering import \ MPlugForVisualQuestionAnswering +from .ofa_for_image_captioning_model import OfaForImageCaptioning diff --git a/modelscope/models/multi_modal/ofa/__init__.py b/modelscope/models/multi_modal/ofa/__init__.py new file mode 100644 index 00000000..433e8266 --- /dev/null +++ b/modelscope/models/multi_modal/ofa/__init__.py @@ -0,0 +1,2 @@ +from .modeling_ofa import OFADecoder, OFAEncoder, OFAModel, OFAPreTrainedModel +from .tokenization_ofa import OFATokenizer diff --git a/modelscope/models/multi_modal/ofa/configuration_ofa.py b/modelscope/models/multi_modal/ofa/configuration_ofa.py new file mode 100644 index 00000000..4d28dcc5 --- /dev/null +++ b/modelscope/models/multi_modal/ofa/configuration_ofa.py @@ -0,0 +1,194 @@ +# Copyright 2022 Alibaba Group and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" OFA model configuration""" +import warnings + +from transformers import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +OFA_PRETRAINED_CONFIG_ARCHIVE_MAP = { + 'ofa-medium': 'https://huggingface.co/ofa-base/resolve/main/config.json', + # OFA models are implemeted to be compatible with both huggingface + # and modelscope frameworks. For all OFA models available on huggingface, + # please refer to https://huggingface.co/models?filter=ofa +} + + +class OFAConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~OFAModel`]. It is used to instantiate an OFA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the OFA [ofa-base](https://huggingface.co/ofa-base) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the OFA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~OFAModel`] or [`~TFOFAModel`]. + d_model (`int`, *optional*, defaults to 1024): + Dimension of the layers and the pooler layer. + encoder_layers (`int`, *optional*, defaults to 12): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 12): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier. + max_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + encoder_layerdrop: (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + decoder_layerdrop: (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + """ + + model_type = 'ofa' + keys_to_ignore_at_inference = ['past_key_values'] + + attribute_map = { + 'num_attention_heads': 'encoder_attention_heads', + 'hidden_size': 'd_model' + } + + def __init__(self, + vocab_size=59457, + max_position_embeddings=1024, + encoder_layers=4, + encoder_ffn_dim=512 * 4, + encoder_attention_heads=8, + decoder_layers=4, + decoder_ffn_dim=512 * 4, + decoder_attention_heads=8, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + use_cache=True, + is_encoder_decoder=True, + activation_function='gelu', + d_model=512, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + classifier_dropout=0.0, + scale_embedding=False, + pad_token_id=1, + bos_token_id=0, + decoder_start_token_id=0, + eos_token_id=2, + forced_eos_token_id=2, + encoder_normalize_before=True, + decoder_normalize_before=True, + normformer=True, + encoder_drop_path_rate=0.0, + decoder_drop_path_rate=0.0, + layernorm_embedding=True, + patch_layernorm_embedding=True, + resnet_type='resnet101', + resnet_model_path=None, + resnet_drop_path_rate=0.0, + token_bucket_size=256, + image_bucket_size=42, + add_type_embedding=True, + share_decoder_input_output_embed=True, + attn_scale_factor=2., + code_layernorm_embedding=True, + code_image_size=128, + entangle_position_embedding=False, + **kwargs): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + self.encoder_normalize_before = encoder_normalize_before + self.decoder_normalize_before = decoder_normalize_before + self.normformer = normformer + self.encoder_drop_path_rate = encoder_drop_path_rate + self.decoder_drop_path_rate = decoder_drop_path_rate + self.layernorm_embedding = layernorm_embedding + self.patch_layernorm_embedding = patch_layernorm_embedding + self.resnet_type = resnet_type + self.resnet_model_path = resnet_model_path + self.resnet_drop_path_rate = resnet_drop_path_rate + self.token_bucket_size = token_bucket_size + self.image_bucket_size = image_bucket_size + self.add_type_embedding = add_type_embedding + self.share_decoder_input_output_embed = share_decoder_input_output_embed + self.attn_scale_factor = attn_scale_factor + self.code_layernorm_embedding = code_layernorm_embedding + self.code_image_size = code_image_size + self.entangle_position_embedding = entangle_position_embedding + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + forced_eos_token_id=forced_eos_token_id, + **kwargs, + ) + + # ensure backward compatibility for BART CNN models + if self.forced_bos_token_id is None and kwargs.get( + 'force_bos_token_to_be_generated', False): + self.forced_bos_token_id = self.bos_token_id + warnings.warn( + f'Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. ' + 'The config can simply be saved and uploaded again to be fixed.' + ) diff --git a/modelscope/models/multi_modal/ofa/generate/__init__.py b/modelscope/models/multi_modal/ofa/generate/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/multi_modal/ofa/generate/incremental_decoding_utils.py b/modelscope/models/multi_modal/ofa/generate/incremental_decoding_utils.py new file mode 100644 index 00000000..db0df9b2 --- /dev/null +++ b/modelscope/models/multi_modal/ofa/generate/incremental_decoding_utils.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license which can be found at +# https://github.com/facebookresearch/fairseq/blob/main/LICENSE + +import uuid +from typing import Dict, Optional + +from torch import Tensor + + +class FairseqIncrementalState(object): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.init_incremental_state() + + def init_incremental_state(self): + self._incremental_state_id = str(uuid.uuid4()) + + def _get_full_incremental_state_key(self, key: str) -> str: + return '{}.{}'.format(self._incremental_state_id, key) + + def get_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + ) -> Optional[Dict[str, Optional[Tensor]]]: + """Helper for getting incremental state for an nn.Module.""" + full_key = self._get_full_incremental_state_key(key) + if incremental_state is None or full_key not in incremental_state: + return None + return incremental_state[full_key] + + def set_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + value: Dict[str, Optional[Tensor]], + ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: + """Helper for setting incremental state for an nn.Module.""" + if incremental_state is not None: + full_key = self._get_full_incremental_state_key(key) + incremental_state[full_key] = value + return incremental_state + + +def with_incremental_state(cls): + cls.__bases__ = (FairseqIncrementalState, ) + tuple( + b for b in cls.__bases__ if b != FairseqIncrementalState) + return cls diff --git a/modelscope/models/multi_modal/ofa/generate/multihead_attention.py b/modelscope/models/multi_modal/ofa/generate/multihead_attention.py new file mode 100644 index 00000000..9101d52d --- /dev/null +++ b/modelscope/models/multi_modal/ofa/generate/multihead_attention.py @@ -0,0 +1,510 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license which can be found at +# https://github.com/facebookresearch/fairseq/blob/main/LICENSE + +import math +from typing import Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.quant_noise import quant_noise +from torch import Tensor, nn +from torch.nn import Parameter + + +@with_incremental_state +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__) + + self.head_dim = embed_dim // num_heads + assert (self.head_dim * num_heads == self.embed_dim + ), 'embed_dim must be divisible by num_heads' + self.scaling = self.head_dim**-0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + 'Self-attention requires query, key and ' + 'value to be of the same size') + + self.k_proj = quant_noise( + nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size) + self.q_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.reset_parameters() + + self.onnx_trace = False + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, + Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == 'xla' + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim, f'query dim {embed_dim} != {self.embed_dim}' + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if (not self.onnx_trace + and not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting()): + assert key is not None and value is not None + return F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat( + (self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training or self.dropout_module.apply_during_inference, + key_padding_mask, + need_weights, + attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and 'prev_key' in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, + attn_mask.new_zeros(attn_mask.size(0), 1)], + dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros( + key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous().view(tgt_len, bsz * self.num_heads, + self.head_dim).transpose(0, 1)) + if k is not None: + k = ( + k.contiguous().view(-1, bsz * self.num_heads, + self.head_dim).transpose(0, 1)) + if v is not None: + v = ( + v.contiguous().view(-1, bsz * self.num_heads, + self.head_dim).transpose(0, 1)) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if 'prev_key' in saved_state: + _prev_key = saved_state['prev_key'] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, + self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if 'prev_value' in saved_state: + _prev_value = saved_state['prev_value'] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, + self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if 'prev_key_padding_mask' in saved_state: + prev_key_padding_mask = saved_state['prev_key_padding_mask'] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, + self.head_dim) + saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, + self.head_dim) + saved_state['prev_key_padding_mask'] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, + saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], + dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], + dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, + attn_mask.new_zeros(attn_mask.size(0), 1)], + dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), + 1).type_as(key_padding_mask), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, + bsz) + + assert list( + attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + if self.onnx_trace: + attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, + src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float('-inf'), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill( + key_padding_mask, float('-inf')) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, + src_len) + + if before_softmax: + return attn_weights, v + + attn_weights_float = utils.softmax( + attn_weights, dim=-1, onnx_trace=self.onnx_trace) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list( + attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + if self.onnx_trace and attn.size(1) == 1: + # when ONNX tracing a single decoder step (sequence length == 1) + # the transpose is a no-op copy before view, thus unnecessary + attn = attn.contiguous().view(tgt_len, bsz, embed_dim) + else: + attn = attn.transpose(0, + 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view(bsz, self.num_heads, + tgt_len, + src_len).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), + key_padding_mask.float()], + dim=1) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), + filler.float()], dim=1) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + @torch.jit.export + def reorder_incremental_state( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, + ): + """Reorder buffered internal state (for incremental generation).""" + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + for k in input_buffer.keys(): + input_buffer_k = input_buffer[k] + if input_buffer_k is not None: + if self.encoder_decoder_attention and input_buffer_k.size( + 0) == new_order.size(0): + break + input_buffer[k] = input_buffer_k.index_select(0, new_order) + incremental_state = self._set_input_buffer(incremental_state, + input_buffer) + return incremental_state + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, + Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, 'attn_state') + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, 'attn_state', + buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, + bsz: int): + return attn_weights + + def upgrade_state_dict_named(self, state_dict, name): + prefix = name + '.' if name != '' else '' + items_to_add = {} + keys_to_remove = [] + for k in state_dict.keys(): + if k.endswith(prefix + 'in_proj_weight'): + # in_proj_weight used to be q + k + v with same dimensions + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + 'q_proj.weight'] = state_dict[k][:dim] + items_to_add[prefix + 'k_proj.weight'] = state_dict[k][dim:2 + * dim] + items_to_add[prefix + 'v_proj.weight'] = state_dict[k][2 + * dim:] + + keys_to_remove.append(k) + + k_bias = prefix + 'in_proj_bias' + if k_bias in state_dict.keys(): + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + + 'q_proj.bias'] = state_dict[k_bias][:dim] + items_to_add[prefix + + 'k_proj.bias'] = state_dict[k_bias][dim:2 + * dim] + items_to_add[prefix + + 'v_proj.bias'] = state_dict[k_bias][2 + * dim:] + + keys_to_remove.append(prefix + 'in_proj_bias') + + for k in keys_to_remove: + del state_dict[k] + + for key, value in items_to_add.items(): + state_dict[key] = value diff --git a/modelscope/models/multi_modal/ofa/generate/ngram_repeat_block.py b/modelscope/models/multi_modal/ofa/generate/ngram_repeat_block.py new file mode 100644 index 00000000..4bccfa76 --- /dev/null +++ b/modelscope/models/multi_modal/ofa/generate/ngram_repeat_block.py @@ -0,0 +1,155 @@ +# Originally from Microsoft Corporation. +# Licensed under the MIT License. +""" Wrapper for ngram_repeat_block cuda extension """ +import math +import warnings +from typing import Dict, List + +import torch +from torch import nn + +try: + from fairseq import ngram_repeat_block_cuda + + EXTENSION_BUILT = True +except ImportError: + EXTENSION_BUILT = False + + +def is_cuda_extension_usable() -> bool: + """Check whether ngram_repeat_block_cuda is built properly""" + if not EXTENSION_BUILT or not torch.cuda.is_available(): + return False + bsz = 2 + tokens = torch.tensor([[4, 4, 3, 2], [1, 2, 3, 4]], + dtype=torch.long, + device='cuda') + lprobs = torch.rand((8, 12), device='cuda') + try: + outputs = ngram_repeat_block_cuda.forward(tokens, lprobs, bsz, 3, 4, 3) + outputs = outputs + 4 # This line breaks if the extension is built incorrectly. + return True + except RuntimeError: + warnings.warn( + 'NGramRepeatBlock extension must be rebuilt.' + 'Run TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0" python setup.py build_ext --inplace' + ) + return False + + +class NGramRepeatBlock(nn.Module): + """ Wrapper class for calling ngram_repeat_block cuda extension """ + + def __init__(self, no_repeat_ngram_size: int, use_extension: bool = True): + super().__init__() + self.use_extension = is_cuda_extension_usable( + ) if use_extension else False + self.no_repeat_ngram_size = no_repeat_ngram_size + + def reset_parameters(self): + pass + + @torch.jit.unused + def call_cuda_extension( + self, + tokens, + lprobs, + bsz: int, + beam_size: int, + step: int, + ): + return ngram_repeat_block_cuda.forward(tokens, lprobs, bsz, step, + beam_size, + self.no_repeat_ngram_size) + + def forward( + self, + tokens, + lprobs, + bsz: int, + beam_size: int, + step: int, + ): + """ + Args: + tokens(Tensor): Input tokens(Bsz*beam, seq_len) + lprobs(Tensor): likelihood probability, + Expected to be updated in place.(Bsz*beam, vocab_size) + bsz(int): batch size + step(int): current step + beam_size(int): beam size + no_repeat_ngram_size(int): Ngram size + """ + msg = f'expected {bsz * beam_size} got' + assert tokens.size(0) == bsz * beam_size, f'{msg} {tokens.size(0)}' + assert lprobs.size(0) == bsz * beam_size, f'{msg} {lprobs.size(0)}' + if self.use_extension: + return self.call_cuda_extension(tokens, lprobs, bsz, beam_size, + step) + + else: + return self._no_repeat_ngram( + tokens, + lprobs, + bsz, + beam_size, + step, + ) + + def _no_repeat_ngram(self, tokens, lprobs, bsz: int, beam_size: int, + step: int): + """For each hypothesis generate a list of previous ngrams and set associated lprobs to -inf""" + gen_ngrams: List[Dict[str, List[int]]] = [ + torch.jit.annotate(Dict[str, List[int]], {}) + for bbsz_idx in range(bsz * beam_size) + ] + cpu_tokens = tokens.cpu() + for bbsz_idx in range(bsz * beam_size): + gen_tokens: List[int] = cpu_tokens[bbsz_idx].tolist() + for ngram in self.transpose_list([ + gen_tokens[i:] for i in range(self.no_repeat_ngram_size) + ]): # noqa + key = ','.join([str(x) for x in ngram[:-1]]) + gen_ngrams[bbsz_idx][key] = gen_ngrams[bbsz_idx].get( + key, torch.jit.annotate(List[int], [])) + [ngram[-1]] + if step + 2 - self.no_repeat_ngram_size >= 0: + # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet + banned_tokens = [ + self.calculate_banned_tokens(tokens, step, gen_ngrams, + self.no_repeat_ngram_size, + bbsz_idx) + for bbsz_idx in range(bsz * beam_size) + ] + else: + banned_tokens = [ + torch.jit.annotate(List[int], []) + for bbsz_idx in range(bsz * beam_size) + ] + for bbsz_idx in range(bsz * beam_size): + lprobs[bbsz_idx][torch.tensor( + banned_tokens[bbsz_idx], + dtype=torch.int64)] = torch.tensor(-math.inf).to(lprobs) + return lprobs + + @staticmethod + def calculate_banned_tokens( + tokens, + step: int, + gen_ngrams: List[Dict[str, List[int]]], + no_repeat_ngram_size: int, + bbsz_idx: int, + ): + tokens_list: List[int] = tokens[bbsz_idx, + step + 2 - no_repeat_ngram_size:step + + 1].tolist() # noqa + # before decoding the next token, prevent decoding of ngrams that have already appeared + ngram_index = ','.join([str(x) for x in tokens_list]) + return gen_ngrams[bbsz_idx].get(ngram_index, + torch.jit.annotate(List[int], [])) + + @staticmethod + def transpose_list(l: List[List[int]]): # noqa + # GeneratorExp aren't supported in TS so ignoring the lint + min_len = min([len(x) for x in l]) # noqa + l2 = [[row[i] for row in l] for i in range(min_len)] + return l2 diff --git a/modelscope/models/multi_modal/ofa/generate/search.py b/modelscope/models/multi_modal/ofa/generate/search.py new file mode 100644 index 00000000..63ecb0a9 --- /dev/null +++ b/modelscope/models/multi_modal/ofa/generate/search.py @@ -0,0 +1,848 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license which can be found at +# https://github.com/facebookresearch/fairseq/blob/main/LICENSE + +import math +from typing import List, Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from .token_generation_constraints import (ConstraintState, + OrderedConstraintState, + UnorderedConstraintState) + + +class Search(nn.Module): + + def __init__(self, tokenizer): + super().__init__() + self.pad = tokenizer.pad_token_id + self.unk = tokenizer.unk_token_id + self.eos = tokenizer.eos_token_id + tgt_dict = {value: key for key, value in tokenizer.get_vocab().items()} + added = { + value: key + for key, value in tokenizer.get_added_vocab().items() + } + tgt_dict.update(added) + self.vocab_size = len(tgt_dict) + self.src_lengths = torch.tensor(-1) + self.supports_constraints = False + self.stop_on_max_len = False + + def step(self, + step, + lprobs, + scores, + prev_output_tokens=None, + original_batch_idxs=None): + """Take a single search step. + + Args: + step: the current search step, starting at 0 + lprobs: (bsz x input_beam_size x vocab_size) + the model's log-probabilities over the vocabulary at the current step + scores: (bsz x input_beam_size x step) + the historical model scores of each hypothesis up to this point + prev_output_tokens: (bsz x step) + the previously generated oputput tokens + original_batch_idxs: (bsz) + the tensor with the batch indices, in the range [0, bsz) + this is useful in case there has been applied a re-ordering + and we need to know the orignal indices + + Return: A tuple of (scores, indices, beams) where: + scores: (bsz x output_beam_size) + the scores of the chosen elements; output_beam_size can be + larger than input_beam_size, e.g., we may return + 2*input_beam_size to account for EOS + indices: (bsz x output_beam_size) + the indices of the chosen elements + beams: (bsz x output_beam_size) + the hypothesis ids of the chosen elements, in the range [0, input_beam_size) + """ + raise NotImplementedError + + @torch.jit.export + def set_src_lengths(self, src_lengths): + self.src_lengths = src_lengths + + @torch.jit.export + def init_constraints(self, batch_constraints: Optional[Tensor], + beam_size: int): + """Initialize constraint states for constrained decoding (if supported). + + Args: + batch_constraints: (torch.Tensor, optional) + the list of constraints, in packed form + beam_size: (int) + the beam size + Returns: + *encoder_out* rearranged according to *new_order* + """ + pass + + def prune_sentences(self, batch_idxs: Tensor): + """ + Removes constraint states for completed sentences (if supported). + This is called from sequence_generator._generate() when sentences are + deleted from the batch. + + Args: + batch_idxs: Indices of *sentences* whose constraint state should be *kept*. + """ + pass + + def update_constraints(self, active_hypos: Tensor): + """ + Updates the constraint states by selecting the beam items that are retained. + This is called at each time step of sequence_generator._generate() when + the set of 2 * {beam_size} candidate hypotheses are reduced to the beam size. + + Args: + active_hypos: (batch size, beam size) + list of integers denoting, for each sentence, which beam candidate items + should be kept. + """ + pass + + +class BeamSearch(Search): + + def __init__(self, tgt_dict): + super().__init__(tgt_dict) + self.constraint_states = None + + @torch.jit.export + def step( + self, + step: int, + lprobs, + scores: Optional[Tensor], + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + bsz, beam_size, vocab_size = lprobs.size() + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + lprobs = lprobs[:, ::beam_size, :].contiguous() + else: + # make probs contain cumulative scores for each hypothesis + assert scores is not None + lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) + + top_prediction = torch.topk( + lprobs.view(bsz, -1), + k=min( + # Take the best 2 x beam_size predictions. We'll choose the first + # beam_size of these which don't predict eos to continue with. + beam_size * 2, + lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad + ), + ) + scores_buf = top_prediction[0] + indices_buf = top_prediction[1] + # Project back into relative indices and beams + beams_buf = indices_buf // vocab_size + indices_buf = indices_buf.fmod(vocab_size) + + # At this point, beams_buf and indices_buf are single-dim and contain relative indices + return scores_buf, indices_buf, beams_buf + + +class PrefixConstrainedBeamSearch(Search): + + def __init__(self, tgt_dict, prefix_allowed_tokens_fn): + super().__init__(tgt_dict) + self.prefix_allowed_tokens_fn = prefix_allowed_tokens_fn + self.stop_on_max_len = True + + @torch.jit.export + def apply_mask(self, x, prev_output_tokens, original_batch_idxs): + beam_size = x.shape[0] // original_batch_idxs.shape[0] + original_batch_idxs = ( + original_batch_idxs.unsqueeze(-1).repeat( + (1, beam_size)).flatten().tolist()) + + mask = torch.full_like(x, -math.inf) + for sent_i, (sent, batch_i) in enumerate( + zip(prev_output_tokens, original_batch_idxs)): + mask[sent_i, :, self.prefix_allowed_tokens_fn(batch_i, sent)] = 0 + + return mask + + @torch.jit.export + def step( + self, + step: int, + lprobs: Tensor, + scores: Tensor, + prev_output_tokens: Tensor, + original_batch_idxs: Tensor, + ): + bsz, beam_size, vocab_size = lprobs.size() + + lprobs += self.apply_mask( + lprobs.view(bsz * beam_size, 1, vocab_size), + prev_output_tokens, + original_batch_idxs, + ).view(bsz, beam_size, vocab_size) + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + lprobs = lprobs[:, ::beam_size, :].contiguous() + else: + # make probs contain cumulative scores for each hypothesis + assert scores is not None + lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) + + top_prediction = torch.topk( + lprobs.view(bsz, -1), + k=min( + # Take the best beam_size predictions. We'll choose the first + # beam_size of these which don't predict eos to continue with. + beam_size, + lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad + ), + ) + scores_buf = top_prediction[0] + indices_buf = top_prediction[1] + beams_buf = indices_buf // vocab_size + indices_buf = indices_buf.fmod(vocab_size) + return scores_buf, indices_buf, beams_buf + + +class LexicallyConstrainedBeamSearch(Search): + """Implements lexically constrained beam search as described in + + Fast Lexically Constrained Decoding with Dynamic Beam + Allocation for Neural Machine Translation. Post & Vilar, + NAACL 2018. https://www.aclweb.org/anthology/N18-1119/ + + and + + Improved Lexically Constrained Decoding for Translation and + Monolingual Rewriting. Hu et al, NAACL + 2019. https://www.aclweb.org/anthology/N19-1090/ + + This is accomplished by maintaining, for each beam hypothesis, a + ConstraintState object (see constraints.py) that tracks which + constraints have been generated and using this information to + shape the beam for each input sentence. + """ + + def __init__(self, tokenizer, representation): + super().__init__(tokenizer) + self.representation = representation + tgt_dict = {value: key for key, value in tokenizer.get_vocab().items()} + added = { + value: key + for key, value in tokenizer.get_added_vocab().items() + } + tgt_dict.update(added) + self.vocab_size = len(tgt_dict) + self.num_cands = 0 + self.supports_constraints = True + + @torch.jit.export + def init_constraints(self, batch_constraints: Optional[Tensor], + beam_size: int): + self.constraint_states = [] + for constraint_tensor in batch_constraints: + if self.representation == 'ordered': + constraint_state = OrderedConstraintState.create( + constraint_tensor) + elif self.representation == 'unordered': + constraint_state = UnorderedConstraintState.create( + constraint_tensor) + + self.constraint_states.append( + [constraint_state for i in range(beam_size)]) + + @torch.jit.export + def prune_sentences(self, batch_idxs: Tensor): + self.constraint_states = [ + self.constraint_states[i] for i in batch_idxs.tolist() + ] + + @torch.jit.export + def update_constraints(self, active_hypos: Tensor): + if self.constraint_states: + batch_size = active_hypos.size(0) + for sentid in range(batch_size): + self.constraint_states[sentid] = [ + self.constraint_states[sentid][i] + for i in active_hypos[sentid] + ] + + @torch.jit.export + def step( + self, + step: int, + lprobs: Tensor, + scores: Optional[Tensor], + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + """ + A constrained step builds a large candidates list from the following: + - the top 2 * {beam_size} items over the whole beam + - for each item in the beam + - the top {each_k} (default 1) + - all next constraints + We then compute the constrained state of each beam item, and assign + stripe codes: 0 to the best in each bank, 1 to the 2nd-best, and so + on. We then sort by (stripe, score), and truncate the list at + 2 * beam size. + + Args: + step: the decoder step + lprobs: (batch size, beam size, target vocab) + the target-vocab distributions for each item in the beam. + Retrun: A tuple of (scores, indices, beams, constraints) where: + scores: (batch, output beam size) + the scores of the chosen elements + indices: (batch, output beam size) + the target vocab indices of the chosen elements + beams: (batch, output beam size) + the 0-indexed hypothesis ids of the chosen elements + constraints: (batch, output beam size) + the new constraint states + """ + each_k = 1 + device = lprobs.device + + batch_size, beam_size, vocab_size = lprobs.size() + + self.num_cands = min( + # Just take the k-best. We'll get another k from the 1-best from each + # row, plus more from the constraints + beam_size * 2, + lprobs.view(batch_size, -1).size(1) + - 1, # -1 so we never select pad + ) + + # STEP 0: Preliminary. Prevent EOS for unfinished hyps across all batch items + constraint_states = self.constraint_states + if constraint_states and step > 0: + not_finished_indices = [] + for sentno, sent_constraints in enumerate(constraint_states): + for beamno, state in enumerate(sent_constraints): + index = sentno * beam_size + beamno + if not state.finished: + not_finished_indices.append(index) + not_finished_indices = torch.tensor(not_finished_indices) + if not_finished_indices.numel() > 0: + lprobs.view(batch_size * beam_size, -1)[not_finished_indices, + self.eos] = -math.inf + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam entry for each batch item + lprobs = lprobs[:, ::beam_size, :].contiguous() + else: + # make probs contain cumulative scores for each hypothesis + assert scores is not None + lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) + + top_prediction = torch.topk( + lprobs.view(batch_size, -1), + self.num_cands, + ) + scores_buf, indices_buf = top_prediction + # Project back into relative indices and beams + beams_buf = indices_buf // vocab_size + indices_buf = indices_buf.fmod(vocab_size) + + # Short circuit if there are no constraints in this batch + if not constraint_states: + return scores_buf, indices_buf, beams_buf + + # STEP 1: get top-1 from each hypothesis across all sentences in the batch + if step > 0: + top_scores, top_indices = torch.topk( + lprobs.view(batch_size * beam_size, -1), + k=each_k, + dim=1, + ) + top_scores = top_scores.view(batch_size, -1) + top_indices = top_indices.view(batch_size, -1) + scores_buf = torch.cat((scores_buf, top_scores), dim=1) + indices_buf = torch.cat((indices_buf, top_indices), dim=1) + new_beams = torch.arange( + 0, beam_size, device=device).repeat(batch_size, 1) + beams_buf = torch.cat((beams_buf, new_beams), dim=1) + + # Now, process sentences in the batch one by one. + new_scores_buf = torch.zeros((batch_size, 2 * beam_size), + device=device) + new_indices_buf = torch.zeros((batch_size, 2 * beam_size), + device=device).long() + new_beams_buf = torch.zeros((batch_size, 2 * beam_size), + device=device).long() + for sentno, states in enumerate(constraint_states): + scores, indices, beams, new_states = self.step_sentence( + step, + sentno, + lprobs[sentno], + constraint_states[sentno], + beams_buf[sentno].clone(), + indices_buf[sentno].clone(), + scores_buf[sentno].clone(), + ) + new_scores_buf[sentno] = scores + new_indices_buf[sentno] = indices + new_beams_buf[sentno] = beams + self.constraint_states[sentno] = new_states + + return new_scores_buf, new_indices_buf, new_beams_buf + + @torch.jit.export + def step_sentence( + self, + step: int, + sentno: int, + lprobs: Tensor, + constraint_states: List[List[ConstraintState]], + beams_buf: Tensor, + indices_buf: Tensor, + scores_buf: Tensor, + ): + """Does per-sentence processing. Adds all constraints for each + hypothesis to the list of candidates; then removes duplicates, + sorts, and dynamically stripes across the banks. All tensor inputs + are collapsed to those pertaining to a single input sentence. + """ + device = lprobs.device + + # STEP 2: Add all constraints for each beam item + for beamno, state in enumerate(constraint_states): + next_tokens = torch.tensor( + list(state.next_tokens()), device=device).long() + if next_tokens.numel() != 0: + indices_buf = torch.cat((indices_buf, next_tokens)) + next_beams = ( + torch.tensor(beamno, device=device).repeat( + next_tokens.size(0)).long()) + beams_buf = torch.cat((beams_buf, next_beams)) + next_values = lprobs[beamno].take(next_tokens.view(-1)) + scores_buf = torch.cat((scores_buf, next_values)) + + # At the 0th time step, there is just one beam item + if step == 0: + break + + # STEP 3: Compute the "bank" for each candidate. This is the + # number of constraints it's generated. We need this so that + # we can do round-robin allocation of the beam across these + # banks. If C is the number of constraints, we select the best + # item in bank C, then the best in bank C-1, etc, followed by + # the 2nd-best in bank C, the 2nd-best in bank C-1, etc, and so + # on, until the maximum beam size. We accomplish this by + # creating a sort key and striping across the banks. + + # Compute the new states for all candidates + cands_size = indices_buf.size(0) + constraint_states = [ + constraint_states[beams_buf[i]].advance(indices_buf[i]) + for i in range(cands_size) + ] + + banks = torch.tensor([state.bank for state in constraint_states], + device=device) + + # STEP 4: Sort + num_constraint_tokens = len(state.tokens) + + # Sort by keys (bank, score) (i.e., sort banks together, and scores + # within banks). AFAIK pytorch doesn't support either stable sort or + # multi-key sorting, so we have to hack this. + MAX_SCORE = -100 + sort_key = (num_constraint_tokens - banks) * MAX_SCORE + scores_buf + sort_values, sort_indices = sort_key.sort(dim=0, descending=True) + scores_buf = scores_buf[sort_indices] + indices_buf = indices_buf[sort_indices] + beams_buf = beams_buf[sort_indices] + banks = banks[sort_indices] + + # Sort the constraints to follow suit + constraint_states = [constraint_states[i] for i in sort_indices] + + # STEP 5: Remove duplicates. The topk calls (overall and + # per-row) plus the per-row generation of constraints will + # produce duplicates. Here we remove them. + + def roll(t): + """Rolls a 1d tensor left by 1. + + [0, 1, 2, 3, 4] becomes [4, 0, 1, 2, 3] + """ + return torch.cat((t[-1].unsqueeze(0), t[0:-1]), dim=0) + + # We map candidates (beam, token_id) to a single dimension. + # This is then shifted by 1. We can then easily identify + # duplicates and create a mask that identifies unique + # extensions. + uniques_mask = beams_buf * (self.vocab_size + 1) + indices_buf + uniques_mask = roll(uniques_mask) != uniques_mask + + # Use the mask to pare down the data structures + scores_buf = torch.masked_select(scores_buf, uniques_mask) + indices_buf = torch.masked_select(indices_buf, uniques_mask) + beams_buf = torch.masked_select(beams_buf, uniques_mask) + banks = torch.masked_select(banks, uniques_mask) + i = 1 + for mask in uniques_mask[1:]: + if not mask: + constraint_states.pop(i) + i += mask + + # STEP 6: Assign IDs round-robin across banks, sort, and + # truncate. Now that the candidates are sorted by (bank, + # score) and uniqed, we dynamically allocate the {beam_size} + # beam by striping across the candidates. These stripes will + # be used as sort keys to do round-robin selection. This is + # accomplished in a single pass with offsets. Sorting by + # highest-banks (furthest-along hypotheses) first ensures + # progress through the constraints. + # + # e.g., BANKS: 3 3 3 2 2 2 2 1 1 1 0 0 + # OLD STRIPES: 0 1 2 0 1 2 3 0 1 2 0 1 + # NEW STRIPES: 0 1+4 2+8 0+1 1+5 2+9 3+11 0+2 1+6 2+10 0+3 1+7 + # = 0 5 10 1 6 11 13 2 7 12 3 8 + # + # Sorting by this then gives the following banks: + # + # 3 2 1 0 3 2 1 0 3 2 1 2 + # + # We'll take the top {beam_size} of these. + stripe_offsets = [ + offset * (len(banks) + 1) for offset in range(len(banks) + 1) + ] + stripes = torch.zeros_like(banks) + cur_bank_count = -1 + cur_bank = banks[0] + for i, bank in enumerate(banks): + if bank != cur_bank: + cur_bank_count = 0 + cur_bank = bank + else: + cur_bank_count += 1 + stripes[i] = num_constraint_tokens - bank + stripe_offsets[ + cur_bank_count] + + # STEP 7: Sort by the stripes values + sort_values, sort_indices = stripes.sort(dim=0) + scores_buf = scores_buf[sort_indices] + indices_buf = indices_buf[sort_indices] + beams_buf = beams_buf[sort_indices] + constraint_states = [constraint_states[i] for i in sort_indices] + + # STEP 8: Truncate to the candidates size! + scores_buf = scores_buf[:self.num_cands] + indices_buf = indices_buf[:self.num_cands] + beams_buf = beams_buf[:self.num_cands] + + return scores_buf, indices_buf, beams_buf, constraint_states + + +class LengthConstrainedBeamSearch(Search): + + def __init__(self, tgt_dict, min_len_a, min_len_b, max_len_a, max_len_b): + super().__init__(tgt_dict) + self.min_len_a = min_len_a + self.min_len_b = min_len_b + self.max_len_a = max_len_a + self.max_len_b = max_len_b + self.beam = BeamSearch(tgt_dict) + self.needs_src_lengths = True + + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + min_lens = self.min_len_a * self.src_lengths + self.min_len_b + max_lens = self.max_len_a * self.src_lengths + self.max_len_b + lprobs[step < min_lens, :, self.eos] = -math.inf + lprobs[step >= max_lens, :, self.eos] = 0 + return self.beam.step(step, lprobs, scores) + + +class DiverseBeamSearch(Search): + """Diverse Beam Search. + + See "Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence + Models" for details. + + We only implement the Hamming Diversity penalty here, which performed best + in the original paper. + """ + + def __init__(self, tgt_dict, num_groups, diversity_strength): + super().__init__(tgt_dict) + self.num_groups = num_groups + self.diversity_strength = -diversity_strength + self.beam = BeamSearch(tgt_dict) + + @torch.jit.export + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + bsz, beam_size, vocab_size = lprobs.size() + if beam_size % self.num_groups != 0: + raise ValueError( + 'DiverseBeamSearch requires --beam to be divisible by the number of groups' + ) + + # initialize diversity penalty + diversity_buf = torch.zeros(lprobs[:, 0, :].size()).to(lprobs) + + scores_G, indices_G, beams_G = [], [], [] + for g in range(self.num_groups): + lprobs_g = lprobs[:, g::self.num_groups, :] + scores_g = scores[:, g::self.num_groups, :] if step > 0 else None + + # apply diversity penalty + if g > 0: + lprobs_g = torch.add( + lprobs_g, + other=diversity_buf.unsqueeze(1), + alpha=self.diversity_strength, + ) + else: + lprobs_g = lprobs_g.contiguous() + + scores_buf, indices_buf, beams_buf = self.beam.step( + step, lprobs_g, scores_g) + beams_buf.mul_(self.num_groups).add_(g) + + scores_G.append(scores_buf.clone()) + indices_G.append(indices_buf.clone()) + beams_G.append(beams_buf.clone()) + + # update diversity penalty + diversity_buf.scatter_add_( + 1, indices_buf, + torch.ones(indices_buf.size()).to(diversity_buf)) + + # interleave results from different groups + scores_buf = torch.stack(scores_G, dim=2).view(bsz, -1) + indices_buf = torch.stack(indices_G, dim=2).view(bsz, -1) + beams_buf = torch.stack(beams_G, dim=2).view(bsz, -1) + return scores_buf, indices_buf, beams_buf + + +class Sampling(Search): + sampling_topk: int + sampling_topp: float + + def __init__(self, tgt_dict, sampling_topk=-1, sampling_topp=-1.0): + super().__init__(tgt_dict) + self.sampling_topk = sampling_topk + self.sampling_topp = sampling_topp + + def _sample_topp(self, lprobs): + """Sample among the smallest set of elements whose cumulative probability mass exceeds p. + + See `"The Curious Case of Neural Text Degeneration" + (Holtzman et al., 2019) `_. + + Args: + lprobs: (bsz x input_beam_size x vocab_size) + the model's log-probabilities over the vocabulary at the current step + + Return: A tuple of (trimed_probs, truncated_indices) where: + trimed_probs: (bsz x input_beam_size x ?) + the model's probabilities over the elements selected to sample from. The + width of the third dimension is determined by top-P. + truncated_indices: (bsz x input_beam_size x ?) + the indices of the chosen elements. + """ + probs = lprobs.exp_() + + # sort the last dimension (vocab dimension) in descending order + sorted_probs, sorted_indices = probs.sort(descending=True) + + # compute a mask to indicate the words to be included in the top-P set. + cumsum_probs = sorted_probs.cumsum(dim=2) + mask = cumsum_probs.lt(self.sampling_topp) + + # note that mask was computed by 'lt'. One more word needs to be included + # so that the cumulative probability mass can exceed p. + cumsum_mask = mask.cumsum(dim=2) + last_included = cumsum_mask[:, :, -1:] + last_included.clamp_(0, mask.size()[2] - 1) + mask = mask.scatter_(2, last_included, 1) + + # truncate unnecessary dims. + max_dim = last_included.max() + truncated_mask = mask[:, :, :max_dim + 1] + truncated_probs = sorted_probs[:, :, :max_dim + 1] + truncated_indices = sorted_indices[:, :, :max_dim + 1] + + # trim the words that are not in top-P by setting their probabilities + # to 0, so that they would not be sampled later. + trim_mask = ~truncated_mask + trimed_probs = truncated_probs.masked_fill_(trim_mask, 0) + return trimed_probs, truncated_indices + + @torch.jit.export + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + bsz, beam_size, vocab_size = lprobs.size() + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + lprobs = lprobs[:, ::beam_size, :].contiguous() + + if self.sampling_topp > 0: + # only sample from the smallest set of words whose cumulative probability mass exceeds p + probs, top_indices = self._sample_topp(lprobs) + elif self.sampling_topk > 0: + # only sample from top-k candidates + lprobs, top_indices = lprobs.topk(self.sampling_topk) + probs = lprobs.exp_() + else: + probs = lprobs.exp_() + + # dummy data to be consistent with true branch for type check + top_indices = torch.empty(0).to(probs) + # sample + if step == 0: + indices_buf = torch.multinomial( + probs.view(bsz, -1), + beam_size, + replacement=True, + ).view(bsz, beam_size) + else: + indices_buf = torch.multinomial( + probs.view(bsz * beam_size, -1), + 1, + replacement=True, + ).view(bsz, beam_size) + + if step == 0: + # expand to beam size + probs = probs.expand(bsz, beam_size, -1) + + # gather scores + scores_buf = torch.gather( + probs, dim=2, index=indices_buf.unsqueeze(-1)) + scores_buf = scores_buf.log_().view(bsz, -1) + + # remap indices if using top-k or top-P sampling + if self.sampling_topk > 0 or self.sampling_topp > 0: + indices_buf = torch.gather( + top_indices.expand(bsz, beam_size, -1), + dim=2, + index=indices_buf.unsqueeze(-1), + ).squeeze(2) + + if step == 0: + beams_buf = indices_buf.new_zeros(bsz, beam_size) + else: + beams_buf = torch.arange(0, + beam_size).to(indices_buf).repeat(bsz, 1) + # make scores cumulative + scores_buf.add_( + torch.gather(scores[:, :, step - 1], dim=1, index=beams_buf)) + + return scores_buf, indices_buf, beams_buf + + +class DiverseSiblingsSearch(Search): + """ + Beam search with diverse siblings. + + See "A Simple, Fast Diverse Decoding Algorithm for Neural Generation" for details. + https://arxiv.org/abs/1611.08562 + + 1/ Calculate hypotheses for each beam + 2/ Intra-sibling ordering + 3/ Rewrite scores + 4/ Choose top K hypotheses + + if diversity_rate == 0 is equivalent to BeamSearch + """ + + def __init__(self, tgt_dict, diversity_rate): + super().__init__(tgt_dict) + self.diversity_rate = diversity_rate + self.beam = BeamSearch(tgt_dict) + + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + bsz, beam_size, vocab_size = lprobs.size() + k = min( + # Take the best 2 x beam_size predictions. We'll choose the first + # beam_size of these which don't predict eos to continue with. + beam_size * 2, + lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad + ) + s_list: List[Tensor] + i_list: List[Tensor] + s_list = [torch.empty(0).to(lprobs) for i in range(beam_size)] + i_list = [ + torch.LongTensor().to(device=lprobs.device) + for i in range(beam_size) + ] + sibling_score = torch.arange(1, k + 1).to(lprobs) * self.diversity_rate + + if step == 0: + return self.beam.step(step, lprobs, scores) + lprobs.add_(scores[:, :, step - 1].unsqueeze(-1)) + + # 1/ Calculate hypotheses for each beam + for i in range(beam_size): + torch.topk( + lprobs[:, i, :].view(bsz, -1), k, out=(s_list[i], i_list[i])) + i_list[i].fmod_(vocab_size) + + # 2/ Intra-sibling ordering by default from topk + 3/ Rewrite scores + s_list[i].sub_(sibling_score) + + # 4/ Choose top K hypotheses + indices = torch.stack(i_list, dim=1).view(bsz, -1) + + final_scores = torch.empty(0).to(lprobs) + final_indices = torch.LongTensor().to(device=lprobs.device) + final_beams = torch.LongTensor().to(device=lprobs.device) + (final_scores, final_indices) = torch.topk( + torch.stack(s_list, dim=1).view(bsz, -1), + k, + ) + + final_beams = final_indices // k + + for i in range(bsz): + final_indices[i] = indices[i][final_indices[i]] + + return final_scores, final_indices, final_beams diff --git a/modelscope/models/multi_modal/ofa/generate/sequence_generator.py b/modelscope/models/multi_modal/ofa/generate/sequence_generator.py new file mode 100644 index 00000000..d592f2eb --- /dev/null +++ b/modelscope/models/multi_modal/ofa/generate/sequence_generator.py @@ -0,0 +1,996 @@ +# Copyright 2022 The OFA-Sys Team. +# All rights reserved. +# This source code is licensed under the Apache 2.0 license +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 + +import math +import sys +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from torch import Tensor + +from ..generate import search +from .ngram_repeat_block import NGramRepeatBlock + + +def _expand_mask(mask: torch.Tensor, + dtype: torch.dtype, + tgt_len: Optional[int] = None): + r""" + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, + src_len).to(dtype) + return expanded_mask.masked_fill(expanded_mask.bool(), + torch.finfo(dtype).min) + + +class SequenceGenerator(nn.Module): + + def __init__(self, + tokenizer, + beam_size=1, + max_len_a=0, + max_len_b=200, + max_len=0, + min_len=1, + normalize_scores=True, + len_penalty=1.0, + unk_penalty=0.0, + temperature=1.0, + match_source_len=False, + no_repeat_ngram_size=0, + search_strategy=None, + eos=None, + symbols_to_strip_from_output=None, + lm_model=None, + lm_weight=1.0, + constraint_trie=None, + constraint_range=None, + gen_code=False, + gen_box=False, + ignore_eos=False, + zero_shot=False): + """Generates translations of a given source sentence. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models, + currently support fairseq.models.TransformerModel for scripting + beam_size (int, optional): beam width (default: 1) + max_len_a/b (int, optional): generate sequences of maximum length + ax + b, where x is the source length + max_len (int, optional): the maximum length of the generated output + (not including end-of-sentence) + min_len (int, optional): the minimum length of the generated output + (not including end-of-sentence) + normalize_scores (bool, optional): normalize scores by the length + of the output (default: True) + len_penalty (float, optional): length penalty, where <1.0 favors + shorter, >1.0 favors longer sentences (default: 1.0) + unk_penalty (float, optional): unknown word penalty, where <0 + produces more unks, >0 produces fewer (default: 0.0) + temperature (float, optional): temperature, where values + >1.0 produce more uniform samples and values <1.0 produce + sharper samples (default: 1.0) + match_source_len (bool, optional): outputs should match the source + length (default: False) + """ + super().__init__() + self.gen_code = gen_code + self.gen_box = gen_box + self.ignore_eos = ignore_eos + self.tokenizer = tokenizer + self.tgt_dict = { + value: key + for key, value in tokenizer.get_vocab().items() + } + added = { + value: key + for key, value in tokenizer.get_added_vocab().items() + } + self.tgt_dict.update(added) + self.pad = tokenizer.pad_token_id + self.unk = tokenizer.unk_token_id + self.bos = tokenizer.bos_token_id + self.eos = tokenizer.eos_token_id + self.symbols_to_strip_from_output = ( + symbols_to_strip_from_output.union({self.eos}) if + symbols_to_strip_from_output is not None else {self.bos, self.eos}) + self.vocab_size = len(self.tgt_dict) + self.beam_size = beam_size + # the max beam size is the dictionary size - 1, since we never select pad + self.beam_size = min(beam_size, self.vocab_size - 1) + self.max_len_a = max_len_a + self.max_len_b = max_len_b + self.min_len = min_len + self.max_len = max_len + + self.normalize_scores = normalize_scores + self.len_penalty = len_penalty + self.unk_penalty = unk_penalty + self.temperature = temperature + self.match_source_len = match_source_len + self.zero_shot = zero_shot + + if no_repeat_ngram_size > 0: + self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size) + else: + self.repeat_ngram_blocker = None + + assert temperature > 0, '--temperature must be greater than 0' + + self.search = ( + search.BeamSearch(self.tokenizer) + if search_strategy is None else search_strategy) + # We only need to set src_lengths in LengthConstrainedBeamSearch. + # As a module attribute, setting it would break in multithread + # settings when the model is shared. + self.should_set_src_lengths = ( + hasattr(self.search, 'needs_src_lengths') + and self.search.needs_src_lengths) + + self.lm_model = lm_model + self.lm_weight = lm_weight + if self.lm_model is not None: + self.lm_model.eval() + + self.constraint_trie = constraint_trie + + self.constraint_start = None + self.constraint_end = None + if constraint_range is not None: + constraint_start, constraint_end = constraint_range.split(',') + self.constraint_start = int(constraint_start) + self.constraint_end = int(constraint_end) + + @torch.no_grad() + def forward( + self, + sample: Dict[str, Dict[str, Tensor]], + prefix_tokens: Optional[Tensor] = None, + bos_token: Optional[int] = None, + ): + """Generate a batch of translations. + + Args: + sample (dict): batch + prefix_tokens (torch.LongTensor, optional): force decoder to begin + with these tokens + bos_token (int, optional): beginning of sentence token + (default: self.eos) + """ + return self._generate(sample, prefix_tokens, bos_token=bos_token) + + @torch.no_grad() + def generate(self, models, sample: Dict[str, Dict[str, Tensor]], + **kwargs) -> List[List[Dict[str, Tensor]]]: + """Generate translations. Match the api of other fairseq generators. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models + sample (dict): batch + prefix_tokens (torch.LongTensor, optional): force decoder to begin + with these tokens + constraints (torch.LongTensor, optional): force decoder to include + the list of constraints + bos_token (int, optional): beginning of sentence token + (default: self.eos) + """ + return self._generate(models, sample, **kwargs) + + def _generate( + self, + models, + sample: Dict[str, Dict[str, Tensor]], + prefix_tokens: Optional[Tensor] = None, + constraints: Optional[Tensor] = None, + bos_token: Optional[int] = None, + ): + model = EnsembleModel(models) + # incremental_states = torch.jit.annotate( + # List[Dict[str, Dict[str, Optional[Tensor]]]], + # [ + # torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) + # for i in range(model.models_size) + # ], + # ) + incremental_states = torch.jit.annotate( + List[Tuple[Tuple[torch.Tensor]]], + [ + torch.jit.annotate(Tuple[Tuple[torch.Tensor]], {}) + for i in range(model.models_size) + ], + ) + # print("incremental_states",incremental_states) + # print("incremental_states[0]",incremental_states[0]) + net_input = sample['net_input'] + + if 'src_tokens' in net_input: + src_tokens = net_input['src_tokens'] + # length of the source text being the character length except EndOfSentence and pad + src_lengths = ((src_tokens.ne(self.eos) + & src_tokens.ne(self.pad)).long().sum(dim=1)) + elif 'input_ids' in net_input: + src_tokens = net_input['input_ids'] + # length of the source text being the character length except EndOfSentence and pad + src_lengths = ((src_tokens.ne(self.eos) + & src_tokens.ne(self.pad)).long().sum(dim=1)) + elif 'source' in net_input: + src_tokens = net_input['source'] + src_lengths = ( + net_input['padding_mask'].size(-1) + - net_input['padding_mask'].sum(-1) + if net_input['padding_mask'] is not None else torch.tensor( + src_tokens.size(-1)).to(src_tokens)) + elif 'features' in net_input: + src_tokens = net_input['features'] + src_lengths = ( + net_input['padding_mask'].size(-1) + - net_input['padding_mask'].sum(-1) + if net_input['padding_mask'] is not None else torch.tensor( + src_tokens.size(-1)).to(src_tokens)) + else: + raise Exception( + 'expected src_tokens or source in net input. input keys: ' + + str(net_input.keys())) + + # bsz: total number of sentences in beam + # Note that src_tokens may have more than 2 dimensions (i.e. audio features) + bsz, src_len = src_tokens.size()[:2] + beam_size = self.beam_size + + if constraints is not None and not self.search.supports_constraints: + raise NotImplementedError( + "Target-side constraints were provided, but search method doesn't support them" + ) + + # Initialize constraints, when active + self.search.init_constraints(constraints, beam_size) + + max_len: int = -1 + if self.match_source_len: + max_len = src_lengths.max().item() + else: + max_len = int(self.max_len_a * src_len + self.max_len_b) + assert ( + self.min_len <= max_len + ), 'min_len cannot be larger than max_len, please adjust these!' + # compute the encoder output for each beam + with torch.autograd.profiler.record_function( + 'EnsembleModel: forward_encoder'): + encoder_outs = model.forward_encoder(net_input) + + # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores + new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) + new_order = new_order.to(src_tokens.device).long() + encoder_outs = model.reorder_encoder_out(encoder_outs, new_order) + # ensure encoder_outs is a List. + assert encoder_outs is not None + + # initialize buffers + scores = (torch.zeros(bsz * beam_size, + max_len + 1).to(src_tokens).float() + ) # +1 for eos; pad is never chosen for scoring + tokens = (torch.zeros(bsz * beam_size, + max_len + 2).to(src_tokens).long().fill_( + self.pad)) # +2 for eos and pad + # tokens[:, 0] = self.eos if bos_token is None else bos_token + tokens[:, 0] = self.bos + attn: Optional[Tensor] = None + + # A list that indicates candidates that should be ignored. + # For example, suppose we're sampling and have already finalized 2/5 + # samples. Then cands_to_ignore would mark 2 positions as being ignored, + # so that we only finalize the remaining 3 samples. + cands_to_ignore = (torch.zeros(bsz, beam_size).to(src_tokens).eq(-1) + ) # forward and backward-compatible False mask + + # list of completed sentences + finalized = torch.jit.annotate( + List[List[Dict[str, Tensor]]], + [ + torch.jit.annotate(List[Dict[str, Tensor]], []) + for i in range(bsz) + ], + ) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step + + # a boolean array indicating if the sentence at the index is finished or not + finished = [False for i in range(bsz)] + num_remaining_sent = bsz # number of sentences remaining + + # number of candidate hypos per step + cand_size = 2 * beam_size # 2 x beam size in case half are EOS + + # offset arrays for converting between different indexing schemes + bbsz_offsets = ((torch.arange(0, bsz) + * beam_size).unsqueeze(1).type_as(tokens).to( + src_tokens.device)) + cand_offsets = torch.arange(0, cand_size).type_as(tokens).to( + src_tokens.device) + + reorder_state: Optional[Tensor] = None + batch_idxs: Optional[Tensor] = None + + original_batch_idxs: Optional[Tensor] = None + if 'id' in sample and isinstance(sample['id'], Tensor): + original_batch_idxs = sample['id'] + else: + original_batch_idxs = torch.arange(0, bsz).type_as(tokens) + + for step in range(max_len + 1): # one extra step for EOS marker + # reorder decoder internal states based on the prev choice of beams + if reorder_state is not None: + if batch_idxs is not None: + # update beam indices to take into account removed sentences + corr = batch_idxs - torch.arange( + batch_idxs.numel()).type_as(batch_idxs) + reorder_state.view(-1, beam_size).add_( + corr.unsqueeze(-1) * beam_size) + original_batch_idxs = original_batch_idxs[batch_idxs] + model.reorder_incremental_state(incremental_states, + reorder_state) # todo + encoder_outs = model.reorder_encoder_out( + encoder_outs, reorder_state) + + with torch.autograd.profiler.record_function( + 'EnsembleModel: forward_decoder'): + lprobs, avg_attn_scores = model.forward_decoder( + tokens[:, :step + 1], + encoder_outs, + incremental_states, + self.temperature, + constraint_trie=self.constraint_trie, + constraint_start=self.constraint_start, + constraint_end=self.constraint_end, + gen_code=self.gen_code, + zero_shot=self.zero_shot, + prefix_tokens=prefix_tokens) + + if self.lm_model is not None: + lm_out = self.lm_model(tokens[:, :step + 1]) + probs = self.lm_model.get_normalized_probs( + lm_out, log_probs=True, sample=None) + probs = probs[:, -1, :] * self.lm_weight + lprobs += probs + # handle prefix tokens (possibly with different lengths) + if (prefix_tokens is not None and step < prefix_tokens.size(1) + and step < max_len): + lprobs, tokens, scores = self._prefix_tokens( + step, lprobs, scores, tokens, prefix_tokens, beam_size) + elif step < self.min_len: + # minimum length constraint (does not apply if using prefix_tokens) + lprobs[:, self.eos] = -math.inf + + lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) + + lprobs[:, self.pad] = -math.inf # never select pad + lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty + + if (self.gen_code or self.gen_box) and step < max_len: + lprobs[:, :4] = -math.inf + if self.gen_box: + lprobs[:, -1] = -math.inf + if (step + 1) % 5 == 0: + lprobs[:, self.constraint_start:59457] = -math.inf + else: + lprobs[:, 59457:] = -math.inf + + # handle max length constraint + if step >= max_len: + lprobs[:, :self.eos] = -math.inf + lprobs[:, self.eos + 1:] = -math.inf + if self.ignore_eos: + lprobs[:, self.eos] = 1 + + # Record attention scores, only support avg_attn_scores is a Tensor + if avg_attn_scores is not None: + if attn is None: + attn = torch.empty(bsz * beam_size, + avg_attn_scores.size(1), + max_len + 2).to(scores) + # print("+++++++ debug attention shape +++++++") + # print("attn", attn.shape) + # print("avg_attn_scores", avg_attn_scores.shape) + attn[:, :, step + 1].copy_(avg_attn_scores) + # print("attn[:, :, step + 1]", attn[:, :, step + 1].shape) + # print("attn", attn.shape) + + scores = scores.type_as(lprobs) + eos_bbsz_idx = torch.empty(0).to( + tokens + ) # indices of hypothesis ending with eos (finished sentences) + eos_scores = torch.empty(0).to( + scores + ) # scores of hypothesis ending with eos (finished sentences) + + if self.should_set_src_lengths: + self.search.set_src_lengths(src_lengths) + + if self.repeat_ngram_blocker is not None: + lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, + beam_size, step) + + # Shape: (batch, cand_size) + cand_scores, cand_indices, cand_beams = self.search.step( + step, + lprobs.view(bsz, -1, self.vocab_size), + scores.view(bsz, beam_size, -1)[:, :, :step], + tokens[:, :step + 1], + original_batch_idxs, + ) + + # cand_bbsz_idx contains beam indices for the top candidate + # hypotheses, with a range of values: [0, bsz*beam_size), + # and dimensions: [bsz, cand_size] + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + + # finalize hypotheses that end in eos + # Shape of eos_mask: (batch size, beam size) + eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf) + eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to( + eos_mask) + + # only consider eos when it's among the top beam_size indices + # Now we know what beam item(s) to finish + # Shape: 1d list of absolute-numbered + eos_bbsz_idx = torch.masked_select( + cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]) + + finalized_sents: List[int] = [] + if eos_bbsz_idx.numel() > 0: + eos_scores = torch.masked_select( + cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]) + + finalized_sents = self.finalize_hypos( + step, + eos_bbsz_idx, + eos_scores, + tokens, + scores, + finalized, + finished, + beam_size, + attn, + src_lengths, + max_len, + ) + num_remaining_sent -= len(finalized_sents) + + assert num_remaining_sent >= 0 + if num_remaining_sent == 0: + break + if self.search.stop_on_max_len and step >= max_len: + break + assert step < max_len, f'{step} < {max_len}' + + # Remove finalized sentences (ones for which {beam_size} + # finished hypotheses have been generated) from the batch. + if len(finalized_sents) > 0: + new_bsz = bsz - len(finalized_sents) + + # construct batch_idxs which holds indices of batches to keep for the next pass + batch_mask = torch.ones( + bsz, dtype=torch.bool, device=cand_indices.device) + batch_mask[finalized_sents] = False + # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it + batch_idxs = torch.arange( + bsz, device=cand_indices.device).masked_select(batch_mask) + + # Choose the subset of the hypothesized constraints that will continue + self.search.prune_sentences(batch_idxs) + + eos_mask = eos_mask[batch_idxs] + cand_beams = cand_beams[batch_idxs] + bbsz_offsets.resize_(new_bsz, 1) + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + cand_scores = cand_scores[batch_idxs] + cand_indices = cand_indices[batch_idxs] + + if prefix_tokens is not None: + prefix_tokens = prefix_tokens[batch_idxs] + src_lengths = src_lengths[batch_idxs] + cands_to_ignore = cands_to_ignore[batch_idxs] + + scores = scores.view(bsz, -1)[batch_idxs].view( + new_bsz * beam_size, -1) + tokens = tokens.view(bsz, -1)[batch_idxs].view( + new_bsz * beam_size, -1) + if attn is not None: + attn = attn.view(bsz, -1)[batch_idxs].view( + new_bsz * beam_size, attn.size(1), -1) + bsz = new_bsz + else: + batch_idxs = None + + # Set active_mask so that values > cand_size indicate eos hypos + # and values < cand_size indicate candidate active hypos. + # After, the min values per row are the top candidate active hypos + + # Rewrite the operator since the element wise or is not supported in torchscript. + + eos_mask[:, :beam_size] = ~( # noqa + (~cands_to_ignore) & (~eos_mask[:, :beam_size])) # noqa + active_mask = torch.add( + eos_mask.type_as(cand_offsets) * cand_size, + cand_offsets[:eos_mask.size(1)], + ) + + # get the top beam_size active hypotheses, which are just + # the hypos with the smallest values in active_mask. + # {active_hypos} indicates which {beam_size} hypotheses + # from the list of {2 * beam_size} candidates were + # selected. Shapes: (batch size, beam size) + new_cands_to_ignore, active_hypos = torch.topk( + active_mask, k=beam_size, dim=1, largest=False) + + # update cands_to_ignore to ignore any finalized hypos. + cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size] + # Make sure there is at least one active item for each sentence in the batch. + assert (~cands_to_ignore).any(dim=1).all() + + # update cands_to_ignore to ignore any finalized hypos + + # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam + # can be selected more than once). + active_bbsz_idx = torch.gather( + cand_bbsz_idx, dim=1, index=active_hypos) + active_scores = torch.gather( + cand_scores, dim=1, index=active_hypos) + + active_bbsz_idx = active_bbsz_idx.view(-1) + active_scores = active_scores.view(-1) + + # copy tokens and scores for active hypotheses + + # Set the tokens for each beam (can select the same row more than once) + tokens[:, :step + 1] = torch.index_select( + tokens[:, :step + 1], dim=0, index=active_bbsz_idx) + # Select the next token for each of them + tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather( + cand_indices, dim=1, index=active_hypos) + if step > 0: + scores[:, :step] = torch.index_select( + scores[:, :step], dim=0, index=active_bbsz_idx) + scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather( + cand_scores, dim=1, index=active_hypos) + + # Update constraints based on which candidates were selected for the next beam + self.search.update_constraints(active_hypos) + + # copy attention for active hypotheses + if attn is not None: + attn[:, :, :step + 2] = torch.index_select( + attn[:, :, :step + 2], dim=0, index=active_bbsz_idx) + + # reorder incremental state in decoder + reorder_state = active_bbsz_idx + + # sort by score descending + for sent in range(len(finalized)): + scores = torch.tensor( + [float(elem['score'].item()) for elem in finalized[sent]]) + _, sorted_scores_indices = torch.sort(scores, descending=True) + finalized[sent] = [ + finalized[sent][ssi] for ssi in sorted_scores_indices + ] + finalized[sent] = torch.jit.annotate(List[Dict[str, Tensor]], + finalized[sent]) + return finalized + + def _prefix_tokens(self, step: int, lprobs, scores, tokens, prefix_tokens, + beam_size: int): + """Handle prefix tokens""" + prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat( + 1, beam_size).view(-1) + prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) + prefix_mask = prefix_toks.ne(self.pad) + if self.constraint_trie is None: + lprobs[prefix_mask] = torch.min(prefix_lprobs) - 1 + else: + lprobs[prefix_mask] = -math.inf + lprobs[prefix_mask] = lprobs[prefix_mask].scatter( + -1, prefix_toks[prefix_mask].unsqueeze(-1), + prefix_lprobs[prefix_mask]) + # if prefix includes eos, then we should make sure tokens and + # scores are the same across all beams + eos_mask = prefix_toks.eq(self.eos) + if eos_mask.any(): + # validate that the first beam matches the prefix + first_beam = tokens[eos_mask].view(-1, beam_size, + tokens.size(-1))[:, 0, + 1:step + 1] + eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] + target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] + assert (first_beam == target_prefix).all() + + # copy tokens, scores and lprobs from the first beam to all beams + tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, + beam_size) + scores = self.replicate_first_beam(scores, eos_mask_batch_dim, + beam_size) + lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, + beam_size) + return lprobs, tokens, scores + + def replicate_first_beam(self, tensor, mask, beam_size: int): + tensor = tensor.view(-1, beam_size, tensor.size(-1)) + tensor[mask] = tensor[mask][:, :1, :] + return tensor.view(-1, tensor.size(-1)) + + def finalize_hypos( + self, + step: int, + bbsz_idx, + eos_scores, + tokens, + scores, + finalized: List[List[Dict[str, Tensor]]], + finished: List[bool], + beam_size: int, + attn: Optional[Tensor], + src_lengths, + max_len: int, + ): + """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly. + A sentence is finalized when {beam_size} finished items have been collected for it. + + Returns number of sentences (not beam items) being finalized. + These will be removed from the batch and not processed further. + Args: + bbsz_idx (Tensor): + """ + assert bbsz_idx.numel() == eos_scores.numel() + + # clone relevant token and attention tensors. + # tokens is (batch * beam, max_len). So the index_select + # gets the newly EOS rows, then selects cols 1..{step + 2} + tokens_clone = tokens.index_select( + 0, bbsz_idx)[:, 1:step + 2] # skip the first index, which is EOS + + tokens_clone[:, step] = self.eos + attn_clone = ( + attn.index_select(0, bbsz_idx)[:, :, 1:step + + 2] if attn is not None else None) + + # compute scores per token position + pos_scores = scores.index_select(0, bbsz_idx)[:, :step + 1] + pos_scores[:, step] = eos_scores + # convert from cumulative to per-position scores + pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] + + # normalize sentence-level scores + if self.normalize_scores: + eos_scores /= (step + 1)**self.len_penalty + + # cum_unfin records which sentences in the batch are finished. + # It helps match indexing between (a) the original sentences + # in the batch and (b) the current, possibly-reduced set of + # sentences. + cum_unfin: List[int] = [] + prev = 0 + for f in finished: + if f: + prev += 1 + else: + cum_unfin.append(prev) + cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx) + + unfin_idx = bbsz_idx // beam_size + sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx) + + # Create a set of "{sent}{unfin_idx}", where + # "unfin_idx" is the index in the current (possibly reduced) + # list of sentences, and "sent" is the index in the original, + # unreduced batch + # For every finished beam item + # sentence index in the current (possibly reduced) batch + seen = (sent << 32) + unfin_idx + unique_seen: List[int] = torch.unique(seen).tolist() + + if self.match_source_len: + condition = step > torch.index_select(src_lengths, 0, unfin_idx) + eos_scores = torch.where(condition, torch.tensor(-math.inf), + eos_scores) + sent_list: List[int] = sent.tolist() + for i in range(bbsz_idx.size()[0]): + # An input sentence (among those in a batch) is finished when + # beam_size hypotheses have been collected for it + if len(finalized[sent_list[i]]) < beam_size: + if attn_clone is not None: + # remove padding tokens from attn scores + hypo_attn = attn_clone[i] + else: + hypo_attn = torch.empty(0) + + finalized[sent_list[i]].append({ + 'tokens': + tokens_clone[i], + 'score': + eos_scores[i], + 'attention': + hypo_attn, # src_len x tgt_len + 'alignment': + torch.empty(0), + 'positional_scores': + pos_scores[i], + }) + + newly_finished: List[int] = [] + for unique_s in unique_seen: + # check termination conditions for this sentence + unique_sent: int = unique_s >> 32 + unique_unfin_idx: int = unique_s - (unique_sent << 32) + + if not finished[unique_sent] and self.is_finished( + step, unique_unfin_idx, max_len, len( + finalized[unique_sent]), beam_size): + finished[unique_sent] = True + newly_finished.append(unique_unfin_idx) + + return newly_finished + + def is_finished( + self, + step: int, + unfin_idx: int, + max_len: int, + finalized_sent_len: int, + beam_size: int, + ): + """ + Check whether decoding for a sentence is finished, which + occurs when the list of finalized sentences has reached the + beam size, or when we reach the maximum length. + """ + assert finalized_sent_len <= beam_size + if finalized_sent_len == beam_size or step == max_len: + return True + return False + + +class EnsembleModel(nn.Module): + """A wrapper around an ensemble of models.""" + + def __init__(self, models): + super().__init__() + self.models_size = len(models) + # method '__len__' is not supported in ModuleList for torch script + self.single_model = models[0] + self.models = nn.ModuleList(models) + + # self.has_incremental: bool = False + # if all( + # hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder) + # for m in models + # ): + # self.has_incremental = True + + self.has_incremental = True + + def forward(self): + pass + + def has_encoder(self): + return hasattr(self.single_model, 'encoder') + + def has_incremental_states(self): + return self.has_incremental + + def max_decoder_positions(self): + return min([ + m.max_decoder_positions() + for m in self.models if hasattr(m, 'max_decoder_positions') + ] + [sys.maxsize]) # + + @torch.jit.export + def forward_encoder(self, net_input: Dict[str, Tensor]): + if not self.has_encoder(): + return None + encoder_input = { + k: v + for k, v in net_input.items() if k != 'decoder_input_ids' + } + encoder_input['output_hidden_states'] = True + return [ + model.encoder.forward(**encoder_input) for model in self.models + ] + + @torch.jit.export + def forward_decoder(self, + tokens, + encoder_outs: List[Dict[str, List[Tensor]]], + incremental_states: List[Optional[torch.Tensor]], + temperature: float = 1.0, + constraint_trie=None, + constraint_start=None, + constraint_end=None, + gen_code=False, + zero_shot=False, + prefix_tokens=None): + log_probs = [] + avg_attn: Optional[Tensor] = None + encoder_out: Optional[Dict[str, List[Tensor]]] = None + code_mask = (tokens.new_ones(tokens.size(0)) * gen_code).bool() + + for i, model in enumerate(self.models): + if self.has_encoder(): + encoder_out = encoder_outs[i] + encoder_hidden_states = encoder_out.last_hidden_state + encoder_attention_mask = _expand_mask( + encoder_out.padding_mask, encoder_hidden_states.dtype, + tokens.shape[-1]) + src_pos_embed = encoder_out.position_embedding + + # if tokens.eq(self.single_model.config.pad_token_id).any(): + attention_mask = tokens.eq(self.single_model.padding_idx) + + # decode each model + if self.has_incremental_states(): + decoder_out = model.decoder.forward( # todo 模型输入不同 + input_ids=tokens, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + code_masks=code_mask, + src_pos_embed=src_pos_embed, + past_key_values=incremental_states[i], + use_cache=True, + output_attentions=True) + else: + if hasattr(model, 'decoder'): + # decoder_out = model.decoder.forward(tokens, code_masks=code_mask, encoder_out=encoder_out) + decoder_out = model.decoder.forward( # todo 模型输入不同 + input_ids=tokens, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + code_masks=code_mask, + src_pos_embed=src_pos_embed) + else: + decoder_out = model.forward(tokens) + # print('#### decoder_out ####', decoder_out) + # print('#### decoder_out ####', decoder_out.keys()) + # for k,v in decoder_out.items(): + # print(k) + # if isinstance(v, Tensor): + # print(v.shape) + # elif k == "past_key_values": + # print(len(v)) + # print([v[0][i].shape for i in range(len(v[0]))]) + # else: + # print(len(v)) + # print([v[i].shape for i in range(len(v))]) + + attn: Optional[Tensor] = None + decoder_len = len(decoder_out) + # if decoder_len > 1 and decoder_out[1] is not None: + # if isinstance(decoder_out[1], Tensor): + # attn = decoder_out[1] + # else: + # attn_holder = decoder_out[1]["attn"] + # if isinstance(attn_holder, Tensor): + # attn = attn_holder + # elif attn_holder is not None: + # attn = attn_holder[0] + # if attn is not None: + # attn = attn[:, -1, :] + + if 'cross_attentions' in decoder_out: + attn = decoder_out['cross_attentions'][-1].transpose(1, 0) + attn = attn.mean(dim=0) # (B, tgt_len, src_len) + if attn is not None: + attn = attn[:, -1, :] + + # decoder_out_tuple = ( + # decoder_out[0][:, -1:, :].div_(temperature), + # None if decoder_len <= 1 else decoder_out[1], + # ) + + decoder_out_tuple = ( + decoder_out[0][:, -1:, :].div_(temperature), + None if decoder_len <= 1 else attn, + ) + + beam_size = decoder_out_tuple[0].size(0) // prefix_tokens.size( + 0) if prefix_tokens is not None else 0 + if constraint_trie is not None and not zero_shot: + assert constraint_start is None and constraint_end is None + constraint_masks = decoder_out_tuple[0].new_zeros( + decoder_out_tuple[0].size()).bool() + constraint_prefix_tokens = tokens.tolist() + for token_index, constraint_prefix_token in enumerate( + constraint_prefix_tokens): + prefix_len = prefix_tokens[token_index // beam_size].ne( + 1).sum().item() if prefix_tokens is not None else 0 + if len(constraint_prefix_token) > prefix_len: + constraint_prefix_token = [ + 0 + ] + constraint_prefix_token[prefix_len + 1:] + constraint_nodes = constraint_trie.get_next_layer( + constraint_prefix_token) + constraint_masks[token_index][:, + constraint_nodes] = True + else: + constraint_masks[token_index] = True + decoder_out_tuple[0].masked_fill_(~constraint_masks, -math.inf) + if constraint_start is not None and constraint_end is not None and not zero_shot: + assert constraint_trie is None + decoder_out_tuple[0][:, :, 4:constraint_start] = -math.inf + decoder_out_tuple[0][:, :, constraint_end:] = -math.inf + + probs = model.get_normalized_probs( + decoder_out_tuple, log_probs=True, sample=None) + if constraint_trie is not None and zero_shot: + assert constraint_start is None and constraint_end is None + constraint_masks = decoder_out_tuple[0].new_zeros( + decoder_out_tuple[0].size()).bool() + constraint_prefix_tokens = tokens.tolist() + for token_index, constraint_prefix_token in enumerate( + constraint_prefix_tokens): + constraint_nodes = constraint_trie.get_next_layer( + constraint_prefix_token) + constraint_masks[token_index][:, constraint_nodes] = True + probs.masked_fill_(~constraint_masks, -math.inf) + if constraint_start is not None and constraint_end is not None and zero_shot: + assert constraint_trie is None + probs[:, :, 4:constraint_start] = -math.inf + probs[:, :, constraint_end:] = -math.inf + probs = probs[:, -1, :] + if self.models_size == 1: + return probs, attn + + log_probs.append(probs) + if attn is not None: + if avg_attn is None: + avg_attn = attn + else: + avg_attn.add_(attn) + + avg_probs = torch.logsumexp( + torch.stack(log_probs, dim=0), dim=0) - math.log(self.models_size) + + if avg_attn is not None: + avg_attn.div_(self.models_size) + return avg_probs, avg_attn + + @torch.jit.export + def reorder_encoder_out(self, + encoder_outs: Optional[List[Dict[str, + List[Tensor]]]], + new_order): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + new_outs: List[Dict[str, List[Tensor]]] = [] + if not self.has_encoder(): + return new_outs + for i, model in enumerate(self.models): + assert encoder_outs is not None + new_outs.append( + model.encoder.reorder_encoder_out(encoder_outs[i], new_order)) + return new_outs + + @torch.jit.export + def reorder_incremental_state( + self, + incremental_states: List[Optional[torch.Tensor]], + new_order, + ): + if not self.has_incremental_states(): + return + for i, model in enumerate(self.models): + model.decoder.reorder_incremental_state_scripting( # todo + incremental_states[i], new_order) diff --git a/modelscope/models/multi_modal/ofa/generate/token_generation_constraints.py b/modelscope/models/multi_modal/ofa/generate/token_generation_constraints.py new file mode 100644 index 00000000..13fb3fcf --- /dev/null +++ b/modelscope/models/multi_modal/ofa/generate/token_generation_constraints.py @@ -0,0 +1,512 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license which can be found at +# https://github.com/facebookresearch/fairseq/blob/main/LICENSE +"""Implements tracking of constraints for a beam item. + +A list of constraints is given as a list of one or more token +sequences, each of length at least one token. For example, for an input sentence + +> Die maschinelle Übersetzung ist schwer zu kontrollieren. + +We could have the constraints: +* to influence +* hard + +There are two implementations: +* OrderedConstraintState: Tracks progress through an ordered list of multitoken constraints. +* UnorderedConstraintState: Tracks progress through an unordered list of multitoken constraints. + +The difference is that in the first, the constraints are assumed to be +in order; the algorithm will permit zero or more tokens between them. +In the second, the constraints are not ordered, so many orderings will +be explored. + +The same sequence can be present any number of times, and will appear +that many times in the output. +""" + +from collections import Counter +from typing import List, Set + +import torch + + +class ConstraintState: + + def __init__(self): + pass + + +def pack_constraints( + batch_constraints: List[List[torch.Tensor]]) -> torch.Tensor: + """Takes a list of list of constraints in tensor form (a list of + tensor constraints for each sentence) and transforms it into a + packed Tensor. For example, here is a batch of size 3 with 3, 0, + and 1 constraints: + + [ [ [3 1 2], [3], [4 5 6 7], ] + [], + [ [1 8 9 10 1 4 11 12], ] + ] + + Its corresponding packed structure is: + + [ [ 3 3 1 2 0 3 0 4 5 6 7 0], + [ 0 0 0 0 0 0 0 0 0 0 0 0], + [ 1 1 8 9 10 1 4 11 12 0 0 0] ] + + The packed tensor has shape (batch size, maxlen), where + maxlen is defined below. Each row contains concatenated + constraint tokens for that sentence, with 0 appended after + each constraint. The first item in each row is the number + of constraints for that sentence. So maxlen is the maximum + of + + (number of constraints) + (sum length of constraints) + 1. + + across all sentences in the batch. + """ + # The maximum word length of concatenated constraints for any sentence + max_constraints_len = 1 + for sentence_constraints in batch_constraints: + if len(sentence_constraints): + # number of constraints, plus sum of constrain lens, plus a zero after each + constraints_len = (1 + + sum([c.size(0) for c in sentence_constraints]) + + len(sentence_constraints)) + max_constraints_len = max(max_constraints_len, constraints_len) + + batch_size = len(batch_constraints) + constraints_tensor = torch.zeros((batch_size, max_constraints_len)).long() + for i, sentence_constraints in enumerate(batch_constraints): + constraints_tensor[i, 0] = len(sentence_constraints) + offset = 1 + for j, constraint in enumerate(sentence_constraints): + this_len = constraint.size(0) + constraints_tensor[i, offset:offset + this_len] = constraint + offset += this_len + 1 + + return constraints_tensor.long() + + +def unpack_constraints(constraint_tensor: torch.Tensor) -> List[torch.Tensor]: + """ + Transforms *one row* of a packed constraint tensor (e.g., for one + sentence in the batch) into a list of constraint tensors. + """ + constraint_list = [] + num_constraints = constraint_tensor[0] + constraints = constraint_tensor.tolist() + offset = 1 + for i in range(num_constraints): + where = constraints.index(0, offset) + constraint_list.append(constraint_tensor[offset:where]) + offset = where + 1 + + return constraint_list + + +class ConstraintNode: + """ + Represents a node in a trie managing unordered constraints. + """ + + def __init__(self, token: int = None, parent=None): + # The token associate with this node (None for the root) + self.token = int(token) if token is not None else None + # The parent (None at the root) + self.parent = parent + # Whether this node is a completed constraint + self.terminal = 0 + # List of child nodes + self.children = {} + + # The cumulative number of constraints from this point in the + # trie forward + self.num_constraints = 0 + + @property + def id(self): + return self.token + + def __str__(self): + term = self.terminal != 0 + return f'[{self.token}].{term}#{self.num_constraints}' + + def __getitem__(self, key: int): + return self.children.get(key, None) + + def next_tokens(self) -> Set[int]: + """The set of child labels.""" + return set(self.children.keys()) + + @staticmethod + def create(constraints: List[List[int]]): + root = ConstraintNode() + for sequence in constraints: + root.add_sequence(sequence) + + return root + + @staticmethod + def print_graph(node: 'ConstraintNode'): + if len(node.children) == 0: + return str(node) + else: + s = f'({node}' + for child in node.children.values(): + s += ' ' + ConstraintNode.print_graph(child) + s += ')' + return s + + def token_counts(self) -> Counter: + """Returns a counter of the number of times each token is used + in a constraint. + """ + token_counts = Counter() + kids = list(self.children.values()) + while len(kids) > 0: + kid = kids.pop() + token_counts[kid.id] += kid.num_constraints + kids += list(kid.children.values()) + + return token_counts + + def tokens(self) -> Set[int]: + """Returns the set of tokens in constraints.""" + return set(self.token_counts().keys()) + + def add_sequence(self, sequence: List[int]): + """Adds a constraint, represented as a list of integers, to + the trie.""" + assert len(sequence) > 0 + + token = int(sequence[0]) + if token not in self.children: + self.children[token] = ConstraintNode(token, parent=self) + + node = self.children[token] + if len(sequence) == 1: + node.terminal += 1 + node.num_constraints += 1 + parent = node.parent + while parent is not None: + parent.num_constraints += 1 + parent = parent.parent + else: + node.add_sequence(sequence[1:]) + + +class UnorderedConstraintState(ConstraintState): + """ + Records progress through the set of constraints for each item in the beam + using a trie. + """ + + def __init__(self, + node: ConstraintNode, + copy_from: 'ConstraintState' = None): + self.node = node + + if copy_from is None: + # The root node + self.root = node + # The set of states in the graph that have been completed + self.completed = Counter() + # The... + self.generated = Counter() + # The list of tokens we need to generate + self.needed_tokens = self.root.tokens() + else: + self.completed = Counter(copy_from.completed) + self.generated = Counter(copy_from.generated) + self.root = copy_from.root + + # Mark the node as generated + if self.node != self.root: + self.generated[node] += 1 + + @staticmethod + def create(constraint_tensor: torch.Tensor): + constraint_list = unpack_constraints(constraint_tensor) + constraint_trie_root = ConstraintNode.create(constraint_list) + return UnorderedConstraintState(constraint_trie_root) + + def __str__(self): + gen_str = ','.join([str(node) for node in self.generated]) + return f'{self.name}/{self.bank}({gen_str})x{self.num_completed}' + + def __copy__(self): + copied_state = UnorderedConstraintState(self.node, copy_from=self) + return copied_state + + def copy(self): + return self.__copy__() + + @property + def name(self): + if self.node.id is None: + return 'ROOT' + else: + return str(self.node.id) + + @property + def is_root(self): + return self.node == self.root + + @property + def bank(self): + return sum(self.generated.values()) + + @property + def num_completed(self): + """The number of constraints (not constraint tokens) that are completed. + In addition to the already-completed states, we need to account for the + current state, which might get marked as completed when another token + is generated. + """ + in_final = self.node.terminal and self.completed[ + self.node] < self.node.terminal + return sum(self.completed.values()) + in_final + + @property + def finished(self): + return self.root.num_constraints - self.num_completed == 0 + + @property + def token_counts(self): + return self.root.token_counts() + + @property + def tokens(self): + return self.root.tokens() + + @property + def num_constraint_tokens(self): + return sum(self.token_counts.values()) + + def next_tokens(self) -> Set[int]: + """Returns the list of tokens that could come next. + These are (a) all tokens extending the root state and, for + non-root states, additionally all tokens extending the current + state.""" + + if self.node != self.root: + return self.root.next_tokens().union(self.node.next_tokens()) + else: + return self.root.next_tokens() + + def advance(self, token: int): + """Reads in a token and advances the state. Here's how it works. + + We can advance to the next state if: + - there is a matching child + - its path isn't blocked + + A path is blocked when all constraints that are descendants of + that node have already been generated, in the current state. + + If we are not able to advance from the current state, we "fall + off the graph" and return to the root state. There, we again + try to advance, checking the same criteria. + + In any case, when falling off the graph, we need to do some + bookkeeping. We: + - check whether any constraints were met (all prefixes of + current state) + - if one is found, mark it as completed + - adjust visited nodes accordingly + """ + token = int(token) + + next_state = None + child = self.node[token] + if child is not None and self.generated[child] < child.num_constraints: + next_state = UnorderedConstraintState(child, copy_from=self) + + def rewind(): + """If we're mid-trie and an "illegal" token is chosen next, we need + to reset our state to the root state. However, along the way, we need + to check whether a prefix of the current trie state represents a state + we could mark as completed. + """ + node = self.node + while node != self.root: + if node.terminal and self.completed[node] < node.terminal: + next_state.completed[node] += 1 + return + + next_state.generated[node] -= 1 + node = node.parent + + # Fall off the graph, check the root + if next_state is None and token in self.root.next_tokens(): + child = self.root[token] + # We can only traverse this edge if it's not saturated + if self.generated[child] < child.num_constraints: + next_state = UnorderedConstraintState(child, copy_from=self) + else: + next_state = UnorderedConstraintState( + self.root, copy_from=self) + + # Rewind + rewind() + + elif next_state is None: + next_state = UnorderedConstraintState(self.root, copy_from=self) + # Rewind + rewind() + + return next_state + + +class ConstraintSequence: + + def __init__(self, sequences: List[List[int]]): + """Represents a set of possibly multitoken constraints by + concatenating them and internally recording the end points. + """ + self.sequences = [] + self.endpoints = [] + self.num_tokens = 0 + self.tokens = set() + for sequence in sequences: + for token in sequence: + self.tokens.add(token) + self.num_tokens += len(sequence) + self.endpoints += [False + for x in range(len(sequence) - 1)] + [True] + self.sequences += sequence + + def __getitem__(self, key: int): + return self.sequences[key] + + def __len__(self): + return len(self.sequences) + + def __str__(self): + return str(self.sequences) + + +class OrderedConstraintState(ConstraintState): + """ + Records progress through the set of linear nonbranching constraints with gaps. + """ + + def __init__(self, sequence: ConstraintSequence, state: int = -1): + self.sequence = sequence + self.state = state + + @staticmethod + def create(constraint_tensor: torch.Tensor): + constraint_list = unpack_constraints(constraint_tensor) + return OrderedConstraintState(ConstraintSequence(constraint_list), -1) + + def __str__(self): + return f'{self.state}/{self.bank}x{self.num_completed}' + + def __copy__(self): + return OrderedConstraintState(self.sequence, self.state) + + def copy(self): + return self.__copy__() + + @property + def num_completed(self): + if self.state == -1: + return 0 + count = len( + list( + filter(lambda x: x, + self.sequence.endpoints[0:self.state + 1]))) + return count + + @property + def is_root(self): + return self.state == -1 + + @property + def name(self): + if self.state == -1: + return 'ROOT' + else: + return str(self.sequence[self.state]) + + @property + def bank(self) -> int: + return self.state + 1 + + @property + def finished(self): + return self.state + 1 == len(self.sequence) + + @property + def token_counts(self): + return self.sequence.token_counts() + + @property + def tokens(self): + return self.sequence.tokens + + @property + def num_constraint_tokens(self): + return sum(self.token_counts.values()) + + def next_tokens(self) -> Set[int]: + """Returns the list of tokens that could come next. + These are (a) all tokens extending the root state and, for + non-root states, additionally all tokens extending the current + state.""" + + tokens = set() + if self.state > 0: + tokens.add(self.sequence[0]) + if not self.finished: + tokens.add(self.sequence[self.state + 1]) + return tokens + + def advance(self, token: int): + """Reads in a token and advances the state. Here's how it works. + + We can advance to the next state if: + - there is a matching child + - its path isn't blocked + + A path is blocked when all constraints that are descendants of + that node have already been generated, in the current state. + + If we are not able to advance from the current state, we "fall + off the graph" and return to the root state. There, we again + try to advance, checking the same criteria. + + In any case, when falling off the graph, we need to do some + bookkeeping. We: + - check whether any constraints were met (all prefixes of + current state) + - if one is found, mark it as completed + - adjust visited nodes accordingly + """ + token = int(token) + # print(f"{self} ADVANCE({token}) {self.sequence} -> ", end="") + + if self.finished: + # Accept anything + next_state = self.copy() + + elif self.sequence[self.state + 1] == token: + # Advance to the next token + next_state = OrderedConstraintState(self.sequence, self.state + 1) + + elif self.sequence.endpoints[self.state]: + # Accept anything between constraints (*) + next_state = self.copy() + + elif token == self.sequence[0]: + # Start over having generated the first token + next_state = OrderedConstraintState(self.sequence, 0) + else: + # Start over from the root + next_state = OrderedConstraintState(self.sequence, -1) + + return next_state diff --git a/modelscope/models/multi_modal/ofa/generate/utils.py b/modelscope/models/multi_modal/ofa/generate/utils.py new file mode 100644 index 00000000..8c8abf99 --- /dev/null +++ b/modelscope/models/multi_modal/ofa/generate/utils.py @@ -0,0 +1,124 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license which can be found at +# https://github.com/facebookresearch/fairseq/blob/main/LICENSE + +import collections +from collections import abc +from itertools import accumulate + +import torch +import torch.nn.functional as F + +try: + from amp_C import multi_tensor_l2norm + + multi_tensor_l2norm_available = True +except ImportError: + multi_tensor_l2norm_available = False + +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None + +MANIFOLD_PATH_SEP = '|' + + +def apply_to_sample(f, sample): + if hasattr(sample, '__len__') and len(sample) == 0: + return {} + + def _apply(x): + if torch.is_tensor(x): + return f(x) + elif isinstance(x, collections.OrderedDict): + # OrderedDict has attributes that needs to be preserved + od = collections.OrderedDict( + (key, _apply(value)) for key, value in x.items()) + od.__dict__ = x.__dict__ + return od + elif isinstance(x, dict): + return {key: _apply(value) for key, value in x.items()} + elif isinstance(x, list): + return [_apply(x) for x in x] + elif isinstance(x, tuple): + return tuple(_apply(x) for x in x) + elif isinstance(x, set): + return {_apply(x) for x in x} + else: + return x + + return _apply(sample) + + +def move_to_device(batch, device): + r"""Puts each data field to the device""" + if isinstance(batch, torch.Tensor): + return batch.to(device) + elif isinstance(batch, (list, tuple)): + return tuple(move_to_device(item, device) for item in batch) + elif isinstance(batch, abc.Mapping): + return { + key: move_to_device(value, device) + for key, value in batch.items() + } + else: + return batch + + +def strip_pad(tensor, pad): + return tensor[tensor.ne(pad)] + + +def get_token_to_word_mapping(tokens, exclude_list): + n = len(tokens) + word_start = [int(token not in exclude_list) for token in tokens] + word_idx = list(accumulate(word_start)) + token_to_word = {i: word_idx[i] for i in range(n)} + return token_to_word + + +def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos): + tgt_valid = (((tgt_sent != pad) & # noqa + (tgt_sent != eos)).nonzero(as_tuple=False).squeeze(dim=-1)) + src_invalid = (((src_sent == pad) | # noqa + (src_sent == eos)).nonzero(as_tuple=False).squeeze(dim=-1)) + src_token_to_word = get_token_to_word_mapping(src_sent, [eos, pad]) + tgt_token_to_word = get_token_to_word_mapping(tgt_sent, [eos, pad]) + alignment = [] + if len(tgt_valid) != 0 and len(src_invalid) < len(src_sent): + attn_valid = attn[tgt_valid] + attn_valid[:, src_invalid] = float('-inf') + _, src_indices = attn_valid.max(dim=1) + for tgt_idx, src_idx in zip(tgt_valid, src_indices): + alignment.append(( + src_token_to_word[src_idx.item()] - 1, + tgt_token_to_word[tgt_idx.item()] - 1, + )) + return alignment + + +def softmax(x, dim: int, onnx_trace: bool = False): + if onnx_trace: + return F.softmax(x.float(), dim=dim) + else: + return F.softmax(x, dim=dim, dtype=torch.float32) + + +def log_softmax(x, dim: int, onnx_trace: bool = False): + if onnx_trace: + return F.log_softmax(x.float(), dim=dim) + else: + return F.log_softmax(x, dim=dim, dtype=torch.float32) + + +def extract_soft_alignment(attn, src_sent, tgt_sent, pad, eos): + tgt_valid = (tgt_sent != pad).nonzero(as_tuple=False) + src_valid = (src_sent != pad).nonzero(as_tuple=False).squeeze(dim=-1) + alignment = [] + if len(tgt_valid) != 0 and len(src_valid) != 0: + attn_valid = attn[tgt_valid, src_valid] + alignment = [['{:.6f}'.format(p) for p in src_probs.tolist()] + for src_probs in attn_valid] + return alignment diff --git a/modelscope/models/multi_modal/ofa/modeling_ofa.py b/modelscope/models/multi_modal/ofa/modeling_ofa.py new file mode 100755 index 00000000..b0350d1d --- /dev/null +++ b/modelscope/models/multi_modal/ofa/modeling_ofa.py @@ -0,0 +1,2192 @@ +# Copyright 2022 OFA-Sys Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch OFA model.""" + +import math +import random +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import torch +from torch import Tensor, nn +from torch.nn import functional as F +from transformers.activations import ACT2FN +from transformers.file_utils import (ModelOutput, add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, + Seq2SeqModelOutput) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from .configuration_ofa import OFAConfig +from .generate import utils +from .resnet import ResNet + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = 'ofa-base' +_CONFIG_FOR_DOC = 'OFAConfig' +_TOKENIZER_FOR_DOC = 'OFATokenizer' + +DEFAULT_MAX_SOURCE_POSITIONS = 1024 +DEFAULT_MAX_TARGET_POSITIONS = 1024 + +DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) + +OFA_PRETRAINED_MODEL_ARCHIVE_LIST = [ + 'ofa-tiny', + 'ofa-medium', + 'ofa-base', + 'ofa-large', +] + +try: + from apex.normalization import FusedLayerNorm as _FusedLayerNorm + + has_fused_layernorm = True + + class FusedLayerNorm(_FusedLayerNorm): + + @torch.jit.unused + def forward(self, x): + if not x.is_cuda: + return super().forward(x) + else: + with torch.cuda.device(x.device): + return super().forward(x) + +except ImportError: + has_fused_layernorm = False + + +def LayerNorm(normalized_shape, + eps=1e-5, + elementwise_affine=True, + export=False): + r""" + Layer normalization. + If apex is available, use `FusedLayerNorm` instead. + """ + if torch.jit.is_scripting(): + export = True + if not export and torch.cuda.is_available() and has_fused_layernorm: + return FusedLayerNorm(normalized_shape, eps, elementwise_affine) + return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) + + +def make_token_bucket_position(bucket_size, + max_position=DEFAULT_MAX_SOURCE_POSITIONS): + r""" + Make relative position indices for the text. + """ + context_pos = torch.arange(max_position, dtype=torch.long)[:, None] + memory_pos = torch.arange(max_position, dtype=torch.long)[None, :] + relative_pos = context_pos - memory_pos + sign = torch.sign(relative_pos) + mid = bucket_size // 2 + abs_pos = torch.where((relative_pos < mid) & (relative_pos > -mid), + mid - 1, torch.abs(relative_pos)) + log_pos = torch.ceil( # noqa + torch.log(abs_pos / mid) / math.log((max_position - 1) / mid) * # noqa + (mid - 1)) + mid # noqa + log_pos = log_pos.int() + bucket_pos = torch.where(abs_pos.le(mid), relative_pos, + log_pos * sign).long() + return bucket_pos + bucket_size - 1 + + +def make_image_bucket_position(bucket_size, num_relative_distance): + r""" + Make relative position indices for the image. + """ + coords_h = torch.arange(bucket_size) + coords_w = torch.arange(bucket_size) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - \ + coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += bucket_size - 1 # shift to start from 0 + relative_coords[:, :, 1] += bucket_size - 1 + relative_coords[:, :, 0] *= 2 * bucket_size - 1 + relative_position_index = torch.zeros( + size=(bucket_size * bucket_size + 1, ) * 2, + dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = num_relative_distance - 3 + relative_position_index[0:, 0] = num_relative_distance - 2 + relative_position_index[0, 0] = num_relative_distance - 1 + return relative_position_index + + +def new_arange(x, *size): + r""" + Return a Tensor of `size` filled with a range function on the device of x. + If size is empty, using the size of the variable x. + """ + if len(size) == 0: + size = x.size() + return torch.arange(size[-1], device=x.device).expand(*size).contiguous() + + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, + decoder_start_token_id: int): + r""" + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + assert pad_token_id is not None, 'self.model.config.pad_token_id has to be defined.' + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def _make_causal_mask(input_ids_shape: torch.Size, + dtype: torch.dtype, + past_key_values_length: int = 0): + r""" + Make causal mask used for uni-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), float('-inf')) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], + dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, + tgt_len + past_key_values_length) + + +def _expand_mask(mask: torch.Tensor, + dtype: torch.dtype, + tgt_len: Optional[int] = None): + r""" + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, + src_len).to(dtype) + return expanded_mask.masked_fill(expanded_mask.bool(), + torch.finfo(dtype).min) + + +def Embedding(num_embeddings, + embedding_dim, + padding_idx=None, + zero_init=False): + r""" + Embedding for tokens + """ + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) + if padding_idx is not None: + nn.init.constant_(m.weight[padding_idx], 0) + if zero_init: + nn.init.constant_(m.weight, 0) + return m + + +def Linear(in_features, out_features, bias=True): + r""" + Implementation of linear projection with xavier initialization + """ + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.0) + return m + + +class LayerDropModuleList(nn.ModuleList): + r""" + A LayerDrop implementation based on :class:`torch.nn.ModuleList`. + + Args: + p (float): probability of dropping out each layer + modules (iterable, optional): an iterable of modules to add + """ + + def __init__(self, p, modules=None): + super().__init__(modules) + self.p = p + + def __iter__(self): + dropout_probs = torch.empty(len(self)).uniform_() + for i, m in enumerate(super().__iter__()): + if not self.training or (dropout_probs[i] > self.p): + yield m + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + r""" + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Args: + x (`nn.Modules`): input nn layers. + drop_prob (`float`): drop path ratio. + training (`bool`): whether is training or inference. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (1, x.shape[1], 1) + random_tensor = keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + r""" + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Args: + drop_prob: drop path ratio. + """ + + def __init__(self, drop_prob=None): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) + + +class OFAAttention(nn.Module): + r""" + Multi-headed attention, with additional implementation for NormFormer. + + Args: + embed_dim (`int`): embedding dimension. + num_heads (`int`): the number of attention heads. + dropout (`float32`): the ratio for dropout. + is_decoder (`bool`): whether or not decoder attention. + bias (`bool`): whether to add bias. + scale_heads (`bool`): whether to learn scaling heads, only for Normformer. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + scale_heads: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), f'embed_dim must be divisible by num_heads ' \ + f'(got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads}).' + # self.scaling = self.head_dim ** -0.5 + # 1. difference + scale_factor = 2 + self.scaling = float(self.head_dim * scale_factor)**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.attn_dropout = nn.Dropout(p=dropout) + self.c_attn = nn.Parameter( + torch.ones((self.num_heads, )), + requires_grad=True) if scale_heads else None + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + r""" + Reshape tensors for multi-head attention. + """ + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + attn_bias: Optional[torch.Tensor] = None, + ): + r""" + Args: + hidden_states (`torch.FloatTensor` of shape `(bsz, tgt_len, embed_dim)`)`: input states. + key_value_states (`torch.FloatTensor` of shape (bsz, tgt_len, embed_dim), *optional*): key value states. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): + cached past key value states for fast inference. + attention_mask (`torch.FloatTensor` of shape `(bsz, 1, tgt_len, seq_len)`, *optional*): attention mask. + output_attentions (`bool`, *optional*): whether to output attention weights of all layers. + attn_bias (`torch.FloatTensor` of shape `(bsz, 1, tgt_len, src_len)`, *optional*): + the attention bias for positional information. + + Returns: + attn_output (`torch.FloatTensor` of shape `(bsz, tgt_len, embed_dim)`): attention outputs. + attn_weights_reshaped (`torch.FloatTensor`, *optional*): attention weights of all layers. + past_key_value (`torch.FloatTensor`, *optional*): cached key value states for fast inference. + """ + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, + bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError(f'Attention weights should be of size ' + f'{(bsz * self.num_heads, tgt_len, src_len)}, ' + f'but is {attn_weights.size()}') + + # Add attention bias for positional information + if attn_bias is not None: + attn_weights += attn_bias + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f'Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}' + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, + src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, + src_len) + + attn_weights = F.softmax(attn_weights, dim=-1) + + if output_attentions: + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, + tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, + tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = self.attn_dropout(attn_weights) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, + self.head_dim): + raise ValueError( + f'`attn_output` should be of size ' + f'{(bsz, self.num_heads, tgt_len, self.head_dim)}, ' + f'but is {attn_output.size()}') + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, + self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + if self.c_attn is not None: + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, + self.head_dim) + attn_output = torch.einsum('bthd,h->bthd', attn_output, + self.c_attn) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class OFAEncoderLayer(nn.Module): + r""" + OFA encoder layer implementation. + + Args: + config: configuration for OFA. + drop_path_rate: the ratio for drop path. + """ + + def __init__(self, config: OFAConfig, drop_path_rate=0.0): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = OFAAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = LayerNorm(self.embed_dim) + self.self_attn_mid_layer_norm = LayerNorm( + self.embed_dim) if config.normformer else None + self.dropout = nn.Dropout(config.dropout) + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = nn.Dropout(config.activation_dropout) + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.ffn_layer_norm = LayerNorm( + config.encoder_ffn_dim) if config.normformer else None + self.final_layer_norm = LayerNorm(self.embed_dim) + self.normalize_before = config.encoder_normalize_before + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + + def residual_connection(self, x, residual): + r""" + Residual connection with drop path. + """ + return residual + self.drop_path(x) + + def forward(self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: bool = False, + attn_bias: Optional[torch.Tensor] = None): + r""" + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape *(bsz, src_len, embed_dim)* + attention_mask (`torch.FloatTensor`): attention mask of size + *(bsz, 1, src_len, src_len)* where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + whether to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + attn_bias (`torch.FloatTensor`): bias for positional information. + + Returns: + outputs (`tuple(torch.FloatTensor)`): + output hidden states of size (bsz, src_len, embed_dim), optionally with attention weights. + """ + + residual = hidden_states + if self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + attn_bias=attn_bias, + ) + if self.self_attn_mid_layer_norm: + hidden_states = self.self_attn_mid_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.residual_connection(hidden_states, residual) + if not self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + + if self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout(hidden_states) + if self.ffn_layer_norm: + hidden_states = self.ffn_layer_norm(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.residual_connection(hidden_states, residual) + if not self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() + or torch.isnan(hidden_states).any()): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states, ) + + if output_attentions: + outputs += (attn_weights, ) + + return outputs + + +class OFADecoderLayer(nn.Module): + r""" + OFA decoder layer implementation. + + Args: + config: configuration for OFA. + drop_path_rate: the ratio for drop path. + """ + + def __init__(self, config: OFAConfig, drop_path_rate=0.0): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = OFAAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = nn.Dropout(p=config.dropout) + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = nn.Dropout(p=config.activation_dropout) + + self.self_attn_layer_norm = LayerNorm(self.embed_dim) + self.self_attn_mid_layer_norm = LayerNorm( + self.embed_dim) if config.normformer else None + self.cross_attn = OFAAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.cross_attn_layer_norm = LayerNorm(self.embed_dim) + self.cross_attn_mid_layer_norm = LayerNorm( + self.embed_dim) if config.normformer else None + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.ffn_layer_norm = LayerNorm( + config.decoder_ffn_dim) if config.normformer else None + self.final_layer_norm = LayerNorm(self.embed_dim) + self.normalize_before = config.decoder_normalize_before + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + + def residual_connection(self, x, residual): + r""" + Residual connection with drop path. + """ + return residual + self.drop_path(x) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + self_attn_bias: Optional[torch.Tensor] = None, + cross_attn_bias: Optional[torch.Tensor] = None, + ): + r""" + Args: + hidden_states (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): input to the layer. + attention_mask (`torch.FloatTensor` of shape `(bsz, 1, tgt_len, src_len)`): + attention mask where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch, seq_len, embed_dim)`): + cross attention input to the layer. + encoder_attention_mask (`torch.FloatTensor` of shape `(bsz, 1, tgt_len, src_len)`): + encoder attention mask where padding elements are indicated by very large negative values. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): whether to return the attentions tensors of all attention layers. + use_cache (`bool`, *optional*): whether to use cache + self_attn_bias (`torch.FloatTensor`): self attention bias for positional information. + cross_attn_bias (`torch.FloatTensor`): cross attention bias for positional information. + """ + + # Self attention with intermediate layernorm + residual = hidden_states + if self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + self_attn_past_key_value = past_key_value[: + 2] if past_key_value is not None else None + # add present self-attn cache to position 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + output_attentions=output_attentions, + attn_bias=self_attn_bias, + ) + if self.self_attn_mid_layer_norm: + hidden_states = self.self_attn_mid_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.residual_connection(hidden_states, residual) + if not self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross attention with intermediate layernorm + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + if self.normalize_before: + hidden_states = self.cross_attn_layer_norm(hidden_states) + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[ + -2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.cross_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + attn_bias=cross_attn_bias, + ) + if self.cross_attn_mid_layer_norm: + hidden_states = self.cross_attn_mid_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.residual_connection(hidden_states, residual) + if not self.normalize_before: + hidden_states = self.cross_attn_layer_norm(hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # FFN with intermediate layernorm + residual = hidden_states + if self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout(hidden_states) + if self.ffn_layer_norm: + hidden_states = self.ffn_layer_norm(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.residual_connection(hidden_states, residual) + if not self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states, ) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value, ) + + return outputs + + +class OFAPreTrainedModel(PreTrainedModel): + r""" + Base class OFA + """ + + config_class = OFAConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + + def _init_weights(self, module): + r""" + Weight initialization which follows BERT. + """ + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + r""" + Turn on the switch of gradient checkpointing. + """ + if isinstance(module, (OFADecoder, OFAEncoder)): + module.gradient_checkpointing = value + + +@dataclass +class OFAEncoderOutput(ModelOutput): + r""" + Base class for OFA's outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): + Sequence of hidden-states at the output of the last layer of the model. + + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed + or when `config.output_hidden_states=True`): + + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(bsz, seq_len, hidden)`. + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed + or when `config.output_attentions=True`): + + Tuple of `torch.FloatTensor` (one for each layer) of shape `(bsz, num_heads, seq_len, seq_len)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + position_embedding (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): + postional embeddings of the inputs. + """ + + last_hidden_state: torch.FloatTensor = None + padding_mask: torch.Tensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + position_embedding: Optional[torch.FloatTensor] = None + + +OFA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`~OFAConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +OFA_GENERATION_EXAMPLE = r""" + Image captioning example: + + ```python + >>> from PIL import Image + >>> from torchvision import transforms + >>> from transformers import OFATokenizer, OFAForConditionalGeneration + + >>> mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] + >>> resolution = 256 + >>> patch_resize_transform = transforms.Compose([ + lambda image: image.convert("RGB"), + transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std) + ]) + + >>> model = OFAForConditionalGeneration.from_pretrained(ckpt_dir) + >>> tokenizer = OFATokenizer.from_pretrained(ckpt_dir) + + >>> txt = " what is the description of the image?" + >>> inputs = tokenizer([txt], max_length=1024, return_tensors="pt")["input_ids"] + >>> img = Image.open(path_to_image) + >>> patch_img = patch_resize_transform(img).unsqueeze(0) + + >>> gen = model.generate(inputs, patch_img=patch_img, num_beams=4) + >>> print(tokenizer.decode(gen, skip_special_tokens=True, clean_up_tokenization_spaces=False)) + ``` +""" + +OFA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): + indices of input sequence tokens in the vocabular, and padding will be ignored by default; + + indices can be obtained using [`~OFATokenizer`]. + + patch_images (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): + the resized image, which are transformed by the default operations. + patch_images_2 (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): + the second (if it exists) image. + patch_masks (`torch.BoolTensor`): the patches to be masked. + token_embeddings (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): token embeddings. + sample_patch_num (`int`): the number of patches to sample. + decoder_input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of the sequence in the vocabulary. + code_masks (`torch.Tensor` of shape `(bsz, seq_len)`): masks only for code generation. + attention_mask (`torch.Tensor` of shape `(bsz, seq_len)`): attention mask for decoding. + encoder_outputs (`OFAEncoderOutput`): + encoder outputs with hidden states, positional embeddings, and padding masks. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(bsz, num_heads, tgt_len, head_size)`) and 2 additional tensors of + shape `(bsz, num_heads, src_len, head_size)`. + use_cache (`bool`): whether to use cache for faster inference. + output_attentions (`bool`): whether to output attention weights. + output_hidden_states (`bool`): whether to output hidden states. + return_dict (`bool`): unused. Keep it for generation only. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. +""" + + +class OFAEncoder(OFAPreTrainedModel): + r""" + OFA encoder consisting of layers of [`OFAEncoderLayer`]. + + Args: + config: OFAConfig + embed_tokens (`nn.Embedding`, *optional*): output embedding + """ + + def __init__(self, + config: OFAConfig, + embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + + self.dropout = nn.Dropout(config.dropout) + self.encoder_layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + self.embed_scale = math.sqrt( + embed_dim) if config.scale_embedding else 1.0 + self.num_attention_heads = config.encoder_attention_heads + + if getattr(config, 'layernorm_embedding', False): + self.layernorm_embedding = LayerNorm(embed_dim) + else: + self.layernorm_embedding = None + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, + self.padding_idx) + + if config.add_type_embedding: + self.type_embedding = Embedding(2, embed_dim, padding_idx=None) + else: + self.type_embedding = None + + if config.resnet_type == 'resnet18': + self.embed_images = ResNet( + [2, 2, 2], drop_path_rate=config.resnet_drop_path_rate) + elif config.resnet_type == 'resnet34': + self.embed_images = ResNet( + [3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) + elif config.resnet_type == 'resnet50': + self.embed_images = ResNet( + [3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) + elif config.resnet_type == 'resnet101': + self.embed_images = ResNet( + [3, 4, 23], drop_path_rate=config.resnet_drop_path_rate) + elif config.resnet_type == 'resnet152': + self.embed_images = ResNet( + [3, 8, 36], drop_path_rate=config.resnet_drop_path_rate) + else: + raise NotImplementedError + + # self.image_proj = nn.Linear(1024, embed_dim) + self.image_proj = Linear(1024, embed_dim) + + if config.resnet_model_path: + print('load resnet {}'.format(config.resnet_model_path)) + resnet_state_dict = torch.load(config.resnet_model_path) + self.embed_images.load_state_dict(resnet_state_dict) + if config.patch_layernorm_embedding: + self.patch_layernorm_embedding = LayerNorm(embed_dim) + else: + self.patch_layernorm_embedding = None + + self.embed_positions = Embedding(self.max_source_positions + 2, + embed_dim) + self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1, + embed_dim) + self.pos_ln = LayerNorm(embed_dim) + self.image_pos_ln = LayerNorm(embed_dim) + self.pos_scaling = float(embed_dim / self.num_attention_heads + * config.attn_scale_factor)**-0.5 + self.pos_q_linear = nn.Linear(embed_dim, embed_dim) + self.pos_k_linear = nn.Linear(embed_dim, embed_dim) + + if self.encoder_layerdrop > 0.0: + self.layers = LayerDropModuleList(p=self.encoder_layerdrop) + else: + self.layers = nn.ModuleList([]) + + dpr = [ + x.item() for x in torch.linspace(0, config.encoder_drop_path_rate, + config.encoder_layers) + ] + self.layers.extend([ + OFAEncoderLayer(config, drop_path_rate=dpr[i]) + for i in range(config.encoder_layers) + ]) + self.num_layers = len(self.layers) + + if config.encoder_normalize_before: + self.layer_norm = LayerNorm(embed_dim) + else: + self.layer_norm = None + + self.token_bucket_size = config.token_bucket_size + token_num_rel_dis = 2 * config.token_bucket_size - 1 + token_rp_bucket = make_token_bucket_position(config.token_bucket_size) + self.token_rel_pos_table_list = nn.ModuleList([ + Embedding( + token_num_rel_dis, self.num_attention_heads, zero_init=True) + for _ in range(config.encoder_layers) + ]) + + self.image_bucket_size = config.image_bucket_size + image_num_rel_dis = (2 * config.image_bucket_size + - 1) * (2 * config.image_bucket_size - 1) + 3 + image_rp_bucket = make_image_bucket_position(config.image_bucket_size, + image_num_rel_dis) + self.image_rel_pos_table_list = nn.ModuleList([ + Embedding( + image_num_rel_dis, self.num_attention_heads, zero_init=True) + for _ in range(config.encoder_layers) + ]) + + if config.layernorm_embedding: + self.layernorm_embedding = LayerNorm(embed_dim) + else: + self.layernorm_embedding = None + + self.register_buffer('token_rp_bucket', token_rp_bucket) + self.register_buffer('image_rp_bucket', image_rp_bucket) + self.entangle_position_embedding = config.entangle_position_embedding + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + r""" + Get the embedding weight. + """ + return self.embed_tokens + + def set_input_embeddings(self, value): + r""" + Set the weight of embedding with the given tensor. + """ + self.embed_tokens = value + + def get_rel_pos_bias(self, x, idx): + r""" + Get the relative positional bias of the text, for attention. + """ + + seq_len = x.size(1) + rp_bucket = self.token_rp_bucket[:seq_len, :seq_len] + values = F.embedding(rp_bucket, + self.token_rel_pos_table_list[idx].weight) + values = values.unsqueeze(0).expand(x.size(0), -1, -1, -1) + values = values.permute([0, 3, 1, 2]) + return values.contiguous() + + def get_image_rel_pos_bias(self, image_position_ids, idx): + r""" + Get the relative positional bias of the image, for attention. + """ + + bsz, seq_len = image_position_ids.shape + rp_bucket_size = self.image_rp_bucket.size(1) + + rp_bucket = self.image_rp_bucket.unsqueeze(0).expand( + bsz, rp_bucket_size, rp_bucket_size).gather( + 1, image_position_ids[:, :, None].expand( + bsz, seq_len, rp_bucket_size)).gather( + 2, image_position_ids[:, None, :].expand( + bsz, seq_len, seq_len)) + values = F.embedding(rp_bucket, + self.image_rel_pos_table_list[idx].weight) + values = values.permute(0, 3, 1, 2) + return values + + def get_patch_images_info(self, patch_images, sample_patch_num, device): + r""" + Get the basic information of the resized image. + + Args: + patch_images (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): the resized image. + sample_patch_num (`int`): + the number of patches to sample. If it is equal to -1, no sampling will be performed. + device: GPU device. + + Returns: + image_embed (`torch.FloatTensor` of shape `(bsz, h * w, hidden)`): the output of the visual encoder. + image_num_patches (`int`, equal to `h * w`): the number of patches. + image_padding_mask (`torch.BooleanTensor` of shape `(bsz, h*w)`): image padding mask. + image_position_ids (`torch.LongTensor` of shape `(bsz, h*w)`): image position ids. + image_pos_embed (`torch.FloatTensor` of shape (bsz, h*w, hidden)): the positional embedding. + """ + + image_embed = self.embed_images(patch_images) + h, w = image_embed.shape[-2:] + image_num_patches = h * w + image_padding_mask = patch_images.new_zeros( + (patch_images.size(0), image_num_patches)).bool() + image_position_idx = torch.arange(w).unsqueeze(0).expand(h, w)\ + + torch.arange(h).unsqueeze(1) * self.image_bucket_size + 1 + image_position_idx = image_position_idx.view(-1).to(device) + image_position_ids = image_position_idx[None, :].expand( + patch_images.size(0), image_num_patches) + + image_embed = image_embed.flatten(2).transpose(1, 2) + if sample_patch_num is not None: + patch_orders = [ + random.sample(range(image_num_patches), k=sample_patch_num) + for _ in range(patch_images.size(0)) + ] + patch_orders = torch.LongTensor(patch_orders).to(device) + image_embed = image_embed.gather( + 1, + patch_orders.unsqueeze(2).expand(-1, -1, image_embed.size(2))) + image_num_patches = sample_patch_num + image_padding_mask = image_padding_mask.gather(1, patch_orders) + image_position_ids = image_position_ids.gather(1, patch_orders) + image_pos_embed = self.embed_image_positions(image_position_ids) + + return image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed + + def forward_embedding(self, + input_ids, + image_embed: Optional[torch.Tensor] = None, + image_embed_2: Optional[torch.Tensor] = None, + token_embedding: Optional[torch.Tensor] = None, + pos_embed: Optional[torch.Tensor] = None, + image_pos_embed: Optional[torch.Tensor] = None, + image_pos_embed_2: Optional[torch.Tensor] = None): + r""" + Generate embeddings of both the image and the text. + Actually since OFA unifies both unimodal and multimodal data, + image inputs are optional. + + Args: + input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of the tokens in the vocabulary. + image_embed (`torch.FloatTensor` of shape `(bsz, h*w, embed_dim)`, *optional*): image embeddings. + image_embed_2 (`torch.FloatTensor` of shape `(bsz, h*w, embed_dim)`, *optional*): + image embeddings of the second image (if it exists). + token_embedding (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`, *optional*): + input token embeddings to replace the embeddings of input ids. + image_pos_embed (`torch.FloatTensor` of shape `(bsz, h*w, embed_dim)`, *optional*): + positional embeddings of the image. + image_pos_embed_2 (`torch.FloatTensor` of shape `(bsz, h*w, embed_dim)`, *optional*): + positional embeddings of the second image. + + Returns: + x (`torch.FloatTensor` of shape `(bsz, h*w+seq_len, embed_dim)`): embeddings of the input. + embed (`torch.FloatTensor` of shape `(bsz, h*w+seq_len, embed_dim)`): + embeddings without adding positional and type embeddings. + """ + + # embed tokens and positions + if token_embedding is None: + token_embedding = self.embed_tokens(input_ids) + x = embed = self.embed_scale * token_embedding + if self.entangle_position_embedding and pos_embed is not None: + x += pos_embed + if self.type_embedding is not None: + x += self.type_embedding(input_ids.new_zeros(x.size()[:2])) + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + x = self.dropout(x) + + # embed raw images + if image_embed is not None: + image_embed = self.image_proj(image_embed) + image_x = image_embed = self.embed_scale * image_embed + if self.entangle_position_embedding and image_pos_embed is not None: + image_x += image_pos_embed + if self.type_embedding is not None: + image_x += self.type_embedding( + input_ids.new_ones(image_x.size()[:2])) + if self.patch_layernorm_embedding is not None: + image_x = self.patch_layernorm_embedding(image_x) + image_x = self.dropout(image_x) + x = torch.cat([image_x, x], dim=1) + embed = torch.cat([image_embed, embed], dim=1) + + if image_embed_2 is not None: + assert self.type_embedding is not None + image_embed_2 = self.image_proj(image_embed_2) + image_x_2 = image_embed_2 = self.embed_scale * image_embed_2 + if self.entangle_position_embedding and image_pos_embed_2 is not None: + image_x_2 += image_pos_embed_2 + if self.type_embedding is not None: + image_x_2 += self.type_embedding( + input_ids.new_full(image_x_2.size()[:2], fill_value=2)) + if self.patch_layernorm_embedding is not None: + image_x_2 = self.patch_layernorm_embedding(image_x_2) + image_x_2 = self.dropout(image_x_2) + if self.quant_noise is not None: + image_x_2 = self.quant_noise(image_x_2) + x = torch.cat([image_x_2, x], dim=1) + embed = torch.cat([image_embed_2, embed], dim=1) + + return x, embed + + def reorder_encoder_out(self, encoder_out, new_order): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + # if encoder_out["last_hidden_state"] is None: + if 'last_hidden_state' not in encoder_out: + new_encoder_out = None + else: + new_encoder_out = encoder_out['last_hidden_state'].index_select( + 0, new_order) + # if encoder_out["padding_mask"] is None: + if 'padding_mask' not in encoder_out: + new_encoder_padding_mask = None + else: + new_encoder_padding_mask = encoder_out[ + 'padding_mask'].index_select(0, new_order) + + # if encoder_out["position_embedding"] is None: + if 'position_embedding' not in encoder_out: + new_position_embeddings = None + else: + new_position_embeddings = encoder_out[ + 'position_embedding'].index_select(0, new_order) + + if 'hidden_states' not in encoder_out: + new_encoer_states = None + else: + encoder_states = encoder_out['hidden_states'] + new_encoer_states = () + if len(encoder_states) > 0: + for idx, state in enumerate(encoder_states): + new_encoer_states += (state.index_select(0, new_order), ) + + if 'attentions' not in encoder_out: + attentions = None + else: + attentions = encoder_out['attentions'] + + return OFAEncoderOutput( + last_hidden_state=new_encoder_out, # B x T x C + padding_mask=new_encoder_padding_mask, # B x T + hidden_states=new_encoer_states, # List[T x B x C] + attentions=attentions, + position_embedding=new_position_embeddings # B x T x C + ) + + def forward( + self, + input_ids=None, + patch_images: Optional[torch.Tensor] = None, + patch_images_2: Optional[torch.Tensor] = None, + patch_masks: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + token_embeddings: Optional[torch.Tensor] = None, + sample_patch_num: Optional[int] = None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): + indices of input sequence tokens in the vocabular, and padding will be ignored by default; + + indices can be obtained using [`~OFATokenizer`]. + + patch_images (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): + the resized image, which are transformed by the default operations. + patch_images_2 (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): + the second (if it exists) image. + patch_masks (`torch.BoolTensor`): the patches to be masked. + output_attentions (`bool`): whether to return all attention weights, + output_hidden_states (`bool`): whether to return all hidden states. + token_embeddings (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): token embeddings. + sample_patch_num (`int`): the number of patches to sample. + + Returns: + [`OFAEncoderOutput`]: + last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): + the states of the last layer. + padding_mask (`torch.BoolTensor` of shape `(bsz, seq_len)`): + the padding mask of the source context. + hidden_states (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): + the states of all layers including the embeddings. + attentions (`torch.FloatTensor` of shape `(bsz, num_heads, seq_len, seq_len)`): + the attention weights of all layers. + position_embedding (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): + positional embeddings of the input image and tokens. + """ + + image_embed = None + image_embed_2 = None + image_pos_embed = None + image_pos_embed_2 = None + if patch_images is not None: + image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = \ + self.get_patch_images_info(patch_images, sample_patch_num, input_ids.device) + # print("patch_masks.shape") + # print(patch_masks.shape) + # print(patch_masks) + # print("image_padding_mask.shape") + # print(image_padding_mask.shape) + # print(image_padding_mask) + image_padding_mask[~patch_masks] = True + # print(image_padding_mask) + if patch_images_2 is not None: + image_embed_2, image_num_patches_2, image_padding_mask_2, image_position_ids_2, image_pos_embed_2 = \ + self.get_patch_images_info(patch_images_2, sample_patch_num, input_ids.device) + image_padding_mask_2[~patch_masks] = True + + encoder_padding_mask = input_ids.eq(self.padding_idx) + if patch_images is not None: + encoder_padding_mask = torch.cat( + [image_padding_mask, encoder_padding_mask], dim=1) + if patch_images_2 is not None: + encoder_padding_mask = torch.cat( + [image_padding_mask_2, encoder_padding_mask], dim=1) + has_pads = encoder_padding_mask.any() + + pos_embed = self.embed_positions(new_arange(input_ids)) + x, encoder_embedding = self.forward_embedding( + input_ids, image_embed, image_embed_2, token_embeddings, pos_embed, + image_pos_embed, image_pos_embed_2) + + # account for padding while computing the representation + if has_pads: + x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) + + pos_embed = self.pos_ln(pos_embed) + if patch_images is not None: + image_pos_embed = self.image_pos_ln(image_pos_embed) + pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1) + if patch_images_2 is not None: + image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2) + pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1) + + pos_q = self.pos_q_linear(pos_embed).view( + x.size(0), x.size(1), self.num_attention_heads, -1).transpose( + 1, 2) * self.pos_scaling + pos_k = self.pos_k_linear(pos_embed).view( + x.size(0), x.size(1), self.num_attention_heads, + -1).transpose(1, 2) + abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) + + # expand attention_mask + if has_pads: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(encoder_padding_mask, dtype=x.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # if output_hidden_states: + # # encoder_states.append(x) + # encoder_states += (x,) + + # encoder layers + for idx, layer in enumerate(self.layers): + if output_hidden_states: + encoder_states += (x, ) + self_attn_bias = abs_pos_bias.clone() + self_attn_bias[:, :, -input_ids.size(1):, + -input_ids.size(1):] += self.get_rel_pos_bias( + input_ids, idx) + if patch_images_2 is not None: + self_attn_bias[:, :, :image_num_patches_2, :image_num_patches_2] += \ + self.get_image_rel_pos_bias(image_position_ids_2, idx) + self_attn_bias[:, :, + image_num_patches_2:image_num_patches_2 + image_num_patches, # noqa + image_num_patches_2:image_num_patches_2 + image_num_patches] += \ + self.get_image_rel_pos_bias(image_position_ids, idx) # noqa + elif patch_images is not None: + self_attn_bias[:, :, :x.size(1) - input_ids.size(1), :x.size(1) - input_ids.size(1)] += \ + self.get_image_rel_pos_bias(image_position_ids, idx) + self_attn_bias = self_attn_bias.reshape(-1, x.size(1), x.size(1)) + + hidden_outputs = layer( + x, + attention_mask if has_pads else None, + attn_bias=self_attn_bias, + output_attentions=output_attentions) + x = hidden_outputs[0] + + if output_attentions: + attention = hidden_outputs[1] + all_attentions = all_attentions + (attention, ) + + if output_hidden_states: + encoder_states += (x, ) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + return OFAEncoderOutput( + last_hidden_state=x, + padding_mask=encoder_padding_mask, + hidden_states=encoder_states, + attentions=all_attentions, + position_embedding=pos_embed) + + +class OFADecoder(OFAPreTrainedModel): + r""" + OFA decoder consisting of layers of [`OFADecoderLayer`] + + Args: + config: OFAConfig + embed_tokens (`nn.Embedding`, *optional*): output embedding + """ + + def __init__(self, + config: OFAConfig, + embed_tokens: Optional[nn.Embedding] = None, + output_projection=None): + super().__init__(config) + self.dropout = nn.Dropout(config.dropout) + self.decoder_layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self._future_mask = torch.empty(0) + self.share_input_output_embed = config.share_decoder_input_output_embed + self.num_attention_heads = config.decoder_attention_heads + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, + self.padding_idx) + + self.embed_dim = config.d_model + self.output_embed_dim = config.d_model + + self.layers = nn.ModuleList( + [OFADecoderLayer(config) for _ in range(config.decoder_layers)]) + if config.layernorm_embedding: + self.layernorm_embedding = LayerNorm(self.embed_dim) + else: + self.layernorm_embedding = None + + self.window_size = config.code_image_size // 8 + + self.embed_positions = Embedding(self.max_target_positions + 2, + self.embed_dim) + self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1, + self.embed_dim) + self.pos_ln = LayerNorm(self.embed_dim) + self.image_pos_ln = LayerNorm(self.embed_dim) + self.pos_scaling = float(self.embed_dim / self.num_attention_heads + * config.attn_scale_factor)**-0.5 + self.self_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim) + self.self_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim) + self.cross_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim) + self.cross_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim) + + if config.code_layernorm_embedding: + self.code_layernorm_embedding = LayerNorm(self.embed_dim) + else: + self.code_layernorm_embedding = None + + if self.decoder_layerdrop > 0.0: + self.layers = LayerDropModuleList(p=self.decoder_layerdrop) + else: + self.layers = nn.ModuleList([]) + + dpr = [ + x.item() for x in torch.linspace(0, config.decoder_drop_path_rate, + config.decoder_layers) + ] + self.layers.extend([ + OFADecoderLayer(config, drop_path_rate=dpr[i]) + for i in range(config.decoder_layers) + ]) + self.num_layers = len(self.layers) + + if config.decoder_normalize_before: + self.layer_norm = LayerNorm(self.embed_dim) + else: + self.layer_norm = None + + self.adaptive_softmax = None + self.output_projection = output_projection + if self.output_projection is None: + self.build_output_projection(config) + + self.token_bucket_size = config.token_bucket_size + token_num_rel_dis = 2 * config.token_bucket_size - 1 + token_rp_bucket = make_token_bucket_position(config.token_bucket_size) + self.token_rel_pos_table_list = nn.ModuleList([ + Embedding( + token_num_rel_dis, self.num_attention_heads, zero_init=True) + for _ in range(config.decoder_layers) + ]) + + self.image_bucket_size = config.image_bucket_size + image_num_rel_dis = (2 * config.image_bucket_size + - 1) * (2 * config.image_bucket_size - 1) + 3 + image_rp_bucket = make_image_bucket_position(config.image_bucket_size, + image_num_rel_dis) + image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \ + torch.arange(self.window_size).unsqueeze(1) * config.image_bucket_size + 1 # noqa + image_position_idx = torch.cat( + [torch.tensor([0]), image_position_idx.view(-1)]) + image_position_idx = torch.cat( + [image_position_idx, + torch.tensor([1024] * 768)]) + self.image_rel_pos_table_list = nn.ModuleList([ + Embedding( + image_num_rel_dis, self.num_attention_heads, zero_init=True) + for _ in range(config.decoder_layers) + ]) + + self.register_buffer('token_rp_bucket', token_rp_bucket) + self.register_buffer('image_rp_bucket', image_rp_bucket) + self.register_buffer('image_position_idx', image_position_idx) + self.entangle_position_embedding = config.entangle_position_embedding + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def build_output_projection(self, config): + if self.share_input_output_embed: + self.output_projection = nn.Linear( + self.embed_tokens.weight.shape[1], + self.embed_tokens.weight.shape[0], + bias=False, + ) + self.output_projection.weight = self.embed_tokens.weight + else: + self.output_projection = nn.Linear( + self.output_embed_dim, config.vocab_size, bias=False) + nn.init.normal_( + self.output_projection.weight, + mean=0, + std=self.output_embed_dim**-0.5) + + def get_rel_pos_bias(self, x, idx): + r""" + Get the relative positional bias of the text, for attention. + """ + + seq_len = x.size(1) + rp_bucket = self.token_rp_bucket[:seq_len, :seq_len] + values = F.embedding(rp_bucket, + self.token_rel_pos_table_list[idx].weight) + values = values.permute([2, 0, 1]) + return values.contiguous() + + def get_image_rel_pos_bias(self, x, idx): + r""" + Get the relative positional bias of the image, for attention. + """ + + seq_len = x.size(1) + image_position_idx = self.image_position_idx[:seq_len] + rp_bucket = self.image_rp_bucket[ + image_position_idx][:, image_position_idx] + values = F.embedding(rp_bucket, + self.image_rel_pos_table_list[idx].weight) + values = values.permute(2, 0, 1) + return values + + def get_pos_info(self, tgt_pos_embed, src_pos_embed=None, use_image=False): + r""" + Get the positional information. + + Args: + tgt_pos_embed (`torch.FloatTensor` of shape `(bsz, tgt_len, embed_dim)`): + the target-side positional embeddings. + src_pos_embed (`torch.FloatTensor` of shape `(bsz, src_len, embed_dim)`, *optional*): + the source-side positional embeddings. + use_image (`bool`): whether to use image. + + Returns: + abs_pos_bias (`torch.FloatTensor` of shape `(bsz, src_len, tgt_len, src_len)`): + absolute positional bias for attention. + """ + + batch_size = tgt_pos_embed.size(0) + tgt_len = tgt_pos_embed.size(1) + tgt_pos_embed = self.image_pos_ln( + tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed) + + if src_pos_embed is not None: + src_len = src_pos_embed.size(1) + pos_q = self.cross_pos_q_linear(tgt_pos_embed).view( + batch_size, tgt_len, self.num_attention_heads, -1).transpose( + 1, 2) * self.pos_scaling + pos_k = self.cross_pos_k_linear(src_pos_embed).view( + batch_size, src_len, self.num_attention_heads, + -1).transpose(1, 2) + else: + src_len = tgt_pos_embed.size(1) + pos_q = self.self_pos_q_linear(tgt_pos_embed).view( + batch_size, tgt_len, self.num_attention_heads, -1).transpose( + 1, 2) * self.pos_scaling + pos_k = self.self_pos_k_linear(tgt_pos_embed).view( + batch_size, src_len, self.num_attention_heads, + -1).transpose(1, 2) + abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) + + return abs_pos_bias + + def get_input_embeddings(self): + r""" + Get the input embeddings + """ + return self.embed_tokens + + def set_input_embeddings(self, value): + r""" + Set the weights of the embeddings with the given tensor. + """ + self.embed_tokens = value + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, + dtype, past_key_values_length): + r""" + Create causal mask for unidirectional decoding. + [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + """ + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + dtype, + past_key_values_length=past_key_values_length).to(self.device) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else + expanded_attn_mask + combined_attention_mask) + + return combined_attention_mask + + def max_positions(self): + """Maximum output length supported by the decoder.""" + if self.embed_positions is None: + return self.max_target_positions + return self.max_target_positions + + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" + return self.get_normalized_probs_scriptable(net_output, log_probs, + sample) + + def get_normalized_probs_scriptable( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" + + if hasattr(self, + 'adaptive_softmax') and self.adaptive_softmax is not None: + if sample is not None: + assert 'target' in sample + target = sample['target'] + else: + target = None + out = self.adaptive_softmax.get_log_prob( + net_output[0], target=target) + return out.exp_() if not log_probs else out + + logits = net_output[0] + if log_probs: + return utils.log_softmax(logits, dim=-1) + else: + return utils.softmax(logits, dim=-1) + + def reorder_incremental_state_scripting( + self, + # incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + past_key_values: Optional[torch.Tensor], + new_order: Tensor, + ): + """Main entry point for reordering the incremental state. + + Due to limitations in TorchScript, we call this function in + :class:`fairseq.sequence_generator.SequenceGenerator` instead of + calling :func:`reorder_incremental_state` directly. + """ + input_buffer = past_key_values + new_past_key_values = [] + if input_buffer is not None: + for input_buffer_k in input_buffer: + new_input_buffer_k = [] + for input in input_buffer_k: + if input is None: + input = None + else: + input = input.index_select(0, new_order) + new_input_buffer_k.append(input) + new_past_key_values.append(new_input_buffer_k) + return new_past_key_values + + def forward( + self, + input_ids: torch.Tensor = None, + attention_mask: torch.Tensor = None, + encoder_hidden_states: torch.Tensor = None, + encoder_attention_mask: torch.Tensor = None, + code_masks: Optional[torch.Tensor] = None, + src_pos_embed: torch.Tensor = None, + past_key_values: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of the sequence in the vocabulary. + attention_mask (`torch.Tensor` of shape `(bsz, seq_len)`): mask to avoid attention on padding tokens. + encoder_hidden_states (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): the last hidden state of the encoder. + encoder_attention_mask (`torch.Tensor` of shape `(bsz, seq_len)`): the padding mask of the source side. + code_masks (`torch.Tensor` of shape `(bsz, seq_len)`): masks only for code generation. + src_pos_embed (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): the positional embeddings of the source side. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(bsz, num_heads, tgt_len, head_size)`) and 2 additional tensors of + shape `(bsz, num_heads, src_len, head_size)`. + use_cache (`bool`): whether to use cache for faster inference. + output_attentions (`bool`): whether to output attention weights. + output_hidden_states (`bool`): whether to output hidden states. + + Returns: + BaseModelOutputWithPastAndCrossAttentions or a plain tuple: + last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): the last hidden states. + past_key_values (`tuple(tuple(torch.FloatTensor)): past keys and values for faster inference. + hidden_states (`tuple(torch.FloatTensor)`): hidden states of all layers. + attentions (`tuple(torch.FloatTensor)): self attention weights of all layers. + cross_attentions (`tuple(torch.FloatTensor)): cross attention weights of all layers. + """ # noqa + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if past_key_values is not None and len(past_key_values) > 0: + size = past_key_values[0][0].size() + bsz, tgt_len = size[0], size[-2] + 1 + token_position_idx = torch.arange( + tgt_len, + device=input_ids.device).expand([bsz, tgt_len]).contiguous() + else: + bsz, tgt_len = input_ids.shape + token_position_idx = new_arange(input_ids) + tgt_pos_embed = self.embed_positions(token_position_idx) + if code_masks is not None and torch.any(code_masks): + image_position_idx = self.image_position_idx[:input_ids.size( + 1)].unsqueeze(0).expand(bsz, tgt_len) + tgt_pos_embed[code_masks] = self.embed_image_positions( + image_position_idx)[code_masks] + + # self attn position bias + self_abs_pos_bias = self.get_pos_info(tgt_pos_embed, use_image=False) + if code_masks is not None and torch.any(code_masks): + self_image_abs_pos_bias = self.get_pos_info( + tgt_pos_embed, use_image=True) + self_abs_pos_bias[code_masks] = self_image_abs_pos_bias[code_masks] + # cross attn position bias + cross_abs_pos_bias = self.get_pos_info( + tgt_pos_embed, src_pos_embed=src_pos_embed) + if code_masks is not None and torch.any(code_masks): + cross_image_abs_pos_bias = self.get_pos_info( + tgt_pos_embed, src_pos_embed=src_pos_embed, use_image=True) + cross_abs_pos_bias[code_masks] = cross_image_abs_pos_bias[ + code_masks] + cross_abs_pos_bias = cross_abs_pos_bias.reshape( + -1, + *cross_abs_pos_bias.size()[-2:]) + + all_prev_output_tokens = input_ids.clone() + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + cross_abs_pos_bias = cross_abs_pos_bias[:, -1:, :] + tgt_pos_embed = tgt_pos_embed[:, -1:, :] + + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(input_ids) + + if self.entangle_position_embedding and not self.disable_entangle: + x += tgt_pos_embed + + if self.layernorm_embedding is not None: + if code_masks is None or not code_masks.any( + ) or not self.code_layernorm_embedding: + x = self.layernorm_embedding(x) + elif code_masks is not None and code_masks.all(): + x = self.code_layernorm_embedding(x) + else: + x[~code_masks] = self.layernorm_embedding(x[~code_masks]) + x[code_masks] = self.code_layernorm_embedding(x[code_masks]) + + hidden_states = self.dropout(x) + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[ + 2] if past_key_values is not None and len( + past_key_values) > 0 else 0 + + shape, dtype = input_ids.shape, hidden_states.dtype + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, shape, dtype, past_key_values_length) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if ( + output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # decoder layers + for idx, layer in enumerate(self.layers): + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + past_key_value = past_key_values[ + idx] if past_key_values is not None and len( + past_key_values) > 0 else None + + self_attn_bias = self_abs_pos_bias.clone() + if code_masks is None or not code_masks.any(): + # print("code_masks is None or not code_masks.any()") + self_attn_bias += self.get_rel_pos_bias( + all_prev_output_tokens, idx).unsqueeze(0) + elif code_masks is not None and code_masks.all(): + # print("code_masks is not None and code_masks.all()") + self_attn_bias += self.get_image_rel_pos_bias( + all_prev_output_tokens, idx).unsqueeze(0) + else: + # print("else") + self_attn_bias[~code_masks] += self.get_rel_pos_bias( + all_prev_output_tokens, idx).unsqueeze(0) + self_attn_bias[code_masks] += self.get_image_rel_pos_bias( + all_prev_output_tokens, idx).unsqueeze(0) + self_attn_bias = self_attn_bias.reshape( + -1, + *self_attn_bias.size()[-2:]) + if past_key_value is not None and len(past_key_values) > 0: + self_attn_bias = self_attn_bias[:, -1:, :] + + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + self_attn_bias=self_attn_bias, + cross_attn_bias=cross_abs_pos_bias, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += ( + layer_outputs[3 if output_attentions else 1], ) + + if output_attentions: + all_self_attns += (layer_outputs[1], ) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2], ) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + next_cache = next_decoder_cache if use_cache else None + + if self.layer_norm is not None: + hidden_states = self.layer_norm(hidden_states) + + if self.output_projection is not None: + hidden_states = self.output_projection(hidden_states) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, # (bz, + past_key_values=next_cache, # (bz, n_heads, seq_len, head_dim) + hidden_states=all_hidden_states, + attentions=all_self_attns, # (bz, n_heads, tgt_len, src_len) + cross_attentions= # noqa + all_cross_attentions # (bz, n_heads, tgt_len, src_len) # noqa + ) + + +@add_start_docstrings( + 'The bare OFA Model outputting raw hidden-states without any specific head on top.', + OFA_START_DOCSTRING, +) +class OFAModel(OFAPreTrainedModel): + r""" + The OFA model built with an encoder and a decoder only, without any classification head. + + Args: + config (OFAConfig): OFA configuration. + """ + + def __init__(self, config: OFAConfig, **kwargs): + super().__init__(config) + self.disable_entangle = getattr(kwargs, 'disable_entangle', False) + + self.padding_idx, vocab_size = config.pad_token_id, config.vocab_size + shared = nn.Embedding(vocab_size, config.d_model, self.padding_idx) + + self.encoder = OFAEncoder(config, shared) + self.decoder = OFADecoder(config, shared) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + r""" + Retrieve input embeddings. + """ + return self.encoder.get_input_embeddings() + + def set_input_embeddings(self, value): + r""" + Set values for input embeddings + """ + shared = value + self.encoder.embed_tokens = shared + self.decoder.embed_tokens = shared + + def get_encoder(self): + r""" + Retrieve the encoder + """ + return self.encoder + + def get_decoder(self): + r""" + Retrieve the decoder + """ + return self.decoder + + @add_start_docstrings_to_model_forward(OFA_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Seq2SeqModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + # 新增函数以适配fairseq的generator + def max_decoder_positions(self): + """Maximum length supported by the decoder.""" + return self.decoder.max_positions() + + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" + return self.get_normalized_probs_scriptable(net_output, log_probs, + sample) + + # TorchScript doesn't support super() method so that the scriptable Subclass + # can't access the base class model in Torchscript. + # Current workaround is to add a helper function with different name and + # call the helper function from scriptable Subclass. + + def get_normalized_probs_scriptable( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Scriptable helper function for get_normalized_probs in ~BaseFairseqModel""" + if hasattr(self, 'decoder'): + return self.decoder.get_normalized_probs(net_output, log_probs, + sample) + elif torch.is_tensor(net_output): + # syntactic sugar for simple models which don't have a decoder + # (e.g., the classification tutorial) + logits = net_output.float() + if log_probs: + return F.log_softmax(logits, dim=-1) + else: + return F.softmax(logits, dim=-1) + raise NotImplementedError + + def forward(self, + input_ids=None, + patch_images=None, + patch_images_2=None, + patch_masks=None, + token_embeddings=None, + sample_patch_num=None, + decoder_input_ids=None, + code_masks=None, + attention_mask=None, + encoder_outputs=None, + past_key_values=None, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + return_dict=False): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): + indices of input sequence tokens in the vocabular, and padding will be ignored by default; + + indices can be obtained using [`~OFATokenizer`]. + + patch_images (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): + the resized image, which are transformed by the default operations. + patch_images_2 (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): + the second (if it exists) image. + patch_masks (`torch.BoolTensor`): the patches to be masked. + token_embeddings (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): token embeddings. + sample_patch_num (`int`): the number of patches to sample. + decoder_input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of the sequence in the vocabulary. + code_masks (`torch.Tensor` of shape `(bsz, seq_len)`): masks only for code generation. + attention_mask (`torch.Tensor` of shape `(bsz, seq_len)`): attention mask for decoding. + encoder_outputs (`OFAEncoderOutput`): + encoder outputs with hidden states, positional embeddings, and padding masks. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(bsz, num_heads, tgt_len, head_size)`) and 2 additional tensors of + shape `(bsz, num_heads, src_len, head_size)`. + use_cache (`bool`): whether to use cache for faster inference. + output_attentions (`bool`): whether to output attention weights. + output_hidden_states (`bool`): whether to output hidden states. + return_dict (`bool`): unused. Keep it for generation only. + + Returns: + Seq2SeqModelOutput: + last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): the last decoder hidden states. + past_key_values (`tuple(tuple(torch.FloatTensor)): past keys and values for faster inference. + decoder_hidden_states (`tuple(torch.FloatTensor)`): the decoder hidden states of all layers. + decoder_attentions (`tuple(torch.FloatTensor)): the decoder self attention weights of all layers. + cross_attentions (`tuple(torch.FloatTensor)): cross attention weights of all layers. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): + the encoder last hidden state. + encoder_hidden_states (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): + the encoder states of all layers including the embeddings. + encoder_attentions (`torch.FloatTensor` of shape `(bsz, num_heads, seq_len, seq_len)`): + the encoder attention weights of all layers. + """ # noqa + + output_attentions = output_attentions if output_attentions else self.config.output_attentions + output_hidden_states = ( + output_hidden_states + if output_hidden_states else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + patch_images=patch_images, + patch_images_2=patch_images_2, + patch_masks=patch_masks, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + token_embeddings=token_embeddings, + sample_patch_num=sample_patch_num, + ) + + if decoder_input_ids.eq(self.config.pad_token_id).any(): + attention_mask = decoder_input_ids.eq(self.padding_idx) + + encoder_hidden_states = encoder_outputs.last_hidden_state + encoder_attention_mask = _expand_mask(encoder_outputs.padding_mask, + encoder_hidden_states.dtype, + decoder_input_ids.shape[-1]) + src_pos_embed = encoder_outputs.position_embedding + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + code_masks=code_masks, + src_pos_embed=src_pos_embed, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + return Seq2SeqLMOutput( + logits=decoder_outputs.last_hidden_state, + # last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation(self, + decoder_input_ids=None, + past=None, + attention_mask=None, + code_masks=None, + use_cache=False, + encoder_outputs=None, + **kwargs): + # if attention_mask is None: + attention_mask = decoder_input_ids.new_zeros(decoder_input_ids.shape) + + # cut decoder_input_ids if past is used + if past is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + 'input_ids': None, + 'patch_images': None, + 'patch_images_2': None, + 'patch_masks': None, + 'token_embeddings': None, + 'sample_patch_num': None, + 'attention_mask': attention_mask, + 'encoder_outputs': encoder_outputs, + 'past_key_values': past, + 'decoder_input_ids': decoder_input_ids, + 'code_masks': code_masks, + 'use_cache': use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, + self.config.decoder_start_token_id) + + def _prepare_encoder_decoder_kwargs_for_generation( + self, + inputs_tensor: torch.Tensor, + model_kwargs, + model_input_name: Optional[str] = None): + # 1. get encoder + encoder = self.get_encoder() + + # 2. prepare encoder args and encoder kwargs from model kwargs + irrelevant_prefix = [ + 'decoder_', 'cross_attn', 'use_cache', 'attention_mask' + ] + encoder_kwargs = { + argument: value + for argument, value in model_kwargs.items() + if not any(argument.startswith(p) for p in irrelevant_prefix) + } + + if encoder_kwargs.get('patch_masks') is None: + encoder_kwargs['patch_masks'] = torch.tensor([True]) + + # 3. make sure that encoder returns `ModelOutput` + model_input_name = model_input_name if model_input_name is not None else self.main_input_name + encoder_kwargs[model_input_name] = inputs_tensor + model_kwargs['encoder_outputs']: ModelOutput = encoder( + **encoder_kwargs) + model_kwargs['attention_mask'] = None + + return model_kwargs + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past + + @staticmethod + def _expand_inputs_for_generation( + input_ids: torch.LongTensor, + expand_size: int = 1, + is_encoder_decoder: bool = False, + attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[ModelOutput] = None, + **model_kwargs, + ): + expanded_return_idx = ( + torch.arange(input_ids.shape[0]).view(-1, 1).repeat( + 1, expand_size).view(-1).to(input_ids.device)) + input_ids = input_ids.index_select(0, expanded_return_idx) + + if 'token_type_ids' in model_kwargs: + token_type_ids = model_kwargs['token_type_ids'] + model_kwargs['token_type_ids'] = token_type_ids.index_select( + 0, expanded_return_idx) + + if attention_mask is not None: + model_kwargs['attention_mask'] = attention_mask.index_select( + 0, expanded_return_idx) + + if is_encoder_decoder: + if encoder_outputs is None: + raise ValueError( + 'If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.' + ) + encoder_outputs[ + 'last_hidden_state'] = encoder_outputs.last_hidden_state.index_select( + 0, + expanded_return_idx.to( + encoder_outputs.last_hidden_state.device)) + encoder_outputs[ + 'position_embedding'] = encoder_outputs.position_embedding.index_select( + 0, + expanded_return_idx.to( + encoder_outputs.position_embedding.device)) + encoder_outputs[ + 'padding_mask'] = encoder_outputs.padding_mask.index_select( + 0, + expanded_return_idx.to( + encoder_outputs.padding_mask.device)) + model_kwargs['encoder_outputs'] = encoder_outputs + return input_ids, model_kwargs diff --git a/modelscope/models/multi_modal/ofa/resnet.py b/modelscope/models/multi_modal/ofa/resnet.py new file mode 100644 index 00000000..de6444ab --- /dev/null +++ b/modelscope/models/multi_modal/ofa/resnet.py @@ -0,0 +1,283 @@ +import torch +import torch.nn as nn + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a.sh different form of dropout in a.sh separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a.sh layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0], ) + (1, ) * ( + x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d( + in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError( + 'BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError( + 'Dilation > 1 not supported in BasicBlock') + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + assert False + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None, + drop_path_rate=0.0): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out = identity + self.drop_path(out) + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, + layers, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + norm_layer=None, + drop_path_rate=0.0): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError('replace_stride_with_dilation should be None ' + 'or a 3-element tuple, got {}'.format( + replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d( + 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer( + Bottleneck, 64, layers[0], drop_path_rate=drop_path_rate) + self.layer2 = self._make_layer( + Bottleneck, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0], + drop_path_rate=drop_path_rate) + self.layer3 = self._make_layer( + Bottleneck, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1], + drop_path_rate=drop_path_rate) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, + (nn.SyncBatchNorm, nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, + block, + planes, + blocks, + stride=1, + dilate=False, + drop_path_rate=0.0): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, blocks)] + for i in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + drop_path_rate=dpr[i])) + + return nn.Sequential(*layers) + + def _forward_impl(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + return x + + def forward(self, x): + return self._forward_impl(x) diff --git a/modelscope/models/multi_modal/ofa/tokenization_ofa.py b/modelscope/models/multi_modal/ofa/tokenization_ofa.py new file mode 100644 index 00000000..e40436b6 --- /dev/null +++ b/modelscope/models/multi_modal/ofa/tokenization_ofa.py @@ -0,0 +1,48 @@ +# Copyright 2022 OFA-Sys Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for OFA.""" +from transformers.models.bart.tokenization_bart import BartTokenizer +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {'vocab_file': 'vocab.json', 'merges_file': 'merges.txt'} + +PRETRAINED_VOCAB_FILES_MAP = { + 'vocab_file': { + 'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/vocab.json', + }, + 'merges_file': { + 'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/merges.txt', + }, +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + 'ofa-base': 1024, +} + + +class OFATokenizer(BartTokenizer): + """ + Construct a OFA tokenizer. + + [`~OFATokenizer`] is identical to [`BartTokenizer`] and runs end-to-end tokenization: punctuation splitting and + wordpiece. + + Refer to superclass [`BartTokenizer`] for usage examples and documentation concerning parameters. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES diff --git a/modelscope/models/multi_modal/ofa/tokenization_ofa_fast.py b/modelscope/models/multi_modal/ofa/tokenization_ofa_fast.py new file mode 100644 index 00000000..235d1b34 --- /dev/null +++ b/modelscope/models/multi_modal/ofa/tokenization_ofa_fast.py @@ -0,0 +1,59 @@ +# Copyright 2022 OFA-Sys Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for OFA.""" +from transformers.models.bart.tokenization_bart_fast import BartTokenizerFast +from transformers.utils import logging + +from .tokenization_ofa import OFATokenizer + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + 'vocab_file': 'vocab.json', + 'merges_file': 'merges.txt', + 'tokenizer_file': 'tokenizer.json' +} + +PRETRAINED_VOCAB_FILES_MAP = { + 'vocab_file': { + 'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/vocab.json', + }, + 'merges_file': { + 'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/merges.txt', + }, + 'tokenizer_file': { + 'ofa-base': + 'https://huggingface.co/ofa-base/resolve/main/tokenizer.json', + }, +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + 'ofa-base': 1024, +} + + +class OFATokenizerFast(BartTokenizerFast): + r""" + Construct a "fast" OFA tokenizer (backed by HuggingFace's *tokenizers* library). + + [`~OFATokenizerFast`] is identical to [`BartTokenizerFast`] and runs end-to-end tokenization: punctuation splitting + and wordpiece. + + Refer to superclass [`BartTokenizerFast`] for usage examples and documentation concerning parameters. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + slow_tokenizer_class = OFATokenizer diff --git a/modelscope/models/multi_modal/ofa_for_image_captioning_model.py b/modelscope/models/multi_modal/ofa_for_image_captioning_model.py new file mode 100644 index 00000000..d560852c --- /dev/null +++ b/modelscope/models/multi_modal/ofa_for_image_captioning_model.py @@ -0,0 +1,53 @@ +from typing import Any, Dict + +import torch.cuda + +from modelscope.metainfo import Models +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks +from ..base import Model +from ..builder import MODELS +from .ofa import OFAModel, OFATokenizer +from .ofa.generate import sequence_generator as sg +from .ofa.generate.utils import move_to_device + +__all__ = ['OfaForImageCaptioning'] + + +@MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa) +class OfaForImageCaptioning(Model): + + def __init__(self, model_dir, *args, **kwargs): + super().__init__(model_dir=model_dir, *args, **kwargs) + model = OFAModel.from_pretrained(model_dir) + + self.model = model.module if hasattr(model, 'module') else model + self.tokenizer = OFATokenizer.from_pretrained(model_dir) + self.tokenizer.add_tokens([''.format(i) for i in range(8192)]) + self.tokenizer.add_tokens([''.format(i) for i in range(1000)]) + self._device = torch.device('cuda') if torch.cuda.is_available() \ + else torch.device('cpu') + self.model.to(self._device) + # Initialize generator + sg_args = { + 'tokenizer': self.tokenizer, + 'beam_size': 5, + 'max_len_b': 16, + 'min_len': 1, + 'no_repeat_ngram_size': 3, + 'constraint_range': None + } + if hasattr(kwargs, 'beam_search'): + sg_args.update(kwargs['beam_search']) + self.generator = sg.SequenceGenerator(**sg_args) + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + input = move_to_device(input, self._device) + gen_output = self.generator.generate([self.model], input) + gen = [gen_output[i][0]['tokens'] for i in range(len(gen_output))] + result = self.tokenizer.batch_decode(gen, skip_special_tokens=True) + return {'image_id': '42', OutputKeys.CAPTION: result[0]} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + # What should we do here ? + return inputs diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index c47c6744..702700eb 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -44,7 +44,7 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/nlp_space_dialog-modeling'), Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, 'damo/nlp_space_dialog-state-tracking'), - Tasks.image_captioning: (Pipelines.image_caption, + Tasks.image_captioning: (Pipelines.image_captioning, 'damo/ofa_image-caption_coco_large_en'), Tasks.image_generation: (Pipelines.person_image_cartoon, diff --git a/modelscope/pipelines/multi_modal/image_captioning_pipeline.py b/modelscope/pipelines/multi_modal/image_captioning_pipeline.py index 039f61dd..62226ff5 100644 --- a/modelscope/pipelines/multi_modal/image_captioning_pipeline.py +++ b/modelscope/pipelines/multi_modal/image_captioning_pipeline.py @@ -11,7 +11,7 @@ logger = get_logger() @PIPELINES.register_module( - Tasks.image_captioning, module_name=Pipelines.image_caption) + Tasks.image_captioning, module_name=Pipelines.image_captioning) class ImageCaptionPipeline(Pipeline): def __init__(self, diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 1bc686eb..b5dc0cf4 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -2,13 +2,14 @@ import os.path as osp from typing import Any, Dict, Union -import numpy as np import torch from PIL import Image +from torchvision import transforms from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import Preprocessors -from modelscope.utils.constant import Fields, ModelFile +from modelscope.models.multi_modal.ofa import OFATokenizer +from modelscope.utils.constant import Fields from modelscope.utils.type_assert import type_assert from .base import Preprocessor from .builder import PREPROCESSORS @@ -31,84 +32,39 @@ class OfaImageCaptionPreprocessor(Preprocessor): model_dir (str): model path """ super().__init__(*args, **kwargs) + model_dir = model_dir if osp.exists(model_dir) else snapshot_download( + model_dir) + self.tokenizer = OFATokenizer.from_pretrained(model_dir) + self.tokenizer.add_tokens([''.format(i) for i in range(8192)]) + self.tokenizer.add_tokens([''.format(i) for i in range(1000)]) - if osp.exists(model_dir): - local_model_dir = model_dir - else: - local_model_dir = snapshot_download(model_dir) - local_model = osp.join(local_model_dir, ModelFile.TORCH_MODEL_FILE) - bpe_dir = local_model_dir - - from fairseq import checkpoint_utils, tasks, utils - from ofa.tasks.mm_tasks import CaptionTask - - tasks.register_task('caption', CaptionTask) - - overrides = { - 'bpe_dir': bpe_dir, - 'eval_cider': False, - 'beam': 5, - 'max_len_b': 16, - 'no_repeat_ngram_size': 3, - 'seed': 7 - } - model, cfg, task = checkpoint_utils.load_model_ensemble_and_task( - utils.split_paths(local_model), arg_overrides=overrides) - del model # Initialize transform - from torchvision import transforms mean = [0.5, 0.5, 0.5] std = [0.5, 0.5, 0.5] - + patch_image_size = 480 self.patch_resize_transform = transforms.Compose([ lambda image: image.convert('RGB'), - transforms.Resize( - (cfg.task.patch_image_size, cfg.task.patch_image_size), - interpolation=Image.BICUBIC), + transforms.Resize((patch_image_size, patch_image_size), + interpolation=Image.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ]) - self.task = task - self.bos_item = torch.LongTensor([task.src_dict.bos()]) - self.eos_item = torch.LongTensor([task.src_dict.eos()]) - self.pad_idx = task.src_dict.pad() - @type_assert(object, (str, tuple, Image.Image)) def __call__(self, data: Union[str, tuple]) -> Dict[str, Any]: - - def encode_text(text, length=None, append_bos=False, append_eos=False): - s = self.task.tgt_dict.encode_line( - line=self.task.bpe.encode(text), - add_if_not_exist=False, - append_eos=False).long() - if length is not None: - s = s[:length] - if append_bos: - s = torch.cat([self.bos_item, s]) - if append_eos: - s = torch.cat([s, self.eos_item]) - return s - if isinstance(data, Image.Image): patch_image = self.patch_resize_transform(data).unsqueeze(0) else: patch_image = self.patch_resize_transform( load_image(data)).unsqueeze(0) - patch_mask = torch.tensor([True]) - text = 'what does the image describe?' - src_text = encode_text( - text, append_bos=True, append_eos=True).unsqueeze(0) - src_length = torch.LongTensor( - [s.ne(self.pad_idx).long().sum() for s in src_text]) - sample = { - 'id': np.array(['42']), - 'net_input': { - 'src_tokens': src_text, - 'src_lengths': src_length, - 'patch_images': patch_image, - 'patch_masks': patch_mask, - } + text = ' what does the image describe?' + inputs = self.tokenizer([text], max_length=1024, + return_tensors='pt')['input_ids'] + sample = dict() + sample['net_input'] = { + 'input_ids': inputs, + 'patch_images': patch_image, + 'patch_masks': torch.tensor([True]) } return sample diff --git a/tests/pipelines/test_image_captioning.py b/tests/pipelines/test_image_captioning.py index 6bede92c..fc029146 100644 --- a/tests/pipelines/test_image_captioning.py +++ b/tests/pipelines/test_image_captioning.py @@ -14,7 +14,7 @@ class ImageCaptionTest(unittest.TestCase): def test_run(self): img_captioning = pipeline( Tasks.image_captioning, - model='damo/ofa_image-caption_coco_large_en') + model='damo/ofa_image-caption_coco_distilled_en') result = img_captioning('data/test/images/image_captioning.png') print(result[OutputKeys.CAPTION])