Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9592902master
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:2d68cfcaa7cc7b8276877c2dfa022deebe82076bc178ece1bfe7fd5423cd5b99 | |||
size 60009 |
@@ -101,6 +101,7 @@ class Pipelines(object): | |||
image2image_translation = 'image-to-image-translation' | |||
live_category = 'live-category' | |||
video_category = 'video-category' | |||
ocr_recognition = 'convnextTiny-ocr-recognition' | |||
image_portrait_enhancement = 'gpen-image-portrait-enhancement' | |||
image_to_image_generation = 'image-to-image-generation' | |||
skin_retouching = 'unet-skin-retouching' | |||
@@ -47,6 +47,12 @@ TASK_OUTPUTS = { | |||
# } | |||
Tasks.ocr_detection: [OutputKeys.POLYGONS], | |||
# ocr recognition result for single sample | |||
# { | |||
# "text": "电子元器件提供BOM配单" | |||
# } | |||
Tasks.ocr_recognition: [OutputKeys.TEXT], | |||
# face detection result for single sample | |||
# { | |||
# "scores": [0.9, 0.1, 0.05, 0.05] | |||
@@ -119,6 +119,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
Tasks.image_classification: | |||
(Pipelines.daily_image_classification, | |||
'damo/cv_vit-base_image-classification_Dailylife-labels'), | |||
Tasks.ocr_recognition: (Pipelines.ocr_recognition, | |||
'damo/cv_convnextTiny_ocr-recognition_damo'), | |||
Tasks.skin_retouching: (Pipelines.skin_retouching, | |||
'damo/cv_unet_skin-retouching'), | |||
} | |||
@@ -29,6 +29,7 @@ if TYPE_CHECKING: | |||
from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline | |||
from .live_category_pipeline import LiveCategoryPipeline | |||
from .ocr_detection_pipeline import OCRDetectionPipeline | |||
from .ocr_recognition_pipeline import OCRRecognitionPipeline | |||
from .skin_retouching_pipeline import SkinRetouchingPipeline | |||
from .tinynas_classification_pipeline import TinynasClassificationPipeline | |||
from .video_category_pipeline import VideoCategoryPipeline | |||
@@ -65,6 +66,7 @@ else: | |||
'image_to_image_generation_pipeline': | |||
['Image2ImageGenerationPipeline'], | |||
'ocr_detection_pipeline': ['OCRDetectionPipeline'], | |||
'ocr_recognition_pipeline': ['OCRRecognitionPipeline'], | |||
'skin_retouching_pipeline': ['SkinRetouchingPipeline'], | |||
'tinynas_classification_pipeline': ['TinynasClassificationPipeline'], | |||
'video_category_pipeline': ['VideoCategoryPipeline'], | |||
@@ -0,0 +1,131 @@ | |||
import math | |||
import os.path as osp | |||
from typing import Any, Dict | |||
import cv2 | |||
import numpy as np | |||
import PIL | |||
import torch | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.pipelines.cv.ocr_utils.model_convnext_transformer import \ | |||
OCRRecModel | |||
from modelscope.preprocessors import load_image | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
# constant | |||
NUM_CLASSES = 7644 | |||
IMG_HEIGHT = 32 | |||
IMG_WIDTH = 300 | |||
PRED_LENTH = 75 | |||
PRED_PAD = 6 | |||
@PIPELINES.register_module( | |||
Tasks.ocr_recognition, module_name=Pipelines.ocr_recognition) | |||
class OCRRecognitionPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
super().__init__(model=model, **kwargs) | |||
model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) | |||
label_path = osp.join(self.model, 'label_dict.txt') | |||
logger.info(f'loading model from {model_path}') | |||
self.device = torch.device( | |||
'cuda' if torch.cuda.is_available() else 'cpu') | |||
self.infer_model = OCRRecModel(NUM_CLASSES).to(self.device) | |||
self.infer_model.eval() | |||
self.infer_model.load_state_dict( | |||
torch.load(model_path, map_location=self.device)) | |||
self.labelMapping = dict() | |||
with open(label_path, 'r') as f: | |||
lines = f.readlines() | |||
cnt = 2 | |||
for line in lines: | |||
line = line.strip('\n') | |||
self.labelMapping[cnt] = line | |||
cnt += 1 | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
if isinstance(input, str): | |||
img = np.array(load_image(input).convert('L')) | |||
elif isinstance(input, PIL.Image.Image): | |||
img = np.array(input.convert('L')) | |||
elif isinstance(input, np.ndarray): | |||
if len(input.shape) == 3: | |||
img = cv2.cvtColor(input, cv2.COLOR_RGB2GRAY) | |||
else: | |||
raise TypeError(f'input should be either str, PIL.Image,' | |||
f' np.array, but got {type(input)}') | |||
data = [] | |||
img_h, img_w = img.shape | |||
wh_ratio = img_w / img_h | |||
true_w = int(IMG_HEIGHT * wh_ratio) | |||
split_batch_cnt = 1 | |||
if true_w < IMG_WIDTH * 1.2: | |||
img = cv2.resize(img, (min(true_w, IMG_WIDTH), IMG_HEIGHT)) | |||
else: | |||
split_batch_cnt = math.ceil((true_w - 48) * 1.0 / 252) | |||
img = cv2.resize(img, (true_w, IMG_HEIGHT)) | |||
if split_batch_cnt == 1: | |||
mask = np.zeros((IMG_HEIGHT, IMG_WIDTH)) | |||
mask[:, :img.shape[1]] = img | |||
data.append(mask) | |||
else: | |||
for idx in range(split_batch_cnt): | |||
mask = np.zeros((IMG_HEIGHT, IMG_WIDTH)) | |||
left = (PRED_LENTH * 4 - PRED_PAD * 4) * idx | |||
trunk_img = img[:, left:min(left + PRED_LENTH * 4, true_w)] | |||
mask[:, :trunk_img.shape[1]] = trunk_img | |||
data.append(mask) | |||
data = torch.FloatTensor(data).view( | |||
len(data), 1, IMG_HEIGHT, IMG_WIDTH).cuda() / 255. | |||
result = {'img': data} | |||
return result | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
pred = self.infer_model(input['img']) | |||
return {'results': pred} | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
preds = inputs['results'] | |||
batchSize, length = preds.shape | |||
pred_idx = [] | |||
if batchSize == 1: | |||
pred_idx = preds[0].cpu().data.tolist() | |||
else: | |||
for idx in range(batchSize): | |||
if idx == 0: | |||
pred_idx.extend(preds[idx].cpu().data[:PRED_LENTH | |||
- PRED_PAD].tolist()) | |||
elif idx == batchSize - 1: | |||
pred_idx.extend(preds[idx].cpu().data[PRED_PAD:].tolist()) | |||
else: | |||
pred_idx.extend(preds[idx].cpu().data[PRED_PAD:PRED_LENTH | |||
- PRED_PAD].tolist()) | |||
# ctc decoder | |||
last_p = 0 | |||
str_pred = [] | |||
for p in pred_idx: | |||
if p != last_p and p != 0: | |||
str_pred.append(self.labelMapping[p]) | |||
last_p = p | |||
final_str = ''.join(str_pred) | |||
result = {OutputKeys.TEXT: final_str} | |||
return result |
@@ -0,0 +1,23 @@ | |||
import torch | |||
import torch.nn as nn | |||
from .ocr_modules.convnext import convnext_tiny | |||
from .ocr_modules.vitstr import vitstr_tiny | |||
class OCRRecModel(nn.Module): | |||
def __init__(self, num_classes): | |||
super(OCRRecModel, self).__init__() | |||
self.cnn_model = convnext_tiny() | |||
self.num_classes = num_classes | |||
self.vitstr = vitstr_tiny(num_tokens=num_classes) | |||
def forward(self, input): | |||
""" Transformation stage """ | |||
features = self.cnn_model(input) | |||
prediction = self.vitstr(features) | |||
prediction = torch.nn.functional.softmax(prediction, dim=-1) | |||
output = torch.argmax(prediction, -1) | |||
return output |
@@ -0,0 +1,23 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .convnext import convnext_tiny | |||
from .vitstr import vitstr_tiny | |||
else: | |||
_import_structure = { | |||
'convnext': ['convnext_tiny'], | |||
'vitstr': ['vitstr_tiny'] | |||
} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1,169 @@ | |||
""" Contains various versions of ConvNext Networks. | |||
ConvNext Networks (ConvNext) were proposed in: | |||
Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell and Saining Xie | |||
A ConvNet for the 2020s. CVPR 2022. | |||
Compared to https://github.com/facebookresearch/ConvNeXt, | |||
we obtain different ConvNext variants by changing the network depth, width, | |||
feature number, and downsample ratio. | |||
""" | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from .timm_tinyc import DropPath | |||
class Block(nn.Module): | |||
r""" ConvNeXt Block. There are two equivalent implementations: | |||
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) | |||
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back | |||
We use (2) as we find it slightly faster in PyTorch | |||
Args: | |||
dim (int): Number of input channels. | |||
drop_path (float): Stochastic depth rate. Default: 0.0 | |||
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. | |||
""" | |||
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): | |||
super().__init__() | |||
self.dwconv = nn.Conv2d( | |||
dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv | |||
self.norm = LayerNorm(dim, eps=1e-6) | |||
self.pwconv1 = nn.Linear( | |||
dim, | |||
4 * dim) # pointwise/1x1 convs, implemented with linear layers | |||
self.act = nn.GELU() | |||
self.pwconv2 = nn.Linear(4 * dim, dim) | |||
self.gamma = nn.Parameter( | |||
layer_scale_init_value * torch.ones((dim)), | |||
requires_grad=True) if layer_scale_init_value > 0 else None | |||
self.drop_path = DropPath( | |||
drop_path) if drop_path > 0. else nn.Identity() | |||
def forward(self, x): | |||
input = x | |||
x = self.dwconv(x) | |||
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) | |||
x = self.norm(x) | |||
x = self.pwconv1(x) | |||
x = self.act(x) | |||
x = self.pwconv2(x) | |||
if self.gamma is not None: | |||
x = self.gamma * x | |||
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) | |||
x = input + self.drop_path(x) | |||
return x | |||
class ConvNeXt(nn.Module): | |||
r""" ConvNeXt | |||
A PyTorch impl of : `A ConvNet for the 2020s` - | |||
https://arxiv.org/pdf/2201.03545.pdf | |||
Args: | |||
in_chans (int): Number of input image channels. Default: 3 | |||
num_classes (int): Number of classes for classification head. Default: 1000 | |||
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] | |||
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] | |||
drop_path_rate (float): Stochastic depth rate. Default: 0. | |||
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. | |||
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. | |||
""" | |||
def __init__( | |||
self, | |||
in_chans=1, | |||
num_classes=1000, | |||
depths=[3, 3, 9, 3], | |||
dims=[96, 192, 384, 768], | |||
drop_path_rate=0., | |||
layer_scale_init_value=1e-6, | |||
head_init_scale=1., | |||
): | |||
super().__init__() | |||
self.downsample_layers = nn.ModuleList( | |||
) # stem and 3 intermediate downsampling conv layers | |||
stem = nn.Sequential( | |||
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), | |||
LayerNorm(dims[0], eps=1e-6, data_format='channels_first')) | |||
self.downsample_layers.append(stem) | |||
for i in range(3): | |||
downsample_layer = nn.Sequential( | |||
LayerNorm(dims[i], eps=1e-6, data_format='channels_first'), | |||
nn.Conv2d( | |||
dims[i], dims[i + 1], kernel_size=(2, 1), stride=(2, 1)), | |||
) | |||
self.downsample_layers.append(downsample_layer) | |||
self.stages = nn.ModuleList( | |||
) # 4 feature resolution stages, each consisting of multiple residual blocks | |||
dp_rates = [ | |||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) | |||
] | |||
cur = 0 | |||
for i in range(4): | |||
stage = nn.Sequential(*[ | |||
Block( | |||
dim=dims[i], | |||
drop_path=dp_rates[cur + j], | |||
layer_scale_init_value=layer_scale_init_value) | |||
for j in range(depths[i]) | |||
]) | |||
self.stages.append(stage) | |||
cur += depths[i] | |||
def _init_weights(self, m): | |||
if isinstance(m, (nn.Conv2d, nn.Linear)): | |||
trunc_normal_(m.weight, std=.02) | |||
nn.init.constant_(m.bias, 0) | |||
def forward_features(self, x): | |||
for i in range(4): | |||
x = self.downsample_layers[i](x.contiguous()) | |||
x = self.stages[i](x.contiguous()) | |||
return x # global average pooling, (N, C, H, W) -> (N, C) | |||
def forward(self, x): | |||
x = self.forward_features(x.contiguous()) | |||
return x.contiguous() | |||
class LayerNorm(nn.Module): | |||
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. | |||
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with | |||
shape (batch_size, height, width, channels) while channels_first corresponds to inputs | |||
with shape (batch_size, channels, height, width). | |||
""" | |||
def __init__(self, | |||
normalized_shape, | |||
eps=1e-6, | |||
data_format='channels_last'): | |||
super().__init__() | |||
self.weight = nn.Parameter(torch.ones(normalized_shape)) | |||
self.bias = nn.Parameter(torch.zeros(normalized_shape)) | |||
self.eps = eps | |||
self.data_format = data_format | |||
if self.data_format not in ['channels_last', 'channels_first']: | |||
raise NotImplementedError | |||
self.normalized_shape = (normalized_shape, ) | |||
def forward(self, x): | |||
if self.data_format == 'channels_last': | |||
return F.layer_norm(x, self.normalized_shape, self.weight, | |||
self.bias, self.eps) | |||
elif self.data_format == 'channels_first': | |||
u = x.mean(1, keepdim=True) | |||
s = (x - u).pow(2).mean(1, keepdim=True) | |||
x = (x - u) / torch.sqrt(s + self.eps) | |||
x = self.weight[:, None, None] * x + self.bias[:, None, None] | |||
return x | |||
def convnext_tiny(): | |||
model = ConvNeXt(depths=[3, 3, 8, 3], dims=[96, 192, 256, 512]) | |||
return model |
@@ -0,0 +1,334 @@ | |||
'''Referenced from rwightman's pytorch-image-models(timm). | |||
Github: https://github.com/rwightman/pytorch-image-models | |||
We use some modules and modify the parameters according to our network. | |||
''' | |||
import collections.abc | |||
import logging | |||
import math | |||
from collections import OrderedDict | |||
from copy import deepcopy | |||
from functools import partial | |||
from itertools import repeat | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
def _ntuple(n): | |||
def parse(x): | |||
if isinstance(x, collections.abc.Iterable): | |||
return x | |||
return tuple(repeat(x, n)) | |||
return parse | |||
class PatchEmbed(nn.Module): | |||
""" 2D Image to Patch Embedding | |||
""" | |||
def __init__(self, | |||
img_size=224, | |||
patch_size=16, | |||
in_chans=3, | |||
embed_dim=768, | |||
norm_layer=None, | |||
flatten=True): | |||
super().__init__() | |||
img_size = (1, 75) | |||
to_2tuple = _ntuple(2) | |||
patch_size = to_2tuple(patch_size) | |||
self.img_size = img_size | |||
self.patch_size = patch_size | |||
self.grid_size = (img_size[0] // patch_size[0], | |||
img_size[1] // patch_size[1]) | |||
self.num_patches = self.grid_size[0] * self.grid_size[1] | |||
self.flatten = flatten | |||
self.proj = nn.Conv2d( | |||
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | |||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() | |||
def forward(self, x): | |||
B, C, H, W = x.shape | |||
assert H == self.img_size[0] and W == self.img_size[1], \ | |||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." | |||
x = self.proj(x) | |||
x = x.permute(0, 1, 3, 2) | |||
if self.flatten: | |||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC | |||
x = self.norm(x) | |||
return x | |||
class Mlp(nn.Module): | |||
""" MLP as used in Vision Transformer, MLP-Mixer and related networks | |||
""" | |||
def __init__(self, | |||
in_features, | |||
hidden_features=None, | |||
out_features=None, | |||
act_layer=nn.GELU, | |||
drop=0.): | |||
super().__init__() | |||
out_features = out_features or in_features | |||
hidden_features = hidden_features or in_features | |||
self.fc1 = nn.Linear(in_features, hidden_features) | |||
self.act = act_layer() | |||
self.fc2 = nn.Linear(hidden_features, out_features) | |||
self.drop = nn.Dropout(drop) | |||
def forward(self, x): | |||
x = self.fc1(x) | |||
x = self.act(x) | |||
x = self.drop(x) | |||
x = self.fc2(x) | |||
x = self.drop(x) | |||
return x | |||
def drop_path(x, drop_prob: float = 0., training: bool = False): | |||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, | |||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... | |||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for | |||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use | |||
'survival rate' as the argument. | |||
""" | |||
if drop_prob == 0. or not training: | |||
return x | |||
keep_prob = 1 - drop_prob | |||
shape = (x.shape[0], ) + (1, ) * ( | |||
x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets | |||
random_tensor = keep_prob + torch.rand( | |||
shape, dtype=x.dtype, device=x.device) | |||
random_tensor.floor_() # binarize | |||
output = x.div(keep_prob) * random_tensor | |||
return output | |||
class DropPath(nn.Module): | |||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |||
""" | |||
def __init__(self, drop_prob=None): | |||
super(DropPath, self).__init__() | |||
self.drop_prob = drop_prob | |||
def forward(self, x): | |||
return drop_path(x, self.drop_prob, self.training) | |||
class Attention(nn.Module): | |||
def __init__(self, | |||
dim, | |||
num_heads=8, | |||
qkv_bias=False, | |||
attn_drop=0.1, | |||
proj_drop=0.1): | |||
super().__init__() | |||
self.num_heads = num_heads | |||
head_dim = dim // num_heads | |||
self.scale = head_dim**-0.5 | |||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |||
self.attn_drop = nn.Dropout(attn_drop) | |||
self.proj = nn.Linear(dim, dim) | |||
self.proj_drop = nn.Dropout(proj_drop) | |||
def forward(self, x): | |||
B, N, C = x.shape | |||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, | |||
C // self.num_heads).permute(2, 0, 3, 1, 4) | |||
q, k, v = qkv[0], qkv[1], qkv[ | |||
2] # make torchscript happy (cannot use tensor as tuple) | |||
attn = (q @ k.transpose(-2, -1)) * self.scale | |||
attn = attn.softmax(dim=-1) | |||
attn = self.attn_drop(attn) | |||
x = (attn @ v).transpose(1, 2).reshape(B, N, C) | |||
x = self.proj(x) | |||
x = self.proj_drop(x) | |||
return x | |||
class Block(nn.Module): | |||
def __init__(self, | |||
dim, | |||
num_heads, | |||
mlp_ratio=4., | |||
qkv_bias=False, | |||
drop=0., | |||
attn_drop=0., | |||
drop_path=0., | |||
act_layer=nn.GELU, | |||
norm_layer=nn.LayerNorm): | |||
super().__init__() | |||
self.norm1 = norm_layer(dim) | |||
self.attn = Attention( | |||
dim, | |||
num_heads=num_heads, | |||
qkv_bias=qkv_bias, | |||
attn_drop=attn_drop, | |||
proj_drop=drop) | |||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |||
self.drop_path = DropPath( | |||
drop_path) if drop_path > 0. else nn.Identity() | |||
self.norm2 = norm_layer(dim) | |||
mlp_hidden_dim = int(dim * mlp_ratio) | |||
self.mlp = Mlp( | |||
in_features=dim, | |||
hidden_features=mlp_hidden_dim, | |||
act_layer=act_layer, | |||
drop=drop) | |||
def forward(self, x): | |||
x = x + self.drop_path(self.attn(self.norm1(x))) | |||
x = x + self.drop_path(self.mlp(self.norm2(x))) | |||
return x | |||
class VisionTransformer(nn.Module): | |||
def __init__(self, | |||
img_size=224, | |||
patch_size=16, | |||
in_chans=3, | |||
num_classes=1000, | |||
embed_dim=768, | |||
depth=12, | |||
num_heads=12, | |||
mlp_ratio=4., | |||
qkv_bias=True, | |||
representation_size=None, | |||
distilled=False, | |||
drop_rate=0.1, | |||
attn_drop_rate=0.1, | |||
drop_path_rate=0., | |||
embed_layer=PatchEmbed, | |||
norm_layer=None, | |||
act_layer=None, | |||
weight_init=''): | |||
""" | |||
Args: | |||
img_size (int, tuple): input image size | |||
patch_size (int, tuple): patch size | |||
in_chans (int): number of input channels | |||
num_classes (int): number of classes for classification head | |||
embed_dim (int): embedding dimension | |||
depth (int): depth of transformer | |||
num_heads (int): number of attention heads | |||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim | |||
qkv_bias (bool): enable bias for qkv if True | |||
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set | |||
distilled (bool): model includes a distillation token and head as in DeiT models | |||
drop_rate (float): dropout rate | |||
attn_drop_rate (float): attention dropout rate | |||
drop_path_rate (float): stochastic depth rate | |||
embed_layer (nn.Module): patch embedding layer | |||
norm_layer: (nn.Module): normalization layer | |||
weight_init: (str): weight init scheme | |||
""" | |||
super().__init__() | |||
self.num_classes = num_classes | |||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models | |||
self.num_tokens = 2 if distilled else 1 | |||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) | |||
act_layer = act_layer or nn.GELU | |||
self.patch_embed = embed_layer( | |||
img_size=img_size, | |||
patch_size=patch_size, | |||
in_chans=in_chans, | |||
embed_dim=embed_dim) | |||
num_patches = self.patch_embed.num_patches | |||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |||
self.dist_token = nn.Parameter(torch.zeros( | |||
1, 1, embed_dim)) if distilled else None | |||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) | |||
self.pos_drop = nn.Dropout(p=drop_rate) | |||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) | |||
] # stochastic depth decay rule | |||
self.blocks = nn.Sequential(*[ | |||
Block( | |||
dim=embed_dim, | |||
num_heads=num_heads, | |||
mlp_ratio=mlp_ratio, | |||
qkv_bias=qkv_bias, | |||
drop=drop_rate, | |||
attn_drop=attn_drop_rate, | |||
drop_path=dpr[i], | |||
norm_layer=norm_layer, | |||
act_layer=act_layer) for i in range(depth) | |||
]) | |||
self.norm = norm_layer(embed_dim) | |||
# Representation layer | |||
if representation_size and not distilled: | |||
self.num_features = representation_size | |||
self.pre_logits = nn.Sequential( | |||
OrderedDict([('fc', nn.Linear(embed_dim, representation_size)), | |||
('act', nn.Tanh())])) | |||
else: | |||
self.pre_logits = nn.Identity() | |||
# Classifier head(s) | |||
self.head = nn.Linear( | |||
self.num_features, | |||
num_classes) if num_classes > 0 else nn.Identity() | |||
self.head_dist = None | |||
if distilled: | |||
self.head_dist = nn.Linear( | |||
self.embed_dim, | |||
self.num_classes) if num_classes > 0 else nn.Identity() | |||
def reset_classifier(self, num_classes, global_pool=''): | |||
self.num_classes = num_classes | |||
self.head = nn.Linear( | |||
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() | |||
if self.num_tokens == 2: | |||
self.head_dist = nn.Linear( | |||
self.embed_dim, | |||
self.num_classes) if num_classes > 0 else nn.Identity() | |||
def forward_features(self, x): | |||
x = self.patch_embed(x) | |||
cls_token = self.cls_token.expand( | |||
x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks | |||
if self.dist_token is None: | |||
x = torch.cat((cls_token, x), dim=1) | |||
else: | |||
x = torch.cat( | |||
(cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), | |||
dim=1) | |||
x = self.pos_drop(x + self.pos_embed) | |||
x = self.blocks(x) | |||
x = self.norm(x) | |||
if self.dist_token is None: | |||
return self.pre_logits(x[:, 0]) | |||
else: | |||
return x[:, 0], x[:, 1] | |||
def forward(self, x): | |||
x = self.forward_features(x) | |||
if self.head_dist is not None: | |||
x, x_dist = self.head(x[0]), self.head_dist( | |||
x[1]) # x must be a tuple | |||
if self.training and not torch.jit.is_scripting(): | |||
# during inference, return the average of both classifier predictions | |||
return x, x_dist | |||
else: | |||
return (x + x_dist) / 2 | |||
else: | |||
x = self.head(x) | |||
return x |
@@ -0,0 +1,63 @@ | |||
""" Contains various versions of ViTSTR. | |||
ViTSTR were proposed in: | |||
Rowel Atienza | |||
Vision transformer for fast and efficient scene text recognition. ICDAR 2021. | |||
Compared to https://github.com/roatienza/deep-text-recognition-benchmark, | |||
we obtain different ViTSTR variants by changing the network patch_size and in_chans. | |||
""" | |||
from __future__ import absolute_import, division, print_function | |||
import logging | |||
from copy import deepcopy | |||
from functools import partial | |||
import torch | |||
import torch.nn as nn | |||
import torch.utils.model_zoo as model_zoo | |||
from .timm_tinyc import VisionTransformer | |||
class ViTSTR(VisionTransformer): | |||
''' | |||
ViTSTR is basically a ViT that uses DeiT weights. | |||
Modified head to support a sequence of characters prediction for STR. | |||
''' | |||
def __init__(self, *args, **kwargs): | |||
super().__init__(*args, **kwargs) | |||
def reset_classifier(self, num_classes): | |||
self.num_classes = num_classes | |||
self.head = nn.Linear( | |||
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() | |||
def forward_features(self, x): | |||
x = self.patch_embed(x) | |||
x = x + self.pos_embed | |||
x = self.pos_drop(x) | |||
for blk in self.blocks: | |||
x = blk(x) | |||
x = self.norm(x) | |||
return x | |||
def forward(self, x): | |||
x = self.forward_features(x) | |||
b, s, e = x.size() | |||
x = x.reshape(b * s, e) | |||
x = self.head(x).view(b, s, self.num_classes) | |||
return x | |||
def vitstr_tiny(num_tokens): | |||
vitstr = ViTSTR( | |||
patch_size=1, | |||
in_chans=512, | |||
embed_dim=192, | |||
depth=12, | |||
num_heads=3, | |||
mlp_ratio=4, | |||
qkv_bias=True) | |||
vitstr.reset_classifier(num_classes=num_tokens) | |||
return vitstr |
@@ -0,0 +1,47 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os.path as osp | |||
import shutil | |||
import sys | |||
import tempfile | |||
import unittest | |||
from typing import Any, Dict, List, Tuple, Union | |||
import cv2 | |||
import numpy as np | |||
import PIL | |||
from modelscope.pipelines import pipeline | |||
from modelscope.pipelines.base import Pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.test_utils import test_level | |||
class OCRRecognitionTest(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.model_id = 'damo/cv_convnextTiny_ocr-recognition_damo' | |||
self.test_image = 'data/test/images/ocr_recognition.jpg' | |||
def pipeline_inference(self, pipeline: Pipeline, input_location: str): | |||
result = pipeline(input_location) | |||
print('ocr recognition results: ', result) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_model_from_modelhub(self): | |||
ocr_recognition = pipeline(Tasks.ocr_recognition, model=self.model_id) | |||
self.pipeline_inference(ocr_recognition, self.test_image) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_with_model_from_modelhub_PILinput(self): | |||
ocr_recognition = pipeline(Tasks.ocr_recognition, model=self.model_id) | |||
imagePIL = PIL.Image.open(self.test_image) | |||
self.pipeline_inference(ocr_recognition, imagePIL) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_modelhub_default_model(self): | |||
ocr_recognition = pipeline(Tasks.ocr_recognition) | |||
self.pipeline_inference(ocr_recognition, self.test_image) | |||
if __name__ == '__main__': | |||
unittest.main() |