diff --git a/data/test/images/card_detection.jpg b/data/test/images/card_detection.jpg new file mode 100644 index 00000000..86728c2c --- /dev/null +++ b/data/test/images/card_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ecbc9d0827cfb92e93e7d75868b1724142685dc20d3b32023c3c657a7b688a9c +size 254845 diff --git a/data/test/images/face_detection2.jpeg b/data/test/images/face_detection2.jpeg new file mode 100644 index 00000000..7f6025fa --- /dev/null +++ b/data/test/images/face_detection2.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d510ab26ddc58ffea882c8ef850c1f9bd4444772f2bce7ebea3e76944536c3ae +size 48909 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 759f1688..0917bf3e 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -148,6 +148,7 @@ class Pipelines(object): salient_detection = 'u2net-salient-detection' image_classification = 'image-classification' face_detection = 'resnet-face-detection-scrfd10gkps' + card_detection = 'resnet-card-detection-scrfd34gkps' ulfd_face_detection = 'manual-face-detection-ulfd' facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' retina_face_detection = 'resnet50-face-detection-retinaface' @@ -270,6 +271,8 @@ class Trainers(object): image_portrait_enhancement = 'image-portrait-enhancement' video_summarization = 'video-summarization' movie_scene_segmentation = 'movie-scene-segmentation' + face_detection_scrfd = 'face-detection-scrfd' + card_detection_scrfd = 'card-detection-scrfd' image_inpainting = 'image-inpainting' # nlp trainers diff --git a/modelscope/models/cv/face_detection/__init__.py b/modelscope/models/cv/face_detection/__init__.py index a2a845d2..27d1bd4c 100644 --- a/modelscope/models/cv/face_detection/__init__.py +++ b/modelscope/models/cv/face_detection/__init__.py @@ -8,12 +8,14 @@ if TYPE_CHECKING: from .mtcnn import MtcnnFaceDetector from .retinaface import RetinaFaceDetection from .ulfd_slim import UlfdFaceDetector + from .scrfd import ScrfdDetect else: _import_structure = { 'ulfd_slim': ['UlfdFaceDetector'], 'retinaface': ['RetinaFaceDetection'], 'mtcnn': ['MtcnnFaceDetector'], - 'mogface': ['MogFaceDetector'] + 'mogface': ['MogFaceDetector'], + 'scrfd': ['ScrfdDetect'] } import sys diff --git a/modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/transforms.py b/modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/transforms.py deleted file mode 100755 index 241f2c0e..00000000 --- a/modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/transforms.py +++ /dev/null @@ -1,189 +0,0 @@ -""" -The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at -https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/transforms.py -""" -import numpy as np -from mmdet.datasets.builder import PIPELINES -from numpy import random - - -@PIPELINES.register_module() -class RandomSquareCrop(object): - """Random crop the image & bboxes, the cropped patches have minimum IoU - requirement with original image & bboxes, the IoU threshold is randomly - selected from min_ious. - - Args: - min_ious (tuple): minimum IoU threshold for all intersections with - bounding boxes - min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w, - where a >= min_crop_size). - - Note: - The keys for bboxes, labels and masks should be paired. That is, \ - `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \ - `gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`. - """ - - def __init__(self, - crop_ratio_range=None, - crop_choice=None, - bbox_clip_border=True): - - self.crop_ratio_range = crop_ratio_range - self.crop_choice = crop_choice - self.bbox_clip_border = bbox_clip_border - - assert (self.crop_ratio_range is None) ^ (self.crop_choice is None) - if self.crop_ratio_range is not None: - self.crop_ratio_min, self.crop_ratio_max = self.crop_ratio_range - - self.bbox2label = { - 'gt_bboxes': 'gt_labels', - 'gt_bboxes_ignore': 'gt_labels_ignore' - } - self.bbox2mask = { - 'gt_bboxes': 'gt_masks', - 'gt_bboxes_ignore': 'gt_masks_ignore' - } - - def __call__(self, results): - """Call function to crop images and bounding boxes with minimum IoU - constraint. - - Args: - results (dict): Result dict from loading pipeline. - - Returns: - dict: Result dict with images and bounding boxes cropped, \ - 'img_shape' key is updated. - """ - - if 'img_fields' in results: - assert results['img_fields'] == ['img'], \ - 'Only single img_fields is allowed' - img = results['img'] - assert 'bbox_fields' in results - assert 'gt_bboxes' in results - boxes = results['gt_bboxes'] - h, w, c = img.shape - scale_retry = 0 - if self.crop_ratio_range is not None: - max_scale = self.crop_ratio_max - else: - max_scale = np.amax(self.crop_choice) - while True: - scale_retry += 1 - - if scale_retry == 1 or max_scale > 1.0: - if self.crop_ratio_range is not None: - scale = np.random.uniform(self.crop_ratio_min, - self.crop_ratio_max) - elif self.crop_choice is not None: - scale = np.random.choice(self.crop_choice) - else: - scale = scale * 1.2 - - for i in range(250): - short_side = min(w, h) - cw = int(scale * short_side) - ch = cw - - # TODO +1 - if w == cw: - left = 0 - elif w > cw: - left = random.randint(0, w - cw) - else: - left = random.randint(w - cw, 0) - if h == ch: - top = 0 - elif h > ch: - top = random.randint(0, h - ch) - else: - top = random.randint(h - ch, 0) - - patch = np.array( - (int(left), int(top), int(left + cw), int(top + ch)), - dtype=np.int) - - # center of boxes should inside the crop img - # only adjust boxes and instance masks when the gt is not empty - # adjust boxes - def is_center_of_bboxes_in_patch(boxes, patch): - # TODO >= - center = (boxes[:, :2] + boxes[:, 2:]) / 2 - mask = \ - ((center[:, 0] > patch[0]) - * (center[:, 1] > patch[1]) - * (center[:, 0] < patch[2]) - * (center[:, 1] < patch[3])) - return mask - - mask = is_center_of_bboxes_in_patch(boxes, patch) - if not mask.any(): - continue - for key in results.get('bbox_fields', []): - boxes = results[key].copy() - mask = is_center_of_bboxes_in_patch(boxes, patch) - boxes = boxes[mask] - if self.bbox_clip_border: - boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) - boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) - boxes -= np.tile(patch[:2], 2) - - results[key] = boxes - # labels - label_key = self.bbox2label.get(key) - if label_key in results: - results[label_key] = results[label_key][mask] - - # keypoints field - if key == 'gt_bboxes': - for kps_key in results.get('keypoints_fields', []): - keypointss = results[kps_key].copy() - keypointss = keypointss[mask, :, :] - if self.bbox_clip_border: - keypointss[:, :, : - 2] = keypointss[:, :, :2].clip( - max=patch[2:]) - keypointss[:, :, : - 2] = keypointss[:, :, :2].clip( - min=patch[:2]) - keypointss[:, :, 0] -= patch[0] - keypointss[:, :, 1] -= patch[1] - results[kps_key] = keypointss - - # mask fields - mask_key = self.bbox2mask.get(key) - if mask_key in results: - results[mask_key] = results[mask_key][mask.nonzero() - [0]].crop(patch) - - # adjust the img no matter whether the gt is empty before crop - rimg = np.ones((ch, cw, 3), dtype=img.dtype) * 128 - patch_from = patch.copy() - patch_from[0] = max(0, patch_from[0]) - patch_from[1] = max(0, patch_from[1]) - patch_from[2] = min(img.shape[1], patch_from[2]) - patch_from[3] = min(img.shape[0], patch_from[3]) - patch_to = patch.copy() - patch_to[0] = max(0, patch_to[0] * -1) - patch_to[1] = max(0, patch_to[1] * -1) - patch_to[2] = patch_to[0] + (patch_from[2] - patch_from[0]) - patch_to[3] = patch_to[1] + (patch_from[3] - patch_from[1]) - rimg[patch_to[1]:patch_to[3], - patch_to[0]:patch_to[2], :] = img[ - patch_from[1]:patch_from[3], - patch_from[0]:patch_from[2], :] - img = rimg - results['img'] = img - results['img_shape'] = img.shape - - return results - - def __repr__(self): - repr_str = self.__class__.__name__ - repr_str += f'(min_ious={self.min_iou}, ' - repr_str += f'crop_size={self.crop_size})' - return repr_str diff --git a/modelscope/models/cv/face_detection/scrfd/__init__.py b/modelscope/models/cv/face_detection/scrfd/__init__.py new file mode 100644 index 00000000..92f81f7a --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .scrfd_detect import ScrfdDetect diff --git a/modelscope/models/cv/face_detection/mmdet_patch/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/__init__.py similarity index 100% rename from modelscope/models/cv/face_detection/mmdet_patch/__init__.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/__init__.py diff --git a/modelscope/models/cv/face_detection/mmdet_patch/core/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/__init__.py similarity index 100% rename from modelscope/models/cv/face_detection/mmdet_patch/core/__init__.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/__init__.py diff --git a/modelscope/models/cv/face_detection/mmdet_patch/core/bbox/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/bbox/__init__.py similarity index 100% rename from modelscope/models/cv/face_detection/mmdet_patch/core/bbox/__init__.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/bbox/__init__.py diff --git a/modelscope/models/cv/face_detection/mmdet_patch/core/bbox/transforms.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/bbox/transforms.py similarity index 94% rename from modelscope/models/cv/face_detection/mmdet_patch/core/bbox/transforms.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/bbox/transforms.py index d65480eb..75e32d85 100755 --- a/modelscope/models/cv/face_detection/mmdet_patch/core/bbox/transforms.py +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/bbox/transforms.py @@ -6,7 +6,7 @@ import numpy as np import torch -def bbox2result(bboxes, labels, num_classes, kps=None): +def bbox2result(bboxes, labels, num_classes, kps=None, num_kps=5): """Convert detection results to a list of numpy arrays. Args: @@ -17,7 +17,7 @@ def bbox2result(bboxes, labels, num_classes, kps=None): Returns: list(ndarray): bbox results of each class """ - bbox_len = 5 if kps is None else 5 + 10 # if has kps, add 10 kps into bbox + bbox_len = 5 if kps is None else 5 + num_kps * 2 # if has kps, add num_kps*2 into bbox if bboxes.shape[0] == 0: return [ np.zeros((0, bbox_len), dtype=np.float32) diff --git a/modelscope/models/cv/face_detection/mmdet_patch/core/post_processing/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/post_processing/__init__.py similarity index 100% rename from modelscope/models/cv/face_detection/mmdet_patch/core/post_processing/__init__.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/post_processing/__init__.py diff --git a/modelscope/models/cv/face_detection/mmdet_patch/core/post_processing/bbox_nms.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/post_processing/bbox_nms.py similarity index 89% rename from modelscope/models/cv/face_detection/mmdet_patch/core/post_processing/bbox_nms.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/post_processing/bbox_nms.py index 7a4f5b3a..697b7338 100644 --- a/modelscope/models/cv/face_detection/mmdet_patch/core/post_processing/bbox_nms.py +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/post_processing/bbox_nms.py @@ -17,6 +17,7 @@ def multiclass_nms(multi_bboxes, Args: multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) + multi_kps (Tensor): shape (n, #class*num_kps*2) or (n, num_kps*2) multi_scores (Tensor): shape (n, #class), where the last column contains scores of the background class, but this will be ignored. score_thr (float): bbox threshold, bboxes with scores lower than it @@ -36,16 +37,18 @@ def multiclass_nms(multi_bboxes, num_classes = multi_scores.size(1) - 1 # exclude background category kps = None + if multi_kps is not None: + num_kps = int((multi_kps.shape[1] / num_classes) / 2) if multi_bboxes.shape[1] > 4: bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) if multi_kps is not None: - kps = multi_kps.view(multi_scores.size(0), -1, 10) + kps = multi_kps.view(multi_scores.size(0), -1, num_kps * 2) else: bboxes = multi_bboxes[:, None].expand( multi_scores.size(0), num_classes, 4) if multi_kps is not None: kps = multi_kps[:, None].expand( - multi_scores.size(0), num_classes, 10) + multi_scores.size(0), num_classes, num_kps * 2) scores = multi_scores[:, :-1] if score_factors is not None: @@ -56,7 +59,7 @@ def multiclass_nms(multi_bboxes, bboxes = bboxes.reshape(-1, 4) if kps is not None: - kps = kps.reshape(-1, 10) + kps = kps.reshape(-1, num_kps * 2) scores = scores.reshape(-1) labels = labels.reshape(-1) diff --git a/modelscope/models/cv/face_detection/mmdet_patch/datasets/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/__init__.py similarity index 100% rename from modelscope/models/cv/face_detection/mmdet_patch/datasets/__init__.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/__init__.py diff --git a/modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/__init__.py similarity index 53% rename from modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/__init__.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/__init__.py index 85288910..a2cafd1a 100755 --- a/modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/__init__.py +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/__init__.py @@ -2,6 +2,12 @@ The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines """ +from .auto_augment import RotateV2 +from .formating import DefaultFormatBundleV2 +from .loading import LoadAnnotationsV2 from .transforms import RandomSquareCrop -__all__ = ['RandomSquareCrop'] +__all__ = [ + 'RandomSquareCrop', 'LoadAnnotationsV2', 'RotateV2', + 'DefaultFormatBundleV2' +] diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/auto_augment.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/auto_augment.py new file mode 100644 index 00000000..ee60c2e0 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/auto_augment.py @@ -0,0 +1,271 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/auto_augment.py +""" +import copy + +import cv2 +import mmcv +import numpy as np +from mmdet.datasets.builder import PIPELINES + +_MAX_LEVEL = 10 + + +def level_to_value(level, max_value): + """Map from level to values based on max_value.""" + return (level / _MAX_LEVEL) * max_value + + +def random_negative(value, random_negative_prob): + """Randomly negate value based on random_negative_prob.""" + return -value if np.random.rand() < random_negative_prob else value + + +def bbox2fields(): + """The key correspondence from bboxes to labels, masks and + segmentations.""" + bbox2label = { + 'gt_bboxes': 'gt_labels', + 'gt_bboxes_ignore': 'gt_labels_ignore' + } + bbox2mask = { + 'gt_bboxes': 'gt_masks', + 'gt_bboxes_ignore': 'gt_masks_ignore' + } + bbox2seg = { + 'gt_bboxes': 'gt_semantic_seg', + } + return bbox2label, bbox2mask, bbox2seg + + +@PIPELINES.register_module() +class RotateV2(object): + """Apply Rotate Transformation to image (and its corresponding bbox, mask, + segmentation). + + Args: + level (int | float): The level should be in range (0,_MAX_LEVEL]. + scale (int | float): Isotropic scale factor. Same in + ``mmcv.imrotate``. + center (int | float | tuple[float]): Center point (w, h) of the + rotation in the source image. If None, the center of the + image will be used. Same in ``mmcv.imrotate``. + img_fill_val (int | float | tuple): The fill value for image border. + If float, the same value will be used for all the three + channels of image. If tuple, the should be 3 elements (e.g. + equals the number of channels for image). + seg_ignore_label (int): The fill value used for segmentation map. + Note this value must equals ``ignore_label`` in ``semantic_head`` + of the corresponding config. Default 255. + prob (float): The probability for perform transformation and + should be in range 0 to 1. + max_rotate_angle (int | float): The maximum angles for rotate + transformation. + random_negative_prob (float): The probability that turns the + offset negative. + """ + + def __init__(self, + level, + scale=1, + center=None, + img_fill_val=128, + seg_ignore_label=255, + prob=0.5, + max_rotate_angle=30, + random_negative_prob=0.5): + assert isinstance(level, (int, float)), \ + f'The level must be type int or float. got {type(level)}.' + assert 0 <= level <= _MAX_LEVEL, \ + f'The level should be in range (0,{_MAX_LEVEL}]. got {level}.' + assert isinstance(scale, (int, float)), \ + f'The scale must be type int or float. got type {type(scale)}.' + if isinstance(center, (int, float)): + center = (center, center) + elif isinstance(center, tuple): + assert len(center) == 2, 'center with type tuple must have '\ + f'2 elements. got {len(center)} elements.' + else: + assert center is None, 'center must be None or type int, '\ + f'float or tuple, got type {type(center)}.' + if isinstance(img_fill_val, (float, int)): + img_fill_val = tuple([float(img_fill_val)] * 3) + elif isinstance(img_fill_val, tuple): + assert len(img_fill_val) == 3, 'img_fill_val as tuple must '\ + f'have 3 elements. got {len(img_fill_val)}.' + img_fill_val = tuple([float(val) for val in img_fill_val]) + else: + raise ValueError( + 'img_fill_val must be float or tuple with 3 elements.') + assert np.all([0 <= val <= 255 for val in img_fill_val]), \ + 'all elements of img_fill_val should between range [0,255]. '\ + f'got {img_fill_val}.' + assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. '\ + f'got {prob}.' + assert isinstance(max_rotate_angle, (int, float)), 'max_rotate_angle '\ + f'should be type int or float. got type {type(max_rotate_angle)}.' + self.level = level + self.scale = scale + # Rotation angle in degrees. Positive values mean + # clockwise rotation. + self.angle = level_to_value(level, max_rotate_angle) + self.center = center + self.img_fill_val = img_fill_val + self.seg_ignore_label = seg_ignore_label + self.prob = prob + self.max_rotate_angle = max_rotate_angle + self.random_negative_prob = random_negative_prob + + def _rotate_img(self, results, angle, center=None, scale=1.0): + """Rotate the image. + + Args: + results (dict): Result dict from loading pipeline. + angle (float): Rotation angle in degrees, positive values + mean clockwise rotation. Same in ``mmcv.imrotate``. + center (tuple[float], optional): Center point (w, h) of the + rotation. Same in ``mmcv.imrotate``. + scale (int | float): Isotropic scale factor. Same in + ``mmcv.imrotate``. + """ + for key in results.get('img_fields', ['img']): + img = results[key].copy() + img_rotated = mmcv.imrotate( + img, angle, center, scale, border_value=self.img_fill_val) + results[key] = img_rotated.astype(img.dtype) + results['img_shape'] = results[key].shape + + def _rotate_bboxes(self, results, rotate_matrix): + """Rotate the bboxes.""" + h, w, c = results['img_shape'] + for key in results.get('bbox_fields', []): + min_x, min_y, max_x, max_y = np.split( + results[key], results[key].shape[-1], axis=-1) + coordinates = np.stack([[min_x, min_y], [max_x, min_y], + [min_x, max_y], + [max_x, max_y]]) # [4, 2, nb_bbox, 1] + # pad 1 to convert from format [x, y] to homogeneous + # coordinates format [x, y, 1] + coordinates = np.concatenate( + (coordinates, + np.ones((4, 1, coordinates.shape[2], 1), coordinates.dtype)), + axis=1) # [4, 3, nb_bbox, 1] + coordinates = coordinates.transpose( + (2, 0, 1, 3)) # [nb_bbox, 4, 3, 1] + rotated_coords = np.matmul(rotate_matrix, + coordinates) # [nb_bbox, 4, 2, 1] + rotated_coords = rotated_coords[..., 0] # [nb_bbox, 4, 2] + min_x, min_y = np.min( + rotated_coords[:, :, 0], axis=1), np.min( + rotated_coords[:, :, 1], axis=1) + max_x, max_y = np.max( + rotated_coords[:, :, 0], axis=1), np.max( + rotated_coords[:, :, 1], axis=1) + results[key] = np.stack([min_x, min_y, max_x, max_y], + axis=-1).astype(results[key].dtype) + + def _rotate_keypoints90(self, results, angle): + """Rotate the keypoints, only valid when angle in [-90,90,-180,180]""" + if angle not in [-90, 90, 180, -180 + ] or self.scale != 1 or self.center is not None: + return + for key in results.get('keypoints_fields', []): + k = results[key] + if angle == 90: + w, h, c = results['img'].shape + new = np.stack([h - k[..., 1], k[..., 0], k[..., 2]], axis=-1) + elif angle == -90: + w, h, c = results['img'].shape + new = np.stack([k[..., 1], w - k[..., 0], k[..., 2]], axis=-1) + else: + h, w, c = results['img'].shape + new = np.stack([w - k[..., 0], h - k[..., 1], k[..., 2]], + axis=-1) + # a kps is invalid if thrid value is -1 + kps_invalid = new[..., -1][:, -1] == -1 + new[kps_invalid] = np.zeros(new.shape[1:]) - 1 + results[key] = new + + def _rotate_masks(self, + results, + angle, + center=None, + scale=1.0, + fill_val=0): + """Rotate the masks.""" + h, w, c = results['img_shape'] + for key in results.get('mask_fields', []): + masks = results[key] + results[key] = masks.rotate((h, w), angle, center, scale, fill_val) + + def _rotate_seg(self, + results, + angle, + center=None, + scale=1.0, + fill_val=255): + """Rotate the segmentation map.""" + for key in results.get('seg_fields', []): + seg = results[key].copy() + results[key] = mmcv.imrotate( + seg, angle, center, scale, + border_value=fill_val).astype(seg.dtype) + + def _filter_invalid(self, results, min_bbox_size=0): + """Filter bboxes and corresponding masks too small after rotate + augmentation.""" + bbox2label, bbox2mask, _ = bbox2fields() + for key in results.get('bbox_fields', []): + bbox_w = results[key][:, 2] - results[key][:, 0] + bbox_h = results[key][:, 3] - results[key][:, 1] + valid_inds = (bbox_w > min_bbox_size) & (bbox_h > min_bbox_size) + valid_inds = np.nonzero(valid_inds)[0] + results[key] = results[key][valid_inds] + # label fields. e.g. gt_labels and gt_labels_ignore + label_key = bbox2label.get(key) + if label_key in results: + results[label_key] = results[label_key][valid_inds] + # mask fields, e.g. gt_masks and gt_masks_ignore + mask_key = bbox2mask.get(key) + if mask_key in results: + results[mask_key] = results[mask_key][valid_inds] + + def __call__(self, results): + """Call function to rotate images, bounding boxes, masks and semantic + segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Rotated results. + """ + if np.random.rand() > self.prob: + return results + h, w = results['img'].shape[:2] + center = self.center + if center is None: + center = ((w - 1) * 0.5, (h - 1) * 0.5) + angle = random_negative(self.angle, self.random_negative_prob) + self._rotate_img(results, angle, center, self.scale) + rotate_matrix = cv2.getRotationMatrix2D(center, -angle, self.scale) + self._rotate_bboxes(results, rotate_matrix) + self._rotate_keypoints90(results, angle) + self._rotate_masks(results, angle, center, self.scale, fill_val=0) + self._rotate_seg( + results, angle, center, self.scale, fill_val=self.seg_ignore_label) + self._filter_invalid(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(level={self.level}, ' + repr_str += f'scale={self.scale}, ' + repr_str += f'center={self.center}, ' + repr_str += f'img_fill_val={self.img_fill_val}, ' + repr_str += f'seg_ignore_label={self.seg_ignore_label}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'max_rotate_angle={self.max_rotate_angle}, ' + repr_str += f'random_negative_prob={self.random_negative_prob})' + return repr_str diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/formating.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/formating.py new file mode 100644 index 00000000..bd2394a8 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/formating.py @@ -0,0 +1,113 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/formating.py +""" +import numpy as np +import torch +from mmcv.parallel import DataContainer as DC +from mmdet.datasets.builder import PIPELINES + + +def to_tensor(data): + """Convert objects of various python types to :obj:`torch.Tensor`. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int` and :class:`float`. + + Args: + data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to + be converted. + """ + + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray): + return torch.from_numpy(data) + elif isinstance(data, Sequence) and not mmcv.is_str(data): + return torch.tensor(data) + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + else: + raise TypeError(f'type {type(data)} cannot be converted to tensor.') + + +@PIPELINES.register_module() +class DefaultFormatBundleV2(object): + """Default formatting bundle. + + It simplifies the pipeline of formatting common fields, including "img", + "proposals", "gt_bboxes", "gt_labels", "gt_masks" and "gt_semantic_seg". + These fields are formatted as follows. + + - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True) + - proposals: (1)to tensor, (2)to DataContainer + - gt_bboxes: (1)to tensor, (2)to DataContainer + - gt_bboxes_ignore: (1)to tensor, (2)to DataContainer + - gt_labels: (1)to tensor, (2)to DataContainer + - gt_masks: (1)to tensor, (2)to DataContainer (cpu_only=True) + - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, \ + (3)to DataContainer (stack=True) + """ + + def __call__(self, results): + """Call function to transform and format common fields in results. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + dict: The result dict contains the data that is formatted with \ + default bundle. + """ + + if 'img' in results: + img = results['img'] + # add default meta keys + results = self._add_default_meta_keys(results) + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + img = np.ascontiguousarray(img.transpose(2, 0, 1)) + results['img'] = DC(to_tensor(img), stack=True) + for key in [ + 'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_keypointss', + 'gt_labels' + ]: + if key not in results: + continue + results[key] = DC(to_tensor(results[key])) + if 'gt_masks' in results: + results['gt_masks'] = DC(results['gt_masks'], cpu_only=True) + if 'gt_semantic_seg' in results: + results['gt_semantic_seg'] = DC( + to_tensor(results['gt_semantic_seg'][None, ...]), stack=True) + return results + + def _add_default_meta_keys(self, results): + """Add default meta keys. + + We set default meta keys including `pad_shape`, `scale_factor` and + `img_norm_cfg` to avoid the case where no `Resize`, `Normalize` and + `Pad` are implemented during the whole pipeline. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + results (dict): Updated result dict contains the data to convert. + """ + img = results['img'] + results.setdefault('pad_shape', img.shape) + results.setdefault('scale_factor', 1.0) + num_channels = 1 if len(img.shape) < 3 else img.shape[2] + results.setdefault( + 'img_norm_cfg', + dict( + mean=np.zeros(num_channels, dtype=np.float32), + std=np.ones(num_channels, dtype=np.float32), + to_rgb=False)) + return results + + def __repr__(self): + return self.__class__.__name__ diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/loading.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/loading.py new file mode 100644 index 00000000..b4c2a385 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/loading.py @@ -0,0 +1,225 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/loading.py +""" +import os.path as osp + +import numpy as np +import pycocotools.mask as maskUtils +from mmdet.core import BitmapMasks, PolygonMasks +from mmdet.datasets.builder import PIPELINES + + +@PIPELINES.register_module() +class LoadAnnotationsV2(object): + """Load mutiple types of annotations. + + Args: + with_bbox (bool): Whether to parse and load the bbox annotation. + Default: True. + with_label (bool): Whether to parse and load the label annotation. + Default: True. + with_keypoints (bool): Whether to parse and load the keypoints annotation. + Default: False. + with_mask (bool): Whether to parse and load the mask annotation. + Default: False. + with_seg (bool): Whether to parse and load the semantic segmentation + annotation. Default: False. + poly2mask (bool): Whether to convert the instance masks from polygons + to bitmaps. Default: True. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. + Defaults to ``dict(backend='disk')``. + """ + + def __init__(self, + with_bbox=True, + with_label=True, + with_keypoints=False, + with_mask=False, + with_seg=False, + poly2mask=True, + file_client_args=dict(backend='disk')): + self.with_bbox = with_bbox + self.with_label = with_label + self.with_keypoints = with_keypoints + self.with_mask = with_mask + self.with_seg = with_seg + self.poly2mask = poly2mask + self.file_client_args = file_client_args.copy() + self.file_client = None + + def _load_bboxes(self, results): + """Private function to load bounding box annotations. + + Args: + results (dict): Result dict from :obj:`mmdet.CustomDataset`. + + Returns: + dict: The dict contains loaded bounding box annotations. + """ + + ann_info = results['ann_info'] + results['gt_bboxes'] = ann_info['bboxes'].copy() + + gt_bboxes_ignore = ann_info.get('bboxes_ignore', None) + if gt_bboxes_ignore is not None: + results['gt_bboxes_ignore'] = gt_bboxes_ignore.copy() + results['bbox_fields'].append('gt_bboxes_ignore') + results['bbox_fields'].append('gt_bboxes') + return results + + def _load_keypoints(self, results): + """Private function to load bounding box annotations. + + Args: + results (dict): Result dict from :obj:`mmdet.CustomDataset`. + + Returns: + dict: The dict contains loaded bounding box annotations. + """ + + ann_info = results['ann_info'] + results['gt_keypointss'] = ann_info['keypointss'].copy() + + results['keypoints_fields'] = ['gt_keypointss'] + return results + + def _load_labels(self, results): + """Private function to load label annotations. + + Args: + results (dict): Result dict from :obj:`mmdet.CustomDataset`. + + Returns: + dict: The dict contains loaded label annotations. + """ + + results['gt_labels'] = results['ann_info']['labels'].copy() + return results + + def _poly2mask(self, mask_ann, img_h, img_w): + """Private function to convert masks represented with polygon to + bitmaps. + + Args: + mask_ann (list | dict): Polygon mask annotation input. + img_h (int): The height of output mask. + img_w (int): The width of output mask. + + Returns: + numpy.ndarray: The decode bitmap mask of shape (img_h, img_w). + """ + + if isinstance(mask_ann, list): + # polygon -- a single object might consist of multiple parts + # we merge all parts into one mask rle code + rles = maskUtils.frPyObjects(mask_ann, img_h, img_w) + rle = maskUtils.merge(rles) + elif isinstance(mask_ann['counts'], list): + # uncompressed RLE + rle = maskUtils.frPyObjects(mask_ann, img_h, img_w) + else: + # rle + rle = mask_ann + mask = maskUtils.decode(rle) + return mask + + def process_polygons(self, polygons): + """Convert polygons to list of ndarray and filter invalid polygons. + + Args: + polygons (list[list]): Polygons of one instance. + + Returns: + list[numpy.ndarray]: Processed polygons. + """ + + polygons = [np.array(p) for p in polygons] + valid_polygons = [] + for polygon in polygons: + if len(polygon) % 2 == 0 and len(polygon) >= 6: + valid_polygons.append(polygon) + return valid_polygons + + def _load_masks(self, results): + """Private function to load mask annotations. + + Args: + results (dict): Result dict from :obj:`mmdet.CustomDataset`. + + Returns: + dict: The dict contains loaded mask annotations. + If ``self.poly2mask`` is set ``True``, `gt_mask` will contain + :obj:`PolygonMasks`. Otherwise, :obj:`BitmapMasks` is used. + """ + + h, w = results['img_info']['height'], results['img_info']['width'] + gt_masks = results['ann_info']['masks'] + if self.poly2mask: + gt_masks = BitmapMasks( + [self._poly2mask(mask, h, w) for mask in gt_masks], h, w) + else: + gt_masks = PolygonMasks( + [self.process_polygons(polygons) for polygons in gt_masks], h, + w) + results['gt_masks'] = gt_masks + results['mask_fields'].append('gt_masks') + return results + + def _load_semantic_seg(self, results): + """Private function to load semantic segmentation annotations. + + Args: + results (dict): Result dict from :obj:`dataset`. + + Returns: + dict: The dict contains loaded semantic segmentation annotations. + """ + import mmcv + if self.file_client is None: + self.file_client = mmcv.FileClient(**self.file_client_args) + + filename = osp.join(results['seg_prefix'], + results['ann_info']['seg_map']) + img_bytes = self.file_client.get(filename) + results['gt_semantic_seg'] = mmcv.imfrombytes( + img_bytes, flag='unchanged').squeeze() + results['seg_fields'].append('gt_semantic_seg') + return results + + def __call__(self, results): + """Call function to load multiple types annotations. + + Args: + results (dict): Result dict from :obj:`mmdet.CustomDataset`. + + Returns: + dict: The dict contains loaded bounding box, label, mask and + semantic segmentation annotations. + """ + + if self.with_bbox: + results = self._load_bboxes(results) + if results is None: + return None + if self.with_label: + results = self._load_labels(results) + if self.with_keypoints: + results = self._load_keypoints(results) + if self.with_mask: + results = self._load_masks(results) + if self.with_seg: + results = self._load_semantic_seg(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(with_bbox={self.with_bbox}, ' + repr_str += f'with_label={self.with_label}, ' + repr_str += f'with_keypoints={self.with_keypoints}, ' + repr_str += f'with_mask={self.with_mask}, ' + repr_str += f'with_seg={self.with_seg})' + repr_str += f'poly2mask={self.poly2mask})' + repr_str += f'poly2mask={self.file_client_args})' + return repr_str diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/transforms.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/transforms.py new file mode 100755 index 00000000..270c34da --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/transforms.py @@ -0,0 +1,737 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/transforms.py +""" +import mmcv +import numpy as np +from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps +from mmdet.datasets.builder import PIPELINES +from numpy import random + + +@PIPELINES.register_module() +class ResizeV2(object): + """Resize images & bbox & mask &kps. + + This transform resizes the input image to some scale. Bboxes and masks are + then resized with the same scale factor. If the input dict contains the key + "scale", then the scale in the input dict is used, otherwise the specified + scale in the init method is used. If the input dict contains the key + "scale_factor" (if MultiScaleFlipAug does not give img_scale but + scale_factor), the actual scale will be computed by image shape and + scale_factor. + + `img_scale` can either be a tuple (single-scale) or a list of tuple + (multi-scale). There are 3 multiscale modes: + + - ``ratio_range is not None``: randomly sample a ratio from the ratio \ + range and multiply it with the image scale. + - ``ratio_range is None`` and ``multiscale_mode == "range"``: randomly \ + sample a scale from the multiscale range. + - ``ratio_range is None`` and ``multiscale_mode == "value"``: randomly \ + sample a scale from multiple scales. + + Args: + img_scale (tuple or list[tuple]): Images scales for resizing. + multiscale_mode (str): Either "range" or "value". + ratio_range (tuple[float]): (min_ratio, max_ratio) + keep_ratio (bool): Whether to keep the aspect ratio when resizing the + image. + bbox_clip_border (bool, optional): Whether clip the objects outside + the border of the image. Defaults to True. + backend (str): Image resize backend, choices are 'cv2' and 'pillow'. + These two backends generates slightly different results. Defaults + to 'cv2'. + override (bool, optional): Whether to override `scale` and + `scale_factor` so as to call resize twice. Default False. If True, + after the first resizing, the existed `scale` and `scale_factor` + will be ignored so the second resizing can be allowed. + This option is a work-around for multiple times of resize in DETR. + Defaults to False. + """ + + def __init__(self, + img_scale=None, + multiscale_mode='range', + ratio_range=None, + keep_ratio=True, + bbox_clip_border=True, + backend='cv2', + override=False): + if img_scale is None: + self.img_scale = None + else: + if isinstance(img_scale, list): + self.img_scale = img_scale + else: + self.img_scale = [img_scale] + assert mmcv.is_list_of(self.img_scale, tuple) + + if ratio_range is not None: + # mode 1: given a scale and a range of image ratio + assert len(self.img_scale) == 1 + else: + # mode 2: given multiple scales or a range of scales + assert multiscale_mode in ['value', 'range'] + + self.backend = backend + self.multiscale_mode = multiscale_mode + self.ratio_range = ratio_range + self.keep_ratio = keep_ratio + # TODO: refactor the override option in Resize + self.override = override + self.bbox_clip_border = bbox_clip_border + + @staticmethod + def random_select(img_scales): + """Randomly select an img_scale from given candidates. + + Args: + img_scales (list[tuple]): Images scales for selection. + + Returns: + (tuple, int): Returns a tuple ``(img_scale, scale_dix)``, \ + where ``img_scale`` is the selected image scale and \ + ``scale_idx`` is the selected index in the given candidates. + """ + + assert mmcv.is_list_of(img_scales, tuple) + scale_idx = np.random.randint(len(img_scales)) + img_scale = img_scales[scale_idx] + return img_scale, scale_idx + + @staticmethod + def random_sample(img_scales): + """Randomly sample an img_scale when ``multiscale_mode=='range'``. + + Args: + img_scales (list[tuple]): Images scale range for sampling. + There must be two tuples in img_scales, which specify the lower + and uper bound of image scales. + + Returns: + (tuple, None): Returns a tuple ``(img_scale, None)``, where \ + ``img_scale`` is sampled scale and None is just a placeholder \ + to be consistent with :func:`random_select`. + """ + + assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2 + img_scale_long = [max(s) for s in img_scales] + img_scale_short = [min(s) for s in img_scales] + long_edge = np.random.randint( + min(img_scale_long), + max(img_scale_long) + 1) + short_edge = np.random.randint( + min(img_scale_short), + max(img_scale_short) + 1) + img_scale = (long_edge, short_edge) + return img_scale, None + + @staticmethod + def random_sample_ratio(img_scale, ratio_range): + """Randomly sample an img_scale when ``ratio_range`` is specified. + + A ratio will be randomly sampled from the range specified by + ``ratio_range``. Then it would be multiplied with ``img_scale`` to + generate sampled scale. + + Args: + img_scale (tuple): Images scale base to multiply with ratio. + ratio_range (tuple[float]): The minimum and maximum ratio to scale + the ``img_scale``. + + Returns: + (tuple, None): Returns a tuple ``(scale, None)``, where \ + ``scale`` is sampled ratio multiplied with ``img_scale`` and \ + None is just a placeholder to be consistent with \ + :func:`random_select`. + """ + + assert isinstance(img_scale, tuple) and len(img_scale) == 2 + min_ratio, max_ratio = ratio_range + assert min_ratio <= max_ratio + ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio + scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio) + return scale, None + + def _random_scale(self, results): + """Randomly sample an img_scale according to ``ratio_range`` and + ``multiscale_mode``. + + If ``ratio_range`` is specified, a ratio will be sampled and be + multiplied with ``img_scale``. + If multiple scales are specified by ``img_scale``, a scale will be + sampled according to ``multiscale_mode``. + Otherwise, single scale will be used. + + Args: + results (dict): Result dict from :obj:`dataset`. + + Returns: + dict: Two new keys 'scale` and 'scale_idx` are added into \ + ``results``, which would be used by subsequent pipelines. + """ + + if self.ratio_range is not None: + scale, scale_idx = self.random_sample_ratio( + self.img_scale[0], self.ratio_range) + elif len(self.img_scale) == 1: + scale, scale_idx = self.img_scale[0], 0 + elif self.multiscale_mode == 'range': + scale, scale_idx = self.random_sample(self.img_scale) + elif self.multiscale_mode == 'value': + scale, scale_idx = self.random_select(self.img_scale) + else: + raise NotImplementedError + + results['scale'] = scale + results['scale_idx'] = scale_idx + + def _resize_img(self, results): + """Resize images with ``results['scale']``.""" + for key in results.get('img_fields', ['img']): + if self.keep_ratio: + img, scale_factor = mmcv.imrescale( + results[key], + results['scale'], + return_scale=True, + backend=self.backend) + # the w_scale and h_scale has minor difference + # a real fix should be done in the mmcv.imrescale in the future + new_h, new_w = img.shape[:2] + h, w = results[key].shape[:2] + w_scale = new_w / w + h_scale = new_h / h + else: + img, w_scale, h_scale = mmcv.imresize( + results[key], + results['scale'], + return_scale=True, + backend=self.backend) + results[key] = img + + scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], + dtype=np.float32) + results['img_shape'] = img.shape + # in case that there is no padding + results['pad_shape'] = img.shape + results['scale_factor'] = scale_factor + results['keep_ratio'] = self.keep_ratio + + def _resize_bboxes(self, results): + """Resize bounding boxes with ``results['scale_factor']``.""" + for key in results.get('bbox_fields', []): + bboxes = results[key] * results['scale_factor'] + if self.bbox_clip_border: + img_shape = results['img_shape'] + bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1]) + bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0]) + results[key] = bboxes + + def _resize_keypoints(self, results): + """Resize keypoints with ``results['scale_factor']``.""" + for key in results.get('keypoints_fields', []): + keypointss = results[key].copy() + factors = results['scale_factor'] + assert factors[0] == factors[2] + assert factors[1] == factors[3] + keypointss[:, :, 0] *= factors[0] + keypointss[:, :, 1] *= factors[1] + if self.bbox_clip_border: + img_shape = results['img_shape'] + keypointss[:, :, 0] = np.clip(keypointss[:, :, 0], 0, + img_shape[1]) + keypointss[:, :, 1] = np.clip(keypointss[:, :, 1], 0, + img_shape[0]) + results[key] = keypointss + + def _resize_masks(self, results): + """Resize masks with ``results['scale']``""" + for key in results.get('mask_fields', []): + if results[key] is None: + continue + if self.keep_ratio: + results[key] = results[key].rescale(results['scale']) + else: + results[key] = results[key].resize(results['img_shape'][:2]) + + def _resize_seg(self, results): + """Resize semantic segmentation map with ``results['scale']``.""" + for key in results.get('seg_fields', []): + if self.keep_ratio: + gt_seg = mmcv.imrescale( + results[key], + results['scale'], + interpolation='nearest', + backend=self.backend) + else: + gt_seg = mmcv.imresize( + results[key], + results['scale'], + interpolation='nearest', + backend=self.backend) + results['gt_semantic_seg'] = gt_seg + + def __call__(self, results): + """Call function to resize images, bounding boxes, masks, semantic + segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', \ + 'keep_ratio' keys are added into result dict. + """ + + if 'scale' not in results: + if 'scale_factor' in results: + img_shape = results['img'].shape[:2] + scale_factor = results['scale_factor'] + assert isinstance(scale_factor, float) + results['scale'] = tuple( + [int(x * scale_factor) for x in img_shape][::-1]) + else: + self._random_scale(results) + else: + if not self.override: + assert 'scale_factor' not in results, ( + 'scale and scale_factor cannot be both set.') + else: + results.pop('scale') + if 'scale_factor' in results: + results.pop('scale_factor') + self._random_scale(results) + + self._resize_img(results) + self._resize_bboxes(results) + self._resize_keypoints(results) + self._resize_masks(results) + self._resize_seg(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(img_scale={self.img_scale}, ' + repr_str += f'multiscale_mode={self.multiscale_mode}, ' + repr_str += f'ratio_range={self.ratio_range}, ' + repr_str += f'keep_ratio={self.keep_ratio})' + repr_str += f'bbox_clip_border={self.bbox_clip_border})' + return repr_str + + +@PIPELINES.register_module() +class RandomFlipV2(object): + """Flip the image & bbox & mask & kps. + + If the input dict contains the key "flip", then the flag will be used, + otherwise it will be randomly decided by a ratio specified in the init + method. + + When random flip is enabled, ``flip_ratio``/``direction`` can either be a + float/string or tuple of float/string. There are 3 flip modes: + + - ``flip_ratio`` is float, ``direction`` is string: the image will be + ``direction``ly flipped with probability of ``flip_ratio`` . + E.g., ``flip_ratio=0.5``, ``direction='horizontal'``, + then image will be horizontally flipped with probability of 0.5. + - ``flip_ratio`` is float, ``direction`` is list of string: the image wil + be ``direction[i]``ly flipped with probability of + ``flip_ratio/len(direction)``. + E.g., ``flip_ratio=0.5``, ``direction=['horizontal', 'vertical']``, + then image will be horizontally flipped with probability of 0.25, + vertically with probability of 0.25. + - ``flip_ratio`` is list of float, ``direction`` is list of string: + given ``len(flip_ratio) == len(direction)``, the image wil + be ``direction[i]``ly flipped with probability of ``flip_ratio[i]``. + E.g., ``flip_ratio=[0.3, 0.5]``, ``direction=['horizontal', + 'vertical']``, then image will be horizontally flipped with probability + of 0.3, vertically with probability of 0.5 + + Args: + flip_ratio (float | list[float], optional): The flipping probability. + Default: None. + direction(str | list[str], optional): The flipping direction. Options + are 'horizontal', 'vertical', 'diagonal'. Default: 'horizontal'. + If input is a list, the length must equal ``flip_ratio``. Each + element in ``flip_ratio`` indicates the flip probability of + corresponding direction. + """ + + def __init__(self, flip_ratio=None, direction='horizontal'): + if isinstance(flip_ratio, list): + assert mmcv.is_list_of(flip_ratio, float) + assert 0 <= sum(flip_ratio) <= 1 + elif isinstance(flip_ratio, float): + assert 0 <= flip_ratio <= 1 + elif flip_ratio is None: + pass + else: + raise ValueError('flip_ratios must be None, float, ' + 'or list of float') + self.flip_ratio = flip_ratio + + valid_directions = ['horizontal', 'vertical', 'diagonal'] + if isinstance(direction, str): + assert direction in valid_directions + elif isinstance(direction, list): + assert mmcv.is_list_of(direction, str) + assert set(direction).issubset(set(valid_directions)) + else: + raise ValueError('direction must be either str or list of str') + self.direction = direction + + if isinstance(flip_ratio, list): + assert len(self.flip_ratio) == len(self.direction) + self.count = 0 + + def bbox_flip(self, bboxes, img_shape, direction): + """Flip bboxes horizontally. + + Args: + bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k) + img_shape (tuple[int]): Image shape (height, width) + direction (str): Flip direction. Options are 'horizontal', + 'vertical'. + + Returns: + numpy.ndarray: Flipped bounding boxes. + """ + + assert bboxes.shape[-1] % 4 == 0 + flipped = bboxes.copy() + if direction == 'horizontal': + w = img_shape[1] + flipped[..., 0::4] = w - bboxes[..., 2::4] + flipped[..., 2::4] = w - bboxes[..., 0::4] + elif direction == 'vertical': + h = img_shape[0] + flipped[..., 1::4] = h - bboxes[..., 3::4] + flipped[..., 3::4] = h - bboxes[..., 1::4] + elif direction == 'diagonal': + w = img_shape[1] + h = img_shape[0] + flipped[..., 0::4] = w - bboxes[..., 2::4] + flipped[..., 1::4] = h - bboxes[..., 3::4] + flipped[..., 2::4] = w - bboxes[..., 0::4] + flipped[..., 3::4] = h - bboxes[..., 1::4] + else: + raise ValueError(f"Invalid flipping direction '{direction}'") + return flipped + + def keypoints_flip(self, keypointss, img_shape, direction): + """Flip keypoints horizontally.""" + + assert direction == 'horizontal' + assert keypointss.shape[-1] == 3 + num_kps = keypointss.shape[1] + assert num_kps in [4, 5], f'Only Support num_kps=4 or 5, got:{num_kps}' + assert keypointss.ndim == 3 + flipped = keypointss.copy() + if num_kps == 5: + flip_order = [1, 0, 2, 4, 3] + elif num_kps == 4: + flip_order = [3, 2, 1, 0] + for idx, a in enumerate(flip_order): + flipped[:, idx, :] = keypointss[:, a, :] + w = img_shape[1] + flipped[..., 0] = w - flipped[..., 0] + return flipped + + def __call__(self, results): + """Call function to flip bounding boxes, masks, semantic segmentation + maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Flipped results, 'flip', 'flip_direction' keys are added \ + into result dict. + """ + if 'flip' not in results: + if isinstance(self.direction, list): + # None means non-flip + direction_list = self.direction + [None] + else: + # None means non-flip + direction_list = [self.direction, None] + + if isinstance(self.flip_ratio, list): + non_flip_ratio = 1 - sum(self.flip_ratio) + flip_ratio_list = self.flip_ratio + [non_flip_ratio] + else: + non_flip_ratio = 1 - self.flip_ratio + # exclude non-flip + single_ratio = self.flip_ratio / (len(direction_list) - 1) + flip_ratio_list = [single_ratio] * (len(direction_list) + - 1) + [non_flip_ratio] + + cur_dir = np.random.choice(direction_list, p=flip_ratio_list) + + results['flip'] = cur_dir is not None + if 'flip_direction' not in results: + results['flip_direction'] = cur_dir + if results['flip']: + # flip image + for key in results.get('img_fields', ['img']): + results[key] = mmcv.imflip( + results[key], direction=results['flip_direction']) + # flip bboxes + for key in results.get('bbox_fields', []): + results[key] = self.bbox_flip(results[key], + results['img_shape'], + results['flip_direction']) + # flip kps + for key in results.get('keypoints_fields', []): + results[key] = self.keypoints_flip(results[key], + results['img_shape'], + results['flip_direction']) + # flip masks + for key in results.get('mask_fields', []): + results[key] = results[key].flip(results['flip_direction']) + + # flip segs + for key in results.get('seg_fields', []): + results[key] = mmcv.imflip( + results[key], direction=results['flip_direction']) + return results + + def __repr__(self): + return self.__class__.__name__ + f'(flip_ratio={self.flip_ratio})' + + +@PIPELINES.register_module() +class RandomSquareCrop(object): + """Random crop the image & bboxes, the cropped patches have minimum IoU + requirement with original image & bboxes, the IoU threshold is randomly + selected from min_ious. + + Args: + min_ious (tuple): minimum IoU threshold for all intersections with + bounding boxes + min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w, + where a >= min_crop_size). + + Note: + The keys for bboxes, labels and masks should be paired. That is, \ + `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \ + `gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`. + """ + + def __init__(self, + crop_ratio_range=None, + crop_choice=None, + bbox_clip_border=True, + big_face_ratio=0, + big_face_crop_choice=None): + + self.crop_ratio_range = crop_ratio_range + self.crop_choice = crop_choice + self.big_face_crop_choice = big_face_crop_choice + self.bbox_clip_border = bbox_clip_border + + assert (self.crop_ratio_range is None) ^ (self.crop_choice is None) + if self.crop_ratio_range is not None: + self.crop_ratio_min, self.crop_ratio_max = self.crop_ratio_range + + self.bbox2label = { + 'gt_bboxes': 'gt_labels', + 'gt_bboxes_ignore': 'gt_labels_ignore' + } + self.bbox2mask = { + 'gt_bboxes': 'gt_masks', + 'gt_bboxes_ignore': 'gt_masks_ignore' + } + assert big_face_ratio >= 0 and big_face_ratio <= 1.0 + self.big_face_ratio = big_face_ratio + + def __call__(self, results): + """Call function to crop images and bounding boxes with minimum IoU + constraint. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with images and bounding boxes cropped, \ + 'img_shape' key is updated. + """ + + if 'img_fields' in results: + assert results['img_fields'] == ['img'], \ + 'Only single img_fields is allowed' + img = results['img'] + assert 'bbox_fields' in results + assert 'gt_bboxes' in results + # try augment big face images + find_bigface = False + if np.random.random() < self.big_face_ratio: + min_size = 100 # h and w + expand_ratio = 0.3 # expand ratio of croped face alongwith both w and h + bbox = results['gt_bboxes'].copy() + lmks = results['gt_keypointss'].copy() + label = results['gt_labels'].copy() + # filter small faces + size_mask = ((bbox[:, 2] - bbox[:, 0]) > min_size) * ( + (bbox[:, 3] - bbox[:, 1]) > min_size) + bbox = bbox[size_mask] + lmks = lmks[size_mask] + label = label[size_mask] + # randomly choose a face that has no overlap with others + if len(bbox) > 0: + overlaps = bbox_overlaps(bbox, bbox) + overlaps -= np.eye(overlaps.shape[0]) + iou_mask = np.sum(overlaps, axis=1) == 0 + bbox = bbox[iou_mask] + lmks = lmks[iou_mask] + label = label[iou_mask] + if len(bbox) > 0: + choice = np.random.randint(len(bbox)) + bbox = bbox[choice] + lmks = lmks[choice] + label = [label[choice]] + w = bbox[2] - bbox[0] + h = bbox[3] - bbox[1] + x1 = bbox[0] - w * expand_ratio + x2 = bbox[2] + w * expand_ratio + y1 = bbox[1] - h * expand_ratio + y2 = bbox[3] + h * expand_ratio + x1, x2 = np.clip([x1, x2], 0, img.shape[1]) + y1, y2 = np.clip([y1, y2], 0, img.shape[0]) + bbox -= np.tile([x1, y1], 2) + lmks -= (x1, y1, 0) + + find_bigface = True + img = img[int(y1):int(y2), int(x1):int(x2), :] + results['gt_bboxes'] = np.expand_dims(bbox, axis=0) + results['gt_keypointss'] = np.expand_dims(lmks, axis=0) + results['gt_labels'] = np.array(label) + results['img'] = img + + boxes = results['gt_bboxes'] + h, w, c = img.shape + + if self.crop_ratio_range is not None: + max_scale = self.crop_ratio_max + else: + max_scale = np.amax(self.crop_choice) + scale_retry = 0 + while True: + scale_retry += 1 + if scale_retry == 1 or max_scale > 1.0: + if self.crop_ratio_range is not None: + scale = np.random.uniform(self.crop_ratio_min, + self.crop_ratio_max) + elif self.crop_choice is not None: + scale = np.random.choice(self.crop_choice) + else: + scale = scale * 1.2 + + if find_bigface: + # select a scale from big_face_crop_choice if in big_face mode + scale = np.random.choice(self.big_face_crop_choice) + + for i in range(250): + long_side = max(w, h) + cw = int(scale * long_side) + ch = cw + + # TODO +1 + if w == cw: + left = 0 + elif w > cw: + left = random.randint(0, w - cw) + else: + left = random.randint(w - cw, 0) + if h == ch: + top = 0 + elif h > ch: + top = random.randint(0, h - ch) + else: + top = random.randint(h - ch, 0) + + patch = np.array( + (int(left), int(top), int(left + cw), int(top + ch)), + dtype=np.int32) + + # center of boxes should inside the crop img + # only adjust boxes and instance masks when the gt is not empty + # adjust boxes + def is_center_of_bboxes_in_patch(boxes, patch): + # TODO >= + center = (boxes[:, :2] + boxes[:, 2:]) / 2 + mask = \ + ((center[:, 0] > patch[0]) + * (center[:, 1] > patch[1]) + * (center[:, 0] < patch[2]) + * (center[:, 1] < patch[3])) + return mask + + mask = is_center_of_bboxes_in_patch(boxes, patch) + if not mask.any(): + continue + for key in results.get('bbox_fields', []): + boxes = results[key].copy() + mask = is_center_of_bboxes_in_patch(boxes, patch) + boxes = boxes[mask] + if self.bbox_clip_border: + boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) + boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) + boxes -= np.tile(patch[:2], 2) + + results[key] = boxes + # labels + label_key = self.bbox2label.get(key) + if label_key in results: + results[label_key] = results[label_key][mask] + + # keypoints field + if key == 'gt_bboxes': + for kps_key in results.get('keypoints_fields', []): + keypointss = results[kps_key].copy() + keypointss = keypointss[mask, :, :] + if self.bbox_clip_border: + keypointss[:, :, : + 2] = keypointss[:, :, :2].clip( + max=patch[2:]) + keypointss[:, :, : + 2] = keypointss[:, :, :2].clip( + min=patch[:2]) + keypointss[:, :, 0] -= patch[0] + keypointss[:, :, 1] -= patch[1] + results[kps_key] = keypointss + + # mask fields + mask_key = self.bbox2mask.get(key) + if mask_key in results: + results[mask_key] = results[mask_key][mask.nonzero() + [0]].crop(patch) + + # adjust the img no matter whether the gt is empty before crop + rimg = np.ones((ch, cw, 3), dtype=img.dtype) * 128 + patch_from = patch.copy() + patch_from[0] = max(0, patch_from[0]) + patch_from[1] = max(0, patch_from[1]) + patch_from[2] = min(img.shape[1], patch_from[2]) + patch_from[3] = min(img.shape[0], patch_from[3]) + patch_to = patch.copy() + patch_to[0] = max(0, patch_to[0] * -1) + patch_to[1] = max(0, patch_to[1] * -1) + patch_to[2] = patch_to[0] + (patch_from[2] - patch_from[0]) + patch_to[3] = patch_to[1] + (patch_from[3] - patch_from[1]) + rimg[patch_to[1]:patch_to[3], + patch_to[0]:patch_to[2], :] = img[ + patch_from[1]:patch_from[3], + patch_from[0]:patch_from[2], :] + img = rimg + results['img'] = img + results['img_shape'] = img.shape + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(min_ious={self.min_iou}, ' + repr_str += f'crop_size={self.crop_size})' + return repr_str diff --git a/modelscope/models/cv/face_detection/mmdet_patch/datasets/retinaface.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/retinaface.py similarity index 97% rename from modelscope/models/cv/face_detection/mmdet_patch/datasets/retinaface.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/retinaface.py index bbacd9be..40c440b9 100755 --- a/modelscope/models/cv/face_detection/mmdet_patch/datasets/retinaface.py +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/retinaface.py @@ -13,7 +13,7 @@ class RetinaFaceDataset(CustomDataset): CLASSES = ('FG', ) def __init__(self, min_size=None, **kwargs): - self.NK = 5 + self.NK = kwargs.pop('num_kps', 5) self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)} self.min_size = min_size self.gt_path = kwargs.get('gt_path') @@ -33,7 +33,8 @@ class RetinaFaceDataset(CustomDataset): if len(values) > 4: if len(values) > 5: kps = np.array( - values[4:19], dtype=np.float32).reshape((self.NK, 3)) + values[4:4 + self.NK * 3], dtype=np.float32).reshape( + (self.NK, 3)) for li in range(kps.shape[0]): if (kps[li, :] == -1).all(): kps[li][2] = 0.0 # weight = 0, ignore diff --git a/modelscope/models/cv/face_detection/mmdet_patch/models/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/__init__.py similarity index 100% rename from modelscope/models/cv/face_detection/mmdet_patch/models/__init__.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/__init__.py diff --git a/modelscope/models/cv/face_detection/mmdet_patch/models/backbones/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/__init__.py similarity index 100% rename from modelscope/models/cv/face_detection/mmdet_patch/models/backbones/__init__.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/__init__.py diff --git a/modelscope/models/cv/face_detection/mmdet_patch/models/backbones/resnet.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/resnet.py similarity index 100% rename from modelscope/models/cv/face_detection/mmdet_patch/models/backbones/resnet.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/resnet.py diff --git a/modelscope/models/cv/face_detection/mmdet_patch/models/dense_heads/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/dense_heads/__init__.py similarity index 100% rename from modelscope/models/cv/face_detection/mmdet_patch/models/dense_heads/__init__.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/dense_heads/__init__.py diff --git a/modelscope/models/cv/face_detection/mmdet_patch/models/dense_heads/scrfd_head.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/dense_heads/scrfd_head.py similarity index 99% rename from modelscope/models/cv/face_detection/mmdet_patch/models/dense_heads/scrfd_head.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/dense_heads/scrfd_head.py index acc45670..77ec99cf 100755 --- a/modelscope/models/cv/face_detection/mmdet_patch/models/dense_heads/scrfd_head.py +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/dense_heads/scrfd_head.py @@ -103,6 +103,7 @@ class SCRFDHead(AnchorHead): scale_mode=1, dw_conv=False, use_kps=False, + num_kps=5, loss_kps=dict( type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.1), **kwargs): @@ -116,7 +117,7 @@ class SCRFDHead(AnchorHead): self.scale_mode = scale_mode self.use_dfl = True self.dw_conv = dw_conv - self.NK = 5 + self.NK = num_kps self.extra_flops = 0.0 if loss_dfl is None or not loss_dfl: self.use_dfl = False @@ -323,8 +324,8 @@ class SCRFDHead(AnchorHead): batch_size, -1, self.cls_out_channels).sigmoid() bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4) - kps_pred = kps_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 10) - + kps_pred = kps_pred.permute(0, 2, 3, + 1).reshape(batch_size, -1, self.NK * 2) return cls_score, bbox_pred, kps_pred def forward_train(self, @@ -788,7 +789,7 @@ class SCRFDHead(AnchorHead): if self.use_dfl: kps_pred = self.integral(kps_pred) * stride[0] else: - kps_pred = kps_pred.reshape((-1, 10)) * stride[0] + kps_pred = kps_pred.reshape((-1, self.NK * 2)) * stride[0] nms_pre = cfg.get('nms_pre', -1) if nms_pre > 0 and scores.shape[0] > nms_pre: @@ -815,7 +816,7 @@ class SCRFDHead(AnchorHead): mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) if mlvl_kps is not None: scale_factor2 = torch.tensor( - [scale_factor[0], scale_factor[1]] * 5) + [scale_factor[0], scale_factor[1]] * self.NK) mlvl_kps /= scale_factor2.to(mlvl_kps.device) mlvl_scores = torch.cat(mlvl_scores) diff --git a/modelscope/models/cv/face_detection/mmdet_patch/models/detectors/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/__init__.py similarity index 100% rename from modelscope/models/cv/face_detection/mmdet_patch/models/detectors/__init__.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/__init__.py diff --git a/modelscope/models/cv/face_detection/mmdet_patch/models/detectors/scrfd.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/scrfd.py similarity index 50% rename from modelscope/models/cv/face_detection/mmdet_patch/models/detectors/scrfd.py rename to modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/scrfd.py index a5f5cac2..18b46be1 100755 --- a/modelscope/models/cv/face_detection/mmdet_patch/models/detectors/scrfd.py +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/scrfd.py @@ -54,7 +54,13 @@ class SCRFD(SingleStageDetector): gt_bboxes_ignore) return losses - def simple_test(self, img, img_metas, rescale=False): + def simple_test(self, + img, + img_metas, + rescale=False, + repeat_head=1, + output_kps_var=0, + output_results=1): """Test function without test time augmentation. Args: @@ -62,6 +68,9 @@ class SCRFD(SingleStageDetector): img_metas (list[dict]): List of image information. rescale (bool, optional): Whether to rescale the results. Defaults to False. + repeat_head (int): repeat inference times in head + output_kps_var (int): whether output kps var to calculate quality + output_results (int): 0: nothing 1: bbox 2: both bbox and kps Returns: list[list[np.ndarray]]: BBox results of each image and classes. @@ -69,40 +78,71 @@ class SCRFD(SingleStageDetector): corresponds to each class. """ x = self.extract_feat(img) - outs = self.bbox_head(x) - if torch.onnx.is_in_onnx_export(): - print('single_stage.py in-onnx-export') - print(outs.__class__) - cls_score, bbox_pred, kps_pred = outs - for c in cls_score: - print(c.shape) - for c in bbox_pred: - print(c.shape) - if self.bbox_head.use_kps: - for c in kps_pred: + assert repeat_head >= 1 + kps_out0 = [] + kps_out1 = [] + kps_out2 = [] + for i in range(repeat_head): + outs = self.bbox_head(x) + kps_out0 += [outs[2][0].detach().cpu().numpy()] + kps_out1 += [outs[2][1].detach().cpu().numpy()] + kps_out2 += [outs[2][2].detach().cpu().numpy()] + if output_kps_var: + var0 = np.var(np.vstack(kps_out0), axis=0).mean() + var1 = np.var(np.vstack(kps_out1), axis=0).mean() + var2 = np.var(np.vstack(kps_out2), axis=0).mean() + var = np.mean([var0, var1, var2]) + else: + var = None + + if output_results > 0: + if torch.onnx.is_in_onnx_export(): + print('single_stage.py in-onnx-export') + print(outs.__class__) + cls_score, bbox_pred, kps_pred = outs + for c in cls_score: + print(c.shape) + for c in bbox_pred: print(c.shape) - return (cls_score, bbox_pred, kps_pred) - else: - return (cls_score, bbox_pred) - bbox_list = self.bbox_head.get_bboxes( - *outs, img_metas, rescale=rescale) + if self.bbox_head.use_kps: + for c in kps_pred: + print(c.shape) + return (cls_score, bbox_pred, kps_pred) + else: + return (cls_score, bbox_pred) + bbox_list = self.bbox_head.get_bboxes( + *outs, img_metas, rescale=rescale) - # return kps if use_kps - if len(bbox_list[0]) == 2: - bbox_results = [ - bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) - for det_bboxes, det_labels in bbox_list - ] - elif len(bbox_list[0]) == 3: - bbox_results = [ - bbox2result( - det_bboxes, - det_labels, - self.bbox_head.num_classes, - kps=det_kps) - for det_bboxes, det_labels, det_kps in bbox_list - ] - return bbox_results + # return kps if use_kps + if len(bbox_list[0]) == 2: + bbox_results = [ + bbox2result(det_bboxes, det_labels, + self.bbox_head.num_classes) + for det_bboxes, det_labels in bbox_list + ] + elif len(bbox_list[0]) == 3: + if output_results == 2: + bbox_results = [ + bbox2result( + det_bboxes, + det_labels, + self.bbox_head.num_classes, + kps=det_kps, + num_kps=self.bbox_head.NK) + for det_bboxes, det_labels, det_kps in bbox_list + ] + elif output_results == 1: + bbox_results = [ + bbox2result(det_bboxes, det_labels, + self.bbox_head.num_classes) + for det_bboxes, det_labels, _ in bbox_list + ] + else: + bbox_results = None + if var is not None: + return bbox_results, var + else: + return bbox_results def feature_test(self, img): x = self.extract_feat(img) diff --git a/modelscope/models/cv/face_detection/scrfd/scrfd_detect.py b/modelscope/models/cv/face_detection/scrfd/scrfd_detect.py new file mode 100644 index 00000000..59611604 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/scrfd_detect.py @@ -0,0 +1,71 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from copy import deepcopy +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['ScrfdDetect'] + + +@MODELS.register_module(Tasks.face_detection, module_name=Models.scrfd) +class ScrfdDetect(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the face detection model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + from mmcv import Config + from mmcv.parallel import MMDataParallel + from mmcv.runner import load_checkpoint + from mmdet.models import build_detector + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets import RetinaFaceDataset + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import RandomSquareCrop + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.backbones import ResNetV1e + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.dense_heads import SCRFDHead + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.detectors import SCRFD + cfg = Config.fromfile(osp.join(model_dir, 'mmcv_scrfd.py')) + ckpt_path = osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) + cfg.model.test_cfg.score_thr = kwargs.get('score_thr', 0.3) + detector = build_detector(cfg.model) + logger.info(f'loading model from {ckpt_path}') + device = torch.device( + f'cuda:{0}' if torch.cuda.is_available() else 'cpu') + load_checkpoint(detector, ckpt_path, map_location=device) + detector = MMDataParallel(detector, device_ids=[0]) + detector.eval() + self.detector = detector + logger.info('load model done') + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + result = self.detector( + return_loss=False, + rescale=True, + img=[input['img'][0].unsqueeze(0)], + img_metas=[[dict(input['img_metas'][0].data)]], + output_results=2) + assert result is not None + result = result[0][0] + bboxes = result[:, :4].tolist() + kpss = result[:, 5:].tolist() + scores = result[:, 4].tolist() + return { + OutputKeys.SCORES: scores, + OutputKeys.BOXES: bboxes, + OutputKeys.KEYPOINTS: kpss + } + + def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: + return input diff --git a/modelscope/outputs.py b/modelscope/outputs.py index ab3ea54a..3001c03c 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -90,6 +90,25 @@ TASK_OUTPUTS = { Tasks.face_detection: [OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS], + # card detection result for single sample + # { + # "scores": [0.9, 0.1, 0.05, 0.05] + # "boxes": [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ], + # "keypoints": [ + # [x1, y1, x2, y2, x3, y3, x4, y4], + # [x1, y1, x2, y2, x3, y3, x4, y4], + # [x1, y1, x2, y2, x3, y3, x4, y4], + # [x1, y1, x2, y2, x3, y3, x4, y4], + # ], + # } + Tasks.card_detection: + [OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS], + # facial expression recognition result for single sample # { # "scores": [0.9, 0.1, 0.02, 0.02, 0.02, 0.02, 0.02], diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index bc9073bc..174d10b1 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -116,6 +116,10 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.hand_2d_keypoints: (Pipelines.hand_2d_keypoints, 'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody'), + Tasks.face_detection: (Pipelines.face_detection, + 'damo/cv_resnet_facedetection_scrfd10gkps'), + Tasks.card_detection: (Pipelines.card_detection, + 'damo/cv_resnet_carddetection_scrfd34gkps'), Tasks.face_detection: (Pipelines.face_detection, 'damo/cv_resnet101_face-detection_cvpr22papermogface'), diff --git a/modelscope/pipelines/cv/card_detection_pipeline.py b/modelscope/pipelines/cv/card_detection_pipeline.py new file mode 100644 index 00000000..00b18024 --- /dev/null +++ b/modelscope/pipelines/cv/card_detection_pipeline.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from modelscope.metainfo import Pipelines +from modelscope.pipelines.builder import PIPELINES +from modelscope.pipelines.cv.face_detection_pipeline import \ + FaceDetectionPipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.card_detection, module_name=Pipelines.card_detection) +class CardDetectionPipeline(FaceDetectionPipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a card detection pipeline for prediction + Args: + model: model id on modelscope hub. + """ + thr = 0.45 # card/face detect use different threshold + super().__init__(model=model, score_thr=thr, **kwargs) diff --git a/modelscope/pipelines/cv/face_detection_pipeline.py b/modelscope/pipelines/cv/face_detection_pipeline.py index eff5b70f..608567a4 100644 --- a/modelscope/pipelines/cv/face_detection_pipeline.py +++ b/modelscope/pipelines/cv/face_detection_pipeline.py @@ -8,6 +8,7 @@ import PIL import torch from modelscope.metainfo import Pipelines +from modelscope.models.cv.face_detection import ScrfdDetect from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES @@ -29,27 +30,8 @@ class FaceDetectionPipeline(Pipeline): model: model id on modelscope hub. """ super().__init__(model=model, **kwargs) - from mmcv import Config - from mmcv.parallel import MMDataParallel - from mmcv.runner import load_checkpoint - from mmdet.models import build_detector - from modelscope.models.cv.face_detection.mmdet_patch.datasets import RetinaFaceDataset - from modelscope.models.cv.face_detection.mmdet_patch.datasets.pipelines import RandomSquareCrop - from modelscope.models.cv.face_detection.mmdet_patch.models.backbones import ResNetV1e - from modelscope.models.cv.face_detection.mmdet_patch.models.dense_heads import SCRFDHead - from modelscope.models.cv.face_detection.mmdet_patch.models.detectors import SCRFD - cfg = Config.fromfile(osp.join(model, 'mmcv_scrfd_10g_bnkps.py')) - detector = build_detector( - cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) - ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_BIN_FILE) - logger.info(f'loading model from {ckpt_path}') - device = torch.device( - f'cuda:{0}' if torch.cuda.is_available() else 'cpu') - load_checkpoint(detector, ckpt_path, map_location=device) - detector = MMDataParallel(detector, device_ids=[0]) - detector.eval() + detector = ScrfdDetect(model_dir=model, **kwargs) self.detector = detector - logger.info('load model done') def preprocess(self, input: Input) -> Dict[str, Any]: img = LoadImage.convert_to_ndarray(input) @@ -85,22 +67,7 @@ class FaceDetectionPipeline(Pipeline): return result def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: - - result = self.detector( - return_loss=False, - rescale=True, - img=[input['img'][0].unsqueeze(0)], - img_metas=[[dict(input['img_metas'][0].data)]]) - assert result is not None - result = result[0][0] - bboxes = result[:, :4].tolist() - kpss = result[:, 5:].tolist() - scores = result[:, 4].tolist() - return { - OutputKeys.SCORES: scores, - OutputKeys.BOXES: bboxes, - OutputKeys.KEYPOINTS: kpss - } + return self.detector(input) def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return inputs diff --git a/modelscope/pipelines/cv/face_recognition_pipeline.py b/modelscope/pipelines/cv/face_recognition_pipeline.py index 873e4a1f..abae69d4 100644 --- a/modelscope/pipelines/cv/face_recognition_pipeline.py +++ b/modelscope/pipelines/cv/face_recognition_pipeline.py @@ -49,7 +49,7 @@ class FaceRecognitionPipeline(Pipeline): # face detect pipeline det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' self.face_detection = pipeline( - Tasks.face_detection, model=det_model_id) + Tasks.face_detection, model=det_model_id, model_revision='v2') def _choose_face(self, det_result, diff --git a/modelscope/trainers/cv/card_detection_scrfd_trainer.py b/modelscope/trainers/cv/card_detection_scrfd_trainer.py new file mode 100644 index 00000000..e1f81bcf --- /dev/null +++ b/modelscope/trainers/cv/card_detection_scrfd_trainer.py @@ -0,0 +1,18 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from modelscope.metainfo import Trainers +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.cv.face_detection_scrfd_trainer import \ + FaceDetectionScrfdTrainer + + +@TRAINERS.register_module(module_name=Trainers.card_detection_scrfd) +class CardDetectionScrfdTrainer(FaceDetectionScrfdTrainer): + + def __init__(self, cfg_file: str, *args, **kwargs): + """ High-level finetune api for SCRFD. + + Args: + cfg_file: Path to configuration file. + """ + # card/face dataset use different img folder names + super().__init__(cfg_file, imgdir_name='', **kwargs) diff --git a/modelscope/trainers/cv/face_detection_scrfd_trainer.py b/modelscope/trainers/cv/face_detection_scrfd_trainer.py new file mode 100644 index 00000000..9cfae7dd --- /dev/null +++ b/modelscope/trainers/cv/face_detection_scrfd_trainer.py @@ -0,0 +1,154 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import copy +import os +import os.path as osp +import time +from typing import Callable, Dict, Optional + +from modelscope.metainfo import Trainers +from modelscope.trainers.base import BaseTrainer +from modelscope.trainers.builder import TRAINERS + + +@TRAINERS.register_module(module_name=Trainers.face_detection_scrfd) +class FaceDetectionScrfdTrainer(BaseTrainer): + + def __init__(self, + cfg_file: str, + cfg_modify_fn: Optional[Callable] = None, + *args, + **kwargs): + """ High-level finetune api for SCRFD. + + Args: + cfg_file: Path to configuration file. + cfg_modify_fn: An input fn which is used to modify the cfg read out of the file. + """ + import mmcv + from mmcv.runner import get_dist_info, init_dist + from mmcv.utils import get_git_hash + from mmdet.utils import collect_env, get_root_logger + from mmdet.apis import set_random_seed + from mmdet.models import build_detector + from mmdet.datasets import build_dataset + from mmdet import __version__ + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets import RetinaFaceDataset + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import DefaultFormatBundleV2 + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import LoadAnnotationsV2 + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import RotateV2 + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import RandomSquareCrop + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.backbones import ResNetV1e + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.dense_heads import SCRFDHead + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.detectors import SCRFD + super().__init__(cfg_file) + cfg = self.cfg + if 'work_dir' in kwargs: + cfg.work_dir = kwargs['work_dir'] + else: + # use config filename as default work_dir if work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(cfg_file))[0]) + mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) + + if 'resume_from' in kwargs: # pretrain model for finetune + cfg.resume_from = kwargs['resume_from'] + cfg.device = 'cuda' + if 'gpu_ids' in kwargs: + cfg.gpu_ids = kwargs['gpu_ids'] + else: + cfg.gpu_ids = range(1) + labelfile_name = kwargs.pop('labelfile_name', 'labelv2.txt') + imgdir_name = kwargs.pop('imgdir_name', 'images/') + if 'train_root' in kwargs: + cfg.data.train.ann_file = kwargs['train_root'] + labelfile_name + cfg.data.train.img_prefix = kwargs['train_root'] + imgdir_name + if 'val_root' in kwargs: + cfg.data.val.ann_file = kwargs['val_root'] + labelfile_name + cfg.data.val.img_prefix = kwargs['val_root'] + imgdir_name + if 'total_epochs' in kwargs: + cfg.total_epochs = kwargs['total_epochs'] + if cfg_modify_fn is not None: + cfg = cfg_modify_fn(cfg) + if 'launcher' in kwargs: + distributed = True + init_dist(kwargs['launcher'], **cfg.dist_params) + # re-set gpu_ids with distributed training mode + _, world_size = get_dist_info() + cfg.gpu_ids = range(world_size) + else: + distributed = False + # no_validate=True will not evaluate checkpoint during training + cfg.no_validate = kwargs.get('no_validate', False) + # init the logger before other steps + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + log_file = osp.join(cfg.work_dir, f'{timestamp}.log') + logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) + # init the meta dict to record some important information such as + # environment info and seed, which will be logged + meta = dict() + # log env info + env_info_dict = collect_env() + env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) + dash_line = '-' * 60 + '\n' + logger.info('Environment info:\n' + dash_line + env_info + '\n' + + dash_line) + meta['env_info'] = env_info + meta['config'] = cfg.pretty_text + # log some basic info + logger.info(f'Distributed training: {distributed}') + logger.info(f'Config:\n{cfg.pretty_text}') + + # set random seeds + if 'seed' in kwargs: + cfg.seed = kwargs['seed'] + _deterministic = kwargs.get('deterministic', False) + logger.info(f'Set random seed to {kwargs["seed"]}, ' + f'deterministic: {_deterministic}') + set_random_seed(kwargs['seed'], deterministic=_deterministic) + else: + cfg.seed = None + meta['seed'] = cfg.seed + meta['exp_name'] = osp.basename(cfg_file) + + model = build_detector(cfg.model) + model.init_weights() + datasets = [build_dataset(cfg.data.train)] + if len(cfg.workflow) == 2: + val_dataset = copy.deepcopy(cfg.data.val) + val_dataset.pipeline = cfg.data.train.pipeline + datasets.append(build_dataset(val_dataset)) + if cfg.checkpoint_config is not None: + # save mmdet version, config file content and class names in + # checkpoints as meta data + cfg.checkpoint_config.meta = dict( + mmdet_version=__version__ + get_git_hash()[:7], + CLASSES=datasets[0].CLASSES) + # add an attribute for visualization convenience + model.CLASSES = datasets[0].CLASSES + + self.cfg = cfg + self.datasets = datasets + self.model = model + self.distributed = distributed + self.timestamp = timestamp + self.meta = meta + self.logger = logger + + def train(self, *args, **kwargs): + from mmdet.apis import train_detector + train_detector( + self.model, + self.datasets, + self.cfg, + distributed=self.distributed, + validate=(not self.cfg.no_validate), + timestamp=self.timestamp, + meta=self.meta) + + def evaluate(self, + checkpoint_path: str = None, + *args, + **kwargs) -> Dict[str, float]: + cfg = self.cfg.evaluation + logger.info(f'eval cfg {cfg}') + logger.info(f'checkpoint_path {checkpoint_path}') diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 4fa3d766..5f0532ce 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -19,6 +19,7 @@ class CVTasks(object): # human face body related animal_recognition = 'animal-recognition' face_detection = 'face-detection' + card_detection = 'card-detection' face_recognition = 'face-recognition' facial_expression_recognition = 'facial-expression-recognition' face_2d_keypoints = 'face-2d-keypoints' diff --git a/modelscope/utils/cv/image_utils.py b/modelscope/utils/cv/image_utils.py index 06a9bbaa..2d420892 100644 --- a/modelscope/utils/cv/image_utils.py +++ b/modelscope/utils/cv/image_utils.py @@ -154,6 +154,54 @@ def draw_face_detection_result(img_path, detection_result): return img +def draw_card_detection_result(img_path, detection_result): + + def warp_img(src_img, kps, ratio): + short_size = 500 + if ratio > 1: + obj_h = short_size + obj_w = int(obj_h * ratio) + else: + obj_w = short_size + obj_h = int(obj_w / ratio) + input_pts = np.float32([kps[0], kps[1], kps[2], kps[3]]) + output_pts = np.float32([[0, obj_h - 1], [0, 0], [obj_w - 1, 0], + [obj_w - 1, obj_h - 1]]) + M = cv2.getPerspectiveTransform(input_pts, output_pts) + obj_img = cv2.warpPerspective(src_img, M, (obj_w, obj_h)) + return obj_img + + bboxes = np.array(detection_result[OutputKeys.BOXES]) + kpss = np.array(detection_result[OutputKeys.KEYPOINTS]) + scores = np.array(detection_result[OutputKeys.SCORES]) + img_list = [] + ver_col = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (0, 255, 255)] + img = cv2.imread(img_path) + img_list += [img] + assert img is not None, f"Can't read img: {img_path}" + for i in range(len(scores)): + bbox = bboxes[i].astype(np.int32) + kps = kpss[i].reshape(-1, 2).astype(np.int32) + _w = (kps[0][0] - kps[3][0])**2 + (kps[0][1] - kps[3][1])**2 + _h = (kps[0][0] - kps[1][0])**2 + (kps[0][1] - kps[1][1])**2 + ratio = 1.59 if _w >= _h else 1 / 1.59 + card_img = warp_img(img, kps, ratio) + img_list += [card_img] + score = scores[i] + x1, y1, x2, y2 = bbox + cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 4) + for k, kp in enumerate(kps): + cv2.circle(img, tuple(kp), 1, color=ver_col[k], thickness=10) + cv2.putText( + img, + f'{score:.2f}', (x1, y2), + 1, + 1.0, (0, 255, 0), + thickness=1, + lineType=8) + return img_list + + def created_boxed_image(image_in, box): image = load_image(image_in) img = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) diff --git a/tests/pipelines/test_card_detection.py b/tests/pipelines/test_card_detection.py new file mode 100644 index 00000000..d913f494 --- /dev/null +++ b/tests/pipelines/test_card_detection.py @@ -0,0 +1,66 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 + +from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import draw_card_detection_result +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class CardDetectionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.card_detection + self.model_id = 'damo/cv_resnet_carddetection_scrfd34gkps' + + def show_result(self, img_path, detection_result): + img_list = draw_card_detection_result(img_path, detection_result) + for i, img in enumerate(img_list): + if i == 0: + cv2.imwrite('result.jpg', img_list[0]) + print( + f'Found {len(img_list)-1} cards, output written to {osp.abspath("result.jpg")}' + ) + else: + cv2.imwrite(f'card_{i}.jpg', img_list[i]) + save_path = osp.abspath(f'card_{i}.jpg') + print(f'detect card_{i}: {save_path}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_dataset(self): + input_location = ['data/test/images/card_detection.jpg'] + + dataset = MsDataset.load(input_location, target='image') + card_detection = pipeline(Tasks.card_detection, model=self.model_id) + # note that for dataset output, the inference-output is a Generator that can be iterated. + result = card_detection(dataset) + result = next(result) + self.show_result(input_location[0], result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + card_detection = pipeline(Tasks.card_detection, model=self.model_id) + img_path = 'data/test/images/card_detection.jpg' + + result = card_detection(img_path) + self.show_result(img_path, result) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + card_detection = pipeline(Tasks.card_detection) + img_path = 'data/test/images/card_detection.jpg' + result = card_detection(img_path) + self.show_result(img_path, result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_face_detection.py b/tests/pipelines/test_face_detection.py index f89e9a94..31ae403e 100644 --- a/tests/pipelines/test_face_detection.py +++ b/tests/pipelines/test_face_detection.py @@ -25,10 +25,11 @@ class FaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_dataset(self): - input_location = ['data/test/images/face_detection.png'] + input_location = ['data/test/images/face_detection2.jpeg'] dataset = MsDataset.load(input_location, target='image') - face_detection = pipeline(Tasks.face_detection, model=self.model_id) + face_detection = pipeline( + Tasks.face_detection, model=self.model_id, model_revision='v2') # note that for dataset output, the inference-output is a Generator that can be iterated. result = face_detection(dataset) result = next(result) @@ -36,8 +37,9 @@ class FaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_modelhub(self): - face_detection = pipeline(Tasks.face_detection, model=self.model_id) - img_path = 'data/test/images/face_detection.png' + face_detection = pipeline( + Tasks.face_detection, model=self.model_id, model_revision='v2') + img_path = 'data/test/images/face_detection2.jpeg' result = face_detection(img_path) self.show_result(img_path, result) @@ -45,7 +47,7 @@ class FaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_modelhub_default_model(self): face_detection = pipeline(Tasks.face_detection) - img_path = 'data/test/images/face_detection.png' + img_path = 'data/test/images/face_detection2.jpeg' result = face_detection(img_path) self.show_result(img_path, result) diff --git a/tests/trainers/test_card_detection_scrfd_trainer.py b/tests/trainers/test_card_detection_scrfd_trainer.py new file mode 100644 index 00000000..af87000b --- /dev/null +++ b/tests/trainers/test_card_detection_scrfd_trainer.py @@ -0,0 +1,151 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import glob +import os +import shutil +import tempfile +import unittest + +import torch + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile +from modelscope.utils.test_utils import DistributedTestCase, test_level + + +def _setup(): + model_id = 'damo/cv_resnet_carddetection_scrfd34gkps' + # mini dataset only for unit test, remove '_mini' for full dataset. + ms_ds_syncards = MsDataset.load( + 'SyntheticCards_mini', namespace='shaoxuan') + + data_path = ms_ds_syncards.config_kwargs['split_config'] + train_dir = data_path['train'] + val_dir = data_path['validation'] + train_root = train_dir + '/' + os.listdir(train_dir)[0] + '/' + val_root = val_dir + '/' + os.listdir(val_dir)[0] + '/' + max_epochs = 1 # run epochs in unit test + + cache_path = snapshot_download(model_id) + + tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + return train_root, val_root, max_epochs, cache_path, tmp_dir + + +def train_func(**kwargs): + trainer = build_trainer( + name=Trainers.card_detection_scrfd, default_args=kwargs) + trainer.train() + + +class TestCardDetectionScrfdTrainerSingleGPU(unittest.TestCase): + + def setUp(self): + print(('SingleGPU Testing %s.%s' % + (type(self).__name__, self._testMethodName))) + self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup( + ) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + def _cfg_modify_fn(self, cfg): + cfg.checkpoint_config.interval = 1 + cfg.log_config.interval = 10 + cfg.evaluation.interval = 1 + cfg.data.workers_per_gpu = 3 + cfg.data.samples_per_gpu = 4 # batch size + return cfg + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_from_scratch(self): + kwargs = dict( + cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), + work_dir=self.tmp_dir, + train_root=self.train_root, + val_root=self.val_root, + total_epochs=self.max_epochs, + cfg_modify_fn=self._cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.card_detection_scrfd, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_finetune(self): + pretrain_epoch = 640 + self.max_epochs += pretrain_epoch + kwargs = dict( + cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), + work_dir=self.tmp_dir, + train_root=self.train_root, + val_root=self.val_root, + total_epochs=self.max_epochs, + resume_from=os.path.join(self.cache_path, + ModelFile.TORCH_MODEL_BIN_FILE), + cfg_modify_fn=self._cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.card_detection_scrfd, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(pretrain_epoch, self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +@unittest.skipIf(not torch.cuda.is_available() + or torch.cuda.device_count() <= 1, 'distributed unittest') +class TestCardDetectionScrfdTrainerMultiGpus(DistributedTestCase): + + def setUp(self): + print(('MultiGPUs Testing %s.%s' % + (type(self).__name__, self._testMethodName))) + self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup( + ) + cfg_file_path = os.path.join(self.cache_path, 'mmcv_scrfd.py') + cfg = Config.from_file(cfg_file_path) + cfg.checkpoint_config.interval = 1 + cfg.log_config.interval = 10 + cfg.evaluation.interval = 1 + cfg.data.workers_per_gpu = 3 + cfg.data.samples_per_gpu = 4 + cfg.dump(cfg_file_path) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_multi_gpus_finetune(self): + pretrain_epoch = 640 + self.max_epochs += pretrain_epoch + kwargs = dict( + cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), + work_dir=self.tmp_dir, + train_root=self.train_root, + val_root=self.val_root, + total_epochs=self.max_epochs, + resume_from=os.path.join(self.cache_path, + ModelFile.TORCH_MODEL_BIN_FILE), + launcher='pytorch') + self.start(train_func, num_gpus=2, **kwargs) + results_files = os.listdir(self.tmp_dir) + json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) + self.assertEqual(len(json_files), 1) + for i in range(pretrain_epoch, self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_face_detection_scrfd_trainer.py b/tests/trainers/test_face_detection_scrfd_trainer.py new file mode 100644 index 00000000..eb9440ef --- /dev/null +++ b/tests/trainers/test_face_detection_scrfd_trainer.py @@ -0,0 +1,150 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import glob +import os +import shutil +import tempfile +import unittest + +import torch + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile +from modelscope.utils.test_utils import DistributedTestCase, test_level + + +def _setup(): + model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' + # mini dataset only for unit test, remove '_mini' for full dataset. + ms_ds_widerface = MsDataset.load('WIDER_FACE_mini', namespace='shaoxuan') + + data_path = ms_ds_widerface.config_kwargs['split_config'] + train_dir = data_path['train'] + val_dir = data_path['validation'] + train_root = train_dir + '/' + os.listdir(train_dir)[0] + '/' + val_root = val_dir + '/' + os.listdir(val_dir)[0] + '/' + max_epochs = 1 # run epochs in unit test + + cache_path = snapshot_download(model_id, revision='v2') + + tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + return train_root, val_root, max_epochs, cache_path, tmp_dir + + +def train_func(**kwargs): + trainer = build_trainer( + name=Trainers.face_detection_scrfd, default_args=kwargs) + trainer.train() + + +class TestFaceDetectionScrfdTrainerSingleGPU(unittest.TestCase): + + def setUp(self): + print(('SingleGPU Testing %s.%s' % + (type(self).__name__, self._testMethodName))) + self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup( + ) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + def _cfg_modify_fn(self, cfg): + cfg.checkpoint_config.interval = 1 + cfg.log_config.interval = 10 + cfg.evaluation.interval = 1 + cfg.data.workers_per_gpu = 3 + cfg.data.samples_per_gpu = 4 # batch size + return cfg + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_from_scratch(self): + kwargs = dict( + cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), + work_dir=self.tmp_dir, + train_root=self.train_root, + val_root=self.val_root, + total_epochs=self.max_epochs, + cfg_modify_fn=self._cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.face_detection_scrfd, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_finetune(self): + pretrain_epoch = 640 + self.max_epochs += pretrain_epoch + kwargs = dict( + cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), + work_dir=self.tmp_dir, + train_root=self.train_root, + val_root=self.val_root, + total_epochs=self.max_epochs, + resume_from=os.path.join(self.cache_path, + ModelFile.TORCH_MODEL_BIN_FILE), + cfg_modify_fn=self._cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.face_detection_scrfd, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(pretrain_epoch, self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +@unittest.skipIf(not torch.cuda.is_available() + or torch.cuda.device_count() <= 1, 'distributed unittest') +class TestFaceDetectionScrfdTrainerMultiGpus(DistributedTestCase): + + def setUp(self): + print(('MultiGPUs Testing %s.%s' % + (type(self).__name__, self._testMethodName))) + self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup( + ) + cfg_file_path = os.path.join(self.cache_path, 'mmcv_scrfd.py') + cfg = Config.from_file(cfg_file_path) + cfg.checkpoint_config.interval = 1 + cfg.log_config.interval = 10 + cfg.evaluation.interval = 1 + cfg.data.workers_per_gpu = 3 + cfg.data.samples_per_gpu = 4 + cfg.dump(cfg_file_path) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_multi_gpus_finetune(self): + pretrain_epoch = 640 + self.max_epochs += pretrain_epoch + kwargs = dict( + cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), + work_dir=self.tmp_dir, + train_root=self.train_root, + val_root=self.val_root, + total_epochs=self.max_epochs, + resume_from=os.path.join(self.cache_path, + ModelFile.TORCH_MODEL_BIN_FILE), + launcher='pytorch') + self.start(train_func, num_gpus=2, **kwargs) + results_files = os.listdir(self.tmp_dir) + json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) + self.assertEqual(len(json_files), 1) + for i in range(pretrain_epoch, self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +if __name__ == '__main__': + unittest.main()