yaoxiong.hyx yingda.chen 3 years ago
parent
commit
69c57e0f55
12 changed files with 804 additions and 0 deletions
  1. +3
    -0
      data/test/images/ocr_recognition.jpg
  2. +1
    -0
      modelscope/metainfo.py
  3. +6
    -0
      modelscope/outputs.py
  4. +2
    -0
      modelscope/pipelines/builder.py
  5. +2
    -0
      modelscope/pipelines/cv/__init__.py
  6. +131
    -0
      modelscope/pipelines/cv/ocr_recognition_pipeline.py
  7. +23
    -0
      modelscope/pipelines/cv/ocr_utils/model_convnext_transformer.py
  8. +23
    -0
      modelscope/pipelines/cv/ocr_utils/ocr_modules/__init__.py
  9. +169
    -0
      modelscope/pipelines/cv/ocr_utils/ocr_modules/convnext.py
  10. +334
    -0
      modelscope/pipelines/cv/ocr_utils/ocr_modules/timm_tinyc.py
  11. +63
    -0
      modelscope/pipelines/cv/ocr_utils/ocr_modules/vitstr.py
  12. +47
    -0
      tests/pipelines/test_ocr_recognition.py

+ 3
- 0
data/test/images/ocr_recognition.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2d68cfcaa7cc7b8276877c2dfa022deebe82076bc178ece1bfe7fd5423cd5b99
size 60009

+ 1
- 0
modelscope/metainfo.py View File

@@ -101,6 +101,7 @@ class Pipelines(object):
image2image_translation = 'image-to-image-translation'
live_category = 'live-category'
video_category = 'video-category'
ocr_recognition = 'convnextTiny-ocr-recognition'
image_portrait_enhancement = 'gpen-image-portrait-enhancement'
image_to_image_generation = 'image-to-image-generation'
skin_retouching = 'unet-skin-retouching'


+ 6
- 0
modelscope/outputs.py View File

@@ -47,6 +47,12 @@ TASK_OUTPUTS = {
# }
Tasks.ocr_detection: [OutputKeys.POLYGONS],

# ocr recognition result for single sample
# {
# "text": "电子元器件提供BOM配单"
# }
Tasks.ocr_recognition: [OutputKeys.TEXT],

# face detection result for single sample
# {
# "scores": [0.9, 0.1, 0.05, 0.05]


+ 2
- 0
modelscope/pipelines/builder.py View File

@@ -119,6 +119,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.image_classification:
(Pipelines.daily_image_classification,
'damo/cv_vit-base_image-classification_Dailylife-labels'),
Tasks.ocr_recognition: (Pipelines.ocr_recognition,
'damo/cv_convnextTiny_ocr-recognition_damo'),
Tasks.skin_retouching: (Pipelines.skin_retouching,
'damo/cv_unet_skin-retouching'),
}


+ 2
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -29,6 +29,7 @@ if TYPE_CHECKING:
from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline
from .live_category_pipeline import LiveCategoryPipeline
from .ocr_detection_pipeline import OCRDetectionPipeline
from .ocr_recognition_pipeline import OCRRecognitionPipeline
from .skin_retouching_pipeline import SkinRetouchingPipeline
from .tinynas_classification_pipeline import TinynasClassificationPipeline
from .video_category_pipeline import VideoCategoryPipeline
@@ -65,6 +66,7 @@ else:
'image_to_image_generation_pipeline':
['Image2ImageGenerationPipeline'],
'ocr_detection_pipeline': ['OCRDetectionPipeline'],
'ocr_recognition_pipeline': ['OCRRecognitionPipeline'],
'skin_retouching_pipeline': ['SkinRetouchingPipeline'],
'tinynas_classification_pipeline': ['TinynasClassificationPipeline'],
'video_category_pipeline': ['VideoCategoryPipeline'],


+ 131
- 0
modelscope/pipelines/cv/ocr_recognition_pipeline.py View File

@@ -0,0 +1,131 @@
import math
import os.path as osp
from typing import Any, Dict

import cv2
import numpy as np
import PIL
import torch

from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.pipelines.cv.ocr_utils.model_convnext_transformer import \
OCRRecModel
from modelscope.preprocessors import load_image
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()

# constant
NUM_CLASSES = 7644
IMG_HEIGHT = 32
IMG_WIDTH = 300
PRED_LENTH = 75
PRED_PAD = 6


@PIPELINES.register_module(
Tasks.ocr_recognition, module_name=Pipelines.ocr_recognition)
class OCRRecognitionPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
label_path = osp.join(self.model, 'label_dict.txt')
logger.info(f'loading model from {model_path}')

self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
self.infer_model = OCRRecModel(NUM_CLASSES).to(self.device)
self.infer_model.eval()
self.infer_model.load_state_dict(
torch.load(model_path, map_location=self.device))
self.labelMapping = dict()
with open(label_path, 'r') as f:
lines = f.readlines()
cnt = 2
for line in lines:
line = line.strip('\n')
self.labelMapping[cnt] = line
cnt += 1

def preprocess(self, input: Input) -> Dict[str, Any]:
if isinstance(input, str):
img = np.array(load_image(input).convert('L'))
elif isinstance(input, PIL.Image.Image):
img = np.array(input.convert('L'))
elif isinstance(input, np.ndarray):
if len(input.shape) == 3:
img = cv2.cvtColor(input, cv2.COLOR_RGB2GRAY)
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')
data = []
img_h, img_w = img.shape
wh_ratio = img_w / img_h
true_w = int(IMG_HEIGHT * wh_ratio)
split_batch_cnt = 1
if true_w < IMG_WIDTH * 1.2:
img = cv2.resize(img, (min(true_w, IMG_WIDTH), IMG_HEIGHT))
else:
split_batch_cnt = math.ceil((true_w - 48) * 1.0 / 252)
img = cv2.resize(img, (true_w, IMG_HEIGHT))

if split_batch_cnt == 1:
mask = np.zeros((IMG_HEIGHT, IMG_WIDTH))
mask[:, :img.shape[1]] = img
data.append(mask)
else:
for idx in range(split_batch_cnt):
mask = np.zeros((IMG_HEIGHT, IMG_WIDTH))
left = (PRED_LENTH * 4 - PRED_PAD * 4) * idx
trunk_img = img[:, left:min(left + PRED_LENTH * 4, true_w)]
mask[:, :trunk_img.shape[1]] = trunk_img
data.append(mask)

data = torch.FloatTensor(data).view(
len(data), 1, IMG_HEIGHT, IMG_WIDTH).cuda() / 255.

result = {'img': data}

return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
pred = self.infer_model(input['img'])
return {'results': pred}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
preds = inputs['results']
batchSize, length = preds.shape
pred_idx = []
if batchSize == 1:
pred_idx = preds[0].cpu().data.tolist()
else:
for idx in range(batchSize):
if idx == 0:
pred_idx.extend(preds[idx].cpu().data[:PRED_LENTH
- PRED_PAD].tolist())
elif idx == batchSize - 1:
pred_idx.extend(preds[idx].cpu().data[PRED_PAD:].tolist())
else:
pred_idx.extend(preds[idx].cpu().data[PRED_PAD:PRED_LENTH
- PRED_PAD].tolist())

# ctc decoder
last_p = 0
str_pred = []
for p in pred_idx:
if p != last_p and p != 0:
str_pred.append(self.labelMapping[p])
last_p = p

final_str = ''.join(str_pred)
result = {OutputKeys.TEXT: final_str}
return result

+ 23
- 0
modelscope/pipelines/cv/ocr_utils/model_convnext_transformer.py View File

@@ -0,0 +1,23 @@
import torch
import torch.nn as nn

from .ocr_modules.convnext import convnext_tiny
from .ocr_modules.vitstr import vitstr_tiny


class OCRRecModel(nn.Module):

def __init__(self, num_classes):
super(OCRRecModel, self).__init__()
self.cnn_model = convnext_tiny()
self.num_classes = num_classes
self.vitstr = vitstr_tiny(num_tokens=num_classes)

def forward(self, input):
""" Transformation stage """
features = self.cnn_model(input)
prediction = self.vitstr(features)
prediction = torch.nn.functional.softmax(prediction, dim=-1)

output = torch.argmax(prediction, -1)
return output

+ 23
- 0
modelscope/pipelines/cv/ocr_utils/ocr_modules/__init__.py View File

@@ -0,0 +1,23 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .convnext import convnext_tiny
from .vitstr import vitstr_tiny
else:
_import_structure = {
'convnext': ['convnext_tiny'],
'vitstr': ['vitstr_tiny']
}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 169
- 0
modelscope/pipelines/cv/ocr_utils/ocr_modules/convnext.py View File

@@ -0,0 +1,169 @@
""" Contains various versions of ConvNext Networks.
ConvNext Networks (ConvNext) were proposed in:
Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell and Saining Xie
A ConvNet for the 2020s. CVPR 2022.
Compared to https://github.com/facebookresearch/ConvNeXt,
we obtain different ConvNext variants by changing the network depth, width,
feature number, and downsample ratio.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

from .timm_tinyc import DropPath


class Block(nn.Module):
r""" ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch

Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
"""

def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
super().__init__()
self.dwconv = nn.Conv2d(
dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(
dim,
4 * dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4 * dim, dim)
self.gamma = nn.Parameter(
layer_scale_init_value * torch.ones((dim)),
requires_grad=True) if layer_scale_init_value > 0 else None
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()

def forward(self, x):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)

x = input + self.drop_path(x)
return x


class ConvNeXt(nn.Module):
r""" ConvNeXt
A PyTorch impl of : `A ConvNet for the 2020s` -
https://arxiv.org/pdf/2201.03545.pdf

Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_path_rate (float): Stochastic depth rate. Default: 0.
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
"""

def __init__(
self,
in_chans=1,
num_classes=1000,
depths=[3, 3, 9, 3],
dims=[96, 192, 384, 768],
drop_path_rate=0.,
layer_scale_init_value=1e-6,
head_init_scale=1.,
):
super().__init__()

self.downsample_layers = nn.ModuleList(
) # stem and 3 intermediate downsampling conv layers
stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
LayerNorm(dims[0], eps=1e-6, data_format='channels_first'))
self.downsample_layers.append(stem)
for i in range(3):
downsample_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format='channels_first'),
nn.Conv2d(
dims[i], dims[i + 1], kernel_size=(2, 1), stride=(2, 1)),
)
self.downsample_layers.append(downsample_layer)

self.stages = nn.ModuleList(
) # 4 feature resolution stages, each consisting of multiple residual blocks
dp_rates = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
]
cur = 0
for i in range(4):
stage = nn.Sequential(*[
Block(
dim=dims[i],
drop_path=dp_rates[cur + j],
layer_scale_init_value=layer_scale_init_value)
for j in range(depths[i])
])
self.stages.append(stage)
cur += depths[i]

def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
nn.init.constant_(m.bias, 0)

def forward_features(self, x):
for i in range(4):
x = self.downsample_layers[i](x.contiguous())
x = self.stages[i](x.contiguous())
return x # global average pooling, (N, C, H, W) -> (N, C)

def forward(self, x):
x = self.forward_features(x.contiguous())

return x.contiguous()


class LayerNorm(nn.Module):
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""

def __init__(self,
normalized_shape,
eps=1e-6,
data_format='channels_last'):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ['channels_last', 'channels_first']:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )

def forward(self, x):
if self.data_format == 'channels_last':
return F.layer_norm(x, self.normalized_shape, self.weight,
self.bias, self.eps)
elif self.data_format == 'channels_first':
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x


def convnext_tiny():
model = ConvNeXt(depths=[3, 3, 8, 3], dims=[96, 192, 256, 512])
return model

+ 334
- 0
modelscope/pipelines/cv/ocr_utils/ocr_modules/timm_tinyc.py View File

@@ -0,0 +1,334 @@
'''Referenced from rwightman's pytorch-image-models(timm).
Github: https://github.com/rwightman/pytorch-image-models
We use some modules and modify the parameters according to our network.
'''
import collections.abc
import logging
import math
from collections import OrderedDict
from copy import deepcopy
from functools import partial
from itertools import repeat

import torch
import torch.nn as nn
import torch.nn.functional as F


def _ntuple(n):

def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))

return parse


class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""

def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True):
super().__init__()
img_size = (1, 75)
to_2tuple = _ntuple(2)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten

self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
x = x.permute(0, 1, 3, 2)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x


class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""

def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)

def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x


def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.

"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0], ) + (1, ) * (
x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(
shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output


class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""

def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob

def forward(self, x):
return drop_path(x, self.drop_prob, self.training)


class Attention(nn.Module):

def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
attn_drop=0.1,
proj_drop=0.1):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[
2] # make torchscript happy (cannot use tensor as tuple)

attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x


class Block(nn.Module):

def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)

def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x


class VisionTransformer(nn.Module):

def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
representation_size=None,
distilled=False,
drop_rate=0.1,
attn_drop_rate=0.1,
drop_path_rate=0.,
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
weight_init=''):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
distilled (bool): model includes a distillation token and head as in DeiT models
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
weight_init: (str): weight init scheme
"""
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 2 if distilled else 1
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU

self.patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches

self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.dist_token = nn.Parameter(torch.zeros(
1, 1, embed_dim)) if distilled else None
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)

dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer) for i in range(depth)
])
self.norm = norm_layer(embed_dim)

# Representation layer
if representation_size and not distilled:
self.num_features = representation_size
self.pre_logits = nn.Sequential(
OrderedDict([('fc', nn.Linear(embed_dim, representation_size)),
('act', nn.Tanh())]))
else:
self.pre_logits = nn.Identity()

# Classifier head(s)
self.head = nn.Linear(
self.num_features,
num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = None
if distilled:
self.head_dist = nn.Linear(
self.embed_dim,
self.num_classes) if num_classes > 0 else nn.Identity()

def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.num_tokens == 2:
self.head_dist = nn.Linear(
self.embed_dim,
self.num_classes) if num_classes > 0 else nn.Identity()

def forward_features(self, x):
x = self.patch_embed(x)
cls_token = self.cls_token.expand(
x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
if self.dist_token is None:
x = torch.cat((cls_token, x), dim=1)
else:
x = torch.cat(
(cls_token, self.dist_token.expand(x.shape[0], -1, -1), x),
dim=1)
x = self.pos_drop(x + self.pos_embed)
x = self.blocks(x)
x = self.norm(x)
if self.dist_token is None:
return self.pre_logits(x[:, 0])
else:
return x[:, 0], x[:, 1]

def forward(self, x):
x = self.forward_features(x)
if self.head_dist is not None:
x, x_dist = self.head(x[0]), self.head_dist(
x[1]) # x must be a tuple
if self.training and not torch.jit.is_scripting():
# during inference, return the average of both classifier predictions
return x, x_dist
else:
return (x + x_dist) / 2
else:
x = self.head(x)
return x

+ 63
- 0
modelscope/pipelines/cv/ocr_utils/ocr_modules/vitstr.py View File

@@ -0,0 +1,63 @@
""" Contains various versions of ViTSTR.
ViTSTR were proposed in:
Rowel Atienza
Vision transformer for fast and efficient scene text recognition. ICDAR 2021.
Compared to https://github.com/roatienza/deep-text-recognition-benchmark,
we obtain different ViTSTR variants by changing the network patch_size and in_chans.
"""
from __future__ import absolute_import, division, print_function
import logging
from copy import deepcopy
from functools import partial

import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo

from .timm_tinyc import VisionTransformer


class ViTSTR(VisionTransformer):
'''
ViTSTR is basically a ViT that uses DeiT weights.
Modified head to support a sequence of characters prediction for STR.
'''

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def reset_classifier(self, num_classes):
self.num_classes = num_classes
self.head = nn.Linear(
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

def forward_features(self, x):
x = self.patch_embed(x)

x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)

x = self.norm(x)
return x

def forward(self, x):
x = self.forward_features(x)
b, s, e = x.size()
x = x.reshape(b * s, e)
x = self.head(x).view(b, s, self.num_classes)
return x


def vitstr_tiny(num_tokens):
vitstr = ViTSTR(
patch_size=1,
in_chans=512,
embed_dim=192,
depth=12,
num_heads=3,
mlp_ratio=4,
qkv_bias=True)
vitstr.reset_classifier(num_classes=num_tokens)
return vitstr

+ 47
- 0
tests/pipelines/test_ocr_recognition.py View File

@@ -0,0 +1,47 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import shutil
import sys
import tempfile
import unittest
from typing import Any, Dict, List, Tuple, Union

import cv2
import numpy as np
import PIL

from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class OCRRecognitionTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/cv_convnextTiny_ocr-recognition_damo'
self.test_image = 'data/test/images/ocr_recognition.jpg'

def pipeline_inference(self, pipeline: Pipeline, input_location: str):
result = pipeline(input_location)
print('ocr recognition results: ', result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
ocr_recognition = pipeline(Tasks.ocr_recognition, model=self.model_id)
self.pipeline_inference(ocr_recognition, self.test_image)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub_PILinput(self):
ocr_recognition = pipeline(Tasks.ocr_recognition, model=self.model_id)
imagePIL = PIL.Image.open(self.test_image)
self.pipeline_inference(ocr_recognition, imagePIL)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
ocr_recognition = pipeline(Tasks.ocr_recognition)
self.pipeline_inference(ocr_recognition, self.test_image)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save