From 9d8eb5b0b3b685fbec975d833be4260ebde1c1a2 Mon Sep 17 00:00:00 2001 From: "rujiao.lrj" Date: Thu, 1 Dec 2022 19:48:06 +0800 Subject: [PATCH] support license plate detection Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10917315 --- data/test/images/license_plate_detection.jpg | 3 + modelscope/metainfo.py | 1 + modelscope/outputs/outputs.py | 1 + modelscope/pipelines/builder.py | 3 + modelscope/pipelines/cv/__init__.py | 2 + .../cv/license_plate_detection_pipeline.py | 122 ++++++++ .../cv/ocr_utils/model_resnet18_half.py | 275 ++++++++++++++++++ .../pipelines/cv/ocr_utils/table_process.py | 10 +- .../cv/table_recognition_pipeline.py | 2 +- modelscope/utils/constant.py | 1 + .../pipelines/test_license_plate_detection.py | 41 +++ 11 files changed, 459 insertions(+), 2 deletions(-) create mode 100644 data/test/images/license_plate_detection.jpg create mode 100644 modelscope/pipelines/cv/license_plate_detection_pipeline.py create mode 100644 modelscope/pipelines/cv/ocr_utils/model_resnet18_half.py create mode 100644 tests/pipelines/test_license_plate_detection.py diff --git a/data/test/images/license_plate_detection.jpg b/data/test/images/license_plate_detection.jpg new file mode 100644 index 00000000..e61e54f1 --- /dev/null +++ b/data/test/images/license_plate_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:209f6ba7f15c9c34a02801b4c6ef33a979f3086702b5229d2e7975eb403c3e15 +size 45819 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index cc3ff3e7..1fccb46e 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -157,6 +157,7 @@ class Pipelines(object): person_image_cartoon = 'unet-person-image-cartoon' ocr_detection = 'resnet18-ocr-detection' table_recognition = 'dla34-table-recognition' + license_plate_detection = 'resnet18-license-plate-detection' action_recognition = 'TAdaConv_action-recognition' animal_recognition = 'resnet101-animal-recognition' general_recognition = 'resnet101-general-recognition' diff --git a/modelscope/outputs/outputs.py b/modelscope/outputs/outputs.py index b9ee0239..dbd1ec3c 100644 --- a/modelscope/outputs/outputs.py +++ b/modelscope/outputs/outputs.py @@ -62,6 +62,7 @@ TASK_OUTPUTS = { # } Tasks.ocr_detection: [OutputKeys.POLYGONS], Tasks.table_recognition: [OutputKeys.POLYGONS], + Tasks.license_plate_detection: [OutputKeys.POLYGONS, OutputKeys.TEXT], # ocr recognition result for single sample # { diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index c1634a9c..dac6011d 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -85,6 +85,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.table_recognition: (Pipelines.table_recognition, 'damo/cv_dla34_table-structure-recognition_cycle-centernet'), + Tasks.license_plate_detection: + (Pipelines.license_plate_detection, + 'damo/cv_resnet18_license-plate-detection_damo'), Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), Tasks.feature_extraction: (Pipelines.feature_extraction, 'damo/pert_feature-extraction_base-test'), diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index e196e8f7..e5bebe5f 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -41,6 +41,7 @@ if TYPE_CHECKING: from .live_category_pipeline import LiveCategoryPipeline from .ocr_detection_pipeline import OCRDetectionPipeline from .ocr_recognition_pipeline import OCRRecognitionPipeline + from .license_plate_detection_pipeline import LicensePlateDetectionPipeline from .table_recognition_pipeline import TableRecognitionPipeline from .skin_retouching_pipeline import SkinRetouchingPipeline from .tinynas_classification_pipeline import TinynasClassificationPipeline @@ -109,6 +110,7 @@ else: 'image_inpainting_pipeline': ['ImageInpaintingPipeline'], 'ocr_detection_pipeline': ['OCRDetectionPipeline'], 'ocr_recognition_pipeline': ['OCRRecognitionPipeline'], + 'license_plate_detection_pipeline': ['LicensePlateDetectionPipeline'], 'table_recognition_pipeline': ['TableRecognitionPipeline'], 'skin_retouching_pipeline': ['SkinRetouchingPipeline'], 'tinynas_classification_pipeline': ['TinynasClassificationPipeline'], diff --git a/modelscope/pipelines/cv/license_plate_detection_pipeline.py b/modelscope/pipelines/cv/license_plate_detection_pipeline.py new file mode 100644 index 00000000..a2ba4203 --- /dev/null +++ b/modelscope/pipelines/cv/license_plate_detection_pipeline.py @@ -0,0 +1,122 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.pipelines.cv.ocr_utils.model_resnet18_half import \ + LicensePlateDet +from modelscope.pipelines.cv.ocr_utils.table_process import ( + bbox_decode, bbox_post_process, decode_by_ind, get_affine_transform, nms) +from modelscope.preprocessors import load_image +from modelscope.preprocessors.image import LoadImage +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.license_plate_detection, + module_name=Pipelines.license_plate_detection) +class LicensePlateDetection(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) + config_path = osp.join(self.model, ModelFile.CONFIGURATION) + logger.info(f'loading model from {model_path}') + + self.cfg = Config.from_file(config_path) + self.K = self.cfg.K + self.car_type = self.cfg.Type + self.device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + self.infer_model = LicensePlateDet() + checkpoint = torch.load(model_path, map_location=self.device) + if 'state_dict' in checkpoint: + self.infer_model.load_state_dict(checkpoint['state_dict']) + else: + self.infer_model.load_state_dict(checkpoint) + self.infer_model = self.infer_model.to(self.device) + self.infer_model.to(self.device).eval() + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input)[:, :, ::-1] + + mean = np.array([0.408, 0.447, 0.470], + dtype=np.float32).reshape(1, 1, 3) + std = np.array([0.289, 0.274, 0.278], + dtype=np.float32).reshape(1, 1, 3) + height, width = img.shape[0:2] + inp_height, inp_width = 512, 512 + c = np.array([width / 2., height / 2.], dtype=np.float32) + s = max(height, width) * 1.0 + + trans_input = get_affine_transform(c, s, 0, [inp_width, inp_height]) + resized_image = cv2.resize(img, (width, height)) + inp_image = cv2.warpAffine( + resized_image, + trans_input, (inp_width, inp_height), + flags=cv2.INTER_LINEAR) + inp_image = ((inp_image / 255. - mean) / std).astype(np.float32) + + images = inp_image.transpose(2, 0, 1).reshape(1, 3, inp_height, + inp_width) + images = torch.from_numpy(images).to(self.device) + meta = { + 'c': c, + 's': s, + 'input_height': inp_height, + 'input_width': inp_width, + 'out_height': inp_height // 4, + 'out_width': inp_width // 4 + } + + result = {'img': images, 'meta': meta} + + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + pred = self.infer_model(input['img']) + return {'results': pred, 'meta': input['meta']} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + output = inputs['results'][0] + meta = inputs['meta'] + hm = output['hm'].sigmoid_() + ftype = output['ftype'].sigmoid_() + wh = output['wh'] + reg = output['reg'] + + bbox, inds = bbox_decode(hm, wh, reg=reg, K=self.K) + car_type = decode_by_ind(ftype, inds, K=self.K).detach().cpu().numpy() + bbox = bbox.detach().cpu().numpy() + for i in range(bbox.shape[1]): + bbox[0][i][9] = car_type[0][i] + bbox = nms(bbox, 0.3) + bbox = bbox_post_process(bbox.copy(), [meta['c'].cpu().numpy()], + [meta['s']], meta['out_height'], + meta['out_width']) + + res, Type = [], [] + for box in bbox[0]: + if box[8] > 0.3: + res.append(box[0:8]) + Type.append(self.car_type[int(box[9])]) + + result = {OutputKeys.POLYGONS: np.array(res), OutputKeys.TEXT: Type} + return result diff --git a/modelscope/pipelines/cv/ocr_utils/model_resnet18_half.py b/modelscope/pipelines/cv/ocr_utils/model_resnet18_half.py new file mode 100644 index 00000000..2d771eb4 --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/model_resnet18_half.py @@ -0,0 +1,275 @@ +# ------------------------------------------------------------------------------ +# The implementation is adopted from CenterNet, +# made publicly available under the MIT License at https://github.com/xingyizhou/CenterNet.git +# ------------------------------------------------------------------------------ + +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F + +BN_MOMENTUM = 0.1 + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d( + inplanes, planes, kernel_size=3, stride=stride, padding=1) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + self.planes = planes + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + if self.downsample is not None: + residual = self.downsample(residual) + + out += residual + out = self.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d( + planes * self.expansion, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class PoseResNet(nn.Module): + + def __init__(self, block, layers, head_conv=64, **kwargs): + self.inplanes = 64 + self.deconv_with_bias = False + self.heads = {'hm': 1, 'cls': 4, 'ftype': 11, 'wh': 8, 'reg': 2} + + super(PoseResNet, self).__init__() + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0], stride=2) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 256, layers[3], stride=2) + + self.adaption3 = nn.Conv2d( + 256, 256, kernel_size=1, stride=1, padding=0, bias=False) + self.adaption2 = nn.Conv2d( + 128, 256, kernel_size=1, stride=1, padding=0, bias=False) + self.adaption1 = nn.Conv2d( + 64, 256, kernel_size=1, stride=1, padding=0, bias=False) + self.adaption0 = nn.Conv2d( + 64, 256, kernel_size=1, stride=1, padding=0, bias=False) + + self.adaptionU1 = nn.Conv2d( + 256, 256, kernel_size=1, stride=1, padding=0, bias=False) + + self.deconv_layers1 = self._make_deconv_layer( + 1, + [256], + [4], + ) + self.deconv_layers2 = self._make_deconv_layer( + 1, + [256], + [4], + ) + self.deconv_layers3 = self._make_deconv_layer( + 1, + [256], + [4], + ) + self.deconv_layers4 = self._make_deconv_layer( + 1, + [256], + [4], + ) + + for head in sorted(self.heads): + num_output = self.heads[head] + if head_conv > 0: + inchannel = 256 + fc = nn.Sequential( + nn.Conv2d( + inchannel, + head_conv, + kernel_size=3, + padding=1, + bias=True), nn.ReLU(inplace=True), + nn.Conv2d( + head_conv, + num_output, + kernel_size=1, + stride=1, + padding=0)) + else: + inchannel = 256 + fc = nn.Conv2d( + in_channels=inchannel, + out_channels=num_output, + kernel_size=1, + stride=1, + padding=0) + self.__setattr__(head, fc) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def _get_deconv_cfg(self, deconv_kernel, index): + if deconv_kernel == 4: + padding = 1 + output_padding = 0 + elif deconv_kernel == 3: + padding = 1 + output_padding = 1 + elif deconv_kernel == 2: + padding = 0 + output_padding = 0 + elif deconv_kernel == 7: + padding = 3 + output_padding = 0 + + return deconv_kernel, padding, output_padding + + def _make_deconv_layer(self, num_layers, num_filters, num_kernels): + assert num_layers == len(num_filters), \ + 'ERROR: num_deconv_layers is different len(num_deconv_filters)' + assert num_layers == len(num_kernels), \ + 'ERROR: num_deconv_layers is different len(num_deconv_filters)' + + layers = [] + for i in range(num_layers): + kernel, padding, output_padding = \ + self._get_deconv_cfg(num_kernels[i], i) + + planes = num_filters[i] + layers.append( + nn.ConvTranspose2d( + in_channels=self.inplanes, + out_channels=planes, + kernel_size=kernel, + stride=2, + padding=padding, + output_padding=output_padding, + bias=self.deconv_with_bias)) + layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)) + layers.append(nn.ReLU(inplace=True)) + self.inplanes = planes + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x0 = self.maxpool(x) + x1 = self.layer1(x0) + x2 = self.layer2(x1) + x3 = self.layer3(x2) + x4 = self.layer4(x3) + + x3_ = self.deconv_layers1(x4) + x3_ = self.adaption3(x3) + x3_ + + x2_ = self.deconv_layers2(x3_) + x2_ = self.adaption2(x2) + x2_ + + x1_ = self.deconv_layers3(x2_) + x1_ = self.adaption1(x1) + x1_ + + x0_ = self.deconv_layers4(x1_) + self.adaption0(x0) + x0_ = self.adaptionU1(x0_) + + ret = {} + + for head in self.heads: + ret[head] = self.__getattr__(head)(x0_) + return [ret] + + +resnet_spec = { + 18: (BasicBlock, [2, 2, 2, 2]), + 34: (BasicBlock, [3, 4, 6, 3]), + 50: (Bottleneck, [3, 4, 6, 3]), + 101: (Bottleneck, [3, 4, 23, 3]), + 152: (Bottleneck, [3, 8, 36, 3]) +} + + +def LicensePlateDet(num_layers=18): + block_class, layers = resnet_spec[num_layers] + model = PoseResNet(block_class, layers) + return model diff --git a/modelscope/pipelines/cv/ocr_utils/table_process.py b/modelscope/pipelines/cv/ocr_utils/table_process.py index 864ec71d..3bf28e84 100644 --- a/modelscope/pipelines/cv/ocr_utils/table_process.py +++ b/modelscope/pipelines/cv/ocr_utils/table_process.py @@ -129,6 +129,14 @@ def _topk(scores, K=40): return topk_score, topk_inds, topk_clses, topk_ys, topk_xs +def decode_by_ind(heat, inds, K=100): + batch, cat, height, width = heat.size() + score = _tranpose_and_gather_feat(heat, inds) + score = score.view(batch, K, cat) + _, Type = torch.max(score, 2) + return Type + + def bbox_decode(heat, wh, reg=None, K=100): batch, cat, height, width = heat.size() @@ -163,7 +171,7 @@ def bbox_decode(heat, wh, reg=None, K=100): ) detections = torch.cat([bboxes, scores, clses], dim=2) - return detections, keep + return detections, inds def gbox_decode(mk, st_reg, reg=None, K=400): diff --git a/modelscope/pipelines/cv/table_recognition_pipeline.py b/modelscope/pipelines/cv/table_recognition_pipeline.py index 1ee9a4f0..8608cd06 100644 --- a/modelscope/pipelines/cv/table_recognition_pipeline.py +++ b/modelscope/pipelines/cv/table_recognition_pipeline.py @@ -50,7 +50,7 @@ class TableRecognitionPipeline(Pipeline): self.infer_model.load_state_dict(checkpoint) def preprocess(self, input: Input) -> Dict[str, Any]: - img = LoadImage.convert_to_ndarray(input) + img = LoadImage.convert_to_ndarray(input)[:, :, ::-1] mean = np.array([0.408, 0.447, 0.470], dtype=np.float32).reshape(1, 1, 3) diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 46817703..8376c971 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -17,6 +17,7 @@ class CVTasks(object): ocr_detection = 'ocr-detection' ocr_recognition = 'ocr-recognition' table_recognition = 'table-recognition' + license_plate_detection = 'license-plate-detection' # human face body related animal_recognition = 'animal-recognition' diff --git a/tests/pipelines/test_license_plate_detection.py b/tests/pipelines/test_license_plate_detection.py new file mode 100644 index 00000000..70cdb820 --- /dev/null +++ b/tests/pipelines/test_license_plate_detection.py @@ -0,0 +1,41 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class LicensePlateDectionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.model_id = 'damo/cv_resnet18_license-plate-detection_damo' + self.test_image = 'data/test/images/license_plate_detection.jpg' + self.task = Tasks.license_plate_detection + + def pipeline_inference(self, pipe: Pipeline, input_location: str): + result = pipe(input_location) + print('license plate recognition results: ') + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + license_plate_detection = pipeline( + Tasks.license_plate_detection, model=self.model_id) + self.pipeline_inference(license_plate_detection, self.test_image) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + license_plate_detection = pipeline(Tasks.license_plate_detection) + self.pipeline_inference(license_plate_detection, self.test_image) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main()