diff --git a/.gitignore b/.gitignore index 8a0db7fa..de086eea 100644 --- a/.gitignore +++ b/.gitignore @@ -121,6 +121,7 @@ source.sh tensorboard.sh .DS_Store replace.sh +result.png # Pytorch *.pth diff --git a/data/test/images/face_detection.png b/data/test/images/face_detection.png new file mode 100644 index 00000000..3b572877 --- /dev/null +++ b/data/test/images/face_detection.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa3963d1c54e6d3d46e9a59872a99ed955d4050092f5cfe5f591e03d740b7042 +size 653006 diff --git a/data/test/images/face_recognition_1.png b/data/test/images/face_recognition_1.png new file mode 100644 index 00000000..eefe2138 --- /dev/null +++ b/data/test/images/face_recognition_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:48e541daeb2692907efef47018e41abb5ae6bcd88eb5ff58290d7fe5dc8b2a13 +size 462584 diff --git a/data/test/images/face_recognition_2.png b/data/test/images/face_recognition_2.png new file mode 100644 index 00000000..1292d8cb --- /dev/null +++ b/data/test/images/face_recognition_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e9565b43d9f65361b9bad6553b327c2c6f02fd063a4c8dc0f461e88ea461989d +size 357166 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 3e31f422..5efc724c 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -10,6 +10,7 @@ class Models(object): Model name should only contain model info but not task info. """ # vision models + scrfd = 'scrfd' classification_model = 'ClassificationModel' nafnet = 'nafnet' csrnet = 'csrnet' @@ -67,6 +68,7 @@ class Pipelines(object): action_recognition = 'TAdaConv_action-recognition' animal_recognation = 'resnet101-animal_recog' cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' + face_detection = 'resnet-face-detection-scrfd10gkps' live_category = 'live-category' general_image_classification = 'vit-base_image-classification_ImageNet-labels' daily_image_classification = 'vit-base_image-classification_Dailylife-labels' @@ -76,6 +78,7 @@ class Pipelines(object): image_super_resolution = 'rrdb-image-super-resolution' face_image_generation = 'gan-face-image-generation' style_transfer = 'AAMS-style-transfer' + face_recognition = 'ir101-face-recognition-cfglint' image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' image2image_translation = 'image-to-image-translation' live_category = 'live-category' diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index 88177746..076e1f4e 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from . import (action_recognition, animal_recognition, cartoon, - cmdssl_video_embedding, face_generation, image_classification, - image_color_enhance, image_colorization, image_denoise, - image_instance_segmentation, super_resolution, virual_tryon) + cmdssl_video_embedding, face_detection, face_generation, + image_classification, image_color_enhance, image_colorization, + image_denoise, image_instance_segmentation, + image_to_image_translation, super_resolution, virual_tryon) diff --git a/modelscope/models/cv/face_detection/__init__.py b/modelscope/models/cv/face_detection/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/face_detection/mmdet_patch/__init__.py b/modelscope/models/cv/face_detection/mmdet_patch/__init__.py new file mode 100755 index 00000000..921bdc08 --- /dev/null +++ b/modelscope/models/cv/face_detection/mmdet_patch/__init__.py @@ -0,0 +1,5 @@ +""" +mmdet_patch is based on +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet, +all duplicate functions from official mmdetection are removed. +""" diff --git a/modelscope/models/cv/face_detection/mmdet_patch/core/bbox/__init__.py b/modelscope/models/cv/face_detection/mmdet_patch/core/bbox/__init__.py new file mode 100644 index 00000000..8375649c --- /dev/null +++ b/modelscope/models/cv/face_detection/mmdet_patch/core/bbox/__init__.py @@ -0,0 +1,3 @@ +from .transforms import bbox2result, distance2kps, kps2distance + +__all__ = ['bbox2result', 'distance2kps', 'kps2distance'] diff --git a/modelscope/models/cv/face_detection/mmdet_patch/core/bbox/transforms.py b/modelscope/models/cv/face_detection/mmdet_patch/core/bbox/transforms.py new file mode 100755 index 00000000..26278837 --- /dev/null +++ b/modelscope/models/cv/face_detection/mmdet_patch/core/bbox/transforms.py @@ -0,0 +1,86 @@ +""" +based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/core/bbox/transforms.py +""" +import numpy as np +import torch + + +def bbox2result(bboxes, labels, num_classes, kps=None): + """Convert detection results to a list of numpy arrays. + + Args: + bboxes (torch.Tensor | np.ndarray): shape (n, 5) + labels (torch.Tensor | np.ndarray): shape (n, ) + num_classes (int): class number, including background class + + Returns: + list(ndarray): bbox results of each class + """ + bbox_len = 5 if kps is None else 5 + 10 # if has kps, add 10 kps into bbox + if bboxes.shape[0] == 0: + return [ + np.zeros((0, bbox_len), dtype=np.float32) + for i in range(num_classes) + ] + else: + if isinstance(bboxes, torch.Tensor): + bboxes = bboxes.detach().cpu().numpy() + labels = labels.detach().cpu().numpy() + if kps is None: + return [bboxes[labels == i, :] for i in range(num_classes)] + else: # with kps + if isinstance(kps, torch.Tensor): + kps = kps.detach().cpu().numpy() + return [ + np.hstack([bboxes[labels == i, :], kps[labels == i, :]]) + for i in range(num_classes) + ] + + +def distance2kps(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded kps. + """ + preds = [] + for i in range(0, distance.shape[1], 2): + px = points[:, i % 2] + distance[:, i] + py = points[:, i % 2 + 1] + distance[:, i + 1] + if max_shape is not None: + px = px.clamp(min=0, max=max_shape[1]) + py = py.clamp(min=0, max=max_shape[0]) + preds.append(px) + preds.append(py) + return torch.stack(preds, -1) + + +def kps2distance(points, kps, max_dis=None, eps=0.1): + """Decode bounding box based on distances. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + kps (Tensor): Shape (n, K), "xyxy" format + max_dis (float): Upper bound of the distance. + eps (float): a small value to ensure target < max_dis, instead <= + + Returns: + Tensor: Decoded distances. + """ + + preds = [] + for i in range(0, kps.shape[1], 2): + px = kps[:, i] - points[:, i % 2] + py = kps[:, i + 1] - points[:, i % 2 + 1] + if max_dis is not None: + px = px.clamp(min=0, max=max_dis - eps) + py = py.clamp(min=0, max=max_dis - eps) + preds.append(px) + preds.append(py) + return torch.stack(preds, -1) diff --git a/modelscope/models/cv/face_detection/mmdet_patch/core/post_processing/__init__.py b/modelscope/models/cv/face_detection/mmdet_patch/core/post_processing/__init__.py new file mode 100755 index 00000000..8cd31348 --- /dev/null +++ b/modelscope/models/cv/face_detection/mmdet_patch/core/post_processing/__init__.py @@ -0,0 +1,3 @@ +from .bbox_nms import multiclass_nms + +__all__ = ['multiclass_nms'] diff --git a/modelscope/models/cv/face_detection/mmdet_patch/core/post_processing/bbox_nms.py b/modelscope/models/cv/face_detection/mmdet_patch/core/post_processing/bbox_nms.py new file mode 100644 index 00000000..efe8813f --- /dev/null +++ b/modelscope/models/cv/face_detection/mmdet_patch/core/post_processing/bbox_nms.py @@ -0,0 +1,85 @@ +""" +based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/core/post_processing/bbox_nms.py +""" +import torch + + +def multiclass_nms(multi_bboxes, + multi_scores, + score_thr, + nms_cfg, + max_num=-1, + score_factors=None, + return_inds=False, + multi_kps=None): + """NMS for multi-class bboxes. + + Args: + multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) + multi_scores (Tensor): shape (n, #class), where the last column + contains scores of the background class, but this will be ignored. + score_thr (float): bbox threshold, bboxes with scores lower than it + will not be considered. + nms_thr (float): NMS IoU threshold + max_num (int, optional): if there are more than max_num bboxes after + NMS, only top max_num will be kept. Default to -1. + score_factors (Tensor, optional): The factors multiplied to scores + before applying NMS. Default to None. + return_inds (bool, optional): Whether return the indices of kept + bboxes. Default to False. + + Returns: + tuple: (bboxes, labels, indices (optional)), tensors of shape (k, 5), + (k), and (k). Labels are 0-based. + """ + num_classes = multi_scores.size(1) - 1 + # exclude background category + kps = None + if multi_bboxes.shape[1] > 4: + bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) + if multi_kps is not None: + kps = multi_kps.view(multi_scores.size(0), -1, 10) + else: + bboxes = multi_bboxes[:, None].expand( + multi_scores.size(0), num_classes, 4) + if multi_kps is not None: + kps = multi_kps[:, None].expand( + multi_scores.size(0), num_classes, 10) + + scores = multi_scores[:, :-1] + if score_factors is not None: + scores = scores * score_factors[:, None] + + labels = torch.arange(num_classes, dtype=torch.long) + labels = labels.view(1, -1).expand_as(scores) + + bboxes = bboxes.reshape(-1, 4) + if kps is not None: + kps = kps.reshape(-1, 10) + scores = scores.reshape(-1) + labels = labels.reshape(-1) + + # remove low scoring boxes + valid_mask = scores > score_thr + inds = valid_mask.nonzero(as_tuple=False).squeeze(1) + bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds] + if kps is not None: + kps = kps[inds] + if inds.numel() == 0: + if torch.onnx.is_in_onnx_export(): + raise RuntimeError('[ONNX Error] Can not record NMS ' + 'as it has not been executed this time') + return bboxes, labels, kps + + # TODO: add size check before feed into batched_nms + from mmcv.ops.nms import batched_nms + dets, keep = batched_nms(bboxes, scores, labels, nms_cfg) + + if max_num > 0: + dets = dets[:max_num] + keep = keep[:max_num] + + if return_inds: + return dets, labels[keep], kps[keep], keep + else: + return dets, labels[keep], kps[keep] diff --git a/modelscope/models/cv/face_detection/mmdet_patch/datasets/__init__.py b/modelscope/models/cv/face_detection/mmdet_patch/datasets/__init__.py new file mode 100644 index 00000000..07a45208 --- /dev/null +++ b/modelscope/models/cv/face_detection/mmdet_patch/datasets/__init__.py @@ -0,0 +1,3 @@ +from .retinaface import RetinaFaceDataset + +__all__ = ['RetinaFaceDataset'] diff --git a/modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/__init__.py b/modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/__init__.py new file mode 100755 index 00000000..979212a3 --- /dev/null +++ b/modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/__init__.py @@ -0,0 +1,3 @@ +from .transforms import RandomSquareCrop + +__all__ = ['RandomSquareCrop'] diff --git a/modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/transforms.py b/modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/transforms.py new file mode 100755 index 00000000..3048cefa --- /dev/null +++ b/modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/transforms.py @@ -0,0 +1,188 @@ +""" +based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/transforms.py +""" +import numpy as np +from mmdet.datasets.builder import PIPELINES +from numpy import random + + +@PIPELINES.register_module() +class RandomSquareCrop(object): + """Random crop the image & bboxes, the cropped patches have minimum IoU + requirement with original image & bboxes, the IoU threshold is randomly + selected from min_ious. + + Args: + min_ious (tuple): minimum IoU threshold for all intersections with + bounding boxes + min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w, + where a >= min_crop_size). + + Note: + The keys for bboxes, labels and masks should be paired. That is, \ + `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \ + `gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`. + """ + + def __init__(self, + crop_ratio_range=None, + crop_choice=None, + bbox_clip_border=True): + + self.crop_ratio_range = crop_ratio_range + self.crop_choice = crop_choice + self.bbox_clip_border = bbox_clip_border + + assert (self.crop_ratio_range is None) ^ (self.crop_choice is None) + if self.crop_ratio_range is not None: + self.crop_ratio_min, self.crop_ratio_max = self.crop_ratio_range + + self.bbox2label = { + 'gt_bboxes': 'gt_labels', + 'gt_bboxes_ignore': 'gt_labels_ignore' + } + self.bbox2mask = { + 'gt_bboxes': 'gt_masks', + 'gt_bboxes_ignore': 'gt_masks_ignore' + } + + def __call__(self, results): + """Call function to crop images and bounding boxes with minimum IoU + constraint. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with images and bounding boxes cropped, \ + 'img_shape' key is updated. + """ + + if 'img_fields' in results: + assert results['img_fields'] == ['img'], \ + 'Only single img_fields is allowed' + img = results['img'] + assert 'bbox_fields' in results + assert 'gt_bboxes' in results + boxes = results['gt_bboxes'] + h, w, c = img.shape + scale_retry = 0 + if self.crop_ratio_range is not None: + max_scale = self.crop_ratio_max + else: + max_scale = np.amax(self.crop_choice) + while True: + scale_retry += 1 + + if scale_retry == 1 or max_scale > 1.0: + if self.crop_ratio_range is not None: + scale = np.random.uniform(self.crop_ratio_min, + self.crop_ratio_max) + elif self.crop_choice is not None: + scale = np.random.choice(self.crop_choice) + else: + scale = scale * 1.2 + + for i in range(250): + short_side = min(w, h) + cw = int(scale * short_side) + ch = cw + + # TODO +1 + if w == cw: + left = 0 + elif w > cw: + left = random.randint(0, w - cw) + else: + left = random.randint(w - cw, 0) + if h == ch: + top = 0 + elif h > ch: + top = random.randint(0, h - ch) + else: + top = random.randint(h - ch, 0) + + patch = np.array( + (int(left), int(top), int(left + cw), int(top + ch)), + dtype=np.int) + + # center of boxes should inside the crop img + # only adjust boxes and instance masks when the gt is not empty + # adjust boxes + def is_center_of_bboxes_in_patch(boxes, patch): + # TODO >= + center = (boxes[:, :2] + boxes[:, 2:]) / 2 + mask = \ + ((center[:, 0] > patch[0]) + * (center[:, 1] > patch[1]) + * (center[:, 0] < patch[2]) + * (center[:, 1] < patch[3])) + return mask + + mask = is_center_of_bboxes_in_patch(boxes, patch) + if not mask.any(): + continue + for key in results.get('bbox_fields', []): + boxes = results[key].copy() + mask = is_center_of_bboxes_in_patch(boxes, patch) + boxes = boxes[mask] + if self.bbox_clip_border: + boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) + boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) + boxes -= np.tile(patch[:2], 2) + + results[key] = boxes + # labels + label_key = self.bbox2label.get(key) + if label_key in results: + results[label_key] = results[label_key][mask] + + # keypoints field + if key == 'gt_bboxes': + for kps_key in results.get('keypoints_fields', []): + keypointss = results[kps_key].copy() + keypointss = keypointss[mask, :, :] + if self.bbox_clip_border: + keypointss[:, :, : + 2] = keypointss[:, :, :2].clip( + max=patch[2:]) + keypointss[:, :, : + 2] = keypointss[:, :, :2].clip( + min=patch[:2]) + keypointss[:, :, 0] -= patch[0] + keypointss[:, :, 1] -= patch[1] + results[kps_key] = keypointss + + # mask fields + mask_key = self.bbox2mask.get(key) + if mask_key in results: + results[mask_key] = results[mask_key][mask.nonzero() + [0]].crop(patch) + + # adjust the img no matter whether the gt is empty before crop + rimg = np.ones((ch, cw, 3), dtype=img.dtype) * 128 + patch_from = patch.copy() + patch_from[0] = max(0, patch_from[0]) + patch_from[1] = max(0, patch_from[1]) + patch_from[2] = min(img.shape[1], patch_from[2]) + patch_from[3] = min(img.shape[0], patch_from[3]) + patch_to = patch.copy() + patch_to[0] = max(0, patch_to[0] * -1) + patch_to[1] = max(0, patch_to[1] * -1) + patch_to[2] = patch_to[0] + (patch_from[2] - patch_from[0]) + patch_to[3] = patch_to[1] + (patch_from[3] - patch_from[1]) + rimg[patch_to[1]:patch_to[3], + patch_to[0]:patch_to[2], :] = img[ + patch_from[1]:patch_from[3], + patch_from[0]:patch_from[2], :] + img = rimg + results['img'] = img + results['img_shape'] = img.shape + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(min_ious={self.min_iou}, ' + repr_str += f'crop_size={self.crop_size})' + return repr_str diff --git a/modelscope/models/cv/face_detection/mmdet_patch/datasets/retinaface.py b/modelscope/models/cv/face_detection/mmdet_patch/datasets/retinaface.py new file mode 100755 index 00000000..bf20764b --- /dev/null +++ b/modelscope/models/cv/face_detection/mmdet_patch/datasets/retinaface.py @@ -0,0 +1,151 @@ +""" +based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/retinaface.py +""" +import numpy as np +from mmdet.datasets.builder import DATASETS +from mmdet.datasets.custom import CustomDataset + + +@DATASETS.register_module() +class RetinaFaceDataset(CustomDataset): + + CLASSES = ('FG', ) + + def __init__(self, min_size=None, **kwargs): + self.NK = 5 + self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)} + self.min_size = min_size + self.gt_path = kwargs.get('gt_path') + super(RetinaFaceDataset, self).__init__(**kwargs) + + def _parse_ann_line(self, line): + values = [float(x) for x in line.strip().split()] + bbox = np.array(values[0:4], dtype=np.float32) + kps = np.zeros((self.NK, 3), dtype=np.float32) + ignore = False + if self.min_size is not None: + assert not self.test_mode + w = bbox[2] - bbox[0] + h = bbox[3] - bbox[1] + if w < self.min_size or h < self.min_size: + ignore = True + if len(values) > 4: + if len(values) > 5: + kps = np.array( + values[4:19], dtype=np.float32).reshape((self.NK, 3)) + for li in range(kps.shape[0]): + if (kps[li, :] == -1).all(): + kps[li][2] = 0.0 # weight = 0, ignore + else: + assert kps[li][2] >= 0 + kps[li][2] = 1.0 # weight + else: # len(values)==5 + if not ignore: + ignore = (values[4] == 1) + else: + assert self.test_mode + + return dict(bbox=bbox, kps=kps, ignore=ignore, cat='FG') + + def load_annotations(self, ann_file): + """Load annotation from COCO style annotation file. + + Args: + ann_file (str): Path of annotation file. + 20220711@tyx: ann_file is list of img paths is supported + + Returns: + list[dict]: Annotation info from COCO api. + """ + if isinstance(ann_file, list): + data_infos = [] + for line in ann_file: + name = line + objs = [0, 0, 0, 0] + data_infos.append( + dict(filename=name, width=0, height=0, objs=objs)) + else: + name = None + bbox_map = {} + for line in open(ann_file, 'r'): + line = line.strip() + if line.startswith('#'): + value = line[1:].strip().split() + name = value[0] + width = int(value[1]) + height = int(value[2]) + + bbox_map[name] = dict(width=width, height=height, objs=[]) + continue + assert name is not None + assert name in bbox_map + bbox_map[name]['objs'].append(line) + print('origin image size', len(bbox_map)) + data_infos = [] + for name in bbox_map: + item = bbox_map[name] + width = item['width'] + height = item['height'] + vals = item['objs'] + objs = [] + for line in vals: + data = self._parse_ann_line(line) + if data is None: + continue + objs.append(data) # data is (bbox, kps, cat) + if len(objs) == 0 and not self.test_mode: + continue + data_infos.append( + dict(filename=name, width=width, height=height, objs=objs)) + return data_infos + + def get_ann_info(self, idx): + """Get COCO annotation by index. + + Args: + idx (int): Index of data. + + Returns: + dict: Annotation info of specified index. + """ + data_info = self.data_infos[idx] + + bboxes = [] + keypointss = [] + labels = [] + bboxes_ignore = [] + labels_ignore = [] + for obj in data_info['objs']: + label = self.cat2label[obj['cat']] + bbox = obj['bbox'] + keypoints = obj['kps'] + ignore = obj['ignore'] + if ignore: + bboxes_ignore.append(bbox) + labels_ignore.append(label) + else: + bboxes.append(bbox) + labels.append(label) + keypointss.append(keypoints) + if not bboxes: + bboxes = np.zeros((0, 4)) + labels = np.zeros((0, )) + keypointss = np.zeros((0, self.NK, 3)) + else: + # bboxes = np.array(bboxes, ndmin=2) - 1 + bboxes = np.array(bboxes, ndmin=2) + labels = np.array(labels) + keypointss = np.array(keypointss, ndmin=3) + if not bboxes_ignore: + bboxes_ignore = np.zeros((0, 4)) + labels_ignore = np.zeros((0, )) + else: + bboxes_ignore = np.array(bboxes_ignore, ndmin=2) + labels_ignore = np.array(labels_ignore) + ann = dict( + bboxes=bboxes.astype(np.float32), + labels=labels.astype(np.int64), + keypointss=keypointss.astype(np.float32), + bboxes_ignore=bboxes_ignore.astype(np.float32), + labels_ignore=labels_ignore.astype(np.int64)) + return ann diff --git a/modelscope/models/cv/face_detection/mmdet_patch/models/__init__.py b/modelscope/models/cv/face_detection/mmdet_patch/models/__init__.py new file mode 100755 index 00000000..38c8ff5b --- /dev/null +++ b/modelscope/models/cv/face_detection/mmdet_patch/models/__init__.py @@ -0,0 +1,2 @@ +from .dense_heads import * # noqa: F401,F403 +from .detectors import * # noqa: F401,F403 diff --git a/modelscope/models/cv/face_detection/mmdet_patch/models/backbones/__init__.py b/modelscope/models/cv/face_detection/mmdet_patch/models/backbones/__init__.py new file mode 100755 index 00000000..2d930bf4 --- /dev/null +++ b/modelscope/models/cv/face_detection/mmdet_patch/models/backbones/__init__.py @@ -0,0 +1,3 @@ +from .resnet import ResNetV1e + +__all__ = ['ResNetV1e'] diff --git a/modelscope/models/cv/face_detection/mmdet_patch/models/backbones/resnet.py b/modelscope/models/cv/face_detection/mmdet_patch/models/backbones/resnet.py new file mode 100644 index 00000000..54bcb127 --- /dev/null +++ b/modelscope/models/cv/face_detection/mmdet_patch/models/backbones/resnet.py @@ -0,0 +1,412 @@ +""" +based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/backbones/resnet.py +""" +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer, + constant_init, kaiming_init) +from mmcv.runner import load_checkpoint +from mmdet.models.backbones.resnet import BasicBlock, Bottleneck +from mmdet.models.builder import BACKBONES +from mmdet.models.utils import ResLayer +from mmdet.utils import get_root_logger +from torch.nn.modules.batchnorm import _BatchNorm + + +class ResNet(nn.Module): + """ResNet backbone. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + stem_channels (int | None): Number of stem channels. If not specified, + it will be the same as `base_channels`. Default: None. + base_channels (int): Number of base channels of res layer. Default: 64. + in_channels (int): Number of input image channels. Default: 3. + num_stages (int): Resnet stages. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + norm_cfg (dict): Dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + plugins (list[dict]): List of plugins for stages, each dict contains: + + - cfg (dict, required): Cfg dict to build plugin. + - position (str, required): Position inside block to insert + plugin, options are 'after_conv1', 'after_conv2', 'after_conv3'. + - stages (tuple[bool], optional): Stages to apply plugin, length + should be same as 'num_stages'. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. + + Example: + >>> from mmdet.models import ResNet + >>> import torch + >>> self = ResNet(depth=18) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 8, 8) + (1, 128, 4, 4) + (1, 256, 2, 2) + (1, 512, 1, 1) + """ + + arch_settings = { + 0: (BasicBlock, (2, 2, 2, 2)), + 18: (BasicBlock, (2, 2, 2, 2)), + 19: (BasicBlock, (2, 4, 4, 1)), + 20: (BasicBlock, (2, 3, 2, 2)), + 22: (BasicBlock, (2, 4, 3, 1)), + 24: (BasicBlock, (2, 4, 4, 1)), + 26: (BasicBlock, (2, 4, 4, 2)), + 28: (BasicBlock, (2, 5, 4, 2)), + 29: (BasicBlock, (2, 6, 3, 2)), + 30: (BasicBlock, (2, 5, 5, 2)), + 32: (BasicBlock, (2, 6, 5, 2)), + 34: (BasicBlock, (3, 4, 6, 3)), + 35: (BasicBlock, (3, 6, 4, 3)), + 38: (BasicBlock, (3, 8, 4, 3)), + 40: (BasicBlock, (3, 8, 5, 3)), + 50: (Bottleneck, (3, 4, 6, 3)), + 56: (Bottleneck, (3, 8, 4, 3)), + 68: (Bottleneck, (3, 10, 6, 3)), + 74: (Bottleneck, (3, 12, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + depth, + in_channels=3, + stem_channels=None, + base_channels=64, + num_stages=4, + block_cfg=None, + strides=(1, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(0, 1, 2, 3), + style='pytorch', + deep_stem=False, + avg_down=False, + no_pool33=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + dcn=None, + stage_with_dcn=(False, False, False, False), + plugins=None, + with_cp=False, + zero_init_residual=True): + super(ResNet, self).__init__() + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + self.depth = depth + if stem_channels is None: + stem_channels = base_channels + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + assert num_stages >= 1 and num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + self.avg_down = avg_down + self.no_pool33 = no_pool33 + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.dcn = dcn + self.stage_with_dcn = stage_with_dcn + if dcn is not None: + assert len(stage_with_dcn) == num_stages + self.plugins = plugins + self.zero_init_residual = zero_init_residual + if block_cfg is None: + self.block, stage_blocks = self.arch_settings[depth] + else: + self.block = BasicBlock if block_cfg[ + 'block'] == 'BasicBlock' else Bottleneck + stage_blocks = block_cfg['stage_blocks'] + assert len(stage_blocks) >= num_stages + self.stage_blocks = stage_blocks[:num_stages] + self.inplanes = stem_channels + + self._make_stem_layer(in_channels, stem_channels) + if block_cfg is not None and 'stage_planes' in block_cfg: + stage_planes = block_cfg['stage_planes'] + else: + stage_planes = [base_channels * 2**i for i in range(num_stages)] + + # print('resnet cfg:', stage_blocks, stage_planes) + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + stride = strides[i] + dilation = dilations[i] + dcn = self.dcn if self.stage_with_dcn[i] else None + if plugins is not None: + stage_plugins = self.make_stage_plugins(plugins, i) + else: + stage_plugins = None + planes = stage_planes[i] + res_layer = self.make_res_layer( + block=self.block, + inplanes=self.inplanes, + planes=planes, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + dcn=dcn, + plugins=stage_plugins) + self.inplanes = planes * self.block.expansion + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = self.block.expansion * base_channels * 2**( + len(self.stage_blocks) - 1) + + def make_stage_plugins(self, plugins, stage_idx): + """Make plugins for ResNet ``stage_idx`` th stage. + + Currently we support to insert ``context_block``, + ``empirical_attention_block``, ``nonlocal_block`` into the backbone + like ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of + Bottleneck. + + An example of plugins format could be: + + Examples: + >>> plugins=[ + ... dict(cfg=dict(type='xxx', arg1='xxx'), + ... stages=(False, True, True, True), + ... position='after_conv2'), + ... dict(cfg=dict(type='yyy'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='1'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='2'), + ... stages=(True, True, True, True), + ... position='after_conv3') + ... ] + >>> self = ResNet(depth=18) + >>> stage_plugins = self.make_stage_plugins(plugins, 0) + >>> assert len(stage_plugins) == 3 + + Suppose ``stage_idx=0``, the structure of blocks in the stage would be: + + .. code-block:: none + + conv1-> conv2->conv3->yyy->zzz1->zzz2 + + Suppose 'stage_idx=1', the structure of blocks in the stage would be: + + .. code-block:: none + + conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2 + + If stages is missing, the plugin would be applied to all stages. + + Args: + plugins (list[dict]): List of plugins cfg to build. The postfix is + required if multiple same type plugins are inserted. + stage_idx (int): Index of stage to build + + Returns: + list[dict]: Plugins for current stage + """ + stage_plugins = [] + for plugin in plugins: + plugin = plugin.copy() + stages = plugin.pop('stages', None) + assert stages is None or len(stages) == self.num_stages + # whether to insert plugin into current stage + if stages is None or stages[stage_idx]: + stage_plugins.append(plugin) + + return stage_plugins + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``.""" + return ResLayer(**kwargs) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + def _make_stem_layer(self, in_channels, stem_channels): + if self.deep_stem: + self.stem = nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels // 2, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels)[1], + nn.ReLU(inplace=True)) + else: + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, stem_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + if self.no_pool33: + assert self.deep_stem + self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + else: + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + if self.deep_stem: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + else: + self.norm1.eval() + for m in [self.conv1, self.norm1]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = getattr(self, f'layer{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + + if self.dcn is not None: + for m in self.modules(): + if isinstance(m, Bottleneck) and hasattr( + m.conv2, 'conv_offset'): + constant_init(m.conv2.conv_offset, 0) + + if self.zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + constant_init(m.norm3, 0) + elif isinstance(m, BasicBlock): + constant_init(m.norm2, 0) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer + freezed.""" + super(ResNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + +@BACKBONES.register_module() +class ResNetV1e(ResNet): + r"""ResNetV1d variant described in `Bag of Tricks + `_. + + Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in + the input stem with three 3x3 convs. And in the downsampling block, a 2x2 + avg_pool with stride 2 is added before conv, whose stride is changed to 1. + + Compared with ResNetV1d, ResNetV1e change maxpooling from 3x3 to 2x2 pad=1 + """ + + def __init__(self, **kwargs): + super(ResNetV1e, self).__init__( + deep_stem=True, avg_down=True, no_pool33=True, **kwargs) diff --git a/modelscope/models/cv/face_detection/mmdet_patch/models/dense_heads/__init__.py b/modelscope/models/cv/face_detection/mmdet_patch/models/dense_heads/__init__.py new file mode 100755 index 00000000..e67031bc --- /dev/null +++ b/modelscope/models/cv/face_detection/mmdet_patch/models/dense_heads/__init__.py @@ -0,0 +1,3 @@ +from .scrfd_head import SCRFDHead + +__all__ = ['SCRFDHead'] diff --git a/modelscope/models/cv/face_detection/mmdet_patch/models/dense_heads/scrfd_head.py b/modelscope/models/cv/face_detection/mmdet_patch/models/dense_heads/scrfd_head.py new file mode 100755 index 00000000..1667f29f --- /dev/null +++ b/modelscope/models/cv/face_detection/mmdet_patch/models/dense_heads/scrfd_head.py @@ -0,0 +1,1068 @@ +""" +based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/dense_heads/scrfd_head.py +""" +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, Scale, + bias_init_with_prob, constant_init, kaiming_init, + normal_init) +from mmcv.runner import force_fp32 +from mmdet.core import (anchor_inside_flags, bbox2distance, bbox_overlaps, + build_assigner, build_sampler, distance2bbox, + images_to_levels, multi_apply, reduce_mean, unmap) +from mmdet.models.builder import HEADS, build_loss +from mmdet.models.dense_heads.anchor_head import AnchorHead + +from ....mmdet_patch.core.bbox import distance2kps, kps2distance +from ....mmdet_patch.core.post_processing import multiclass_nms + + +class Integral(nn.Module): + """A fixed layer for calculating integral result from distribution. + + This layer calculates the target location by :math: `sum{P(y_i) * y_i}`, + P(y_i) denotes the softmax vector that represents the discrete distribution + y_i denotes the discrete set, usually {0, 1, 2, ..., reg_max} + + Args: + reg_max (int): The maximal value of the discrete set. Default: 16. You + may want to reset it according to your new dataset or related + settings. + """ + + def __init__(self, reg_max=16): + super(Integral, self).__init__() + self.reg_max = reg_max + self.register_buffer('project', + torch.linspace(0, self.reg_max, self.reg_max + 1)) + + def forward(self, x): + """Forward feature from the regression head to get integral result of + bounding box location. + + Args: + x (Tensor): Features of the regression head, shape (N, 4*(n+1)), + n is self.reg_max. + + Returns: + x (Tensor): Integral result of box locations, i.e., distance + offsets from the box center in four directions, shape (N, 4). + """ + x = F.softmax(x.reshape(-1, self.reg_max + 1), dim=1) + x = F.linear(x, self.project.type_as(x)).reshape(-1, 4) + return x + + +@HEADS.register_module() +class SCRFDHead(AnchorHead): + """Generalized Focal Loss: Learning Qualified and Distributed Bounding + Boxes for Dense Object Detection. + + GFL head structure is similar with ATSS, however GFL uses + 1) joint representation for classification and localization quality, and + 2) flexible General distribution for bounding box locations, + which are supervised by + Quality Focal Loss (QFL) and Distribution Focal Loss (DFL), respectively + + https://arxiv.org/abs/2006.04388 + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + stacked_convs (int): Number of conv layers in cls and reg tower. + Default: 4. + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='GN', num_groups=32, requires_grad=True). + loss_qfl (dict): Config of Quality Focal Loss (QFL). + reg_max (int): Max value of integral set :math: `{0, ..., reg_max}` + in QFL setting. Default: 16. + Example: + >>> self = GFLHead(11, 7) + >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]] + >>> cls_quality_score, bbox_pred = self.forward(feats) + >>> assert len(cls_quality_score) == len(self.scales) + """ + + def __init__(self, + num_classes, + in_channels, + stacked_convs=4, + feat_mults=None, + conv_cfg=None, + norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), + loss_dfl=None, + reg_max=8, + cls_reg_share=False, + strides_share=True, + scale_mode=1, + dw_conv=False, + use_kps=False, + loss_kps=dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.1), + **kwargs): + self.stacked_convs = stacked_convs + self.feat_mults = feat_mults + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.reg_max = reg_max + self.cls_reg_share = cls_reg_share + self.strides_share = strides_share + self.scale_mode = scale_mode + self.use_dfl = True + self.dw_conv = dw_conv + self.NK = 5 + self.extra_flops = 0.0 + if loss_dfl is None or not loss_dfl: + self.use_dfl = False + self.use_scale = False + self.use_kps = use_kps + if self.scale_mode > 0 and (self.strides_share + or self.scale_mode == 2): + self.use_scale = True + super(SCRFDHead, self).__init__(num_classes, in_channels, **kwargs) + + self.sampling = False + if self.train_cfg: + self.assigner = build_assigner(self.train_cfg.assigner) + # SSD sampling=False so use PseudoSampler + sampler_cfg = dict(type='PseudoSampler') + self.sampler = build_sampler(sampler_cfg, context=self) + + self.integral = Integral(self.reg_max) + if self.use_dfl: + self.loss_dfl = build_loss(loss_dfl) + self.loss_kps = build_loss(loss_kps) + self.loss_kps_std = 1.0 + self.train_step = 0 + self.pos_count = {} + self.gtgroup_count = {} + for stride in self.anchor_generator.strides: + self.pos_count[stride[0]] = 0 + + def _get_conv_module(self, in_channel, out_channel): + if not self.dw_conv: + conv = ConvModule( + in_channel, + out_channel, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg) + else: + conv = DepthwiseSeparableConvModule( + in_channel, + out_channel, + 3, + stride=1, + padding=1, + pw_norm_cfg=self.norm_cfg, + dw_norm_cfg=self.norm_cfg) + return conv + + def _init_layers(self): + """Initialize layers of the head.""" + self.relu = nn.ReLU(inplace=True) + conv_strides = [0] if self.strides_share else \ + self.anchor_generator.strides + self.cls_stride_convs = nn.ModuleDict() + self.reg_stride_convs = nn.ModuleDict() + self.stride_cls = nn.ModuleDict() + self.stride_reg = nn.ModuleDict() + if self.use_kps: + self.stride_kps = nn.ModuleDict() + for stride_idx, conv_stride in enumerate(conv_strides): + key = str(conv_stride) + cls_convs = nn.ModuleList() + reg_convs = nn.ModuleList() + stacked_convs = self.stacked_convs[stride_idx] if \ + isinstance(self.stacked_convs, (list, tuple)) else \ + self.stacked_convs + feat_mult = self.feat_mults[stride_idx] if \ + self.feat_mults is not None else 1 + feat_ch = int(self.feat_channels * feat_mult) + last_feat_ch = 0 + for i in range(stacked_convs): + chn = self.in_channels if i == 0 else last_feat_ch + cls_convs.append(self._get_conv_module(chn, feat_ch)) + if not self.cls_reg_share: + reg_convs.append(self._get_conv_module(chn, feat_ch)) + last_feat_ch = feat_ch + self.cls_stride_convs[key] = cls_convs + self.reg_stride_convs[key] = reg_convs + self.stride_cls[key] = nn.Conv2d( + feat_ch, + self.cls_out_channels * self.num_anchors, + 3, + padding=1) + if not self.use_dfl: + self.stride_reg[key] = nn.Conv2d( + feat_ch, 4 * self.num_anchors, 3, padding=1) + else: + self.stride_reg[key] = nn.Conv2d( + feat_ch, + 4 * (self.reg_max + 1) * self.num_anchors, + 3, + padding=1) + if self.use_kps: + self.stride_kps[key] = nn.Conv2d( + feat_ch, self.NK * 2 * self.num_anchors, 3, padding=1) + if self.use_scale: + self.scales = nn.ModuleList( + [Scale(1.0) for _ in self.anchor_generator.strides]) + else: + self.scales = [None for _ in self.anchor_generator.strides] + + def init_weights(self): + """Initialize weights of the head.""" + for stride, cls_convs in self.cls_stride_convs.items(): + for m in cls_convs: + if not self.dw_conv: + try: + normal_init(m.conv, std=0.01) + except Exception: + pass + else: + normal_init(m.depthwise_conv.conv, std=0.01) + normal_init(m.pointwise_conv.conv, std=0.01) + for stride, reg_convs in self.reg_stride_convs.items(): + for m in reg_convs: + if not self.dw_conv: + normal_init(m.conv, std=0.01) + else: + normal_init(m.depthwise_conv.conv, std=0.01) + normal_init(m.pointwise_conv.conv, std=0.01) + bias_cls = -4.595 + for stride, conv in self.stride_cls.items(): + normal_init(conv, std=0.01, bias=bias_cls) + for stride, conv in self.stride_reg.items(): + normal_init(conv, std=0.01) + if self.use_kps: + for stride, conv in self.stride_kps.items(): + normal_init(conv, std=0.01) + + def forward(self, feats): + """Forward features from the upstream network. + + Args: + feats (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: Usually a tuple of classification scores and bbox prediction + cls_scores (list[Tensor]): Classification and quality (IoU) + joint scores for all scale levels, each is a 4D-tensor, + the channel number is num_classes. + bbox_preds (list[Tensor]): Box distribution logits for all + scale levels, each is a 4D-tensor, the channel number is + 4*(n+1), n is max value of integral set. + """ + return multi_apply(self.forward_single, feats, self.scales, + self.anchor_generator.strides) + + def forward_single(self, x, scale, stride): + """Forward feature of a single scale level. + + Args: + x (Tensor): Features of a single scale level. + scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize + the bbox prediction. + + Returns: + tuple: + cls_score (Tensor): Cls and quality joint scores for a single + scale level the channel number is num_classes. + bbox_pred (Tensor): Box distribution logits for a single scale + level, the channel number is 4*(n+1), n is max value of + integral set. + """ + cls_feat = x + reg_feat = x + cls_convs = self.cls_stride_convs[ + '0'] if self.strides_share else self.cls_stride_convs[str(stride)] + for cls_conv in cls_convs: + cls_feat = cls_conv(cls_feat) + if not self.cls_reg_share: + reg_convs = self.reg_stride_convs[ + '0'] if self.strides_share else self.reg_stride_convs[str( + stride)] + for reg_conv in reg_convs: + reg_feat = reg_conv(reg_feat) + else: + reg_feat = cls_feat + cls_pred_module = self.stride_cls[ + '0'] if self.strides_share else self.stride_cls[str(stride)] + cls_score = cls_pred_module(cls_feat) + reg_pred_module = self.stride_reg[ + '0'] if self.strides_share else self.stride_reg[str(stride)] + _bbox_pred = reg_pred_module(reg_feat) + if self.use_scale: + bbox_pred = scale(_bbox_pred) + else: + bbox_pred = _bbox_pred + if self.use_kps: + kps_pred_module = self.stride_kps[ + '0'] if self.strides_share else self.stride_kps[str(stride)] + kps_pred = kps_pred_module(reg_feat) + else: + kps_pred = bbox_pred.new_zeros( + (bbox_pred.shape[0], self.NK * 2, bbox_pred.shape[2], + bbox_pred.shape[3])) + if torch.onnx.is_in_onnx_export(): + assert not self.use_dfl + print('in-onnx-export', cls_score.shape, bbox_pred.shape) + # Add output batch dim, based on pull request #1593 + batch_size = cls_score.shape[0] + cls_score = cls_score.permute(0, 2, 3, 1).reshape( + batch_size, -1, self.cls_out_channels).sigmoid() + bbox_pred = bbox_pred.permute(0, 2, 3, + 1).reshape(batch_size, -1, 4) + kps_pred = kps_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 10) + + return cls_score, bbox_pred, kps_pred + + def forward_train(self, + x, + img_metas, + gt_bboxes, + gt_labels=None, + gt_keypointss=None, + gt_bboxes_ignore=None, + proposal_cfg=None, + **kwargs): + """ + Args: + x (list[Tensor]): Features from FPN. + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes (Tensor): Ground truth bboxes of the image, + shape (num_gts, 4). + gt_labels (Tensor): Ground truth labels of each box, + shape (num_gts,). + gt_bboxes_ignore (Tensor): Ground truth bboxes to be + ignored, shape (num_ignored_gts, 4). + proposal_cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used + + Returns: + tuple: + losses: (dict[str, Tensor]): A dictionary of loss components. + proposal_list (list[Tensor]): Proposals of each image. + """ + outs = self(x) + if gt_labels is None: + loss_inputs = outs + (gt_bboxes, img_metas) + else: + loss_inputs = outs + (gt_bboxes, gt_labels, gt_keypointss, + img_metas) + losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) + if proposal_cfg is None: + return losses + else: + proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg) + return losses, proposal_list + + def get_anchors(self, featmap_sizes, img_metas, device='cuda'): + """Get anchors according to feature map sizes. + + Args: + featmap_sizes (list[tuple]): Multi-level feature map sizes. + img_metas (list[dict]): Image meta info. + device (torch.device | str): Device for returned tensors + + Returns: + tuple: + anchor_list (list[Tensor]): Anchors of each image. + valid_flag_list (list[Tensor]): Valid flags of each image. + """ + num_imgs = len(img_metas) + + # since feature map sizes of all images are the same, we only compute + # anchors for one time + multi_level_anchors = self.anchor_generator.grid_anchors( + featmap_sizes, device) + anchor_list = [multi_level_anchors for _ in range(num_imgs)] + + # for each image, we compute valid flags of multi level anchors + valid_flag_list = [] + for img_id, img_meta in enumerate(img_metas): + multi_level_flags = self.anchor_generator.valid_flags( + featmap_sizes, img_meta['pad_shape'], device) + valid_flag_list.append(multi_level_flags) + + return anchor_list, valid_flag_list + + def anchor_center(self, anchors): + """Get anchor centers from anchors. + + Args: + anchors (Tensor): Anchor list with shape (N, 4), "xyxy" format. + + Returns: + Tensor: Anchor centers with shape (N, 2), "xy" format. + """ + anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2 + anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2 + return torch.stack([anchors_cx, anchors_cy], dim=-1) + + def loss_single(self, anchors, cls_score, bbox_pred, kps_pred, labels, + label_weights, bbox_targets, kps_targets, kps_weights, + stride, num_total_samples): + """Compute loss of a single scale level. + + Args: + anchors (Tensor): Box reference for each scale level with shape + (N, num_total_anchors, 4). + cls_score (Tensor): Cls and quality joint scores for each scale + level has shape (N, num_classes, H, W). + bbox_pred (Tensor): Box distribution logits for each scale + level with shape (N, 4*(n+1), H, W), n is max value of integral + set. + labels (Tensor): Labels of each anchors with shape + (N, num_total_anchors). + label_weights (Tensor): Label weights of each anchor with shape + (N, num_total_anchors) + bbox_targets (Tensor): BBox regression targets of each anchor wight + shape (N, num_total_anchors, 4). + stride (tuple): Stride in this scale level. + num_total_samples (int): Number of positive samples that is + reduced over all GPUs. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert stride[0] == stride[1], 'h stride is not equal to w stride!' + use_qscore = True + anchors = anchors.reshape(-1, 4) + cls_score = cls_score.permute(0, 2, 3, + 1).reshape(-1, self.cls_out_channels) + if not self.use_dfl: + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) + else: + bbox_pred = bbox_pred.permute(0, 2, 3, 1) + bbox_pred = bbox_pred.reshape(-1, 4 * (self.reg_max + 1)) + bbox_targets = bbox_targets.reshape(-1, 4) + labels = labels.reshape(-1) + label_weights = label_weights.reshape(-1) + + if self.use_kps: + kps_pred = kps_pred.permute(0, 2, 3, 1).reshape(-1, self.NK * 2) + kps_targets = kps_targets.reshape((-1, self.NK * 2)) + kps_weights = kps_weights.reshape((-1, self.NK * 2)) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = self.num_classes + pos_inds = ((labels >= 0) + & (labels < bg_class_ind)).nonzero().squeeze(1) + score = label_weights.new_zeros(labels.shape) + + if len(pos_inds) > 0: + pos_bbox_targets = bbox_targets[pos_inds] + pos_bbox_pred = bbox_pred[pos_inds] + pos_anchors = anchors[pos_inds] + pos_anchor_centers = self.anchor_center(pos_anchors) / stride[0] + + weight_targets = cls_score.detach().sigmoid() + weight_targets = weight_targets.max(dim=1)[0][pos_inds] + pos_decode_bbox_targets = pos_bbox_targets / stride[0] + + if self.use_dfl: + pos_bbox_pred_corners = self.integral(pos_bbox_pred) + pos_decode_bbox_pred = distance2bbox(pos_anchor_centers, + pos_bbox_pred_corners) + else: + pos_decode_bbox_pred = distance2bbox(pos_anchor_centers, + pos_bbox_pred) + if self.use_kps: + pos_kps_targets = kps_targets[pos_inds] + pos_kps_pred = kps_pred[pos_inds] + pos_kps_weights = kps_weights.max( + dim=1)[0][pos_inds] * weight_targets + pos_kps_weights = pos_kps_weights.reshape((-1, 1)) + pos_decode_kps_targets = kps2distance( + pos_anchor_centers, pos_kps_targets / stride[0]) + pos_decode_kps_pred = pos_kps_pred + if use_qscore: + score[pos_inds] = bbox_overlaps( + pos_decode_bbox_pred.detach(), + pos_decode_bbox_targets, + is_aligned=True) + else: + score[pos_inds] = 1.0 + + # regression loss + loss_bbox = self.loss_bbox( + pos_decode_bbox_pred, + pos_decode_bbox_targets, + weight=weight_targets, + avg_factor=1.0) + + if self.use_kps: + loss_kps = self.loss_kps( + pos_decode_kps_pred * self.loss_kps_std, + pos_decode_kps_targets * self.loss_kps_std, + weight=pos_kps_weights, + avg_factor=1.0) + else: + loss_kps = kps_pred.sum() * 0 + + # dfl loss + if self.use_dfl: + pred_corners = pos_bbox_pred.reshape(-1, self.reg_max + 1) + target_corners = bbox2distance(pos_anchor_centers, + pos_decode_bbox_targets, + self.reg_max).reshape(-1) + loss_dfl = self.loss_dfl( + pred_corners, + target_corners, + weight=weight_targets[:, None].expand(-1, 4).reshape(-1), + avg_factor=4.0) + else: + loss_dfl = bbox_pred.sum() * 0 + else: + loss_bbox = bbox_pred.sum() * 0 + loss_dfl = bbox_pred.sum() * 0 + loss_kps = kps_pred.sum() * 0 + weight_targets = torch.tensor(0).cuda() + + loss_cls = self.loss_cls( + cls_score, (labels, score), + weight=label_weights, + avg_factor=num_total_samples) + return loss_cls, loss_bbox, loss_dfl, loss_kps, weight_targets.sum() + + @force_fp32(apply_to=('cls_scores', 'bbox_preds')) + def loss(self, + cls_scores, + bbox_preds, + kps_preds, + gt_bboxes, + gt_labels, + gt_keypointss, + img_metas, + gt_bboxes_ignore=None): + """Compute losses of the head. + + Args: + cls_scores (list[Tensor]): Cls and quality scores for each scale + level has shape (N, num_classes, H, W). + bbox_preds (list[Tensor]): Box distribution logits for each scale + level with shape (N, 4*(n+1), H, W), n is max value of integral + set. + gt_bboxes (list[Tensor]): Ground truth bboxes for each image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): class indices corresponding to each box + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes_ignore (list[Tensor] | None): specify which bounding + boxes can be ignored when computing the loss. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.anchor_generator.num_levels + + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, img_metas, device=device) + label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 + + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + gt_bboxes, + gt_keypointss, + img_metas, + gt_bboxes_ignore_list=gt_bboxes_ignore, + gt_labels_list=gt_labels, + label_channels=label_channels) + if cls_reg_targets is None: + return None + + (anchor_list, labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, keypoints_targets_list, keypoints_weights_list, + num_total_pos, num_total_neg) = cls_reg_targets + + num_total_samples = reduce_mean( + torch.tensor(num_total_pos, dtype=torch.float, + device=device)).item() + num_total_samples = max(num_total_samples, 1.0) + + losses_cls, losses_bbox, losses_dfl, losses_kps,\ + avg_factor = multi_apply( + self.loss_single, + anchor_list, + cls_scores, + bbox_preds, + kps_preds, + labels_list, + label_weights_list, + bbox_targets_list, + keypoints_targets_list, + keypoints_weights_list, + self.anchor_generator.strides, + num_total_samples=num_total_samples) + + avg_factor = sum(avg_factor) + avg_factor = reduce_mean(avg_factor).item() + losses_bbox = list(map(lambda x: x / avg_factor, losses_bbox)) + losses = dict(loss_cls=losses_cls, loss_bbox=losses_bbox) + if self.use_kps: + losses_kps = list(map(lambda x: x / avg_factor, losses_kps)) + losses['loss_kps'] = losses_kps + if self.use_dfl: + losses_dfl = list(map(lambda x: x / avg_factor, losses_dfl)) + losses['loss_dfl'] = losses_dfl + return losses + + @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'kps_preds')) + def get_bboxes(self, + cls_scores, + bbox_preds, + kps_preds, + img_metas, + cfg=None, + rescale=False, + with_nms=True): + """Transform network output for a batch into bbox predictions. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W) + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + cfg (mmcv.Config | None): Test / postprocessing configuration, + if None, test_cfg would be used + rescale (bool): If True, return boxes in original image space. + Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. + + Returns: + list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. + The first item is an (n, 5) tensor, where the first 4 columns + are bounding box positions (tl_x, tl_y, br_x, br_y) and the + 5-th column is a score between 0 and 1. The second item is a + (n,) tensor where each item is the predicted class labelof the + corresponding box. + + Example: + >>> import mmcv + >>> self = AnchorHead( + >>> num_classes=9, + >>> in_channels=1, + >>> anchor_generator=dict( + >>> type='AnchorGenerator', + >>> scales=[8], + >>> ratios=[0.5, 1.0, 2.0], + >>> strides=[4,])) + >>> img_metas = [{'img_shape': (32, 32, 3), 'scale_factor': 1}] + >>> cfg = mmcv.Config(dict( + >>> score_thr=0.00, + >>> nms=dict(type='nms', iou_thr=1.0), + >>> max_per_img=10)) + >>> feat = torch.rand(1, 1, 3, 3) + >>> cls_score, bbox_pred = self.forward_single(feat) + >>> # note the input lists are over different levels, not images + >>> cls_scores, bbox_preds = [cls_score], [bbox_pred] + >>> result_list = self.get_bboxes(cls_scores, bbox_preds, + >>> img_metas, cfg) + >>> det_bboxes, det_labels = result_list[0] + >>> assert len(result_list) == 1 + >>> assert det_bboxes.shape[1] == 5 + >>> assert len(det_bboxes) == len(det_labels) == cfg.max_per_img + """ + assert len(cls_scores) == len(bbox_preds) + num_levels = len(cls_scores) + + device = cls_scores[0].device + featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] + mlvl_anchors = self.anchor_generator.grid_anchors( + featmap_sizes, device=device) + + result_list = [] + # bbox_preds and kps_preds are list of 3 tensor, each tensor is NCHW + # corresponding to a stage, C is 8 for bbox and 20 for kps + for img_id in range(len(img_metas)): + cls_score_list = [ + cls_scores[i][img_id].detach() for i in range(num_levels) + ] + bbox_pred_list = [ + bbox_preds[i][img_id].detach() for i in range(num_levels) + ] + if self.use_kps: + kps_pred_list = [ + kps_preds[i][img_id].detach() for i in range(num_levels) + ] + else: + kps_pred_list = [None for i in range(num_levels)] + img_shape = img_metas[img_id]['img_shape'] + scale_factor = img_metas[img_id]['scale_factor'] + if with_nms: + # some heads don't support with_nms argument + proposals = self._get_bboxes_single(cls_score_list, + bbox_pred_list, + kps_pred_list, + mlvl_anchors, img_shape, + scale_factor, cfg, rescale) + else: + proposals = self._get_bboxes_single(cls_score_list, + bbox_pred_list, + kps_pred_list, + mlvl_anchors, img_shape, + scale_factor, cfg, rescale, + with_nms) + result_list.append(proposals) + return result_list + + def _get_bboxes_single(self, + cls_scores, + bbox_preds, + kps_preds, + mlvl_anchors, + img_shape, + scale_factor, + cfg, + rescale=False, + with_nms=True): + """Transform outputs for a single batch item into labeled boxes. + + Args: + cls_scores (list[Tensor]): Box scores for a single scale level + has shape (num_classes, H, W). + bbox_preds (list[Tensor]): Box distribution logits for a single + scale level with shape (4*(n+1), H, W), n is max value of + integral set. + mlvl_anchors (list[Tensor]): Box reference for a single scale level + with shape (num_total_anchors, 4). + img_shape (tuple[int]): Shape of the input image, + (height, width, 3). + scale_factor (ndarray): Scale factor of the image arange as + (w_scale, h_scale, w_scale, h_scale). + cfg (mmcv.Config | None): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. + + Returns: + tuple(Tensor): + det_bboxes (Tensor): Bbox predictions in shape (N, 5), where + the first 4 columns are bounding box positions + (tl_x, tl_y, br_x, br_y) and the 5-th column is a score + between 0 and 1. + det_labels (Tensor): A (N,) tensor where each item is the + predicted class label of the corresponding box. + """ + cfg = self.test_cfg if cfg is None else cfg + assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors) + mlvl_bboxes = [] + mlvl_scores = [] + mlvl_kps = [] + for cls_score, bbox_pred, kps_pred, stride, anchors in zip( + cls_scores, bbox_preds, kps_preds, + self.anchor_generator.strides, mlvl_anchors): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + assert stride[0] == stride[1] + + scores = cls_score.permute(1, 2, 0).reshape( + -1, self.cls_out_channels).sigmoid() + bbox_pred = bbox_pred.permute(1, 2, 0) + if self.use_dfl: + bbox_pred = self.integral(bbox_pred) * stride[0] + else: + bbox_pred = bbox_pred.reshape((-1, 4)) * stride[0] + if kps_pred is not None: + kps_pred = kps_pred.permute(1, 2, 0) + if self.use_dfl: + kps_pred = self.integral(kps_pred) * stride[0] + else: + kps_pred = kps_pred.reshape((-1, 10)) * stride[0] + + nms_pre = cfg.get('nms_pre', -1) + if nms_pre > 0 and scores.shape[0] > nms_pre: + max_scores, _ = scores.max(dim=1) + _, topk_inds = max_scores.topk(nms_pre) + anchors = anchors[topk_inds, :] + bbox_pred = bbox_pred[topk_inds, :] + scores = scores[topk_inds, :] + if kps_pred is not None: + kps_pred = kps_pred[topk_inds, :] + + bboxes = distance2bbox( + self.anchor_center(anchors), bbox_pred, max_shape=img_shape) + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + if kps_pred is not None: + kps = distance2kps(self.anchor_center(anchors), kps_pred) + mlvl_kps.append(kps) + + mlvl_bboxes = torch.cat(mlvl_bboxes) + if mlvl_kps is not None: + mlvl_kps = torch.cat(mlvl_kps) + if rescale: + mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) + if mlvl_kps is not None: + scale_factor2 = torch.tensor( + [scale_factor[0], scale_factor[1]] * 5) + mlvl_kps /= scale_factor2.to(mlvl_kps.device) + + mlvl_scores = torch.cat(mlvl_scores) + # Add a dummy background class to the backend when using sigmoid + # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 + # BG cat_id: num_class + padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) + mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) + + if with_nms: + det_bboxes, det_labels, det_kps = multiclass_nms( + mlvl_bboxes, + mlvl_scores, + cfg.score_thr, + cfg.nms, + cfg.max_per_img, + multi_kps=mlvl_kps) + if det_kps is not None: + return det_bboxes, det_labels, det_kps + else: + return det_bboxes, det_labels + else: + if mlvl_kps is not None: + return mlvl_bboxes, mlvl_scores, mlvl_kps + else: + return mlvl_bboxes, mlvl_scores + + def get_targets(self, + anchor_list, + valid_flag_list, + gt_bboxes_list, + gt_keypointss_list, + img_metas, + gt_bboxes_ignore_list=None, + gt_labels_list=None, + label_channels=1, + unmap_outputs=True): + """Get targets for GFL head. + + This method is almost the same as `AnchorHead.get_targets()`. Besides + returning the targets as the parent method does, it also returns the + anchors as the first element of the returned tuple. + """ + num_imgs = len(img_metas) + assert len(anchor_list) == len(valid_flag_list) == num_imgs + + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + num_level_anchors_list = [num_level_anchors] * num_imgs + + # concat all level anchors and flags to a single tensor + for i in range(num_imgs): + assert len(anchor_list[i]) == len(valid_flag_list[i]) + anchor_list[i] = torch.cat(anchor_list[i]) + valid_flag_list[i] = torch.cat(valid_flag_list[i]) + + # compute targets for each image + if gt_bboxes_ignore_list is None: + gt_bboxes_ignore_list = [None for _ in range(num_imgs)] + if gt_labels_list is None: + gt_labels_list = [None for _ in range(num_imgs)] + if gt_keypointss_list is None: + gt_keypointss_list = [None for _ in range(num_imgs)] + (all_anchors, all_labels, all_label_weights, all_bbox_targets, + all_bbox_weights, all_keypoints_targets, all_keypoints_weights, + pos_inds_list, neg_inds_list) = multi_apply( + self._get_target_single, + anchor_list, + valid_flag_list, + num_level_anchors_list, + gt_bboxes_list, + gt_bboxes_ignore_list, + gt_labels_list, + gt_keypointss_list, + img_metas, + label_channels=label_channels, + unmap_outputs=unmap_outputs) + # no valid anchors + if any([labels is None for labels in all_labels]): + return None + # sampled anchors of all images + num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list]) + num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list]) + # split targets to a list w.r.t. multiple levels + anchors_list = images_to_levels(all_anchors, num_level_anchors) + labels_list = images_to_levels(all_labels, num_level_anchors) + label_weights_list = images_to_levels(all_label_weights, + num_level_anchors) + bbox_targets_list = images_to_levels(all_bbox_targets, + num_level_anchors) + bbox_weights_list = images_to_levels(all_bbox_weights, + num_level_anchors) + keypoints_targets_list = images_to_levels(all_keypoints_targets, + num_level_anchors) + keypoints_weights_list = images_to_levels(all_keypoints_weights, + num_level_anchors) + return (anchors_list, labels_list, label_weights_list, + bbox_targets_list, bbox_weights_list, keypoints_targets_list, + keypoints_weights_list, num_total_pos, num_total_neg) + + def _get_target_single(self, + flat_anchors, + valid_flags, + num_level_anchors, + gt_bboxes, + gt_bboxes_ignore, + gt_labels, + gt_keypointss, + img_meta, + label_channels=1, + unmap_outputs=True): + """Compute regression, classification targets for anchors in a single + image. + + Args: + flat_anchors (Tensor): Multi-level anchors of the image, which are + concatenated into a single tensor of shape (num_anchors, 4) + valid_flags (Tensor): Multi level valid flags of the image, + which are concatenated into a single tensor of + shape (num_anchors,). + num_level_anchors Tensor): Number of anchors of each scale level. + gt_bboxes (Tensor): Ground truth bboxes of the image, + shape (num_gts, 4). + gt_bboxes_ignore (Tensor): Ground truth bboxes to be + ignored, shape (num_ignored_gts, 4). + gt_labels (Tensor): Ground truth labels of each box, + shape (num_gts,). + img_meta (dict): Meta info of the image. + label_channels (int): Channel of label. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Returns: + tuple: N is the number of total anchors in the image. + anchors (Tensor): All anchors in the image with shape (N, 4). + labels (Tensor): Labels of all anchors in the image with shape + (N,). + label_weights (Tensor): Label weights of all anchor in the + image with shape (N,). + bbox_targets (Tensor): BBox targets of all anchors in the + image with shape (N, 4). + bbox_weights (Tensor): BBox weights of all anchors in the + image with shape (N, 4). + pos_inds (Tensor): Indices of postive anchor with shape + (num_pos,). + neg_inds (Tensor): Indices of negative anchor with shape + (num_neg,). + """ + inside_flags = anchor_inside_flags(flat_anchors, valid_flags, + img_meta['img_shape'][:2], + self.train_cfg.allowed_border) + if not inside_flags.any(): + return (None, ) * 7 + # assign gt and sample anchors + anchors = flat_anchors[inside_flags, :] + + num_level_anchors_inside = self.get_num_level_anchors_inside( + num_level_anchors, inside_flags) + if self.assigner.__class__.__name__ == 'ATSSAssigner': + assign_result = self.assigner.assign(anchors, + num_level_anchors_inside, + gt_bboxes, gt_bboxes_ignore, + gt_labels) + else: + assign_result = self.assigner.assign(anchors, gt_bboxes, + gt_bboxes_ignore, gt_labels) + + sampling_result = self.sampler.sample(assign_result, anchors, + gt_bboxes) + + num_valid_anchors = anchors.shape[0] + bbox_targets = torch.zeros_like(anchors) + bbox_weights = torch.zeros_like(anchors) + kps_targets = anchors.new_zeros(size=(anchors.shape[0], self.NK * 2)) + kps_weights = anchors.new_zeros(size=(anchors.shape[0], self.NK * 2)) + labels = anchors.new_full((num_valid_anchors, ), + self.num_classes, + dtype=torch.long) + label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + pos_bbox_targets = sampling_result.pos_gt_bboxes + bbox_targets[pos_inds, :] = pos_bbox_targets + bbox_weights[pos_inds, :] = 1.0 + if self.use_kps: + pos_assigned_gt_inds = sampling_result.pos_assigned_gt_inds + kps_targets[pos_inds, :] = gt_keypointss[ + pos_assigned_gt_inds, :, :2].reshape((-1, self.NK * 2)) + kps_weights[pos_inds, :] = torch.mean( + gt_keypointss[pos_assigned_gt_inds, :, 2], + dim=1, + keepdims=True) + if gt_labels is None: + # Only rpn gives gt_labels as None + # Foreground is the first class + labels[pos_inds] = 0 + else: + labels[pos_inds] = gt_labels[ + sampling_result.pos_assigned_gt_inds] + if self.train_cfg.pos_weight <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = self.train_cfg.pos_weight + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + # map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_anchors.size(0) + anchors = unmap(anchors, num_total_anchors, inside_flags) + labels = unmap( + labels, num_total_anchors, inside_flags, fill=self.num_classes) + label_weights = unmap(label_weights, num_total_anchors, + inside_flags) + bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) + bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) + if self.use_kps: + kps_targets = unmap(kps_targets, num_total_anchors, + inside_flags) + kps_weights = unmap(kps_weights, num_total_anchors, + inside_flags) + + return (anchors, labels, label_weights, bbox_targets, bbox_weights, + kps_targets, kps_weights, pos_inds, neg_inds) + + def get_num_level_anchors_inside(self, num_level_anchors, inside_flags): + split_inside_flags = torch.split(inside_flags, num_level_anchors) + num_level_anchors_inside = [ + int(flags.sum()) for flags in split_inside_flags + ] + return num_level_anchors_inside + + def aug_test(self, feats, img_metas, rescale=False): + """Test function with test time augmentation. + + Args: + feats (list[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains features for all images in the batch. + img_metas (list[list[dict]]): the outer list indicates test-time + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. each dict has image information. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[ndarray]: bbox results of each class + """ + return self.aug_test_bboxes(feats, img_metas, rescale=rescale) diff --git a/modelscope/models/cv/face_detection/mmdet_patch/models/detectors/__init__.py b/modelscope/models/cv/face_detection/mmdet_patch/models/detectors/__init__.py new file mode 100755 index 00000000..1c16028f --- /dev/null +++ b/modelscope/models/cv/face_detection/mmdet_patch/models/detectors/__init__.py @@ -0,0 +1,3 @@ +from .scrfd import SCRFD + +__all__ = ['SCRFD'] diff --git a/modelscope/models/cv/face_detection/mmdet_patch/models/detectors/scrfd.py b/modelscope/models/cv/face_detection/mmdet_patch/models/detectors/scrfd.py new file mode 100755 index 00000000..98b6702c --- /dev/null +++ b/modelscope/models/cv/face_detection/mmdet_patch/models/detectors/scrfd.py @@ -0,0 +1,109 @@ +""" +based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/detectors/scrfd.py +""" +import torch +from mmdet.models.builder import DETECTORS +from mmdet.models.detectors.single_stage import SingleStageDetector + +from ....mmdet_patch.core.bbox import bbox2result + + +@DETECTORS.register_module() +class SCRFD(SingleStageDetector): + + def __init__(self, + backbone, + neck, + bbox_head, + train_cfg=None, + test_cfg=None, + pretrained=None): + super(SCRFD, self).__init__(backbone, neck, bbox_head, train_cfg, + test_cfg, pretrained) + + def forward_train(self, + img, + img_metas, + gt_bboxes, + gt_labels, + gt_keypointss=None, + gt_bboxes_ignore=None): + """ + Args: + img (Tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + img_metas (list[dict]): A List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + :class:`mmdet.datasets.pipelines.Collect`. + gt_bboxes (list[Tensor]): Each item are the truth boxes for each + image in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): Class indices corresponding to each box + gt_bboxes_ignore (None | list[Tensor]): Specify which bounding + boxes can be ignored when computing the loss. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + super(SingleStageDetector, self).forward_train(img, img_metas) + x = self.extract_feat(img) + losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes, + gt_labels, gt_keypointss, + gt_bboxes_ignore) + return losses + + def simple_test(self, img, img_metas, rescale=False): + """Test function without test time augmentation. + + Args: + imgs (list[torch.Tensor]): List of multiple images + img_metas (list[dict]): List of image information. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[list[np.ndarray]]: BBox results of each image and classes. + The outer list corresponds to each image. The inner list + corresponds to each class. + """ + x = self.extract_feat(img) + outs = self.bbox_head(x) + if torch.onnx.is_in_onnx_export(): + print('single_stage.py in-onnx-export') + print(outs.__class__) + cls_score, bbox_pred, kps_pred = outs + for c in cls_score: + print(c.shape) + for c in bbox_pred: + print(c.shape) + if self.bbox_head.use_kps: + for c in kps_pred: + print(c.shape) + return (cls_score, bbox_pred, kps_pred) + else: + return (cls_score, bbox_pred) + bbox_list = self.bbox_head.get_bboxes( + *outs, img_metas, rescale=rescale) + + # return kps if use_kps + if len(bbox_list[0]) == 2: + bbox_results = [ + bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) + for det_bboxes, det_labels in bbox_list + ] + elif len(bbox_list[0]) == 3: + bbox_results = [ + bbox2result( + det_bboxes, + det_labels, + self.bbox_head.num_classes, + kps=det_kps) + for det_bboxes, det_labels, det_kps in bbox_list + ] + return bbox_results + + def feature_test(self, img): + x = self.extract_feat(img) + outs = self.bbox_head(x) + return outs diff --git a/modelscope/models/cv/face_recognition/__init__.py b/modelscope/models/cv/face_recognition/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/face_recognition/align_face.py b/modelscope/models/cv/face_recognition/align_face.py new file mode 100644 index 00000000..a6469a10 --- /dev/null +++ b/modelscope/models/cv/face_recognition/align_face.py @@ -0,0 +1,50 @@ +import cv2 +import numpy as np +from skimage import transform as trans + + +def align_face(image, size, lmks): + dst_w = size[1] + dst_h = size[0] + # landmark calculation of dst images + base_w = 96 + base_h = 112 + assert (dst_w >= base_w) + assert (dst_h >= base_h) + base_lmk = [ + 30.2946, 51.6963, 65.5318, 51.5014, 48.0252, 71.7366, 33.5493, 92.3655, + 62.7299, 92.2041 + ] + + dst_lmk = np.array(base_lmk).reshape((5, 2)).astype(np.float32) + if dst_w != base_w: + slide = (dst_w - base_w) / 2 + dst_lmk[:, 0] += slide + + if dst_h != base_h: + slide = (dst_h - base_h) / 2 + dst_lmk[:, 1] += slide + + src_lmk = lmks + # using skimage method + tform = trans.SimilarityTransform() + tform.estimate(src_lmk, dst_lmk) + t = tform.params[0:2, :] + + assert (image.shape[2] == 3) + + dst_image = cv2.warpAffine(image.copy(), t, (dst_w, dst_h)) + dst_pts = GetAffinePoints(src_lmk, t) + return dst_image, dst_pts + + +def GetAffinePoints(pts_in, trans): + pts_out = pts_in.copy() + assert (pts_in.shape[1] == 2) + + for k in range(pts_in.shape[0]): + pts_out[k, 0] = pts_in[k, 0] * trans[0, 0] + pts_in[k, 1] * trans[ + 0, 1] + trans[0, 2] + pts_out[k, 1] = pts_in[k, 0] * trans[1, 0] + pts_in[k, 1] * trans[ + 1, 1] + trans[1, 2] + return pts_out diff --git a/modelscope/models/cv/face_recognition/torchkit/__init__.py b/modelscope/models/cv/face_recognition/torchkit/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/modelscope/models/cv/face_recognition/torchkit/backbone/__init__.py b/modelscope/models/cv/face_recognition/torchkit/backbone/__init__.py new file mode 100755 index 00000000..a58d8e17 --- /dev/null +++ b/modelscope/models/cv/face_recognition/torchkit/backbone/__init__.py @@ -0,0 +1,31 @@ +from .model_irse import (IR_18, IR_34, IR_50, IR_101, IR_152, IR_200, IR_SE_50, + IR_SE_101, IR_SE_152, IR_SE_200) +from .model_resnet import ResNet_50, ResNet_101, ResNet_152 + +_model_dict = { + 'ResNet_50': ResNet_50, + 'ResNet_101': ResNet_101, + 'ResNet_152': ResNet_152, + 'IR_18': IR_18, + 'IR_34': IR_34, + 'IR_50': IR_50, + 'IR_101': IR_101, + 'IR_152': IR_152, + 'IR_200': IR_200, + 'IR_SE_50': IR_SE_50, + 'IR_SE_101': IR_SE_101, + 'IR_SE_152': IR_SE_152, + 'IR_SE_200': IR_SE_200 +} + + +def get_model(key): + """ Get different backbone network by key, + support ResNet50, ResNet_101, ResNet_152 + IR_18, IR_34, IR_50, IR_101, IR_152, IR_200, + IR_SE_50, IR_SE_101, IR_SE_152, IR_SE_200. + """ + if key in _model_dict.keys(): + return _model_dict[key] + else: + raise KeyError('not support model {}'.format(key)) diff --git a/modelscope/models/cv/face_recognition/torchkit/backbone/common.py b/modelscope/models/cv/face_recognition/torchkit/backbone/common.py new file mode 100755 index 00000000..426d2591 --- /dev/null +++ b/modelscope/models/cv/face_recognition/torchkit/backbone/common.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn +from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Linear, Module, ReLU, + Sigmoid) + + +def initialize_weights(modules): + """ Weight initilize, conv2d and linear is initialized with kaiming_normal + """ + for m in modules: + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + m.bias.data.zero_() + + +class Flatten(Module): + """ Flat tensor + """ + + def forward(self, input): + return input.view(input.size(0), -1) + + +class SEModule(Module): + """ SE block + """ + + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc1 = Conv2d( + channels, + channels // reduction, + kernel_size=1, + padding=0, + bias=False) + + nn.init.xavier_uniform_(self.fc1.weight.data) + + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d( + channels // reduction, + channels, + kernel_size=1, + padding=0, + bias=False) + + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + + return module_input * x diff --git a/modelscope/models/cv/face_recognition/torchkit/backbone/model_irse.py b/modelscope/models/cv/face_recognition/torchkit/backbone/model_irse.py new file mode 100755 index 00000000..4fb7ee9c --- /dev/null +++ b/modelscope/models/cv/face_recognition/torchkit/backbone/model_irse.py @@ -0,0 +1,279 @@ +# based on: +# https://github.com/ZhaoJ9014/face.evoLVe.PyTorch/blob/master/backbone/model_irse.py +from collections import namedtuple + +from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, + MaxPool2d, Module, PReLU, Sequential) + +from .common import Flatten, SEModule, initialize_weights + + +class BasicBlockIR(Module): + """ BasicBlock for IRNet + """ + + def __init__(self, in_channel, depth, stride): + super(BasicBlockIR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth)) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + BatchNorm2d(depth), PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth)) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + + return res + shortcut + + +class BottleneckIR(Module): + """ BasicBlock with bottleneck for IRNet + """ + + def __init__(self, in_channel, depth, stride): + super(BottleneckIR, self).__init__() + reduction_channel = depth // 4 + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth)) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d( + in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False), + BatchNorm2d(reduction_channel), PReLU(reduction_channel), + Conv2d( + reduction_channel, + reduction_channel, (3, 3), (1, 1), + 1, + bias=False), BatchNorm2d(reduction_channel), + PReLU(reduction_channel), + Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False), + BatchNorm2d(depth)) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + + return res + shortcut + + +class BasicBlockIRSE(BasicBlockIR): + + def __init__(self, in_channel, depth, stride): + super(BasicBlockIRSE, self).__init__(in_channel, depth, stride) + self.res_layer.add_module('se_block', SEModule(depth, 16)) + + +class BottleneckIRSE(BottleneckIR): + + def __init__(self, in_channel, depth, stride): + super(BottleneckIRSE, self).__init__(in_channel, depth, stride) + self.res_layer.add_module('se_block', SEModule(depth, 16)) + + +class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): + '''A named tuple describing a ResNet block.''' + + +def get_block(in_channel, depth, num_units, stride=2): + + return [Bottleneck(in_channel, depth, stride)] +\ + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 18: + blocks = [ + get_block(in_channel=64, depth=64, num_units=2), + get_block(in_channel=64, depth=128, num_units=2), + get_block(in_channel=128, depth=256, num_units=2), + get_block(in_channel=256, depth=512, num_units=2) + ] + elif num_layers == 34: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=6), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=256, num_units=3), + get_block(in_channel=256, depth=512, num_units=8), + get_block(in_channel=512, depth=1024, num_units=36), + get_block(in_channel=1024, depth=2048, num_units=3) + ] + elif num_layers == 200: + blocks = [ + get_block(in_channel=64, depth=256, num_units=3), + get_block(in_channel=256, depth=512, num_units=24), + get_block(in_channel=512, depth=1024, num_units=36), + get_block(in_channel=1024, depth=2048, num_units=3) + ] + + return blocks + + +class Backbone(Module): + + def __init__(self, input_size, num_layers, mode='ir'): + """ Args: + input_size: input_size of backbone + num_layers: num_layers of backbone + mode: support ir or irse + """ + super(Backbone, self).__init__() + assert input_size[0] in [112, 224], \ + 'input_size should be [112, 112] or [224, 224]' + assert num_layers in [18, 34, 50, 100, 152, 200], \ + 'num_layers should be 18, 34, 50, 100 or 152' + assert mode in ['ir', 'ir_se'], \ + 'mode should be ir or ir_se' + self.input_layer = Sequential( + Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), + PReLU(64)) + blocks = get_blocks(num_layers) + if num_layers <= 100: + if mode == 'ir': + unit_module = BasicBlockIR + elif mode == 'ir_se': + unit_module = BasicBlockIRSE + output_channel = 512 + else: + if mode == 'ir': + unit_module = BottleneckIR + elif mode == 'ir_se': + unit_module = BottleneckIRSE + output_channel = 2048 + + if input_size[0] == 112: + self.output_layer = Sequential( + BatchNorm2d(output_channel), Dropout(0.4), Flatten(), + Linear(output_channel * 7 * 7, 512), + BatchNorm1d(512, affine=False)) + else: + self.output_layer = Sequential( + BatchNorm2d(output_channel), Dropout(0.4), Flatten(), + Linear(output_channel * 14 * 14, 512), + BatchNorm1d(512, affine=False)) + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append( + unit_module(bottleneck.in_channel, bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + initialize_weights(self.modules()) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer(x) + return x + + +def IR_18(input_size): + """ Constructs a ir-18 model. + """ + model = Backbone(input_size, 18, 'ir') + + return model + + +def IR_34(input_size): + """ Constructs a ir-34 model. + """ + model = Backbone(input_size, 34, 'ir') + + return model + + +def IR_50(input_size): + """ Constructs a ir-50 model. + """ + model = Backbone(input_size, 50, 'ir') + + return model + + +def IR_101(input_size): + """ Constructs a ir-101 model. + """ + model = Backbone(input_size, 100, 'ir') + + return model + + +def IR_152(input_size): + """ Constructs a ir-152 model. + """ + model = Backbone(input_size, 152, 'ir') + + return model + + +def IR_200(input_size): + """ Constructs a ir-200 model. + """ + model = Backbone(input_size, 200, 'ir') + + return model + + +def IR_SE_50(input_size): + """ Constructs a ir_se-50 model. + """ + model = Backbone(input_size, 50, 'ir_se') + + return model + + +def IR_SE_101(input_size): + """ Constructs a ir_se-101 model. + """ + model = Backbone(input_size, 100, 'ir_se') + + return model + + +def IR_SE_152(input_size): + """ Constructs a ir_se-152 model. + """ + model = Backbone(input_size, 152, 'ir_se') + + return model + + +def IR_SE_200(input_size): + """ Constructs a ir_se-200 model. + """ + model = Backbone(input_size, 200, 'ir_se') + + return model diff --git a/modelscope/models/cv/face_recognition/torchkit/backbone/model_resnet.py b/modelscope/models/cv/face_recognition/torchkit/backbone/model_resnet.py new file mode 100755 index 00000000..7072f384 --- /dev/null +++ b/modelscope/models/cv/face_recognition/torchkit/backbone/model_resnet.py @@ -0,0 +1,162 @@ +# based on: +# https://github.com/ZhaoJ9014/face.evoLVe.PyTorch/blob/master/backbone/model_resnet.py +import torch.nn as nn +from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, + MaxPool2d, Module, ReLU, Sequential) + +from .common import initialize_weights + + +def conv3x3(in_planes, out_planes, stride=1): + """ 3x3 convolution with padding + """ + return Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + + +def conv1x1(in_planes, out_planes, stride=1): + """ 1x1 convolution + """ + return Conv2d( + in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class Bottleneck(Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = conv1x1(inplanes, planes) + self.bn1 = BatchNorm2d(planes) + self.conv2 = conv3x3(planes, planes, stride) + self.bn2 = BatchNorm2d(planes) + self.conv3 = conv1x1(planes, planes * self.expansion) + self.bn3 = BatchNorm2d(planes * self.expansion) + self.relu = ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(Module): + """ ResNet backbone + """ + + def __init__(self, input_size, block, layers, zero_init_residual=True): + """ Args: + input_size: input_size of backbone + block: block function + layers: layers in each block + """ + super(ResNet, self).__init__() + assert input_size[0] in [112, 224],\ + 'input_size should be [112, 112] or [224, 224]' + self.inplanes = 64 + self.conv1 = Conv2d( + 3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = BatchNorm2d(64) + self.relu = ReLU(inplace=True) + self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + self.bn_o1 = BatchNorm2d(2048) + self.dropout = Dropout() + if input_size[0] == 112: + self.fc = Linear(2048 * 4 * 4, 512) + else: + self.fc = Linear(2048 * 7 * 7, 512) + self.bn_o2 = BatchNorm1d(512) + + initialize_weights(self.modules) + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.bn_o1(x) + x = self.dropout(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + x = self.bn_o2(x) + + return x + + +def ResNet_50(input_size, **kwargs): + """ Constructs a ResNet-50 model. + """ + model = ResNet(input_size, Bottleneck, [3, 4, 6, 3], **kwargs) + + return model + + +def ResNet_101(input_size, **kwargs): + """ Constructs a ResNet-101 model. + """ + model = ResNet(input_size, Bottleneck, [3, 4, 23, 3], **kwargs) + + return model + + +def ResNet_152(input_size, **kwargs): + """ Constructs a ResNet-152 model. + """ + model = ResNet(input_size, Bottleneck, [3, 8, 36, 3], **kwargs) + + return model diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 783142c6..cffbc05f 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -13,6 +13,7 @@ class OutputKeys(object): POSES = 'poses' CAPTION = 'caption' BOXES = 'boxes' + KEYPOINTS = 'keypoints' MASKS = 'masks' TEXT = 'text' POLYGONS = 'polygons' @@ -55,6 +56,31 @@ TASK_OUTPUTS = { Tasks.object_detection: [OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.BOXES], + # face detection result for single sample + # { + # "scores": [0.9, 0.1, 0.05, 0.05] + # "boxes": [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ], + # "keypoints": [ + # [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5], + # [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5], + # [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5], + # [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5], + # ], + # } + Tasks.face_detection: + [OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS], + + # face recognition result for single sample + # { + # "img_embedding": np.array with shape [1, D], + # } + Tasks.face_recognition: [OutputKeys.IMG_EMBEDDING], + # instance segmentation result for single sample # { # "scores": [0.9, 0.1, 0.05, 0.05], diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index ca8f5a85..6e2d6bc7 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -255,7 +255,11 @@ class Pipeline(ABC): elif isinstance(data, InputFeatures): return data else: - raise ValueError(f'Unsupported data type {type(data)}') + import mmcv + if isinstance(data, mmcv.parallel.data_container.DataContainer): + return data + else: + raise ValueError(f'Unsupported data type {type(data)}') def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]: preprocess_params = kwargs.get('preprocess_params') diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 06f435ff..cf8b1147 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -80,6 +80,10 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.text_to_image_synthesis: (Pipelines.text_to_image_synthesis, 'damo/cv_imagen_text-to-image-synthesis_tiny'), + Tasks.face_detection: (Pipelines.face_detection, + 'damo/cv_resnet_facedetection_scrfd10gkps'), + Tasks.face_recognition: (Pipelines.face_recognition, + 'damo/cv_ir101_facerecognition_cfglint'), Tasks.video_multi_modal_embedding: (Pipelines.video_multi_modal_embedding, 'damo/multi_modal_clip_vtretrival_msrvtt_53'), diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index d183c889..abfefcca 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -5,44 +5,50 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .action_recognition_pipeline import ActionRecognitionPipeline - from .animal_recog_pipeline import AnimalRecogPipeline - from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline - from .live_category_pipeline import LiveCategoryPipeline - from .image_classification_pipeline import GeneralImageClassificationPipeline + from .animal_recognition_pipeline import AnimalRecognitionPipeline + from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline + from .face_detection_pipeline import FaceDetectionPipeline + from .face_recognition_pipeline import FaceRecognitionPipeline from .face_image_generation_pipeline import FaceImageGenerationPipeline from .image_cartoon_pipeline import ImageCartoonPipeline + from .image_classification_pipeline import GeneralImageClassificationPipeline from .image_denoise_pipeline import ImageDenoisePipeline from .image_color_enhance_pipeline import ImageColorEnhancePipeline from .image_colorization_pipeline import ImageColorizationPipeline from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline - from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline - from .video_category_pipeline import VideoCategoryPipeline from .image_matting_pipeline import ImageMattingPipeline from .image_super_resolution_pipeline import ImageSuperResolutionPipeline + from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline from .style_transfer_pipeline import StyleTransferPipeline + from .live_category_pipeline import LiveCategoryPipeline from .ocr_detection_pipeline import OCRDetectionPipeline + from .video_category_pipeline import VideoCategoryPipeline from .virtual_tryon_pipeline import VirtualTryonPipeline else: _import_structure = { 'action_recognition_pipeline': ['ActionRecognitionPipeline'], - 'animal_recog_pipeline': ['AnimalRecogPipeline'], - 'cmdssl_video_embedding_pipleline': ['CMDSSLVideoEmbeddingPipeline'], + 'animal_recognition_pipeline': ['AnimalRecognitionPipeline'], + 'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'], + 'face_detection_pipeline': ['FaceDetectionPipeline'], + 'face_image_generation_pipeline': ['FaceImageGenerationPipeline'], + 'face_recognition_pipeline': ['FaceRecognitionPipeline'], 'image_classification_pipeline': ['GeneralImageClassificationPipeline'], + 'image_cartoon_pipeline': ['ImageCartoonPipeline'], + 'image_denoise_pipeline': ['ImageDenoisePipeline'], 'image_color_enhance_pipeline': ['ImageColorEnhancePipeline'], - 'virtual_tryon_pipeline': ['VirtualTryonPipeline'], 'image_colorization_pipeline': ['ImageColorizationPipeline'], - 'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'], - 'image_denoise_pipeline': ['ImageDenoisePipeline'], - 'face_image_generation_pipeline': ['FaceImageGenerationPipeline'], - 'image_cartoon_pipeline': ['ImageCartoonPipeline'], - 'image_matting_pipeline': ['ImageMattingPipeline'], - 'style_transfer_pipeline': ['StyleTransferPipeline'], - 'ocr_detection_pipeline': ['OCRDetectionPipeline'], 'image_instance_segmentation_pipeline': ['ImageInstanceSegmentationPipeline'], - 'video_category_pipeline': ['VideoCategoryPipeline'], + 'image_matting_pipeline': ['ImageMattingPipeline'], + 'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'], + 'image_to_image_translation_pipeline': + ['Image2ImageTranslationPipeline'], 'live_category_pipeline': ['LiveCategoryPipeline'], + 'ocr_detection_pipeline': ['OCRDetectionPipeline'], + 'style_transfer_pipeline': ['StyleTransferPipeline'], + 'video_category_pipeline': ['VideoCategoryPipeline'], + 'virtual_tryon_pipeline': ['VirtualTryonPipeline'], } import sys diff --git a/modelscope/pipelines/cv/action_recognition_pipeline.py b/modelscope/pipelines/cv/action_recognition_pipeline.py index a2e7c0a8..087548f0 100644 --- a/modelscope/pipelines/cv/action_recognition_pipeline.py +++ b/modelscope/pipelines/cv/action_recognition_pipeline.py @@ -23,7 +23,7 @@ class ActionRecognitionPipeline(Pipeline): def __init__(self, model: str, **kwargs): """ - use `model` and `preprocessor` to create a kws pipeline for prediction + use `model` to create a action recognition pipeline for prediction Args: model: model id on modelscope hub. """ diff --git a/modelscope/pipelines/cv/animal_recog_pipeline.py b/modelscope/pipelines/cv/animal_recognition_pipeline.py similarity index 97% rename from modelscope/pipelines/cv/animal_recog_pipeline.py rename to modelscope/pipelines/cv/animal_recognition_pipeline.py index fd3903ec..ab0232bd 100644 --- a/modelscope/pipelines/cv/animal_recog_pipeline.py +++ b/modelscope/pipelines/cv/animal_recognition_pipeline.py @@ -22,11 +22,11 @@ logger = get_logger() @PIPELINES.register_module( Tasks.image_classification, module_name=Pipelines.animal_recognation) -class AnimalRecogPipeline(Pipeline): +class AnimalRecognitionPipeline(Pipeline): def __init__(self, model: str, **kwargs): """ - use `model` and `preprocessor` to create a kws pipeline for prediction + use `model` to create a animal recognition pipeline for prediction Args: model: model id on modelscope hub. """ diff --git a/modelscope/pipelines/cv/cmdssl_video_embedding_pipleline.py b/modelscope/pipelines/cv/cmdssl_video_embedding_pipeline.py similarity index 98% rename from modelscope/pipelines/cv/cmdssl_video_embedding_pipleline.py rename to modelscope/pipelines/cv/cmdssl_video_embedding_pipeline.py index 1a80fbb8..f29d766c 100644 --- a/modelscope/pipelines/cv/cmdssl_video_embedding_pipleline.py +++ b/modelscope/pipelines/cv/cmdssl_video_embedding_pipeline.py @@ -24,7 +24,7 @@ class CMDSSLVideoEmbeddingPipeline(Pipeline): def __init__(self, model: str, **kwargs): """ - use `model` and `preprocessor` to create a kws pipeline for prediction + use `model` to create a CMDSSL Video Embedding pipeline for prediction Args: model: model id on modelscope hub. """ diff --git a/modelscope/pipelines/cv/face_detection_pipeline.py b/modelscope/pipelines/cv/face_detection_pipeline.py new file mode 100644 index 00000000..8fda5b46 --- /dev/null +++ b/modelscope/pipelines/cv/face_detection_pipeline.py @@ -0,0 +1,105 @@ +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.face_detection, module_name=Pipelines.face_detection) +class FaceDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a face detection pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + from mmcv import Config + from mmcv.parallel import MMDataParallel + from mmcv.runner import load_checkpoint + from mmdet.models import build_detector + from modelscope.models.cv.face_detection.mmdet_patch.datasets import RetinaFaceDataset + from modelscope.models.cv.face_detection.mmdet_patch.datasets.pipelines import RandomSquareCrop + from modelscope.models.cv.face_detection.mmdet_patch.models.backbones import ResNetV1e + from modelscope.models.cv.face_detection.mmdet_patch.models.dense_heads import SCRFDHead + from modelscope.models.cv.face_detection.mmdet_patch.models.detectors import SCRFD + cfg = Config.fromfile(osp.join(model, 'mmcv_scrfd_10g_bnkps.py')) + detector = build_detector( + cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) + ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_BIN_FILE) + logger.info(f'loading model from {ckpt_path}') + device = torch.device( + f'cuda:{0}' if torch.cuda.is_available() else 'cpu') + load_checkpoint(detector, ckpt_path, map_location=device) + detector = MMDataParallel(detector, device_ids=[0]) + detector.eval() + self.detector = detector + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + img = img.astype(np.float32) + pre_pipeline = [ + dict( + type='MultiScaleFlipAug', + img_scale=(640, 640), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.0), + dict( + type='Normalize', + mean=[127.5, 127.5, 127.5], + std=[128.0, 128.0, 128.0], + to_rgb=False), + dict(type='Pad', size=(640, 640), pad_val=0), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ] + from mmdet.datasets.pipelines import Compose + pipeline = Compose(pre_pipeline) + result = {} + result['filename'] = '' + result['ori_filename'] = '' + result['img'] = img + result['img_shape'] = img.shape + result['ori_shape'] = img.shape + result['img_fields'] = ['img'] + result = pipeline(result) + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + result = self.detector( + return_loss=False, + rescale=True, + img=[input['img'][0].unsqueeze(0)], + img_metas=[[dict(input['img_metas'][0].data)]]) + assert result is not None + result = result[0][0] + bboxes = result[:, :4].tolist() + kpss = result[:, 5:].tolist() + scores = result[:, 4].tolist() + return { + OutputKeys.SCORES: scores, + OutputKeys.BOXES: bboxes, + OutputKeys.KEYPOINTS: kpss + } + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/face_image_generation_pipeline.py b/modelscope/pipelines/cv/face_image_generation_pipeline.py index e3aa0777..31c97b30 100644 --- a/modelscope/pipelines/cv/face_image_generation_pipeline.py +++ b/modelscope/pipelines/cv/face_image_generation_pipeline.py @@ -24,7 +24,7 @@ class FaceImageGenerationPipeline(Pipeline): def __init__(self, model: str, **kwargs): """ - use `model` to create a kws pipeline for prediction + use `model` to create a face image generation pipeline for prediction Args: model: model id on modelscope hub. """ diff --git a/modelscope/pipelines/cv/face_recognition_pipeline.py b/modelscope/pipelines/cv/face_recognition_pipeline.py new file mode 100644 index 00000000..3779b055 --- /dev/null +++ b/modelscope/pipelines/cv/face_recognition_pipeline.py @@ -0,0 +1,130 @@ +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.face_recognition.align_face import align_face +from modelscope.models.cv.face_recognition.torchkit.backbone import get_model +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.face_recognition, module_name=Pipelines.face_recognition) +class FaceRecognitionPipeline(Pipeline): + + def __init__(self, model: str, face_detection: Pipeline, **kwargs): + """ + use `model` to create a face recognition pipeline for prediction + Args: + model: model id on modelscope hub. + face_detecion: pipeline for face detection and face alignment before recognition + """ + + # face recong model + super().__init__(model=model, **kwargs) + device = torch.device( + f'cuda:{0}' if torch.cuda.is_available() else 'cpu') + self.device = device + face_model = get_model('IR_101')([112, 112]) + face_model.load_state_dict( + torch.load( + osp.join(model, ModelFile.TORCH_MODEL_BIN_FILE), + map_location=device)) + face_model = face_model.to(device) + face_model.eval() + self.face_model = face_model + logger.info('face recognition model loaded!') + # face detect pipeline + self.face_detection = face_detection + + def _choose_face(self, + det_result, + min_face=10, + top_face=1, + center_face=False): + ''' + choose face with maximum area + Args: + det_result: output of face detection pipeline + min_face: minimum size of valid face w/h + top_face: take faces with top max areas + center_face: choose the most centerd face from multi faces, only valid if top_face > 1 + ''' + bboxes = np.array(det_result[OutputKeys.BOXES]) + landmarks = np.array(det_result[OutputKeys.KEYPOINTS]) + # scores = np.array(det_result[OutputKeys.SCORES]) + if bboxes.shape[0] == 0: + logger.info('No face detected!') + return None + # face idx with enough size + face_idx = [] + for i in range(bboxes.shape[0]): + box = bboxes[i] + if (box[2] - box[0]) >= min_face and (box[3] - box[1]) >= min_face: + face_idx += [i] + if len(face_idx) == 0: + logger.info( + f'Face size not enough, less than {min_face}x{min_face}!') + return None + bboxes = bboxes[face_idx] + landmarks = landmarks[face_idx] + # find max faces + boxes = np.array(bboxes) + area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + sort_idx = np.argsort(area)[-top_face:] + # find center face + if top_face > 1 and center_face and bboxes.shape[0] > 1: + img_center = [img.shape[1] // 2, img.shape[0] // 2] + min_dist = float('inf') + sel_idx = -1 + for _idx in sort_idx: + box = boxes[_idx] + dist = np.square( + np.abs((box[0] + box[2]) / 2 - img_center[0])) + np.square( + np.abs((box[1] + box[3]) / 2 - img_center[1])) + if dist < min_dist: + min_dist = dist + sel_idx = _idx + sort_idx = [sel_idx] + main_idx = sort_idx[-1] + return bboxes[main_idx], landmarks[main_idx] + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + img = img[:, :, ::-1] + det_result = self.face_detection(img.copy()) + rtn = self._choose_face(det_result) + face_img = None + if rtn is not None: + _, face_lmks = rtn + face_lmks = face_lmks.reshape(5, 2) + align_img, _ = align_face(img, (112, 112), face_lmks) + face_img = align_img[:, :, ::-1] # to rgb + face_img = np.transpose(face_img, axes=(2, 0, 1)) + face_img = (face_img / 255. - 0.5) / 0.5 + face_img = face_img.astype(np.float32) + result = {} + result['img'] = face_img + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + assert input['img'] is not None + img = input['img'].unsqueeze(0) + emb = self.face_model(img).detach().cpu().numpy() + emb /= np.sqrt(np.sum(emb**2, -1, keepdims=True)) # l2 norm + return {OutputKeys.IMG_EMBEDDING: emb} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/image_cartoon_pipeline.py b/modelscope/pipelines/cv/image_cartoon_pipeline.py index 5ea76f6a..46a30ad0 100644 --- a/modelscope/pipelines/cv/image_cartoon_pipeline.py +++ b/modelscope/pipelines/cv/image_cartoon_pipeline.py @@ -30,7 +30,7 @@ class ImageCartoonPipeline(Pipeline): def __init__(self, model: str, **kwargs): """ - use `model` and `preprocessor` to create a kws pipeline for prediction + use `model` to create a image cartoon pipeline for prediction Args: model: model id on modelscope hub. """ diff --git a/modelscope/pipelines/cv/image_color_enhance_pipeline.py b/modelscope/pipelines/cv/image_color_enhance_pipeline.py index 6e3ece68..b9007f77 100644 --- a/modelscope/pipelines/cv/image_color_enhance_pipeline.py +++ b/modelscope/pipelines/cv/image_color_enhance_pipeline.py @@ -27,7 +27,7 @@ class ImageColorEnhancePipeline(Pipeline): ImageColorEnhanceFinetunePreprocessor] = None, **kwargs): """ - use `model` and `preprocessor` to create a kws pipeline for prediction + use `model` and `preprocessor` to create a image color enhance pipeline for prediction Args: model: model id on modelscope hub. """ diff --git a/modelscope/pipelines/cv/image_colorization_pipeline.py b/modelscope/pipelines/cv/image_colorization_pipeline.py index 3f5bd706..838ccab5 100644 --- a/modelscope/pipelines/cv/image_colorization_pipeline.py +++ b/modelscope/pipelines/cv/image_colorization_pipeline.py @@ -25,7 +25,7 @@ class ImageColorizationPipeline(Pipeline): def __init__(self, model: str, **kwargs): """ - use `model` to create a kws pipeline for prediction + use `model` to create a image colorization pipeline for prediction Args: model: model id on modelscope hub. """ diff --git a/modelscope/pipelines/cv/image_matting_pipeline.py b/modelscope/pipelines/cv/image_matting_pipeline.py index 3166c9a8..2faaec37 100644 --- a/modelscope/pipelines/cv/image_matting_pipeline.py +++ b/modelscope/pipelines/cv/image_matting_pipeline.py @@ -21,7 +21,7 @@ class ImageMattingPipeline(Pipeline): def __init__(self, model: str, **kwargs): """ - use `model` and `preprocessor` to create a kws pipeline for prediction + use `model` to create a image matting pipeline for prediction Args: model: model id on modelscope hub. """ diff --git a/modelscope/pipelines/cv/image_super_resolution_pipeline.py b/modelscope/pipelines/cv/image_super_resolution_pipeline.py index a9839281..6464fe69 100644 --- a/modelscope/pipelines/cv/image_super_resolution_pipeline.py +++ b/modelscope/pipelines/cv/image_super_resolution_pipeline.py @@ -23,7 +23,7 @@ class ImageSuperResolutionPipeline(Pipeline): def __init__(self, model: str, **kwargs): """ - use `model` to create a kws pipeline for prediction + use `model` to create a image super resolution pipeline for prediction Args: model: model id on modelscope hub. """ diff --git a/modelscope/pipelines/cv/ocr_detection_pipeline.py b/modelscope/pipelines/cv/ocr_detection_pipeline.py index c95e0c9f..32209c1e 100644 --- a/modelscope/pipelines/cv/ocr_detection_pipeline.py +++ b/modelscope/pipelines/cv/ocr_detection_pipeline.py @@ -41,7 +41,7 @@ class OCRDetectionPipeline(Pipeline): def __init__(self, model: str, **kwargs): """ - use `model` and `preprocessor` to create a kws pipeline for prediction + use `model` to create a OCR detection pipeline for prediction Args: model: model id on modelscope hub. """ diff --git a/modelscope/pipelines/cv/style_transfer_pipeline.py b/modelscope/pipelines/cv/style_transfer_pipeline.py index 687f0d40..efafc2a7 100644 --- a/modelscope/pipelines/cv/style_transfer_pipeline.py +++ b/modelscope/pipelines/cv/style_transfer_pipeline.py @@ -21,7 +21,7 @@ class StyleTransferPipeline(Pipeline): def __init__(self, model: str, **kwargs): """ - use `model` and `preprocessor` to create a kws pipeline for prediction + use `model` to create a style transfer pipeline for prediction Args: model: model id on modelscope hub. """ diff --git a/modelscope/pipelines/cv/virtual_tryon_pipeline.py b/modelscope/pipelines/cv/virtual_tryon_pipeline.py index f29ab351..afd5ad1a 100644 --- a/modelscope/pipelines/cv/virtual_tryon_pipeline.py +++ b/modelscope/pipelines/cv/virtual_tryon_pipeline.py @@ -25,7 +25,7 @@ class VirtualTryonPipeline(Pipeline): def __init__(self, model: str, **kwargs): """ - use `model` to create a kws pipeline for prediction + use `model` to create a virtual tryon pipeline for prediction Args: model: model id on modelscope hub. """ diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 666872ba..eececd8d 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -28,6 +28,8 @@ class CVTasks(object): ocr_detection = 'ocr-detection' action_recognition = 'action-recognition' video_embedding = 'video-embedding' + face_detection = 'face-detection' + face_recognition = 'face-recognition' image_color_enhance = 'image-color-enhance' virtual_tryon = 'virtual-tryon' image_colorization = 'image-colorization' diff --git a/tests/pipelines/test_face_detection.py b/tests/pipelines/test_face_detection.py new file mode 100644 index 00000000..23fda2c5 --- /dev/null +++ b/tests/pipelines/test_face_detection.py @@ -0,0 +1,84 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import tempfile +import unittest + +import cv2 +import numpy as np + +from modelscope.fileio import File +from modelscope.msdatasets import MsDataset +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.test_utils import test_level + + +class FaceDetectionTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' + + def show_result(self, img_path, bboxes, kpss, scores): + bboxes = np.array(bboxes) + kpss = np.array(kpss) + scores = np.array(scores) + img = cv2.imread(img_path) + assert img is not None, f"Can't read img: {img_path}" + for i in range(len(scores)): + bbox = bboxes[i].astype(np.int32) + kps = kpss[i].reshape(-1, 2).astype(np.int32) + score = scores[i] + x1, y1, x2, y2 = bbox + cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2) + for kp in kps: + cv2.circle(img, tuple(kp), 1, (0, 0, 255), 1) + cv2.putText( + img, + f'{score:.2f}', (x1, y2), + 1, + 1.0, (0, 255, 0), + thickness=1, + lineType=8) + cv2.imwrite('result.png', img) + print( + f'Found {len(scores)} faces, output written to {osp.abspath("result.png")}' + ) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_dataset(self): + input_location = ['data/test/images/face_detection.png'] + # alternatively: + # input_location = '/dir/to/images' + + dataset = MsDataset.load(input_location, target='image') + face_detection = pipeline(Tasks.face_detection, model=self.model_id) + # note that for dataset output, the inference-output is a Generator that can be iterated. + result = face_detection(dataset) + result = next(result) + self.show_result(input_location[0], result[OutputKeys.BOXES], + result[OutputKeys.KEYPOINTS], + result[OutputKeys.SCORES]) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + face_detection = pipeline(Tasks.face_detection, model=self.model_id) + img_path = 'data/test/images/face_detection.png' + + result = face_detection(img_path) + self.show_result(img_path, result[OutputKeys.BOXES], + result[OutputKeys.KEYPOINTS], + result[OutputKeys.SCORES]) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + face_detection = pipeline(Tasks.face_detection) + img_path = 'data/test/images/face_detection.png' + result = face_detection(img_path) + self.show_result(img_path, result[OutputKeys.BOXES], + result[OutputKeys.KEYPOINTS], + result[OutputKeys.SCORES]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_face_recognition.py b/tests/pipelines/test_face_recognition.py new file mode 100644 index 00000000..a41de3e3 --- /dev/null +++ b/tests/pipelines/test_face_recognition.py @@ -0,0 +1,42 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import tempfile +import unittest + +import cv2 +import numpy as np + +from modelscope.fileio import File +from modelscope.msdatasets import MsDataset +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.test_utils import test_level + + +class FaceRecognitionTest(unittest.TestCase): + + def setUp(self) -> None: + self.recog_model_id = 'damo/cv_ir101_facerecognition_cfglint' + self.det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_face_compare(self): + img1 = 'data/test/images/face_recognition_1.png' + img2 = 'data/test/images/face_recognition_2.png' + + face_detection = pipeline( + Tasks.face_detection, model=self.det_model_id) + face_recognition = pipeline( + Tasks.face_recognition, + face_detection=face_detection, + model=self.recog_model_id) + # note that for dataset output, the inference-output is a Generator that can be iterated. + emb1 = face_recognition(img1)[OutputKeys.IMG_EMBEDDING] + emb2 = face_recognition(img2)[OutputKeys.IMG_EMBEDDING] + sim = np.dot(emb1[0], emb2[0]) + print(f'Cos similarity={sim:.3f}, img1:{img1} img2:{img2}') + + +if __name__ == '__main__': + unittest.main()