Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9491487master
@@ -0,0 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | |||||
oid sha256:24b78db10990c809380508b962decb53cb16db582135cb3c7d56c48f71d5ceb8 | |||||
size 39683 |
@@ -31,6 +31,7 @@ class Models(object): | |||||
# multi-modal models | # multi-modal models | ||||
ofa = 'ofa' | ofa = 'ofa' | ||||
clip = 'clip-multi-modal-embedding' | clip = 'clip-multi-modal-embedding' | ||||
gemm = 'gemm-generative-multi-modal' | |||||
mplug = 'mplug' | mplug = 'mplug' | ||||
imagen = 'imagen-text-to-image-synthesis' | imagen = 'imagen-text-to-image-synthesis' | ||||
@@ -95,6 +96,7 @@ class Pipelines(object): | |||||
# multi-modal tasks | # multi-modal tasks | ||||
image_captioning = 'image-captioning' | image_captioning = 'image-captioning' | ||||
multi_modal_embedding = 'multi-modal-embedding' | multi_modal_embedding = 'multi-modal-embedding' | ||||
generative_multi_modal_embedding = 'generative-multi-modal-embedding' | |||||
visual_question_answering = 'visual-question-answering' | visual_question_answering = 'visual-question-answering' | ||||
text_to_image_synthesis = 'text-to-image-synthesis' | text_to_image_synthesis = 'text-to-image-synthesis' | ||||
@@ -1,4 +1,5 @@ | |||||
from .clip.clip_model import CLIPForMultiModalEmbedding | from .clip.clip_model import CLIPForMultiModalEmbedding | ||||
from .gemm.gemm_model import GEMMForMultiModalEmbedding | |||||
from .imagen.imagen_model import ImagenForTextToImageSynthesis | from .imagen.imagen_model import ImagenForTextToImageSynthesis | ||||
from .mplug_for_visual_question_answering import \ | from .mplug_for_visual_question_answering import \ | ||||
MPlugForVisualQuestionAnswering | MPlugForVisualQuestionAnswering | ||||
@@ -0,0 +1,550 @@ | |||||
""" Generative Multimodal Model | |||||
Base modules are adapted from https://github.com/openai/CLIP/, | |||||
originally MIT License, Copyright (c) 2021 OpenAI, | |||||
and adapted from https://github.com/lucidrains/CoCa-pytorch/, | |||||
originally MIT License, Copyright (c) 2022 Phil Wang. | |||||
""" | |||||
import os | |||||
from collections import OrderedDict | |||||
from typing import Tuple, Union | |||||
import json | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn.functional as F | |||||
from torch import nn | |||||
from torch.nn import LayerNorm | |||||
from modelscope.models.multi_modal.gemm.tokenizer import (SimpleTokenizer, | |||||
clip_tokenize) | |||||
class Bottleneck(nn.Module): | |||||
""" ResNet style bottleneck module | |||||
From https://github.com/openai/CLIP/blob/main/clip/model.py | |||||
""" | |||||
expansion = 4 | |||||
def __init__(self, inplanes, planes, stride=1): | |||||
super().__init__() | |||||
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) | |||||
self.bn1 = nn.BatchNorm2d(planes) | |||||
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) | |||||
self.bn2 = nn.BatchNorm2d(planes) | |||||
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() | |||||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) | |||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion) | |||||
self.relu = nn.ReLU(inplace=True) | |||||
self.downsample = None | |||||
self.stride = stride | |||||
if stride > 1 or inplanes != planes * Bottleneck.expansion: | |||||
self.downsample = nn.Sequential( | |||||
OrderedDict([('-1', nn.AvgPool2d(stride)), | |||||
('0', | |||||
nn.Conv2d( | |||||
inplanes, | |||||
planes * self.expansion, | |||||
1, | |||||
stride=1, | |||||
bias=False)), | |||||
('1', nn.BatchNorm2d(planes * self.expansion))])) | |||||
def forward(self, x: torch.Tensor): | |||||
identity = x | |||||
out = self.relu(self.bn1(self.conv1(x))) | |||||
out = self.relu(self.bn2(self.conv2(out))) | |||||
out = self.avgpool(out) | |||||
out = self.bn3(self.conv3(out)) | |||||
if self.downsample is not None: | |||||
identity = self.downsample(x) | |||||
out += identity | |||||
out = self.relu(out) | |||||
return out | |||||
class QuickGELU(nn.Module): | |||||
""" A quick version of GELU module | |||||
From https://github.com/openai/CLIP/blob/main/clip/model.py | |||||
""" | |||||
def forward(self, x: torch.Tensor): | |||||
return x * torch.sigmoid(1.702 * x) | |||||
class ResidualAttentionBlock(nn.Module): | |||||
""" Multihead attention block with residual link | |||||
Adapted from https://github.com/openai/CLIP/blob/main/clip/model.py | |||||
""" | |||||
def __init__(self, | |||||
d_model: int, | |||||
n_head: int, | |||||
attn_mask: torch.Tensor = None): | |||||
super().__init__() | |||||
self.attn = nn.MultiheadAttention(d_model, n_head) | |||||
self.ln_1 = LayerNorm(d_model) | |||||
self.mlp = nn.Sequential( | |||||
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), | |||||
('gelu', QuickGELU()), | |||||
('c_proj', nn.Linear(d_model * 4, d_model))])) | |||||
self.ln_2 = LayerNorm(d_model) | |||||
self.attn_mask = attn_mask | |||||
def attention(self, x: torch.Tensor): | |||||
self.attn_mask = self.attn_mask.to( | |||||
dtype=x.dtype, | |||||
device=x.device) if self.attn_mask is not None else None | |||||
attn_mask = self.attn_mask | |||||
if attn_mask is not None and attn_mask.shape[0] > x.shape[0]: | |||||
attn_mask = self.attn_mask[:x.shape[0], :x.shape[0]] | |||||
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] | |||||
def forward(self, x: torch.Tensor): | |||||
x = x + self.attention(self.ln_1(x)) | |||||
x = x + self.mlp(self.ln_2(x)) | |||||
return x | |||||
class Transformer(nn.Module): | |||||
""" Transformer encoder module | |||||
Adapted from https://github.com/openai/CLIP/blob/main/clip/model.py | |||||
""" | |||||
def __init__(self, | |||||
width: int, | |||||
layers: int, | |||||
heads: int, | |||||
attn_mask: torch.Tensor = None, | |||||
use_gc: bool = False): | |||||
super().__init__() | |||||
self.use_gc = use_gc | |||||
self.width = width | |||||
self.layers = layers | |||||
self.resblocks = nn.Sequential(*[ | |||||
ResidualAttentionBlock(width, heads, attn_mask) | |||||
for _ in range(layers) | |||||
]) | |||||
def forward(self, x: torch.Tensor): | |||||
return self.resblocks(x) | |||||
class AttentionPool2d(nn.Module): | |||||
""" Pool layer with attention module | |||||
Adapted from https://github.com/openai/CLIP/blob/main/clip/model.py | |||||
""" | |||||
def __init__(self, | |||||
spacial_dim: int, | |||||
embed_dim: int, | |||||
num_heads: int, | |||||
output_dim: int = None): | |||||
super().__init__() | |||||
self.positional_embedding = nn.Parameter( | |||||
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) | |||||
self.k_proj = nn.Linear(embed_dim, embed_dim) | |||||
self.q_proj = nn.Linear(embed_dim, embed_dim) | |||||
self.v_proj = nn.Linear(embed_dim, embed_dim) | |||||
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) | |||||
self.num_heads = num_heads | |||||
def forward(self, x): | |||||
x = x.reshape(x.shape[0], x.shape[1], | |||||
x.shape[2] * x.shape[3]).permute(2, 0, 1) | |||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) | |||||
x = x + self.positional_embedding[:, None, :].to(x.dtype) | |||||
x, _ = F.multi_head_attention_forward( | |||||
query=x, | |||||
key=x, | |||||
value=x, | |||||
embed_dim_to_check=x.shape[-1], | |||||
num_heads=self.num_heads, | |||||
q_proj_weight=self.q_proj.weight, | |||||
k_proj_weight=self.k_proj.weight, | |||||
v_proj_weight=self.v_proj.weight, | |||||
in_proj_weight=None, | |||||
in_proj_bias=torch.cat( | |||||
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), | |||||
bias_k=None, | |||||
bias_v=None, | |||||
add_zero_attn=False, | |||||
dropout_p=0, | |||||
out_proj_weight=self.c_proj.weight, | |||||
out_proj_bias=self.c_proj.bias, | |||||
use_separate_proj_weight=True, | |||||
training=self.training, | |||||
need_weights=False) | |||||
return x.permute(1, 0, 2).contiguous() | |||||
class CrossAttention(nn.Module): | |||||
""" Cross attention module with query and context as input | |||||
Adapted from https://github.com/lucidrains/CoCa-pytorch/blob/main/coca_pytorch/coca_pytorch.py | |||||
""" | |||||
def __init__(self, | |||||
dim, | |||||
*, | |||||
context_dim=None, | |||||
dim_head=64, | |||||
heads=8, | |||||
parallel_ff=False, | |||||
ff_mult=4, | |||||
norm_context=False): | |||||
super().__init__() | |||||
self.heads = heads | |||||
self.scale = dim_head**-0.5 | |||||
inner_dim = heads * dim_head | |||||
context_dim = dim if context_dim is None else context_dim | |||||
self.norm = LayerNorm(dim) | |||||
self.context_norm = LayerNorm( | |||||
context_dim) if norm_context else nn.Identity() | |||||
self.to_q = nn.Linear(dim, inner_dim, bias=False) | |||||
self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False) | |||||
self.to_out = nn.Linear(inner_dim, dim, bias=False) | |||||
ff_inner_dim = ff_mult * dim | |||||
self.ff = nn.Sequential( | |||||
nn.Linear(dim, ff_inner_dim * 2, bias=False), SwiGLU(), | |||||
nn.Linear(ff_inner_dim, dim, bias=False)) if parallel_ff else None | |||||
def forward(self, x, context): | |||||
""" | |||||
einstein notation | |||||
b - batch | |||||
h - heads | |||||
n, i, j - sequence length (base sequence length, source, target) | |||||
d - feature dimension | |||||
""" | |||||
x = self.norm(x) | |||||
context = self.context_norm(context) | |||||
q = self.to_q(x) | |||||
q = q.view(q.shape[0], q.shape[1], self.heads, | |||||
-1).permute(0, 2, 1, 3).contiguous() | |||||
q = q * self.scale | |||||
k, v = self.to_kv(context).chunk(2, dim=-1) | |||||
sim = torch.einsum('b h i d, b j d -> b h i j', q, k) | |||||
sim = sim - sim.amax(dim=-1, keepdim=True) | |||||
attn = sim.softmax(dim=-1) | |||||
out = torch.einsum('b h i j, b j d -> b h i d', attn, v) | |||||
out = out.permute(0, 2, 1, | |||||
3).contiguous().reshape(out.shape[0], out.shape[2], | |||||
-1) | |||||
out = self.to_out(out) | |||||
if self.ff is not None: | |||||
out = out + self.ff(x) | |||||
return out | |||||
class ModifiedResNet(nn.Module): | |||||
""" Modified ResNet backbone | |||||
From https://github.com/openai/CLIP/blob/main/clip/model.py | |||||
A ResNet class that is similar to torchvision's but contains the following changes: | |||||
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. | |||||
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 | |||||
- The final pooling layer is a QKV attention instead of an average pool | |||||
""" | |||||
def __init__(self, | |||||
layers, | |||||
output_dim, | |||||
heads, | |||||
input_resolution=224, | |||||
width=64): | |||||
super().__init__() | |||||
self.output_dim = output_dim | |||||
self.input_resolution = input_resolution | |||||
self.conv1 = nn.Conv2d( | |||||
3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) | |||||
self.bn1 = nn.BatchNorm2d(width // 2) | |||||
self.conv2 = nn.Conv2d( | |||||
width // 2, width // 2, kernel_size=3, padding=1, bias=False) | |||||
self.bn2 = nn.BatchNorm2d(width // 2) | |||||
self.conv3 = nn.Conv2d( | |||||
width // 2, width, kernel_size=3, padding=1, bias=False) | |||||
self.bn3 = nn.BatchNorm2d(width) | |||||
self.avgpool = nn.AvgPool2d(2) | |||||
self.relu = nn.ReLU(inplace=True) | |||||
self._inplanes = width | |||||
self.layer1 = self._make_layer(width, layers[0]) | |||||
self.layer2 = self._make_layer(width * 2, layers[1], stride=2) | |||||
self.layer3 = self._make_layer(width * 4, layers[2], stride=2) | |||||
self.layer4 = self._make_layer(width * 8, layers[3], stride=2) | |||||
embed_dim = width * 32 | |||||
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, | |||||
heads, output_dim) | |||||
def _make_layer(self, planes, blocks, stride=1): | |||||
layers = [Bottleneck(self._inplanes, planes, stride)] | |||||
self._inplanes = planes * Bottleneck.expansion | |||||
for _ in range(1, blocks): | |||||
layers.append(Bottleneck(self._inplanes, planes)) | |||||
return nn.Sequential(*layers) | |||||
def forward(self, x): | |||||
def stem(x): | |||||
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), | |||||
(self.conv3, self.bn3)]: | |||||
x = self.relu(bn(conv(x))) | |||||
x = self.avgpool(x) | |||||
return x | |||||
x = stem(x) | |||||
x = self.layer1(x) | |||||
x = self.layer2(x) | |||||
x = self.layer3(x) | |||||
x = self.layer4(x) | |||||
x = self.attnpool(x) | |||||
return x | |||||
class VisualTransformer(nn.Module): | |||||
""" ViT transformer backbone | |||||
From https://github.com/openai/CLIP/blob/main/clip/model.py | |||||
""" | |||||
def __init__(self, input_resolution: int, patch_size: int, width: int, | |||||
layers: int, heads: int, output_dim: int, use_gc: bool): | |||||
super().__init__() | |||||
self.input_resolution = input_resolution | |||||
self.output_dim = output_dim | |||||
self.conv1 = nn.Conv2d( | |||||
in_channels=3, | |||||
out_channels=width, | |||||
kernel_size=patch_size, | |||||
stride=patch_size, | |||||
bias=False) | |||||
scale = width**-0.5 | |||||
self.class_embedding = nn.Parameter(scale * torch.randn(width)) | |||||
self.positional_embedding = nn.Parameter(scale * torch.randn( | |||||
(input_resolution // patch_size)**2 + 1, width)) | |||||
self.ln_pre = LayerNorm(width) | |||||
self.transformer = Transformer(width, layers, heads, use_gc=use_gc) | |||||
self.ln_post = LayerNorm(width) | |||||
self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) | |||||
def forward(self, x: torch.Tensor): | |||||
x = self.conv1(x) | |||||
x = x.reshape(x.shape[0], x.shape[1], -1) | |||||
x = x.permute(0, 2, 1) | |||||
z = torch.zeros( | |||||
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) | |||||
x = torch.cat([self.class_embedding.to(x.dtype) + z, x], dim=1) | |||||
x = x + self.positional_embedding.to(x.dtype) | |||||
x = self.ln_pre(x) | |||||
x = x.permute(1, 0, 2) | |||||
x = self.transformer(x) | |||||
x = x.permute(1, 0, 2) | |||||
x = self.ln_post(x) | |||||
if self.proj is not None: | |||||
x = x @ self.proj | |||||
return x | |||||
class GEVL(nn.Module): | |||||
""" Generative vision-language model | |||||
Support learning from both generative and contrastive loss. | |||||
Given image and text input, it could output the features of | |||||
image and text respectively. Furthermore, caption could also | |||||
be produced when image input is available. | |||||
""" | |||||
def __init__(self, embed_dim: int, image_resolution: int, | |||||
vision_layers: Union[Tuple[int, int, int, int], | |||||
int], vision_width: int, | |||||
vision_patch_size: int, context_length: int, vocab_size: int, | |||||
transformer_width: int, transformer_heads: int, | |||||
transformer_layers: int, use_gc: bool, tokenizer): | |||||
nn.Module.__init__(self) | |||||
self.context_length = context_length | |||||
self.vis_token_size = context_length | |||||
self.tokenizer = tokenizer | |||||
if isinstance(vision_layers, (tuple, list)): | |||||
vision_heads = vision_width * 32 // 64 | |||||
self.visual = ModifiedResNet( | |||||
layers=vision_layers, | |||||
output_dim=embed_dim, | |||||
heads=vision_heads, | |||||
input_resolution=image_resolution, | |||||
width=vision_width) | |||||
else: | |||||
vision_heads = vision_width // 64 | |||||
self.visual = VisualTransformer( | |||||
input_resolution=image_resolution, | |||||
patch_size=vision_patch_size, | |||||
width=vision_width, | |||||
layers=vision_layers, | |||||
heads=vision_heads, | |||||
output_dim=embed_dim, | |||||
use_gc=use_gc) | |||||
self.transformer = Transformer( | |||||
width=transformer_width, | |||||
layers=transformer_layers, | |||||
heads=transformer_heads, | |||||
attn_mask=self.build_attention_mask(), | |||||
use_gc=use_gc) | |||||
self.vocab_size = vocab_size | |||||
self.token_embedding = nn.Embedding(vocab_size, transformer_width) | |||||
self.positional_embedding = nn.Parameter( | |||||
torch.empty(self.context_length, transformer_width)) | |||||
self.ln_final = LayerNorm(transformer_width) | |||||
self.vis_token_projection = nn.Parameter( | |||||
torch.empty(embed_dim, transformer_width)) | |||||
nn.init.normal_( | |||||
self.vis_token_projection, std=self.transformer.width**-0.5) | |||||
self.text_projection = nn.Parameter( | |||||
torch.empty(transformer_width, embed_dim)) | |||||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |||||
self.decoder = Transformer( | |||||
width=transformer_width, | |||||
layers=4, | |||||
heads=transformer_heads, | |||||
attn_mask=self.build_attention_mask( | |||||
self.vis_token_size + self.context_length, | |||||
self.vis_token_size), | |||||
use_gc=use_gc) | |||||
self.to_logits = nn.Sequential( | |||||
LayerNorm(transformer_width), | |||||
nn.Linear(transformer_width, transformer_width), | |||||
nn.Linear(transformer_width, vocab_size, bias=False)) | |||||
self.gen_logit_scale = nn.Parameter( | |||||
torch.ones([]) * np.log(np.log(vocab_size))) | |||||
self.bias = nn.Parameter(torch.ones(vocab_size)) | |||||
self.to_logits[-1].weight = self.token_embedding.weight | |||||
self.to_logits[-1].bias = self.bias | |||||
self.img_queries = nn.Parameter( | |||||
torch.randn(self.vis_token_size, transformer_width)) | |||||
self.img_attn_pool = CrossAttention( | |||||
dim=transformer_width, norm_context=True) | |||||
self.img_attn_pool_norm = LayerNorm(transformer_width) | |||||
def build_attention_mask(self, seq_length=None, prefix_length=0): | |||||
seq_length = self.context_length if seq_length is None else seq_length | |||||
mask = torch.empty(seq_length, seq_length) | |||||
mask.fill_(torch.tensor(torch.finfo(torch.float16).min)) | |||||
mask.triu_(1) | |||||
if prefix_length > 0: | |||||
mask[:prefix_length, :prefix_length] = 0 | |||||
return mask | |||||
@property | |||||
def dtype(self): | |||||
return self.visual.conv1.weight.dtype | |||||
def encode_image(self, image, return_tokens=False): | |||||
image_outputs = self.visual(image) | |||||
image_features = image_outputs[:, 0, :] | |||||
image_features = image_features / image_features.norm( | |||||
dim=-1, p=2, keepdim=True) | |||||
if return_tokens: | |||||
image_tokens = image_outputs[:, 1:, :] @ self.vis_token_projection | |||||
return image_features, image_tokens | |||||
else: | |||||
return image_features | |||||
def encode_text(self, text, return_tokens=False): | |||||
x = self.token_embedding(text) | |||||
x = x + self.positional_embedding[:x.shape[1], :] | |||||
x = x.permute(1, 0, 2) | |||||
x = self.transformer(x) | |||||
x = x.permute(1, 0, 2) | |||||
x = self.ln_final(x) | |||||
text_features = x[torch.arange(x.shape[0]), | |||||
text.argmax(dim=-1), ...] @ self.text_projection | |||||
text_features = text_features / text_features.norm( | |||||
dim=-1, p=2, keepdim=True) | |||||
if return_tokens: | |||||
text_tokens = x | |||||
return text_features, text_tokens | |||||
else: | |||||
return text_features | |||||
def image_to_text(self, image): | |||||
image_features, image_tokens = self.encode_image( | |||||
image, return_tokens=True) | |||||
img_queries = self.img_queries.expand(image_tokens.shape[0], -1, -1) | |||||
img_token_features = self.img_attn_pool(img_queries, image_tokens) | |||||
img_token_features = self.img_attn_pool_norm(img_token_features) | |||||
sot_token = self.tokenizer.encoder['<|startoftext|>'] | |||||
eot_token = self.tokenizer.encoder['<|endoftext|>'] | |||||
text_input = image.new_ones( | |||||
image.shape[0], 1, dtype=torch.long) * sot_token | |||||
input_tokens = img_token_features | |||||
pred_tokens = [] | |||||
for text_idx in range(self.context_length): | |||||
text_features, text_tokens = self.encode_text( | |||||
text_input, return_tokens=True) | |||||
input_tokens = torch.cat([img_token_features, text_tokens], axis=1) | |||||
out_embs = self.decoder(input_tokens.permute(1, 0, 2).contiguous()) | |||||
gen_logits = self.to_logits(out_embs[-1:, ...]) | |||||
probs = F.softmax(self.gen_logit_scale.exp() * gen_logits, dim=-1) | |||||
pred = torch.argmax( | |||||
probs * (1.0 + torch.rand_like(probs)), axis=-1) | |||||
pred_tokens.append(pred) | |||||
text_input = torch.cat( | |||||
[text_input, pred.permute(1, 0).contiguous()], axis=1) | |||||
pred_text_tokens = torch.cat(pred_tokens, axis=0).permute(1, 0) | |||||
text_list = [] | |||||
for out_tokens in pred_text_tokens: | |||||
tokens = [] | |||||
for x in out_tokens: | |||||
if x >= eot_token or x <= 0: | |||||
break | |||||
tokens.append(int(x)) | |||||
out_text = self.tokenizer.decode(tokens) | |||||
out_text = out_text.strip() | |||||
text_list.append(out_text) | |||||
return image_features, text_list[0] | |||||
class GEMMModel(nn.Module): | |||||
""" Generative multi-modal model, wrapper of GEVL module. | |||||
It takes image or text or both of them as input, and output | |||||
features of input or caption when image input is available. | |||||
""" | |||||
def __init__(self, model_dir): | |||||
super().__init__() | |||||
with open('{}/encoder_config.json'.format(model_dir), 'r') as f: | |||||
model_config = json.loads(f.read()) | |||||
model_name = list(model_config.keys())[0] | |||||
config_args = model_config[model_name] | |||||
bpe_path = os.path.join(model_dir, 'bpe_vocab_16e6.txt.gz') | |||||
self.tokenizer = SimpleTokenizer(bpe_path) | |||||
self.model = GEVL(*config_args, self.tokenizer) | |||||
def tokenize(self, text_str): | |||||
text_tensor = clip_tokenize(self.tokenizer, [text_str])[0] | |||||
return text_tensor | |||||
def parse_feat(self, feat): | |||||
out = feat.cpu().numpy() | |||||
return out | |||||
@torch.no_grad() | |||||
def forward(self, image=None, text=None, captioning=True): | |||||
img_feature, text_feature, caption = None, None, None | |||||
if captioning and image is not None: | |||||
img_feature, caption = self.model.image_to_text(image) | |||||
elif image is not None: | |||||
img_feature = self.parse_feat(self.model.encode_image(image)) | |||||
if text is not None: | |||||
text_feature = self.parse_feat(self.model.encode_text(text)) | |||||
out = { | |||||
'image_feature': img_feature, | |||||
'text_feature': text_feature, | |||||
'caption': caption, | |||||
} | |||||
return out |
@@ -0,0 +1,88 @@ | |||||
import os.path as osp | |||||
from typing import Any, Dict | |||||
import json | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from PIL import Image | |||||
from torchvision import transforms as T | |||||
from modelscope.metainfo import Models | |||||
from modelscope.models.base import TorchModel | |||||
from modelscope.models.builder import MODELS | |||||
from modelscope.models.multi_modal.gemm.gemm_base import GEMMModel | |||||
from modelscope.outputs import OutputKeys | |||||
from modelscope.preprocessors import LoadImage | |||||
from modelscope.utils.constant import ModelFile, Tasks | |||||
from modelscope.utils.logger import get_logger | |||||
logger = get_logger() | |||||
__all__ = ['GEMMForMultiModalEmbedding'] | |||||
@MODELS.register_module( | |||||
Tasks.generative_multi_modal_embedding, module_name=Models.gemm) | |||||
class GEMMForMultiModalEmbedding(TorchModel): | |||||
""" Generative multi-modal model for multi-modal embedding | |||||
Inputs could be image or text or both of them. | |||||
Outputs could be features of input image or text, | |||||
image caption could also be produced when image is available. | |||||
""" | |||||
def __init__(self, model_dir, device_id=0, *args, **kwargs): | |||||
super().__init__( | |||||
model_dir=model_dir, device_id=device_id, *args, **kwargs) | |||||
self.gemm_model = GEMMModel(model_dir=model_dir) | |||||
pretrained_params = torch.load('{}/{}'.format( | |||||
model_dir, ModelFile.TORCH_MODEL_BIN_FILE)) | |||||
self.gemm_model.load_state_dict(pretrained_params) | |||||
self.gemm_model.eval() | |||||
self.device_id = device_id | |||||
if self.device_id >= 0 and torch.cuda.is_available(): | |||||
self.gemm_model.to('cuda:{}'.format(self.device_id)) | |||||
logger.info('Use GPU: {}'.format(self.device_id)) | |||||
else: | |||||
logger.info('Use CPU for inference') | |||||
self.img_preprocessor = T.Compose([ | |||||
T.Resize(224), | |||||
T.CenterCrop(224), | |||||
T.ToTensor(), | |||||
T.Normalize((0.48145466, 0.4578275, 0.40821073), | |||||
(0.26862954, 0.26130258, 0.27577711)) | |||||
]) | |||||
def parse_image(self, input_img): | |||||
if input_img is None: | |||||
return None | |||||
input_img = LoadImage.convert_to_img(input_img) | |||||
img_tensor = self.img_preprocessor(input_img)[None, ...] | |||||
if self.device_id >= 0: | |||||
img_tensor = img_tensor.to('cuda:{}'.format(self.device_id)) | |||||
return img_tensor | |||||
def parse_text(self, text_str): | |||||
if text_str is None: | |||||
return None | |||||
if isinstance(text_str, str): | |||||
text_ids_tensor = self.gemm_model.tokenize(text_str) | |||||
else: | |||||
raise TypeError(f'text should be str, but got {type(text_str)}') | |||||
if self.device_id >= 0: | |||||
text_ids_tensor = text_ids_tensor.to('cuda:{}'.format( | |||||
self.device_id)) | |||||
return text_ids_tensor.view(1, -1) | |||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
image = self.parse_image(input.get('image', input.get('img', None))) | |||||
text = self.parse_text(input.get('text', input.get('txt', None))) | |||||
captioning = input.get('captioning', False) is True | |||||
out = self.gemm_model(image, text, captioning) | |||||
output = { | |||||
OutputKeys.IMG_EMBEDDING: out.get('image_feature', None), | |||||
OutputKeys.TEXT_EMBEDDING: out.get('text_feature', None), | |||||
OutputKeys.CAPTION: out.get('caption', None) | |||||
} | |||||
return output |
@@ -0,0 +1,197 @@ | |||||
""" CLIP Tokenizer | |||||
Adapted from https://github.com/openai/CLIP. | |||||
Originally MIT License, Copyright (c) 2021 OpenAI. | |||||
""" | |||||
import gzip | |||||
import html | |||||
import os | |||||
from functools import lru_cache | |||||
import ftfy | |||||
import regex as re | |||||
import torch | |||||
@lru_cache() | |||||
def default_bpe(): | |||||
return os.path.join( | |||||
os.path.dirname(os.path.abspath(__file__)), | |||||
'bpe_simple_vocab_16e6.txt.gz') | |||||
@lru_cache() | |||||
def bytes_to_unicode(): | |||||
""" | |||||
Returns list of utf-8 byte and a corresponding list of unicode strings. | |||||
The reversible bpe codes work on unicode strings. | |||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. | |||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. | |||||
This is a signficant percentage of your normal, say, 32K bpe vocab. | |||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings. | |||||
And avoids mapping to whitespace/control characters the bpe code barfs on. | |||||
""" | |||||
bs = list(range(ord('!'), | |||||
ord('~') + 1)) + list(range( | |||||
ord('¡'), | |||||
ord('¬') + 1)) + list(range(ord('®'), | |||||
ord('ÿ') + 1)) | |||||
cs = bs[:] | |||||
n = 0 | |||||
for b in range(2**8): | |||||
if b not in bs: | |||||
bs.append(b) | |||||
cs.append(2**8 + n) | |||||
n += 1 | |||||
cs = [chr(n) for n in cs] | |||||
return dict(zip(bs, cs)) | |||||
def get_pairs(word): | |||||
"""Return set of symbol pairs in a word. | |||||
Word is represented as tuple of symbols (symbols being variable-length strings). | |||||
""" | |||||
pairs = set() | |||||
prev_char = word[0] | |||||
for char in word[1:]: | |||||
pairs.add((prev_char, char)) | |||||
prev_char = char | |||||
return pairs | |||||
def basic_clean(text): | |||||
text = ftfy.fix_text(text) | |||||
text = html.unescape(html.unescape(text)) | |||||
return text.strip() | |||||
def whitespace_clean(text): | |||||
text = re.sub(r'\s+', ' ', text) | |||||
text = text.strip() | |||||
return text | |||||
class SimpleTokenizer(object): | |||||
def __init__(self, bpe_path: str = default_bpe()): | |||||
self.byte_encoder = bytes_to_unicode() | |||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} | |||||
merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') | |||||
merges = merges[1:49152 - 256 - 2 + 1] | |||||
merges = [tuple(merge.split()) for merge in merges] | |||||
vocab = list(bytes_to_unicode().values()) | |||||
vocab = vocab + [v + '</w>' for v in vocab] | |||||
for merge in merges: | |||||
vocab.append(''.join(merge)) | |||||
vocab.extend(['<|startoftext|>', '<|endoftext|>']) | |||||
self.encoder = dict(zip(vocab, range(len(vocab)))) | |||||
self.decoder = {v: k for k, v in self.encoder.items()} | |||||
self.bpe_ranks = dict(zip(merges, range(len(merges)))) | |||||
self.cache = { | |||||
'<|startoftext|>': '<|startoftext|>', | |||||
'<|endoftext|>': '<|endoftext|>' | |||||
} | |||||
self.pat = re.compile( | |||||
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", | |||||
re.IGNORECASE) | |||||
def bpe(self, token): | |||||
if token in self.cache: | |||||
return self.cache[token] | |||||
word = tuple(token[:-1]) + (token[-1] + '</w>', ) | |||||
pairs = get_pairs(word) | |||||
if not pairs: | |||||
return token + '</w>' | |||||
error_list = [] | |||||
while True: | |||||
bigram = min( | |||||
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) | |||||
if bigram not in self.bpe_ranks: | |||||
break | |||||
first, second = bigram | |||||
new_word = [] | |||||
i = 0 | |||||
while i < len(word): | |||||
try: | |||||
j = word.index(first, i) | |||||
new_word.extend(word[i:j]) | |||||
i = j | |||||
except Exception as err: | |||||
error_list.append(err) | |||||
new_word.extend(word[i:]) | |||||
break | |||||
if word[i] == first and i < len(word) - 1 and word[ | |||||
i + 1] == second: | |||||
new_word.append(first + second) | |||||
i += 2 | |||||
else: | |||||
new_word.append(word[i]) | |||||
i += 1 | |||||
new_word = tuple(new_word) | |||||
word = new_word | |||||
if len(word) == 1: | |||||
break | |||||
else: | |||||
pairs = get_pairs(word) | |||||
if len(error_list) > 100: | |||||
print(error_list[-1]) | |||||
word = ' '.join(word) | |||||
self.cache[token] = word | |||||
return word | |||||
def encode(self, text): | |||||
bpe_tokens = [] | |||||
text = whitespace_clean(basic_clean(text)).lower() | |||||
for token in re.findall(self.pat, text): | |||||
token = ''.join(self.byte_encoder[b] | |||||
for b in token.encode('utf-8')) | |||||
bpe_tokens.extend(self.encoder[bpe_token] | |||||
for bpe_token in self.bpe(token).split(' ')) | |||||
return bpe_tokens | |||||
def decode(self, tokens): | |||||
text = ''.join([self.decoder[token] for token in tokens]) | |||||
text = bytearray([self.byte_decoder[c] for c in text]).decode( | |||||
'utf-8', errors='replace').replace('</w>', ' ') | |||||
return text | |||||
def clip_tokenize(tokenizer, texts, context_length=77, truncate=True): | |||||
""" | |||||
Returns the tokenized representation of given input string(s) | |||||
Parameters | |||||
---------- | |||||
texts : Union[str, List[str]] | |||||
An input string or a list of input strings to tokenize | |||||
context_length : int | |||||
The context length to use; all CLIP models use 77 as the context length | |||||
truncate: bool | |||||
Whether to truncate the text in case its encoding is longer than the context length | |||||
Returns | |||||
------- | |||||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. | |||||
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. | |||||
""" | |||||
if isinstance(texts, str): | |||||
texts = [texts] | |||||
sot_token = tokenizer.encoder['<|startoftext|>'] | |||||
eot_token = tokenizer.encoder['<|endoftext|>'] | |||||
all_tokens = [[sot_token] + tokenizer.encode(text) + [eot_token] | |||||
for text in texts] | |||||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) | |||||
for i, tokens in enumerate(all_tokens): | |||||
if len(tokens) > context_length: | |||||
if truncate: | |||||
tokens = tokens[:context_length] | |||||
tokens[-1] = eot_token | |||||
else: | |||||
raise RuntimeError( | |||||
f'Input {texts[i]} is too long for context length {context_length}' | |||||
) | |||||
result[i, :len(tokens)] = torch.tensor(tokens) | |||||
return result |
@@ -271,6 +271,15 @@ TASK_OUTPUTS = { | |||||
Tasks.multi_modal_embedding: | Tasks.multi_modal_embedding: | ||||
[OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING], | [OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING], | ||||
# generative multi-modal embedding result for single sample | |||||
# { | |||||
# "img_embedding": np.array with shape [1, D], | |||||
# "text_embedding": np.array with shape [1, D], | |||||
# "caption": "this is an image caption text." | |||||
# } | |||||
Tasks.generative_multi_modal_embedding: | |||||
[OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.CAPTION], | |||||
# visual grounding result for single sample | # visual grounding result for single sample | ||||
# { | # { | ||||
# "boxes": [ | # "boxes": [ | ||||
@@ -62,6 +62,10 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
Tasks.multi_modal_embedding: | Tasks.multi_modal_embedding: | ||||
(Pipelines.multi_modal_embedding, | (Pipelines.multi_modal_embedding, | ||||
'damo/multi-modal_clip-vit-large-patch14-chinese_multi-modal-embedding'), | 'damo/multi-modal_clip-vit-large-patch14-chinese_multi-modal-embedding'), | ||||
Tasks.generative_multi_modal_embedding: | |||||
(Pipelines.generative_multi_modal_embedding, | |||||
'damo/multi-modal_gemm-vit-large-patch14_generative-multi-modal-embedding' | |||||
), | |||||
Tasks.visual_question_answering: | Tasks.visual_question_answering: | ||||
(Pipelines.visual_question_answering, | (Pipelines.visual_question_answering, | ||||
'damo/mplug_visual-question-answering_coco_large_en'), | 'damo/mplug_visual-question-answering_coco_large_en'), | ||||
@@ -1,6 +1,7 @@ | |||||
try: | try: | ||||
from .image_captioning_pipeline import ImageCaptionPipeline | from .image_captioning_pipeline import ImageCaptionPipeline | ||||
from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline | from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline | ||||
from .generative_multi_modal_embedding_pipeline import GEMMMultiModalEmbeddingPipeline | |||||
from .text_to_image_synthesis_pipeline import TextToImageSynthesisPipeline | from .text_to_image_synthesis_pipeline import TextToImageSynthesisPipeline | ||||
from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline | from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline | ||||
except ModuleNotFoundError as e: | except ModuleNotFoundError as e: | ||||
@@ -0,0 +1,32 @@ | |||||
from typing import Any, Dict | |||||
from modelscope.metainfo import Pipelines | |||||
from modelscope.pipelines.base import Input, Model, Pipeline | |||||
from modelscope.pipelines.builder import PIPELINES | |||||
from modelscope.utils.constant import Tasks | |||||
from modelscope.utils.logger import get_logger | |||||
logger = get_logger() | |||||
@PIPELINES.register_module( | |||||
Tasks.generative_multi_modal_embedding, | |||||
module_name=Pipelines.generative_multi_modal_embedding) | |||||
class GEMMMultiModalEmbeddingPipeline(Pipeline): | |||||
def __init__(self, model: str, **kwargs): | |||||
""" | |||||
use `model` to create a generative multimodal embedding pipeline | |||||
Args: | |||||
model: model id on modelscope hub. | |||||
""" | |||||
super().__init__(model=model, **kwargs) | |||||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
return input | |||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
return self.model(input) | |||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
return inputs |
@@ -75,6 +75,7 @@ class MultiModalTasks(object): | |||||
visual_grounding = 'visual-grounding' | visual_grounding = 'visual-grounding' | ||||
text_to_image_synthesis = 'text-to-image-synthesis' | text_to_image_synthesis = 'text-to-image-synthesis' | ||||
multi_modal_embedding = 'multi-modal-embedding' | multi_modal_embedding = 'multi-modal-embedding' | ||||
generative_multi_modal_embedding = 'generative-multi-modal-embedding' | |||||
visual_question_answering = 'visual-question-answering' | visual_question_answering = 'visual-question-answering' | ||||
@@ -0,0 +1,70 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import unittest | |||||
import numpy as np | |||||
from modelscope.models import Model | |||||
from modelscope.pipelines import pipeline | |||||
from modelscope.utils.constant import Tasks | |||||
from modelscope.utils.test_utils import test_level | |||||
class GEMMMultiModalEmbeddingTest(unittest.TestCase): | |||||
model_id = 'damo/multi-modal_gemm-vit-large-patch14_generative-multi-modal-embedding' | |||||
test_input = { | |||||
'image': 'data/test/images/generative_multimodal.jpg', | |||||
'text': | |||||
'interior design of modern living room with fireplace in a new house', | |||||
'captioning': False | |||||
} | |||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
def test_run(self): | |||||
generative_multi_modal_embedding_pipeline = pipeline( | |||||
Tasks.generative_multi_modal_embedding, model=self.model_id) | |||||
output = generative_multi_modal_embedding_pipeline(self.test_input) | |||||
print(output) | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run_with_default_model(self): | |||||
generative_multi_modal_embedding_pipeline = pipeline( | |||||
task=Tasks.generative_multi_modal_embedding) | |||||
output = generative_multi_modal_embedding_pipeline(self.test_input) | |||||
print(output) | |||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
def test_run_with_model_from_modelhub(self): | |||||
model = Model.from_pretrained(self.model_id) | |||||
generative_multi_modal_embedding_pipeline = pipeline( | |||||
task=Tasks.generative_multi_modal_embedding, model=model) | |||||
output = generative_multi_modal_embedding_pipeline(self.test_input) | |||||
print(output) | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run_with_output_captioning(self): | |||||
generative_multi_modal_embedding_pipeline = pipeline( | |||||
task=Tasks.generative_multi_modal_embedding, model=self.model_id) | |||||
test_input = {'image': self.test_input['image'], 'captioning': True} | |||||
output = generative_multi_modal_embedding_pipeline(test_input) | |||||
print(output) | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run_with_output_only_image(self): | |||||
generative_multi_modal_embedding_pipeline = pipeline( | |||||
task=Tasks.generative_multi_modal_embedding, model=self.model_id) | |||||
test_input = {'image': self.test_input['image'], 'captioning': False} | |||||
output = generative_multi_modal_embedding_pipeline(test_input) | |||||
print(output) | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run_with_output_only_text(self): | |||||
generative_multi_modal_embedding_pipeline = pipeline( | |||||
task=Tasks.generative_multi_modal_embedding, model=self.model_id) | |||||
test_input = {'text': self.test_input['text']} | |||||
output = generative_multi_modal_embedding_pipeline(test_input) | |||||
print(output) | |||||
if __name__ == '__main__': | |||||
unittest.main() |