Browse Source

[to #42322933]文本指导的语义分割模型

文本指导的语义分割模型,根据输入的文本信息,讲图像中对应文本描述的物体分割出来。
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9942863
master
xingguang.zxg yingda.chen 3 years ago
parent
commit
4d3716cf4e
17 changed files with 2092 additions and 0 deletions
  1. +3
    -0
      data/test/images/text_driven_segmentation.jpg
  2. +2
    -0
      modelscope/metainfo.py
  3. +1
    -0
      modelscope/models/cv/text_driven_segmentation/__init__.py
  4. +170
    -0
      modelscope/models/cv/text_driven_segmentation/clip.py
  5. +28
    -0
      modelscope/models/cv/text_driven_segmentation/lseg_base.py
  6. +334
    -0
      modelscope/models/cv/text_driven_segmentation/lseg_blocks.py
  7. +107
    -0
      modelscope/models/cv/text_driven_segmentation/lseg_model.py
  8. +197
    -0
      modelscope/models/cv/text_driven_segmentation/lseg_net.py
  9. +543
    -0
      modelscope/models/cv/text_driven_segmentation/lseg_vit.py
  10. +458
    -0
      modelscope/models/cv/text_driven_segmentation/model.py
  11. +156
    -0
      modelscope/models/cv/text_driven_segmentation/simple_tokenizer.py
  12. +7
    -0
      modelscope/outputs.py
  13. +3
    -0
      modelscope/pipelines/builder.py
  14. +3
    -0
      modelscope/pipelines/cv/__init__.py
  15. +51
    -0
      modelscope/pipelines/cv/text_driven_segmentation_pipleline.py
  16. +1
    -0
      modelscope/utils/constant.py
  17. +28
    -0
      tests/pipelines/test_text_driven_segmentation.py

+ 3
- 0
data/test/images/text_driven_segmentation.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2c7d2f279e3b317f1d0de18410a0585e122166fa2464c17b88a0c813f6c58bd4
size 67861

+ 2
- 0
modelscope/metainfo.py View File

@@ -29,6 +29,7 @@ class Models(object):
video_summarization = 'pgl-video-summarization'
swinL_semantic_segmentation = 'swinL-semantic-segmentation'
vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation'
text_driven_segmentation = 'text-driven-segmentation'
resnet50_bert = 'resnet50-bert'

# EasyCV models
@@ -143,6 +144,7 @@ class Pipelines(object):
video_summarization = 'googlenet_pgl_video_summarization'
image_semantic_segmentation = 'image-semantic-segmentation'
image_reid_person = 'passvitb-image-reid-person'
text_driven_segmentation = 'text-driven-segmentation'
movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation'

# nlp tasks


+ 1
- 0
modelscope/models/cv/text_driven_segmentation/__init__.py View File

@@ -0,0 +1 @@
from .lseg_base import TextDrivenSegmentation

+ 170
- 0
modelscope/models/cv/text_driven_segmentation/clip.py View File

@@ -0,0 +1,170 @@
""" CLIP
Adapted from https://github.com/openai/CLIP.
Originally MIT License, Copyright (c) 2021 OpenAI.
"""

import hashlib
import os
import urllib
import warnings
from typing import Any, List, Union

import torch
from PIL import Image
from pkg_resources import packaging
from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
ToTensor)
from tqdm import tqdm

from .model import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer

try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC

if packaging.version.parse(
torch.__version__) < packaging.version.parse('1.7.1'):
warnings.warn('PyTorch version 1.7.1 or higher is recommended')
__all__ = ['load', 'tokenize']


def _convert_image_to_rgb(image):
return image.convert('RGB')


def _transform(n_px):
return Compose([
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
_convert_image_to_rgb,
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])


def load(name: str,
device: Union[str, torch.device] = 'cuda'
if torch.cuda.is_available() else 'cpu',
jit: bool = False,
root: str = None):

if not jit:
model = build_model().to(device)
if str(device) == 'cpu':
model.float()
return model, _transform(model.visual.input_resolution)

# patch the device names
device_holder = torch.jit.trace(
lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [
n for n in device_holder.graph.findAllNodes('prim::Constant')
if 'Device' in repr(n)
][-1]

def patch_device(module):
try:
graphs = [module.graph] if hasattr(module, 'graph') else []
except RuntimeError:
graphs = []

if hasattr(module, 'forward1'):
graphs.append(module.forward1.graph)

for graph in graphs:
for node in graph.findAllNodes('prim::Constant'):
if 'value' in node.attributeNames() and str(
node['value']).startswith('cuda'):
node.copyAttributes(device_node)

model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)

# patch dtype to float32 on CPU
if str(device) == 'cpu':
float_holder = torch.jit.trace(
lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode('aten::to').inputs())[1]
float_node = float_input.node()

def patch_float(module):
try:
graphs = [module.graph] if hasattr(module, 'graph') else []
except RuntimeError:
graphs = []

if hasattr(module, 'forward1'):
graphs.append(module.forward1.graph)

for graph in graphs:
for node in graph.findAllNodes('aten::to'):
inputs = list(node.inputs())
for i in [
1, 2
]: # dtype can be the second or third argument to aten::to()
if inputs[i].node()['value'] == 5:
inputs[i].node().copyAttributes(float_node)

model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)

model.float()

return model, _transform(model.input_resolution.item())


def tokenize(
_tokenizer,
texts: Union[str, List[str]],
context_length: int = 77,
truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
"""
Returns the tokenized representation of given input string(s)

Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize

context_length : int
The context length to use; all CLIP models use 77 as the context length

truncate: bool
Whether to truncate the text in case its encoding is longer than the context length

Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
"""
if isinstance(texts, str):
texts = [texts]

sot_token = _tokenizer.encoder['<|startoftext|>']
eot_token = _tokenizer.encoder['<|endoftext|>']
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
for text in texts]
if packaging.version.parse(
torch.__version__) < packaging.version.parse('1.8.0'):
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
else:
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)

for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate:
tokens = tokens[:context_length]
tokens[-1] = eot_token
else:
raise RuntimeError(
f'Input {texts[i]} is too long for context length {context_length}'
)
result[i, :len(tokens)] = torch.tensor(tokens)

return result

+ 28
- 0
modelscope/models/cv/text_driven_segmentation/lseg_base.py View File

@@ -0,0 +1,28 @@
"""
Adapted from https://github.com/isl-org/lang-seg.
Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org.
"""

import torch
import torch.nn as nn

from .lseg_net import LSeg


class TextDrivenSegmentation(nn.Module):

def __init__(self, model_dir):
super(TextDrivenSegmentation, self).__init__()
self.net = LSeg(model_dir=model_dir)
self.model_dir = model_dir

def forward(self, img, txt_list):
b = img.size()[0]
batch_name_list = txt_list
xout_list = []
for i in range(b):
labelset = ['others', batch_name_list[i]]
xout = self.net(img[i:i + 1], labelset=labelset)
xout_list.append(xout)
score_map = torch.cat(xout_list, dim=0)
return score_map

+ 334
- 0
modelscope/models/cv/text_driven_segmentation/lseg_blocks.py View File

@@ -0,0 +1,334 @@
"""
Adapted from https://github.com/isl-org/lang-seg.
Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org.
"""

import torch
import torch.nn as nn

from .lseg_vit import _make_pretrained_clip_vitl16_384, forward_vit


def _make_encoder(
backbone,
features,
use_pretrained=True,
groups=1,
expand=False,
exportable=True,
hooks=None,
use_vit_only=False,
use_readout='ignore',
enable_attention_hooks=False,
):
if backbone == 'clip_vitl16_384':
clip_pretrained, pretrained = _make_pretrained_clip_vitl16_384(
use_pretrained,
hooks=hooks,
use_readout=use_readout,
enable_attention_hooks=enable_attention_hooks,
)
scratch = _make_scratch([256, 512, 1024, 1024],
features,
groups=groups,
expand=expand)
else:
raise NotImplementedError(f"Backbone '{backbone}' not implemented")

return clip_pretrained, pretrained, scratch


def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Module()

out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
out_shape4 = out_shape
if expand is True:
out_shape1 = out_shape
out_shape2 = out_shape * 2
out_shape3 = out_shape * 4
out_shape4 = out_shape * 8

scratch.layer1_rn = nn.Conv2d(
in_shape[0],
out_shape1,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups,
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1],
out_shape2,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups,
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2],
out_shape3,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups,
)
scratch.layer4_rn = nn.Conv2d(
in_shape[3],
out_shape4,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups,
)

return scratch


class Interpolate(nn.Module):
"""Interpolation module."""

def __init__(self, scale_factor, mode, align_corners=False):
"""Init.

Args:
scale_factor (float): scaling
mode (str): interpolation mode
"""
super(Interpolate, self).__init__()

self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners

def forward(self, x):
"""Forward pass.

Args:
x (tensor): input

Returns:
tensor: interpolated data
"""

x = self.interp(
x,
scale_factor=self.scale_factor,
mode=self.mode,
align_corners=self.align_corners,
)

return x


class ResidualConvUnit(nn.Module):
"""Residual convolution module."""

def __init__(self, features):
"""Init.

Args:
features (int): number of features
"""
super().__init__()

self.conv1 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True)

self.conv2 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True)

self.relu = nn.ReLU(inplace=True)

def forward(self, x):
"""Forward pass.

Args:
x (tensor): input

Returns:
tensor: output
"""
out = self.relu(x)
out = self.conv1(out)
out = self.relu(out)
out = self.conv2(out)

return out + x


class FeatureFusionBlock(nn.Module):
"""Feature fusion block."""

def __init__(self, features):
"""Init.

Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()

self.resConfUnit1 = ResidualConvUnit(features)
self.resConfUnit2 = ResidualConvUnit(features)

def forward(self, *xs):
"""Forward pass.

Returns:
tensor: output
"""
output = xs[0]

if len(xs) == 2:
output += self.resConfUnit1(xs[1])

output = self.resConfUnit2(output)

output = nn.functional.interpolate(
output, scale_factor=2, mode='bilinear', align_corners=True)

return output


class ResidualConvUnit_custom(nn.Module):
"""Residual convolution module."""

def __init__(self, features, activation, bn):
"""Init.

Args:
features (int): number of features
"""
super().__init__()

self.bn = bn

self.groups = 1

self.conv1 = nn.Conv2d(
features,
features,
kernel_size=3,
stride=1,
padding=1,
bias=not self.bn,
groups=self.groups,
)

self.conv2 = nn.Conv2d(
features,
features,
kernel_size=3,
stride=1,
padding=1,
bias=not self.bn,
groups=self.groups,
)

if self.bn is True:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)

self.activation = activation

self.skip_add = nn.quantized.FloatFunctional()

def forward(self, x):
"""Forward pass.

Args:
x (tensor): input

Returns:
tensor: output
"""

out = self.activation(x)
out = self.conv1(out)
if self.bn is True:
out = self.bn1(out)

out = self.activation(out)
out = self.conv2(out)
if self.bn is True:
out = self.bn2(out)

if self.groups > 1:
out = self.conv_merge(out)

return self.skip_add.add(out, x)


class FeatureFusionBlock_custom(nn.Module):
"""Feature fusion block."""

def __init__(
self,
features,
activation,
deconv=False,
bn=False,
expand=False,
align_corners=True,
):
"""Init.

Args:
features (int): number of features
"""
super(FeatureFusionBlock_custom, self).__init__()

self.deconv = deconv
self.align_corners = align_corners

self.groups = 1

self.expand = expand
out_features = features
if self.expand is True:
out_features = features // 2

self.out_conv = nn.Conv2d(
features,
out_features,
kernel_size=1,
stride=1,
padding=0,
bias=True,
groups=1,
)

self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)

self.skip_add = nn.quantized.FloatFunctional()

def forward(self, *xs):
"""Forward pass.

Returns:
tensor: output
"""
output = xs[0]

if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)

output = self.resConfUnit2(output)

output = nn.functional.interpolate(
output,
scale_factor=2,
mode='bilinear',
align_corners=self.align_corners)

output = self.out_conv(output)
return output

+ 107
- 0
modelscope/models/cv/text_driven_segmentation/lseg_model.py View File

@@ -0,0 +1,107 @@
import os.path as osp
from typing import Any, Dict

import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image

from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.cv.text_driven_segmentation import \
TextDrivenSegmentation
from modelscope.outputs import OutputKeys
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()
__all__ = ['TextDrivenSeg']


@MODELS.register_module(
Tasks.text_driven_segmentation,
module_name=Models.text_driven_segmentation)
class TextDrivenSeg(TorchModel):
""" text driven segmentation model.
"""

def __init__(self, model_dir, device_id=0, *args, **kwargs):
super().__init__(
model_dir=model_dir, device_id=device_id, *args, **kwargs)
self.model = TextDrivenSegmentation(model_dir=model_dir)
pretrained_params = torch.load('{}/{}'.format(
model_dir, ModelFile.TORCH_MODEL_BIN_FILE))
self.model.load_state_dict(pretrained_params)
self.model.eval()
if device_id >= 0 and torch.cuda.is_available():
self.model.to('cuda:{}'.format(device_id))
logger.info('Use GPU: {}'.format(device_id))
else:
device_id = -1
logger.info('Use CPU for inference')
self.device_id = device_id

def preprocess(self, img, size=640):
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
h, w, c = img.shape
max_hw = max(h, w)
ratio = 1.0 * size / max_hw
crop_h, crop_w = int(ratio * h), int(ratio * w)
pil_img = Image.fromarray(img)
pil_img = pil_img.resize((crop_w, crop_h), Image.BILINEAR)
np_img = np.array(pil_img, dtype=np.float32) / 255.
for j in range(3):
np_img[:, :, j] = (np_img[:, :, j] - mean[j]) / std[j]
img_pad = np.zeros((size, size, 3), dtype=np.float32)
img_pad[:crop_h, :crop_w] = np_img
img_pad = torch.from_numpy(img_pad).permute(2, 0,
1).unsqueeze(0).float()
return img_pad, h, w, crop_h, crop_w

def postprocess(self, tensors, crop_h, crop_w, ori_h, ori_w):
output = np.clip(tensors * 255., a_min=0, a_max=255.)
crop_output = np.array(output[:crop_h, :crop_w], dtype=np.uint8)
pil_output = Image.fromarray(crop_output)
pil_output = pil_output.resize((ori_w, ori_h), Image.BILINEAR)
np_output = np.array(pil_output, dtype=np.uint8)
np_output[np_output < 128] = 0
np_output[np_output >= 128] = 255
np_output = np.uint8(np_output)
return np_output

def forward(self, image, text):
"""
image should be numpy array, dtype=np.uint8, shape: height*width*3
"""
image_tensor, ori_h, ori_w, crop_h, crop_w = self.preprocess(
image, size=640)
pred = self.inference(image_tensor, text)
msk = self.postprocess(pred, crop_h, crop_w, ori_h, ori_w, size=640)
outputs = {OutputKeys.MASKS: msk}
return outputs

def inference(self, image, text):
"""
image should be tensor, 1 * 3 * 640 * 640
"""
with torch.no_grad():
if self.device_id == -1:
output = self.model(image)
else:
device = torch.device('cuda', self.device_id)
output = self.model(image.to(device), [text])
output = F.interpolate(output, size=(640, 640), mode='bilinear')
output = F.softmax(output, dim=1)
output = torch.argmax(output, dim=1)
output = output[0]
if self.device_id == -1:
pred = output.data.numpy()
else:
pred = output.data.cpu().numpy()
del output
return pred

+ 197
- 0
modelscope/models/cv/text_driven_segmentation/lseg_net.py View File

@@ -0,0 +1,197 @@
"""
Adapted from https://github.com/isl-org/lang-seg.
Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org.
"""

import numpy as np
import torch
import torch.nn as nn

from . import clip
from .lseg_blocks import (FeatureFusionBlock, FeatureFusionBlock_custom,
Interpolate, _make_encoder, forward_vit)
from .simple_tokenizer import SimpleTokenizer


class depthwise_clipseg_conv(nn.Module):

def __init__(self):
super(depthwise_clipseg_conv, self).__init__()
self.depthwise = nn.Conv2d(1, 1, kernel_size=3, padding=1)

def depthwise_clipseg(self, x, channels):
x = torch.cat(
[self.depthwise(x[:, i].unsqueeze(1)) for i in range(channels)],
dim=1)
return x

def forward(self, x):
channels = x.shape[1]
out = self.depthwise_clipseg(x, channels)
return out


class depthwise_conv(nn.Module):

def __init__(self, kernel_size=3, stride=1, padding=1):
super(depthwise_conv, self).__init__()
self.depthwise = nn.Conv2d(
1, 1, kernel_size=kernel_size, stride=stride, padding=padding)

def forward(self, x):
# support for 4D tensor with NCHW
C, H, W = x.shape[1:]
x = x.reshape(-1, 1, H, W)
x = self.depthwise(x)
x = x.view(-1, C, H, W)
return x


class depthwise_block(nn.Module):

def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'):
super(depthwise_block, self).__init__()
self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1)
if activation == 'relu':
self.activation = nn.ReLU()
elif activation == 'lrelu':
self.activation = nn.LeakyReLU()
elif activation == 'tanh':
self.activation = nn.Tanh()

def forward(self, x, act=True):
x = self.depthwise(x)
if act:
x = self.activation(x)
return x


class bottleneck_block(nn.Module):

def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'):
super(bottleneck_block, self).__init__()
self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1)
if activation == 'relu':
self.activation = nn.ReLU()
elif activation == 'lrelu':
self.activation = nn.LeakyReLU()
elif activation == 'tanh':
self.activation = nn.Tanh()

def forward(self, x, act=True):
sum_layer = x.max(dim=1, keepdim=True)[0]
x = self.depthwise(x)
x = x + sum_layer
if act:
x = self.activation(x)
return x


class BaseModel(torch.nn.Module):

def load(self, path):
"""Load model from file.
Args:
path (str): file path
"""
parameters = torch.load(path, map_location=torch.device('cpu'))

if 'optimizer' in parameters:
parameters = parameters['model']

self.load_state_dict(parameters)


def _make_fusion_block(features, use_bn):
return FeatureFusionBlock_custom(
features,
activation=nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
)


class LSeg(BaseModel):

def __init__(
self,
features=256,
backbone='clip_vitl16_384',
readout='project',
use_bn=True,
model_dir=None,
):
super(LSeg, self).__init__()
hooks = {
'clip_vitl16_384': [5, 11, 17, 23],
}

# Instantiate backbone and reassemble blocks
self.clip_pretrained, self.pretrained, self.scratch = _make_encoder(
backbone,
features,
groups=1,
expand=False,
exportable=False,
hooks=hooks[backbone],
use_readout=readout,
)

self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)

self.logit_scale = nn.Parameter(torch.ones([])
* np.log(1 / 0.07)).exp()
self.out_c = 512
self.scratch.head1 = nn.Conv2d(features, self.out_c, kernel_size=1)

self.scratch.output_conv = nn.Sequential(
Interpolate(scale_factor=2, mode='bilinear', align_corners=True), )

self.tau = 0.07
self.model_dir = model_dir
self.tokenizer = SimpleTokenizer(model_dir
+ '/bpe_simple_vocab_16e6.txt.gz')

def forward(self, x, labelset=''):
text = clip.tokenize(self.tokenizer, labelset)

layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)

layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)

path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)

text = text.to(x.device)
text_features = self.clip_pretrained.encode_text(text)

image_features = self.scratch.head1(path_1)

imshape = image_features.shape
image_features = image_features.permute(0, 2, 3,
1).reshape(-1, self.out_c)

# normalized features
image_features = image_features / image_features.norm(
dim=-1, keepdim=True)
text_features = text_features / text_features.norm(
dim=-1, keepdim=True)

logits_per_image = image_features @ text_features.t() / self.tau

out = logits_per_image.float().view(imshape[0], imshape[2], imshape[3],
-1).permute(0, 3, 1, 2)

out = self.scratch.output_conv(out)

return out

+ 543
- 0
modelscope/models/cv/text_driven_segmentation/lseg_vit.py View File

@@ -0,0 +1,543 @@
"""
Adapted from https://github.com/isl-org/lang-seg.
Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org.
"""

import math
import types

import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint

from . import clip

activations = {}


def get_activation(name):

def hook(model, input, output):
activations[name] = output

return hook


attention = {}


def get_attention(name):

def hook(module, input, output):
x = input[0]
B, N, C = x.shape
qkv = (
module.qkv(x).reshape(B, N, 3, module.num_heads,
C // module.num_heads).permute(
2, 0, 3, 1, 4))
q, k, _ = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)

attn = (q @ k.transpose(-2, -1)) * module.scale

attn = attn.softmax(dim=-1) # [:,:,1,1:]
attention[name] = attn

return hook


def get_mean_attention_map(attn, token, shape):
attn = attn[:, :, token, 1:]
attn = attn.unflatten(2, torch.Size([shape[2] // 16,
shape[3] // 16])).float()
attn = torch.nn.functional.interpolate(
attn, size=shape[2:], mode='bicubic', align_corners=False).squeeze(0)

all_attn = torch.mean(attn, 0)

return all_attn


class Slice(nn.Module):

def __init__(self, start_index=1):
super(Slice, self).__init__()
self.start_index = start_index

def forward(self, x):
return x[:, self.start_index:]


class AddReadout(nn.Module):

def __init__(self, start_index=1):
super(AddReadout, self).__init__()
self.start_index = start_index

def forward(self, x):
if self.start_index == 2:
readout = (x[:, 0] + x[:, 1]) / 2
else:
readout = x[:, 0]
return x[:, self.start_index:] + readout.unsqueeze(1)


class ProjectReadout(nn.Module):

def __init__(self, in_features, start_index=1):
super(ProjectReadout, self).__init__()
self.start_index = start_index

self.project = nn.Sequential(
nn.Linear(2 * in_features, in_features), nn.GELU())

def forward(self, x):
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:])
features = torch.cat((x[:, self.start_index:], readout), -1)

return self.project(features)


class Transpose(nn.Module):

def __init__(self, dim0, dim1):
super(Transpose, self).__init__()
self.dim0 = dim0
self.dim1 = dim1

def forward(self, x):
x = x.transpose(self.dim0, self.dim1)
return x


def forward_vit(pretrained, x):
b, c, h, w = x.shape

# encoder
_ = pretrained.model.forward_flex(x)

layer_1 = pretrained.activations['1']
layer_2 = pretrained.activations['2']
layer_3 = pretrained.activations['3']
layer_4 = pretrained.activations['4']

layer_1 = pretrained.act_postprocess1[0:2](layer_1)
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
layer_4 = pretrained.act_postprocess4[0:2](layer_4)

unflatten = nn.Sequential(
nn.Unflatten(
2,
torch.Size([
h // pretrained.model.patch_size[1],
w // pretrained.model.patch_size[0],
]),
))

if layer_1.ndim == 3:
layer_1 = unflatten(layer_1)
if layer_2.ndim == 3:
layer_2 = unflatten(layer_2)
if layer_3.ndim == 3:
layer_3 = unflatten(layer_3)
if layer_4.ndim == 3:
layer_4 = unflatten(layer_4)

layer_1 = pretrained.act_postprocess1[3:len(pretrained.act_postprocess1)](
layer_1)
layer_2 = pretrained.act_postprocess2[3:len(pretrained.act_postprocess2)](
layer_2)
layer_3 = pretrained.act_postprocess3[3:len(pretrained.act_postprocess3)](
layer_3)
layer_4 = pretrained.act_postprocess4[3:len(pretrained.act_postprocess4)](
layer_4)

return layer_1, layer_2, layer_3, layer_4


def _resize_pos_embed(self, posemb, gs_h, gs_w):
posemb_tok, posemb_grid = (
posemb[:, :self.start_index],
posemb[0, self.start_index:],
)

gs_old = int(math.sqrt(len(posemb_grid)))

posemb_grid = posemb_grid.reshape(1, gs_old, gs_old,
-1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(
posemb_grid, size=(gs_h, gs_w), mode='bilinear')
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)

posemb = torch.cat([posemb_tok, posemb_grid], dim=1)

return posemb


def forward_flex(self, x):
b, c, h, w = x.shape

pos_embed = self._resize_pos_embed(self.pos_embed, h // self.patch_size[1],
w // self.patch_size[0])

B = x.shape[0]

if hasattr(self.patch_embed, 'backbone'):
x = self.patch_embed.backbone(x)
if isinstance(x, (list, tuple)):
x = x[
-1] # last feature if backbone outputs list/tuple of features
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)

if getattr(self, 'dist_token', None) is not None:
cls_tokens = self.cls_token.expand(
B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
dist_token = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
else:
cls_tokens = self.cls_token.expand(
B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)

x = x + pos_embed
x = self.pos_drop(x)

gradient_checkpoint = False
for blk in self.blocks:
if gradient_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)

x = self.norm(x)

return x


def get_readout_oper(vit_features, features, use_readout, start_index=1):
if use_readout == 'ignore':
readout_oper = [Slice(start_index)] * len(features)
elif use_readout == 'add':
readout_oper = [AddReadout(start_index)] * len(features)
elif use_readout == 'project':
readout_oper = [
ProjectReadout(vit_features, start_index) for out_feat in features
]
else:
assert (
False
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"

return readout_oper


def adapt_input_conv(in_chans, conv_weight):
conv_type = conv_weight.dtype
conv_weight = conv_weight.float(
) # Some weights are in torch.half, ensure it's float for sum on CPU
O, II, J, K = conv_weight.shape
if in_chans == 1:
if II > 3:
assert conv_weight.shape[1] % 3 == 0
# For models with space2depth stems
conv_weight = conv_weight.reshape(O, II // 3, 3, J, K)
conv_weight = conv_weight.sum(dim=2, keepdim=False)
else:
conv_weight = conv_weight.sum(dim=1, keepdim=True)
elif in_chans != 3:
if II != 3:
raise NotImplementedError(
'Weight format not supported by conversion.')
else:
# NOTE this strategy should be better than random init, but there could be other combinations of
# the original RGB input layer weights that'd work better for specific cases.
repeat = int(math.ceil(in_chans / 3))
conv_weight = conv_weight.repeat(1, repeat, 1,
1)[:, :in_chans, :, :]
conv_weight *= (3 / float(in_chans))
conv_weight = conv_weight.to(conv_type)
return conv_weight


@torch.no_grad()
def _load_weights(model, checkpoint_path, prefix=''):
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
"""
import numpy as np

def _n2p(w, t=True):
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
w = w.flatten()
if t:
if w.ndim == 4:
w = w.transpose([3, 2, 0, 1])
elif w.ndim == 3:
w = w.transpose([2, 0, 1])
elif w.ndim == 2:
w = w.transpose([1, 0])
return torch.from_numpy(w)

w = np.load(checkpoint_path)
if not prefix and 'opt/target/embedding/kernel' in w:
prefix = 'opt/target/'

if hasattr(model.patch_embed, 'backbone'):
# hybrid
backbone = model.patch_embed.backbone
stem_only = not hasattr(backbone, 'stem')
stem = backbone if stem_only else backbone.stem
stem.conv.weight.copy_(
adapt_input_conv(stem.conv.weight.shape[1],
_n2p(w[f'{prefix}conv_root/kernel'])))
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
if not stem_only:
for i, stage in enumerate(backbone.stages):
for j, block in enumerate(stage.blocks):
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
for r in range(3):
getattr(block, f'conv{r + 1}').weight.copy_(
_n2p(w[f'{bp}conv{r + 1}/kernel']))
getattr(block, f'norm{r + 1}').weight.copy_(
_n2p(w[f'{bp}gn{r + 1}/scale']))
getattr(block, f'norm{r + 1}').bias.copy_(
_n2p(w[f'{bp}gn{r + 1}/bias']))
if block.downsample is not None:
block.downsample.conv.weight.copy_(
_n2p(w[f'{bp}conv_proj/kernel']))
block.downsample.norm.weight.copy_(
_n2p(w[f'{bp}gn_proj/scale']))
block.downsample.norm.bias.copy_(
_n2p(w[f'{bp}gn_proj/bias']))
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
else:
embed_conv_w = adapt_input_conv(model.patch_embed.proj.weight.shape[1],
_n2p(w[f'{prefix}embedding/kernel']))
model.patch_embed.proj.weight.copy_(embed_conv_w)
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
pos_embed_w = _n2p(
w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
if pos_embed_w.shape != model.pos_embed.shape:
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
pos_embed_w, model.pos_embed, getattr(model, 'num_prefix_tokens',
1),
model.patch_embed.grid_size)
model.pos_embed.copy_(pos_embed_w)
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
if isinstance(
model.head, nn.Linear
) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
# NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
for i, block in enumerate(model.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.attn.qkv.weight.copy_(
torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T
for n in ('query', 'key', 'value')
]))
block.attn.qkv.bias.copy_(
torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1)
for n in ('query', 'key', 'value')
]))
block.attn.proj.weight.copy_(
_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
for r in range(2):
getattr(block.mlp, f'fc{r + 1}').weight.copy_(
_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(
_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))


def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
ntok_new = posemb_new.shape[1]
if num_prefix_tokens:
posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[
0, num_prefix_tokens:]
ntok_new -= num_prefix_tokens
else:
posemb_prefix, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(math.sqrt(len(posemb_grid)))
if not len(gs_new): # backwards compatibility
gs_new = [int(math.sqrt(ntok_new))] * 2
assert len(gs_new) >= 2
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old,
-1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(
posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
posemb_grid = posemb_grid.permute(0, 2, 3,
1).reshape(1, gs_new[0] * gs_new[1], -1)
posemb = torch.cat([posemb_prefix, posemb_grid], dim=1)
return posemb


def _make_pretrained_clip_vitl16_384(pretrained,
use_readout='ignore',
hooks=None,
enable_attention_hooks=False):
clip_pretrained, _ = clip.load('ViT-B/32', device='cpu', jit=False)

# model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
model = timm.create_model('vit_large_patch16_384', pretrained=False)
hooks = [5, 11, 17, 23] if hooks is None else hooks
pretrained = _make_vit_b16_backbone(
model,
features=[256, 512, 1024, 1024],
hooks=hooks,
vit_features=1024,
use_readout=use_readout,
enable_attention_hooks=enable_attention_hooks,
)
return clip_pretrained, pretrained


def _make_vit_b16_backbone(
model,
features=[96, 192, 384, 768],
size=[384, 384],
hooks=[2, 5, 8, 11],
vit_features=768,
use_readout='ignore',
start_index=1,
enable_attention_hooks=False,
):
pretrained = nn.Module()

pretrained.model = model
pretrained.model.blocks[hooks[0]].register_forward_hook(
get_activation('1'))
pretrained.model.blocks[hooks[1]].register_forward_hook(
get_activation('2'))
pretrained.model.blocks[hooks[2]].register_forward_hook(
get_activation('3'))
pretrained.model.blocks[hooks[3]].register_forward_hook(
get_activation('4'))

pretrained.activations = activations

if enable_attention_hooks:
pretrained.model.blocks[hooks[0]].attn.register_forward_hook(
get_attention('attn_1'))
pretrained.model.blocks[hooks[1]].attn.register_forward_hook(
get_attention('attn_2'))
pretrained.model.blocks[hooks[2]].attn.register_forward_hook(
get_attention('attn_3'))
pretrained.model.blocks[hooks[3]].attn.register_forward_hook(
get_attention('attn_4'))
pretrained.attention = attention

readout_oper = get_readout_oper(vit_features, features, use_readout,
start_index)

# 32, 48, 136, 384
pretrained.act_postprocess1 = nn.Sequential(
readout_oper[0],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[0],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[0],
out_channels=features[0],
kernel_size=4,
stride=4,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)

pretrained.act_postprocess2 = nn.Sequential(
readout_oper[1],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[1],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[1],
out_channels=features[1],
kernel_size=2,
stride=2,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)

pretrained.act_postprocess3 = nn.Sequential(
readout_oper[2],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[2],
kernel_size=1,
stride=1,
padding=0,
),
)

pretrained.act_postprocess4 = nn.Sequential(
readout_oper[3],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[3],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2d(
in_channels=features[3],
out_channels=features[3],
kernel_size=3,
stride=2,
padding=1,
),
)

pretrained.model.start_index = start_index
pretrained.model.patch_size = [16, 16]

# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex,
pretrained.model)
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model)

return pretrained

+ 458
- 0
modelscope/models/cv/text_driven_segmentation/model.py View File

@@ -0,0 +1,458 @@
"""
Adapted from https://github.com/isl-org/lang-seg.
Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org.
"""

from collections import OrderedDict
from typing import Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn


class Bottleneck(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, stride=1):
super().__init__()

# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.relu1 = nn.ReLU(inplace=True)

self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.relu2 = nn.ReLU(inplace=True)

self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()

self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu3 = nn.ReLU(inplace=True)

self.downsample = None
self.stride = stride

if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(
OrderedDict([('-1', nn.AvgPool2d(stride)),
('0',
nn.Conv2d(
inplanes,
planes * self.expansion,
1,
stride=1,
bias=False)),
('1', nn.BatchNorm2d(planes * self.expansion))]))

def forward(self, x: torch.Tensor):
identity = x

out = self.relu1(self.bn1(self.conv1(x)))
out = self.relu2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu3(out)
return out


class AttentionPool2d(nn.Module):

def __init__(self,
spacial_dim: int,
embed_dim: int,
num_heads: int,
output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads

def forward(self, x):
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x[:1],
key=x,
value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat(
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False)
return x.squeeze(0)


class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""

def __init__(self,
layers,
output_dim,
heads,
input_resolution=224,
width=64):
super().__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution

# the 3-layer stem
self.conv1 = nn.Conv2d(
3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(
width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(
width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.relu3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)

# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)

embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim,
heads, output_dim)

def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]

self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))

return nn.Sequential(*layers)

def forward(self, x):

def stem(x):
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x

x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)

return x


class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""

def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)


class QuickGELU(nn.Module):

def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)


class ResidualAttentionBlock(nn.Module):

def __init__(self,
d_model: int,
n_head: int,
attn_mask: torch.Tensor = None):
super().__init__()

self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
('gelu', QuickGELU()),
('c_proj', nn.Linear(d_model * 4, d_model))]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask

def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(
dtype=x.dtype,
device=x.device) if self.attn_mask is not None else None
return self.attn(
x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x


class Transformer(nn.Module):

def __init__(self, width, layers, heads, attn_mask=None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[
ResidualAttentionBlock(width, heads, attn_mask)
for _ in range(layers)
])

def forward(self, x: torch.Tensor):
return self.resblocks(x)


class VisionTransformer(nn.Module):

def __init__(self, input_resolution: int, patch_size: int, width: int,
layers: int, heads: int, output_dim: int):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False)

scale = width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn(
(input_resolution // patch_size)**2 + 1, width))
self.ln_pre = LayerNorm(width)

self.transformer = Transformer(width, layers, heads)

self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1],
-1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x1 = self.class_embedding.to(x.dtype)
x2 = torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
x = torch.cat([x1 + x2, x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)

x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD

x = self.ln_post(x[:, 0, :])

if self.proj is not None:
x = x @ self.proj

return x


class CLIP(nn.Module):

def __init__(
self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int):
super().__init__()

self.context_length = context_length

if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width)
else:
vision_heads = vision_width // 64
self.visual = VisionTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim)

self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask())

self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(
torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)

self.text_projection = nn.Parameter(
torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

self.initialize_parameters()

def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)

if isinstance(self.visual, ModifiedResNet):
if self.visual.attnpool is not None:
std = self.visual.attnpool.c_proj.in_features**-0.5
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)

for resnet_block in [
self.visual.layer1, self.visual.layer2, self.visual.layer3,
self.visual.layer4
]:
for name, param in resnet_block.named_parameters():
if name.endswith('bn3.weight'):
nn.init.zeros_(param)

proj_std = (self.transformer.width**-0.5) * (
(2 * self.transformer.layers)**-0.5)
attn_std = self.transformer.width**-0.5
fc_std = (2 * self.transformer.width)**-0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)

if self.text_projection is not None:
nn.init.normal_(
self.text_projection, std=self.transformer.width**-0.5)

def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float('-inf'))
mask.triu_(1) # zero out the lower diagonal
return mask

@property
def dtype(self):
return self.visual.conv1.weight.dtype

def encode_image(self, image):
return self.visual(image.type(self.dtype))

def encode_text(self, text):
x = self.token_embedding(text).type(self.dtype)
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
x = x[torch.arange(x.shape[0]),
text.argmax(dim=-1)] @ self.text_projection
return x

def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)

# normalized features
image_features = image_features / image_features.norm(
dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)

# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()

# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text


def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp16"""

def _convert_weights_to_fp16(ll):
if isinstance(ll, (nn.Conv1d, nn.Conv2d, nn.Linear)):
ll.weight.data = ll.weight.data.half()
if ll.bias is not None:
ll.bias.data = ll.bias.data.half()

if isinstance(ll, nn.MultiheadAttention):
for attr in [
*[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']],
'in_proj_bias', 'bias_k', 'bias_v'
]:
tensor = getattr(ll, attr)
if tensor is not None:
tensor.data = tensor.data.half()

for name in ['text_projection', 'proj']:
if hasattr(ll, name):
attr = getattr(ll, name)
if attr is not None:
attr.data = attr.data.half()

model.apply(_convert_weights_to_fp16)


def build_model():
model = CLIP(512, 224, 12, 768, 32, 77, 49408, 512, 8, 12)
convert_weights(model)
return model.eval()

+ 156
- 0
modelscope/models/cv/text_driven_segmentation/simple_tokenizer.py View File

@@ -0,0 +1,156 @@
""" CLIP
Adapted from https://github.com/openai/CLIP.
Originally MIT License, Copyright (c) 2021 OpenAI.
"""

import gzip
import html
import os
from functools import lru_cache

import ftfy
import regex as re


@lru_cache()
def default_bpe():
return os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'bpe_simple_vocab_16e6.txt.gz')


@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord('!'),
ord('~') + 1)) + list(range(
ord('¡'),
ord('¬') + 1)) + list(range(ord('®'),
ord('ÿ') + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))


def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs


def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()


def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text


class SimpleTokenizer(object):

def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
merges = merges[1:49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + '</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {
'<|startoftext|>': '<|startoftext|>',
'<|endoftext|>': '<|endoftext|>'
}
self.pat = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE)

def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + '</w>', )
pairs = get_pairs(word)

if not pairs:
return token + '</w>'

while True:
bigram = min(
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
error_list = []
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except Exception as err:
new_word.extend(word[i:])
error_list.append(err)
break

if word[i] == first and i < len(word) - 1 and word[
i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word

def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b]
for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token]
for bpe_token in self.bpe(token).split(' '))
return bpe_tokens

def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode(
'utf-8', errors='replace').replace('</w>', ' ')
return text

+ 7
- 0
modelscope/outputs.py View File

@@ -243,6 +243,13 @@ TASK_OUTPUTS = {
# "output_img": np.ndarray with shape [height, width, 3]
# }
Tasks.virtual_try_on: [OutputKeys.OUTPUT_IMG],
# text driven segmentation result for single sample
# {
# "masks": [
# np.array # 2D array containing only 0, 255
# ]
# }
Tasks.text_driven_segmentation: [OutputKeys.MASKS],

# movide scene segmentation result for a single video
# {


+ 3
- 0
modelscope/pipelines/builder.py View File

@@ -149,6 +149,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_vitb_video-single-object-tracking_ostrack'),
Tasks.image_reid_person: (Pipelines.image_reid_person,
'damo/cv_passvitb_image-reid-person_market'),
Tasks.text_driven_segmentation:
(Pipelines.text_driven_segmentation,
'damo/cv_vitl16_segmentation_text-driven-seg'),
Tasks.movie_scene_segmentation:
(Pipelines.movie_scene_segmentation,
'damo/cv_resnet50-bert_video-scene-segmentation_movienet')


+ 3
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -44,6 +44,7 @@ if TYPE_CHECKING:
from .video_category_pipeline import VideoCategoryPipeline
from .virtual_try_on_pipeline import VirtualTryonPipeline
from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline
from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipleline
from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline

else:
@@ -97,6 +98,8 @@ else:
'virtual_try_on_pipeline': ['VirtualTryonPipeline'],
'easycv_pipeline':
['EasyCVDetectionPipeline', 'EasyCVSegmentationPipeline'],
'text_driven_segmentation_pipeline':
['TextDrivenSegmentationPipeline'],
'movie_scene_segmentation_pipeline':
['MovieSceneSegmentationPipeline'],
}


+ 51
- 0
modelscope/pipelines/cv/text_driven_segmentation_pipleline.py View File

@@ -0,0 +1,51 @@
from typing import Any, Dict

from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import Tasks


@PIPELINES.register_module(
Tasks.text_driven_segmentation,
module_name=Pipelines.text_driven_segmentation)
class TextDrivenSegmentationPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
model: model id on modelscope hub.
"""
super().__init__(model=model, auto_collate=False, **kwargs)

def preprocess(self, input: Dict) -> Dict[str, Any]:
img = LoadImage.convert_to_ndarray(input['image'])
img_tensor, ori_h, ori_w, crop_h, crop_w = self.model.preprocess(img)
result = {
'img': img_tensor,
'ori_h': ori_h,
'ori_w': ori_w,
'crop_h': crop_h,
'crop_w': crop_w,
'text': input['text'],
}
return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
outputs = self.model.inference(input['img'], input['text'])
result = {
'data': outputs,
'ori_h': input['ori_h'],
'ori_w': input['ori_w'],
'crop_h': input['crop_h'],
'crop_w': input['crop_w'],
}
return result

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
data = self.model.postprocess(inputs['data'], inputs['crop_h'],
inputs['crop_w'], inputs['ori_h'],
inputs['ori_w'])
outputs = {OutputKeys.MASKS: data}
return outputs

+ 1
- 0
modelscope/utils/constant.py View File

@@ -36,6 +36,7 @@ class CVTasks(object):

image_segmentation = 'image-segmentation'
portrait_matting = 'portrait-matting'
text_driven_segmentation = 'text-driven-segmentation'

# image editing
skin_retouching = 'skin-retouching'


+ 28
- 0
tests/pipelines/test_text_driven_segmentation.py View File

@@ -0,0 +1,28 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class TextDrivenSegmentationTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_text_driven_segmentation(self):
input_location = 'data/test/images/text_driven_segmentation.jpg'
test_input = {
'image': input_location,
'text': 'bear',
}
model_id = 'damo/cv_vitl16_segmentation_text-driven-seg'
shop_seg = pipeline(Tasks.text_driven_segmentation, model=model_id)
result = shop_seg(test_input)
import cv2
# result[OutputKeys.MASKS] is segment map result,other keys are not used
cv2.imwrite(input_location + '_lseg.jpg', result[OutputKeys.MASKS])


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save