Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10917315master^2
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:209f6ba7f15c9c34a02801b4c6ef33a979f3086702b5229d2e7975eb403c3e15 | |||
size 45819 |
@@ -157,6 +157,7 @@ class Pipelines(object): | |||
person_image_cartoon = 'unet-person-image-cartoon' | |||
ocr_detection = 'resnet18-ocr-detection' | |||
table_recognition = 'dla34-table-recognition' | |||
license_plate_detection = 'resnet18-license-plate-detection' | |||
action_recognition = 'TAdaConv_action-recognition' | |||
animal_recognition = 'resnet101-animal-recognition' | |||
general_recognition = 'resnet101-general-recognition' | |||
@@ -62,6 +62,7 @@ TASK_OUTPUTS = { | |||
# } | |||
Tasks.ocr_detection: [OutputKeys.POLYGONS], | |||
Tasks.table_recognition: [OutputKeys.POLYGONS], | |||
Tasks.license_plate_detection: [OutputKeys.POLYGONS, OutputKeys.TEXT], | |||
# ocr recognition result for single sample | |||
# { | |||
@@ -85,6 +85,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
Tasks.table_recognition: | |||
(Pipelines.table_recognition, | |||
'damo/cv_dla34_table-structure-recognition_cycle-centernet'), | |||
Tasks.license_plate_detection: | |||
(Pipelines.license_plate_detection, | |||
'damo/cv_resnet18_license-plate-detection_damo'), | |||
Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), | |||
Tasks.feature_extraction: (Pipelines.feature_extraction, | |||
'damo/pert_feature-extraction_base-test'), | |||
@@ -41,6 +41,7 @@ if TYPE_CHECKING: | |||
from .live_category_pipeline import LiveCategoryPipeline | |||
from .ocr_detection_pipeline import OCRDetectionPipeline | |||
from .ocr_recognition_pipeline import OCRRecognitionPipeline | |||
from .license_plate_detection_pipeline import LicensePlateDetectionPipeline | |||
from .table_recognition_pipeline import TableRecognitionPipeline | |||
from .skin_retouching_pipeline import SkinRetouchingPipeline | |||
from .tinynas_classification_pipeline import TinynasClassificationPipeline | |||
@@ -109,6 +110,7 @@ else: | |||
'image_inpainting_pipeline': ['ImageInpaintingPipeline'], | |||
'ocr_detection_pipeline': ['OCRDetectionPipeline'], | |||
'ocr_recognition_pipeline': ['OCRRecognitionPipeline'], | |||
'license_plate_detection_pipeline': ['LicensePlateDetectionPipeline'], | |||
'table_recognition_pipeline': ['TableRecognitionPipeline'], | |||
'skin_retouching_pipeline': ['SkinRetouchingPipeline'], | |||
'tinynas_classification_pipeline': ['TinynasClassificationPipeline'], | |||
@@ -0,0 +1,122 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import math | |||
import os.path as osp | |||
from typing import Any, Dict | |||
import cv2 | |||
import numpy as np | |||
import PIL | |||
import torch | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.pipelines.cv.ocr_utils.model_resnet18_half import \ | |||
LicensePlateDet | |||
from modelscope.pipelines.cv.ocr_utils.table_process import ( | |||
bbox_decode, bbox_post_process, decode_by_ind, get_affine_transform, nms) | |||
from modelscope.preprocessors import load_image | |||
from modelscope.preprocessors.image import LoadImage | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.license_plate_detection, | |||
module_name=Pipelines.license_plate_detection) | |||
class LicensePlateDetection(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
super().__init__(model=model, **kwargs) | |||
model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) | |||
config_path = osp.join(self.model, ModelFile.CONFIGURATION) | |||
logger.info(f'loading model from {model_path}') | |||
self.cfg = Config.from_file(config_path) | |||
self.K = self.cfg.K | |||
self.car_type = self.cfg.Type | |||
self.device = torch.device( | |||
'cuda' if torch.cuda.is_available() else 'cpu') | |||
self.infer_model = LicensePlateDet() | |||
checkpoint = torch.load(model_path, map_location=self.device) | |||
if 'state_dict' in checkpoint: | |||
self.infer_model.load_state_dict(checkpoint['state_dict']) | |||
else: | |||
self.infer_model.load_state_dict(checkpoint) | |||
self.infer_model = self.infer_model.to(self.device) | |||
self.infer_model.to(self.device).eval() | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
img = LoadImage.convert_to_ndarray(input)[:, :, ::-1] | |||
mean = np.array([0.408, 0.447, 0.470], | |||
dtype=np.float32).reshape(1, 1, 3) | |||
std = np.array([0.289, 0.274, 0.278], | |||
dtype=np.float32).reshape(1, 1, 3) | |||
height, width = img.shape[0:2] | |||
inp_height, inp_width = 512, 512 | |||
c = np.array([width / 2., height / 2.], dtype=np.float32) | |||
s = max(height, width) * 1.0 | |||
trans_input = get_affine_transform(c, s, 0, [inp_width, inp_height]) | |||
resized_image = cv2.resize(img, (width, height)) | |||
inp_image = cv2.warpAffine( | |||
resized_image, | |||
trans_input, (inp_width, inp_height), | |||
flags=cv2.INTER_LINEAR) | |||
inp_image = ((inp_image / 255. - mean) / std).astype(np.float32) | |||
images = inp_image.transpose(2, 0, 1).reshape(1, 3, inp_height, | |||
inp_width) | |||
images = torch.from_numpy(images).to(self.device) | |||
meta = { | |||
'c': c, | |||
's': s, | |||
'input_height': inp_height, | |||
'input_width': inp_width, | |||
'out_height': inp_height // 4, | |||
'out_width': inp_width // 4 | |||
} | |||
result = {'img': images, 'meta': meta} | |||
return result | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
pred = self.infer_model(input['img']) | |||
return {'results': pred, 'meta': input['meta']} | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
output = inputs['results'][0] | |||
meta = inputs['meta'] | |||
hm = output['hm'].sigmoid_() | |||
ftype = output['ftype'].sigmoid_() | |||
wh = output['wh'] | |||
reg = output['reg'] | |||
bbox, inds = bbox_decode(hm, wh, reg=reg, K=self.K) | |||
car_type = decode_by_ind(ftype, inds, K=self.K).detach().cpu().numpy() | |||
bbox = bbox.detach().cpu().numpy() | |||
for i in range(bbox.shape[1]): | |||
bbox[0][i][9] = car_type[0][i] | |||
bbox = nms(bbox, 0.3) | |||
bbox = bbox_post_process(bbox.copy(), [meta['c'].cpu().numpy()], | |||
[meta['s']], meta['out_height'], | |||
meta['out_width']) | |||
res, Type = [], [] | |||
for box in bbox[0]: | |||
if box[8] > 0.3: | |||
res.append(box[0:8]) | |||
Type.append(self.car_type[int(box[9])]) | |||
result = {OutputKeys.POLYGONS: np.array(res), OutputKeys.TEXT: Type} | |||
return result |
@@ -0,0 +1,275 @@ | |||
# ------------------------------------------------------------------------------ | |||
# The implementation is adopted from CenterNet, | |||
# made publicly available under the MIT License at https://github.com/xingyizhou/CenterNet.git | |||
# ------------------------------------------------------------------------------ | |||
import os | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
BN_MOMENTUM = 0.1 | |||
class BasicBlock(nn.Module): | |||
expansion = 1 | |||
def __init__(self, inplanes, planes, stride=1, downsample=None): | |||
super(BasicBlock, self).__init__() | |||
self.conv1 = nn.Conv2d( | |||
inplanes, planes, kernel_size=3, stride=stride, padding=1) | |||
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) | |||
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||
self.downsample = downsample | |||
self.stride = stride | |||
self.planes = planes | |||
def forward(self, x): | |||
residual = x | |||
out = self.conv1(x) | |||
out = self.bn1(out) | |||
out = self.relu(out) | |||
out = self.conv2(out) | |||
out = self.bn2(out) | |||
if self.downsample is not None: | |||
residual = self.downsample(residual) | |||
out += residual | |||
out = self.relu(out) | |||
return out | |||
class Bottleneck(nn.Module): | |||
expansion = 4 | |||
def __init__(self, inplanes, planes, stride=1, downsample=None): | |||
super(Bottleneck, self).__init__() | |||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | |||
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||
self.conv2 = nn.Conv2d( | |||
planes, | |||
planes, | |||
kernel_size=3, | |||
stride=stride, | |||
padding=1, | |||
bias=False) | |||
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||
self.conv3 = nn.Conv2d( | |||
planes, planes * self.expansion, kernel_size=1, bias=False) | |||
self.bn3 = nn.BatchNorm2d( | |||
planes * self.expansion, momentum=BN_MOMENTUM) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.downsample = downsample | |||
self.stride = stride | |||
def forward(self, x): | |||
residual = x | |||
out = self.conv1(x) | |||
out = self.bn1(out) | |||
out = self.relu(out) | |||
out = self.conv2(out) | |||
out = self.bn2(out) | |||
out = self.relu(out) | |||
out = self.conv3(out) | |||
out = self.bn3(out) | |||
if self.downsample is not None: | |||
residual = self.downsample(x) | |||
out += residual | |||
out = self.relu(out) | |||
return out | |||
class PoseResNet(nn.Module): | |||
def __init__(self, block, layers, head_conv=64, **kwargs): | |||
self.inplanes = 64 | |||
self.deconv_with_bias = False | |||
self.heads = {'hm': 1, 'cls': 4, 'ftype': 11, 'wh': 8, 'reg': 2} | |||
super(PoseResNet, self).__init__() | |||
self.conv1 = nn.Conv2d( | |||
3, 64, kernel_size=7, stride=2, padding=3, bias=False) | |||
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||
self.layer1 = self._make_layer(block, 64, layers[0], stride=2) | |||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | |||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | |||
self.layer4 = self._make_layer(block, 256, layers[3], stride=2) | |||
self.adaption3 = nn.Conv2d( | |||
256, 256, kernel_size=1, stride=1, padding=0, bias=False) | |||
self.adaption2 = nn.Conv2d( | |||
128, 256, kernel_size=1, stride=1, padding=0, bias=False) | |||
self.adaption1 = nn.Conv2d( | |||
64, 256, kernel_size=1, stride=1, padding=0, bias=False) | |||
self.adaption0 = nn.Conv2d( | |||
64, 256, kernel_size=1, stride=1, padding=0, bias=False) | |||
self.adaptionU1 = nn.Conv2d( | |||
256, 256, kernel_size=1, stride=1, padding=0, bias=False) | |||
self.deconv_layers1 = self._make_deconv_layer( | |||
1, | |||
[256], | |||
[4], | |||
) | |||
self.deconv_layers2 = self._make_deconv_layer( | |||
1, | |||
[256], | |||
[4], | |||
) | |||
self.deconv_layers3 = self._make_deconv_layer( | |||
1, | |||
[256], | |||
[4], | |||
) | |||
self.deconv_layers4 = self._make_deconv_layer( | |||
1, | |||
[256], | |||
[4], | |||
) | |||
for head in sorted(self.heads): | |||
num_output = self.heads[head] | |||
if head_conv > 0: | |||
inchannel = 256 | |||
fc = nn.Sequential( | |||
nn.Conv2d( | |||
inchannel, | |||
head_conv, | |||
kernel_size=3, | |||
padding=1, | |||
bias=True), nn.ReLU(inplace=True), | |||
nn.Conv2d( | |||
head_conv, | |||
num_output, | |||
kernel_size=1, | |||
stride=1, | |||
padding=0)) | |||
else: | |||
inchannel = 256 | |||
fc = nn.Conv2d( | |||
in_channels=inchannel, | |||
out_channels=num_output, | |||
kernel_size=1, | |||
stride=1, | |||
padding=0) | |||
self.__setattr__(head, fc) | |||
def _make_layer(self, block, planes, blocks, stride=1): | |||
downsample = None | |||
if stride != 1 or self.inplanes != planes * block.expansion: | |||
downsample = nn.Sequential( | |||
nn.Conv2d( | |||
self.inplanes, | |||
planes * block.expansion, | |||
kernel_size=1, | |||
stride=stride, | |||
bias=False), | |||
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), | |||
) | |||
layers = [] | |||
layers.append(block(self.inplanes, planes, stride, downsample)) | |||
self.inplanes = planes * block.expansion | |||
for i in range(1, blocks): | |||
layers.append(block(self.inplanes, planes)) | |||
return nn.Sequential(*layers) | |||
def _get_deconv_cfg(self, deconv_kernel, index): | |||
if deconv_kernel == 4: | |||
padding = 1 | |||
output_padding = 0 | |||
elif deconv_kernel == 3: | |||
padding = 1 | |||
output_padding = 1 | |||
elif deconv_kernel == 2: | |||
padding = 0 | |||
output_padding = 0 | |||
elif deconv_kernel == 7: | |||
padding = 3 | |||
output_padding = 0 | |||
return deconv_kernel, padding, output_padding | |||
def _make_deconv_layer(self, num_layers, num_filters, num_kernels): | |||
assert num_layers == len(num_filters), \ | |||
'ERROR: num_deconv_layers is different len(num_deconv_filters)' | |||
assert num_layers == len(num_kernels), \ | |||
'ERROR: num_deconv_layers is different len(num_deconv_filters)' | |||
layers = [] | |||
for i in range(num_layers): | |||
kernel, padding, output_padding = \ | |||
self._get_deconv_cfg(num_kernels[i], i) | |||
planes = num_filters[i] | |||
layers.append( | |||
nn.ConvTranspose2d( | |||
in_channels=self.inplanes, | |||
out_channels=planes, | |||
kernel_size=kernel, | |||
stride=2, | |||
padding=padding, | |||
output_padding=output_padding, | |||
bias=self.deconv_with_bias)) | |||
layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)) | |||
layers.append(nn.ReLU(inplace=True)) | |||
self.inplanes = planes | |||
return nn.Sequential(*layers) | |||
def forward(self, x): | |||
x = self.conv1(x) | |||
x = self.bn1(x) | |||
x = self.relu(x) | |||
x0 = self.maxpool(x) | |||
x1 = self.layer1(x0) | |||
x2 = self.layer2(x1) | |||
x3 = self.layer3(x2) | |||
x4 = self.layer4(x3) | |||
x3_ = self.deconv_layers1(x4) | |||
x3_ = self.adaption3(x3) + x3_ | |||
x2_ = self.deconv_layers2(x3_) | |||
x2_ = self.adaption2(x2) + x2_ | |||
x1_ = self.deconv_layers3(x2_) | |||
x1_ = self.adaption1(x1) + x1_ | |||
x0_ = self.deconv_layers4(x1_) + self.adaption0(x0) | |||
x0_ = self.adaptionU1(x0_) | |||
ret = {} | |||
for head in self.heads: | |||
ret[head] = self.__getattr__(head)(x0_) | |||
return [ret] | |||
resnet_spec = { | |||
18: (BasicBlock, [2, 2, 2, 2]), | |||
34: (BasicBlock, [3, 4, 6, 3]), | |||
50: (Bottleneck, [3, 4, 6, 3]), | |||
101: (Bottleneck, [3, 4, 23, 3]), | |||
152: (Bottleneck, [3, 8, 36, 3]) | |||
} | |||
def LicensePlateDet(num_layers=18): | |||
block_class, layers = resnet_spec[num_layers] | |||
model = PoseResNet(block_class, layers) | |||
return model |
@@ -129,6 +129,14 @@ def _topk(scores, K=40): | |||
return topk_score, topk_inds, topk_clses, topk_ys, topk_xs | |||
def decode_by_ind(heat, inds, K=100): | |||
batch, cat, height, width = heat.size() | |||
score = _tranpose_and_gather_feat(heat, inds) | |||
score = score.view(batch, K, cat) | |||
_, Type = torch.max(score, 2) | |||
return Type | |||
def bbox_decode(heat, wh, reg=None, K=100): | |||
batch, cat, height, width = heat.size() | |||
@@ -163,7 +171,7 @@ def bbox_decode(heat, wh, reg=None, K=100): | |||
) | |||
detections = torch.cat([bboxes, scores, clses], dim=2) | |||
return detections, keep | |||
return detections, inds | |||
def gbox_decode(mk, st_reg, reg=None, K=400): | |||
@@ -50,7 +50,7 @@ class TableRecognitionPipeline(Pipeline): | |||
self.infer_model.load_state_dict(checkpoint) | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
img = LoadImage.convert_to_ndarray(input) | |||
img = LoadImage.convert_to_ndarray(input)[:, :, ::-1] | |||
mean = np.array([0.408, 0.447, 0.470], | |||
dtype=np.float32).reshape(1, 1, 3) | |||
@@ -17,6 +17,7 @@ class CVTasks(object): | |||
ocr_detection = 'ocr-detection' | |||
ocr_recognition = 'ocr-recognition' | |||
table_recognition = 'table-recognition' | |||
license_plate_detection = 'license-plate-detection' | |||
# human face body related | |||
animal_recognition = 'animal-recognition' | |||
@@ -0,0 +1,41 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
from modelscope.pipelines import pipeline | |||
from modelscope.pipelines.base import Pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
from modelscope.utils.test_utils import test_level | |||
class LicensePlateDectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
def setUp(self) -> None: | |||
self.model_id = 'damo/cv_resnet18_license-plate-detection_damo' | |||
self.test_image = 'data/test/images/license_plate_detection.jpg' | |||
self.task = Tasks.license_plate_detection | |||
def pipeline_inference(self, pipe: Pipeline, input_location: str): | |||
result = pipe(input_location) | |||
print('license plate recognition results: ') | |||
print(result) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_model_from_modelhub(self): | |||
license_plate_detection = pipeline( | |||
Tasks.license_plate_detection, model=self.model_id) | |||
self.pipeline_inference(license_plate_detection, self.test_image) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_modelhub_default_model(self): | |||
license_plate_detection = pipeline(Tasks.license_plate_detection) | |||
self.pipeline_inference(license_plate_detection, self.test_image) | |||
@unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||
def test_demo_compatibility(self): | |||
self.compatibility_check() | |||
if __name__ == '__main__': | |||
unittest.main() |