Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10917315master^2
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:209f6ba7f15c9c34a02801b4c6ef33a979f3086702b5229d2e7975eb403c3e15 | |||||
| size 45819 | |||||
| @@ -157,6 +157,7 @@ class Pipelines(object): | |||||
| person_image_cartoon = 'unet-person-image-cartoon' | person_image_cartoon = 'unet-person-image-cartoon' | ||||
| ocr_detection = 'resnet18-ocr-detection' | ocr_detection = 'resnet18-ocr-detection' | ||||
| table_recognition = 'dla34-table-recognition' | table_recognition = 'dla34-table-recognition' | ||||
| license_plate_detection = 'resnet18-license-plate-detection' | |||||
| action_recognition = 'TAdaConv_action-recognition' | action_recognition = 'TAdaConv_action-recognition' | ||||
| animal_recognition = 'resnet101-animal-recognition' | animal_recognition = 'resnet101-animal-recognition' | ||||
| general_recognition = 'resnet101-general-recognition' | general_recognition = 'resnet101-general-recognition' | ||||
| @@ -62,6 +62,7 @@ TASK_OUTPUTS = { | |||||
| # } | # } | ||||
| Tasks.ocr_detection: [OutputKeys.POLYGONS], | Tasks.ocr_detection: [OutputKeys.POLYGONS], | ||||
| Tasks.table_recognition: [OutputKeys.POLYGONS], | Tasks.table_recognition: [OutputKeys.POLYGONS], | ||||
| Tasks.license_plate_detection: [OutputKeys.POLYGONS, OutputKeys.TEXT], | |||||
| # ocr recognition result for single sample | # ocr recognition result for single sample | ||||
| # { | # { | ||||
| @@ -85,6 +85,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| Tasks.table_recognition: | Tasks.table_recognition: | ||||
| (Pipelines.table_recognition, | (Pipelines.table_recognition, | ||||
| 'damo/cv_dla34_table-structure-recognition_cycle-centernet'), | 'damo/cv_dla34_table-structure-recognition_cycle-centernet'), | ||||
| Tasks.license_plate_detection: | |||||
| (Pipelines.license_plate_detection, | |||||
| 'damo/cv_resnet18_license-plate-detection_damo'), | |||||
| Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), | Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), | ||||
| Tasks.feature_extraction: (Pipelines.feature_extraction, | Tasks.feature_extraction: (Pipelines.feature_extraction, | ||||
| 'damo/pert_feature-extraction_base-test'), | 'damo/pert_feature-extraction_base-test'), | ||||
| @@ -41,6 +41,7 @@ if TYPE_CHECKING: | |||||
| from .live_category_pipeline import LiveCategoryPipeline | from .live_category_pipeline import LiveCategoryPipeline | ||||
| from .ocr_detection_pipeline import OCRDetectionPipeline | from .ocr_detection_pipeline import OCRDetectionPipeline | ||||
| from .ocr_recognition_pipeline import OCRRecognitionPipeline | from .ocr_recognition_pipeline import OCRRecognitionPipeline | ||||
| from .license_plate_detection_pipeline import LicensePlateDetectionPipeline | |||||
| from .table_recognition_pipeline import TableRecognitionPipeline | from .table_recognition_pipeline import TableRecognitionPipeline | ||||
| from .skin_retouching_pipeline import SkinRetouchingPipeline | from .skin_retouching_pipeline import SkinRetouchingPipeline | ||||
| from .tinynas_classification_pipeline import TinynasClassificationPipeline | from .tinynas_classification_pipeline import TinynasClassificationPipeline | ||||
| @@ -109,6 +110,7 @@ else: | |||||
| 'image_inpainting_pipeline': ['ImageInpaintingPipeline'], | 'image_inpainting_pipeline': ['ImageInpaintingPipeline'], | ||||
| 'ocr_detection_pipeline': ['OCRDetectionPipeline'], | 'ocr_detection_pipeline': ['OCRDetectionPipeline'], | ||||
| 'ocr_recognition_pipeline': ['OCRRecognitionPipeline'], | 'ocr_recognition_pipeline': ['OCRRecognitionPipeline'], | ||||
| 'license_plate_detection_pipeline': ['LicensePlateDetectionPipeline'], | |||||
| 'table_recognition_pipeline': ['TableRecognitionPipeline'], | 'table_recognition_pipeline': ['TableRecognitionPipeline'], | ||||
| 'skin_retouching_pipeline': ['SkinRetouchingPipeline'], | 'skin_retouching_pipeline': ['SkinRetouchingPipeline'], | ||||
| 'tinynas_classification_pipeline': ['TinynasClassificationPipeline'], | 'tinynas_classification_pipeline': ['TinynasClassificationPipeline'], | ||||
| @@ -0,0 +1,122 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import math | |||||
| import os.path as osp | |||||
| from typing import Any, Dict | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import PIL | |||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.pipelines.cv.ocr_utils.model_resnet18_half import \ | |||||
| LicensePlateDet | |||||
| from modelscope.pipelines.cv.ocr_utils.table_process import ( | |||||
| bbox_decode, bbox_post_process, decode_by_ind, get_affine_transform, nms) | |||||
| from modelscope.preprocessors import load_image | |||||
| from modelscope.preprocessors.image import LoadImage | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.license_plate_detection, | |||||
| module_name=Pipelines.license_plate_detection) | |||||
| class LicensePlateDetection(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| """ | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model, **kwargs) | |||||
| model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) | |||||
| config_path = osp.join(self.model, ModelFile.CONFIGURATION) | |||||
| logger.info(f'loading model from {model_path}') | |||||
| self.cfg = Config.from_file(config_path) | |||||
| self.K = self.cfg.K | |||||
| self.car_type = self.cfg.Type | |||||
| self.device = torch.device( | |||||
| 'cuda' if torch.cuda.is_available() else 'cpu') | |||||
| self.infer_model = LicensePlateDet() | |||||
| checkpoint = torch.load(model_path, map_location=self.device) | |||||
| if 'state_dict' in checkpoint: | |||||
| self.infer_model.load_state_dict(checkpoint['state_dict']) | |||||
| else: | |||||
| self.infer_model.load_state_dict(checkpoint) | |||||
| self.infer_model = self.infer_model.to(self.device) | |||||
| self.infer_model.to(self.device).eval() | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| img = LoadImage.convert_to_ndarray(input)[:, :, ::-1] | |||||
| mean = np.array([0.408, 0.447, 0.470], | |||||
| dtype=np.float32).reshape(1, 1, 3) | |||||
| std = np.array([0.289, 0.274, 0.278], | |||||
| dtype=np.float32).reshape(1, 1, 3) | |||||
| height, width = img.shape[0:2] | |||||
| inp_height, inp_width = 512, 512 | |||||
| c = np.array([width / 2., height / 2.], dtype=np.float32) | |||||
| s = max(height, width) * 1.0 | |||||
| trans_input = get_affine_transform(c, s, 0, [inp_width, inp_height]) | |||||
| resized_image = cv2.resize(img, (width, height)) | |||||
| inp_image = cv2.warpAffine( | |||||
| resized_image, | |||||
| trans_input, (inp_width, inp_height), | |||||
| flags=cv2.INTER_LINEAR) | |||||
| inp_image = ((inp_image / 255. - mean) / std).astype(np.float32) | |||||
| images = inp_image.transpose(2, 0, 1).reshape(1, 3, inp_height, | |||||
| inp_width) | |||||
| images = torch.from_numpy(images).to(self.device) | |||||
| meta = { | |||||
| 'c': c, | |||||
| 's': s, | |||||
| 'input_height': inp_height, | |||||
| 'input_width': inp_width, | |||||
| 'out_height': inp_height // 4, | |||||
| 'out_width': inp_width // 4 | |||||
| } | |||||
| result = {'img': images, 'meta': meta} | |||||
| return result | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| pred = self.infer_model(input['img']) | |||||
| return {'results': pred, 'meta': input['meta']} | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| output = inputs['results'][0] | |||||
| meta = inputs['meta'] | |||||
| hm = output['hm'].sigmoid_() | |||||
| ftype = output['ftype'].sigmoid_() | |||||
| wh = output['wh'] | |||||
| reg = output['reg'] | |||||
| bbox, inds = bbox_decode(hm, wh, reg=reg, K=self.K) | |||||
| car_type = decode_by_ind(ftype, inds, K=self.K).detach().cpu().numpy() | |||||
| bbox = bbox.detach().cpu().numpy() | |||||
| for i in range(bbox.shape[1]): | |||||
| bbox[0][i][9] = car_type[0][i] | |||||
| bbox = nms(bbox, 0.3) | |||||
| bbox = bbox_post_process(bbox.copy(), [meta['c'].cpu().numpy()], | |||||
| [meta['s']], meta['out_height'], | |||||
| meta['out_width']) | |||||
| res, Type = [], [] | |||||
| for box in bbox[0]: | |||||
| if box[8] > 0.3: | |||||
| res.append(box[0:8]) | |||||
| Type.append(self.car_type[int(box[9])]) | |||||
| result = {OutputKeys.POLYGONS: np.array(res), OutputKeys.TEXT: Type} | |||||
| return result | |||||
| @@ -0,0 +1,275 @@ | |||||
| # ------------------------------------------------------------------------------ | |||||
| # The implementation is adopted from CenterNet, | |||||
| # made publicly available under the MIT License at https://github.com/xingyizhou/CenterNet.git | |||||
| # ------------------------------------------------------------------------------ | |||||
| import os | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| BN_MOMENTUM = 0.1 | |||||
| class BasicBlock(nn.Module): | |||||
| expansion = 1 | |||||
| def __init__(self, inplanes, planes, stride=1, downsample=None): | |||||
| super(BasicBlock, self).__init__() | |||||
| self.conv1 = nn.Conv2d( | |||||
| inplanes, planes, kernel_size=3, stride=stride, padding=1) | |||||
| self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) | |||||
| self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||||
| self.downsample = downsample | |||||
| self.stride = stride | |||||
| self.planes = planes | |||||
| def forward(self, x): | |||||
| residual = x | |||||
| out = self.conv1(x) | |||||
| out = self.bn1(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv2(out) | |||||
| out = self.bn2(out) | |||||
| if self.downsample is not None: | |||||
| residual = self.downsample(residual) | |||||
| out += residual | |||||
| out = self.relu(out) | |||||
| return out | |||||
| class Bottleneck(nn.Module): | |||||
| expansion = 4 | |||||
| def __init__(self, inplanes, planes, stride=1, downsample=None): | |||||
| super(Bottleneck, self).__init__() | |||||
| self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | |||||
| self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||||
| self.conv2 = nn.Conv2d( | |||||
| planes, | |||||
| planes, | |||||
| kernel_size=3, | |||||
| stride=stride, | |||||
| padding=1, | |||||
| bias=False) | |||||
| self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||||
| self.conv3 = nn.Conv2d( | |||||
| planes, planes * self.expansion, kernel_size=1, bias=False) | |||||
| self.bn3 = nn.BatchNorm2d( | |||||
| planes * self.expansion, momentum=BN_MOMENTUM) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| self.downsample = downsample | |||||
| self.stride = stride | |||||
| def forward(self, x): | |||||
| residual = x | |||||
| out = self.conv1(x) | |||||
| out = self.bn1(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv2(out) | |||||
| out = self.bn2(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv3(out) | |||||
| out = self.bn3(out) | |||||
| if self.downsample is not None: | |||||
| residual = self.downsample(x) | |||||
| out += residual | |||||
| out = self.relu(out) | |||||
| return out | |||||
| class PoseResNet(nn.Module): | |||||
| def __init__(self, block, layers, head_conv=64, **kwargs): | |||||
| self.inplanes = 64 | |||||
| self.deconv_with_bias = False | |||||
| self.heads = {'hm': 1, 'cls': 4, 'ftype': 11, 'wh': 8, 'reg': 2} | |||||
| super(PoseResNet, self).__init__() | |||||
| self.conv1 = nn.Conv2d( | |||||
| 3, 64, kernel_size=7, stride=2, padding=3, bias=False) | |||||
| self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||||
| self.layer1 = self._make_layer(block, 64, layers[0], stride=2) | |||||
| self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | |||||
| self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | |||||
| self.layer4 = self._make_layer(block, 256, layers[3], stride=2) | |||||
| self.adaption3 = nn.Conv2d( | |||||
| 256, 256, kernel_size=1, stride=1, padding=0, bias=False) | |||||
| self.adaption2 = nn.Conv2d( | |||||
| 128, 256, kernel_size=1, stride=1, padding=0, bias=False) | |||||
| self.adaption1 = nn.Conv2d( | |||||
| 64, 256, kernel_size=1, stride=1, padding=0, bias=False) | |||||
| self.adaption0 = nn.Conv2d( | |||||
| 64, 256, kernel_size=1, stride=1, padding=0, bias=False) | |||||
| self.adaptionU1 = nn.Conv2d( | |||||
| 256, 256, kernel_size=1, stride=1, padding=0, bias=False) | |||||
| self.deconv_layers1 = self._make_deconv_layer( | |||||
| 1, | |||||
| [256], | |||||
| [4], | |||||
| ) | |||||
| self.deconv_layers2 = self._make_deconv_layer( | |||||
| 1, | |||||
| [256], | |||||
| [4], | |||||
| ) | |||||
| self.deconv_layers3 = self._make_deconv_layer( | |||||
| 1, | |||||
| [256], | |||||
| [4], | |||||
| ) | |||||
| self.deconv_layers4 = self._make_deconv_layer( | |||||
| 1, | |||||
| [256], | |||||
| [4], | |||||
| ) | |||||
| for head in sorted(self.heads): | |||||
| num_output = self.heads[head] | |||||
| if head_conv > 0: | |||||
| inchannel = 256 | |||||
| fc = nn.Sequential( | |||||
| nn.Conv2d( | |||||
| inchannel, | |||||
| head_conv, | |||||
| kernel_size=3, | |||||
| padding=1, | |||||
| bias=True), nn.ReLU(inplace=True), | |||||
| nn.Conv2d( | |||||
| head_conv, | |||||
| num_output, | |||||
| kernel_size=1, | |||||
| stride=1, | |||||
| padding=0)) | |||||
| else: | |||||
| inchannel = 256 | |||||
| fc = nn.Conv2d( | |||||
| in_channels=inchannel, | |||||
| out_channels=num_output, | |||||
| kernel_size=1, | |||||
| stride=1, | |||||
| padding=0) | |||||
| self.__setattr__(head, fc) | |||||
| def _make_layer(self, block, planes, blocks, stride=1): | |||||
| downsample = None | |||||
| if stride != 1 or self.inplanes != planes * block.expansion: | |||||
| downsample = nn.Sequential( | |||||
| nn.Conv2d( | |||||
| self.inplanes, | |||||
| planes * block.expansion, | |||||
| kernel_size=1, | |||||
| stride=stride, | |||||
| bias=False), | |||||
| nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), | |||||
| ) | |||||
| layers = [] | |||||
| layers.append(block(self.inplanes, planes, stride, downsample)) | |||||
| self.inplanes = planes * block.expansion | |||||
| for i in range(1, blocks): | |||||
| layers.append(block(self.inplanes, planes)) | |||||
| return nn.Sequential(*layers) | |||||
| def _get_deconv_cfg(self, deconv_kernel, index): | |||||
| if deconv_kernel == 4: | |||||
| padding = 1 | |||||
| output_padding = 0 | |||||
| elif deconv_kernel == 3: | |||||
| padding = 1 | |||||
| output_padding = 1 | |||||
| elif deconv_kernel == 2: | |||||
| padding = 0 | |||||
| output_padding = 0 | |||||
| elif deconv_kernel == 7: | |||||
| padding = 3 | |||||
| output_padding = 0 | |||||
| return deconv_kernel, padding, output_padding | |||||
| def _make_deconv_layer(self, num_layers, num_filters, num_kernels): | |||||
| assert num_layers == len(num_filters), \ | |||||
| 'ERROR: num_deconv_layers is different len(num_deconv_filters)' | |||||
| assert num_layers == len(num_kernels), \ | |||||
| 'ERROR: num_deconv_layers is different len(num_deconv_filters)' | |||||
| layers = [] | |||||
| for i in range(num_layers): | |||||
| kernel, padding, output_padding = \ | |||||
| self._get_deconv_cfg(num_kernels[i], i) | |||||
| planes = num_filters[i] | |||||
| layers.append( | |||||
| nn.ConvTranspose2d( | |||||
| in_channels=self.inplanes, | |||||
| out_channels=planes, | |||||
| kernel_size=kernel, | |||||
| stride=2, | |||||
| padding=padding, | |||||
| output_padding=output_padding, | |||||
| bias=self.deconv_with_bias)) | |||||
| layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)) | |||||
| layers.append(nn.ReLU(inplace=True)) | |||||
| self.inplanes = planes | |||||
| return nn.Sequential(*layers) | |||||
| def forward(self, x): | |||||
| x = self.conv1(x) | |||||
| x = self.bn1(x) | |||||
| x = self.relu(x) | |||||
| x0 = self.maxpool(x) | |||||
| x1 = self.layer1(x0) | |||||
| x2 = self.layer2(x1) | |||||
| x3 = self.layer3(x2) | |||||
| x4 = self.layer4(x3) | |||||
| x3_ = self.deconv_layers1(x4) | |||||
| x3_ = self.adaption3(x3) + x3_ | |||||
| x2_ = self.deconv_layers2(x3_) | |||||
| x2_ = self.adaption2(x2) + x2_ | |||||
| x1_ = self.deconv_layers3(x2_) | |||||
| x1_ = self.adaption1(x1) + x1_ | |||||
| x0_ = self.deconv_layers4(x1_) + self.adaption0(x0) | |||||
| x0_ = self.adaptionU1(x0_) | |||||
| ret = {} | |||||
| for head in self.heads: | |||||
| ret[head] = self.__getattr__(head)(x0_) | |||||
| return [ret] | |||||
| resnet_spec = { | |||||
| 18: (BasicBlock, [2, 2, 2, 2]), | |||||
| 34: (BasicBlock, [3, 4, 6, 3]), | |||||
| 50: (Bottleneck, [3, 4, 6, 3]), | |||||
| 101: (Bottleneck, [3, 4, 23, 3]), | |||||
| 152: (Bottleneck, [3, 8, 36, 3]) | |||||
| } | |||||
| def LicensePlateDet(num_layers=18): | |||||
| block_class, layers = resnet_spec[num_layers] | |||||
| model = PoseResNet(block_class, layers) | |||||
| return model | |||||
| @@ -129,6 +129,14 @@ def _topk(scores, K=40): | |||||
| return topk_score, topk_inds, topk_clses, topk_ys, topk_xs | return topk_score, topk_inds, topk_clses, topk_ys, topk_xs | ||||
| def decode_by_ind(heat, inds, K=100): | |||||
| batch, cat, height, width = heat.size() | |||||
| score = _tranpose_and_gather_feat(heat, inds) | |||||
| score = score.view(batch, K, cat) | |||||
| _, Type = torch.max(score, 2) | |||||
| return Type | |||||
| def bbox_decode(heat, wh, reg=None, K=100): | def bbox_decode(heat, wh, reg=None, K=100): | ||||
| batch, cat, height, width = heat.size() | batch, cat, height, width = heat.size() | ||||
| @@ -163,7 +171,7 @@ def bbox_decode(heat, wh, reg=None, K=100): | |||||
| ) | ) | ||||
| detections = torch.cat([bboxes, scores, clses], dim=2) | detections = torch.cat([bboxes, scores, clses], dim=2) | ||||
| return detections, keep | |||||
| return detections, inds | |||||
| def gbox_decode(mk, st_reg, reg=None, K=400): | def gbox_decode(mk, st_reg, reg=None, K=400): | ||||
| @@ -50,7 +50,7 @@ class TableRecognitionPipeline(Pipeline): | |||||
| self.infer_model.load_state_dict(checkpoint) | self.infer_model.load_state_dict(checkpoint) | ||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | def preprocess(self, input: Input) -> Dict[str, Any]: | ||||
| img = LoadImage.convert_to_ndarray(input) | |||||
| img = LoadImage.convert_to_ndarray(input)[:, :, ::-1] | |||||
| mean = np.array([0.408, 0.447, 0.470], | mean = np.array([0.408, 0.447, 0.470], | ||||
| dtype=np.float32).reshape(1, 1, 3) | dtype=np.float32).reshape(1, 1, 3) | ||||
| @@ -17,6 +17,7 @@ class CVTasks(object): | |||||
| ocr_detection = 'ocr-detection' | ocr_detection = 'ocr-detection' | ||||
| ocr_recognition = 'ocr-recognition' | ocr_recognition = 'ocr-recognition' | ||||
| table_recognition = 'table-recognition' | table_recognition = 'table-recognition' | ||||
| license_plate_detection = 'license-plate-detection' | |||||
| # human face body related | # human face body related | ||||
| animal_recognition = 'animal-recognition' | animal_recognition = 'animal-recognition' | ||||
| @@ -0,0 +1,41 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.pipelines.base import Pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class LicensePlateDectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| def setUp(self) -> None: | |||||
| self.model_id = 'damo/cv_resnet18_license-plate-detection_damo' | |||||
| self.test_image = 'data/test/images/license_plate_detection.jpg' | |||||
| self.task = Tasks.license_plate_detection | |||||
| def pipeline_inference(self, pipe: Pipeline, input_location: str): | |||||
| result = pipe(input_location) | |||||
| print('license plate recognition results: ') | |||||
| print(result) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| license_plate_detection = pipeline( | |||||
| Tasks.license_plate_detection, model=self.model_id) | |||||
| self.pipeline_inference(license_plate_detection, self.test_image) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_modelhub_default_model(self): | |||||
| license_plate_detection = pipeline(Tasks.license_plate_detection) | |||||
| self.pipeline_inference(license_plate_detection, self.test_image) | |||||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||||
| def test_demo_compatibility(self): | |||||
| self.compatibility_check() | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||