From 9d43823f366360a7b98ddad6334a006e2f6dbb29 Mon Sep 17 00:00:00 2001 From: ly261666 Date: Mon, 5 Dec 2022 12:01:26 +0800 Subject: [PATCH] [to #42322933] add TinyMogFace face detector Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10792564 --- modelscope/metainfo.py | 2 + .../models/cv/face_detection/__init__.py | 3 +- .../cv/face_detection/scrfd/__init__.py | 1 + .../mmdet_patch/models/backbones/__init__.py | 3 +- .../mmdet_patch/models/backbones/mobilenet.py | 99 ++++++++++++ .../mmdet_patch/models/detectors/__init__.py | 3 +- .../mmdet_patch/models/detectors/tinymog.py | 148 ++++++++++++++++++ .../cv/face_detection/scrfd/tinymog_detect.py | 67 ++++++++ .../pipelines/cv/face_detection_pipeline.py | 11 +- .../pipelines/test_tinymog_face_detection.py | 57 +++++++ 10 files changed, 389 insertions(+), 5 deletions(-) create mode 100644 modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/mobilenet.py create mode 100755 modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/tinymog.py create mode 100644 modelscope/models/cv/face_detection/scrfd/tinymog_detect.py create mode 100644 tests/pipelines/test_tinymog_face_detection.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 2a05035a..50f8ac34 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -47,6 +47,7 @@ class Models(object): ulfd = 'ulfd' arcface = 'arcface' facemask = 'facemask' + tinymog = 'tinymog' video_inpainting = 'video-inpainting' human_wholebody_keypoint = 'human-wholebody-keypoint' hand_static = 'hand-static' @@ -182,6 +183,7 @@ class Pipelines(object): face_detection = 'resnet-face-detection-scrfd10gkps' card_detection = 'resnet-card-detection-scrfd34gkps' ulfd_face_detection = 'manual-face-detection-ulfd' + tinymog_face_detection = 'manual-face-detection-tinymog' facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' retina_face_detection = 'resnet50-face-detection-retinaface' mog_face_detection = 'resnet101-face-detection-cvpr22papermogface' diff --git a/modelscope/models/cv/face_detection/__init__.py b/modelscope/models/cv/face_detection/__init__.py index 27d1bd4c..85d2e5fb 100644 --- a/modelscope/models/cv/face_detection/__init__.py +++ b/modelscope/models/cv/face_detection/__init__.py @@ -9,13 +9,14 @@ if TYPE_CHECKING: from .retinaface import RetinaFaceDetection from .ulfd_slim import UlfdFaceDetector from .scrfd import ScrfdDetect + from .scrfd import TinyMogDetect else: _import_structure = { 'ulfd_slim': ['UlfdFaceDetector'], 'retinaface': ['RetinaFaceDetection'], 'mtcnn': ['MtcnnFaceDetector'], 'mogface': ['MogFaceDetector'], - 'scrfd': ['ScrfdDetect'] + 'scrfd': ['TinyMogDetect', 'ScrfdDetect'], } import sys diff --git a/modelscope/models/cv/face_detection/scrfd/__init__.py b/modelscope/models/cv/face_detection/scrfd/__init__.py index 92f81f7a..e1d096a3 100644 --- a/modelscope/models/cv/face_detection/scrfd/__init__.py +++ b/modelscope/models/cv/face_detection/scrfd/__init__.py @@ -1,2 +1,3 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .scrfd_detect import ScrfdDetect +from .tinymog_detect import TinyMogDetect diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/__init__.py index 5c3b190e..653bd3ef 100755 --- a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/__init__.py +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/__init__.py @@ -2,6 +2,7 @@ The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/backbones """ +from .mobilenet import MobileNetV1 from .resnet import ResNetV1e -__all__ = ['ResNetV1e'] +__all__ = ['ResNetV1e', 'MobileNetV1'] diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/mobilenet.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/mobilenet.py new file mode 100644 index 00000000..600f0434 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/mobilenet.py @@ -0,0 +1,99 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/backbones/mobilenet.py +""" + +import torch +import torch.nn as nn +from mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer, + constant_init, kaiming_init) +from mmcv.runner import load_checkpoint +from mmdet.models.builder import BACKBONES +from mmdet.utils import get_root_logger +from torch.nn.modules.batchnorm import _BatchNorm + + +@BACKBONES.register_module() +class MobileNetV1(nn.Module): + + def __init__(self, + in_channels=3, + block_cfg=None, + num_stages=4, + out_indices=(0, 1, 2, 3)): + super(MobileNetV1, self).__init__() + self.out_indices = out_indices + + def conv_bn(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), nn.ReLU(inplace=True)) + + def conv_dw(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.ReLU(inplace=True), + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU(inplace=True), + ) + + if block_cfg is None: + stage_planes = [8, 16, 32, 64, 128, 256] + stage_blocks = [2, 4, 4, 2] + else: + stage_planes = block_cfg['stage_planes'] + stage_blocks = block_cfg['stage_blocks'] + assert len(stage_planes) == 6 + assert len(stage_blocks) == 4 + self.stem = nn.Sequential( + conv_bn(3, stage_planes[0], 2), + conv_dw(stage_planes[0], stage_planes[1], 1), + ) + self.stage_layers = [] + for i, num_blocks in enumerate(stage_blocks): + _layers = [] + for n in range(num_blocks): + if n == 0: + _layer = conv_dw(stage_planes[i + 1], stage_planes[i + 2], + 2) + else: + _layer = conv_dw(stage_planes[i + 2], stage_planes[i + 2], + 1) + _layers.append(_layer) + + _block = nn.Sequential(*_layers) + layer_name = f'layer{i + 1}' + self.add_module(layer_name, _block) + self.stage_layers.append(layer_name) + + def forward(self, x): + output = [] + x = self.stem(x) + for i, layer_name in enumerate(self.stage_layers): + stage_layer = getattr(self, layer_name) + x = stage_layer(x) + if i in self.out_indices: + output.append(x) + + return tuple(output) + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + else: + raise TypeError('pretrained must be a str or None') diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/__init__.py index 7935606a..c1ed8f16 100755 --- a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/__init__.py +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/__init__.py @@ -3,5 +3,6 @@ The implementation here is modified based on insightface, originally MIT license https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/detectors """ from .scrfd import SCRFD +from .tinymog import TinyMog -__all__ = ['SCRFD'] +__all__ = ['SCRFD', 'TinyMog'] diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/tinymog.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/tinymog.py new file mode 100755 index 00000000..a0b51753 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/tinymog.py @@ -0,0 +1,148 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/detectors/scrfd.py +""" +import torch +from mmdet.models.builder import DETECTORS +from mmdet.models.detectors.single_stage import SingleStageDetector + +from ....mmdet_patch.core.bbox import bbox2result + + +@DETECTORS.register_module() +class TinyMog(SingleStageDetector): + + def __init__(self, + backbone, + neck, + bbox_head, + train_cfg=None, + test_cfg=None, + pretrained=None): + super(TinyMog, self).__init__(backbone, neck, bbox_head, train_cfg, + test_cfg, pretrained) + + def forward_train(self, + img, + img_metas, + gt_bboxes, + gt_labels, + gt_keypointss=None, + gt_bboxes_ignore=None): + """ + Args: + img (Tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + img_metas (list[dict]): A List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + :class:`mmdet.datasets.pipelines.Collect`. + gt_bboxes (list[Tensor]): Each item are the truth boxes for each + image in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): Class indices corresponding to each box + gt_bboxes_ignore (None | list[Tensor]): Specify which bounding + boxes can be ignored when computing the loss. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + super(SingleStageDetector, self).forward_train(img, img_metas) + x = self.extract_feat(img) + losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes, + gt_labels, gt_keypointss, + gt_bboxes_ignore) + return losses + + def simple_test(self, + img, + img_metas, + rescale=False, + repeat_head=1, + output_kps_var=0, + output_results=1): + """Test function without test time augmentation. + + Args: + imgs (list[torch.Tensor]): List of multiple images + img_metas (list[dict]): List of image information. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + repeat_head (int): repeat inference times in head + output_kps_var (int): whether output kps var to calculate quality + output_results (int): 0: nothing 1: bbox 2: both bbox and kps + + Returns: + list[list[np.ndarray]]: BBox results of each image and classes. + The outer list corresponds to each image. The inner list + corresponds to each class. + """ + x = self.extract_feat(img) + assert repeat_head >= 1 + kps_out0 = [] + kps_out1 = [] + kps_out2 = [] + for i in range(repeat_head): + outs = self.bbox_head(x) + kps_out0 += [outs[2][0].detach().cpu().numpy()] + kps_out1 += [outs[2][1].detach().cpu().numpy()] + kps_out2 += [outs[2][2].detach().cpu().numpy()] + if output_kps_var: + var0 = np.var(np.vstack(kps_out0), axis=0).mean() + var1 = np.var(np.vstack(kps_out1), axis=0).mean() + var2 = np.var(np.vstack(kps_out2), axis=0).mean() + var = np.mean([var0, var1, var2]) + else: + var = None + + if output_results > 0: + if torch.onnx.is_in_onnx_export(): + cls_score, bbox_pred, kps_pred = outs + for c in cls_score: + print(c.shape) + for c in bbox_pred: + print(c.shape) + if self.bbox_head.use_kps: + for c in kps_pred: + print(c.shape) + return (cls_score, bbox_pred, kps_pred) + else: + return (cls_score, bbox_pred) + bbox_list = self.bbox_head.get_bboxes( + *outs, img_metas, rescale=rescale) + + # return kps if use_kps + if len(bbox_list[0]) == 2: + bbox_results = [ + bbox2result(det_bboxes, det_labels, + self.bbox_head.num_classes) + for det_bboxes, det_labels in bbox_list + ] + elif len(bbox_list[0]) == 3: + if output_results == 2: + bbox_results = [ + bbox2result( + det_bboxes, + det_labels, + self.bbox_head.num_classes, + kps=det_kps, + num_kps=self.bbox_head.NK) + for det_bboxes, det_labels, det_kps in bbox_list + ] + elif output_results == 1: + bbox_results = [ + bbox2result(det_bboxes, det_labels, + self.bbox_head.num_classes) + for det_bboxes, det_labels, _ in bbox_list + ] + else: + bbox_results = None + if var is not None: + return bbox_results, var + else: + return bbox_results + + def feature_test(self, img): + x = self.extract_feat(img) + outs = self.bbox_head(x) + return outs diff --git a/modelscope/models/cv/face_detection/scrfd/tinymog_detect.py b/modelscope/models/cv/face_detection/scrfd/tinymog_detect.py new file mode 100644 index 00000000..17d61871 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/tinymog_detect.py @@ -0,0 +1,67 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from copy import deepcopy +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['TinyMogDetect'] + + +@MODELS.register_module(Tasks.face_detection, module_name=Models.tinymog) +class TinyMogDetect(TorchModel): + + def __init__(self, model_dir, *args, **kwargs): + """ + initialize the tinymog face detection model from the `model_dir` path. + """ + super().__init__(model_dir) + from mmcv import Config + from mmcv.parallel import MMDataParallel + from mmcv.runner import load_checkpoint + from mmdet.models import build_detector + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets import RetinaFaceDataset + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import RandomSquareCrop + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.backbones import ResNetV1e + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.dense_heads import SCRFDHead + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.detectors import SCRFD + cfg = Config.fromfile(osp.join(model_dir, 'mmcv_tinymog.py')) + ckpt_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) + cfg.model.test_cfg.score_thr = kwargs.get('score_thr', 0.3) + detector = build_detector(cfg.model) + logger.info(f'loading model from {ckpt_path}') + load_checkpoint(detector, ckpt_path, map_location='cpu') + detector = MMDataParallel(detector) + detector.eval() + self.detector = detector + logger.info('load model done') + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + result = self.detector( + return_loss=False, + rescale=True, + img=[input['img'][0].unsqueeze(0)], + img_metas=[[dict(input['img_metas'][0].data)]], + output_results=2) + assert result is not None + result = result[0][0] + bboxes = result[:, :4].tolist() + kpss = result[:, 5:].tolist() + scores = result[:, 4].tolist() + return { + OutputKeys.SCORES: scores, + OutputKeys.BOXES: bboxes, + OutputKeys.KEYPOINTS: kpss + } + + def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: + return input diff --git a/modelscope/pipelines/cv/face_detection_pipeline.py b/modelscope/pipelines/cv/face_detection_pipeline.py index 608567a4..3b17d830 100644 --- a/modelscope/pipelines/cv/face_detection_pipeline.py +++ b/modelscope/pipelines/cv/face_detection_pipeline.py @@ -8,11 +8,12 @@ import PIL import torch from modelscope.metainfo import Pipelines -from modelscope.models.cv.face_detection import ScrfdDetect +from modelscope.models.cv.face_detection import ScrfdDetect, TinyMogDetect from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import LoadImage +from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger @@ -30,7 +31,13 @@ class FaceDetectionPipeline(Pipeline): model: model id on modelscope hub. """ super().__init__(model=model, **kwargs) - detector = ScrfdDetect(model_dir=model, **kwargs) + config_path = osp.join(model, ModelFile.CONFIGURATION) + cfg = Config.from_file(config_path) + cfg_model = getattr(cfg, 'model', None) + if cfg_model is None: + detector = ScrfdDetect(model_dir=model, **kwargs) + elif cfg_model.type == 'tinymog': + detector = self.model.to(self.device) self.detector = detector def preprocess(self, input: Input) -> Dict[str, Any]: diff --git a/tests/pipelines/test_tinymog_face_detection.py b/tests/pipelines/test_tinymog_face_detection.py new file mode 100644 index 00000000..e80fa482 --- /dev/null +++ b/tests/pipelines/test_tinymog_face_detection.py @@ -0,0 +1,57 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 + +from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import draw_face_detection_result +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class TinyMogFaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.face_detection + self.model_id = 'damo/cv_manual_face-detection_tinymog' + self.img_path = 'data/test/images/mog_face_detection.jpg' + + def show_result(self, img_path, detection_result): + img = draw_face_detection_result(img_path, detection_result) + cv2.imwrite('result.png', img) + print(f'output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_dataset(self): + input_location = ['data/test/images/mog_face_detection.jpg'] + + dataset = MsDataset.load(input_location, target='image') + face_detection = pipeline(Tasks.face_detection, model=self.model_id) + # note that for dataset output, the inference-output is a Generator that can be iterated. + result = face_detection(dataset) + result = next(result) + self.show_result(input_location[0], result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + face_detection = pipeline(Tasks.face_detection, model=self.model_id) + + result = face_detection(self.img_path) + self.show_result(self.img_path, result) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + face_detection = pipeline(Tasks.face_detection) + result = face_detection(self.img_path) + self.show_result(self.img_path, result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main()