文本指导的语义分割模型,根据输入的文本信息,讲图像中对应文本描述的物体分割出来。 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9942863master
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:2c7d2f279e3b317f1d0de18410a0585e122166fa2464c17b88a0c813f6c58bd4 | |||
size 67861 |
@@ -29,6 +29,7 @@ class Models(object): | |||
video_summarization = 'pgl-video-summarization' | |||
swinL_semantic_segmentation = 'swinL-semantic-segmentation' | |||
vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | |||
text_driven_segmentation = 'text-driven-segmentation' | |||
resnet50_bert = 'resnet50-bert' | |||
# EasyCV models | |||
@@ -143,6 +144,7 @@ class Pipelines(object): | |||
video_summarization = 'googlenet_pgl_video_summarization' | |||
image_semantic_segmentation = 'image-semantic-segmentation' | |||
image_reid_person = 'passvitb-image-reid-person' | |||
text_driven_segmentation = 'text-driven-segmentation' | |||
movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' | |||
# nlp tasks | |||
@@ -0,0 +1 @@ | |||
from .lseg_base import TextDrivenSegmentation |
@@ -0,0 +1,170 @@ | |||
""" CLIP | |||
Adapted from https://github.com/openai/CLIP. | |||
Originally MIT License, Copyright (c) 2021 OpenAI. | |||
""" | |||
import hashlib | |||
import os | |||
import urllib | |||
import warnings | |||
from typing import Any, List, Union | |||
import torch | |||
from PIL import Image | |||
from pkg_resources import packaging | |||
from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize, | |||
ToTensor) | |||
from tqdm import tqdm | |||
from .model import build_model | |||
from .simple_tokenizer import SimpleTokenizer as _Tokenizer | |||
try: | |||
from torchvision.transforms import InterpolationMode | |||
BICUBIC = InterpolationMode.BICUBIC | |||
except ImportError: | |||
BICUBIC = Image.BICUBIC | |||
if packaging.version.parse( | |||
torch.__version__) < packaging.version.parse('1.7.1'): | |||
warnings.warn('PyTorch version 1.7.1 or higher is recommended') | |||
__all__ = ['load', 'tokenize'] | |||
def _convert_image_to_rgb(image): | |||
return image.convert('RGB') | |||
def _transform(n_px): | |||
return Compose([ | |||
Resize(n_px, interpolation=BICUBIC), | |||
CenterCrop(n_px), | |||
_convert_image_to_rgb, | |||
ToTensor(), | |||
Normalize((0.48145466, 0.4578275, 0.40821073), | |||
(0.26862954, 0.26130258, 0.27577711)), | |||
]) | |||
def load(name: str, | |||
device: Union[str, torch.device] = 'cuda' | |||
if torch.cuda.is_available() else 'cpu', | |||
jit: bool = False, | |||
root: str = None): | |||
if not jit: | |||
model = build_model().to(device) | |||
if str(device) == 'cpu': | |||
model.float() | |||
return model, _transform(model.visual.input_resolution) | |||
# patch the device names | |||
device_holder = torch.jit.trace( | |||
lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) | |||
device_node = [ | |||
n for n in device_holder.graph.findAllNodes('prim::Constant') | |||
if 'Device' in repr(n) | |||
][-1] | |||
def patch_device(module): | |||
try: | |||
graphs = [module.graph] if hasattr(module, 'graph') else [] | |||
except RuntimeError: | |||
graphs = [] | |||
if hasattr(module, 'forward1'): | |||
graphs.append(module.forward1.graph) | |||
for graph in graphs: | |||
for node in graph.findAllNodes('prim::Constant'): | |||
if 'value' in node.attributeNames() and str( | |||
node['value']).startswith('cuda'): | |||
node.copyAttributes(device_node) | |||
model.apply(patch_device) | |||
patch_device(model.encode_image) | |||
patch_device(model.encode_text) | |||
# patch dtype to float32 on CPU | |||
if str(device) == 'cpu': | |||
float_holder = torch.jit.trace( | |||
lambda: torch.ones([]).float(), example_inputs=[]) | |||
float_input = list(float_holder.graph.findNode('aten::to').inputs())[1] | |||
float_node = float_input.node() | |||
def patch_float(module): | |||
try: | |||
graphs = [module.graph] if hasattr(module, 'graph') else [] | |||
except RuntimeError: | |||
graphs = [] | |||
if hasattr(module, 'forward1'): | |||
graphs.append(module.forward1.graph) | |||
for graph in graphs: | |||
for node in graph.findAllNodes('aten::to'): | |||
inputs = list(node.inputs()) | |||
for i in [ | |||
1, 2 | |||
]: # dtype can be the second or third argument to aten::to() | |||
if inputs[i].node()['value'] == 5: | |||
inputs[i].node().copyAttributes(float_node) | |||
model.apply(patch_float) | |||
patch_float(model.encode_image) | |||
patch_float(model.encode_text) | |||
model.float() | |||
return model, _transform(model.input_resolution.item()) | |||
def tokenize( | |||
_tokenizer, | |||
texts: Union[str, List[str]], | |||
context_length: int = 77, | |||
truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: | |||
""" | |||
Returns the tokenized representation of given input string(s) | |||
Parameters | |||
---------- | |||
texts : Union[str, List[str]] | |||
An input string or a list of input strings to tokenize | |||
context_length : int | |||
The context length to use; all CLIP models use 77 as the context length | |||
truncate: bool | |||
Whether to truncate the text in case its encoding is longer than the context length | |||
Returns | |||
------- | |||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. | |||
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. | |||
""" | |||
if isinstance(texts, str): | |||
texts = [texts] | |||
sot_token = _tokenizer.encoder['<|startoftext|>'] | |||
eot_token = _tokenizer.encoder['<|endoftext|>'] | |||
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] | |||
for text in texts] | |||
if packaging.version.parse( | |||
torch.__version__) < packaging.version.parse('1.8.0'): | |||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) | |||
else: | |||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) | |||
for i, tokens in enumerate(all_tokens): | |||
if len(tokens) > context_length: | |||
if truncate: | |||
tokens = tokens[:context_length] | |||
tokens[-1] = eot_token | |||
else: | |||
raise RuntimeError( | |||
f'Input {texts[i]} is too long for context length {context_length}' | |||
) | |||
result[i, :len(tokens)] = torch.tensor(tokens) | |||
return result |
@@ -0,0 +1,28 @@ | |||
""" | |||
Adapted from https://github.com/isl-org/lang-seg. | |||
Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. | |||
""" | |||
import torch | |||
import torch.nn as nn | |||
from .lseg_net import LSeg | |||
class TextDrivenSegmentation(nn.Module): | |||
def __init__(self, model_dir): | |||
super(TextDrivenSegmentation, self).__init__() | |||
self.net = LSeg(model_dir=model_dir) | |||
self.model_dir = model_dir | |||
def forward(self, img, txt_list): | |||
b = img.size()[0] | |||
batch_name_list = txt_list | |||
xout_list = [] | |||
for i in range(b): | |||
labelset = ['others', batch_name_list[i]] | |||
xout = self.net(img[i:i + 1], labelset=labelset) | |||
xout_list.append(xout) | |||
score_map = torch.cat(xout_list, dim=0) | |||
return score_map |
@@ -0,0 +1,334 @@ | |||
""" | |||
Adapted from https://github.com/isl-org/lang-seg. | |||
Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. | |||
""" | |||
import torch | |||
import torch.nn as nn | |||
from .lseg_vit import _make_pretrained_clip_vitl16_384, forward_vit | |||
def _make_encoder( | |||
backbone, | |||
features, | |||
use_pretrained=True, | |||
groups=1, | |||
expand=False, | |||
exportable=True, | |||
hooks=None, | |||
use_vit_only=False, | |||
use_readout='ignore', | |||
enable_attention_hooks=False, | |||
): | |||
if backbone == 'clip_vitl16_384': | |||
clip_pretrained, pretrained = _make_pretrained_clip_vitl16_384( | |||
use_pretrained, | |||
hooks=hooks, | |||
use_readout=use_readout, | |||
enable_attention_hooks=enable_attention_hooks, | |||
) | |||
scratch = _make_scratch([256, 512, 1024, 1024], | |||
features, | |||
groups=groups, | |||
expand=expand) | |||
else: | |||
raise NotImplementedError(f"Backbone '{backbone}' not implemented") | |||
return clip_pretrained, pretrained, scratch | |||
def _make_scratch(in_shape, out_shape, groups=1, expand=False): | |||
scratch = nn.Module() | |||
out_shape1 = out_shape | |||
out_shape2 = out_shape | |||
out_shape3 = out_shape | |||
out_shape4 = out_shape | |||
if expand is True: | |||
out_shape1 = out_shape | |||
out_shape2 = out_shape * 2 | |||
out_shape3 = out_shape * 4 | |||
out_shape4 = out_shape * 8 | |||
scratch.layer1_rn = nn.Conv2d( | |||
in_shape[0], | |||
out_shape1, | |||
kernel_size=3, | |||
stride=1, | |||
padding=1, | |||
bias=False, | |||
groups=groups, | |||
) | |||
scratch.layer2_rn = nn.Conv2d( | |||
in_shape[1], | |||
out_shape2, | |||
kernel_size=3, | |||
stride=1, | |||
padding=1, | |||
bias=False, | |||
groups=groups, | |||
) | |||
scratch.layer3_rn = nn.Conv2d( | |||
in_shape[2], | |||
out_shape3, | |||
kernel_size=3, | |||
stride=1, | |||
padding=1, | |||
bias=False, | |||
groups=groups, | |||
) | |||
scratch.layer4_rn = nn.Conv2d( | |||
in_shape[3], | |||
out_shape4, | |||
kernel_size=3, | |||
stride=1, | |||
padding=1, | |||
bias=False, | |||
groups=groups, | |||
) | |||
return scratch | |||
class Interpolate(nn.Module): | |||
"""Interpolation module.""" | |||
def __init__(self, scale_factor, mode, align_corners=False): | |||
"""Init. | |||
Args: | |||
scale_factor (float): scaling | |||
mode (str): interpolation mode | |||
""" | |||
super(Interpolate, self).__init__() | |||
self.interp = nn.functional.interpolate | |||
self.scale_factor = scale_factor | |||
self.mode = mode | |||
self.align_corners = align_corners | |||
def forward(self, x): | |||
"""Forward pass. | |||
Args: | |||
x (tensor): input | |||
Returns: | |||
tensor: interpolated data | |||
""" | |||
x = self.interp( | |||
x, | |||
scale_factor=self.scale_factor, | |||
mode=self.mode, | |||
align_corners=self.align_corners, | |||
) | |||
return x | |||
class ResidualConvUnit(nn.Module): | |||
"""Residual convolution module.""" | |||
def __init__(self, features): | |||
"""Init. | |||
Args: | |||
features (int): number of features | |||
""" | |||
super().__init__() | |||
self.conv1 = nn.Conv2d( | |||
features, features, kernel_size=3, stride=1, padding=1, bias=True) | |||
self.conv2 = nn.Conv2d( | |||
features, features, kernel_size=3, stride=1, padding=1, bias=True) | |||
self.relu = nn.ReLU(inplace=True) | |||
def forward(self, x): | |||
"""Forward pass. | |||
Args: | |||
x (tensor): input | |||
Returns: | |||
tensor: output | |||
""" | |||
out = self.relu(x) | |||
out = self.conv1(out) | |||
out = self.relu(out) | |||
out = self.conv2(out) | |||
return out + x | |||
class FeatureFusionBlock(nn.Module): | |||
"""Feature fusion block.""" | |||
def __init__(self, features): | |||
"""Init. | |||
Args: | |||
features (int): number of features | |||
""" | |||
super(FeatureFusionBlock, self).__init__() | |||
self.resConfUnit1 = ResidualConvUnit(features) | |||
self.resConfUnit2 = ResidualConvUnit(features) | |||
def forward(self, *xs): | |||
"""Forward pass. | |||
Returns: | |||
tensor: output | |||
""" | |||
output = xs[0] | |||
if len(xs) == 2: | |||
output += self.resConfUnit1(xs[1]) | |||
output = self.resConfUnit2(output) | |||
output = nn.functional.interpolate( | |||
output, scale_factor=2, mode='bilinear', align_corners=True) | |||
return output | |||
class ResidualConvUnit_custom(nn.Module): | |||
"""Residual convolution module.""" | |||
def __init__(self, features, activation, bn): | |||
"""Init. | |||
Args: | |||
features (int): number of features | |||
""" | |||
super().__init__() | |||
self.bn = bn | |||
self.groups = 1 | |||
self.conv1 = nn.Conv2d( | |||
features, | |||
features, | |||
kernel_size=3, | |||
stride=1, | |||
padding=1, | |||
bias=not self.bn, | |||
groups=self.groups, | |||
) | |||
self.conv2 = nn.Conv2d( | |||
features, | |||
features, | |||
kernel_size=3, | |||
stride=1, | |||
padding=1, | |||
bias=not self.bn, | |||
groups=self.groups, | |||
) | |||
if self.bn is True: | |||
self.bn1 = nn.BatchNorm2d(features) | |||
self.bn2 = nn.BatchNorm2d(features) | |||
self.activation = activation | |||
self.skip_add = nn.quantized.FloatFunctional() | |||
def forward(self, x): | |||
"""Forward pass. | |||
Args: | |||
x (tensor): input | |||
Returns: | |||
tensor: output | |||
""" | |||
out = self.activation(x) | |||
out = self.conv1(out) | |||
if self.bn is True: | |||
out = self.bn1(out) | |||
out = self.activation(out) | |||
out = self.conv2(out) | |||
if self.bn is True: | |||
out = self.bn2(out) | |||
if self.groups > 1: | |||
out = self.conv_merge(out) | |||
return self.skip_add.add(out, x) | |||
class FeatureFusionBlock_custom(nn.Module): | |||
"""Feature fusion block.""" | |||
def __init__( | |||
self, | |||
features, | |||
activation, | |||
deconv=False, | |||
bn=False, | |||
expand=False, | |||
align_corners=True, | |||
): | |||
"""Init. | |||
Args: | |||
features (int): number of features | |||
""" | |||
super(FeatureFusionBlock_custom, self).__init__() | |||
self.deconv = deconv | |||
self.align_corners = align_corners | |||
self.groups = 1 | |||
self.expand = expand | |||
out_features = features | |||
if self.expand is True: | |||
out_features = features // 2 | |||
self.out_conv = nn.Conv2d( | |||
features, | |||
out_features, | |||
kernel_size=1, | |||
stride=1, | |||
padding=0, | |||
bias=True, | |||
groups=1, | |||
) | |||
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) | |||
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) | |||
self.skip_add = nn.quantized.FloatFunctional() | |||
def forward(self, *xs): | |||
"""Forward pass. | |||
Returns: | |||
tensor: output | |||
""" | |||
output = xs[0] | |||
if len(xs) == 2: | |||
res = self.resConfUnit1(xs[1]) | |||
output = self.skip_add.add(output, res) | |||
output = self.resConfUnit2(output) | |||
output = nn.functional.interpolate( | |||
output, | |||
scale_factor=2, | |||
mode='bilinear', | |||
align_corners=self.align_corners) | |||
output = self.out_conv(output) | |||
return output |
@@ -0,0 +1,107 @@ | |||
import os.path as osp | |||
from typing import Any, Dict | |||
import json | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from PIL import Image | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base import TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.models.cv.text_driven_segmentation import \ | |||
TextDrivenSegmentation | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.preprocessors import LoadImage | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
__all__ = ['TextDrivenSeg'] | |||
@MODELS.register_module( | |||
Tasks.text_driven_segmentation, | |||
module_name=Models.text_driven_segmentation) | |||
class TextDrivenSeg(TorchModel): | |||
""" text driven segmentation model. | |||
""" | |||
def __init__(self, model_dir, device_id=0, *args, **kwargs): | |||
super().__init__( | |||
model_dir=model_dir, device_id=device_id, *args, **kwargs) | |||
self.model = TextDrivenSegmentation(model_dir=model_dir) | |||
pretrained_params = torch.load('{}/{}'.format( | |||
model_dir, ModelFile.TORCH_MODEL_BIN_FILE)) | |||
self.model.load_state_dict(pretrained_params) | |||
self.model.eval() | |||
if device_id >= 0 and torch.cuda.is_available(): | |||
self.model.to('cuda:{}'.format(device_id)) | |||
logger.info('Use GPU: {}'.format(device_id)) | |||
else: | |||
device_id = -1 | |||
logger.info('Use CPU for inference') | |||
self.device_id = device_id | |||
def preprocess(self, img, size=640): | |||
mean = [0.48145466, 0.4578275, 0.40821073] | |||
std = [0.26862954, 0.26130258, 0.27577711] | |||
h, w, c = img.shape | |||
max_hw = max(h, w) | |||
ratio = 1.0 * size / max_hw | |||
crop_h, crop_w = int(ratio * h), int(ratio * w) | |||
pil_img = Image.fromarray(img) | |||
pil_img = pil_img.resize((crop_w, crop_h), Image.BILINEAR) | |||
np_img = np.array(pil_img, dtype=np.float32) / 255. | |||
for j in range(3): | |||
np_img[:, :, j] = (np_img[:, :, j] - mean[j]) / std[j] | |||
img_pad = np.zeros((size, size, 3), dtype=np.float32) | |||
img_pad[:crop_h, :crop_w] = np_img | |||
img_pad = torch.from_numpy(img_pad).permute(2, 0, | |||
1).unsqueeze(0).float() | |||
return img_pad, h, w, crop_h, crop_w | |||
def postprocess(self, tensors, crop_h, crop_w, ori_h, ori_w): | |||
output = np.clip(tensors * 255., a_min=0, a_max=255.) | |||
crop_output = np.array(output[:crop_h, :crop_w], dtype=np.uint8) | |||
pil_output = Image.fromarray(crop_output) | |||
pil_output = pil_output.resize((ori_w, ori_h), Image.BILINEAR) | |||
np_output = np.array(pil_output, dtype=np.uint8) | |||
np_output[np_output < 128] = 0 | |||
np_output[np_output >= 128] = 255 | |||
np_output = np.uint8(np_output) | |||
return np_output | |||
def forward(self, image, text): | |||
""" | |||
image should be numpy array, dtype=np.uint8, shape: height*width*3 | |||
""" | |||
image_tensor, ori_h, ori_w, crop_h, crop_w = self.preprocess( | |||
image, size=640) | |||
pred = self.inference(image_tensor, text) | |||
msk = self.postprocess(pred, crop_h, crop_w, ori_h, ori_w, size=640) | |||
outputs = {OutputKeys.MASKS: msk} | |||
return outputs | |||
def inference(self, image, text): | |||
""" | |||
image should be tensor, 1 * 3 * 640 * 640 | |||
""" | |||
with torch.no_grad(): | |||
if self.device_id == -1: | |||
output = self.model(image) | |||
else: | |||
device = torch.device('cuda', self.device_id) | |||
output = self.model(image.to(device), [text]) | |||
output = F.interpolate(output, size=(640, 640), mode='bilinear') | |||
output = F.softmax(output, dim=1) | |||
output = torch.argmax(output, dim=1) | |||
output = output[0] | |||
if self.device_id == -1: | |||
pred = output.data.numpy() | |||
else: | |||
pred = output.data.cpu().numpy() | |||
del output | |||
return pred |
@@ -0,0 +1,197 @@ | |||
""" | |||
Adapted from https://github.com/isl-org/lang-seg. | |||
Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. | |||
""" | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
from . import clip | |||
from .lseg_blocks import (FeatureFusionBlock, FeatureFusionBlock_custom, | |||
Interpolate, _make_encoder, forward_vit) | |||
from .simple_tokenizer import SimpleTokenizer | |||
class depthwise_clipseg_conv(nn.Module): | |||
def __init__(self): | |||
super(depthwise_clipseg_conv, self).__init__() | |||
self.depthwise = nn.Conv2d(1, 1, kernel_size=3, padding=1) | |||
def depthwise_clipseg(self, x, channels): | |||
x = torch.cat( | |||
[self.depthwise(x[:, i].unsqueeze(1)) for i in range(channels)], | |||
dim=1) | |||
return x | |||
def forward(self, x): | |||
channels = x.shape[1] | |||
out = self.depthwise_clipseg(x, channels) | |||
return out | |||
class depthwise_conv(nn.Module): | |||
def __init__(self, kernel_size=3, stride=1, padding=1): | |||
super(depthwise_conv, self).__init__() | |||
self.depthwise = nn.Conv2d( | |||
1, 1, kernel_size=kernel_size, stride=stride, padding=padding) | |||
def forward(self, x): | |||
# support for 4D tensor with NCHW | |||
C, H, W = x.shape[1:] | |||
x = x.reshape(-1, 1, H, W) | |||
x = self.depthwise(x) | |||
x = x.view(-1, C, H, W) | |||
return x | |||
class depthwise_block(nn.Module): | |||
def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'): | |||
super(depthwise_block, self).__init__() | |||
self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1) | |||
if activation == 'relu': | |||
self.activation = nn.ReLU() | |||
elif activation == 'lrelu': | |||
self.activation = nn.LeakyReLU() | |||
elif activation == 'tanh': | |||
self.activation = nn.Tanh() | |||
def forward(self, x, act=True): | |||
x = self.depthwise(x) | |||
if act: | |||
x = self.activation(x) | |||
return x | |||
class bottleneck_block(nn.Module): | |||
def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'): | |||
super(bottleneck_block, self).__init__() | |||
self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1) | |||
if activation == 'relu': | |||
self.activation = nn.ReLU() | |||
elif activation == 'lrelu': | |||
self.activation = nn.LeakyReLU() | |||
elif activation == 'tanh': | |||
self.activation = nn.Tanh() | |||
def forward(self, x, act=True): | |||
sum_layer = x.max(dim=1, keepdim=True)[0] | |||
x = self.depthwise(x) | |||
x = x + sum_layer | |||
if act: | |||
x = self.activation(x) | |||
return x | |||
class BaseModel(torch.nn.Module): | |||
def load(self, path): | |||
"""Load model from file. | |||
Args: | |||
path (str): file path | |||
""" | |||
parameters = torch.load(path, map_location=torch.device('cpu')) | |||
if 'optimizer' in parameters: | |||
parameters = parameters['model'] | |||
self.load_state_dict(parameters) | |||
def _make_fusion_block(features, use_bn): | |||
return FeatureFusionBlock_custom( | |||
features, | |||
activation=nn.ReLU(False), | |||
deconv=False, | |||
bn=use_bn, | |||
expand=False, | |||
align_corners=True, | |||
) | |||
class LSeg(BaseModel): | |||
def __init__( | |||
self, | |||
features=256, | |||
backbone='clip_vitl16_384', | |||
readout='project', | |||
use_bn=True, | |||
model_dir=None, | |||
): | |||
super(LSeg, self).__init__() | |||
hooks = { | |||
'clip_vitl16_384': [5, 11, 17, 23], | |||
} | |||
# Instantiate backbone and reassemble blocks | |||
self.clip_pretrained, self.pretrained, self.scratch = _make_encoder( | |||
backbone, | |||
features, | |||
groups=1, | |||
expand=False, | |||
exportable=False, | |||
hooks=hooks[backbone], | |||
use_readout=readout, | |||
) | |||
self.scratch.refinenet1 = _make_fusion_block(features, use_bn) | |||
self.scratch.refinenet2 = _make_fusion_block(features, use_bn) | |||
self.scratch.refinenet3 = _make_fusion_block(features, use_bn) | |||
self.scratch.refinenet4 = _make_fusion_block(features, use_bn) | |||
self.logit_scale = nn.Parameter(torch.ones([]) | |||
* np.log(1 / 0.07)).exp() | |||
self.out_c = 512 | |||
self.scratch.head1 = nn.Conv2d(features, self.out_c, kernel_size=1) | |||
self.scratch.output_conv = nn.Sequential( | |||
Interpolate(scale_factor=2, mode='bilinear', align_corners=True), ) | |||
self.tau = 0.07 | |||
self.model_dir = model_dir | |||
self.tokenizer = SimpleTokenizer(model_dir | |||
+ '/bpe_simple_vocab_16e6.txt.gz') | |||
def forward(self, x, labelset=''): | |||
text = clip.tokenize(self.tokenizer, labelset) | |||
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) | |||
layer_1_rn = self.scratch.layer1_rn(layer_1) | |||
layer_2_rn = self.scratch.layer2_rn(layer_2) | |||
layer_3_rn = self.scratch.layer3_rn(layer_3) | |||
layer_4_rn = self.scratch.layer4_rn(layer_4) | |||
path_4 = self.scratch.refinenet4(layer_4_rn) | |||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn) | |||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn) | |||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn) | |||
text = text.to(x.device) | |||
text_features = self.clip_pretrained.encode_text(text) | |||
image_features = self.scratch.head1(path_1) | |||
imshape = image_features.shape | |||
image_features = image_features.permute(0, 2, 3, | |||
1).reshape(-1, self.out_c) | |||
# normalized features | |||
image_features = image_features / image_features.norm( | |||
dim=-1, keepdim=True) | |||
text_features = text_features / text_features.norm( | |||
dim=-1, keepdim=True) | |||
logits_per_image = image_features @ text_features.t() / self.tau | |||
out = logits_per_image.float().view(imshape[0], imshape[2], imshape[3], | |||
-1).permute(0, 3, 1, 2) | |||
out = self.scratch.output_conv(out) | |||
return out |
@@ -0,0 +1,543 @@ | |||
""" | |||
Adapted from https://github.com/isl-org/lang-seg. | |||
Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. | |||
""" | |||
import math | |||
import types | |||
import timm | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
import torch.utils.checkpoint as checkpoint | |||
from . import clip | |||
activations = {} | |||
def get_activation(name): | |||
def hook(model, input, output): | |||
activations[name] = output | |||
return hook | |||
attention = {} | |||
def get_attention(name): | |||
def hook(module, input, output): | |||
x = input[0] | |||
B, N, C = x.shape | |||
qkv = ( | |||
module.qkv(x).reshape(B, N, 3, module.num_heads, | |||
C // module.num_heads).permute( | |||
2, 0, 3, 1, 4)) | |||
q, k, _ = ( | |||
qkv[0], | |||
qkv[1], | |||
qkv[2], | |||
) # make torchscript happy (cannot use tensor as tuple) | |||
attn = (q @ k.transpose(-2, -1)) * module.scale | |||
attn = attn.softmax(dim=-1) # [:,:,1,1:] | |||
attention[name] = attn | |||
return hook | |||
def get_mean_attention_map(attn, token, shape): | |||
attn = attn[:, :, token, 1:] | |||
attn = attn.unflatten(2, torch.Size([shape[2] // 16, | |||
shape[3] // 16])).float() | |||
attn = torch.nn.functional.interpolate( | |||
attn, size=shape[2:], mode='bicubic', align_corners=False).squeeze(0) | |||
all_attn = torch.mean(attn, 0) | |||
return all_attn | |||
class Slice(nn.Module): | |||
def __init__(self, start_index=1): | |||
super(Slice, self).__init__() | |||
self.start_index = start_index | |||
def forward(self, x): | |||
return x[:, self.start_index:] | |||
class AddReadout(nn.Module): | |||
def __init__(self, start_index=1): | |||
super(AddReadout, self).__init__() | |||
self.start_index = start_index | |||
def forward(self, x): | |||
if self.start_index == 2: | |||
readout = (x[:, 0] + x[:, 1]) / 2 | |||
else: | |||
readout = x[:, 0] | |||
return x[:, self.start_index:] + readout.unsqueeze(1) | |||
class ProjectReadout(nn.Module): | |||
def __init__(self, in_features, start_index=1): | |||
super(ProjectReadout, self).__init__() | |||
self.start_index = start_index | |||
self.project = nn.Sequential( | |||
nn.Linear(2 * in_features, in_features), nn.GELU()) | |||
def forward(self, x): | |||
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:]) | |||
features = torch.cat((x[:, self.start_index:], readout), -1) | |||
return self.project(features) | |||
class Transpose(nn.Module): | |||
def __init__(self, dim0, dim1): | |||
super(Transpose, self).__init__() | |||
self.dim0 = dim0 | |||
self.dim1 = dim1 | |||
def forward(self, x): | |||
x = x.transpose(self.dim0, self.dim1) | |||
return x | |||
def forward_vit(pretrained, x): | |||
b, c, h, w = x.shape | |||
# encoder | |||
_ = pretrained.model.forward_flex(x) | |||
layer_1 = pretrained.activations['1'] | |||
layer_2 = pretrained.activations['2'] | |||
layer_3 = pretrained.activations['3'] | |||
layer_4 = pretrained.activations['4'] | |||
layer_1 = pretrained.act_postprocess1[0:2](layer_1) | |||
layer_2 = pretrained.act_postprocess2[0:2](layer_2) | |||
layer_3 = pretrained.act_postprocess3[0:2](layer_3) | |||
layer_4 = pretrained.act_postprocess4[0:2](layer_4) | |||
unflatten = nn.Sequential( | |||
nn.Unflatten( | |||
2, | |||
torch.Size([ | |||
h // pretrained.model.patch_size[1], | |||
w // pretrained.model.patch_size[0], | |||
]), | |||
)) | |||
if layer_1.ndim == 3: | |||
layer_1 = unflatten(layer_1) | |||
if layer_2.ndim == 3: | |||
layer_2 = unflatten(layer_2) | |||
if layer_3.ndim == 3: | |||
layer_3 = unflatten(layer_3) | |||
if layer_4.ndim == 3: | |||
layer_4 = unflatten(layer_4) | |||
layer_1 = pretrained.act_postprocess1[3:len(pretrained.act_postprocess1)]( | |||
layer_1) | |||
layer_2 = pretrained.act_postprocess2[3:len(pretrained.act_postprocess2)]( | |||
layer_2) | |||
layer_3 = pretrained.act_postprocess3[3:len(pretrained.act_postprocess3)]( | |||
layer_3) | |||
layer_4 = pretrained.act_postprocess4[3:len(pretrained.act_postprocess4)]( | |||
layer_4) | |||
return layer_1, layer_2, layer_3, layer_4 | |||
def _resize_pos_embed(self, posemb, gs_h, gs_w): | |||
posemb_tok, posemb_grid = ( | |||
posemb[:, :self.start_index], | |||
posemb[0, self.start_index:], | |||
) | |||
gs_old = int(math.sqrt(len(posemb_grid))) | |||
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, | |||
-1).permute(0, 3, 1, 2) | |||
posemb_grid = F.interpolate( | |||
posemb_grid, size=(gs_h, gs_w), mode='bilinear') | |||
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) | |||
posemb = torch.cat([posemb_tok, posemb_grid], dim=1) | |||
return posemb | |||
def forward_flex(self, x): | |||
b, c, h, w = x.shape | |||
pos_embed = self._resize_pos_embed(self.pos_embed, h // self.patch_size[1], | |||
w // self.patch_size[0]) | |||
B = x.shape[0] | |||
if hasattr(self.patch_embed, 'backbone'): | |||
x = self.patch_embed.backbone(x) | |||
if isinstance(x, (list, tuple)): | |||
x = x[ | |||
-1] # last feature if backbone outputs list/tuple of features | |||
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) | |||
if getattr(self, 'dist_token', None) is not None: | |||
cls_tokens = self.cls_token.expand( | |||
B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks | |||
dist_token = self.dist_token.expand(B, -1, -1) | |||
x = torch.cat((cls_tokens, dist_token, x), dim=1) | |||
else: | |||
cls_tokens = self.cls_token.expand( | |||
B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks | |||
x = torch.cat((cls_tokens, x), dim=1) | |||
x = x + pos_embed | |||
x = self.pos_drop(x) | |||
gradient_checkpoint = False | |||
for blk in self.blocks: | |||
if gradient_checkpoint: | |||
x = checkpoint.checkpoint(blk, x) | |||
else: | |||
x = blk(x) | |||
x = self.norm(x) | |||
return x | |||
def get_readout_oper(vit_features, features, use_readout, start_index=1): | |||
if use_readout == 'ignore': | |||
readout_oper = [Slice(start_index)] * len(features) | |||
elif use_readout == 'add': | |||
readout_oper = [AddReadout(start_index)] * len(features) | |||
elif use_readout == 'project': | |||
readout_oper = [ | |||
ProjectReadout(vit_features, start_index) for out_feat in features | |||
] | |||
else: | |||
assert ( | |||
False | |||
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" | |||
return readout_oper | |||
def adapt_input_conv(in_chans, conv_weight): | |||
conv_type = conv_weight.dtype | |||
conv_weight = conv_weight.float( | |||
) # Some weights are in torch.half, ensure it's float for sum on CPU | |||
O, II, J, K = conv_weight.shape | |||
if in_chans == 1: | |||
if II > 3: | |||
assert conv_weight.shape[1] % 3 == 0 | |||
# For models with space2depth stems | |||
conv_weight = conv_weight.reshape(O, II // 3, 3, J, K) | |||
conv_weight = conv_weight.sum(dim=2, keepdim=False) | |||
else: | |||
conv_weight = conv_weight.sum(dim=1, keepdim=True) | |||
elif in_chans != 3: | |||
if II != 3: | |||
raise NotImplementedError( | |||
'Weight format not supported by conversion.') | |||
else: | |||
# NOTE this strategy should be better than random init, but there could be other combinations of | |||
# the original RGB input layer weights that'd work better for specific cases. | |||
repeat = int(math.ceil(in_chans / 3)) | |||
conv_weight = conv_weight.repeat(1, repeat, 1, | |||
1)[:, :in_chans, :, :] | |||
conv_weight *= (3 / float(in_chans)) | |||
conv_weight = conv_weight.to(conv_type) | |||
return conv_weight | |||
@torch.no_grad() | |||
def _load_weights(model, checkpoint_path, prefix=''): | |||
""" Load weights from .npz checkpoints for official Google Brain Flax implementation | |||
""" | |||
import numpy as np | |||
def _n2p(w, t=True): | |||
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: | |||
w = w.flatten() | |||
if t: | |||
if w.ndim == 4: | |||
w = w.transpose([3, 2, 0, 1]) | |||
elif w.ndim == 3: | |||
w = w.transpose([2, 0, 1]) | |||
elif w.ndim == 2: | |||
w = w.transpose([1, 0]) | |||
return torch.from_numpy(w) | |||
w = np.load(checkpoint_path) | |||
if not prefix and 'opt/target/embedding/kernel' in w: | |||
prefix = 'opt/target/' | |||
if hasattr(model.patch_embed, 'backbone'): | |||
# hybrid | |||
backbone = model.patch_embed.backbone | |||
stem_only = not hasattr(backbone, 'stem') | |||
stem = backbone if stem_only else backbone.stem | |||
stem.conv.weight.copy_( | |||
adapt_input_conv(stem.conv.weight.shape[1], | |||
_n2p(w[f'{prefix}conv_root/kernel']))) | |||
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) | |||
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) | |||
if not stem_only: | |||
for i, stage in enumerate(backbone.stages): | |||
for j, block in enumerate(stage.blocks): | |||
bp = f'{prefix}block{i + 1}/unit{j + 1}/' | |||
for r in range(3): | |||
getattr(block, f'conv{r + 1}').weight.copy_( | |||
_n2p(w[f'{bp}conv{r + 1}/kernel'])) | |||
getattr(block, f'norm{r + 1}').weight.copy_( | |||
_n2p(w[f'{bp}gn{r + 1}/scale'])) | |||
getattr(block, f'norm{r + 1}').bias.copy_( | |||
_n2p(w[f'{bp}gn{r + 1}/bias'])) | |||
if block.downsample is not None: | |||
block.downsample.conv.weight.copy_( | |||
_n2p(w[f'{bp}conv_proj/kernel'])) | |||
block.downsample.norm.weight.copy_( | |||
_n2p(w[f'{bp}gn_proj/scale'])) | |||
block.downsample.norm.bias.copy_( | |||
_n2p(w[f'{bp}gn_proj/bias'])) | |||
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) | |||
else: | |||
embed_conv_w = adapt_input_conv(model.patch_embed.proj.weight.shape[1], | |||
_n2p(w[f'{prefix}embedding/kernel'])) | |||
model.patch_embed.proj.weight.copy_(embed_conv_w) | |||
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) | |||
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) | |||
pos_embed_w = _n2p( | |||
w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) | |||
if pos_embed_w.shape != model.pos_embed.shape: | |||
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights | |||
pos_embed_w, model.pos_embed, getattr(model, 'num_prefix_tokens', | |||
1), | |||
model.patch_embed.grid_size) | |||
model.pos_embed.copy_(pos_embed_w) | |||
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) | |||
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) | |||
if isinstance( | |||
model.head, nn.Linear | |||
) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: | |||
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) | |||
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) | |||
# NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights | |||
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: | |||
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) | |||
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) | |||
for i, block in enumerate(model.blocks.children()): | |||
block_prefix = f'{prefix}Transformer/encoderblock_{i}/' | |||
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' | |||
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) | |||
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) | |||
block.attn.qkv.weight.copy_( | |||
torch.cat([ | |||
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T | |||
for n in ('query', 'key', 'value') | |||
])) | |||
block.attn.qkv.bias.copy_( | |||
torch.cat([ | |||
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) | |||
for n in ('query', 'key', 'value') | |||
])) | |||
block.attn.proj.weight.copy_( | |||
_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) | |||
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) | |||
for r in range(2): | |||
getattr(block.mlp, f'fc{r + 1}').weight.copy_( | |||
_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) | |||
getattr(block.mlp, f'fc{r + 1}').bias.copy_( | |||
_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) | |||
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) | |||
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) | |||
def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()): | |||
# Rescale the grid of position embeddings when loading from state_dict. Adapted from | |||
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 | |||
ntok_new = posemb_new.shape[1] | |||
if num_prefix_tokens: | |||
posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[ | |||
0, num_prefix_tokens:] | |||
ntok_new -= num_prefix_tokens | |||
else: | |||
posemb_prefix, posemb_grid = posemb[:, :0], posemb[0] | |||
gs_old = int(math.sqrt(len(posemb_grid))) | |||
if not len(gs_new): # backwards compatibility | |||
gs_new = [int(math.sqrt(ntok_new))] * 2 | |||
assert len(gs_new) >= 2 | |||
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, | |||
-1).permute(0, 3, 1, 2) | |||
posemb_grid = F.interpolate( | |||
posemb_grid, size=gs_new, mode='bicubic', align_corners=False) | |||
posemb_grid = posemb_grid.permute(0, 2, 3, | |||
1).reshape(1, gs_new[0] * gs_new[1], -1) | |||
posemb = torch.cat([posemb_prefix, posemb_grid], dim=1) | |||
return posemb | |||
def _make_pretrained_clip_vitl16_384(pretrained, | |||
use_readout='ignore', | |||
hooks=None, | |||
enable_attention_hooks=False): | |||
clip_pretrained, _ = clip.load('ViT-B/32', device='cpu', jit=False) | |||
# model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) | |||
model = timm.create_model('vit_large_patch16_384', pretrained=False) | |||
hooks = [5, 11, 17, 23] if hooks is None else hooks | |||
pretrained = _make_vit_b16_backbone( | |||
model, | |||
features=[256, 512, 1024, 1024], | |||
hooks=hooks, | |||
vit_features=1024, | |||
use_readout=use_readout, | |||
enable_attention_hooks=enable_attention_hooks, | |||
) | |||
return clip_pretrained, pretrained | |||
def _make_vit_b16_backbone( | |||
model, | |||
features=[96, 192, 384, 768], | |||
size=[384, 384], | |||
hooks=[2, 5, 8, 11], | |||
vit_features=768, | |||
use_readout='ignore', | |||
start_index=1, | |||
enable_attention_hooks=False, | |||
): | |||
pretrained = nn.Module() | |||
pretrained.model = model | |||
pretrained.model.blocks[hooks[0]].register_forward_hook( | |||
get_activation('1')) | |||
pretrained.model.blocks[hooks[1]].register_forward_hook( | |||
get_activation('2')) | |||
pretrained.model.blocks[hooks[2]].register_forward_hook( | |||
get_activation('3')) | |||
pretrained.model.blocks[hooks[3]].register_forward_hook( | |||
get_activation('4')) | |||
pretrained.activations = activations | |||
if enable_attention_hooks: | |||
pretrained.model.blocks[hooks[0]].attn.register_forward_hook( | |||
get_attention('attn_1')) | |||
pretrained.model.blocks[hooks[1]].attn.register_forward_hook( | |||
get_attention('attn_2')) | |||
pretrained.model.blocks[hooks[2]].attn.register_forward_hook( | |||
get_attention('attn_3')) | |||
pretrained.model.blocks[hooks[3]].attn.register_forward_hook( | |||
get_attention('attn_4')) | |||
pretrained.attention = attention | |||
readout_oper = get_readout_oper(vit_features, features, use_readout, | |||
start_index) | |||
# 32, 48, 136, 384 | |||
pretrained.act_postprocess1 = nn.Sequential( | |||
readout_oper[0], | |||
Transpose(1, 2), | |||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), | |||
nn.Conv2d( | |||
in_channels=vit_features, | |||
out_channels=features[0], | |||
kernel_size=1, | |||
stride=1, | |||
padding=0, | |||
), | |||
nn.ConvTranspose2d( | |||
in_channels=features[0], | |||
out_channels=features[0], | |||
kernel_size=4, | |||
stride=4, | |||
padding=0, | |||
bias=True, | |||
dilation=1, | |||
groups=1, | |||
), | |||
) | |||
pretrained.act_postprocess2 = nn.Sequential( | |||
readout_oper[1], | |||
Transpose(1, 2), | |||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), | |||
nn.Conv2d( | |||
in_channels=vit_features, | |||
out_channels=features[1], | |||
kernel_size=1, | |||
stride=1, | |||
padding=0, | |||
), | |||
nn.ConvTranspose2d( | |||
in_channels=features[1], | |||
out_channels=features[1], | |||
kernel_size=2, | |||
stride=2, | |||
padding=0, | |||
bias=True, | |||
dilation=1, | |||
groups=1, | |||
), | |||
) | |||
pretrained.act_postprocess3 = nn.Sequential( | |||
readout_oper[2], | |||
Transpose(1, 2), | |||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), | |||
nn.Conv2d( | |||
in_channels=vit_features, | |||
out_channels=features[2], | |||
kernel_size=1, | |||
stride=1, | |||
padding=0, | |||
), | |||
) | |||
pretrained.act_postprocess4 = nn.Sequential( | |||
readout_oper[3], | |||
Transpose(1, 2), | |||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), | |||
nn.Conv2d( | |||
in_channels=vit_features, | |||
out_channels=features[3], | |||
kernel_size=1, | |||
stride=1, | |||
padding=0, | |||
), | |||
nn.Conv2d( | |||
in_channels=features[3], | |||
out_channels=features[3], | |||
kernel_size=3, | |||
stride=2, | |||
padding=1, | |||
), | |||
) | |||
pretrained.model.start_index = start_index | |||
pretrained.model.patch_size = [16, 16] | |||
# We inject this function into the VisionTransformer instances so that | |||
# we can use it with interpolated position embeddings without modifying the library source. | |||
pretrained.model.forward_flex = types.MethodType(forward_flex, | |||
pretrained.model) | |||
pretrained.model._resize_pos_embed = types.MethodType( | |||
_resize_pos_embed, pretrained.model) | |||
return pretrained |
@@ -0,0 +1,458 @@ | |||
""" | |||
Adapted from https://github.com/isl-org/lang-seg. | |||
Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. | |||
""" | |||
from collections import OrderedDict | |||
from typing import Tuple, Union | |||
import numpy as np | |||
import torch | |||
import torch.nn.functional as F | |||
from torch import nn | |||
class Bottleneck(nn.Module): | |||
expansion = 4 | |||
def __init__(self, inplanes, planes, stride=1): | |||
super().__init__() | |||
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 | |||
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) | |||
self.bn1 = nn.BatchNorm2d(planes) | |||
self.relu1 = nn.ReLU(inplace=True) | |||
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) | |||
self.bn2 = nn.BatchNorm2d(planes) | |||
self.relu2 = nn.ReLU(inplace=True) | |||
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() | |||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) | |||
self.bn3 = nn.BatchNorm2d(planes * self.expansion) | |||
self.relu3 = nn.ReLU(inplace=True) | |||
self.downsample = None | |||
self.stride = stride | |||
if stride > 1 or inplanes != planes * Bottleneck.expansion: | |||
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 | |||
self.downsample = nn.Sequential( | |||
OrderedDict([('-1', nn.AvgPool2d(stride)), | |||
('0', | |||
nn.Conv2d( | |||
inplanes, | |||
planes * self.expansion, | |||
1, | |||
stride=1, | |||
bias=False)), | |||
('1', nn.BatchNorm2d(planes * self.expansion))])) | |||
def forward(self, x: torch.Tensor): | |||
identity = x | |||
out = self.relu1(self.bn1(self.conv1(x))) | |||
out = self.relu2(self.bn2(self.conv2(out))) | |||
out = self.avgpool(out) | |||
out = self.bn3(self.conv3(out)) | |||
if self.downsample is not None: | |||
identity = self.downsample(x) | |||
out += identity | |||
out = self.relu3(out) | |||
return out | |||
class AttentionPool2d(nn.Module): | |||
def __init__(self, | |||
spacial_dim: int, | |||
embed_dim: int, | |||
num_heads: int, | |||
output_dim: int = None): | |||
super().__init__() | |||
self.positional_embedding = nn.Parameter( | |||
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) | |||
self.k_proj = nn.Linear(embed_dim, embed_dim) | |||
self.q_proj = nn.Linear(embed_dim, embed_dim) | |||
self.v_proj = nn.Linear(embed_dim, embed_dim) | |||
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) | |||
self.num_heads = num_heads | |||
def forward(self, x): | |||
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC | |||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC | |||
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC | |||
x, _ = F.multi_head_attention_forward( | |||
query=x[:1], | |||
key=x, | |||
value=x, | |||
embed_dim_to_check=x.shape[-1], | |||
num_heads=self.num_heads, | |||
q_proj_weight=self.q_proj.weight, | |||
k_proj_weight=self.k_proj.weight, | |||
v_proj_weight=self.v_proj.weight, | |||
in_proj_weight=None, | |||
in_proj_bias=torch.cat( | |||
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), | |||
bias_k=None, | |||
bias_v=None, | |||
add_zero_attn=False, | |||
dropout_p=0, | |||
out_proj_weight=self.c_proj.weight, | |||
out_proj_bias=self.c_proj.bias, | |||
use_separate_proj_weight=True, | |||
training=self.training, | |||
need_weights=False) | |||
return x.squeeze(0) | |||
class ModifiedResNet(nn.Module): | |||
""" | |||
A ResNet class that is similar to torchvision's but contains the following changes: | |||
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. | |||
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 | |||
- The final pooling layer is a QKV attention instead of an average pool | |||
""" | |||
def __init__(self, | |||
layers, | |||
output_dim, | |||
heads, | |||
input_resolution=224, | |||
width=64): | |||
super().__init__() | |||
self.output_dim = output_dim | |||
self.input_resolution = input_resolution | |||
# the 3-layer stem | |||
self.conv1 = nn.Conv2d( | |||
3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) | |||
self.bn1 = nn.BatchNorm2d(width // 2) | |||
self.relu1 = nn.ReLU(inplace=True) | |||
self.conv2 = nn.Conv2d( | |||
width // 2, width // 2, kernel_size=3, padding=1, bias=False) | |||
self.bn2 = nn.BatchNorm2d(width // 2) | |||
self.relu2 = nn.ReLU(inplace=True) | |||
self.conv3 = nn.Conv2d( | |||
width // 2, width, kernel_size=3, padding=1, bias=False) | |||
self.bn3 = nn.BatchNorm2d(width) | |||
self.relu3 = nn.ReLU(inplace=True) | |||
self.avgpool = nn.AvgPool2d(2) | |||
# residual layers | |||
self._inplanes = width # this is a *mutable* variable used during construction | |||
self.layer1 = self._make_layer(width, layers[0]) | |||
self.layer2 = self._make_layer(width * 2, layers[1], stride=2) | |||
self.layer3 = self._make_layer(width * 4, layers[2], stride=2) | |||
self.layer4 = self._make_layer(width * 8, layers[3], stride=2) | |||
embed_dim = width * 32 # the ResNet feature dimension | |||
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, | |||
heads, output_dim) | |||
def _make_layer(self, planes, blocks, stride=1): | |||
layers = [Bottleneck(self._inplanes, planes, stride)] | |||
self._inplanes = planes * Bottleneck.expansion | |||
for _ in range(1, blocks): | |||
layers.append(Bottleneck(self._inplanes, planes)) | |||
return nn.Sequential(*layers) | |||
def forward(self, x): | |||
def stem(x): | |||
x = self.relu1(self.bn1(self.conv1(x))) | |||
x = self.relu2(self.bn2(self.conv2(x))) | |||
x = self.relu3(self.bn3(self.conv3(x))) | |||
x = self.avgpool(x) | |||
return x | |||
x = x.type(self.conv1.weight.dtype) | |||
x = stem(x) | |||
x = self.layer1(x) | |||
x = self.layer2(x) | |||
x = self.layer3(x) | |||
x = self.layer4(x) | |||
x = self.attnpool(x) | |||
return x | |||
class LayerNorm(nn.LayerNorm): | |||
"""Subclass torch's LayerNorm to handle fp16.""" | |||
def forward(self, x: torch.Tensor): | |||
orig_type = x.dtype | |||
ret = super().forward(x.type(torch.float32)) | |||
return ret.type(orig_type) | |||
class QuickGELU(nn.Module): | |||
def forward(self, x: torch.Tensor): | |||
return x * torch.sigmoid(1.702 * x) | |||
class ResidualAttentionBlock(nn.Module): | |||
def __init__(self, | |||
d_model: int, | |||
n_head: int, | |||
attn_mask: torch.Tensor = None): | |||
super().__init__() | |||
self.attn = nn.MultiheadAttention(d_model, n_head) | |||
self.ln_1 = LayerNorm(d_model) | |||
self.mlp = nn.Sequential( | |||
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), | |||
('gelu', QuickGELU()), | |||
('c_proj', nn.Linear(d_model * 4, d_model))])) | |||
self.ln_2 = LayerNorm(d_model) | |||
self.attn_mask = attn_mask | |||
def attention(self, x: torch.Tensor): | |||
self.attn_mask = self.attn_mask.to( | |||
dtype=x.dtype, | |||
device=x.device) if self.attn_mask is not None else None | |||
return self.attn( | |||
x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] | |||
def forward(self, x: torch.Tensor): | |||
x = x + self.attention(self.ln_1(x)) | |||
x = x + self.mlp(self.ln_2(x)) | |||
return x | |||
class Transformer(nn.Module): | |||
def __init__(self, width, layers, heads, attn_mask=None): | |||
super().__init__() | |||
self.width = width | |||
self.layers = layers | |||
self.resblocks = nn.Sequential(*[ | |||
ResidualAttentionBlock(width, heads, attn_mask) | |||
for _ in range(layers) | |||
]) | |||
def forward(self, x: torch.Tensor): | |||
return self.resblocks(x) | |||
class VisionTransformer(nn.Module): | |||
def __init__(self, input_resolution: int, patch_size: int, width: int, | |||
layers: int, heads: int, output_dim: int): | |||
super().__init__() | |||
self.input_resolution = input_resolution | |||
self.output_dim = output_dim | |||
self.conv1 = nn.Conv2d( | |||
in_channels=3, | |||
out_channels=width, | |||
kernel_size=patch_size, | |||
stride=patch_size, | |||
bias=False) | |||
scale = width**-0.5 | |||
self.class_embedding = nn.Parameter(scale * torch.randn(width)) | |||
self.positional_embedding = nn.Parameter(scale * torch.randn( | |||
(input_resolution // patch_size)**2 + 1, width)) | |||
self.ln_pre = LayerNorm(width) | |||
self.transformer = Transformer(width, layers, heads) | |||
self.ln_post = LayerNorm(width) | |||
self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) | |||
def forward(self, x: torch.Tensor): | |||
x = self.conv1(x) # shape = [*, width, grid, grid] | |||
x = x.reshape(x.shape[0], x.shape[1], | |||
-1) # shape = [*, width, grid ** 2] | |||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |||
x1 = self.class_embedding.to(x.dtype) | |||
x2 = torch.zeros( | |||
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) | |||
x = torch.cat([x1 + x2, x], dim=1) # shape = [*, grid ** 2 + 1, width] | |||
x = x + self.positional_embedding.to(x.dtype) | |||
x = self.ln_pre(x) | |||
x = x.permute(1, 0, 2) # NLD -> LND | |||
x = self.transformer(x) | |||
x = x.permute(1, 0, 2) # LND -> NLD | |||
x = self.ln_post(x[:, 0, :]) | |||
if self.proj is not None: | |||
x = x @ self.proj | |||
return x | |||
class CLIP(nn.Module): | |||
def __init__( | |||
self, | |||
embed_dim: int, | |||
# vision | |||
image_resolution: int, | |||
vision_layers: Union[Tuple[int, int, int, int], int], | |||
vision_width: int, | |||
vision_patch_size: int, | |||
# text | |||
context_length: int, | |||
vocab_size: int, | |||
transformer_width: int, | |||
transformer_heads: int, | |||
transformer_layers: int): | |||
super().__init__() | |||
self.context_length = context_length | |||
if isinstance(vision_layers, (tuple, list)): | |||
vision_heads = vision_width * 32 // 64 | |||
self.visual = ModifiedResNet( | |||
layers=vision_layers, | |||
output_dim=embed_dim, | |||
heads=vision_heads, | |||
input_resolution=image_resolution, | |||
width=vision_width) | |||
else: | |||
vision_heads = vision_width // 64 | |||
self.visual = VisionTransformer( | |||
input_resolution=image_resolution, | |||
patch_size=vision_patch_size, | |||
width=vision_width, | |||
layers=vision_layers, | |||
heads=vision_heads, | |||
output_dim=embed_dim) | |||
self.transformer = Transformer( | |||
width=transformer_width, | |||
layers=transformer_layers, | |||
heads=transformer_heads, | |||
attn_mask=self.build_attention_mask()) | |||
self.vocab_size = vocab_size | |||
self.token_embedding = nn.Embedding(vocab_size, transformer_width) | |||
self.positional_embedding = nn.Parameter( | |||
torch.empty(self.context_length, transformer_width)) | |||
self.ln_final = LayerNorm(transformer_width) | |||
self.text_projection = nn.Parameter( | |||
torch.empty(transformer_width, embed_dim)) | |||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |||
self.initialize_parameters() | |||
def initialize_parameters(self): | |||
nn.init.normal_(self.token_embedding.weight, std=0.02) | |||
nn.init.normal_(self.positional_embedding, std=0.01) | |||
if isinstance(self.visual, ModifiedResNet): | |||
if self.visual.attnpool is not None: | |||
std = self.visual.attnpool.c_proj.in_features**-0.5 | |||
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) | |||
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) | |||
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) | |||
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) | |||
for resnet_block in [ | |||
self.visual.layer1, self.visual.layer2, self.visual.layer3, | |||
self.visual.layer4 | |||
]: | |||
for name, param in resnet_block.named_parameters(): | |||
if name.endswith('bn3.weight'): | |||
nn.init.zeros_(param) | |||
proj_std = (self.transformer.width**-0.5) * ( | |||
(2 * self.transformer.layers)**-0.5) | |||
attn_std = self.transformer.width**-0.5 | |||
fc_std = (2 * self.transformer.width)**-0.5 | |||
for block in self.transformer.resblocks: | |||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std) | |||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std) | |||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) | |||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) | |||
if self.text_projection is not None: | |||
nn.init.normal_( | |||
self.text_projection, std=self.transformer.width**-0.5) | |||
def build_attention_mask(self): | |||
# lazily create causal attention mask, with full attention between the vision tokens | |||
# pytorch uses additive attention mask; fill with -inf | |||
mask = torch.empty(self.context_length, self.context_length) | |||
mask.fill_(float('-inf')) | |||
mask.triu_(1) # zero out the lower diagonal | |||
return mask | |||
@property | |||
def dtype(self): | |||
return self.visual.conv1.weight.dtype | |||
def encode_image(self, image): | |||
return self.visual(image.type(self.dtype)) | |||
def encode_text(self, text): | |||
x = self.token_embedding(text).type(self.dtype) | |||
x = x + self.positional_embedding.type(self.dtype) | |||
x = x.permute(1, 0, 2) # NLD -> LND | |||
x = self.transformer(x) | |||
x = x.permute(1, 0, 2) # LND -> NLD | |||
x = self.ln_final(x).type(self.dtype) | |||
x = x[torch.arange(x.shape[0]), | |||
text.argmax(dim=-1)] @ self.text_projection | |||
return x | |||
def forward(self, image, text): | |||
image_features = self.encode_image(image) | |||
text_features = self.encode_text(text) | |||
# normalized features | |||
image_features = image_features / image_features.norm( | |||
dim=1, keepdim=True) | |||
text_features = text_features / text_features.norm(dim=1, keepdim=True) | |||
# cosine similarity as logits | |||
logit_scale = self.logit_scale.exp() | |||
logits_per_image = logit_scale * image_features @ text_features.t() | |||
logits_per_text = logits_per_image.t() | |||
# shape = [global_batch_size, global_batch_size] | |||
return logits_per_image, logits_per_text | |||
def convert_weights(model: nn.Module): | |||
"""Convert applicable model parameters to fp16""" | |||
def _convert_weights_to_fp16(ll): | |||
if isinstance(ll, (nn.Conv1d, nn.Conv2d, nn.Linear)): | |||
ll.weight.data = ll.weight.data.half() | |||
if ll.bias is not None: | |||
ll.bias.data = ll.bias.data.half() | |||
if isinstance(ll, nn.MultiheadAttention): | |||
for attr in [ | |||
*[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']], | |||
'in_proj_bias', 'bias_k', 'bias_v' | |||
]: | |||
tensor = getattr(ll, attr) | |||
if tensor is not None: | |||
tensor.data = tensor.data.half() | |||
for name in ['text_projection', 'proj']: | |||
if hasattr(ll, name): | |||
attr = getattr(ll, name) | |||
if attr is not None: | |||
attr.data = attr.data.half() | |||
model.apply(_convert_weights_to_fp16) | |||
def build_model(): | |||
model = CLIP(512, 224, 12, 768, 32, 77, 49408, 512, 8, 12) | |||
convert_weights(model) | |||
return model.eval() |
@@ -0,0 +1,156 @@ | |||
""" CLIP | |||
Adapted from https://github.com/openai/CLIP. | |||
Originally MIT License, Copyright (c) 2021 OpenAI. | |||
""" | |||
import gzip | |||
import html | |||
import os | |||
from functools import lru_cache | |||
import ftfy | |||
import regex as re | |||
@lru_cache() | |||
def default_bpe(): | |||
return os.path.join( | |||
os.path.dirname(os.path.abspath(__file__)), | |||
'bpe_simple_vocab_16e6.txt.gz') | |||
@lru_cache() | |||
def bytes_to_unicode(): | |||
""" | |||
Returns list of utf-8 byte and a corresponding list of unicode strings. | |||
The reversible bpe codes work on unicode strings. | |||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. | |||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. | |||
This is a signficant percentage of your normal, say, 32K bpe vocab. | |||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings. | |||
And avoids mapping to whitespace/control characters the bpe code barfs on. | |||
""" | |||
bs = list(range(ord('!'), | |||
ord('~') + 1)) + list(range( | |||
ord('¡'), | |||
ord('¬') + 1)) + list(range(ord('®'), | |||
ord('ÿ') + 1)) | |||
cs = bs[:] | |||
n = 0 | |||
for b in range(2**8): | |||
if b not in bs: | |||
bs.append(b) | |||
cs.append(2**8 + n) | |||
n += 1 | |||
cs = [chr(n) for n in cs] | |||
return dict(zip(bs, cs)) | |||
def get_pairs(word): | |||
"""Return set of symbol pairs in a word. | |||
Word is represented as tuple of symbols (symbols being variable-length strings). | |||
""" | |||
pairs = set() | |||
prev_char = word[0] | |||
for char in word[1:]: | |||
pairs.add((prev_char, char)) | |||
prev_char = char | |||
return pairs | |||
def basic_clean(text): | |||
text = ftfy.fix_text(text) | |||
text = html.unescape(html.unescape(text)) | |||
return text.strip() | |||
def whitespace_clean(text): | |||
text = re.sub(r'\s+', ' ', text) | |||
text = text.strip() | |||
return text | |||
class SimpleTokenizer(object): | |||
def __init__(self, bpe_path: str = default_bpe()): | |||
self.byte_encoder = bytes_to_unicode() | |||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} | |||
merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') | |||
merges = merges[1:49152 - 256 - 2 + 1] | |||
merges = [tuple(merge.split()) for merge in merges] | |||
vocab = list(bytes_to_unicode().values()) | |||
vocab = vocab + [v + '</w>' for v in vocab] | |||
for merge in merges: | |||
vocab.append(''.join(merge)) | |||
vocab.extend(['<|startoftext|>', '<|endoftext|>']) | |||
self.encoder = dict(zip(vocab, range(len(vocab)))) | |||
self.decoder = {v: k for k, v in self.encoder.items()} | |||
self.bpe_ranks = dict(zip(merges, range(len(merges)))) | |||
self.cache = { | |||
'<|startoftext|>': '<|startoftext|>', | |||
'<|endoftext|>': '<|endoftext|>' | |||
} | |||
self.pat = re.compile( | |||
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", | |||
re.IGNORECASE) | |||
def bpe(self, token): | |||
if token in self.cache: | |||
return self.cache[token] | |||
word = tuple(token[:-1]) + (token[-1] + '</w>', ) | |||
pairs = get_pairs(word) | |||
if not pairs: | |||
return token + '</w>' | |||
while True: | |||
bigram = min( | |||
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) | |||
if bigram not in self.bpe_ranks: | |||
break | |||
first, second = bigram | |||
new_word = [] | |||
i = 0 | |||
error_list = [] | |||
while i < len(word): | |||
try: | |||
j = word.index(first, i) | |||
new_word.extend(word[i:j]) | |||
i = j | |||
except Exception as err: | |||
new_word.extend(word[i:]) | |||
error_list.append(err) | |||
break | |||
if word[i] == first and i < len(word) - 1 and word[ | |||
i + 1] == second: | |||
new_word.append(first + second) | |||
i += 2 | |||
else: | |||
new_word.append(word[i]) | |||
i += 1 | |||
new_word = tuple(new_word) | |||
word = new_word | |||
if len(word) == 1: | |||
break | |||
else: | |||
pairs = get_pairs(word) | |||
word = ' '.join(word) | |||
self.cache[token] = word | |||
return word | |||
def encode(self, text): | |||
bpe_tokens = [] | |||
text = whitespace_clean(basic_clean(text)).lower() | |||
for token in re.findall(self.pat, text): | |||
token = ''.join(self.byte_encoder[b] | |||
for b in token.encode('utf-8')) | |||
bpe_tokens.extend(self.encoder[bpe_token] | |||
for bpe_token in self.bpe(token).split(' ')) | |||
return bpe_tokens | |||
def decode(self, tokens): | |||
text = ''.join([self.decoder[token] for token in tokens]) | |||
text = bytearray([self.byte_decoder[c] for c in text]).decode( | |||
'utf-8', errors='replace').replace('</w>', ' ') | |||
return text |
@@ -243,6 +243,13 @@ TASK_OUTPUTS = { | |||
# "output_img": np.ndarray with shape [height, width, 3] | |||
# } | |||
Tasks.virtual_try_on: [OutputKeys.OUTPUT_IMG], | |||
# text driven segmentation result for single sample | |||
# { | |||
# "masks": [ | |||
# np.array # 2D array containing only 0, 255 | |||
# ] | |||
# } | |||
Tasks.text_driven_segmentation: [OutputKeys.MASKS], | |||
# movide scene segmentation result for a single video | |||
# { | |||
@@ -149,6 +149,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
'damo/cv_vitb_video-single-object-tracking_ostrack'), | |||
Tasks.image_reid_person: (Pipelines.image_reid_person, | |||
'damo/cv_passvitb_image-reid-person_market'), | |||
Tasks.text_driven_segmentation: | |||
(Pipelines.text_driven_segmentation, | |||
'damo/cv_vitl16_segmentation_text-driven-seg'), | |||
Tasks.movie_scene_segmentation: | |||
(Pipelines.movie_scene_segmentation, | |||
'damo/cv_resnet50-bert_video-scene-segmentation_movienet') | |||
@@ -44,6 +44,7 @@ if TYPE_CHECKING: | |||
from .video_category_pipeline import VideoCategoryPipeline | |||
from .virtual_try_on_pipeline import VirtualTryonPipeline | |||
from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline | |||
from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipleline | |||
from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline | |||
else: | |||
@@ -97,6 +98,8 @@ else: | |||
'virtual_try_on_pipeline': ['VirtualTryonPipeline'], | |||
'easycv_pipeline': | |||
['EasyCVDetectionPipeline', 'EasyCVSegmentationPipeline'], | |||
'text_driven_segmentation_pipeline': | |||
['TextDrivenSegmentationPipeline'], | |||
'movie_scene_segmentation_pipeline': | |||
['MovieSceneSegmentationPipeline'], | |||
} | |||
@@ -0,0 +1,51 @@ | |||
from typing import Any, Dict | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import LoadImage | |||
from modelscope.utils.constant import Tasks | |||
@PIPELINES.register_module( | |||
Tasks.text_driven_segmentation, | |||
module_name=Pipelines.text_driven_segmentation) | |||
class TextDrivenSegmentationPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
model: model id on modelscope hub. | |||
""" | |||
super().__init__(model=model, auto_collate=False, **kwargs) | |||
def preprocess(self, input: Dict) -> Dict[str, Any]: | |||
img = LoadImage.convert_to_ndarray(input['image']) | |||
img_tensor, ori_h, ori_w, crop_h, crop_w = self.model.preprocess(img) | |||
result = { | |||
'img': img_tensor, | |||
'ori_h': ori_h, | |||
'ori_w': ori_w, | |||
'crop_h': crop_h, | |||
'crop_w': crop_w, | |||
'text': input['text'], | |||
} | |||
return result | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
outputs = self.model.inference(input['img'], input['text']) | |||
result = { | |||
'data': outputs, | |||
'ori_h': input['ori_h'], | |||
'ori_w': input['ori_w'], | |||
'crop_h': input['crop_h'], | |||
'crop_w': input['crop_w'], | |||
} | |||
return result | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
data = self.model.postprocess(inputs['data'], inputs['crop_h'], | |||
inputs['crop_w'], inputs['ori_h'], | |||
inputs['ori_w']) | |||
outputs = {OutputKeys.MASKS: data} | |||
return outputs |
@@ -36,6 +36,7 @@ class CVTasks(object): | |||
image_segmentation = 'image-segmentation' | |||
portrait_matting = 'portrait-matting' | |||
text_driven_segmentation = 'text-driven-segmentation' | |||
# image editing | |||
skin_retouching = 'skin-retouching' | |||
@@ -0,0 +1,28 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.test_utils import test_level | |||
class TextDrivenSegmentationTest(unittest.TestCase): | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_text_driven_segmentation(self): | |||
input_location = 'data/test/images/text_driven_segmentation.jpg' | |||
test_input = { | |||
'image': input_location, | |||
'text': 'bear', | |||
} | |||
model_id = 'damo/cv_vitl16_segmentation_text-driven-seg' | |||
shop_seg = pipeline(Tasks.text_driven_segmentation, model=model_id) | |||
result = shop_seg(test_input) | |||
import cv2 | |||
# result[OutputKeys.MASKS] is segment map result,other keys are not used | |||
cv2.imwrite(input_location + '_lseg.jpg', result[OutputKeys.MASKS]) | |||
if __name__ == '__main__': | |||
unittest.main() |