Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10917315master^2
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:209f6ba7f15c9c34a02801b4c6ef33a979f3086702b5229d2e7975eb403c3e15 | |||
| size 45819 | |||
| @@ -157,6 +157,7 @@ class Pipelines(object): | |||
| person_image_cartoon = 'unet-person-image-cartoon' | |||
| ocr_detection = 'resnet18-ocr-detection' | |||
| table_recognition = 'dla34-table-recognition' | |||
| license_plate_detection = 'resnet18-license-plate-detection' | |||
| action_recognition = 'TAdaConv_action-recognition' | |||
| animal_recognition = 'resnet101-animal-recognition' | |||
| general_recognition = 'resnet101-general-recognition' | |||
| @@ -62,6 +62,7 @@ TASK_OUTPUTS = { | |||
| # } | |||
| Tasks.ocr_detection: [OutputKeys.POLYGONS], | |||
| Tasks.table_recognition: [OutputKeys.POLYGONS], | |||
| Tasks.license_plate_detection: [OutputKeys.POLYGONS, OutputKeys.TEXT], | |||
| # ocr recognition result for single sample | |||
| # { | |||
| @@ -85,6 +85,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.table_recognition: | |||
| (Pipelines.table_recognition, | |||
| 'damo/cv_dla34_table-structure-recognition_cycle-centernet'), | |||
| Tasks.license_plate_detection: | |||
| (Pipelines.license_plate_detection, | |||
| 'damo/cv_resnet18_license-plate-detection_damo'), | |||
| Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), | |||
| Tasks.feature_extraction: (Pipelines.feature_extraction, | |||
| 'damo/pert_feature-extraction_base-test'), | |||
| @@ -41,6 +41,7 @@ if TYPE_CHECKING: | |||
| from .live_category_pipeline import LiveCategoryPipeline | |||
| from .ocr_detection_pipeline import OCRDetectionPipeline | |||
| from .ocr_recognition_pipeline import OCRRecognitionPipeline | |||
| from .license_plate_detection_pipeline import LicensePlateDetectionPipeline | |||
| from .table_recognition_pipeline import TableRecognitionPipeline | |||
| from .skin_retouching_pipeline import SkinRetouchingPipeline | |||
| from .tinynas_classification_pipeline import TinynasClassificationPipeline | |||
| @@ -109,6 +110,7 @@ else: | |||
| 'image_inpainting_pipeline': ['ImageInpaintingPipeline'], | |||
| 'ocr_detection_pipeline': ['OCRDetectionPipeline'], | |||
| 'ocr_recognition_pipeline': ['OCRRecognitionPipeline'], | |||
| 'license_plate_detection_pipeline': ['LicensePlateDetectionPipeline'], | |||
| 'table_recognition_pipeline': ['TableRecognitionPipeline'], | |||
| 'skin_retouching_pipeline': ['SkinRetouchingPipeline'], | |||
| 'tinynas_classification_pipeline': ['TinynasClassificationPipeline'], | |||
| @@ -0,0 +1,122 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import math | |||
| import os.path as osp | |||
| from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.pipelines.cv.ocr_utils.model_resnet18_half import \ | |||
| LicensePlateDet | |||
| from modelscope.pipelines.cv.ocr_utils.table_process import ( | |||
| bbox_decode, bbox_post_process, decode_by_ind, get_affine_transform, nms) | |||
| from modelscope.preprocessors import load_image | |||
| from modelscope.preprocessors.image import LoadImage | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.license_plate_detection, | |||
| module_name=Pipelines.license_plate_detection) | |||
| class LicensePlateDetection(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model, **kwargs) | |||
| model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) | |||
| config_path = osp.join(self.model, ModelFile.CONFIGURATION) | |||
| logger.info(f'loading model from {model_path}') | |||
| self.cfg = Config.from_file(config_path) | |||
| self.K = self.cfg.K | |||
| self.car_type = self.cfg.Type | |||
| self.device = torch.device( | |||
| 'cuda' if torch.cuda.is_available() else 'cpu') | |||
| self.infer_model = LicensePlateDet() | |||
| checkpoint = torch.load(model_path, map_location=self.device) | |||
| if 'state_dict' in checkpoint: | |||
| self.infer_model.load_state_dict(checkpoint['state_dict']) | |||
| else: | |||
| self.infer_model.load_state_dict(checkpoint) | |||
| self.infer_model = self.infer_model.to(self.device) | |||
| self.infer_model.to(self.device).eval() | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| img = LoadImage.convert_to_ndarray(input)[:, :, ::-1] | |||
| mean = np.array([0.408, 0.447, 0.470], | |||
| dtype=np.float32).reshape(1, 1, 3) | |||
| std = np.array([0.289, 0.274, 0.278], | |||
| dtype=np.float32).reshape(1, 1, 3) | |||
| height, width = img.shape[0:2] | |||
| inp_height, inp_width = 512, 512 | |||
| c = np.array([width / 2., height / 2.], dtype=np.float32) | |||
| s = max(height, width) * 1.0 | |||
| trans_input = get_affine_transform(c, s, 0, [inp_width, inp_height]) | |||
| resized_image = cv2.resize(img, (width, height)) | |||
| inp_image = cv2.warpAffine( | |||
| resized_image, | |||
| trans_input, (inp_width, inp_height), | |||
| flags=cv2.INTER_LINEAR) | |||
| inp_image = ((inp_image / 255. - mean) / std).astype(np.float32) | |||
| images = inp_image.transpose(2, 0, 1).reshape(1, 3, inp_height, | |||
| inp_width) | |||
| images = torch.from_numpy(images).to(self.device) | |||
| meta = { | |||
| 'c': c, | |||
| 's': s, | |||
| 'input_height': inp_height, | |||
| 'input_width': inp_width, | |||
| 'out_height': inp_height // 4, | |||
| 'out_width': inp_width // 4 | |||
| } | |||
| result = {'img': images, 'meta': meta} | |||
| return result | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| pred = self.infer_model(input['img']) | |||
| return {'results': pred, 'meta': input['meta']} | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| output = inputs['results'][0] | |||
| meta = inputs['meta'] | |||
| hm = output['hm'].sigmoid_() | |||
| ftype = output['ftype'].sigmoid_() | |||
| wh = output['wh'] | |||
| reg = output['reg'] | |||
| bbox, inds = bbox_decode(hm, wh, reg=reg, K=self.K) | |||
| car_type = decode_by_ind(ftype, inds, K=self.K).detach().cpu().numpy() | |||
| bbox = bbox.detach().cpu().numpy() | |||
| for i in range(bbox.shape[1]): | |||
| bbox[0][i][9] = car_type[0][i] | |||
| bbox = nms(bbox, 0.3) | |||
| bbox = bbox_post_process(bbox.copy(), [meta['c'].cpu().numpy()], | |||
| [meta['s']], meta['out_height'], | |||
| meta['out_width']) | |||
| res, Type = [], [] | |||
| for box in bbox[0]: | |||
| if box[8] > 0.3: | |||
| res.append(box[0:8]) | |||
| Type.append(self.car_type[int(box[9])]) | |||
| result = {OutputKeys.POLYGONS: np.array(res), OutputKeys.TEXT: Type} | |||
| return result | |||
| @@ -0,0 +1,275 @@ | |||
| # ------------------------------------------------------------------------------ | |||
| # The implementation is adopted from CenterNet, | |||
| # made publicly available under the MIT License at https://github.com/xingyizhou/CenterNet.git | |||
| # ------------------------------------------------------------------------------ | |||
| import os | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| BN_MOMENTUM = 0.1 | |||
| class BasicBlock(nn.Module): | |||
| expansion = 1 | |||
| def __init__(self, inplanes, planes, stride=1, downsample=None): | |||
| super(BasicBlock, self).__init__() | |||
| self.conv1 = nn.Conv2d( | |||
| inplanes, planes, kernel_size=3, stride=stride, padding=1) | |||
| self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) | |||
| self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||
| self.downsample = downsample | |||
| self.stride = stride | |||
| self.planes = planes | |||
| def forward(self, x): | |||
| residual = x | |||
| out = self.conv1(x) | |||
| out = self.bn1(out) | |||
| out = self.relu(out) | |||
| out = self.conv2(out) | |||
| out = self.bn2(out) | |||
| if self.downsample is not None: | |||
| residual = self.downsample(residual) | |||
| out += residual | |||
| out = self.relu(out) | |||
| return out | |||
| class Bottleneck(nn.Module): | |||
| expansion = 4 | |||
| def __init__(self, inplanes, planes, stride=1, downsample=None): | |||
| super(Bottleneck, self).__init__() | |||
| self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | |||
| self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||
| self.conv2 = nn.Conv2d( | |||
| planes, | |||
| planes, | |||
| kernel_size=3, | |||
| stride=stride, | |||
| padding=1, | |||
| bias=False) | |||
| self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||
| self.conv3 = nn.Conv2d( | |||
| planes, planes * self.expansion, kernel_size=1, bias=False) | |||
| self.bn3 = nn.BatchNorm2d( | |||
| planes * self.expansion, momentum=BN_MOMENTUM) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.downsample = downsample | |||
| self.stride = stride | |||
| def forward(self, x): | |||
| residual = x | |||
| out = self.conv1(x) | |||
| out = self.bn1(out) | |||
| out = self.relu(out) | |||
| out = self.conv2(out) | |||
| out = self.bn2(out) | |||
| out = self.relu(out) | |||
| out = self.conv3(out) | |||
| out = self.bn3(out) | |||
| if self.downsample is not None: | |||
| residual = self.downsample(x) | |||
| out += residual | |||
| out = self.relu(out) | |||
| return out | |||
| class PoseResNet(nn.Module): | |||
| def __init__(self, block, layers, head_conv=64, **kwargs): | |||
| self.inplanes = 64 | |||
| self.deconv_with_bias = False | |||
| self.heads = {'hm': 1, 'cls': 4, 'ftype': 11, 'wh': 8, 'reg': 2} | |||
| super(PoseResNet, self).__init__() | |||
| self.conv1 = nn.Conv2d( | |||
| 3, 64, kernel_size=7, stride=2, padding=3, bias=False) | |||
| self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||
| self.layer1 = self._make_layer(block, 64, layers[0], stride=2) | |||
| self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | |||
| self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | |||
| self.layer4 = self._make_layer(block, 256, layers[3], stride=2) | |||
| self.adaption3 = nn.Conv2d( | |||
| 256, 256, kernel_size=1, stride=1, padding=0, bias=False) | |||
| self.adaption2 = nn.Conv2d( | |||
| 128, 256, kernel_size=1, stride=1, padding=0, bias=False) | |||
| self.adaption1 = nn.Conv2d( | |||
| 64, 256, kernel_size=1, stride=1, padding=0, bias=False) | |||
| self.adaption0 = nn.Conv2d( | |||
| 64, 256, kernel_size=1, stride=1, padding=0, bias=False) | |||
| self.adaptionU1 = nn.Conv2d( | |||
| 256, 256, kernel_size=1, stride=1, padding=0, bias=False) | |||
| self.deconv_layers1 = self._make_deconv_layer( | |||
| 1, | |||
| [256], | |||
| [4], | |||
| ) | |||
| self.deconv_layers2 = self._make_deconv_layer( | |||
| 1, | |||
| [256], | |||
| [4], | |||
| ) | |||
| self.deconv_layers3 = self._make_deconv_layer( | |||
| 1, | |||
| [256], | |||
| [4], | |||
| ) | |||
| self.deconv_layers4 = self._make_deconv_layer( | |||
| 1, | |||
| [256], | |||
| [4], | |||
| ) | |||
| for head in sorted(self.heads): | |||
| num_output = self.heads[head] | |||
| if head_conv > 0: | |||
| inchannel = 256 | |||
| fc = nn.Sequential( | |||
| nn.Conv2d( | |||
| inchannel, | |||
| head_conv, | |||
| kernel_size=3, | |||
| padding=1, | |||
| bias=True), nn.ReLU(inplace=True), | |||
| nn.Conv2d( | |||
| head_conv, | |||
| num_output, | |||
| kernel_size=1, | |||
| stride=1, | |||
| padding=0)) | |||
| else: | |||
| inchannel = 256 | |||
| fc = nn.Conv2d( | |||
| in_channels=inchannel, | |||
| out_channels=num_output, | |||
| kernel_size=1, | |||
| stride=1, | |||
| padding=0) | |||
| self.__setattr__(head, fc) | |||
| def _make_layer(self, block, planes, blocks, stride=1): | |||
| downsample = None | |||
| if stride != 1 or self.inplanes != planes * block.expansion: | |||
| downsample = nn.Sequential( | |||
| nn.Conv2d( | |||
| self.inplanes, | |||
| planes * block.expansion, | |||
| kernel_size=1, | |||
| stride=stride, | |||
| bias=False), | |||
| nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), | |||
| ) | |||
| layers = [] | |||
| layers.append(block(self.inplanes, planes, stride, downsample)) | |||
| self.inplanes = planes * block.expansion | |||
| for i in range(1, blocks): | |||
| layers.append(block(self.inplanes, planes)) | |||
| return nn.Sequential(*layers) | |||
| def _get_deconv_cfg(self, deconv_kernel, index): | |||
| if deconv_kernel == 4: | |||
| padding = 1 | |||
| output_padding = 0 | |||
| elif deconv_kernel == 3: | |||
| padding = 1 | |||
| output_padding = 1 | |||
| elif deconv_kernel == 2: | |||
| padding = 0 | |||
| output_padding = 0 | |||
| elif deconv_kernel == 7: | |||
| padding = 3 | |||
| output_padding = 0 | |||
| return deconv_kernel, padding, output_padding | |||
| def _make_deconv_layer(self, num_layers, num_filters, num_kernels): | |||
| assert num_layers == len(num_filters), \ | |||
| 'ERROR: num_deconv_layers is different len(num_deconv_filters)' | |||
| assert num_layers == len(num_kernels), \ | |||
| 'ERROR: num_deconv_layers is different len(num_deconv_filters)' | |||
| layers = [] | |||
| for i in range(num_layers): | |||
| kernel, padding, output_padding = \ | |||
| self._get_deconv_cfg(num_kernels[i], i) | |||
| planes = num_filters[i] | |||
| layers.append( | |||
| nn.ConvTranspose2d( | |||
| in_channels=self.inplanes, | |||
| out_channels=planes, | |||
| kernel_size=kernel, | |||
| stride=2, | |||
| padding=padding, | |||
| output_padding=output_padding, | |||
| bias=self.deconv_with_bias)) | |||
| layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)) | |||
| layers.append(nn.ReLU(inplace=True)) | |||
| self.inplanes = planes | |||
| return nn.Sequential(*layers) | |||
| def forward(self, x): | |||
| x = self.conv1(x) | |||
| x = self.bn1(x) | |||
| x = self.relu(x) | |||
| x0 = self.maxpool(x) | |||
| x1 = self.layer1(x0) | |||
| x2 = self.layer2(x1) | |||
| x3 = self.layer3(x2) | |||
| x4 = self.layer4(x3) | |||
| x3_ = self.deconv_layers1(x4) | |||
| x3_ = self.adaption3(x3) + x3_ | |||
| x2_ = self.deconv_layers2(x3_) | |||
| x2_ = self.adaption2(x2) + x2_ | |||
| x1_ = self.deconv_layers3(x2_) | |||
| x1_ = self.adaption1(x1) + x1_ | |||
| x0_ = self.deconv_layers4(x1_) + self.adaption0(x0) | |||
| x0_ = self.adaptionU1(x0_) | |||
| ret = {} | |||
| for head in self.heads: | |||
| ret[head] = self.__getattr__(head)(x0_) | |||
| return [ret] | |||
| resnet_spec = { | |||
| 18: (BasicBlock, [2, 2, 2, 2]), | |||
| 34: (BasicBlock, [3, 4, 6, 3]), | |||
| 50: (Bottleneck, [3, 4, 6, 3]), | |||
| 101: (Bottleneck, [3, 4, 23, 3]), | |||
| 152: (Bottleneck, [3, 8, 36, 3]) | |||
| } | |||
| def LicensePlateDet(num_layers=18): | |||
| block_class, layers = resnet_spec[num_layers] | |||
| model = PoseResNet(block_class, layers) | |||
| return model | |||
| @@ -129,6 +129,14 @@ def _topk(scores, K=40): | |||
| return topk_score, topk_inds, topk_clses, topk_ys, topk_xs | |||
| def decode_by_ind(heat, inds, K=100): | |||
| batch, cat, height, width = heat.size() | |||
| score = _tranpose_and_gather_feat(heat, inds) | |||
| score = score.view(batch, K, cat) | |||
| _, Type = torch.max(score, 2) | |||
| return Type | |||
| def bbox_decode(heat, wh, reg=None, K=100): | |||
| batch, cat, height, width = heat.size() | |||
| @@ -163,7 +171,7 @@ def bbox_decode(heat, wh, reg=None, K=100): | |||
| ) | |||
| detections = torch.cat([bboxes, scores, clses], dim=2) | |||
| return detections, keep | |||
| return detections, inds | |||
| def gbox_decode(mk, st_reg, reg=None, K=400): | |||
| @@ -50,7 +50,7 @@ class TableRecognitionPipeline(Pipeline): | |||
| self.infer_model.load_state_dict(checkpoint) | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| img = LoadImage.convert_to_ndarray(input) | |||
| img = LoadImage.convert_to_ndarray(input)[:, :, ::-1] | |||
| mean = np.array([0.408, 0.447, 0.470], | |||
| dtype=np.float32).reshape(1, 1, 3) | |||
| @@ -17,6 +17,7 @@ class CVTasks(object): | |||
| ocr_detection = 'ocr-detection' | |||
| ocr_recognition = 'ocr-recognition' | |||
| table_recognition = 'table-recognition' | |||
| license_plate_detection = 'license-plate-detection' | |||
| # human face body related | |||
| animal_recognition = 'animal-recognition' | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
| from modelscope.utils.test_utils import test_level | |||
| class LicensePlateDectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/cv_resnet18_license-plate-detection_damo' | |||
| self.test_image = 'data/test/images/license_plate_detection.jpg' | |||
| self.task = Tasks.license_plate_detection | |||
| def pipeline_inference(self, pipe: Pipeline, input_location: str): | |||
| result = pipe(input_location) | |||
| print('license plate recognition results: ') | |||
| print(result) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| license_plate_detection = pipeline( | |||
| Tasks.license_plate_detection, model=self.model_id) | |||
| self.pipeline_inference(license_plate_detection, self.test_image) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_modelhub_default_model(self): | |||
| license_plate_detection = pipeline(Tasks.license_plate_detection) | |||
| self.pipeline_inference(license_plate_detection, self.test_image) | |||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||
| def test_demo_compatibility(self): | |||
| self.compatibility_check() | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||