Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10917315master^2
@@ -0,0 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | |||||
oid sha256:209f6ba7f15c9c34a02801b4c6ef33a979f3086702b5229d2e7975eb403c3e15 | |||||
size 45819 |
@@ -157,6 +157,7 @@ class Pipelines(object): | |||||
person_image_cartoon = 'unet-person-image-cartoon' | person_image_cartoon = 'unet-person-image-cartoon' | ||||
ocr_detection = 'resnet18-ocr-detection' | ocr_detection = 'resnet18-ocr-detection' | ||||
table_recognition = 'dla34-table-recognition' | table_recognition = 'dla34-table-recognition' | ||||
license_plate_detection = 'resnet18-license-plate-detection' | |||||
action_recognition = 'TAdaConv_action-recognition' | action_recognition = 'TAdaConv_action-recognition' | ||||
animal_recognition = 'resnet101-animal-recognition' | animal_recognition = 'resnet101-animal-recognition' | ||||
general_recognition = 'resnet101-general-recognition' | general_recognition = 'resnet101-general-recognition' | ||||
@@ -62,6 +62,7 @@ TASK_OUTPUTS = { | |||||
# } | # } | ||||
Tasks.ocr_detection: [OutputKeys.POLYGONS], | Tasks.ocr_detection: [OutputKeys.POLYGONS], | ||||
Tasks.table_recognition: [OutputKeys.POLYGONS], | Tasks.table_recognition: [OutputKeys.POLYGONS], | ||||
Tasks.license_plate_detection: [OutputKeys.POLYGONS, OutputKeys.TEXT], | |||||
# ocr recognition result for single sample | # ocr recognition result for single sample | ||||
# { | # { | ||||
@@ -85,6 +85,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
Tasks.table_recognition: | Tasks.table_recognition: | ||||
(Pipelines.table_recognition, | (Pipelines.table_recognition, | ||||
'damo/cv_dla34_table-structure-recognition_cycle-centernet'), | 'damo/cv_dla34_table-structure-recognition_cycle-centernet'), | ||||
Tasks.license_plate_detection: | |||||
(Pipelines.license_plate_detection, | |||||
'damo/cv_resnet18_license-plate-detection_damo'), | |||||
Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), | Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), | ||||
Tasks.feature_extraction: (Pipelines.feature_extraction, | Tasks.feature_extraction: (Pipelines.feature_extraction, | ||||
'damo/pert_feature-extraction_base-test'), | 'damo/pert_feature-extraction_base-test'), | ||||
@@ -41,6 +41,7 @@ if TYPE_CHECKING: | |||||
from .live_category_pipeline import LiveCategoryPipeline | from .live_category_pipeline import LiveCategoryPipeline | ||||
from .ocr_detection_pipeline import OCRDetectionPipeline | from .ocr_detection_pipeline import OCRDetectionPipeline | ||||
from .ocr_recognition_pipeline import OCRRecognitionPipeline | from .ocr_recognition_pipeline import OCRRecognitionPipeline | ||||
from .license_plate_detection_pipeline import LicensePlateDetectionPipeline | |||||
from .table_recognition_pipeline import TableRecognitionPipeline | from .table_recognition_pipeline import TableRecognitionPipeline | ||||
from .skin_retouching_pipeline import SkinRetouchingPipeline | from .skin_retouching_pipeline import SkinRetouchingPipeline | ||||
from .tinynas_classification_pipeline import TinynasClassificationPipeline | from .tinynas_classification_pipeline import TinynasClassificationPipeline | ||||
@@ -109,6 +110,7 @@ else: | |||||
'image_inpainting_pipeline': ['ImageInpaintingPipeline'], | 'image_inpainting_pipeline': ['ImageInpaintingPipeline'], | ||||
'ocr_detection_pipeline': ['OCRDetectionPipeline'], | 'ocr_detection_pipeline': ['OCRDetectionPipeline'], | ||||
'ocr_recognition_pipeline': ['OCRRecognitionPipeline'], | 'ocr_recognition_pipeline': ['OCRRecognitionPipeline'], | ||||
'license_plate_detection_pipeline': ['LicensePlateDetectionPipeline'], | |||||
'table_recognition_pipeline': ['TableRecognitionPipeline'], | 'table_recognition_pipeline': ['TableRecognitionPipeline'], | ||||
'skin_retouching_pipeline': ['SkinRetouchingPipeline'], | 'skin_retouching_pipeline': ['SkinRetouchingPipeline'], | ||||
'tinynas_classification_pipeline': ['TinynasClassificationPipeline'], | 'tinynas_classification_pipeline': ['TinynasClassificationPipeline'], | ||||
@@ -0,0 +1,122 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import math | |||||
import os.path as osp | |||||
from typing import Any, Dict | |||||
import cv2 | |||||
import numpy as np | |||||
import PIL | |||||
import torch | |||||
from modelscope.metainfo import Pipelines | |||||
from modelscope.outputs import OutputKeys | |||||
from modelscope.pipelines.base import Input, Pipeline | |||||
from modelscope.pipelines.builder import PIPELINES | |||||
from modelscope.pipelines.cv.ocr_utils.model_resnet18_half import \ | |||||
LicensePlateDet | |||||
from modelscope.pipelines.cv.ocr_utils.table_process import ( | |||||
bbox_decode, bbox_post_process, decode_by_ind, get_affine_transform, nms) | |||||
from modelscope.preprocessors import load_image | |||||
from modelscope.preprocessors.image import LoadImage | |||||
from modelscope.utils.config import Config | |||||
from modelscope.utils.constant import ModelFile, Tasks | |||||
from modelscope.utils.logger import get_logger | |||||
logger = get_logger() | |||||
@PIPELINES.register_module( | |||||
Tasks.license_plate_detection, | |||||
module_name=Pipelines.license_plate_detection) | |||||
class LicensePlateDetection(Pipeline): | |||||
def __init__(self, model: str, **kwargs): | |||||
""" | |||||
Args: | |||||
model: model id on modelscope hub. | |||||
""" | |||||
super().__init__(model=model, **kwargs) | |||||
model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) | |||||
config_path = osp.join(self.model, ModelFile.CONFIGURATION) | |||||
logger.info(f'loading model from {model_path}') | |||||
self.cfg = Config.from_file(config_path) | |||||
self.K = self.cfg.K | |||||
self.car_type = self.cfg.Type | |||||
self.device = torch.device( | |||||
'cuda' if torch.cuda.is_available() else 'cpu') | |||||
self.infer_model = LicensePlateDet() | |||||
checkpoint = torch.load(model_path, map_location=self.device) | |||||
if 'state_dict' in checkpoint: | |||||
self.infer_model.load_state_dict(checkpoint['state_dict']) | |||||
else: | |||||
self.infer_model.load_state_dict(checkpoint) | |||||
self.infer_model = self.infer_model.to(self.device) | |||||
self.infer_model.to(self.device).eval() | |||||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
img = LoadImage.convert_to_ndarray(input)[:, :, ::-1] | |||||
mean = np.array([0.408, 0.447, 0.470], | |||||
dtype=np.float32).reshape(1, 1, 3) | |||||
std = np.array([0.289, 0.274, 0.278], | |||||
dtype=np.float32).reshape(1, 1, 3) | |||||
height, width = img.shape[0:2] | |||||
inp_height, inp_width = 512, 512 | |||||
c = np.array([width / 2., height / 2.], dtype=np.float32) | |||||
s = max(height, width) * 1.0 | |||||
trans_input = get_affine_transform(c, s, 0, [inp_width, inp_height]) | |||||
resized_image = cv2.resize(img, (width, height)) | |||||
inp_image = cv2.warpAffine( | |||||
resized_image, | |||||
trans_input, (inp_width, inp_height), | |||||
flags=cv2.INTER_LINEAR) | |||||
inp_image = ((inp_image / 255. - mean) / std).astype(np.float32) | |||||
images = inp_image.transpose(2, 0, 1).reshape(1, 3, inp_height, | |||||
inp_width) | |||||
images = torch.from_numpy(images).to(self.device) | |||||
meta = { | |||||
'c': c, | |||||
's': s, | |||||
'input_height': inp_height, | |||||
'input_width': inp_width, | |||||
'out_height': inp_height // 4, | |||||
'out_width': inp_width // 4 | |||||
} | |||||
result = {'img': images, 'meta': meta} | |||||
return result | |||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
pred = self.infer_model(input['img']) | |||||
return {'results': pred, 'meta': input['meta']} | |||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
output = inputs['results'][0] | |||||
meta = inputs['meta'] | |||||
hm = output['hm'].sigmoid_() | |||||
ftype = output['ftype'].sigmoid_() | |||||
wh = output['wh'] | |||||
reg = output['reg'] | |||||
bbox, inds = bbox_decode(hm, wh, reg=reg, K=self.K) | |||||
car_type = decode_by_ind(ftype, inds, K=self.K).detach().cpu().numpy() | |||||
bbox = bbox.detach().cpu().numpy() | |||||
for i in range(bbox.shape[1]): | |||||
bbox[0][i][9] = car_type[0][i] | |||||
bbox = nms(bbox, 0.3) | |||||
bbox = bbox_post_process(bbox.copy(), [meta['c'].cpu().numpy()], | |||||
[meta['s']], meta['out_height'], | |||||
meta['out_width']) | |||||
res, Type = [], [] | |||||
for box in bbox[0]: | |||||
if box[8] > 0.3: | |||||
res.append(box[0:8]) | |||||
Type.append(self.car_type[int(box[9])]) | |||||
result = {OutputKeys.POLYGONS: np.array(res), OutputKeys.TEXT: Type} | |||||
return result |
@@ -0,0 +1,275 @@ | |||||
# ------------------------------------------------------------------------------ | |||||
# The implementation is adopted from CenterNet, | |||||
# made publicly available under the MIT License at https://github.com/xingyizhou/CenterNet.git | |||||
# ------------------------------------------------------------------------------ | |||||
import os | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
BN_MOMENTUM = 0.1 | |||||
class BasicBlock(nn.Module): | |||||
expansion = 1 | |||||
def __init__(self, inplanes, planes, stride=1, downsample=None): | |||||
super(BasicBlock, self).__init__() | |||||
self.conv1 = nn.Conv2d( | |||||
inplanes, planes, kernel_size=3, stride=stride, padding=1) | |||||
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||||
self.relu = nn.ReLU(inplace=True) | |||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) | |||||
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||||
self.downsample = downsample | |||||
self.stride = stride | |||||
self.planes = planes | |||||
def forward(self, x): | |||||
residual = x | |||||
out = self.conv1(x) | |||||
out = self.bn1(out) | |||||
out = self.relu(out) | |||||
out = self.conv2(out) | |||||
out = self.bn2(out) | |||||
if self.downsample is not None: | |||||
residual = self.downsample(residual) | |||||
out += residual | |||||
out = self.relu(out) | |||||
return out | |||||
class Bottleneck(nn.Module): | |||||
expansion = 4 | |||||
def __init__(self, inplanes, planes, stride=1, downsample=None): | |||||
super(Bottleneck, self).__init__() | |||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | |||||
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||||
self.conv2 = nn.Conv2d( | |||||
planes, | |||||
planes, | |||||
kernel_size=3, | |||||
stride=stride, | |||||
padding=1, | |||||
bias=False) | |||||
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||||
self.conv3 = nn.Conv2d( | |||||
planes, planes * self.expansion, kernel_size=1, bias=False) | |||||
self.bn3 = nn.BatchNorm2d( | |||||
planes * self.expansion, momentum=BN_MOMENTUM) | |||||
self.relu = nn.ReLU(inplace=True) | |||||
self.downsample = downsample | |||||
self.stride = stride | |||||
def forward(self, x): | |||||
residual = x | |||||
out = self.conv1(x) | |||||
out = self.bn1(out) | |||||
out = self.relu(out) | |||||
out = self.conv2(out) | |||||
out = self.bn2(out) | |||||
out = self.relu(out) | |||||
out = self.conv3(out) | |||||
out = self.bn3(out) | |||||
if self.downsample is not None: | |||||
residual = self.downsample(x) | |||||
out += residual | |||||
out = self.relu(out) | |||||
return out | |||||
class PoseResNet(nn.Module): | |||||
def __init__(self, block, layers, head_conv=64, **kwargs): | |||||
self.inplanes = 64 | |||||
self.deconv_with_bias = False | |||||
self.heads = {'hm': 1, 'cls': 4, 'ftype': 11, 'wh': 8, 'reg': 2} | |||||
super(PoseResNet, self).__init__() | |||||
self.conv1 = nn.Conv2d( | |||||
3, 64, kernel_size=7, stride=2, padding=3, bias=False) | |||||
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) | |||||
self.relu = nn.ReLU(inplace=True) | |||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||||
self.layer1 = self._make_layer(block, 64, layers[0], stride=2) | |||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | |||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | |||||
self.layer4 = self._make_layer(block, 256, layers[3], stride=2) | |||||
self.adaption3 = nn.Conv2d( | |||||
256, 256, kernel_size=1, stride=1, padding=0, bias=False) | |||||
self.adaption2 = nn.Conv2d( | |||||
128, 256, kernel_size=1, stride=1, padding=0, bias=False) | |||||
self.adaption1 = nn.Conv2d( | |||||
64, 256, kernel_size=1, stride=1, padding=0, bias=False) | |||||
self.adaption0 = nn.Conv2d( | |||||
64, 256, kernel_size=1, stride=1, padding=0, bias=False) | |||||
self.adaptionU1 = nn.Conv2d( | |||||
256, 256, kernel_size=1, stride=1, padding=0, bias=False) | |||||
self.deconv_layers1 = self._make_deconv_layer( | |||||
1, | |||||
[256], | |||||
[4], | |||||
) | |||||
self.deconv_layers2 = self._make_deconv_layer( | |||||
1, | |||||
[256], | |||||
[4], | |||||
) | |||||
self.deconv_layers3 = self._make_deconv_layer( | |||||
1, | |||||
[256], | |||||
[4], | |||||
) | |||||
self.deconv_layers4 = self._make_deconv_layer( | |||||
1, | |||||
[256], | |||||
[4], | |||||
) | |||||
for head in sorted(self.heads): | |||||
num_output = self.heads[head] | |||||
if head_conv > 0: | |||||
inchannel = 256 | |||||
fc = nn.Sequential( | |||||
nn.Conv2d( | |||||
inchannel, | |||||
head_conv, | |||||
kernel_size=3, | |||||
padding=1, | |||||
bias=True), nn.ReLU(inplace=True), | |||||
nn.Conv2d( | |||||
head_conv, | |||||
num_output, | |||||
kernel_size=1, | |||||
stride=1, | |||||
padding=0)) | |||||
else: | |||||
inchannel = 256 | |||||
fc = nn.Conv2d( | |||||
in_channels=inchannel, | |||||
out_channels=num_output, | |||||
kernel_size=1, | |||||
stride=1, | |||||
padding=0) | |||||
self.__setattr__(head, fc) | |||||
def _make_layer(self, block, planes, blocks, stride=1): | |||||
downsample = None | |||||
if stride != 1 or self.inplanes != planes * block.expansion: | |||||
downsample = nn.Sequential( | |||||
nn.Conv2d( | |||||
self.inplanes, | |||||
planes * block.expansion, | |||||
kernel_size=1, | |||||
stride=stride, | |||||
bias=False), | |||||
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), | |||||
) | |||||
layers = [] | |||||
layers.append(block(self.inplanes, planes, stride, downsample)) | |||||
self.inplanes = planes * block.expansion | |||||
for i in range(1, blocks): | |||||
layers.append(block(self.inplanes, planes)) | |||||
return nn.Sequential(*layers) | |||||
def _get_deconv_cfg(self, deconv_kernel, index): | |||||
if deconv_kernel == 4: | |||||
padding = 1 | |||||
output_padding = 0 | |||||
elif deconv_kernel == 3: | |||||
padding = 1 | |||||
output_padding = 1 | |||||
elif deconv_kernel == 2: | |||||
padding = 0 | |||||
output_padding = 0 | |||||
elif deconv_kernel == 7: | |||||
padding = 3 | |||||
output_padding = 0 | |||||
return deconv_kernel, padding, output_padding | |||||
def _make_deconv_layer(self, num_layers, num_filters, num_kernels): | |||||
assert num_layers == len(num_filters), \ | |||||
'ERROR: num_deconv_layers is different len(num_deconv_filters)' | |||||
assert num_layers == len(num_kernels), \ | |||||
'ERROR: num_deconv_layers is different len(num_deconv_filters)' | |||||
layers = [] | |||||
for i in range(num_layers): | |||||
kernel, padding, output_padding = \ | |||||
self._get_deconv_cfg(num_kernels[i], i) | |||||
planes = num_filters[i] | |||||
layers.append( | |||||
nn.ConvTranspose2d( | |||||
in_channels=self.inplanes, | |||||
out_channels=planes, | |||||
kernel_size=kernel, | |||||
stride=2, | |||||
padding=padding, | |||||
output_padding=output_padding, | |||||
bias=self.deconv_with_bias)) | |||||
layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)) | |||||
layers.append(nn.ReLU(inplace=True)) | |||||
self.inplanes = planes | |||||
return nn.Sequential(*layers) | |||||
def forward(self, x): | |||||
x = self.conv1(x) | |||||
x = self.bn1(x) | |||||
x = self.relu(x) | |||||
x0 = self.maxpool(x) | |||||
x1 = self.layer1(x0) | |||||
x2 = self.layer2(x1) | |||||
x3 = self.layer3(x2) | |||||
x4 = self.layer4(x3) | |||||
x3_ = self.deconv_layers1(x4) | |||||
x3_ = self.adaption3(x3) + x3_ | |||||
x2_ = self.deconv_layers2(x3_) | |||||
x2_ = self.adaption2(x2) + x2_ | |||||
x1_ = self.deconv_layers3(x2_) | |||||
x1_ = self.adaption1(x1) + x1_ | |||||
x0_ = self.deconv_layers4(x1_) + self.adaption0(x0) | |||||
x0_ = self.adaptionU1(x0_) | |||||
ret = {} | |||||
for head in self.heads: | |||||
ret[head] = self.__getattr__(head)(x0_) | |||||
return [ret] | |||||
resnet_spec = { | |||||
18: (BasicBlock, [2, 2, 2, 2]), | |||||
34: (BasicBlock, [3, 4, 6, 3]), | |||||
50: (Bottleneck, [3, 4, 6, 3]), | |||||
101: (Bottleneck, [3, 4, 23, 3]), | |||||
152: (Bottleneck, [3, 8, 36, 3]) | |||||
} | |||||
def LicensePlateDet(num_layers=18): | |||||
block_class, layers = resnet_spec[num_layers] | |||||
model = PoseResNet(block_class, layers) | |||||
return model |
@@ -129,6 +129,14 @@ def _topk(scores, K=40): | |||||
return topk_score, topk_inds, topk_clses, topk_ys, topk_xs | return topk_score, topk_inds, topk_clses, topk_ys, topk_xs | ||||
def decode_by_ind(heat, inds, K=100): | |||||
batch, cat, height, width = heat.size() | |||||
score = _tranpose_and_gather_feat(heat, inds) | |||||
score = score.view(batch, K, cat) | |||||
_, Type = torch.max(score, 2) | |||||
return Type | |||||
def bbox_decode(heat, wh, reg=None, K=100): | def bbox_decode(heat, wh, reg=None, K=100): | ||||
batch, cat, height, width = heat.size() | batch, cat, height, width = heat.size() | ||||
@@ -163,7 +171,7 @@ def bbox_decode(heat, wh, reg=None, K=100): | |||||
) | ) | ||||
detections = torch.cat([bboxes, scores, clses], dim=2) | detections = torch.cat([bboxes, scores, clses], dim=2) | ||||
return detections, keep | |||||
return detections, inds | |||||
def gbox_decode(mk, st_reg, reg=None, K=400): | def gbox_decode(mk, st_reg, reg=None, K=400): | ||||
@@ -50,7 +50,7 @@ class TableRecognitionPipeline(Pipeline): | |||||
self.infer_model.load_state_dict(checkpoint) | self.infer_model.load_state_dict(checkpoint) | ||||
def preprocess(self, input: Input) -> Dict[str, Any]: | def preprocess(self, input: Input) -> Dict[str, Any]: | ||||
img = LoadImage.convert_to_ndarray(input) | |||||
img = LoadImage.convert_to_ndarray(input)[:, :, ::-1] | |||||
mean = np.array([0.408, 0.447, 0.470], | mean = np.array([0.408, 0.447, 0.470], | ||||
dtype=np.float32).reshape(1, 1, 3) | dtype=np.float32).reshape(1, 1, 3) | ||||
@@ -17,6 +17,7 @@ class CVTasks(object): | |||||
ocr_detection = 'ocr-detection' | ocr_detection = 'ocr-detection' | ||||
ocr_recognition = 'ocr-recognition' | ocr_recognition = 'ocr-recognition' | ||||
table_recognition = 'table-recognition' | table_recognition = 'table-recognition' | ||||
license_plate_detection = 'license-plate-detection' | |||||
# human face body related | # human face body related | ||||
animal_recognition = 'animal-recognition' | animal_recognition = 'animal-recognition' | ||||
@@ -0,0 +1,41 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import unittest | |||||
from modelscope.pipelines import pipeline | |||||
from modelscope.pipelines.base import Pipeline | |||||
from modelscope.utils.constant import Tasks | |||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||||
from modelscope.utils.test_utils import test_level | |||||
class LicensePlateDectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
def setUp(self) -> None: | |||||
self.model_id = 'damo/cv_resnet18_license-plate-detection_damo' | |||||
self.test_image = 'data/test/images/license_plate_detection.jpg' | |||||
self.task = Tasks.license_plate_detection | |||||
def pipeline_inference(self, pipe: Pipeline, input_location: str): | |||||
result = pipe(input_location) | |||||
print('license plate recognition results: ') | |||||
print(result) | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_run_with_model_from_modelhub(self): | |||||
license_plate_detection = pipeline( | |||||
Tasks.license_plate_detection, model=self.model_id) | |||||
self.pipeline_inference(license_plate_detection, self.test_image) | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run_modelhub_default_model(self): | |||||
license_plate_detection = pipeline(Tasks.license_plate_detection) | |||||
self.pipeline_inference(license_plate_detection, self.test_image) | |||||
@unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||||
def test_demo_compatibility(self): | |||||
self.compatibility_check() | |||||
if __name__ == '__main__': | |||||
unittest.main() |