Migrate MPLUG model code from sofa to maas. No need to download checkpoint from huggingface anymore. Added OutputKeys definition for vqa.master
@@ -0,0 +1,18 @@ | |||
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
# All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from .configuration_mplug import MPlugConfig | |||
from .modeling_mplug import (CONFIG_NAME, VOCAB_NAME, | |||
MPlugForVisualQuestionAnswering) |
@@ -0,0 +1 @@ | |||
from .clip import load_from_config |
@@ -0,0 +1,401 @@ | |||
# Copyright 2021 The OpenAI CLIP Authors. All rights reserved. | |||
from collections import OrderedDict | |||
from typing import Tuple, Union | |||
import torch | |||
import torch.nn.functional as F | |||
from torch import nn | |||
from modelscope.models.multi_modal.clip.clip_vit import Transformer | |||
class Bottleneck(nn.Module): | |||
expansion = 4 | |||
def __init__(self, inplanes, planes, stride=1): | |||
super().__init__() | |||
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 | |||
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) | |||
self.bn1 = nn.BatchNorm2d(planes) | |||
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) | |||
self.bn2 = nn.BatchNorm2d(planes) | |||
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() | |||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) | |||
self.bn3 = nn.BatchNorm2d(planes * self.expansion) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.downsample = None | |||
self.stride = stride | |||
if stride > 1 or inplanes != planes * Bottleneck.expansion: | |||
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 | |||
self.downsample = nn.Sequential( | |||
OrderedDict([('-1', nn.AvgPool2d(stride)), | |||
('0', | |||
nn.Conv2d( | |||
inplanes, | |||
planes * self.expansion, | |||
1, | |||
stride=1, | |||
bias=False)), | |||
('1', nn.BatchNorm2d(planes * self.expansion))])) | |||
def forward(self, x: torch.Tensor): | |||
identity = x | |||
out = self.relu(self.bn1(self.conv1(x))) | |||
out = self.relu(self.bn2(self.conv2(out))) | |||
out = self.avgpool(out) | |||
out = self.bn3(self.conv3(out)) | |||
if self.downsample is not None: | |||
identity = self.downsample(x) | |||
out += identity | |||
out = self.relu(out) | |||
return out | |||
class AttentionPool2d(nn.Module): | |||
def __init__(self, | |||
spacial_dim: int, | |||
embed_dim: int, | |||
num_heads: int, | |||
output_dim: int = None): | |||
super().__init__() | |||
self.positional_embedding = nn.Parameter( | |||
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) | |||
self.k_proj = nn.Linear(embed_dim, embed_dim) | |||
self.q_proj = nn.Linear(embed_dim, embed_dim) | |||
self.v_proj = nn.Linear(embed_dim, embed_dim) | |||
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) | |||
self.num_heads = num_heads | |||
def forward(self, x): | |||
x = x.reshape(x.shape[0], x.shape[1], | |||
x.shape[2] * x.shape[3]).permute(2, 0, | |||
1) # NCHW -> (HW)NC | |||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC | |||
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC | |||
if self.training: | |||
dropout = 0.1 | |||
else: | |||
dropout = 0.0 | |||
x, _ = F.multi_head_attention_forward( | |||
query=x, | |||
key=x, | |||
value=x, | |||
embed_dim_to_check=x.shape[-1], | |||
num_heads=self.num_heads, | |||
q_proj_weight=self.q_proj.weight, | |||
k_proj_weight=self.k_proj.weight, | |||
v_proj_weight=self.v_proj.weight, | |||
in_proj_weight=None, | |||
in_proj_bias=torch.cat( | |||
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), | |||
bias_k=None, | |||
bias_v=None, | |||
add_zero_attn=False, | |||
dropout_p=dropout, | |||
out_proj_weight=self.c_proj.weight, | |||
out_proj_bias=self.c_proj.bias, | |||
use_separate_proj_weight=True, | |||
training=self.training, | |||
need_weights=False) | |||
return x[0] | |||
class ModifiedResNet(nn.Module): | |||
""" | |||
A ResNet class that is similar to torchvision's but contains the following changes: | |||
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. | |||
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 | |||
- The final pooling layer is a QKV attention instead of an average pool | |||
""" | |||
def __init__(self, | |||
layers, | |||
output_dim, | |||
heads, | |||
input_resolution=224, | |||
width=64): | |||
super().__init__() | |||
self.output_dim = output_dim | |||
self.input_resolution = input_resolution | |||
# the 3-layer stem | |||
self.conv1 = nn.Conv2d( | |||
3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) | |||
self.bn1 = nn.BatchNorm2d(width // 2) | |||
self.conv2 = nn.Conv2d( | |||
width // 2, width // 2, kernel_size=3, padding=1, bias=False) | |||
self.bn2 = nn.BatchNorm2d(width // 2) | |||
self.conv3 = nn.Conv2d( | |||
width // 2, width, kernel_size=3, padding=1, bias=False) | |||
self.bn3 = nn.BatchNorm2d(width) | |||
self.avgpool = nn.AvgPool2d(2) | |||
self.relu = nn.ReLU(inplace=True) | |||
# residual layers | |||
self._inplanes = width # this is a *mutable* variable used during construction | |||
self.layer1 = self._make_layer(width, layers[0]) | |||
self.layer2 = self._make_layer(width * 2, layers[1], stride=2) | |||
self.layer3 = self._make_layer(width * 4, layers[2], stride=2) | |||
self.layer4 = self._make_layer(width * 8, layers[3], stride=2) | |||
embed_dim = width * 32 # the ResNet feature dimension | |||
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, | |||
heads, output_dim) | |||
def _make_layer(self, planes, blocks, stride=1): | |||
layers = [Bottleneck(self._inplanes, planes, stride)] | |||
self._inplanes = planes * Bottleneck.expansion | |||
for _ in range(1, blocks): | |||
layers.append(Bottleneck(self._inplanes, planes)) | |||
return nn.Sequential(*layers) | |||
def forward(self, x, skip_last_layer=False): | |||
def stem(x): | |||
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), | |||
(self.conv3, self.bn3)]: | |||
x = self.relu(bn(conv(x))) | |||
x = self.avgpool(x) | |||
return x | |||
x = x.type(self.conv1.weight.dtype) | |||
x = stem(x) | |||
x = self.layer1(x) | |||
x = self.layer2(x) | |||
x = self.layer3(x) | |||
x = self.layer4(x) | |||
if not skip_last_layer: | |||
x = self.attnpool(x) | |||
return x | |||
class LayerNorm(nn.LayerNorm): | |||
"""Subclass torch's LayerNorm to handle fp16.""" | |||
def forward(self, x: torch.Tensor): | |||
orig_type = x.dtype | |||
ret = super().forward(x) | |||
return ret.type(orig_type) | |||
class VisualTransformer(nn.Module): | |||
def __init__(self, input_resolution: int, patch_size: int, width: int, | |||
layers: int, heads: int, output_dim: int): | |||
super().__init__() | |||
self.input_resolution = input_resolution | |||
self.output_dim = output_dim | |||
self.heads = heads | |||
self.conv1 = nn.Conv2d( | |||
in_channels=3, | |||
out_channels=width, | |||
kernel_size=patch_size, | |||
stride=patch_size, | |||
bias=False) | |||
scale = width**-0.5 | |||
self.class_embedding = nn.Parameter(scale * torch.randn(width)) | |||
self.positional_embedding = nn.Parameter(scale * torch.randn( | |||
(input_resolution // patch_size)**2 + 1, width)) | |||
self.ln_pre = LayerNorm(width) | |||
self.transformer = Transformer(width, layers, heads) | |||
self.ln_post = LayerNorm(width) | |||
self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) | |||
def forward(self, | |||
x: torch.Tensor, | |||
skip_last_layer=False, | |||
text_embedding=None, | |||
text_mask=None): | |||
x = self.conv1(x) # shape = [*, width, grid, grid] | |||
x = x.reshape(x.shape[0], x.shape[1], | |||
-1) # shape = [*, width, grid ** 2] | |||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |||
cls_emb = self.class_embedding.to(x.dtype) | |||
x_zeros = torch.zeros( | |||
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) | |||
x = torch.cat([cls_emb + x_zeros, x], | |||
dim=1) # shape = [*, grid ** 2 + 1, width] | |||
x = x + self.positional_embedding.to(x.dtype)[:x.size(1), :] | |||
x = self.ln_pre(x) | |||
x = x.permute(1, 0, 2) # NLD -> LND | |||
x = self.transformer(x) | |||
x = x.permute(1, 0, 2) # LND -> NLD | |||
if skip_last_layer: | |||
x = self.ln_post(x) | |||
# x = x @ self.proj | |||
else: | |||
x = x @ self.proj | |||
return x | |||
class CLIP(nn.Module): | |||
def __init__( | |||
self, | |||
embed_dim: int, | |||
# vision | |||
image_resolution: int, | |||
vision_layers: Union[Tuple[int, int, int, int], int], | |||
vision_width: int, | |||
vision_patch_size: int, | |||
# text | |||
context_length: int, | |||
vocab_size: int, | |||
transformer_width: int, | |||
transformer_heads: int, | |||
transformer_layers: int): | |||
super().__init__() | |||
self.context_length = context_length | |||
if isinstance(vision_layers, (tuple, list)): | |||
vision_heads = vision_width * 32 // 64 | |||
self.visual = ModifiedResNet( | |||
layers=vision_layers, | |||
output_dim=embed_dim, | |||
heads=vision_heads, | |||
input_resolution=image_resolution, | |||
width=vision_width) | |||
else: | |||
vision_heads = vision_width // 64 | |||
self.visual = VisualTransformer( | |||
input_resolution=image_resolution, | |||
patch_size=vision_patch_size, | |||
width=vision_width, | |||
layers=vision_layers, | |||
heads=vision_heads, | |||
output_dim=embed_dim) | |||
self.transformer = Transformer( | |||
width=transformer_width, | |||
layers=transformer_layers, | |||
heads=transformer_heads, | |||
attn_mask=self.build_attention_mask()) | |||
self.vocab_size = vocab_size | |||
self.token_embedding = nn.Embedding(vocab_size, transformer_width) | |||
self.positional_embedding = nn.Parameter( | |||
torch.empty(self.context_length, transformer_width)) | |||
self.ln_final = LayerNorm(transformer_width) | |||
self.text_projection = nn.Parameter( | |||
torch.empty(transformer_width, embed_dim)) | |||
self.logit_scale = nn.Parameter(torch.ones([])) | |||
self.initialize_parameters() | |||
def initialize_parameters(self): | |||
nn.init.normal_(self.token_embedding.weight, std=0.02) | |||
nn.init.normal_(self.positional_embedding, std=0.01) | |||
if isinstance(self.visual, ModifiedResNet): | |||
if self.visual.attnpool is not None: | |||
std = self.visual.attnpool.c_proj.in_features**-0.5 | |||
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) | |||
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) | |||
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) | |||
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) | |||
for resnet_block in [ | |||
self.visual.layer1, self.visual.layer2, self.visual.layer3, | |||
self.visual.layer4 | |||
]: | |||
for name, param in resnet_block.named_parameters(): | |||
if name.endswith('bn3.weight'): | |||
nn.init.zeros_(param) | |||
proj_std = (self.transformer.width**-0.5) * ( | |||
(2 * self.transformer.layers)**-0.5) | |||
attn_std = self.transformer.width**-0.5 | |||
fc_std = (2 * self.transformer.width)**-0.5 | |||
for block in self.transformer.resblocks: | |||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std) | |||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std) | |||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) | |||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) | |||
if self.text_projection is not None: | |||
nn.init.normal_( | |||
self.text_projection, std=self.transformer.width**-0.5) | |||
def build_attention_mask(self): | |||
# lazily create causal attention mask, with full attention between the vision tokens | |||
# pytorch uses additive attention mask; fill with -inf | |||
mask = torch.empty(self.context_length, self.context_length) | |||
mask.fill_(float('-inf')) | |||
mask.triu_(1) # zero out the lower diagonal | |||
return mask | |||
@property | |||
def dtype(self): | |||
return self.visual.conv1.weight.dtype | |||
def encode_image(self, image): | |||
return self.visual(image.type(self.dtype)) | |||
def encode_text(self, text): | |||
x = self.token_embedding(text).type( | |||
self.dtype) # [batch_size, n_ctx, d_model] | |||
x = x + self.positional_embedding.type(self.dtype) | |||
x = x.permute(1, 0, 2) # NLD -> LND | |||
x = self.transformer(x) | |||
x = x.permute(1, 0, 2) # LND -> NLD | |||
x = self.ln_final(x).type(self.dtype) | |||
# x.shape = [batch_size, n_ctx, transformer.width] | |||
# take features from the eot embedding (eot_token is the highest number in each sequence) | |||
x = x[torch.arange(x.shape[0]), | |||
text.argmax(dim=-1)] @ self.text_projection | |||
return x | |||
def forward(self, image, text): | |||
image_features = self.encode_image(image) | |||
text_features = self.encode_text(text) | |||
# normalized features | |||
image_features = image_features / image_features.norm( | |||
dim=-1, keepdim=True) | |||
text_features = text_features / text_features.norm( | |||
dim=-1, keepdim=True) | |||
# cosine similarity as logits | |||
logit_scale = self.logit_scale.exp() | |||
logits_per_image = logit_scale * image_features @ text_features.t() | |||
logits_per_text = logit_scale * text_features @ image_features.t() | |||
# shape = [global_batch_size, global_batch_size] | |||
return logits_per_image, logits_per_text | |||
def load_from_config(config): | |||
return CLIP(config.clip_embed_dim, config.clip_image_resolution, | |||
config.clip_vision_layers, config.clip_vision_width, | |||
config.clip_vision_patch_size, config.clip_context_length, | |||
config.clip_vocab_size, config.clip_transformer_width, | |||
config.clip_transformer_heads, config.clip_transformer_layers) |
@@ -0,0 +1,125 @@ | |||
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" MPLUG model configuration """ | |||
import os | |||
from collections import OrderedDict | |||
from typing import Any, Dict, Mapping, Union | |||
import yaml | |||
from transformers import PretrainedConfig | |||
from transformers.onnx import OnnxConfig | |||
from transformers.utils import logging | |||
logger = logging.get_logger(__name__) | |||
class MPlugConfig(PretrainedConfig): | |||
model_type = 'mplug' | |||
def __init__( | |||
self, | |||
bert_config='config_bert.json', | |||
image_res=504, | |||
batch_size_train=128, | |||
vision_width=1024, | |||
distill=True, | |||
clip_name='ViT-L-14', # ViT-B-16 | ViT-L-14 | |||
batch_size_test=64, | |||
k_test=128, | |||
alpha=0.4, | |||
warm_up=True, | |||
eos='[SEP]', | |||
optimizer=None, | |||
schedular=None, | |||
min_length=1, | |||
max_length=10, | |||
beam_size=5, | |||
add_ocr=False, | |||
add_object=False, | |||
text_encoder='bert-base-uncased', | |||
text_decoder='bert-base-uncased', | |||
# clip | |||
clip_embed_dim=768, | |||
clip_image_resolution=224, | |||
clip_vision_layers=24, | |||
clip_vision_width=1024, | |||
clip_vision_patch_size=14, | |||
clip_context_length=77, | |||
clip_vocab_size=49408, | |||
clip_transformer_width=768, | |||
clip_transformer_heads=12, | |||
clip_transformer_layers=12, | |||
**kwargs): | |||
super().__init__(**kwargs) | |||
self.bert_config = bert_config | |||
self.image_res = image_res | |||
self.batch_size_train = batch_size_train | |||
self.vision_width = vision_width | |||
self.distill = distill | |||
self.clip_name = clip_name | |||
self.batch_size_test = batch_size_test | |||
self.k_test = k_test | |||
self.alpha = alpha | |||
self.warm_up = warm_up | |||
self.eos = eos | |||
self.optimizer = optimizer | |||
self.schedular = schedular | |||
self.min_length = min_length | |||
self.max_length = max_length | |||
self.beam_size = beam_size | |||
self.add_ocr = add_ocr | |||
self.add_object = add_object | |||
self.text_encoder = text_encoder | |||
self.text_decoder = text_decoder | |||
# clip | |||
self.clip_embed_dim = clip_embed_dim | |||
self.clip_image_resolution = clip_image_resolution | |||
self.clip_vision_layers = clip_vision_layers | |||
self.clip_vision_width = clip_vision_width | |||
self.clip_vision_patch_size = clip_vision_patch_size | |||
self.clip_context_length = clip_context_length | |||
self.clip_vocab_size = clip_vocab_size | |||
self.clip_transformer_width = clip_transformer_width | |||
self.clip_transformer_heads = clip_transformer_heads | |||
self.clip_transformer_layers = clip_transformer_layers | |||
@classmethod | |||
def from_yaml_file(cls, yaml_file: Union[str, | |||
os.PathLike]) -> Dict[str, Any]: | |||
with open(yaml_file, 'r') as reader: | |||
config_dict = yaml.load(reader, Loader=yaml.Loader) | |||
return cls(**config_dict) | |||
class MPlugOnnxConfig(OnnxConfig): | |||
@property | |||
def inputs(self) -> Mapping[str, Mapping[int, str]]: | |||
return OrderedDict([ | |||
('input_ids', { | |||
0: 'batch', | |||
1: 'sequence' | |||
}), | |||
('attention_mask', { | |||
0: 'batch', | |||
1: 'sequence' | |||
}), | |||
('token_type_ids', { | |||
0: 'batch', | |||
1: 'sequence' | |||
}), | |||
]) |
@@ -0,0 +1,535 @@ | |||
from __future__ import print_function | |||
import torch | |||
import torch.nn.functional as F | |||
def build_predictor(args, tokenizer, symbols, model, logger=None): | |||
scorer = None | |||
translator = TextGenerator( | |||
args, model, tokenizer, symbols, global_scorer=scorer, logger=logger) | |||
return translator | |||
class TextGenerator(object): | |||
""" | |||
Uses a model to translate a batch of sentences. | |||
Args: | |||
model (:obj:`onmt.modules.NMTModel`): | |||
NMT model to use for translation | |||
fields (dict of Fields): data fields | |||
beam_size (int): size of beam to use | |||
n_best (int): number of translations produced | |||
max_length (int): maximum length output to produce | |||
global_scores (:obj:`GlobalScorer`): | |||
object to rescore final translations | |||
copy_attn (bool): use copy attention during translation | |||
cuda (bool): use cuda | |||
beam_trace (bool): trace beam search for debugging | |||
logger(logging.Logger): logger. | |||
""" | |||
def __init__(self, | |||
args, | |||
model, | |||
vocab=None, | |||
symbols=None, | |||
global_scorer=None, | |||
logger=None, | |||
dump_beam=''): | |||
self.alpha = 0.6 | |||
self.logger = logger | |||
self.cuda = (torch.cuda.device_count() > 0) | |||
self.args = args | |||
self.model = model | |||
self.vocab = vocab | |||
self.symbols = symbols | |||
self.start_token = 101 # ['[PAD]'] | |||
self.end_token = 102 # ['[PAD]'] | |||
self.global_scorer = global_scorer | |||
self.beam_size = args.beam_size | |||
self.min_length = args.min_length | |||
self.max_length = args.max_length | |||
self.dump_beam = dump_beam | |||
# for debugging | |||
self.beam_trace = self.dump_beam != '' | |||
self.beam_accum = None | |||
if self.beam_trace: | |||
self.beam_accum = { | |||
'predicted_ids': [], | |||
'beam_parent_ids': [], | |||
'scores': [], | |||
'log_probs': [] | |||
} | |||
def _build_target_tokens(self, pred): | |||
tokens = [] | |||
for tok in pred: | |||
tok = int(tok) | |||
tokens.append(tok) | |||
if tokens[-1] == self.end_token: | |||
tokens = tokens[:-1] | |||
break | |||
tokens = [t for t in tokens if t < len(self.vocab)] | |||
tokens = self.vocab.DecodeIds(tokens).split(' ') | |||
return tokens | |||
def translate_batch(self, encoder_inputs, do_sample=False, out_size=1): | |||
""" | |||
Translate a batch of sentences. | |||
Mostly a wrapper around :obj:`Beam`. | |||
Args: | |||
batch (:obj:`Batch`): a batch from a dataset object | |||
data (:obj:`Dataset`): the dataset object | |||
fast (bool): enables fast beam search (may not support all features) | |||
Todo: | |||
Shouldn't need the original dataset. | |||
""" | |||
if do_sample: | |||
return self._fast_translate_batch( | |||
encoder_inputs, | |||
self.max_length, | |||
min_length=self.min_length, | |||
do_sample=do_sample, | |||
out_size=out_size) | |||
else: | |||
with torch.no_grad(): | |||
return self._fast_translate_batch( | |||
encoder_inputs, | |||
self.max_length, | |||
min_length=self.min_length, | |||
do_sample=do_sample, | |||
out_size=out_size) | |||
def translate_batch_scst(self, | |||
encoder_inputs, | |||
do_sample=False, | |||
out_size=1): | |||
return self._fast_translate_batch( | |||
encoder_inputs, | |||
self.max_length, | |||
min_length=self.min_length, | |||
do_sample=do_sample, | |||
out_size=out_size) | |||
def _fast_translate_batch(self, | |||
encoder_inputs, | |||
max_length, | |||
min_length=0, | |||
do_sample=False, | |||
out_size=1): | |||
assert not self.dump_beam | |||
if do_sample: | |||
beam_size = 1 | |||
else: | |||
beam_size = self.beam_size | |||
if len(encoder_inputs) == 3: | |||
src_features, padding_mask, input_ids = encoder_inputs | |||
elif len(encoder_inputs) == 2: | |||
src_features, padding_mask = encoder_inputs | |||
input_ids = None | |||
device = src_features.device | |||
# Tile states and memory beam_size times. | |||
batch_size = src_features.size(0) | |||
src_features = tile(src_features, beam_size, dim=0) | |||
attention_mask = tile(padding_mask, beam_size, dim=0) | |||
batch_offset = torch.arange( | |||
batch_size, dtype=torch.long, device=device) | |||
beam_offset = torch.arange( | |||
0, | |||
batch_size * beam_size, | |||
step=beam_size, | |||
dtype=torch.long, | |||
device=device) | |||
if input_ids is not None: | |||
alive_seq = tile(input_ids, beam_size, dim=0) | |||
else: | |||
alive_seq = torch.full([batch_size * beam_size, 1], | |||
self.start_token, | |||
dtype=torch.long, | |||
device=device) | |||
# Give full probability to the first beam on the first step. | |||
topk_log_probs = ( | |||
torch.tensor( | |||
[0.0] + [float('-inf')] * (beam_size - 1), | |||
device=device).repeat(batch_size)) | |||
# Structure that holds finished hypotheses. | |||
hypotheses = [[] for _ in range(batch_size)] # noqa: F812 | |||
results = {} | |||
results['predictions'] = [[] for _ in range(batch_size)] # noqa: F812 | |||
results['scores'] = [[] for _ in range(batch_size)] # noqa: F812 | |||
results['gold_score'] = [0] * batch_size | |||
results['batch'] = [] | |||
for step in range(max_length): | |||
dec_feat_seq = self.model( | |||
alive_seq, | |||
encoder_hidden_states=src_features, | |||
encoder_attention_mask=attention_mask, | |||
return_dict=True, | |||
reduction='none') | |||
dec_feat_seq = dec_feat_seq.logits[:, -1, :] | |||
vocab_size = dec_feat_seq.size(-1) | |||
log_probs = torch.log( | |||
torch.softmax(dec_feat_seq.view(-1, vocab_size), dim=-1)) | |||
if step < min_length: | |||
log_probs[:, self.end_token] = -1e20 | |||
alpha = self.alpha | |||
if do_sample: | |||
length_penalty = 1.0 | |||
else: | |||
length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha | |||
if do_sample: | |||
_scores = log_probs / self.args.temperature | |||
_scores = top_k_top_p_filtering( | |||
_scores, | |||
top_k=self.args.top_k, | |||
top_p=self.args.top_p, | |||
min_tokens_to_keep=1 | |||
) # (batch_size * num_beams, vocab_size) | |||
# Sample 2 next words for each beam | |||
# (so we have some spare tokens and match output of greedy beam search) | |||
topk_ids = torch.multinomial( | |||
F.softmax(_scores, dim=-1), | |||
num_samples=1) # (batch_size * num_beams, 2) | |||
# Compute next scores | |||
_scores = F.log_softmax( | |||
_scores, dim=1) # (batch_size * num_beams, vocab_size) | |||
_scores += topk_log_probs.view(-1).unsqueeze(1) | |||
topk_scores = torch.gather( | |||
_scores, -1, topk_ids) # (batch_size * num_beams, 2) | |||
# log_probs += # (batch_size * num_beams, 2) | |||
# Match shape of greedy beam search | |||
topk_ids = topk_ids.view( | |||
-1, beam_size) # (batch_size, 2 * num_beams) | |||
topk_scores = topk_scores.view( | |||
-1, beam_size) # (batch_size, 2 * num_beams) | |||
else: | |||
log_probs += topk_log_probs.view(-1).unsqueeze(1) | |||
curr_scores = log_probs / length_penalty | |||
curr_scores = curr_scores.reshape(-1, beam_size * vocab_size) | |||
topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1) | |||
topk_log_probs = topk_scores * length_penalty | |||
# Resolve beam origin and true word ids. | |||
# topk_beam_index = topk_ids.div(vocab_size) | |||
topk_beam_index = torch.div( | |||
topk_ids, vocab_size, rounding_mode='floor') | |||
topk_ids = topk_ids.fmod(vocab_size) | |||
# Map beam_index to batch_index in the flat representation. | |||
batch_index = ( | |||
topk_beam_index | |||
+ beam_offset[:topk_beam_index.size(0)].unsqueeze(1)) | |||
select_indices = batch_index.view(-1) | |||
# Append last prediction. | |||
alive_seq = torch.cat([ | |||
alive_seq.index_select(0, select_indices), | |||
topk_ids.view(-1, 1) | |||
], -1) | |||
is_finished = topk_ids.eq(self.end_token) | |||
if step + 1 == max_length: | |||
is_finished.fill_(1) # self.end_token) | |||
# End condition is top beam is finished. | |||
end_condition = is_finished[:, 0].eq(1) # self.end_token) | |||
# Save finished hypotheses. | |||
if is_finished.any(): | |||
predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1)) | |||
for i in range(is_finished.size(0)): | |||
b = batch_offset[i] | |||
if end_condition[i]: | |||
is_finished[i].fill_(1) # self.end_token) | |||
finished_hyp = is_finished[i].nonzero().view(-1) | |||
# Store finished hypotheses for this batch. | |||
for j in finished_hyp: | |||
hypotheses[b].append( | |||
(topk_scores[i, j], predictions[i, j, 0:])) | |||
# If the batch reached the end, save the n_best hypotheses. | |||
if end_condition[i]: | |||
best_hyp = sorted( | |||
hypotheses[b], key=lambda x: x[0], reverse=True) | |||
for each in best_hyp[:beam_size]: | |||
score, pred = each | |||
results['scores'][b].append(score) | |||
results['predictions'][b].append(pred) | |||
non_finished = end_condition.eq(0).nonzero().view(-1) | |||
# If all sentences are translated, no need to go further. | |||
if len(non_finished) == 0: | |||
break | |||
# Remove finished batches for the next step. | |||
topk_log_probs = topk_log_probs.index_select(0, non_finished) | |||
batch_index = batch_index.index_select(0, non_finished) | |||
batch_offset = batch_offset.index_select(0, non_finished) | |||
alive_seq = predictions.index_select(0, non_finished) \ | |||
.view(-1, alive_seq.size(-1)) | |||
# Reorder states. | |||
select_indices = batch_index.view(-1) | |||
src_features = src_features.index_select(0, select_indices) | |||
attention_mask = attention_mask.index_select(0, select_indices) | |||
pred_ids = [] | |||
scores = [] | |||
# print (pred_ids, scores) | |||
for each in results['scores']: | |||
scores.append(each[:out_size]) | |||
for each in results['predictions']: | |||
pred_ids.append(each[:out_size]) | |||
return pred_ids, scores | |||
def _generate_no_beam_search( | |||
self, | |||
input_ids, | |||
cur_len, | |||
max_length, | |||
do_sample, | |||
temperature, | |||
top_k, | |||
top_p, | |||
repetition_penalty, | |||
pad_token_id, | |||
eos_token_ids, | |||
batch_size, | |||
): | |||
""" Generate sequences for each example without beam search (num_beams == 1). | |||
All returned sequence are generated independantly. | |||
""" | |||
assert self.num_keep_best == 1, 'cannot generate >1 sentences in greedy search' | |||
# current position / max lengths / length of generated sentences / unfinished sentences | |||
unfinished_sents = [] | |||
cur_unfinished = input_ids.new(batch_size).fill_(1) | |||
# log of scores for each sentence in the batch | |||
logprobs = [] | |||
past = None | |||
while cur_len < max_length: | |||
model_inputs = self.prepare_inputs_for_generation( | |||
input_ids, past=past) | |||
outputs = self(**model_inputs) | |||
if cur_len == 1: | |||
token_len = 2 + self.od_labels_len | |||
next_token_idx = 1 | |||
else: | |||
assert cur_len > 1 | |||
if not self._do_output_past(outputs): | |||
token_len = cur_len + 1 + self.od_labels_len | |||
next_token_idx = cur_len | |||
else: | |||
token_len = 2 | |||
next_token_idx = 1 | |||
assert outputs[0].shape[1] == token_len | |||
next_token_logits = outputs[0][:, next_token_idx, :] | |||
# if model has past, then set the past variable to speed up decoding | |||
if self._do_output_past(outputs): | |||
past = outputs[1] | |||
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) | |||
if repetition_penalty != 1.0: | |||
for i in range(batch_size): | |||
for previous_token in set(input_ids[i].tolist()): | |||
# if score < 0 then repetition penalty has to multiplied | |||
# to reduce the previous token probability | |||
if next_token_logits[i, previous_token] < 0: | |||
next_token_logits[ | |||
i, previous_token] *= repetition_penalty | |||
else: | |||
next_token_logits[ | |||
i, previous_token] /= repetition_penalty | |||
if do_sample: | |||
# Temperature (higher temperature => more likely to sample low probability tokens) | |||
if temperature != 1.0: | |||
next_token_logits = next_token_logits / temperature | |||
# Top-p/top-k filtering | |||
next_token_logits = top_k_top_p_filtering( | |||
next_token_logits, top_k=top_k, top_p=top_p) | |||
# Sample | |||
next_token = torch.multinomial( | |||
F.softmax(next_token_logits, dim=-1), | |||
num_samples=1).squeeze(1) | |||
else: | |||
# Greedy decoding | |||
next_token = torch.argmax(next_token_logits, dim=-1) | |||
# Compute scores | |||
_scores = F.log_softmax( | |||
next_token_logits, dim=-1) # (batch_size, vocab_size) | |||
_scores = torch.gather(_scores, -1, | |||
next_token.unsqueeze(-1)) # (batch_size, 1) | |||
logprobs.append(_scores) # (batch_size, 1) | |||
unfinished_sents.append(cur_unfinished) | |||
# update generations and finished sentences | |||
tokens_to_add = next_token * cur_unfinished + pad_token_id * ( | |||
1 - cur_unfinished) | |||
input_ids = torch.cat( | |||
[input_ids, tokens_to_add.unsqueeze(-1)], dim=-1) | |||
for eos_token_id in eos_token_ids: | |||
cur_unfinished = cur_unfinished.mul( | |||
tokens_to_add.ne(eos_token_id).long()) | |||
cur_len = cur_len + 1 | |||
# stop when there is a </s> in each sentence, or if we exceed the maximul length | |||
if cur_unfinished.max() == 0: | |||
break | |||
# add eos_token_ids to unfinished sentences | |||
if cur_len == max_length: | |||
input_ids[:, -1].masked_fill_( | |||
cur_unfinished.to(dtype=torch.bool), eos_token_ids[0]) | |||
logprobs = torch.cat(logprobs, dim=1) | |||
unfinished_sents = torch.stack(unfinished_sents, dim=1).float() | |||
sum_logprobs = (logprobs * unfinished_sents).sum(dim=1) | |||
# return logprobs to keep consistent with beam search output | |||
logprobs = sum_logprobs / unfinished_sents.sum(dim=1) | |||
# pad to the same length, otherwise DataParallel will give error | |||
pad_len = max_length - input_ids.shape[1] | |||
if pad_len > 0: | |||
padding_ids = input_ids.new(batch_size, | |||
pad_len).fill_(pad_token_id) | |||
input_ids = torch.cat([input_ids, padding_ids], dim=1) | |||
# (batch_size, n_best, max_len), (batch_size, n_best) | |||
return input_ids.unsqueeze(1), logprobs.unsqueeze(1) | |||
def top_k_top_p_filtering(logits, | |||
top_k=10, | |||
top_p=1.0, | |||
filter_value=-float('Inf'), | |||
min_tokens_to_keep=1): | |||
if top_k > 0: | |||
top_k = min(max(top_k, min_tokens_to_keep), | |||
logits.size(-1)) # Safety check | |||
# Remove all tokens with a probability less than the last token of the top-k | |||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, | |||
None] | |||
logits[indices_to_remove] = filter_value | |||
if top_p < 1.0: | |||
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |||
cumulative_probs = torch.cumsum( | |||
F.softmax(sorted_logits, dim=-1), dim=-1) | |||
# Remove tokens with cumulative probability above the threshold (token with 0 are kept) | |||
sorted_indices_to_remove = cumulative_probs > top_p | |||
if min_tokens_to_keep > 1: | |||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) | |||
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 | |||
# Shift the indices to the right to keep also the first token above the threshold | |||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ | |||
..., :-1].clone() | |||
sorted_indices_to_remove[..., 0] = 0 | |||
# scatter sorted tensors to original indexing | |||
indices_to_remove = sorted_indices_to_remove.scatter( | |||
1, sorted_indices, sorted_indices_to_remove) | |||
logits[indices_to_remove] = filter_value | |||
return logits | |||
class Translation(object): | |||
""" | |||
Container for a translated sentence. | |||
Attributes: | |||
src (`LongTensor`): src word ids | |||
src_raw ([str]): raw src words | |||
pred_sents ([[str]]): words from the n-best translations | |||
pred_scores ([[float]]): log-probs of n-best translations | |||
attns ([`FloatTensor`]) : attention dist for each translation | |||
gold_sent ([str]): words from gold translation | |||
gold_score ([float]): log-prob of gold translation | |||
""" | |||
def __init__(self, fname, src, src_raw, pred_sents, attn, pred_scores, | |||
tgt_sent, gold_score): | |||
self.fname = fname | |||
self.src = src | |||
self.src_raw = src_raw | |||
self.pred_sents = pred_sents | |||
self.attns = attn | |||
self.pred_scores = pred_scores | |||
self.gold_sent = tgt_sent | |||
self.gold_score = gold_score | |||
def log(self, sent_number): | |||
""" | |||
Log translation. | |||
""" | |||
output = '\nSENT {}: {}\n'.format(sent_number, self.src_raw) | |||
best_pred = self.pred_sents[0] | |||
best_score = self.pred_scores[0] | |||
pred_sent = ' '.join(best_pred) | |||
output += 'PRED {}: {}\n'.format(sent_number, pred_sent) | |||
output += 'PRED SCORE: {:.4f}\n'.format(best_score) | |||
if self.gold_sent is not None: | |||
tgt_sent = ' '.join(self.gold_sent) | |||
output += 'GOLD {}: {}\n'.format(sent_number, tgt_sent) | |||
output += ('GOLD SCORE: {:.4f}\n'.format(self.gold_score)) | |||
if len(self.pred_sents) > 1: | |||
output += '\nBEST HYP:\n' | |||
for score, sent in zip(self.pred_scores, self.pred_sents): | |||
output += '[{:.4f}] {}\n'.format(score, sent) | |||
return output | |||
def tile(x, count, dim=0): | |||
""" | |||
Tiles x on dimension dim count times. | |||
""" | |||
perm = list(range(len(x.size()))) | |||
if dim != 0: | |||
perm[0], perm[dim] = perm[dim], perm[0] | |||
x = x.permute(perm).contiguous() | |||
out_size = list(x.size()) | |||
out_size[0] *= count | |||
batch = x.size(0) | |||
x = x.view(batch, -1) \ | |||
.transpose(0, 1) \ | |||
.repeat(count, 1) \ | |||
.transpose(0, 1) \ | |||
.contiguous() \ | |||
.view(*out_size) | |||
if dim != 0: | |||
x = x.permute(perm).contiguous() | |||
return x |
@@ -19,7 +19,7 @@ class MPlugForVisualQuestionAnswering(Model): | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
from sofa.models.mplug import MPlugForVisualQuestionAnswering | |||
from modelscope.models.multi_modal.mplug import MPlugForVisualQuestionAnswering | |||
self.model = MPlugForVisualQuestionAnswering.from_pretrained(model_dir) | |||
self.tokenizer = self.model.tokenizer | |||
@@ -306,5 +306,10 @@ TASK_OUTPUTS = { | |||
# { | |||
# "output_img": np.ndarray with shape [height, width, 3] | |||
# } | |||
Tasks.virtual_tryon: [OutputKeys.OUTPUT_IMG] | |||
Tasks.virtual_tryon: [OutputKeys.OUTPUT_IMG], | |||
# visual_question_answering result for a single sample | |||
# { | |||
# "text": "this is the text generated by a model." | |||
# } | |||
Tasks.visual_question_answering: [OutputKeys.TEXT] | |||
} |
@@ -5,6 +5,7 @@ import torch | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models import Model | |||
from modelscope.models.multi_modal import MPlugForVisualQuestionAnswering | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Pipeline, Tensor | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import MPlugVisualQuestionAnsweringPreprocessor | |||
@@ -62,4 +63,4 @@ class VisualQuestionAnsweringPipeline(Pipeline): | |||
for _old, _new in replace_tokens_bert: | |||
pred_string = pred_string.replace(_old, _new) | |||
pred_string.strip() | |||
return {'answer': pred_string} | |||
return {OutputKeys.TEXT: pred_string} |
@@ -78,14 +78,16 @@ class MPlugVisualQuestionAnsweringPreprocessor(Preprocessor): | |||
"""preprocess the data via 'bert-base-uncased' tokenizer and configuration | |||
""" | |||
from transformers import BertTokenizer | |||
from modelscope.models.multi_modal.mplug import CONFIG_NAME, VOCAB_NAME, MPlugConfig | |||
super().__init__(*args, **kwargs) | |||
# tokenizer | |||
from transformers import AutoTokenizer | |||
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |||
self.tokenizer = BertTokenizer.from_pretrained( | |||
osp.join(model_dir, VOCAB_NAME)) | |||
# load configuration | |||
from sofa.models.mplug import CONFIG_NAME, MPlugConfig | |||
config = MPlugConfig.from_yaml_file(osp.join(model_dir, CONFIG_NAME)) | |||
# Initialize transform | |||
@@ -30,8 +30,8 @@ class VisualQuestionAnsweringTest(unittest.TestCase): | |||
model=model, | |||
preprocessor=preprocessor) | |||
print(f"question: {self.input_vqa['question']}") | |||
print(f"pipeline1: {pipeline1(self.input_vqa)['answer']}") | |||
print(f"pipeline2: {pipeline2(self.input_vqa)['answer']}") | |||
print(f'pipeline1: {pipeline1(self.input_vqa)}') | |||
print(f'pipeline2: {pipeline2(self.input_vqa)}') | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_model_from_modelhub(self): | |||