Browse Source

[to #42322933] move mplug code to maas

Migrate MPLUG model code from sofa to maas.
No need to download checkpoint from huggingface anymore.
Added OutputKeys definition for vqa.
master
hemu.zp 3 years ago
parent
commit
5985bd10ba
11 changed files with 3175 additions and 8 deletions
  1. +18
    -0
      modelscope/models/multi_modal/mplug/__init__.py
  2. +1
    -0
      modelscope/models/multi_modal/mplug/clip/__init__.py
  3. +401
    -0
      modelscope/models/multi_modal/mplug/clip/clip.py
  4. +125
    -0
      modelscope/models/multi_modal/mplug/configuration_mplug.py
  5. +2079
    -0
      modelscope/models/multi_modal/mplug/modeling_mplug.py
  6. +535
    -0
      modelscope/models/multi_modal/mplug/predictor.py
  7. +1
    -1
      modelscope/models/multi_modal/mplug_for_visual_question_answering.py
  8. +6
    -1
      modelscope/outputs.py
  9. +2
    -1
      modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py
  10. +5
    -3
      modelscope/preprocessors/multi_modal.py
  11. +2
    -2
      tests/pipelines/test_visual_question_answering.py

+ 18
- 0
modelscope/models/multi_modal/mplug/__init__.py View File

@@ -0,0 +1,18 @@
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .configuration_mplug import MPlugConfig
from .modeling_mplug import (CONFIG_NAME, VOCAB_NAME,
MPlugForVisualQuestionAnswering)

+ 1
- 0
modelscope/models/multi_modal/mplug/clip/__init__.py View File

@@ -0,0 +1 @@
from .clip import load_from_config

+ 401
- 0
modelscope/models/multi_modal/mplug/clip/clip.py View File

@@ -0,0 +1,401 @@
# Copyright 2021 The OpenAI CLIP Authors. All rights reserved.

from collections import OrderedDict
from typing import Tuple, Union

import torch
import torch.nn.functional as F
from torch import nn

from modelscope.models.multi_modal.clip.clip_vit import Transformer


class Bottleneck(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, stride=1):
super().__init__()

# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)

self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)

self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()

self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)

self.relu = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride

if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(
OrderedDict([('-1', nn.AvgPool2d(stride)),
('0',
nn.Conv2d(
inplanes,
planes * self.expansion,
1,
stride=1,
bias=False)),
('1', nn.BatchNorm2d(planes * self.expansion))]))

def forward(self, x: torch.Tensor):
identity = x

out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu(out)
return out


class AttentionPool2d(nn.Module):

def __init__(self,
spacial_dim: int,
embed_dim: int,
num_heads: int,
output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads

def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1],
x.shape[2] * x.shape[3]).permute(2, 0,
1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
if self.training:
dropout = 0.1
else:
dropout = 0.0
x, _ = F.multi_head_attention_forward(
query=x,
key=x,
value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat(
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=dropout,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False)

return x[0]


class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""

def __init__(self,
layers,
output_dim,
heads,
input_resolution=224,
width=64):
super().__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution

# the 3-layer stem
self.conv1 = nn.Conv2d(
3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.conv2 = nn.Conv2d(
width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.conv3 = nn.Conv2d(
width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.avgpool = nn.AvgPool2d(2)
self.relu = nn.ReLU(inplace=True)

# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)

embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim,
heads, output_dim)

def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]

self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))

return nn.Sequential(*layers)

def forward(self, x, skip_last_layer=False):

def stem(x):
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2),
(self.conv3, self.bn3)]:
x = self.relu(bn(conv(x)))
x = self.avgpool(x)
return x

x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if not skip_last_layer:
x = self.attnpool(x)

return x


class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""

def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x)
return ret.type(orig_type)


class VisualTransformer(nn.Module):

def __init__(self, input_resolution: int, patch_size: int, width: int,
layers: int, heads: int, output_dim: int):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.heads = heads
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False)

scale = width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn(
(input_resolution // patch_size)**2 + 1, width))
self.ln_pre = LayerNorm(width)

self.transformer = Transformer(width, layers, heads)

self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

def forward(self,
x: torch.Tensor,
skip_last_layer=False,
text_embedding=None,
text_mask=None):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1],
-1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]

cls_emb = self.class_embedding.to(x.dtype)
x_zeros = torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
x = torch.cat([cls_emb + x_zeros, x],
dim=1) # shape = [*, grid ** 2 + 1, width]

x = x + self.positional_embedding.to(x.dtype)[:x.size(1), :]
x = self.ln_pre(x)

x = x.permute(1, 0, 2) # NLD -> LND

x = self.transformer(x)

x = x.permute(1, 0, 2) # LND -> NLD

if skip_last_layer:
x = self.ln_post(x)
# x = x @ self.proj
else:
x = x @ self.proj
return x


class CLIP(nn.Module):

def __init__(
self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int):
super().__init__()

self.context_length = context_length

if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width)
else:
vision_heads = vision_width // 64
self.visual = VisualTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim)

self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask())

self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(
torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)

self.text_projection = nn.Parameter(
torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]))

self.initialize_parameters()

def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)

if isinstance(self.visual, ModifiedResNet):
if self.visual.attnpool is not None:
std = self.visual.attnpool.c_proj.in_features**-0.5
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)

for resnet_block in [
self.visual.layer1, self.visual.layer2, self.visual.layer3,
self.visual.layer4
]:
for name, param in resnet_block.named_parameters():
if name.endswith('bn3.weight'):
nn.init.zeros_(param)

proj_std = (self.transformer.width**-0.5) * (
(2 * self.transformer.layers)**-0.5)
attn_std = self.transformer.width**-0.5
fc_std = (2 * self.transformer.width)**-0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)

if self.text_projection is not None:
nn.init.normal_(
self.text_projection, std=self.transformer.width**-0.5)

def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float('-inf'))
mask.triu_(1) # zero out the lower diagonal
return mask

@property
def dtype(self):
return self.visual.conv1.weight.dtype

def encode_image(self, image):
return self.visual(image.type(self.dtype))

def encode_text(self, text):
x = self.token_embedding(text).type(
self.dtype) # [batch_size, n_ctx, d_model]

x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)

# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]),
text.argmax(dim=-1)] @ self.text_projection

return x

def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)

# normalized features
image_features = image_features / image_features.norm(
dim=-1, keepdim=True)
text_features = text_features / text_features.norm(
dim=-1, keepdim=True)

# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logit_scale * text_features @ image_features.t()

# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text


def load_from_config(config):
return CLIP(config.clip_embed_dim, config.clip_image_resolution,
config.clip_vision_layers, config.clip_vision_width,
config.clip_vision_patch_size, config.clip_context_length,
config.clip_vocab_size, config.clip_transformer_width,
config.clip_transformer_heads, config.clip_transformer_layers)

+ 125
- 0
modelscope/models/multi_modal/mplug/configuration_mplug.py View File

@@ -0,0 +1,125 @@
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" MPLUG model configuration """
import os
from collections import OrderedDict
from typing import Any, Dict, Mapping, Union

import yaml
from transformers import PretrainedConfig
from transformers.onnx import OnnxConfig
from transformers.utils import logging

logger = logging.get_logger(__name__)


class MPlugConfig(PretrainedConfig):

model_type = 'mplug'

def __init__(
self,
bert_config='config_bert.json',
image_res=504,
batch_size_train=128,
vision_width=1024,
distill=True,
clip_name='ViT-L-14', # ViT-B-16 | ViT-L-14
batch_size_test=64,
k_test=128,
alpha=0.4,
warm_up=True,
eos='[SEP]',
optimizer=None,
schedular=None,
min_length=1,
max_length=10,
beam_size=5,
add_ocr=False,
add_object=False,
text_encoder='bert-base-uncased',
text_decoder='bert-base-uncased',
# clip
clip_embed_dim=768,
clip_image_resolution=224,
clip_vision_layers=24,
clip_vision_width=1024,
clip_vision_patch_size=14,
clip_context_length=77,
clip_vocab_size=49408,
clip_transformer_width=768,
clip_transformer_heads=12,
clip_transformer_layers=12,
**kwargs):
super().__init__(**kwargs)
self.bert_config = bert_config
self.image_res = image_res
self.batch_size_train = batch_size_train
self.vision_width = vision_width
self.distill = distill
self.clip_name = clip_name
self.batch_size_test = batch_size_test
self.k_test = k_test
self.alpha = alpha
self.warm_up = warm_up
self.eos = eos
self.optimizer = optimizer
self.schedular = schedular
self.min_length = min_length
self.max_length = max_length
self.beam_size = beam_size
self.add_ocr = add_ocr
self.add_object = add_object
self.text_encoder = text_encoder
self.text_decoder = text_decoder
# clip
self.clip_embed_dim = clip_embed_dim
self.clip_image_resolution = clip_image_resolution
self.clip_vision_layers = clip_vision_layers
self.clip_vision_width = clip_vision_width
self.clip_vision_patch_size = clip_vision_patch_size
self.clip_context_length = clip_context_length
self.clip_vocab_size = clip_vocab_size
self.clip_transformer_width = clip_transformer_width
self.clip_transformer_heads = clip_transformer_heads
self.clip_transformer_layers = clip_transformer_layers

@classmethod
def from_yaml_file(cls, yaml_file: Union[str,
os.PathLike]) -> Dict[str, Any]:
with open(yaml_file, 'r') as reader:
config_dict = yaml.load(reader, Loader=yaml.Loader)
return cls(**config_dict)


class MPlugOnnxConfig(OnnxConfig):

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([
('input_ids', {
0: 'batch',
1: 'sequence'
}),
('attention_mask', {
0: 'batch',
1: 'sequence'
}),
('token_type_ids', {
0: 'batch',
1: 'sequence'
}),
])

+ 2079
- 0
modelscope/models/multi_modal/mplug/modeling_mplug.py
File diff suppressed because it is too large
View File


+ 535
- 0
modelscope/models/multi_modal/mplug/predictor.py View File

@@ -0,0 +1,535 @@
from __future__ import print_function

import torch
import torch.nn.functional as F


def build_predictor(args, tokenizer, symbols, model, logger=None):
scorer = None

translator = TextGenerator(
args, model, tokenizer, symbols, global_scorer=scorer, logger=logger)
return translator


class TextGenerator(object):
"""
Uses a model to translate a batch of sentences.


Args:
model (:obj:`onmt.modules.NMTModel`):
NMT model to use for translation
fields (dict of Fields): data fields
beam_size (int): size of beam to use
n_best (int): number of translations produced
max_length (int): maximum length output to produce
global_scores (:obj:`GlobalScorer`):
object to rescore final translations
copy_attn (bool): use copy attention during translation
cuda (bool): use cuda
beam_trace (bool): trace beam search for debugging
logger(logging.Logger): logger.
"""

def __init__(self,
args,
model,
vocab=None,
symbols=None,
global_scorer=None,
logger=None,
dump_beam=''):
self.alpha = 0.6

self.logger = logger
self.cuda = (torch.cuda.device_count() > 0)

self.args = args
self.model = model

self.vocab = vocab
self.symbols = symbols
self.start_token = 101 # ['[PAD]']
self.end_token = 102 # ['[PAD]']

self.global_scorer = global_scorer
self.beam_size = args.beam_size
self.min_length = args.min_length
self.max_length = args.max_length

self.dump_beam = dump_beam

# for debugging
self.beam_trace = self.dump_beam != ''
self.beam_accum = None

if self.beam_trace:
self.beam_accum = {
'predicted_ids': [],
'beam_parent_ids': [],
'scores': [],
'log_probs': []
}

def _build_target_tokens(self, pred):
tokens = []
for tok in pred:
tok = int(tok)
tokens.append(tok)
if tokens[-1] == self.end_token:
tokens = tokens[:-1]
break
tokens = [t for t in tokens if t < len(self.vocab)]
tokens = self.vocab.DecodeIds(tokens).split(' ')
return tokens

def translate_batch(self, encoder_inputs, do_sample=False, out_size=1):
"""
Translate a batch of sentences.

Mostly a wrapper around :obj:`Beam`.

Args:
batch (:obj:`Batch`): a batch from a dataset object
data (:obj:`Dataset`): the dataset object
fast (bool): enables fast beam search (may not support all features)

Todo:
Shouldn't need the original dataset.
"""
if do_sample:
return self._fast_translate_batch(
encoder_inputs,
self.max_length,
min_length=self.min_length,
do_sample=do_sample,
out_size=out_size)
else:
with torch.no_grad():
return self._fast_translate_batch(
encoder_inputs,
self.max_length,
min_length=self.min_length,
do_sample=do_sample,
out_size=out_size)

def translate_batch_scst(self,
encoder_inputs,
do_sample=False,
out_size=1):
return self._fast_translate_batch(
encoder_inputs,
self.max_length,
min_length=self.min_length,
do_sample=do_sample,
out_size=out_size)

def _fast_translate_batch(self,
encoder_inputs,
max_length,
min_length=0,
do_sample=False,
out_size=1):

assert not self.dump_beam
if do_sample:
beam_size = 1
else:
beam_size = self.beam_size
if len(encoder_inputs) == 3:
src_features, padding_mask, input_ids = encoder_inputs
elif len(encoder_inputs) == 2:
src_features, padding_mask = encoder_inputs
input_ids = None

device = src_features.device

# Tile states and memory beam_size times.
batch_size = src_features.size(0)
src_features = tile(src_features, beam_size, dim=0)
attention_mask = tile(padding_mask, beam_size, dim=0)

batch_offset = torch.arange(
batch_size, dtype=torch.long, device=device)
beam_offset = torch.arange(
0,
batch_size * beam_size,
step=beam_size,
dtype=torch.long,
device=device)
if input_ids is not None:
alive_seq = tile(input_ids, beam_size, dim=0)
else:
alive_seq = torch.full([batch_size * beam_size, 1],
self.start_token,
dtype=torch.long,
device=device)

# Give full probability to the first beam on the first step.
topk_log_probs = (
torch.tensor(
[0.0] + [float('-inf')] * (beam_size - 1),
device=device).repeat(batch_size))

# Structure that holds finished hypotheses.
hypotheses = [[] for _ in range(batch_size)] # noqa: F812

results = {}
results['predictions'] = [[] for _ in range(batch_size)] # noqa: F812
results['scores'] = [[] for _ in range(batch_size)] # noqa: F812
results['gold_score'] = [0] * batch_size
results['batch'] = []

for step in range(max_length):
dec_feat_seq = self.model(
alive_seq,
encoder_hidden_states=src_features,
encoder_attention_mask=attention_mask,
return_dict=True,
reduction='none')

dec_feat_seq = dec_feat_seq.logits[:, -1, :]
vocab_size = dec_feat_seq.size(-1)
log_probs = torch.log(
torch.softmax(dec_feat_seq.view(-1, vocab_size), dim=-1))
if step < min_length:
log_probs[:, self.end_token] = -1e20
alpha = self.alpha
if do_sample:
length_penalty = 1.0
else:
length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha

if do_sample:
_scores = log_probs / self.args.temperature
_scores = top_k_top_p_filtering(
_scores,
top_k=self.args.top_k,
top_p=self.args.top_p,
min_tokens_to_keep=1
) # (batch_size * num_beams, vocab_size)
# Sample 2 next words for each beam
# (so we have some spare tokens and match output of greedy beam search)
topk_ids = torch.multinomial(
F.softmax(_scores, dim=-1),
num_samples=1) # (batch_size * num_beams, 2)
# Compute next scores
_scores = F.log_softmax(
_scores, dim=1) # (batch_size * num_beams, vocab_size)

_scores += topk_log_probs.view(-1).unsqueeze(1)
topk_scores = torch.gather(
_scores, -1, topk_ids) # (batch_size * num_beams, 2)
# log_probs += # (batch_size * num_beams, 2)
# Match shape of greedy beam search
topk_ids = topk_ids.view(
-1, beam_size) # (batch_size, 2 * num_beams)
topk_scores = topk_scores.view(
-1, beam_size) # (batch_size, 2 * num_beams)
else:
log_probs += topk_log_probs.view(-1).unsqueeze(1)
curr_scores = log_probs / length_penalty

curr_scores = curr_scores.reshape(-1, beam_size * vocab_size)
topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1)
topk_log_probs = topk_scores * length_penalty

# Resolve beam origin and true word ids.
# topk_beam_index = topk_ids.div(vocab_size)
topk_beam_index = torch.div(
topk_ids, vocab_size, rounding_mode='floor')
topk_ids = topk_ids.fmod(vocab_size)

# Map beam_index to batch_index in the flat representation.
batch_index = (
topk_beam_index
+ beam_offset[:topk_beam_index.size(0)].unsqueeze(1))
select_indices = batch_index.view(-1)

# Append last prediction.
alive_seq = torch.cat([
alive_seq.index_select(0, select_indices),
topk_ids.view(-1, 1)
], -1)

is_finished = topk_ids.eq(self.end_token)
if step + 1 == max_length:
is_finished.fill_(1) # self.end_token)
# End condition is top beam is finished.
end_condition = is_finished[:, 0].eq(1) # self.end_token)
# Save finished hypotheses.
if is_finished.any():
predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1))
for i in range(is_finished.size(0)):
b = batch_offset[i]
if end_condition[i]:
is_finished[i].fill_(1) # self.end_token)
finished_hyp = is_finished[i].nonzero().view(-1)
# Store finished hypotheses for this batch.
for j in finished_hyp:
hypotheses[b].append(
(topk_scores[i, j], predictions[i, j, 0:]))
# If the batch reached the end, save the n_best hypotheses.
if end_condition[i]:
best_hyp = sorted(
hypotheses[b], key=lambda x: x[0], reverse=True)

for each in best_hyp[:beam_size]:
score, pred = each
results['scores'][b].append(score)
results['predictions'][b].append(pred)
non_finished = end_condition.eq(0).nonzero().view(-1)
# If all sentences are translated, no need to go further.
if len(non_finished) == 0:
break
# Remove finished batches for the next step.
topk_log_probs = topk_log_probs.index_select(0, non_finished)
batch_index = batch_index.index_select(0, non_finished)
batch_offset = batch_offset.index_select(0, non_finished)
alive_seq = predictions.index_select(0, non_finished) \
.view(-1, alive_seq.size(-1))
# Reorder states.
select_indices = batch_index.view(-1)
src_features = src_features.index_select(0, select_indices)
attention_mask = attention_mask.index_select(0, select_indices)
pred_ids = []
scores = []
# print (pred_ids, scores)
for each in results['scores']:
scores.append(each[:out_size])
for each in results['predictions']:
pred_ids.append(each[:out_size])
return pred_ids, scores

def _generate_no_beam_search(
self,
input_ids,
cur_len,
max_length,
do_sample,
temperature,
top_k,
top_p,
repetition_penalty,
pad_token_id,
eos_token_ids,
batch_size,
):
""" Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
"""
assert self.num_keep_best == 1, 'cannot generate >1 sentences in greedy search'
# current position / max lengths / length of generated sentences / unfinished sentences
unfinished_sents = []
cur_unfinished = input_ids.new(batch_size).fill_(1)

# log of scores for each sentence in the batch
logprobs = []

past = None

while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past)
outputs = self(**model_inputs)
if cur_len == 1:
token_len = 2 + self.od_labels_len
next_token_idx = 1
else:
assert cur_len > 1
if not self._do_output_past(outputs):
token_len = cur_len + 1 + self.od_labels_len
next_token_idx = cur_len
else:
token_len = 2
next_token_idx = 1
assert outputs[0].shape[1] == token_len

next_token_logits = outputs[0][:, next_token_idx, :]

# if model has past, then set the past variable to speed up decoding
if self._do_output_past(outputs):
past = outputs[1]

# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
for i in range(batch_size):
for previous_token in set(input_ids[i].tolist()):
# if score < 0 then repetition penalty has to multiplied
# to reduce the previous token probability
if next_token_logits[i, previous_token] < 0:
next_token_logits[
i, previous_token] *= repetition_penalty
else:
next_token_logits[
i, previous_token] /= repetition_penalty

if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
# Top-p/top-k filtering
next_token_logits = top_k_top_p_filtering(
next_token_logits, top_k=top_k, top_p=top_p)
# Sample
next_token = torch.multinomial(
F.softmax(next_token_logits, dim=-1),
num_samples=1).squeeze(1)
else:
# Greedy decoding
next_token = torch.argmax(next_token_logits, dim=-1)

# Compute scores
_scores = F.log_softmax(
next_token_logits, dim=-1) # (batch_size, vocab_size)
_scores = torch.gather(_scores, -1,
next_token.unsqueeze(-1)) # (batch_size, 1)
logprobs.append(_scores) # (batch_size, 1)
unfinished_sents.append(cur_unfinished)

# update generations and finished sentences
tokens_to_add = next_token * cur_unfinished + pad_token_id * (
1 - cur_unfinished)
input_ids = torch.cat(
[input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)

for eos_token_id in eos_token_ids:
cur_unfinished = cur_unfinished.mul(
tokens_to_add.ne(eos_token_id).long())
cur_len = cur_len + 1

# stop when there is a </s> in each sentence, or if we exceed the maximul length
if cur_unfinished.max() == 0:
break

# add eos_token_ids to unfinished sentences
if cur_len == max_length:
input_ids[:, -1].masked_fill_(
cur_unfinished.to(dtype=torch.bool), eos_token_ids[0])

logprobs = torch.cat(logprobs, dim=1)
unfinished_sents = torch.stack(unfinished_sents, dim=1).float()
sum_logprobs = (logprobs * unfinished_sents).sum(dim=1)
# return logprobs to keep consistent with beam search output
logprobs = sum_logprobs / unfinished_sents.sum(dim=1)

# pad to the same length, otherwise DataParallel will give error
pad_len = max_length - input_ids.shape[1]
if pad_len > 0:
padding_ids = input_ids.new(batch_size,
pad_len).fill_(pad_token_id)
input_ids = torch.cat([input_ids, padding_ids], dim=1)

# (batch_size, n_best, max_len), (batch_size, n_best)
return input_ids.unsqueeze(1), logprobs.unsqueeze(1)


def top_k_top_p_filtering(logits,
top_k=10,
top_p=1.0,
filter_value=-float('Inf'),
min_tokens_to_keep=1):

if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep),
logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1,
None]
logits[indices_to_remove] = filter_value

if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(
F.softmax(sorted_logits, dim=-1), dim=-1)

# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0

# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits


class Translation(object):
"""
Container for a translated sentence.

Attributes:
src (`LongTensor`): src word ids
src_raw ([str]): raw src words

pred_sents ([[str]]): words from the n-best translations
pred_scores ([[float]]): log-probs of n-best translations
attns ([`FloatTensor`]) : attention dist for each translation
gold_sent ([str]): words from gold translation
gold_score ([float]): log-prob of gold translation

"""

def __init__(self, fname, src, src_raw, pred_sents, attn, pred_scores,
tgt_sent, gold_score):
self.fname = fname
self.src = src
self.src_raw = src_raw
self.pred_sents = pred_sents
self.attns = attn
self.pred_scores = pred_scores
self.gold_sent = tgt_sent
self.gold_score = gold_score

def log(self, sent_number):
"""
Log translation.
"""

output = '\nSENT {}: {}\n'.format(sent_number, self.src_raw)

best_pred = self.pred_sents[0]
best_score = self.pred_scores[0]
pred_sent = ' '.join(best_pred)
output += 'PRED {}: {}\n'.format(sent_number, pred_sent)
output += 'PRED SCORE: {:.4f}\n'.format(best_score)

if self.gold_sent is not None:
tgt_sent = ' '.join(self.gold_sent)
output += 'GOLD {}: {}\n'.format(sent_number, tgt_sent)
output += ('GOLD SCORE: {:.4f}\n'.format(self.gold_score))
if len(self.pred_sents) > 1:
output += '\nBEST HYP:\n'
for score, sent in zip(self.pred_scores, self.pred_sents):
output += '[{:.4f}] {}\n'.format(score, sent)

return output


def tile(x, count, dim=0):
"""
Tiles x on dimension dim count times.
"""
perm = list(range(len(x.size())))
if dim != 0:
perm[0], perm[dim] = perm[dim], perm[0]
x = x.permute(perm).contiguous()
out_size = list(x.size())
out_size[0] *= count
batch = x.size(0)
x = x.view(batch, -1) \
.transpose(0, 1) \
.repeat(count, 1) \
.transpose(0, 1) \
.contiguous() \
.view(*out_size)
if dim != 0:
x = x.permute(perm).contiguous()
return x

+ 1
- 1
modelscope/models/multi_modal/mplug_for_visual_question_answering.py View File

@@ -19,7 +19,7 @@ class MPlugForVisualQuestionAnswering(Model):
"""

super().__init__(model_dir, *args, **kwargs)
from sofa.models.mplug import MPlugForVisualQuestionAnswering
from modelscope.models.multi_modal.mplug import MPlugForVisualQuestionAnswering
self.model = MPlugForVisualQuestionAnswering.from_pretrained(model_dir)
self.tokenizer = self.model.tokenizer



+ 6
- 1
modelscope/outputs.py View File

@@ -306,5 +306,10 @@ TASK_OUTPUTS = {
# {
# "output_img": np.ndarray with shape [height, width, 3]
# }
Tasks.virtual_tryon: [OutputKeys.OUTPUT_IMG]
Tasks.virtual_tryon: [OutputKeys.OUTPUT_IMG],
# visual_question_answering result for a single sample
# {
# "text": "this is the text generated by a model."
# }
Tasks.visual_question_answering: [OutputKeys.TEXT]
}

+ 2
- 1
modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py View File

@@ -5,6 +5,7 @@ import torch
from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.models.multi_modal import MPlugForVisualQuestionAnswering
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline, Tensor
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import MPlugVisualQuestionAnsweringPreprocessor
@@ -62,4 +63,4 @@ class VisualQuestionAnsweringPipeline(Pipeline):
for _old, _new in replace_tokens_bert:
pred_string = pred_string.replace(_old, _new)
pred_string.strip()
return {'answer': pred_string}
return {OutputKeys.TEXT: pred_string}

+ 5
- 3
modelscope/preprocessors/multi_modal.py View File

@@ -78,14 +78,16 @@ class MPlugVisualQuestionAnsweringPreprocessor(Preprocessor):
"""preprocess the data via 'bert-base-uncased' tokenizer and configuration

"""
from transformers import BertTokenizer
from modelscope.models.multi_modal.mplug import CONFIG_NAME, VOCAB_NAME, MPlugConfig

super().__init__(*args, **kwargs)

# tokenizer
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
self.tokenizer = BertTokenizer.from_pretrained(
osp.join(model_dir, VOCAB_NAME))

# load configuration
from sofa.models.mplug import CONFIG_NAME, MPlugConfig
config = MPlugConfig.from_yaml_file(osp.join(model_dir, CONFIG_NAME))

# Initialize transform


+ 2
- 2
tests/pipelines/test_visual_question_answering.py View File

@@ -30,8 +30,8 @@ class VisualQuestionAnsweringTest(unittest.TestCase):
model=model,
preprocessor=preprocessor)
print(f"question: {self.input_vqa['question']}")
print(f"pipeline1: {pipeline1(self.input_vqa)['answer']}")
print(f"pipeline2: {pipeline2(self.input_vqa)['answer']}")
print(f'pipeline1: {pipeline1(self.input_vqa)}')
print(f'pipeline2: {pipeline2(self.input_vqa)}')

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):


Loading…
Cancel
Save