diff --git a/modelscope/models/multi_modal/mplug/__init__.py b/modelscope/models/multi_modal/mplug/__init__.py new file mode 100644 index 00000000..bca5849b --- /dev/null +++ b/modelscope/models/multi_modal/mplug/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_mplug import MPlugConfig +from .modeling_mplug import (CONFIG_NAME, VOCAB_NAME, + MPlugForVisualQuestionAnswering) diff --git a/modelscope/models/multi_modal/mplug/clip/__init__.py b/modelscope/models/multi_modal/mplug/clip/__init__.py new file mode 100644 index 00000000..05826f46 --- /dev/null +++ b/modelscope/models/multi_modal/mplug/clip/__init__.py @@ -0,0 +1 @@ +from .clip import load_from_config diff --git a/modelscope/models/multi_modal/mplug/clip/clip.py b/modelscope/models/multi_modal/mplug/clip/clip.py new file mode 100644 index 00000000..fbdfbd29 --- /dev/null +++ b/modelscope/models/multi_modal/mplug/clip/clip.py @@ -0,0 +1,401 @@ +# Copyright 2021 The OpenAI CLIP Authors. All rights reserved. + +from collections import OrderedDict +from typing import Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from modelscope.models.multi_modal.clip.clip_vit import Transformer + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict([('-1', nn.AvgPool2d(stride)), + ('0', + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False)), + ('1', nn.BatchNorm2d(planes * self.expansion))])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + + def __init__(self, + spacial_dim: int, + embed_dim: int, + num_heads: int, + output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], + x.shape[2] * x.shape[3]).permute(2, 0, + 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + if self.training: + dropout = 0.1 + else: + dropout = 0.0 + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=dropout, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, + layers, + output_dim, + heads, + input_resolution=224, + width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d( + width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, + heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x, skip_last_layer=False): + + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), + (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + if not skip_last_layer: + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x) + return ret.type(orig_type) + + +class VisualTransformer(nn.Module): + + def __init__(self, input_resolution: int, patch_size: int, width: int, + layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.heads = heads + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, + x: torch.Tensor, + skip_last_layer=False, + text_embedding=None, + text_mask=None): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + cls_emb = self.class_embedding.to(x.dtype) + x_zeros = torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) + x = torch.cat([cls_emb + x_zeros, x], + dim=1) # shape = [*, grid ** 2 + 1, width] + + x = x + self.positional_embedding.to(x.dtype)[:x.size(1), :] + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + + x = self.transformer(x) + + x = x.permute(1, 0, 2) # LND -> NLD + + if skip_last_layer: + x = self.ln_post(x) + # x = x @ self.proj + else: + x = x @ self.proj + return x + + +class CLIP(nn.Module): + + def __init__( + self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width) + else: + vision_heads = vision_width // 64 + self.visual = VisualTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask()) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter( + torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([])) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features**-0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [ + self.visual.layer1, self.visual.layer2, self.visual.layer3, + self.visual.layer4 + ]: + for name, param in resnet_block.named_parameters(): + if name.endswith('bn3.weight'): + nn.init.zeros_(param) + + proj_std = (self.transformer.width**-0.5) * ( + (2 * self.transformer.layers)**-0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width)**-0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_( + self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type( + self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), + text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm( + dim=-1, keepdim=True) + text_features = text_features / text_features.norm( + dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logit_scale * text_features @ image_features.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def load_from_config(config): + return CLIP(config.clip_embed_dim, config.clip_image_resolution, + config.clip_vision_layers, config.clip_vision_width, + config.clip_vision_patch_size, config.clip_context_length, + config.clip_vocab_size, config.clip_transformer_width, + config.clip_transformer_heads, config.clip_transformer_layers) diff --git a/modelscope/models/multi_modal/mplug/configuration_mplug.py b/modelscope/models/multi_modal/mplug/configuration_mplug.py new file mode 100644 index 00000000..6b2914c4 --- /dev/null +++ b/modelscope/models/multi_modal/mplug/configuration_mplug.py @@ -0,0 +1,125 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" MPLUG model configuration """ +import os +from collections import OrderedDict +from typing import Any, Dict, Mapping, Union + +import yaml +from transformers import PretrainedConfig +from transformers.onnx import OnnxConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class MPlugConfig(PretrainedConfig): + + model_type = 'mplug' + + def __init__( + self, + bert_config='config_bert.json', + image_res=504, + batch_size_train=128, + vision_width=1024, + distill=True, + clip_name='ViT-L-14', # ViT-B-16 | ViT-L-14 + batch_size_test=64, + k_test=128, + alpha=0.4, + warm_up=True, + eos='[SEP]', + optimizer=None, + schedular=None, + min_length=1, + max_length=10, + beam_size=5, + add_ocr=False, + add_object=False, + text_encoder='bert-base-uncased', + text_decoder='bert-base-uncased', + # clip + clip_embed_dim=768, + clip_image_resolution=224, + clip_vision_layers=24, + clip_vision_width=1024, + clip_vision_patch_size=14, + clip_context_length=77, + clip_vocab_size=49408, + clip_transformer_width=768, + clip_transformer_heads=12, + clip_transformer_layers=12, + **kwargs): + super().__init__(**kwargs) + self.bert_config = bert_config + self.image_res = image_res + self.batch_size_train = batch_size_train + self.vision_width = vision_width + self.distill = distill + self.clip_name = clip_name + self.batch_size_test = batch_size_test + self.k_test = k_test + self.alpha = alpha + self.warm_up = warm_up + self.eos = eos + self.optimizer = optimizer + self.schedular = schedular + self.min_length = min_length + self.max_length = max_length + self.beam_size = beam_size + self.add_ocr = add_ocr + self.add_object = add_object + self.text_encoder = text_encoder + self.text_decoder = text_decoder + # clip + self.clip_embed_dim = clip_embed_dim + self.clip_image_resolution = clip_image_resolution + self.clip_vision_layers = clip_vision_layers + self.clip_vision_width = clip_vision_width + self.clip_vision_patch_size = clip_vision_patch_size + self.clip_context_length = clip_context_length + self.clip_vocab_size = clip_vocab_size + self.clip_transformer_width = clip_transformer_width + self.clip_transformer_heads = clip_transformer_heads + self.clip_transformer_layers = clip_transformer_layers + + @classmethod + def from_yaml_file(cls, yaml_file: Union[str, + os.PathLike]) -> Dict[str, Any]: + with open(yaml_file, 'r') as reader: + config_dict = yaml.load(reader, Loader=yaml.Loader) + return cls(**config_dict) + + +class MPlugOnnxConfig(OnnxConfig): + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict([ + ('input_ids', { + 0: 'batch', + 1: 'sequence' + }), + ('attention_mask', { + 0: 'batch', + 1: 'sequence' + }), + ('token_type_ids', { + 0: 'batch', + 1: 'sequence' + }), + ]) diff --git a/modelscope/models/multi_modal/mplug/modeling_mplug.py b/modelscope/models/multi_modal/mplug/modeling_mplug.py new file mode 100755 index 00000000..0b45ea12 --- /dev/null +++ b/modelscope/models/multi_modal/mplug/modeling_mplug.py @@ -0,0 +1,2079 @@ +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MPLUG model. """ + +import math +import os +from typing import Tuple + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from torch import Tensor, device, nn +from torch.nn import CrossEntropyLoss +from transformers import BertConfig, BertTokenizer +from transformers.activations import ACT2FN +from transformers.file_utils import (add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions) +from transformers.modeling_utils import (PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer) +from transformers.utils import logging + +from modelscope.models.multi_modal.mplug.configuration_mplug import MPlugConfig +from modelscope.models.multi_modal.mplug.predictor import TextGenerator + +transformers.logging.set_verbosity_error() + +logger = logging.get_logger(__name__) + +CONFIG_NAME = 'config.yaml' +WEIGHTS_NAME = 'pytorch_model.bin' +VOCAB_NAME = 'vocab.txt' + +_CONFIG_FOR_DOC = 'BertConfig' +_TOKENIZER_FOR_DOC = 'BertTokenizer' + + +def load_tf_weights_in_bert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + 'Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see ' + 'https://www.tensorflow.org/install/ for installation instructions.' + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info('Converting TensorFlow checkpoint from {}'.format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info('Loading TF weight {} with shape {}'.format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split('/') + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any(n in [ + 'adam_v', 'adam_m', 'AdamWeightDecayOptimizer', + 'AdamWeightDecayOptimizer_1', 'global_step' + ] for n in name): + logger.info('Skipping {}'.format('/'.join(name))) + continue + pointer = model + for m_name in name: + if re.fullmatch(r'[A-Za-z]+_\d+', m_name): + scope_names = re.split(r'_(\d+)', m_name) + else: + scope_names = [m_name] + if scope_names[0] == 'kernel' or scope_names[0] == 'gamma': + pointer = getattr(pointer, 'weight') + elif scope_names[0] == 'output_bias' or scope_names[0] == 'beta': + pointer = getattr(pointer, 'bias') + elif scope_names[0] == 'output_weights': + pointer = getattr(pointer, 'weight') + elif scope_names[0] == 'squad': + pointer = getattr(pointer, 'classifier') + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info('Skipping {}'.format('/'.join(name))) + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == '_embeddings': + pointer = getattr(pointer, 'weight') + elif m_name == 'kernel': + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f'Pointer shape {pointer.shape} and array shape {array.shape} mismatched' + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info('Initialize PyTorch weight {}'.format(name)) + pointer.data = torch.from_numpy(array) + return model + + +def clamp_inf(tensor): + if tensor.dtype == torch.float16 and torch.isinf(tensor).any(): + clamp_value = torch.finfo(tensor.dtype).max - 1000 + tensor = torch.clamp(tensor, min=-clamp_value, max=clamp_value) + return tensor + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + 'position_ids', + torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + + self.config = config + + def forward(self, + input_ids=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, + past_key_values_length:seq_length + + past_key_values_length] + + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == 'absolute': + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, 'embedding_size'): + raise ValueError( + 'The hidden size (%d) is not a multiple of the number of attention ' + 'heads (%d)' % + (config.hidden_size, config.num_attention_heads)) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size + / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, + self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores( + self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores( + self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + attention_scores = clamp_inf(attention_scores) + if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == 'relative_key': + relative_position_scores = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == 'relative_key_query': + relative_position_scores_query = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + relative_position_scores_key = torch.einsum( + 'bhrd,lrd->bhlr', key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if output_attentions else (context_layer, ) + + outputs = outputs + (past_key_value, ) + return outputs + + +class BertSelfOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, + self.self.attention_head_size, self.pruned_heads) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len( + heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, + ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = clamp_inf(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = clamp_inf(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class FusionLayer(nn.Module): + + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.stride_layer = getattr(self.config, 'stride_layer', 100) + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + + self.crossattention = BertAttention(config, is_cross_attention=True) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + layer_nums=None, + past_key_value=None, + output_attentions=False, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[: + 2] if past_key_value is not None else None + if layer_nums == 0 or layer_nums % self.stride_layer != 0: + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + assert encoder_hidden_states is not None, 'encoder_hidden_states must be given for cross-attention layers' + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[ + 1:-1] # add cross attentions if we output attention weights + elif layer_nums != 0 and layer_nums % self.stride_layer == 0: + self_attention_outputs = self.attention( + torch.cat([encoder_hidden_states, hidden_states], 1), + torch.cat([encoder_attention_mask, attention_mask], 3), + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + layer_output = apply_chunking_to_forward(self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output) + outputs = (layer_output, ) + outputs + + outputs = outputs + (present_key_value[0], present_key_value[1]) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertLayer(nn.Module): + + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + + self.has_cross_attention = getattr(self.config, 'add_cross_attention', + False) + if self.has_cross_attention: + self.crossattention = BertAttention( + config, is_cross_attention=True) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[: + 2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + if self.has_cross_attention: + assert encoder_hidden_states is not None, 'encoder_hidden_states must be given for cross-attention layers' + + if type(encoder_hidden_states) == list: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states[(self.layer_num + - self.config.fusion_layer) + % len(encoder_hidden_states)], + encoder_attention_mask[(self.layer_num + - self.config.fusion_layer) + % len(encoder_hidden_states)], + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] + + else: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[ + 1: + -1] # add cross attentions if we output attention weights + layer_output = apply_chunking_to_forward(self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output) + outputs = (layer_output, ) + outputs + + outputs = outputs + (present_key_value[0], present_key_value[1]) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class FusionEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [FusionLayer(config, i) for i in range(config.num_hidden_layers)]) + self.start_layer = max(0, + config.num_hidden_layers - config.fusion_layers) + + def forward(self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + next_decoder_cache = () if use_cache else None + + self.stride_layer = getattr(self.config, 'stride_layer', 100) + image_length = encoder_hidden_states.shape[1] + text_length = hidden_states.shape[1] + + for i in range(self.start_layer, len(self.layer)): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + + if getattr(self.config, 'gradient_checkpointing', + False) and self.training: + if use_cache: + logger.warn( + '`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting ' + '`use_cache=False`...') + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + return tuple( + module(*inputs, past_key_value, output_attentions)) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + i - self.start_layer, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + i - self.start_layer, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_outputs[1], ) + if hidden_states.shape[1] == (image_length + text_length): + encoder_hidden_states_new, hidden_states = torch.split( + hidden_states, (image_length, text_length), 1) + encoder_hidden_states += encoder_hidden_states_new + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + return [encoder_hidden_states, hidden_states] + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)]) + + def forward(self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + ) if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + + for i in range(len(self.layer)): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + + if getattr(self.config, 'gradient_checkpointing', + False) and self.training: + if use_cache: + logger.warn( + '`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting ' + '`use_cache=False`...') + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + return tuple( + module(*inputs, past_key_value, output_attentions)) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), hidden_states, + attention_mask, layer_head_mask, encoder_hidden_states, + encoder_attention_mask) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_outputs[1], ) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + load_tf_weights = load_tf_weights_in_bert + base_model_prefix = 'bert' + _keys_to_ignore_on_load_missing = [r'position_ids'] + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +BERT_START_DOCSTRING = r""" + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + Parameters: + config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +BERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): + Indices of input sequence tokens in the vocabulary. + Indices can be obtained using :class:`~transformers.BertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +@add_start_docstrings( + 'The bare Bert Model transformer outputting raw hidden-states without any specific head on top.', + BERT_START_DOCSTRING, +) +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward( + BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint='bert-base-uncased', + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def get_extended_attention_mask(self, attention_mask: Tensor, + input_shape: Tuple[int], device: device, + is_decoder: bool) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to + # [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat( + batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[ + 1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, seq_length, prefix_seq_len), + device=device, + dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask[:, + None, :, :] * attention_mask[:, + None, + None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + 'Wrong shape for input_ids (shape {}) or attention_mask (shape {})' + .format(input_shape, attention_mask.shape)) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states + (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values + (:obj:`tuple(tuple(torch.FloatTensor))` of length + :obj:`config.n_layers` with each tuple having 4 tensors of shape + :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time' + ) + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError( + 'You have to specify either input_ids or inputs_embeds or encoder_embeds' + ) + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[ + 2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), + device=device) + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size( + ) + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) + for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler( + sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class FusionModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + self.encoder = FusionEncoder(config) + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward( + BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) + @add_code_sample_docstrings( + # tokenizer_class=_TOKENIZER_FOR_DOC, + processor_class=_TOKENIZER_FOR_DOC, + checkpoint='bert-base-uncased', + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def get_extended_attention_mask(self, attention_mask: Tensor, + input_shape: Tuple[int], device: device, + is_decoder: bool) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to + # [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat( + batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[ + 1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, seq_length, prefix_seq_len), + device=device, + dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask[:, + None, :, :] * attention_mask[:, + None, + None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + 'Wrong shape for input_ids (shape {}) or attention_mask (shape {})' + .format(input_shape, attention_mask.shape)) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward(self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False): + r""" + encoder_hidden_states + (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values + (:obj:`tuple(tuple(torch.FloatTensor))` of length + :obj:`config.n_layers` with each tuple having 4 tensors of shape + :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time' + ) + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError( + 'You have to specify either input_ids or inputs_embeds or encoder_embeds' + ) + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[ + 2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), + device=device) + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size( + ) + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) + for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + encoder_hidden_states, sequence_output = encoder_outputs + pooled_output = self.pooler( + sequence_output) if self.pooler is not None else None + + if not return_dict: + return [encoder_hidden_states, sequence_output] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings( + """Bert Model with a `language modeling` head on top for CLM fine-tuning. """, + BERT_START_DOCSTRING) +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward( + BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) + @replace_return_docstrings( + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=True, + reduction='mean', + soft_labels=None, + alpha=0, + return_logits=False, + ): + r""" + encoder_hidden_states + (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values + (:obj:`tuple(tuple(torch.FloatTensor))` of length + :obj:`config.n_layers` with each tuple having 4 tensors of shape + :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, : + -1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if soft_labels is not None: + loss_distill = -torch.sum( + F.log_softmax(shifted_prediction_scores, dim=1) * soft_labels, + dim=-1) + loss_distill = (loss_distill * (labels != -100)).sum(1) + lm_loss = (1 - alpha) * lm_loss + alpha * loss_distill + + if not return_dict: + output = (prediction_scores, ) + outputs[2:] + return ((lm_loss, ) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, + input_ids, + past=None, + attention_mask=None, + **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + 'input_ids': + input_ids, + 'attention_mask': + attention_mask, + 'past_key_values': + past, + 'encoder_hidden_states': + model_kwargs.get('encoder_hidden_states', None), + 'encoder_attention_mask': + model_kwargs.get('encoder_attention_mask', None), + 'is_decoder': + True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past + + +class MPlugForVisualQuestionAnswering(PreTrainedModel): + config_class = MPlugConfig + + def __init__(self, config): + super().__init__(config) + self.config = config + self.tokenizer = BertTokenizer.from_pretrained( + os.path.join(config.model_dir, VOCAB_NAME)) + self.module_setting(config) + self.visual_encoder = self._initialize_clip(config) + self.text_encoder = BertModel( + self.config_encoder, add_pooling_layer=False) + self.fusion_encoder = FusionModel( + self.config_fusion, add_pooling_layer=False) + self.text_decoder = BertLMHeadModel(self.config_decoder) + self.init_distill(config) + self.beam_generator = TextGenerator(config, self.text_decoder) + + @classmethod + def from_pretrained(cls, model_dir, load_checkpoint=True): + config = MPlugConfig.from_yaml_file( + os.path.join(model_dir, CONFIG_NAME)) + config.model_dir = model_dir + model = cls(config) + if load_checkpoint: + checkpoint_path = os.path.join(model_dir, WEIGHTS_NAME) + checkpoint = torch.load(checkpoint_path, map_location='cpu') + if 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint['module'] + + msg = model.load_state_dict(state_dict, strict=False) + print('load checkpoint from %s' % checkpoint_path) + print(msg) + return model + + @staticmethod + def _initialize_clip(config, num_patches=240): + + def resize_pos_embed(posemb, posemb_new): + # Rescale the grid of position embeddings when loading from state_dict. Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + ntok_new = posemb_new.shape[1] + + posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] + ntok_new -= 1 + + gs_old = int(math.sqrt(len(posemb_grid))) + gs_new = int(math.sqrt(ntok_new)) + # _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new) + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, + -1).permute(0, 3, 1, 2) + orig = posemb_grid.dtype + posemb_grid = F.interpolate( + posemb_grid.float(), size=(gs_new, gs_new), mode='bilinear') + posemb_grid = posemb_grid.to(orig) + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape( + 1, gs_new * gs_new, -1) + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + return posemb + + from .clip import clip + clip_model = clip.load_from_config(config) + if 'ViT-B-16' in config.clip_name: + num_patches = int(config.image_res * config.image_res / (16 * 16)) + pos_embed = nn.Parameter(torch.zeros(num_patches + 1, 768).float()) + else: + num_patches = int(config.image_res * config.image_res / (14 * 14)) + pos_embed = nn.Parameter( + torch.zeros(num_patches + 1, 1024).float()) + pos_embed.weight = resize_pos_embed( + clip_model.visual.positional_embedding.unsqueeze(0), + pos_embed.unsqueeze(0)) + clip_model.visual.positional_embedding = pos_embed + return clip_model + + def forward(self, + image, + question, + answer=None, + alpha=0, + k=None, + weights=None, + train=True): + image = image.to(dtype=next(self.parameters()).dtype) + image_embeds = self.visual_encoder.visual(image, skip_last_layer=True) + if self.large: + image_embeds = self.dropout( + self.visn_layer_norm(self.visn_fc(image_embeds))) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(image.device) + + if train: + ''' + k: number of answers for each question + weights: weight for each answer + ''' + answer_targets = answer.input_ids.masked_fill( + answer.input_ids == self.tokenizer.pad_token_id, -100) + text_output = self.text_encoder( + question.input_ids, + attention_mask=question.attention_mask, + return_dict=True) + text_embeds = text_output.last_hidden_state + fusion_output = self.fusion_encoder( + encoder_embeds=text_embeds, + attention_mask=question.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False) + + image_output, question_output = fusion_output + + question_output = torch.cat([image_output, question_output], 1) + merge_text_attention = torch.cat( + [image_atts, question.attention_mask], 1) + + question_states = [] + question_atts = [] + for b, n in enumerate(k): + question_states += [question_output[b]] * n + question_atts += [merge_text_attention[b]] * n + question_states = torch.stack(question_states, 0) + question_atts = torch.stack(question_atts, 0) + + if self.distill: + with torch.no_grad(): + self._momentum_update() + image_embeds_m = self.visual_encoder_m.visual( + image, skip_last_layer=True) + if self.large: + image_embeds_m = self.dropout_m( + self.visn_layer_norm_m( + self.visn_fc_m(image_embeds_m))) + text_output_m = self.text_encoder_m( + question.input_ids, + attention_mask=question.attention_mask, + return_dict=True) + text_embeds_m = text_output_m.last_hidden_state + fusion_output_m = self.fusion_encoder_m( + encoder_embeds=text_embeds_m, + attention_mask=question.attention_mask, + encoder_hidden_states=image_embeds_m, + encoder_attention_mask=image_atts, + return_dict=False) + + image_output_m, question_output_m = fusion_output_m + question_output_m = torch.cat( + [image_output_m, question_output_m], 1) + + question_states_m = [] + for b, n in enumerate(k): + question_states_m += [question_output_m[b]] * n + question_states_m = torch.stack(question_states_m, 0) + + logits_m = self.text_decoder_m( + answer.input_ids, + attention_mask=answer.attention_mask, + encoder_hidden_states=question_states_m, + encoder_attention_mask=question_atts, + return_logits=True, + ) + + answer_output = self.text_decoder( + answer.input_ids, + attention_mask=answer.attention_mask, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + labels=answer_targets, + return_dict=True, + soft_labels=F.softmax(logits_m, dim=-1), + reduction='none', + ) + else: + answer_output = self.text_decoder( + answer.input_ids, + attention_mask=answer.attention_mask, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + labels=answer_targets, + return_dict=True, + reduction='none', + ) + loss = weights * answer_output.loss + loss = loss.sum() / image.size(0) + + return loss + + else: + text_output = self.text_encoder( + question.input_ids, + attention_mask=question.attention_mask, + return_dict=True) + text_embeds = text_output.last_hidden_state + fusion_output = self.fusion_encoder( + encoder_embeds=text_embeds, + attention_mask=question.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False) + image_output, question_output = fusion_output + question_output = torch.cat([image_output, question_output], 1) + merge_text_attention = torch.cat( + [image_atts, question.attention_mask], 1) + topk_ids, topk_probs = self.generation(question_output, + merge_text_attention) + return topk_ids, topk_probs + + def module_setting(self, config): + bert_config_path = os.path.join(config.model_dir, config.bert_config) + self.config_encoder = BertConfig.from_json_file(bert_config_path) + self.config_encoder.num_hidden_layers = self.config_encoder.text_encoder_layers + self.config_fusion = BertConfig.from_json_file(bert_config_path) + self.config_decoder = BertConfig.from_json_file(bert_config_path) + self.config_decoder.add_cross_attention = True + self.config_decoder.num_hidden_layers = self.config_decoder.text_decode_layers + self.large = False + if self.config_encoder.hidden_size != config.vision_width: + self.visn_fc = nn.Linear(config.vision_width, + self.config_encoder.hidden_size) + self.visn_layer_norm = nn.LayerNorm( + self.config_encoder.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(self.config_encoder.hidden_dropout_prob) + self.large = True + + def init_distill(self, config): + self.distill = config.distill + if self.distill: + self.visual_encoder_m = self._initialize_clip(config) + self.text_encoder_m = BertModel( + self.config_encoder, add_pooling_layer=False) + self.fusion_encoder_m = FusionModel( + self.config_fusion, add_pooling_layer=False) + self.text_decoder_m = BertLMHeadModel(self.config_decoder) + self.model_pairs = [ + [self.visual_encoder, self.visual_encoder_m], + [self.text_encoder, self.text_encoder_m], + [self.text_decoder, self.text_decoder_m], + ] + if self.config_encoder.hidden_size != config.vision_width: + self.visn_fc_m = nn.Linear(config.vision_width, + self.config_encoder.hidden_size) + self.visn_layer_norm_m = nn.LayerNorm( + self.config_encoder.hidden_size, eps=1e-12) + self.dropout_m = nn.Dropout( + self.config_encoder.hidden_dropout_prob) + self.model_pairs.extend( + [[self.visn_fc, self.visn_fc_m], + [self.visn_layer_norm, self.visn_layer_norm_m]]) + self.copy_params() + self.momentum = 0.995 + + @torch.no_grad() + def copy_params(self): + for model_pair in self.model_pairs: + for param, param_m in zip(model_pair[0].parameters(), + model_pair[1].parameters()): + param_m.data.copy_(param.data) # initialize + param_m.requires_grad = False # not update by gradient + + @torch.no_grad() + def _momentum_update(self): + for model_pair in self.model_pairs: + for param, param_m in zip(model_pair[0].parameters(), + model_pair[1].parameters()): + param_m.data = param_m.data * self.momentum + param.data * ( + 1. - self.momentum) + + def generation(self, question_states, question_atts): + encoder_inputs = [question_states, question_atts] + topk_ids, topk_scores = self.beam_generator.translate_batch( + encoder_inputs) + return topk_ids, topk_scores + + @staticmethod + def _tile(x, dim, n_tile): + import numpy as np + init_dim = x.size(dim) + repeat_idx = [1] * x.dim() + repeat_idx[dim] = n_tile + x = x.repeat(*(repeat_idx)) + order_index = torch.LongTensor( + np.concatenate( + [init_dim * np.arange(n_tile) + i for i in range(init_dim)])) + return torch.index_select(x, dim, order_index.to(x.device)) + + def rank_answer(self, question_states, question_atts, answer_ids, + answer_atts, k): + + num_ques = question_states.size(0) + start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token + + start_output = self.text_decoder( + start_ids, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + return_dict=True, + reduction='none') + logits = start_output.logits[:, 0, :] # first token's logit + + # topk_probs: top-k probability + # topk_ids: [num_question, k] + answer_first_token = answer_ids[:, 1] + prob_first_token = F.softmax( + logits, dim=1).index_select( + dim=1, index=answer_first_token) + topk_probs, topk_ids = prob_first_token.topk(k, dim=1) + + # answer input: [num_question*k, answer_len] + input_ids = [] + input_atts = [] + for b, topk_id in enumerate(topk_ids): + input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) + input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) + input_ids = torch.cat(input_ids, dim=0) + input_atts = torch.cat(input_atts, dim=0) + + targets_ids = input_ids.masked_fill( + input_ids == self.tokenizer.pad_token_id, -100) + + # repeat encoder's output for top-k answers + question_states = self._tile(question_states, 0, k) + question_atts = self._tile(question_atts, 0, k) + + output = self.text_decoder( + input_ids, + attention_mask=input_atts, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + labels=targets_ids, + return_dict=True, + reduction='none') + + answer_loss = output.loss + answer_loss = answer_loss.view(input_ids.size(0), -1) + + # topk_prob: first token probability + topk_probs = topk_probs.view(-1, 1) + log_probs = torch.cat([topk_probs.log(), -answer_loss], dim=1) + + # re-calculate log probabilities for the answer sequences using chain rule + log_probs_sum = log_probs.sum(1) + log_probs_sum = log_probs_sum.view(num_ques, k) + + topk_probs = F.softmax(log_probs_sum, dim=-1) + # get top-k after re-ranking + topk_probs, rerank_id = topk_probs.topk(k, dim=1) + topk_ids = torch.gather(topk_ids, 1, rerank_id) + + return topk_ids, topk_probs diff --git a/modelscope/models/multi_modal/mplug/predictor.py b/modelscope/models/multi_modal/mplug/predictor.py new file mode 100755 index 00000000..c976baa1 --- /dev/null +++ b/modelscope/models/multi_modal/mplug/predictor.py @@ -0,0 +1,535 @@ +from __future__ import print_function + +import torch +import torch.nn.functional as F + + +def build_predictor(args, tokenizer, symbols, model, logger=None): + scorer = None + + translator = TextGenerator( + args, model, tokenizer, symbols, global_scorer=scorer, logger=logger) + return translator + + +class TextGenerator(object): + """ + Uses a model to translate a batch of sentences. + + + Args: + model (:obj:`onmt.modules.NMTModel`): + NMT model to use for translation + fields (dict of Fields): data fields + beam_size (int): size of beam to use + n_best (int): number of translations produced + max_length (int): maximum length output to produce + global_scores (:obj:`GlobalScorer`): + object to rescore final translations + copy_attn (bool): use copy attention during translation + cuda (bool): use cuda + beam_trace (bool): trace beam search for debugging + logger(logging.Logger): logger. + """ + + def __init__(self, + args, + model, + vocab=None, + symbols=None, + global_scorer=None, + logger=None, + dump_beam=''): + self.alpha = 0.6 + + self.logger = logger + self.cuda = (torch.cuda.device_count() > 0) + + self.args = args + self.model = model + + self.vocab = vocab + self.symbols = symbols + self.start_token = 101 # ['[PAD]'] + self.end_token = 102 # ['[PAD]'] + + self.global_scorer = global_scorer + self.beam_size = args.beam_size + self.min_length = args.min_length + self.max_length = args.max_length + + self.dump_beam = dump_beam + + # for debugging + self.beam_trace = self.dump_beam != '' + self.beam_accum = None + + if self.beam_trace: + self.beam_accum = { + 'predicted_ids': [], + 'beam_parent_ids': [], + 'scores': [], + 'log_probs': [] + } + + def _build_target_tokens(self, pred): + tokens = [] + for tok in pred: + tok = int(tok) + tokens.append(tok) + if tokens[-1] == self.end_token: + tokens = tokens[:-1] + break + tokens = [t for t in tokens if t < len(self.vocab)] + tokens = self.vocab.DecodeIds(tokens).split(' ') + return tokens + + def translate_batch(self, encoder_inputs, do_sample=False, out_size=1): + """ + Translate a batch of sentences. + + Mostly a wrapper around :obj:`Beam`. + + Args: + batch (:obj:`Batch`): a batch from a dataset object + data (:obj:`Dataset`): the dataset object + fast (bool): enables fast beam search (may not support all features) + + Todo: + Shouldn't need the original dataset. + """ + if do_sample: + return self._fast_translate_batch( + encoder_inputs, + self.max_length, + min_length=self.min_length, + do_sample=do_sample, + out_size=out_size) + else: + with torch.no_grad(): + return self._fast_translate_batch( + encoder_inputs, + self.max_length, + min_length=self.min_length, + do_sample=do_sample, + out_size=out_size) + + def translate_batch_scst(self, + encoder_inputs, + do_sample=False, + out_size=1): + return self._fast_translate_batch( + encoder_inputs, + self.max_length, + min_length=self.min_length, + do_sample=do_sample, + out_size=out_size) + + def _fast_translate_batch(self, + encoder_inputs, + max_length, + min_length=0, + do_sample=False, + out_size=1): + + assert not self.dump_beam + if do_sample: + beam_size = 1 + else: + beam_size = self.beam_size + if len(encoder_inputs) == 3: + src_features, padding_mask, input_ids = encoder_inputs + elif len(encoder_inputs) == 2: + src_features, padding_mask = encoder_inputs + input_ids = None + + device = src_features.device + + # Tile states and memory beam_size times. + batch_size = src_features.size(0) + src_features = tile(src_features, beam_size, dim=0) + attention_mask = tile(padding_mask, beam_size, dim=0) + + batch_offset = torch.arange( + batch_size, dtype=torch.long, device=device) + beam_offset = torch.arange( + 0, + batch_size * beam_size, + step=beam_size, + dtype=torch.long, + device=device) + if input_ids is not None: + alive_seq = tile(input_ids, beam_size, dim=0) + else: + alive_seq = torch.full([batch_size * beam_size, 1], + self.start_token, + dtype=torch.long, + device=device) + + # Give full probability to the first beam on the first step. + topk_log_probs = ( + torch.tensor( + [0.0] + [float('-inf')] * (beam_size - 1), + device=device).repeat(batch_size)) + + # Structure that holds finished hypotheses. + hypotheses = [[] for _ in range(batch_size)] # noqa: F812 + + results = {} + results['predictions'] = [[] for _ in range(batch_size)] # noqa: F812 + results['scores'] = [[] for _ in range(batch_size)] # noqa: F812 + results['gold_score'] = [0] * batch_size + results['batch'] = [] + + for step in range(max_length): + dec_feat_seq = self.model( + alive_seq, + encoder_hidden_states=src_features, + encoder_attention_mask=attention_mask, + return_dict=True, + reduction='none') + + dec_feat_seq = dec_feat_seq.logits[:, -1, :] + vocab_size = dec_feat_seq.size(-1) + log_probs = torch.log( + torch.softmax(dec_feat_seq.view(-1, vocab_size), dim=-1)) + if step < min_length: + log_probs[:, self.end_token] = -1e20 + alpha = self.alpha + if do_sample: + length_penalty = 1.0 + else: + length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha + + if do_sample: + _scores = log_probs / self.args.temperature + _scores = top_k_top_p_filtering( + _scores, + top_k=self.args.top_k, + top_p=self.args.top_p, + min_tokens_to_keep=1 + ) # (batch_size * num_beams, vocab_size) + # Sample 2 next words for each beam + # (so we have some spare tokens and match output of greedy beam search) + topk_ids = torch.multinomial( + F.softmax(_scores, dim=-1), + num_samples=1) # (batch_size * num_beams, 2) + # Compute next scores + _scores = F.log_softmax( + _scores, dim=1) # (batch_size * num_beams, vocab_size) + + _scores += topk_log_probs.view(-1).unsqueeze(1) + topk_scores = torch.gather( + _scores, -1, topk_ids) # (batch_size * num_beams, 2) + # log_probs += # (batch_size * num_beams, 2) + # Match shape of greedy beam search + topk_ids = topk_ids.view( + -1, beam_size) # (batch_size, 2 * num_beams) + topk_scores = topk_scores.view( + -1, beam_size) # (batch_size, 2 * num_beams) + else: + log_probs += topk_log_probs.view(-1).unsqueeze(1) + curr_scores = log_probs / length_penalty + + curr_scores = curr_scores.reshape(-1, beam_size * vocab_size) + topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1) + topk_log_probs = topk_scores * length_penalty + + # Resolve beam origin and true word ids. + # topk_beam_index = topk_ids.div(vocab_size) + topk_beam_index = torch.div( + topk_ids, vocab_size, rounding_mode='floor') + topk_ids = topk_ids.fmod(vocab_size) + + # Map beam_index to batch_index in the flat representation. + batch_index = ( + topk_beam_index + + beam_offset[:topk_beam_index.size(0)].unsqueeze(1)) + select_indices = batch_index.view(-1) + + # Append last prediction. + alive_seq = torch.cat([ + alive_seq.index_select(0, select_indices), + topk_ids.view(-1, 1) + ], -1) + + is_finished = topk_ids.eq(self.end_token) + if step + 1 == max_length: + is_finished.fill_(1) # self.end_token) + # End condition is top beam is finished. + end_condition = is_finished[:, 0].eq(1) # self.end_token) + # Save finished hypotheses. + if is_finished.any(): + predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1)) + for i in range(is_finished.size(0)): + b = batch_offset[i] + if end_condition[i]: + is_finished[i].fill_(1) # self.end_token) + finished_hyp = is_finished[i].nonzero().view(-1) + # Store finished hypotheses for this batch. + for j in finished_hyp: + hypotheses[b].append( + (topk_scores[i, j], predictions[i, j, 0:])) + # If the batch reached the end, save the n_best hypotheses. + if end_condition[i]: + best_hyp = sorted( + hypotheses[b], key=lambda x: x[0], reverse=True) + + for each in best_hyp[:beam_size]: + score, pred = each + results['scores'][b].append(score) + results['predictions'][b].append(pred) + non_finished = end_condition.eq(0).nonzero().view(-1) + # If all sentences are translated, no need to go further. + if len(non_finished) == 0: + break + # Remove finished batches for the next step. + topk_log_probs = topk_log_probs.index_select(0, non_finished) + batch_index = batch_index.index_select(0, non_finished) + batch_offset = batch_offset.index_select(0, non_finished) + alive_seq = predictions.index_select(0, non_finished) \ + .view(-1, alive_seq.size(-1)) + # Reorder states. + select_indices = batch_index.view(-1) + src_features = src_features.index_select(0, select_indices) + attention_mask = attention_mask.index_select(0, select_indices) + pred_ids = [] + scores = [] + # print (pred_ids, scores) + for each in results['scores']: + scores.append(each[:out_size]) + for each in results['predictions']: + pred_ids.append(each[:out_size]) + return pred_ids, scores + + def _generate_no_beam_search( + self, + input_ids, + cur_len, + max_length, + do_sample, + temperature, + top_k, + top_p, + repetition_penalty, + pad_token_id, + eos_token_ids, + batch_size, + ): + """ Generate sequences for each example without beam search (num_beams == 1). + All returned sequence are generated independantly. + """ + assert self.num_keep_best == 1, 'cannot generate >1 sentences in greedy search' + # current position / max lengths / length of generated sentences / unfinished sentences + unfinished_sents = [] + cur_unfinished = input_ids.new(batch_size).fill_(1) + + # log of scores for each sentence in the batch + logprobs = [] + + past = None + + while cur_len < max_length: + model_inputs = self.prepare_inputs_for_generation( + input_ids, past=past) + outputs = self(**model_inputs) + if cur_len == 1: + token_len = 2 + self.od_labels_len + next_token_idx = 1 + else: + assert cur_len > 1 + if not self._do_output_past(outputs): + token_len = cur_len + 1 + self.od_labels_len + next_token_idx = cur_len + else: + token_len = 2 + next_token_idx = 1 + assert outputs[0].shape[1] == token_len + + next_token_logits = outputs[0][:, next_token_idx, :] + + # if model has past, then set the past variable to speed up decoding + if self._do_output_past(outputs): + past = outputs[1] + + # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) + if repetition_penalty != 1.0: + for i in range(batch_size): + for previous_token in set(input_ids[i].tolist()): + # if score < 0 then repetition penalty has to multiplied + # to reduce the previous token probability + if next_token_logits[i, previous_token] < 0: + next_token_logits[ + i, previous_token] *= repetition_penalty + else: + next_token_logits[ + i, previous_token] /= repetition_penalty + + if do_sample: + # Temperature (higher temperature => more likely to sample low probability tokens) + if temperature != 1.0: + next_token_logits = next_token_logits / temperature + # Top-p/top-k filtering + next_token_logits = top_k_top_p_filtering( + next_token_logits, top_k=top_k, top_p=top_p) + # Sample + next_token = torch.multinomial( + F.softmax(next_token_logits, dim=-1), + num_samples=1).squeeze(1) + else: + # Greedy decoding + next_token = torch.argmax(next_token_logits, dim=-1) + + # Compute scores + _scores = F.log_softmax( + next_token_logits, dim=-1) # (batch_size, vocab_size) + _scores = torch.gather(_scores, -1, + next_token.unsqueeze(-1)) # (batch_size, 1) + logprobs.append(_scores) # (batch_size, 1) + unfinished_sents.append(cur_unfinished) + + # update generations and finished sentences + tokens_to_add = next_token * cur_unfinished + pad_token_id * ( + 1 - cur_unfinished) + input_ids = torch.cat( + [input_ids, tokens_to_add.unsqueeze(-1)], dim=-1) + + for eos_token_id in eos_token_ids: + cur_unfinished = cur_unfinished.mul( + tokens_to_add.ne(eos_token_id).long()) + cur_len = cur_len + 1 + + # stop when there is a in each sentence, or if we exceed the maximul length + if cur_unfinished.max() == 0: + break + + # add eos_token_ids to unfinished sentences + if cur_len == max_length: + input_ids[:, -1].masked_fill_( + cur_unfinished.to(dtype=torch.bool), eos_token_ids[0]) + + logprobs = torch.cat(logprobs, dim=1) + unfinished_sents = torch.stack(unfinished_sents, dim=1).float() + sum_logprobs = (logprobs * unfinished_sents).sum(dim=1) + # return logprobs to keep consistent with beam search output + logprobs = sum_logprobs / unfinished_sents.sum(dim=1) + + # pad to the same length, otherwise DataParallel will give error + pad_len = max_length - input_ids.shape[1] + if pad_len > 0: + padding_ids = input_ids.new(batch_size, + pad_len).fill_(pad_token_id) + input_ids = torch.cat([input_ids, padding_ids], dim=1) + + # (batch_size, n_best, max_len), (batch_size, n_best) + return input_ids.unsqueeze(1), logprobs.unsqueeze(1) + + +def top_k_top_p_filtering(logits, + top_k=10, + top_p=1.0, + filter_value=-float('Inf'), + min_tokens_to_keep=1): + + if top_k > 0: + top_k = min(max(top_k, min_tokens_to_keep), + logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, + None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum( + F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ + ..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = filter_value + return logits + + +class Translation(object): + """ + Container for a translated sentence. + + Attributes: + src (`LongTensor`): src word ids + src_raw ([str]): raw src words + + pred_sents ([[str]]): words from the n-best translations + pred_scores ([[float]]): log-probs of n-best translations + attns ([`FloatTensor`]) : attention dist for each translation + gold_sent ([str]): words from gold translation + gold_score ([float]): log-prob of gold translation + + """ + + def __init__(self, fname, src, src_raw, pred_sents, attn, pred_scores, + tgt_sent, gold_score): + self.fname = fname + self.src = src + self.src_raw = src_raw + self.pred_sents = pred_sents + self.attns = attn + self.pred_scores = pred_scores + self.gold_sent = tgt_sent + self.gold_score = gold_score + + def log(self, sent_number): + """ + Log translation. + """ + + output = '\nSENT {}: {}\n'.format(sent_number, self.src_raw) + + best_pred = self.pred_sents[0] + best_score = self.pred_scores[0] + pred_sent = ' '.join(best_pred) + output += 'PRED {}: {}\n'.format(sent_number, pred_sent) + output += 'PRED SCORE: {:.4f}\n'.format(best_score) + + if self.gold_sent is not None: + tgt_sent = ' '.join(self.gold_sent) + output += 'GOLD {}: {}\n'.format(sent_number, tgt_sent) + output += ('GOLD SCORE: {:.4f}\n'.format(self.gold_score)) + if len(self.pred_sents) > 1: + output += '\nBEST HYP:\n' + for score, sent in zip(self.pred_scores, self.pred_sents): + output += '[{:.4f}] {}\n'.format(score, sent) + + return output + + +def tile(x, count, dim=0): + """ + Tiles x on dimension dim count times. + """ + perm = list(range(len(x.size()))) + if dim != 0: + perm[0], perm[dim] = perm[dim], perm[0] + x = x.permute(perm).contiguous() + out_size = list(x.size()) + out_size[0] *= count + batch = x.size(0) + x = x.view(batch, -1) \ + .transpose(0, 1) \ + .repeat(count, 1) \ + .transpose(0, 1) \ + .contiguous() \ + .view(*out_size) + if dim != 0: + x = x.permute(perm).contiguous() + return x diff --git a/modelscope/models/multi_modal/mplug_for_visual_question_answering.py b/modelscope/models/multi_modal/mplug_for_visual_question_answering.py index 0f69cc2d..dc4fcce0 100644 --- a/modelscope/models/multi_modal/mplug_for_visual_question_answering.py +++ b/modelscope/models/multi_modal/mplug_for_visual_question_answering.py @@ -19,7 +19,7 @@ class MPlugForVisualQuestionAnswering(Model): """ super().__init__(model_dir, *args, **kwargs) - from sofa.models.mplug import MPlugForVisualQuestionAnswering + from modelscope.models.multi_modal.mplug import MPlugForVisualQuestionAnswering self.model = MPlugForVisualQuestionAnswering.from_pretrained(model_dir) self.tokenizer = self.model.tokenizer diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 921b2bc3..306a76cb 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -306,5 +306,10 @@ TASK_OUTPUTS = { # { # "output_img": np.ndarray with shape [height, width, 3] # } - Tasks.virtual_tryon: [OutputKeys.OUTPUT_IMG] + Tasks.virtual_tryon: [OutputKeys.OUTPUT_IMG], + # visual_question_answering result for a single sample + # { + # "text": "this is the text generated by a model." + # } + Tasks.visual_question_answering: [OutputKeys.TEXT] } diff --git a/modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py b/modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py index 9b51efa0..0b1fedff 100644 --- a/modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py +++ b/modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py @@ -5,6 +5,7 @@ import torch from modelscope.metainfo import Pipelines from modelscope.models import Model from modelscope.models.multi_modal import MPlugForVisualQuestionAnswering +from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Pipeline, Tensor from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import MPlugVisualQuestionAnsweringPreprocessor @@ -62,4 +63,4 @@ class VisualQuestionAnsweringPipeline(Pipeline): for _old, _new in replace_tokens_bert: pred_string = pred_string.replace(_old, _new) pred_string.strip() - return {'answer': pred_string} + return {OutputKeys.TEXT: pred_string} diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index b5dc0cf4..56bcfcd1 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -78,14 +78,16 @@ class MPlugVisualQuestionAnsweringPreprocessor(Preprocessor): """preprocess the data via 'bert-base-uncased' tokenizer and configuration """ + from transformers import BertTokenizer + from modelscope.models.multi_modal.mplug import CONFIG_NAME, VOCAB_NAME, MPlugConfig + super().__init__(*args, **kwargs) # tokenizer - from transformers import AutoTokenizer - self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') + self.tokenizer = BertTokenizer.from_pretrained( + osp.join(model_dir, VOCAB_NAME)) # load configuration - from sofa.models.mplug import CONFIG_NAME, MPlugConfig config = MPlugConfig.from_yaml_file(osp.join(model_dir, CONFIG_NAME)) # Initialize transform diff --git a/tests/pipelines/test_visual_question_answering.py b/tests/pipelines/test_visual_question_answering.py index 4577607e..3583c3a4 100644 --- a/tests/pipelines/test_visual_question_answering.py +++ b/tests/pipelines/test_visual_question_answering.py @@ -30,8 +30,8 @@ class VisualQuestionAnsweringTest(unittest.TestCase): model=model, preprocessor=preprocessor) print(f"question: {self.input_vqa['question']}") - print(f"pipeline1: {pipeline1(self.input_vqa)['answer']}") - print(f"pipeline2: {pipeline2(self.input_vqa)['answer']}") + print(f'pipeline1: {pipeline1(self.input_vqa)}') + print(f'pipeline2: {pipeline2(self.input_vqa)}') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_model_from_modelhub(self):