rujiao.lrj yingda.chen 2 years ago
parent
commit
9d8eb5b0b3
11 changed files with 459 additions and 2 deletions
  1. +3
    -0
      data/test/images/license_plate_detection.jpg
  2. +1
    -0
      modelscope/metainfo.py
  3. +1
    -0
      modelscope/outputs/outputs.py
  4. +3
    -0
      modelscope/pipelines/builder.py
  5. +2
    -0
      modelscope/pipelines/cv/__init__.py
  6. +122
    -0
      modelscope/pipelines/cv/license_plate_detection_pipeline.py
  7. +275
    -0
      modelscope/pipelines/cv/ocr_utils/model_resnet18_half.py
  8. +9
    -1
      modelscope/pipelines/cv/ocr_utils/table_process.py
  9. +1
    -1
      modelscope/pipelines/cv/table_recognition_pipeline.py
  10. +1
    -0
      modelscope/utils/constant.py
  11. +41
    -0
      tests/pipelines/test_license_plate_detection.py

+ 3
- 0
data/test/images/license_plate_detection.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:209f6ba7f15c9c34a02801b4c6ef33a979f3086702b5229d2e7975eb403c3e15
size 45819

+ 1
- 0
modelscope/metainfo.py View File

@@ -157,6 +157,7 @@ class Pipelines(object):
person_image_cartoon = 'unet-person-image-cartoon'
ocr_detection = 'resnet18-ocr-detection'
table_recognition = 'dla34-table-recognition'
license_plate_detection = 'resnet18-license-plate-detection'
action_recognition = 'TAdaConv_action-recognition'
animal_recognition = 'resnet101-animal-recognition'
general_recognition = 'resnet101-general-recognition'


+ 1
- 0
modelscope/outputs/outputs.py View File

@@ -62,6 +62,7 @@ TASK_OUTPUTS = {
# }
Tasks.ocr_detection: [OutputKeys.POLYGONS],
Tasks.table_recognition: [OutputKeys.POLYGONS],
Tasks.license_plate_detection: [OutputKeys.POLYGONS, OutputKeys.TEXT],

# ocr recognition result for single sample
# {


+ 3
- 0
modelscope/pipelines/builder.py View File

@@ -85,6 +85,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.table_recognition:
(Pipelines.table_recognition,
'damo/cv_dla34_table-structure-recognition_cycle-centernet'),
Tasks.license_plate_detection:
(Pipelines.license_plate_detection,
'damo/cv_resnet18_license-plate-detection_damo'),
Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'),
Tasks.feature_extraction: (Pipelines.feature_extraction,
'damo/pert_feature-extraction_base-test'),


+ 2
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -41,6 +41,7 @@ if TYPE_CHECKING:
from .live_category_pipeline import LiveCategoryPipeline
from .ocr_detection_pipeline import OCRDetectionPipeline
from .ocr_recognition_pipeline import OCRRecognitionPipeline
from .license_plate_detection_pipeline import LicensePlateDetectionPipeline
from .table_recognition_pipeline import TableRecognitionPipeline
from .skin_retouching_pipeline import SkinRetouchingPipeline
from .tinynas_classification_pipeline import TinynasClassificationPipeline
@@ -109,6 +110,7 @@ else:
'image_inpainting_pipeline': ['ImageInpaintingPipeline'],
'ocr_detection_pipeline': ['OCRDetectionPipeline'],
'ocr_recognition_pipeline': ['OCRRecognitionPipeline'],
'license_plate_detection_pipeline': ['LicensePlateDetectionPipeline'],
'table_recognition_pipeline': ['TableRecognitionPipeline'],
'skin_retouching_pipeline': ['SkinRetouchingPipeline'],
'tinynas_classification_pipeline': ['TinynasClassificationPipeline'],


+ 122
- 0
modelscope/pipelines/cv/license_plate_detection_pipeline.py View File

@@ -0,0 +1,122 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os.path as osp
from typing import Any, Dict

import cv2
import numpy as np
import PIL
import torch

from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.pipelines.cv.ocr_utils.model_resnet18_half import \
LicensePlateDet
from modelscope.pipelines.cv.ocr_utils.table_process import (
bbox_decode, bbox_post_process, decode_by_ind, get_affine_transform, nms)
from modelscope.preprocessors import load_image
from modelscope.preprocessors.image import LoadImage
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.license_plate_detection,
module_name=Pipelines.license_plate_detection)
class LicensePlateDetection(Pipeline):

def __init__(self, model: str, **kwargs):
"""
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
config_path = osp.join(self.model, ModelFile.CONFIGURATION)
logger.info(f'loading model from {model_path}')

self.cfg = Config.from_file(config_path)
self.K = self.cfg.K
self.car_type = self.cfg.Type
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
self.infer_model = LicensePlateDet()
checkpoint = torch.load(model_path, map_location=self.device)
if 'state_dict' in checkpoint:
self.infer_model.load_state_dict(checkpoint['state_dict'])
else:
self.infer_model.load_state_dict(checkpoint)
self.infer_model = self.infer_model.to(self.device)
self.infer_model.to(self.device).eval()

def preprocess(self, input: Input) -> Dict[str, Any]:
img = LoadImage.convert_to_ndarray(input)[:, :, ::-1]

mean = np.array([0.408, 0.447, 0.470],
dtype=np.float32).reshape(1, 1, 3)
std = np.array([0.289, 0.274, 0.278],
dtype=np.float32).reshape(1, 1, 3)
height, width = img.shape[0:2]
inp_height, inp_width = 512, 512
c = np.array([width / 2., height / 2.], dtype=np.float32)
s = max(height, width) * 1.0

trans_input = get_affine_transform(c, s, 0, [inp_width, inp_height])
resized_image = cv2.resize(img, (width, height))
inp_image = cv2.warpAffine(
resized_image,
trans_input, (inp_width, inp_height),
flags=cv2.INTER_LINEAR)
inp_image = ((inp_image / 255. - mean) / std).astype(np.float32)

images = inp_image.transpose(2, 0, 1).reshape(1, 3, inp_height,
inp_width)
images = torch.from_numpy(images).to(self.device)
meta = {
'c': c,
's': s,
'input_height': inp_height,
'input_width': inp_width,
'out_height': inp_height // 4,
'out_width': inp_width // 4
}

result = {'img': images, 'meta': meta}

return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
pred = self.infer_model(input['img'])
return {'results': pred, 'meta': input['meta']}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
output = inputs['results'][0]
meta = inputs['meta']
hm = output['hm'].sigmoid_()
ftype = output['ftype'].sigmoid_()
wh = output['wh']
reg = output['reg']

bbox, inds = bbox_decode(hm, wh, reg=reg, K=self.K)
car_type = decode_by_ind(ftype, inds, K=self.K).detach().cpu().numpy()
bbox = bbox.detach().cpu().numpy()
for i in range(bbox.shape[1]):
bbox[0][i][9] = car_type[0][i]
bbox = nms(bbox, 0.3)
bbox = bbox_post_process(bbox.copy(), [meta['c'].cpu().numpy()],
[meta['s']], meta['out_height'],
meta['out_width'])

res, Type = [], []
for box in bbox[0]:
if box[8] > 0.3:
res.append(box[0:8])
Type.append(self.car_type[int(box[9])])

result = {OutputKeys.POLYGONS: np.array(res), OutputKeys.TEXT: Type}
return result

+ 275
- 0
modelscope/pipelines/cv/ocr_utils/model_resnet18_half.py View File

@@ -0,0 +1,275 @@
# ------------------------------------------------------------------------------
# The implementation is adopted from CenterNet,
# made publicly available under the MIT License at https://github.com/xingyizhou/CenterNet.git
# ------------------------------------------------------------------------------

import os

import torch
import torch.nn as nn
import torch.nn.functional as F

BN_MOMENTUM = 0.1


class BasicBlock(nn.Module):
expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(
inplanes, planes, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.downsample = downsample
self.stride = stride
self.planes = planes

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(residual)

out += residual
out = self.relu(out)
return out


class Bottleneck(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv3 = nn.Conv2d(
planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(
planes * self.expansion, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
residual = self.downsample(x)

out += residual
out = self.relu(out)

return out


class PoseResNet(nn.Module):

def __init__(self, block, layers, head_conv=64, **kwargs):
self.inplanes = 64
self.deconv_with_bias = False
self.heads = {'hm': 1, 'cls': 4, 'ftype': 11, 'wh': 8, 'reg': 2}

super(PoseResNet, self).__init__()
self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 256, layers[3], stride=2)

self.adaption3 = nn.Conv2d(
256, 256, kernel_size=1, stride=1, padding=0, bias=False)
self.adaption2 = nn.Conv2d(
128, 256, kernel_size=1, stride=1, padding=0, bias=False)
self.adaption1 = nn.Conv2d(
64, 256, kernel_size=1, stride=1, padding=0, bias=False)
self.adaption0 = nn.Conv2d(
64, 256, kernel_size=1, stride=1, padding=0, bias=False)

self.adaptionU1 = nn.Conv2d(
256, 256, kernel_size=1, stride=1, padding=0, bias=False)

self.deconv_layers1 = self._make_deconv_layer(
1,
[256],
[4],
)
self.deconv_layers2 = self._make_deconv_layer(
1,
[256],
[4],
)
self.deconv_layers3 = self._make_deconv_layer(
1,
[256],
[4],
)
self.deconv_layers4 = self._make_deconv_layer(
1,
[256],
[4],
)

for head in sorted(self.heads):
num_output = self.heads[head]
if head_conv > 0:
inchannel = 256
fc = nn.Sequential(
nn.Conv2d(
inchannel,
head_conv,
kernel_size=3,
padding=1,
bias=True), nn.ReLU(inplace=True),
nn.Conv2d(
head_conv,
num_output,
kernel_size=1,
stride=1,
padding=0))
else:
inchannel = 256
fc = nn.Conv2d(
in_channels=inchannel,
out_channels=num_output,
kernel_size=1,
stride=1,
padding=0)
self.__setattr__(head, fc)

def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
)

layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))

return nn.Sequential(*layers)

def _get_deconv_cfg(self, deconv_kernel, index):
if deconv_kernel == 4:
padding = 1
output_padding = 0
elif deconv_kernel == 3:
padding = 1
output_padding = 1
elif deconv_kernel == 2:
padding = 0
output_padding = 0
elif deconv_kernel == 7:
padding = 3
output_padding = 0

return deconv_kernel, padding, output_padding

def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
assert num_layers == len(num_filters), \
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
assert num_layers == len(num_kernels), \
'ERROR: num_deconv_layers is different len(num_deconv_filters)'

layers = []
for i in range(num_layers):
kernel, padding, output_padding = \
self._get_deconv_cfg(num_kernels[i], i)

planes = num_filters[i]
layers.append(
nn.ConvTranspose2d(
in_channels=self.inplanes,
out_channels=planes,
kernel_size=kernel,
stride=2,
padding=padding,
output_padding=output_padding,
bias=self.deconv_with_bias))
layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
layers.append(nn.ReLU(inplace=True))
self.inplanes = planes

return nn.Sequential(*layers)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x0 = self.maxpool(x)
x1 = self.layer1(x0)
x2 = self.layer2(x1)
x3 = self.layer3(x2)
x4 = self.layer4(x3)

x3_ = self.deconv_layers1(x4)
x3_ = self.adaption3(x3) + x3_

x2_ = self.deconv_layers2(x3_)
x2_ = self.adaption2(x2) + x2_

x1_ = self.deconv_layers3(x2_)
x1_ = self.adaption1(x1) + x1_

x0_ = self.deconv_layers4(x1_) + self.adaption0(x0)
x0_ = self.adaptionU1(x0_)

ret = {}

for head in self.heads:
ret[head] = self.__getattr__(head)(x0_)
return [ret]


resnet_spec = {
18: (BasicBlock, [2, 2, 2, 2]),
34: (BasicBlock, [3, 4, 6, 3]),
50: (Bottleneck, [3, 4, 6, 3]),
101: (Bottleneck, [3, 4, 23, 3]),
152: (Bottleneck, [3, 8, 36, 3])
}


def LicensePlateDet(num_layers=18):
block_class, layers = resnet_spec[num_layers]
model = PoseResNet(block_class, layers)
return model

+ 9
- 1
modelscope/pipelines/cv/ocr_utils/table_process.py View File

@@ -129,6 +129,14 @@ def _topk(scores, K=40):
return topk_score, topk_inds, topk_clses, topk_ys, topk_xs


def decode_by_ind(heat, inds, K=100):
batch, cat, height, width = heat.size()
score = _tranpose_and_gather_feat(heat, inds)
score = score.view(batch, K, cat)
_, Type = torch.max(score, 2)
return Type


def bbox_decode(heat, wh, reg=None, K=100):
batch, cat, height, width = heat.size()

@@ -163,7 +171,7 @@ def bbox_decode(heat, wh, reg=None, K=100):
)
detections = torch.cat([bboxes, scores, clses], dim=2)

return detections, keep
return detections, inds


def gbox_decode(mk, st_reg, reg=None, K=400):


+ 1
- 1
modelscope/pipelines/cv/table_recognition_pipeline.py View File

@@ -50,7 +50,7 @@ class TableRecognitionPipeline(Pipeline):
self.infer_model.load_state_dict(checkpoint)

def preprocess(self, input: Input) -> Dict[str, Any]:
img = LoadImage.convert_to_ndarray(input)
img = LoadImage.convert_to_ndarray(input)[:, :, ::-1]

mean = np.array([0.408, 0.447, 0.470],
dtype=np.float32).reshape(1, 1, 3)


+ 1
- 0
modelscope/utils/constant.py View File

@@ -17,6 +17,7 @@ class CVTasks(object):
ocr_detection = 'ocr-detection'
ocr_recognition = 'ocr-recognition'
table_recognition = 'table-recognition'
license_plate_detection = 'license-plate-detection'

# human face body related
animal_recognition = 'animal-recognition'


+ 41
- 0
tests/pipelines/test_license_plate_detection.py View File

@@ -0,0 +1,41 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import unittest

from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level


class LicensePlateDectionTest(unittest.TestCase, DemoCompatibilityCheck):

def setUp(self) -> None:
self.model_id = 'damo/cv_resnet18_license-plate-detection_damo'
self.test_image = 'data/test/images/license_plate_detection.jpg'
self.task = Tasks.license_plate_detection

def pipeline_inference(self, pipe: Pipeline, input_location: str):
result = pipe(input_location)
print('license plate recognition results: ')
print(result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
license_plate_detection = pipeline(
Tasks.license_plate_detection, model=self.model_id)
self.pipeline_inference(license_plate_detection, self.test_image)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
license_plate_detection = pipeline(Tasks.license_plate_detection)
self.pipeline_inference(license_plate_detection, self.test_image)

@unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self):
self.compatibility_check()


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save