1. support FaceDetectionPipeline inference
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9470723
master
| @@ -121,6 +121,7 @@ source.sh | |||
| tensorboard.sh | |||
| .DS_Store | |||
| replace.sh | |||
| result.png | |||
| # Pytorch | |||
| *.pth | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:aa3963d1c54e6d3d46e9a59872a99ed955d4050092f5cfe5f591e03d740b7042 | |||
| size 653006 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:48e541daeb2692907efef47018e41abb5ae6bcd88eb5ff58290d7fe5dc8b2a13 | |||
| size 462584 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:e9565b43d9f65361b9bad6553b327c2c6f02fd063a4c8dc0f461e88ea461989d | |||
| size 357166 | |||
| @@ -10,6 +10,7 @@ class Models(object): | |||
| Model name should only contain model info but not task info. | |||
| """ | |||
| # vision models | |||
| scrfd = 'scrfd' | |||
| classification_model = 'ClassificationModel' | |||
| nafnet = 'nafnet' | |||
| csrnet = 'csrnet' | |||
| @@ -67,6 +68,7 @@ class Pipelines(object): | |||
| action_recognition = 'TAdaConv_action-recognition' | |||
| animal_recognation = 'resnet101-animal_recog' | |||
| cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | |||
| face_detection = 'resnet-face-detection-scrfd10gkps' | |||
| live_category = 'live-category' | |||
| general_image_classification = 'vit-base_image-classification_ImageNet-labels' | |||
| daily_image_classification = 'vit-base_image-classification_Dailylife-labels' | |||
| @@ -76,6 +78,7 @@ class Pipelines(object): | |||
| image_super_resolution = 'rrdb-image-super-resolution' | |||
| face_image_generation = 'gan-face-image-generation' | |||
| style_transfer = 'AAMS-style-transfer' | |||
| face_recognition = 'ir101-face-recognition-cfglint' | |||
| image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' | |||
| image2image_translation = 'image-to-image-translation' | |||
| live_category = 'live-category' | |||
| @@ -1,5 +1,6 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from . import (action_recognition, animal_recognition, cartoon, | |||
| cmdssl_video_embedding, face_generation, image_classification, | |||
| image_color_enhance, image_colorization, image_denoise, | |||
| image_instance_segmentation, super_resolution, virual_tryon) | |||
| cmdssl_video_embedding, face_detection, face_generation, | |||
| image_classification, image_color_enhance, image_colorization, | |||
| image_denoise, image_instance_segmentation, | |||
| image_to_image_translation, super_resolution, virual_tryon) | |||
| @@ -0,0 +1,5 @@ | |||
| """ | |||
| mmdet_patch is based on | |||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet, | |||
| all duplicate functions from official mmdetection are removed. | |||
| """ | |||
| @@ -0,0 +1,3 @@ | |||
| from .transforms import bbox2result, distance2kps, kps2distance | |||
| __all__ = ['bbox2result', 'distance2kps', 'kps2distance'] | |||
| @@ -0,0 +1,86 @@ | |||
| """ | |||
| based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/core/bbox/transforms.py | |||
| """ | |||
| import numpy as np | |||
| import torch | |||
| def bbox2result(bboxes, labels, num_classes, kps=None): | |||
| """Convert detection results to a list of numpy arrays. | |||
| Args: | |||
| bboxes (torch.Tensor | np.ndarray): shape (n, 5) | |||
| labels (torch.Tensor | np.ndarray): shape (n, ) | |||
| num_classes (int): class number, including background class | |||
| Returns: | |||
| list(ndarray): bbox results of each class | |||
| """ | |||
| bbox_len = 5 if kps is None else 5 + 10 # if has kps, add 10 kps into bbox | |||
| if bboxes.shape[0] == 0: | |||
| return [ | |||
| np.zeros((0, bbox_len), dtype=np.float32) | |||
| for i in range(num_classes) | |||
| ] | |||
| else: | |||
| if isinstance(bboxes, torch.Tensor): | |||
| bboxes = bboxes.detach().cpu().numpy() | |||
| labels = labels.detach().cpu().numpy() | |||
| if kps is None: | |||
| return [bboxes[labels == i, :] for i in range(num_classes)] | |||
| else: # with kps | |||
| if isinstance(kps, torch.Tensor): | |||
| kps = kps.detach().cpu().numpy() | |||
| return [ | |||
| np.hstack([bboxes[labels == i, :], kps[labels == i, :]]) | |||
| for i in range(num_classes) | |||
| ] | |||
| def distance2kps(points, distance, max_shape=None): | |||
| """Decode distance prediction to bounding box. | |||
| Args: | |||
| points (Tensor): Shape (n, 2), [x, y]. | |||
| distance (Tensor): Distance from the given point to 4 | |||
| boundaries (left, top, right, bottom). | |||
| max_shape (tuple): Shape of the image. | |||
| Returns: | |||
| Tensor: Decoded kps. | |||
| """ | |||
| preds = [] | |||
| for i in range(0, distance.shape[1], 2): | |||
| px = points[:, i % 2] + distance[:, i] | |||
| py = points[:, i % 2 + 1] + distance[:, i + 1] | |||
| if max_shape is not None: | |||
| px = px.clamp(min=0, max=max_shape[1]) | |||
| py = py.clamp(min=0, max=max_shape[0]) | |||
| preds.append(px) | |||
| preds.append(py) | |||
| return torch.stack(preds, -1) | |||
| def kps2distance(points, kps, max_dis=None, eps=0.1): | |||
| """Decode bounding box based on distances. | |||
| Args: | |||
| points (Tensor): Shape (n, 2), [x, y]. | |||
| kps (Tensor): Shape (n, K), "xyxy" format | |||
| max_dis (float): Upper bound of the distance. | |||
| eps (float): a small value to ensure target < max_dis, instead <= | |||
| Returns: | |||
| Tensor: Decoded distances. | |||
| """ | |||
| preds = [] | |||
| for i in range(0, kps.shape[1], 2): | |||
| px = kps[:, i] - points[:, i % 2] | |||
| py = kps[:, i + 1] - points[:, i % 2 + 1] | |||
| if max_dis is not None: | |||
| px = px.clamp(min=0, max=max_dis - eps) | |||
| py = py.clamp(min=0, max=max_dis - eps) | |||
| preds.append(px) | |||
| preds.append(py) | |||
| return torch.stack(preds, -1) | |||
| @@ -0,0 +1,3 @@ | |||
| from .bbox_nms import multiclass_nms | |||
| __all__ = ['multiclass_nms'] | |||
| @@ -0,0 +1,85 @@ | |||
| """ | |||
| based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/core/post_processing/bbox_nms.py | |||
| """ | |||
| import torch | |||
| def multiclass_nms(multi_bboxes, | |||
| multi_scores, | |||
| score_thr, | |||
| nms_cfg, | |||
| max_num=-1, | |||
| score_factors=None, | |||
| return_inds=False, | |||
| multi_kps=None): | |||
| """NMS for multi-class bboxes. | |||
| Args: | |||
| multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) | |||
| multi_scores (Tensor): shape (n, #class), where the last column | |||
| contains scores of the background class, but this will be ignored. | |||
| score_thr (float): bbox threshold, bboxes with scores lower than it | |||
| will not be considered. | |||
| nms_thr (float): NMS IoU threshold | |||
| max_num (int, optional): if there are more than max_num bboxes after | |||
| NMS, only top max_num will be kept. Default to -1. | |||
| score_factors (Tensor, optional): The factors multiplied to scores | |||
| before applying NMS. Default to None. | |||
| return_inds (bool, optional): Whether return the indices of kept | |||
| bboxes. Default to False. | |||
| Returns: | |||
| tuple: (bboxes, labels, indices (optional)), tensors of shape (k, 5), | |||
| (k), and (k). Labels are 0-based. | |||
| """ | |||
| num_classes = multi_scores.size(1) - 1 | |||
| # exclude background category | |||
| kps = None | |||
| if multi_bboxes.shape[1] > 4: | |||
| bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) | |||
| if multi_kps is not None: | |||
| kps = multi_kps.view(multi_scores.size(0), -1, 10) | |||
| else: | |||
| bboxes = multi_bboxes[:, None].expand( | |||
| multi_scores.size(0), num_classes, 4) | |||
| if multi_kps is not None: | |||
| kps = multi_kps[:, None].expand( | |||
| multi_scores.size(0), num_classes, 10) | |||
| scores = multi_scores[:, :-1] | |||
| if score_factors is not None: | |||
| scores = scores * score_factors[:, None] | |||
| labels = torch.arange(num_classes, dtype=torch.long) | |||
| labels = labels.view(1, -1).expand_as(scores) | |||
| bboxes = bboxes.reshape(-1, 4) | |||
| if kps is not None: | |||
| kps = kps.reshape(-1, 10) | |||
| scores = scores.reshape(-1) | |||
| labels = labels.reshape(-1) | |||
| # remove low scoring boxes | |||
| valid_mask = scores > score_thr | |||
| inds = valid_mask.nonzero(as_tuple=False).squeeze(1) | |||
| bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds] | |||
| if kps is not None: | |||
| kps = kps[inds] | |||
| if inds.numel() == 0: | |||
| if torch.onnx.is_in_onnx_export(): | |||
| raise RuntimeError('[ONNX Error] Can not record NMS ' | |||
| 'as it has not been executed this time') | |||
| return bboxes, labels, kps | |||
| # TODO: add size check before feed into batched_nms | |||
| from mmcv.ops.nms import batched_nms | |||
| dets, keep = batched_nms(bboxes, scores, labels, nms_cfg) | |||
| if max_num > 0: | |||
| dets = dets[:max_num] | |||
| keep = keep[:max_num] | |||
| if return_inds: | |||
| return dets, labels[keep], kps[keep], keep | |||
| else: | |||
| return dets, labels[keep], kps[keep] | |||
| @@ -0,0 +1,3 @@ | |||
| from .retinaface import RetinaFaceDataset | |||
| __all__ = ['RetinaFaceDataset'] | |||
| @@ -0,0 +1,3 @@ | |||
| from .transforms import RandomSquareCrop | |||
| __all__ = ['RandomSquareCrop'] | |||
| @@ -0,0 +1,188 @@ | |||
| """ | |||
| based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/transforms.py | |||
| """ | |||
| import numpy as np | |||
| from mmdet.datasets.builder import PIPELINES | |||
| from numpy import random | |||
| @PIPELINES.register_module() | |||
| class RandomSquareCrop(object): | |||
| """Random crop the image & bboxes, the cropped patches have minimum IoU | |||
| requirement with original image & bboxes, the IoU threshold is randomly | |||
| selected from min_ious. | |||
| Args: | |||
| min_ious (tuple): minimum IoU threshold for all intersections with | |||
| bounding boxes | |||
| min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w, | |||
| where a >= min_crop_size). | |||
| Note: | |||
| The keys for bboxes, labels and masks should be paired. That is, \ | |||
| `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \ | |||
| `gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`. | |||
| """ | |||
| def __init__(self, | |||
| crop_ratio_range=None, | |||
| crop_choice=None, | |||
| bbox_clip_border=True): | |||
| self.crop_ratio_range = crop_ratio_range | |||
| self.crop_choice = crop_choice | |||
| self.bbox_clip_border = bbox_clip_border | |||
| assert (self.crop_ratio_range is None) ^ (self.crop_choice is None) | |||
| if self.crop_ratio_range is not None: | |||
| self.crop_ratio_min, self.crop_ratio_max = self.crop_ratio_range | |||
| self.bbox2label = { | |||
| 'gt_bboxes': 'gt_labels', | |||
| 'gt_bboxes_ignore': 'gt_labels_ignore' | |||
| } | |||
| self.bbox2mask = { | |||
| 'gt_bboxes': 'gt_masks', | |||
| 'gt_bboxes_ignore': 'gt_masks_ignore' | |||
| } | |||
| def __call__(self, results): | |||
| """Call function to crop images and bounding boxes with minimum IoU | |||
| constraint. | |||
| Args: | |||
| results (dict): Result dict from loading pipeline. | |||
| Returns: | |||
| dict: Result dict with images and bounding boxes cropped, \ | |||
| 'img_shape' key is updated. | |||
| """ | |||
| if 'img_fields' in results: | |||
| assert results['img_fields'] == ['img'], \ | |||
| 'Only single img_fields is allowed' | |||
| img = results['img'] | |||
| assert 'bbox_fields' in results | |||
| assert 'gt_bboxes' in results | |||
| boxes = results['gt_bboxes'] | |||
| h, w, c = img.shape | |||
| scale_retry = 0 | |||
| if self.crop_ratio_range is not None: | |||
| max_scale = self.crop_ratio_max | |||
| else: | |||
| max_scale = np.amax(self.crop_choice) | |||
| while True: | |||
| scale_retry += 1 | |||
| if scale_retry == 1 or max_scale > 1.0: | |||
| if self.crop_ratio_range is not None: | |||
| scale = np.random.uniform(self.crop_ratio_min, | |||
| self.crop_ratio_max) | |||
| elif self.crop_choice is not None: | |||
| scale = np.random.choice(self.crop_choice) | |||
| else: | |||
| scale = scale * 1.2 | |||
| for i in range(250): | |||
| short_side = min(w, h) | |||
| cw = int(scale * short_side) | |||
| ch = cw | |||
| # TODO +1 | |||
| if w == cw: | |||
| left = 0 | |||
| elif w > cw: | |||
| left = random.randint(0, w - cw) | |||
| else: | |||
| left = random.randint(w - cw, 0) | |||
| if h == ch: | |||
| top = 0 | |||
| elif h > ch: | |||
| top = random.randint(0, h - ch) | |||
| else: | |||
| top = random.randint(h - ch, 0) | |||
| patch = np.array( | |||
| (int(left), int(top), int(left + cw), int(top + ch)), | |||
| dtype=np.int) | |||
| # center of boxes should inside the crop img | |||
| # only adjust boxes and instance masks when the gt is not empty | |||
| # adjust boxes | |||
| def is_center_of_bboxes_in_patch(boxes, patch): | |||
| # TODO >= | |||
| center = (boxes[:, :2] + boxes[:, 2:]) / 2 | |||
| mask = \ | |||
| ((center[:, 0] > patch[0]) | |||
| * (center[:, 1] > patch[1]) | |||
| * (center[:, 0] < patch[2]) | |||
| * (center[:, 1] < patch[3])) | |||
| return mask | |||
| mask = is_center_of_bboxes_in_patch(boxes, patch) | |||
| if not mask.any(): | |||
| continue | |||
| for key in results.get('bbox_fields', []): | |||
| boxes = results[key].copy() | |||
| mask = is_center_of_bboxes_in_patch(boxes, patch) | |||
| boxes = boxes[mask] | |||
| if self.bbox_clip_border: | |||
| boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) | |||
| boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) | |||
| boxes -= np.tile(patch[:2], 2) | |||
| results[key] = boxes | |||
| # labels | |||
| label_key = self.bbox2label.get(key) | |||
| if label_key in results: | |||
| results[label_key] = results[label_key][mask] | |||
| # keypoints field | |||
| if key == 'gt_bboxes': | |||
| for kps_key in results.get('keypoints_fields', []): | |||
| keypointss = results[kps_key].copy() | |||
| keypointss = keypointss[mask, :, :] | |||
| if self.bbox_clip_border: | |||
| keypointss[:, :, : | |||
| 2] = keypointss[:, :, :2].clip( | |||
| max=patch[2:]) | |||
| keypointss[:, :, : | |||
| 2] = keypointss[:, :, :2].clip( | |||
| min=patch[:2]) | |||
| keypointss[:, :, 0] -= patch[0] | |||
| keypointss[:, :, 1] -= patch[1] | |||
| results[kps_key] = keypointss | |||
| # mask fields | |||
| mask_key = self.bbox2mask.get(key) | |||
| if mask_key in results: | |||
| results[mask_key] = results[mask_key][mask.nonzero() | |||
| [0]].crop(patch) | |||
| # adjust the img no matter whether the gt is empty before crop | |||
| rimg = np.ones((ch, cw, 3), dtype=img.dtype) * 128 | |||
| patch_from = patch.copy() | |||
| patch_from[0] = max(0, patch_from[0]) | |||
| patch_from[1] = max(0, patch_from[1]) | |||
| patch_from[2] = min(img.shape[1], patch_from[2]) | |||
| patch_from[3] = min(img.shape[0], patch_from[3]) | |||
| patch_to = patch.copy() | |||
| patch_to[0] = max(0, patch_to[0] * -1) | |||
| patch_to[1] = max(0, patch_to[1] * -1) | |||
| patch_to[2] = patch_to[0] + (patch_from[2] - patch_from[0]) | |||
| patch_to[3] = patch_to[1] + (patch_from[3] - patch_from[1]) | |||
| rimg[patch_to[1]:patch_to[3], | |||
| patch_to[0]:patch_to[2], :] = img[ | |||
| patch_from[1]:patch_from[3], | |||
| patch_from[0]:patch_from[2], :] | |||
| img = rimg | |||
| results['img'] = img | |||
| results['img_shape'] = img.shape | |||
| return results | |||
| def __repr__(self): | |||
| repr_str = self.__class__.__name__ | |||
| repr_str += f'(min_ious={self.min_iou}, ' | |||
| repr_str += f'crop_size={self.crop_size})' | |||
| return repr_str | |||
| @@ -0,0 +1,151 @@ | |||
| """ | |||
| based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/retinaface.py | |||
| """ | |||
| import numpy as np | |||
| from mmdet.datasets.builder import DATASETS | |||
| from mmdet.datasets.custom import CustomDataset | |||
| @DATASETS.register_module() | |||
| class RetinaFaceDataset(CustomDataset): | |||
| CLASSES = ('FG', ) | |||
| def __init__(self, min_size=None, **kwargs): | |||
| self.NK = 5 | |||
| self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)} | |||
| self.min_size = min_size | |||
| self.gt_path = kwargs.get('gt_path') | |||
| super(RetinaFaceDataset, self).__init__(**kwargs) | |||
| def _parse_ann_line(self, line): | |||
| values = [float(x) for x in line.strip().split()] | |||
| bbox = np.array(values[0:4], dtype=np.float32) | |||
| kps = np.zeros((self.NK, 3), dtype=np.float32) | |||
| ignore = False | |||
| if self.min_size is not None: | |||
| assert not self.test_mode | |||
| w = bbox[2] - bbox[0] | |||
| h = bbox[3] - bbox[1] | |||
| if w < self.min_size or h < self.min_size: | |||
| ignore = True | |||
| if len(values) > 4: | |||
| if len(values) > 5: | |||
| kps = np.array( | |||
| values[4:19], dtype=np.float32).reshape((self.NK, 3)) | |||
| for li in range(kps.shape[0]): | |||
| if (kps[li, :] == -1).all(): | |||
| kps[li][2] = 0.0 # weight = 0, ignore | |||
| else: | |||
| assert kps[li][2] >= 0 | |||
| kps[li][2] = 1.0 # weight | |||
| else: # len(values)==5 | |||
| if not ignore: | |||
| ignore = (values[4] == 1) | |||
| else: | |||
| assert self.test_mode | |||
| return dict(bbox=bbox, kps=kps, ignore=ignore, cat='FG') | |||
| def load_annotations(self, ann_file): | |||
| """Load annotation from COCO style annotation file. | |||
| Args: | |||
| ann_file (str): Path of annotation file. | |||
| 20220711@tyx: ann_file is list of img paths is supported | |||
| Returns: | |||
| list[dict]: Annotation info from COCO api. | |||
| """ | |||
| if isinstance(ann_file, list): | |||
| data_infos = [] | |||
| for line in ann_file: | |||
| name = line | |||
| objs = [0, 0, 0, 0] | |||
| data_infos.append( | |||
| dict(filename=name, width=0, height=0, objs=objs)) | |||
| else: | |||
| name = None | |||
| bbox_map = {} | |||
| for line in open(ann_file, 'r'): | |||
| line = line.strip() | |||
| if line.startswith('#'): | |||
| value = line[1:].strip().split() | |||
| name = value[0] | |||
| width = int(value[1]) | |||
| height = int(value[2]) | |||
| bbox_map[name] = dict(width=width, height=height, objs=[]) | |||
| continue | |||
| assert name is not None | |||
| assert name in bbox_map | |||
| bbox_map[name]['objs'].append(line) | |||
| print('origin image size', len(bbox_map)) | |||
| data_infos = [] | |||
| for name in bbox_map: | |||
| item = bbox_map[name] | |||
| width = item['width'] | |||
| height = item['height'] | |||
| vals = item['objs'] | |||
| objs = [] | |||
| for line in vals: | |||
| data = self._parse_ann_line(line) | |||
| if data is None: | |||
| continue | |||
| objs.append(data) # data is (bbox, kps, cat) | |||
| if len(objs) == 0 and not self.test_mode: | |||
| continue | |||
| data_infos.append( | |||
| dict(filename=name, width=width, height=height, objs=objs)) | |||
| return data_infos | |||
| def get_ann_info(self, idx): | |||
| """Get COCO annotation by index. | |||
| Args: | |||
| idx (int): Index of data. | |||
| Returns: | |||
| dict: Annotation info of specified index. | |||
| """ | |||
| data_info = self.data_infos[idx] | |||
| bboxes = [] | |||
| keypointss = [] | |||
| labels = [] | |||
| bboxes_ignore = [] | |||
| labels_ignore = [] | |||
| for obj in data_info['objs']: | |||
| label = self.cat2label[obj['cat']] | |||
| bbox = obj['bbox'] | |||
| keypoints = obj['kps'] | |||
| ignore = obj['ignore'] | |||
| if ignore: | |||
| bboxes_ignore.append(bbox) | |||
| labels_ignore.append(label) | |||
| else: | |||
| bboxes.append(bbox) | |||
| labels.append(label) | |||
| keypointss.append(keypoints) | |||
| if not bboxes: | |||
| bboxes = np.zeros((0, 4)) | |||
| labels = np.zeros((0, )) | |||
| keypointss = np.zeros((0, self.NK, 3)) | |||
| else: | |||
| # bboxes = np.array(bboxes, ndmin=2) - 1 | |||
| bboxes = np.array(bboxes, ndmin=2) | |||
| labels = np.array(labels) | |||
| keypointss = np.array(keypointss, ndmin=3) | |||
| if not bboxes_ignore: | |||
| bboxes_ignore = np.zeros((0, 4)) | |||
| labels_ignore = np.zeros((0, )) | |||
| else: | |||
| bboxes_ignore = np.array(bboxes_ignore, ndmin=2) | |||
| labels_ignore = np.array(labels_ignore) | |||
| ann = dict( | |||
| bboxes=bboxes.astype(np.float32), | |||
| labels=labels.astype(np.int64), | |||
| keypointss=keypointss.astype(np.float32), | |||
| bboxes_ignore=bboxes_ignore.astype(np.float32), | |||
| labels_ignore=labels_ignore.astype(np.int64)) | |||
| return ann | |||
| @@ -0,0 +1,2 @@ | |||
| from .dense_heads import * # noqa: F401,F403 | |||
| from .detectors import * # noqa: F401,F403 | |||
| @@ -0,0 +1,3 @@ | |||
| from .resnet import ResNetV1e | |||
| __all__ = ['ResNetV1e'] | |||
| @@ -0,0 +1,412 @@ | |||
| """ | |||
| based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/backbones/resnet.py | |||
| """ | |||
| import torch.nn as nn | |||
| import torch.utils.checkpoint as cp | |||
| from mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer, | |||
| constant_init, kaiming_init) | |||
| from mmcv.runner import load_checkpoint | |||
| from mmdet.models.backbones.resnet import BasicBlock, Bottleneck | |||
| from mmdet.models.builder import BACKBONES | |||
| from mmdet.models.utils import ResLayer | |||
| from mmdet.utils import get_root_logger | |||
| from torch.nn.modules.batchnorm import _BatchNorm | |||
| class ResNet(nn.Module): | |||
| """ResNet backbone. | |||
| Args: | |||
| depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. | |||
| stem_channels (int | None): Number of stem channels. If not specified, | |||
| it will be the same as `base_channels`. Default: None. | |||
| base_channels (int): Number of base channels of res layer. Default: 64. | |||
| in_channels (int): Number of input image channels. Default: 3. | |||
| num_stages (int): Resnet stages. Default: 4. | |||
| strides (Sequence[int]): Strides of the first block of each stage. | |||
| dilations (Sequence[int]): Dilation of each stage. | |||
| out_indices (Sequence[int]): Output from which stages. | |||
| style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two | |||
| layer is the 3x3 conv layer, otherwise the stride-two layer is | |||
| the first 1x1 conv layer. | |||
| deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv | |||
| avg_down (bool): Use AvgPool instead of stride conv when | |||
| downsampling in the bottleneck. | |||
| frozen_stages (int): Stages to be frozen (stop grad and set eval mode). | |||
| -1 means not freezing any parameters. | |||
| norm_cfg (dict): Dictionary to construct and config norm layer. | |||
| norm_eval (bool): Whether to set norm layers to eval mode, namely, | |||
| freeze running stats (mean and var). Note: Effect on Batch Norm | |||
| and its variants only. | |||
| plugins (list[dict]): List of plugins for stages, each dict contains: | |||
| - cfg (dict, required): Cfg dict to build plugin. | |||
| - position (str, required): Position inside block to insert | |||
| plugin, options are 'after_conv1', 'after_conv2', 'after_conv3'. | |||
| - stages (tuple[bool], optional): Stages to apply plugin, length | |||
| should be same as 'num_stages'. | |||
| with_cp (bool): Use checkpoint or not. Using checkpoint will save some | |||
| memory while slowing down the training speed. | |||
| zero_init_residual (bool): Whether to use zero init for last norm layer | |||
| in resblocks to let them behave as identity. | |||
| Example: | |||
| >>> from mmdet.models import ResNet | |||
| >>> import torch | |||
| >>> self = ResNet(depth=18) | |||
| >>> self.eval() | |||
| >>> inputs = torch.rand(1, 3, 32, 32) | |||
| >>> level_outputs = self.forward(inputs) | |||
| >>> for level_out in level_outputs: | |||
| ... print(tuple(level_out.shape)) | |||
| (1, 64, 8, 8) | |||
| (1, 128, 4, 4) | |||
| (1, 256, 2, 2) | |||
| (1, 512, 1, 1) | |||
| """ | |||
| arch_settings = { | |||
| 0: (BasicBlock, (2, 2, 2, 2)), | |||
| 18: (BasicBlock, (2, 2, 2, 2)), | |||
| 19: (BasicBlock, (2, 4, 4, 1)), | |||
| 20: (BasicBlock, (2, 3, 2, 2)), | |||
| 22: (BasicBlock, (2, 4, 3, 1)), | |||
| 24: (BasicBlock, (2, 4, 4, 1)), | |||
| 26: (BasicBlock, (2, 4, 4, 2)), | |||
| 28: (BasicBlock, (2, 5, 4, 2)), | |||
| 29: (BasicBlock, (2, 6, 3, 2)), | |||
| 30: (BasicBlock, (2, 5, 5, 2)), | |||
| 32: (BasicBlock, (2, 6, 5, 2)), | |||
| 34: (BasicBlock, (3, 4, 6, 3)), | |||
| 35: (BasicBlock, (3, 6, 4, 3)), | |||
| 38: (BasicBlock, (3, 8, 4, 3)), | |||
| 40: (BasicBlock, (3, 8, 5, 3)), | |||
| 50: (Bottleneck, (3, 4, 6, 3)), | |||
| 56: (Bottleneck, (3, 8, 4, 3)), | |||
| 68: (Bottleneck, (3, 10, 6, 3)), | |||
| 74: (Bottleneck, (3, 12, 6, 3)), | |||
| 101: (Bottleneck, (3, 4, 23, 3)), | |||
| 152: (Bottleneck, (3, 8, 36, 3)) | |||
| } | |||
| def __init__(self, | |||
| depth, | |||
| in_channels=3, | |||
| stem_channels=None, | |||
| base_channels=64, | |||
| num_stages=4, | |||
| block_cfg=None, | |||
| strides=(1, 2, 2, 2), | |||
| dilations=(1, 1, 1, 1), | |||
| out_indices=(0, 1, 2, 3), | |||
| style='pytorch', | |||
| deep_stem=False, | |||
| avg_down=False, | |||
| no_pool33=False, | |||
| frozen_stages=-1, | |||
| conv_cfg=None, | |||
| norm_cfg=dict(type='BN', requires_grad=True), | |||
| norm_eval=True, | |||
| dcn=None, | |||
| stage_with_dcn=(False, False, False, False), | |||
| plugins=None, | |||
| with_cp=False, | |||
| zero_init_residual=True): | |||
| super(ResNet, self).__init__() | |||
| if depth not in self.arch_settings: | |||
| raise KeyError(f'invalid depth {depth} for resnet') | |||
| self.depth = depth | |||
| if stem_channels is None: | |||
| stem_channels = base_channels | |||
| self.stem_channels = stem_channels | |||
| self.base_channels = base_channels | |||
| self.num_stages = num_stages | |||
| assert num_stages >= 1 and num_stages <= 4 | |||
| self.strides = strides | |||
| self.dilations = dilations | |||
| assert len(strides) == len(dilations) == num_stages | |||
| self.out_indices = out_indices | |||
| assert max(out_indices) < num_stages | |||
| self.style = style | |||
| self.deep_stem = deep_stem | |||
| self.avg_down = avg_down | |||
| self.no_pool33 = no_pool33 | |||
| self.frozen_stages = frozen_stages | |||
| self.conv_cfg = conv_cfg | |||
| self.norm_cfg = norm_cfg | |||
| self.with_cp = with_cp | |||
| self.norm_eval = norm_eval | |||
| self.dcn = dcn | |||
| self.stage_with_dcn = stage_with_dcn | |||
| if dcn is not None: | |||
| assert len(stage_with_dcn) == num_stages | |||
| self.plugins = plugins | |||
| self.zero_init_residual = zero_init_residual | |||
| if block_cfg is None: | |||
| self.block, stage_blocks = self.arch_settings[depth] | |||
| else: | |||
| self.block = BasicBlock if block_cfg[ | |||
| 'block'] == 'BasicBlock' else Bottleneck | |||
| stage_blocks = block_cfg['stage_blocks'] | |||
| assert len(stage_blocks) >= num_stages | |||
| self.stage_blocks = stage_blocks[:num_stages] | |||
| self.inplanes = stem_channels | |||
| self._make_stem_layer(in_channels, stem_channels) | |||
| if block_cfg is not None and 'stage_planes' in block_cfg: | |||
| stage_planes = block_cfg['stage_planes'] | |||
| else: | |||
| stage_planes = [base_channels * 2**i for i in range(num_stages)] | |||
| # print('resnet cfg:', stage_blocks, stage_planes) | |||
| self.res_layers = [] | |||
| for i, num_blocks in enumerate(self.stage_blocks): | |||
| stride = strides[i] | |||
| dilation = dilations[i] | |||
| dcn = self.dcn if self.stage_with_dcn[i] else None | |||
| if plugins is not None: | |||
| stage_plugins = self.make_stage_plugins(plugins, i) | |||
| else: | |||
| stage_plugins = None | |||
| planes = stage_planes[i] | |||
| res_layer = self.make_res_layer( | |||
| block=self.block, | |||
| inplanes=self.inplanes, | |||
| planes=planes, | |||
| num_blocks=num_blocks, | |||
| stride=stride, | |||
| dilation=dilation, | |||
| style=self.style, | |||
| avg_down=self.avg_down, | |||
| with_cp=with_cp, | |||
| conv_cfg=conv_cfg, | |||
| norm_cfg=norm_cfg, | |||
| dcn=dcn, | |||
| plugins=stage_plugins) | |||
| self.inplanes = planes * self.block.expansion | |||
| layer_name = f'layer{i + 1}' | |||
| self.add_module(layer_name, res_layer) | |||
| self.res_layers.append(layer_name) | |||
| self._freeze_stages() | |||
| self.feat_dim = self.block.expansion * base_channels * 2**( | |||
| len(self.stage_blocks) - 1) | |||
| def make_stage_plugins(self, plugins, stage_idx): | |||
| """Make plugins for ResNet ``stage_idx`` th stage. | |||
| Currently we support to insert ``context_block``, | |||
| ``empirical_attention_block``, ``nonlocal_block`` into the backbone | |||
| like ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of | |||
| Bottleneck. | |||
| An example of plugins format could be: | |||
| Examples: | |||
| >>> plugins=[ | |||
| ... dict(cfg=dict(type='xxx', arg1='xxx'), | |||
| ... stages=(False, True, True, True), | |||
| ... position='after_conv2'), | |||
| ... dict(cfg=dict(type='yyy'), | |||
| ... stages=(True, True, True, True), | |||
| ... position='after_conv3'), | |||
| ... dict(cfg=dict(type='zzz', postfix='1'), | |||
| ... stages=(True, True, True, True), | |||
| ... position='after_conv3'), | |||
| ... dict(cfg=dict(type='zzz', postfix='2'), | |||
| ... stages=(True, True, True, True), | |||
| ... position='after_conv3') | |||
| ... ] | |||
| >>> self = ResNet(depth=18) | |||
| >>> stage_plugins = self.make_stage_plugins(plugins, 0) | |||
| >>> assert len(stage_plugins) == 3 | |||
| Suppose ``stage_idx=0``, the structure of blocks in the stage would be: | |||
| .. code-block:: none | |||
| conv1-> conv2->conv3->yyy->zzz1->zzz2 | |||
| Suppose 'stage_idx=1', the structure of blocks in the stage would be: | |||
| .. code-block:: none | |||
| conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2 | |||
| If stages is missing, the plugin would be applied to all stages. | |||
| Args: | |||
| plugins (list[dict]): List of plugins cfg to build. The postfix is | |||
| required if multiple same type plugins are inserted. | |||
| stage_idx (int): Index of stage to build | |||
| Returns: | |||
| list[dict]: Plugins for current stage | |||
| """ | |||
| stage_plugins = [] | |||
| for plugin in plugins: | |||
| plugin = plugin.copy() | |||
| stages = plugin.pop('stages', None) | |||
| assert stages is None or len(stages) == self.num_stages | |||
| # whether to insert plugin into current stage | |||
| if stages is None or stages[stage_idx]: | |||
| stage_plugins.append(plugin) | |||
| return stage_plugins | |||
| def make_res_layer(self, **kwargs): | |||
| """Pack all blocks in a stage into a ``ResLayer``.""" | |||
| return ResLayer(**kwargs) | |||
| @property | |||
| def norm1(self): | |||
| """nn.Module: the normalization layer named "norm1" """ | |||
| return getattr(self, self.norm1_name) | |||
| def _make_stem_layer(self, in_channels, stem_channels): | |||
| if self.deep_stem: | |||
| self.stem = nn.Sequential( | |||
| build_conv_layer( | |||
| self.conv_cfg, | |||
| in_channels, | |||
| stem_channels // 2, | |||
| kernel_size=3, | |||
| stride=2, | |||
| padding=1, | |||
| bias=False), | |||
| build_norm_layer(self.norm_cfg, stem_channels // 2)[1], | |||
| nn.ReLU(inplace=True), | |||
| build_conv_layer( | |||
| self.conv_cfg, | |||
| stem_channels // 2, | |||
| stem_channels // 2, | |||
| kernel_size=3, | |||
| stride=1, | |||
| padding=1, | |||
| bias=False), | |||
| build_norm_layer(self.norm_cfg, stem_channels // 2)[1], | |||
| nn.ReLU(inplace=True), | |||
| build_conv_layer( | |||
| self.conv_cfg, | |||
| stem_channels // 2, | |||
| stem_channels, | |||
| kernel_size=3, | |||
| stride=1, | |||
| padding=1, | |||
| bias=False), | |||
| build_norm_layer(self.norm_cfg, stem_channels)[1], | |||
| nn.ReLU(inplace=True)) | |||
| else: | |||
| self.conv1 = build_conv_layer( | |||
| self.conv_cfg, | |||
| in_channels, | |||
| stem_channels, | |||
| kernel_size=7, | |||
| stride=2, | |||
| padding=3, | |||
| bias=False) | |||
| self.norm1_name, norm1 = build_norm_layer( | |||
| self.norm_cfg, stem_channels, postfix=1) | |||
| self.add_module(self.norm1_name, norm1) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| if self.no_pool33: | |||
| assert self.deep_stem | |||
| self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) | |||
| else: | |||
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||
| def _freeze_stages(self): | |||
| if self.frozen_stages >= 0: | |||
| if self.deep_stem: | |||
| self.stem.eval() | |||
| for param in self.stem.parameters(): | |||
| param.requires_grad = False | |||
| else: | |||
| self.norm1.eval() | |||
| for m in [self.conv1, self.norm1]: | |||
| for param in m.parameters(): | |||
| param.requires_grad = False | |||
| for i in range(1, self.frozen_stages + 1): | |||
| m = getattr(self, f'layer{i}') | |||
| m.eval() | |||
| for param in m.parameters(): | |||
| param.requires_grad = False | |||
| def init_weights(self, pretrained=None): | |||
| """Initialize the weights in backbone. | |||
| Args: | |||
| pretrained (str, optional): Path to pre-trained weights. | |||
| Defaults to None. | |||
| """ | |||
| if isinstance(pretrained, str): | |||
| logger = get_root_logger() | |||
| load_checkpoint(self, pretrained, strict=False, logger=logger) | |||
| elif pretrained is None: | |||
| for m in self.modules(): | |||
| if isinstance(m, nn.Conv2d): | |||
| kaiming_init(m) | |||
| elif isinstance(m, (_BatchNorm, nn.GroupNorm)): | |||
| constant_init(m, 1) | |||
| if self.dcn is not None: | |||
| for m in self.modules(): | |||
| if isinstance(m, Bottleneck) and hasattr( | |||
| m.conv2, 'conv_offset'): | |||
| constant_init(m.conv2.conv_offset, 0) | |||
| if self.zero_init_residual: | |||
| for m in self.modules(): | |||
| if isinstance(m, Bottleneck): | |||
| constant_init(m.norm3, 0) | |||
| elif isinstance(m, BasicBlock): | |||
| constant_init(m.norm2, 0) | |||
| else: | |||
| raise TypeError('pretrained must be a str or None') | |||
| def forward(self, x): | |||
| """Forward function.""" | |||
| if self.deep_stem: | |||
| x = self.stem(x) | |||
| else: | |||
| x = self.conv1(x) | |||
| x = self.norm1(x) | |||
| x = self.relu(x) | |||
| x = self.maxpool(x) | |||
| outs = [] | |||
| for i, layer_name in enumerate(self.res_layers): | |||
| res_layer = getattr(self, layer_name) | |||
| x = res_layer(x) | |||
| if i in self.out_indices: | |||
| outs.append(x) | |||
| return tuple(outs) | |||
| def train(self, mode=True): | |||
| """Convert the model into training mode while keep normalization layer | |||
| freezed.""" | |||
| super(ResNet, self).train(mode) | |||
| self._freeze_stages() | |||
| if mode and self.norm_eval: | |||
| for m in self.modules(): | |||
| # trick: eval have effect on BatchNorm only | |||
| if isinstance(m, _BatchNorm): | |||
| m.eval() | |||
| @BACKBONES.register_module() | |||
| class ResNetV1e(ResNet): | |||
| r"""ResNetV1d variant described in `Bag of Tricks | |||
| <https://arxiv.org/pdf/1812.01187.pdf>`_. | |||
| Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in | |||
| the input stem with three 3x3 convs. And in the downsampling block, a 2x2 | |||
| avg_pool with stride 2 is added before conv, whose stride is changed to 1. | |||
| Compared with ResNetV1d, ResNetV1e change maxpooling from 3x3 to 2x2 pad=1 | |||
| """ | |||
| def __init__(self, **kwargs): | |||
| super(ResNetV1e, self).__init__( | |||
| deep_stem=True, avg_down=True, no_pool33=True, **kwargs) | |||
| @@ -0,0 +1,3 @@ | |||
| from .scrfd_head import SCRFDHead | |||
| __all__ = ['SCRFDHead'] | |||
| @@ -0,0 +1,3 @@ | |||
| from .scrfd import SCRFD | |||
| __all__ = ['SCRFD'] | |||
| @@ -0,0 +1,109 @@ | |||
| """ | |||
| based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/detectors/scrfd.py | |||
| """ | |||
| import torch | |||
| from mmdet.models.builder import DETECTORS | |||
| from mmdet.models.detectors.single_stage import SingleStageDetector | |||
| from ....mmdet_patch.core.bbox import bbox2result | |||
| @DETECTORS.register_module() | |||
| class SCRFD(SingleStageDetector): | |||
| def __init__(self, | |||
| backbone, | |||
| neck, | |||
| bbox_head, | |||
| train_cfg=None, | |||
| test_cfg=None, | |||
| pretrained=None): | |||
| super(SCRFD, self).__init__(backbone, neck, bbox_head, train_cfg, | |||
| test_cfg, pretrained) | |||
| def forward_train(self, | |||
| img, | |||
| img_metas, | |||
| gt_bboxes, | |||
| gt_labels, | |||
| gt_keypointss=None, | |||
| gt_bboxes_ignore=None): | |||
| """ | |||
| Args: | |||
| img (Tensor): Input images of shape (N, C, H, W). | |||
| Typically these should be mean centered and std scaled. | |||
| img_metas (list[dict]): A List of image info dict where each dict | |||
| has: 'img_shape', 'scale_factor', 'flip', and may also contain | |||
| 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |||
| For details on the values of these keys see | |||
| :class:`mmdet.datasets.pipelines.Collect`. | |||
| gt_bboxes (list[Tensor]): Each item are the truth boxes for each | |||
| image in [tl_x, tl_y, br_x, br_y] format. | |||
| gt_labels (list[Tensor]): Class indices corresponding to each box | |||
| gt_bboxes_ignore (None | list[Tensor]): Specify which bounding | |||
| boxes can be ignored when computing the loss. | |||
| Returns: | |||
| dict[str, Tensor]: A dictionary of loss components. | |||
| """ | |||
| super(SingleStageDetector, self).forward_train(img, img_metas) | |||
| x = self.extract_feat(img) | |||
| losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes, | |||
| gt_labels, gt_keypointss, | |||
| gt_bboxes_ignore) | |||
| return losses | |||
| def simple_test(self, img, img_metas, rescale=False): | |||
| """Test function without test time augmentation. | |||
| Args: | |||
| imgs (list[torch.Tensor]): List of multiple images | |||
| img_metas (list[dict]): List of image information. | |||
| rescale (bool, optional): Whether to rescale the results. | |||
| Defaults to False. | |||
| Returns: | |||
| list[list[np.ndarray]]: BBox results of each image and classes. | |||
| The outer list corresponds to each image. The inner list | |||
| corresponds to each class. | |||
| """ | |||
| x = self.extract_feat(img) | |||
| outs = self.bbox_head(x) | |||
| if torch.onnx.is_in_onnx_export(): | |||
| print('single_stage.py in-onnx-export') | |||
| print(outs.__class__) | |||
| cls_score, bbox_pred, kps_pred = outs | |||
| for c in cls_score: | |||
| print(c.shape) | |||
| for c in bbox_pred: | |||
| print(c.shape) | |||
| if self.bbox_head.use_kps: | |||
| for c in kps_pred: | |||
| print(c.shape) | |||
| return (cls_score, bbox_pred, kps_pred) | |||
| else: | |||
| return (cls_score, bbox_pred) | |||
| bbox_list = self.bbox_head.get_bboxes( | |||
| *outs, img_metas, rescale=rescale) | |||
| # return kps if use_kps | |||
| if len(bbox_list[0]) == 2: | |||
| bbox_results = [ | |||
| bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) | |||
| for det_bboxes, det_labels in bbox_list | |||
| ] | |||
| elif len(bbox_list[0]) == 3: | |||
| bbox_results = [ | |||
| bbox2result( | |||
| det_bboxes, | |||
| det_labels, | |||
| self.bbox_head.num_classes, | |||
| kps=det_kps) | |||
| for det_bboxes, det_labels, det_kps in bbox_list | |||
| ] | |||
| return bbox_results | |||
| def feature_test(self, img): | |||
| x = self.extract_feat(img) | |||
| outs = self.bbox_head(x) | |||
| return outs | |||
| @@ -0,0 +1,50 @@ | |||
| import cv2 | |||
| import numpy as np | |||
| from skimage import transform as trans | |||
| def align_face(image, size, lmks): | |||
| dst_w = size[1] | |||
| dst_h = size[0] | |||
| # landmark calculation of dst images | |||
| base_w = 96 | |||
| base_h = 112 | |||
| assert (dst_w >= base_w) | |||
| assert (dst_h >= base_h) | |||
| base_lmk = [ | |||
| 30.2946, 51.6963, 65.5318, 51.5014, 48.0252, 71.7366, 33.5493, 92.3655, | |||
| 62.7299, 92.2041 | |||
| ] | |||
| dst_lmk = np.array(base_lmk).reshape((5, 2)).astype(np.float32) | |||
| if dst_w != base_w: | |||
| slide = (dst_w - base_w) / 2 | |||
| dst_lmk[:, 0] += slide | |||
| if dst_h != base_h: | |||
| slide = (dst_h - base_h) / 2 | |||
| dst_lmk[:, 1] += slide | |||
| src_lmk = lmks | |||
| # using skimage method | |||
| tform = trans.SimilarityTransform() | |||
| tform.estimate(src_lmk, dst_lmk) | |||
| t = tform.params[0:2, :] | |||
| assert (image.shape[2] == 3) | |||
| dst_image = cv2.warpAffine(image.copy(), t, (dst_w, dst_h)) | |||
| dst_pts = GetAffinePoints(src_lmk, t) | |||
| return dst_image, dst_pts | |||
| def GetAffinePoints(pts_in, trans): | |||
| pts_out = pts_in.copy() | |||
| assert (pts_in.shape[1] == 2) | |||
| for k in range(pts_in.shape[0]): | |||
| pts_out[k, 0] = pts_in[k, 0] * trans[0, 0] + pts_in[k, 1] * trans[ | |||
| 0, 1] + trans[0, 2] | |||
| pts_out[k, 1] = pts_in[k, 0] * trans[1, 0] + pts_in[k, 1] * trans[ | |||
| 1, 1] + trans[1, 2] | |||
| return pts_out | |||
| @@ -0,0 +1,31 @@ | |||
| from .model_irse import (IR_18, IR_34, IR_50, IR_101, IR_152, IR_200, IR_SE_50, | |||
| IR_SE_101, IR_SE_152, IR_SE_200) | |||
| from .model_resnet import ResNet_50, ResNet_101, ResNet_152 | |||
| _model_dict = { | |||
| 'ResNet_50': ResNet_50, | |||
| 'ResNet_101': ResNet_101, | |||
| 'ResNet_152': ResNet_152, | |||
| 'IR_18': IR_18, | |||
| 'IR_34': IR_34, | |||
| 'IR_50': IR_50, | |||
| 'IR_101': IR_101, | |||
| 'IR_152': IR_152, | |||
| 'IR_200': IR_200, | |||
| 'IR_SE_50': IR_SE_50, | |||
| 'IR_SE_101': IR_SE_101, | |||
| 'IR_SE_152': IR_SE_152, | |||
| 'IR_SE_200': IR_SE_200 | |||
| } | |||
| def get_model(key): | |||
| """ Get different backbone network by key, | |||
| support ResNet50, ResNet_101, ResNet_152 | |||
| IR_18, IR_34, IR_50, IR_101, IR_152, IR_200, | |||
| IR_SE_50, IR_SE_101, IR_SE_152, IR_SE_200. | |||
| """ | |||
| if key in _model_dict.keys(): | |||
| return _model_dict[key] | |||
| else: | |||
| raise KeyError('not support model {}'.format(key)) | |||
| @@ -0,0 +1,68 @@ | |||
| import torch | |||
| import torch.nn as nn | |||
| from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Linear, Module, ReLU, | |||
| Sigmoid) | |||
| def initialize_weights(modules): | |||
| """ Weight initilize, conv2d and linear is initialized with kaiming_normal | |||
| """ | |||
| for m in modules: | |||
| if isinstance(m, nn.Conv2d): | |||
| nn.init.kaiming_normal_( | |||
| m.weight, mode='fan_out', nonlinearity='relu') | |||
| if m.bias is not None: | |||
| m.bias.data.zero_() | |||
| elif isinstance(m, nn.BatchNorm2d): | |||
| m.weight.data.fill_(1) | |||
| m.bias.data.zero_() | |||
| elif isinstance(m, nn.Linear): | |||
| nn.init.kaiming_normal_( | |||
| m.weight, mode='fan_out', nonlinearity='relu') | |||
| if m.bias is not None: | |||
| m.bias.data.zero_() | |||
| class Flatten(Module): | |||
| """ Flat tensor | |||
| """ | |||
| def forward(self, input): | |||
| return input.view(input.size(0), -1) | |||
| class SEModule(Module): | |||
| """ SE block | |||
| """ | |||
| def __init__(self, channels, reduction): | |||
| super(SEModule, self).__init__() | |||
| self.avg_pool = nn.AdaptiveAvgPool2d(1) | |||
| self.fc1 = Conv2d( | |||
| channels, | |||
| channels // reduction, | |||
| kernel_size=1, | |||
| padding=0, | |||
| bias=False) | |||
| nn.init.xavier_uniform_(self.fc1.weight.data) | |||
| self.relu = ReLU(inplace=True) | |||
| self.fc2 = Conv2d( | |||
| channels // reduction, | |||
| channels, | |||
| kernel_size=1, | |||
| padding=0, | |||
| bias=False) | |||
| self.sigmoid = Sigmoid() | |||
| def forward(self, x): | |||
| module_input = x | |||
| x = self.avg_pool(x) | |||
| x = self.fc1(x) | |||
| x = self.relu(x) | |||
| x = self.fc2(x) | |||
| x = self.sigmoid(x) | |||
| return module_input * x | |||
| @@ -0,0 +1,279 @@ | |||
| # based on: | |||
| # https://github.com/ZhaoJ9014/face.evoLVe.PyTorch/blob/master/backbone/model_irse.py | |||
| from collections import namedtuple | |||
| from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, | |||
| MaxPool2d, Module, PReLU, Sequential) | |||
| from .common import Flatten, SEModule, initialize_weights | |||
| class BasicBlockIR(Module): | |||
| """ BasicBlock for IRNet | |||
| """ | |||
| def __init__(self, in_channel, depth, stride): | |||
| super(BasicBlockIR, self).__init__() | |||
| if in_channel == depth: | |||
| self.shortcut_layer = MaxPool2d(1, stride) | |||
| else: | |||
| self.shortcut_layer = Sequential( | |||
| Conv2d(in_channel, depth, (1, 1), stride, bias=False), | |||
| BatchNorm2d(depth)) | |||
| self.res_layer = Sequential( | |||
| BatchNorm2d(in_channel), | |||
| Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), | |||
| BatchNorm2d(depth), PReLU(depth), | |||
| Conv2d(depth, depth, (3, 3), stride, 1, bias=False), | |||
| BatchNorm2d(depth)) | |||
| def forward(self, x): | |||
| shortcut = self.shortcut_layer(x) | |||
| res = self.res_layer(x) | |||
| return res + shortcut | |||
| class BottleneckIR(Module): | |||
| """ BasicBlock with bottleneck for IRNet | |||
| """ | |||
| def __init__(self, in_channel, depth, stride): | |||
| super(BottleneckIR, self).__init__() | |||
| reduction_channel = depth // 4 | |||
| if in_channel == depth: | |||
| self.shortcut_layer = MaxPool2d(1, stride) | |||
| else: | |||
| self.shortcut_layer = Sequential( | |||
| Conv2d(in_channel, depth, (1, 1), stride, bias=False), | |||
| BatchNorm2d(depth)) | |||
| self.res_layer = Sequential( | |||
| BatchNorm2d(in_channel), | |||
| Conv2d( | |||
| in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False), | |||
| BatchNorm2d(reduction_channel), PReLU(reduction_channel), | |||
| Conv2d( | |||
| reduction_channel, | |||
| reduction_channel, (3, 3), (1, 1), | |||
| 1, | |||
| bias=False), BatchNorm2d(reduction_channel), | |||
| PReLU(reduction_channel), | |||
| Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False), | |||
| BatchNorm2d(depth)) | |||
| def forward(self, x): | |||
| shortcut = self.shortcut_layer(x) | |||
| res = self.res_layer(x) | |||
| return res + shortcut | |||
| class BasicBlockIRSE(BasicBlockIR): | |||
| def __init__(self, in_channel, depth, stride): | |||
| super(BasicBlockIRSE, self).__init__(in_channel, depth, stride) | |||
| self.res_layer.add_module('se_block', SEModule(depth, 16)) | |||
| class BottleneckIRSE(BottleneckIR): | |||
| def __init__(self, in_channel, depth, stride): | |||
| super(BottleneckIRSE, self).__init__(in_channel, depth, stride) | |||
| self.res_layer.add_module('se_block', SEModule(depth, 16)) | |||
| class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): | |||
| '''A named tuple describing a ResNet block.''' | |||
| def get_block(in_channel, depth, num_units, stride=2): | |||
| return [Bottleneck(in_channel, depth, stride)] +\ | |||
| [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] | |||
| def get_blocks(num_layers): | |||
| if num_layers == 18: | |||
| blocks = [ | |||
| get_block(in_channel=64, depth=64, num_units=2), | |||
| get_block(in_channel=64, depth=128, num_units=2), | |||
| get_block(in_channel=128, depth=256, num_units=2), | |||
| get_block(in_channel=256, depth=512, num_units=2) | |||
| ] | |||
| elif num_layers == 34: | |||
| blocks = [ | |||
| get_block(in_channel=64, depth=64, num_units=3), | |||
| get_block(in_channel=64, depth=128, num_units=4), | |||
| get_block(in_channel=128, depth=256, num_units=6), | |||
| get_block(in_channel=256, depth=512, num_units=3) | |||
| ] | |||
| elif num_layers == 50: | |||
| blocks = [ | |||
| get_block(in_channel=64, depth=64, num_units=3), | |||
| get_block(in_channel=64, depth=128, num_units=4), | |||
| get_block(in_channel=128, depth=256, num_units=14), | |||
| get_block(in_channel=256, depth=512, num_units=3) | |||
| ] | |||
| elif num_layers == 100: | |||
| blocks = [ | |||
| get_block(in_channel=64, depth=64, num_units=3), | |||
| get_block(in_channel=64, depth=128, num_units=13), | |||
| get_block(in_channel=128, depth=256, num_units=30), | |||
| get_block(in_channel=256, depth=512, num_units=3) | |||
| ] | |||
| elif num_layers == 152: | |||
| blocks = [ | |||
| get_block(in_channel=64, depth=256, num_units=3), | |||
| get_block(in_channel=256, depth=512, num_units=8), | |||
| get_block(in_channel=512, depth=1024, num_units=36), | |||
| get_block(in_channel=1024, depth=2048, num_units=3) | |||
| ] | |||
| elif num_layers == 200: | |||
| blocks = [ | |||
| get_block(in_channel=64, depth=256, num_units=3), | |||
| get_block(in_channel=256, depth=512, num_units=24), | |||
| get_block(in_channel=512, depth=1024, num_units=36), | |||
| get_block(in_channel=1024, depth=2048, num_units=3) | |||
| ] | |||
| return blocks | |||
| class Backbone(Module): | |||
| def __init__(self, input_size, num_layers, mode='ir'): | |||
| """ Args: | |||
| input_size: input_size of backbone | |||
| num_layers: num_layers of backbone | |||
| mode: support ir or irse | |||
| """ | |||
| super(Backbone, self).__init__() | |||
| assert input_size[0] in [112, 224], \ | |||
| 'input_size should be [112, 112] or [224, 224]' | |||
| assert num_layers in [18, 34, 50, 100, 152, 200], \ | |||
| 'num_layers should be 18, 34, 50, 100 or 152' | |||
| assert mode in ['ir', 'ir_se'], \ | |||
| 'mode should be ir or ir_se' | |||
| self.input_layer = Sequential( | |||
| Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), | |||
| PReLU(64)) | |||
| blocks = get_blocks(num_layers) | |||
| if num_layers <= 100: | |||
| if mode == 'ir': | |||
| unit_module = BasicBlockIR | |||
| elif mode == 'ir_se': | |||
| unit_module = BasicBlockIRSE | |||
| output_channel = 512 | |||
| else: | |||
| if mode == 'ir': | |||
| unit_module = BottleneckIR | |||
| elif mode == 'ir_se': | |||
| unit_module = BottleneckIRSE | |||
| output_channel = 2048 | |||
| if input_size[0] == 112: | |||
| self.output_layer = Sequential( | |||
| BatchNorm2d(output_channel), Dropout(0.4), Flatten(), | |||
| Linear(output_channel * 7 * 7, 512), | |||
| BatchNorm1d(512, affine=False)) | |||
| else: | |||
| self.output_layer = Sequential( | |||
| BatchNorm2d(output_channel), Dropout(0.4), Flatten(), | |||
| Linear(output_channel * 14 * 14, 512), | |||
| BatchNorm1d(512, affine=False)) | |||
| modules = [] | |||
| for block in blocks: | |||
| for bottleneck in block: | |||
| modules.append( | |||
| unit_module(bottleneck.in_channel, bottleneck.depth, | |||
| bottleneck.stride)) | |||
| self.body = Sequential(*modules) | |||
| initialize_weights(self.modules()) | |||
| def forward(self, x): | |||
| x = self.input_layer(x) | |||
| x = self.body(x) | |||
| x = self.output_layer(x) | |||
| return x | |||
| def IR_18(input_size): | |||
| """ Constructs a ir-18 model. | |||
| """ | |||
| model = Backbone(input_size, 18, 'ir') | |||
| return model | |||
| def IR_34(input_size): | |||
| """ Constructs a ir-34 model. | |||
| """ | |||
| model = Backbone(input_size, 34, 'ir') | |||
| return model | |||
| def IR_50(input_size): | |||
| """ Constructs a ir-50 model. | |||
| """ | |||
| model = Backbone(input_size, 50, 'ir') | |||
| return model | |||
| def IR_101(input_size): | |||
| """ Constructs a ir-101 model. | |||
| """ | |||
| model = Backbone(input_size, 100, 'ir') | |||
| return model | |||
| def IR_152(input_size): | |||
| """ Constructs a ir-152 model. | |||
| """ | |||
| model = Backbone(input_size, 152, 'ir') | |||
| return model | |||
| def IR_200(input_size): | |||
| """ Constructs a ir-200 model. | |||
| """ | |||
| model = Backbone(input_size, 200, 'ir') | |||
| return model | |||
| def IR_SE_50(input_size): | |||
| """ Constructs a ir_se-50 model. | |||
| """ | |||
| model = Backbone(input_size, 50, 'ir_se') | |||
| return model | |||
| def IR_SE_101(input_size): | |||
| """ Constructs a ir_se-101 model. | |||
| """ | |||
| model = Backbone(input_size, 100, 'ir_se') | |||
| return model | |||
| def IR_SE_152(input_size): | |||
| """ Constructs a ir_se-152 model. | |||
| """ | |||
| model = Backbone(input_size, 152, 'ir_se') | |||
| return model | |||
| def IR_SE_200(input_size): | |||
| """ Constructs a ir_se-200 model. | |||
| """ | |||
| model = Backbone(input_size, 200, 'ir_se') | |||
| return model | |||
| @@ -0,0 +1,162 @@ | |||
| # based on: | |||
| # https://github.com/ZhaoJ9014/face.evoLVe.PyTorch/blob/master/backbone/model_resnet.py | |||
| import torch.nn as nn | |||
| from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, | |||
| MaxPool2d, Module, ReLU, Sequential) | |||
| from .common import initialize_weights | |||
| def conv3x3(in_planes, out_planes, stride=1): | |||
| """ 3x3 convolution with padding | |||
| """ | |||
| return Conv2d( | |||
| in_planes, | |||
| out_planes, | |||
| kernel_size=3, | |||
| stride=stride, | |||
| padding=1, | |||
| bias=False) | |||
| def conv1x1(in_planes, out_planes, stride=1): | |||
| """ 1x1 convolution | |||
| """ | |||
| return Conv2d( | |||
| in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |||
| class Bottleneck(Module): | |||
| expansion = 4 | |||
| def __init__(self, inplanes, planes, stride=1, downsample=None): | |||
| super(Bottleneck, self).__init__() | |||
| self.conv1 = conv1x1(inplanes, planes) | |||
| self.bn1 = BatchNorm2d(planes) | |||
| self.conv2 = conv3x3(planes, planes, stride) | |||
| self.bn2 = BatchNorm2d(planes) | |||
| self.conv3 = conv1x1(planes, planes * self.expansion) | |||
| self.bn3 = BatchNorm2d(planes * self.expansion) | |||
| self.relu = ReLU(inplace=True) | |||
| self.downsample = downsample | |||
| self.stride = stride | |||
| def forward(self, x): | |||
| identity = x | |||
| out = self.conv1(x) | |||
| out = self.bn1(out) | |||
| out = self.relu(out) | |||
| out = self.conv2(out) | |||
| out = self.bn2(out) | |||
| out = self.relu(out) | |||
| out = self.conv3(out) | |||
| out = self.bn3(out) | |||
| if self.downsample is not None: | |||
| identity = self.downsample(x) | |||
| out += identity | |||
| out = self.relu(out) | |||
| return out | |||
| class ResNet(Module): | |||
| """ ResNet backbone | |||
| """ | |||
| def __init__(self, input_size, block, layers, zero_init_residual=True): | |||
| """ Args: | |||
| input_size: input_size of backbone | |||
| block: block function | |||
| layers: layers in each block | |||
| """ | |||
| super(ResNet, self).__init__() | |||
| assert input_size[0] in [112, 224],\ | |||
| 'input_size should be [112, 112] or [224, 224]' | |||
| self.inplanes = 64 | |||
| self.conv1 = Conv2d( | |||
| 3, 64, kernel_size=7, stride=2, padding=3, bias=False) | |||
| self.bn1 = BatchNorm2d(64) | |||
| self.relu = ReLU(inplace=True) | |||
| self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1) | |||
| self.layer1 = self._make_layer(block, 64, layers[0]) | |||
| self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | |||
| self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | |||
| self.layer4 = self._make_layer(block, 512, layers[3], stride=2) | |||
| self.bn_o1 = BatchNorm2d(2048) | |||
| self.dropout = Dropout() | |||
| if input_size[0] == 112: | |||
| self.fc = Linear(2048 * 4 * 4, 512) | |||
| else: | |||
| self.fc = Linear(2048 * 7 * 7, 512) | |||
| self.bn_o2 = BatchNorm1d(512) | |||
| initialize_weights(self.modules) | |||
| if zero_init_residual: | |||
| for m in self.modules(): | |||
| if isinstance(m, Bottleneck): | |||
| nn.init.constant_(m.bn3.weight, 0) | |||
| def _make_layer(self, block, planes, blocks, stride=1): | |||
| downsample = None | |||
| if stride != 1 or self.inplanes != planes * block.expansion: | |||
| downsample = Sequential( | |||
| conv1x1(self.inplanes, planes * block.expansion, stride), | |||
| BatchNorm2d(planes * block.expansion), | |||
| ) | |||
| layers = [] | |||
| layers.append(block(self.inplanes, planes, stride, downsample)) | |||
| self.inplanes = planes * block.expansion | |||
| for _ in range(1, blocks): | |||
| layers.append(block(self.inplanes, planes)) | |||
| return Sequential(*layers) | |||
| def forward(self, x): | |||
| x = self.conv1(x) | |||
| x = self.bn1(x) | |||
| x = self.relu(x) | |||
| x = self.maxpool(x) | |||
| x = self.layer1(x) | |||
| x = self.layer2(x) | |||
| x = self.layer3(x) | |||
| x = self.layer4(x) | |||
| x = self.bn_o1(x) | |||
| x = self.dropout(x) | |||
| x = x.view(x.size(0), -1) | |||
| x = self.fc(x) | |||
| x = self.bn_o2(x) | |||
| return x | |||
| def ResNet_50(input_size, **kwargs): | |||
| """ Constructs a ResNet-50 model. | |||
| """ | |||
| model = ResNet(input_size, Bottleneck, [3, 4, 6, 3], **kwargs) | |||
| return model | |||
| def ResNet_101(input_size, **kwargs): | |||
| """ Constructs a ResNet-101 model. | |||
| """ | |||
| model = ResNet(input_size, Bottleneck, [3, 4, 23, 3], **kwargs) | |||
| return model | |||
| def ResNet_152(input_size, **kwargs): | |||
| """ Constructs a ResNet-152 model. | |||
| """ | |||
| model = ResNet(input_size, Bottleneck, [3, 8, 36, 3], **kwargs) | |||
| return model | |||
| @@ -13,6 +13,7 @@ class OutputKeys(object): | |||
| POSES = 'poses' | |||
| CAPTION = 'caption' | |||
| BOXES = 'boxes' | |||
| KEYPOINTS = 'keypoints' | |||
| MASKS = 'masks' | |||
| TEXT = 'text' | |||
| POLYGONS = 'polygons' | |||
| @@ -55,6 +56,31 @@ TASK_OUTPUTS = { | |||
| Tasks.object_detection: | |||
| [OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.BOXES], | |||
| # face detection result for single sample | |||
| # { | |||
| # "scores": [0.9, 0.1, 0.05, 0.05] | |||
| # "boxes": [ | |||
| # [x1, y1, x2, y2], | |||
| # [x1, y1, x2, y2], | |||
| # [x1, y1, x2, y2], | |||
| # [x1, y1, x2, y2], | |||
| # ], | |||
| # "keypoints": [ | |||
| # [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5], | |||
| # [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5], | |||
| # [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5], | |||
| # [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5], | |||
| # ], | |||
| # } | |||
| Tasks.face_detection: | |||
| [OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS], | |||
| # face recognition result for single sample | |||
| # { | |||
| # "img_embedding": np.array with shape [1, D], | |||
| # } | |||
| Tasks.face_recognition: [OutputKeys.IMG_EMBEDDING], | |||
| # instance segmentation result for single sample | |||
| # { | |||
| # "scores": [0.9, 0.1, 0.05, 0.05], | |||
| @@ -255,7 +255,11 @@ class Pipeline(ABC): | |||
| elif isinstance(data, InputFeatures): | |||
| return data | |||
| else: | |||
| raise ValueError(f'Unsupported data type {type(data)}') | |||
| import mmcv | |||
| if isinstance(data, mmcv.parallel.data_container.DataContainer): | |||
| return data | |||
| else: | |||
| raise ValueError(f'Unsupported data type {type(data)}') | |||
| def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]: | |||
| preprocess_params = kwargs.get('preprocess_params') | |||
| @@ -80,6 +80,10 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.text_to_image_synthesis: | |||
| (Pipelines.text_to_image_synthesis, | |||
| 'damo/cv_imagen_text-to-image-synthesis_tiny'), | |||
| Tasks.face_detection: (Pipelines.face_detection, | |||
| 'damo/cv_resnet_facedetection_scrfd10gkps'), | |||
| Tasks.face_recognition: (Pipelines.face_recognition, | |||
| 'damo/cv_ir101_facerecognition_cfglint'), | |||
| Tasks.video_multi_modal_embedding: | |||
| (Pipelines.video_multi_modal_embedding, | |||
| 'damo/multi_modal_clip_vtretrival_msrvtt_53'), | |||
| @@ -5,44 +5,50 @@ from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .action_recognition_pipeline import ActionRecognitionPipeline | |||
| from .animal_recog_pipeline import AnimalRecogPipeline | |||
| from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline | |||
| from .live_category_pipeline import LiveCategoryPipeline | |||
| from .image_classification_pipeline import GeneralImageClassificationPipeline | |||
| from .animal_recognition_pipeline import AnimalRecognitionPipeline | |||
| from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline | |||
| from .face_detection_pipeline import FaceDetectionPipeline | |||
| from .face_recognition_pipeline import FaceRecognitionPipeline | |||
| from .face_image_generation_pipeline import FaceImageGenerationPipeline | |||
| from .image_cartoon_pipeline import ImageCartoonPipeline | |||
| from .image_classification_pipeline import GeneralImageClassificationPipeline | |||
| from .image_denoise_pipeline import ImageDenoisePipeline | |||
| from .image_color_enhance_pipeline import ImageColorEnhancePipeline | |||
| from .image_colorization_pipeline import ImageColorizationPipeline | |||
| from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline | |||
| from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline | |||
| from .video_category_pipeline import VideoCategoryPipeline | |||
| from .image_matting_pipeline import ImageMattingPipeline | |||
| from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | |||
| from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline | |||
| from .style_transfer_pipeline import StyleTransferPipeline | |||
| from .live_category_pipeline import LiveCategoryPipeline | |||
| from .ocr_detection_pipeline import OCRDetectionPipeline | |||
| from .video_category_pipeline import VideoCategoryPipeline | |||
| from .virtual_tryon_pipeline import VirtualTryonPipeline | |||
| else: | |||
| _import_structure = { | |||
| 'action_recognition_pipeline': ['ActionRecognitionPipeline'], | |||
| 'animal_recog_pipeline': ['AnimalRecogPipeline'], | |||
| 'cmdssl_video_embedding_pipleline': ['CMDSSLVideoEmbeddingPipeline'], | |||
| 'animal_recognition_pipeline': ['AnimalRecognitionPipeline'], | |||
| 'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'], | |||
| 'face_detection_pipeline': ['FaceDetectionPipeline'], | |||
| 'face_image_generation_pipeline': ['FaceImageGenerationPipeline'], | |||
| 'face_recognition_pipeline': ['FaceRecognitionPipeline'], | |||
| 'image_classification_pipeline': | |||
| ['GeneralImageClassificationPipeline'], | |||
| 'image_cartoon_pipeline': ['ImageCartoonPipeline'], | |||
| 'image_denoise_pipeline': ['ImageDenoisePipeline'], | |||
| 'image_color_enhance_pipeline': ['ImageColorEnhancePipeline'], | |||
| 'virtual_tryon_pipeline': ['VirtualTryonPipeline'], | |||
| 'image_colorization_pipeline': ['ImageColorizationPipeline'], | |||
| 'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'], | |||
| 'image_denoise_pipeline': ['ImageDenoisePipeline'], | |||
| 'face_image_generation_pipeline': ['FaceImageGenerationPipeline'], | |||
| 'image_cartoon_pipeline': ['ImageCartoonPipeline'], | |||
| 'image_matting_pipeline': ['ImageMattingPipeline'], | |||
| 'style_transfer_pipeline': ['StyleTransferPipeline'], | |||
| 'ocr_detection_pipeline': ['OCRDetectionPipeline'], | |||
| 'image_instance_segmentation_pipeline': | |||
| ['ImageInstanceSegmentationPipeline'], | |||
| 'video_category_pipeline': ['VideoCategoryPipeline'], | |||
| 'image_matting_pipeline': ['ImageMattingPipeline'], | |||
| 'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'], | |||
| 'image_to_image_translation_pipeline': | |||
| ['Image2ImageTranslationPipeline'], | |||
| 'live_category_pipeline': ['LiveCategoryPipeline'], | |||
| 'ocr_detection_pipeline': ['OCRDetectionPipeline'], | |||
| 'style_transfer_pipeline': ['StyleTransferPipeline'], | |||
| 'video_category_pipeline': ['VideoCategoryPipeline'], | |||
| 'virtual_tryon_pipeline': ['VirtualTryonPipeline'], | |||
| } | |||
| import sys | |||
| @@ -23,7 +23,7 @@ class ActionRecognitionPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||
| use `model` to create a action recognition pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| @@ -22,11 +22,11 @@ logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.image_classification, module_name=Pipelines.animal_recognation) | |||
| class AnimalRecogPipeline(Pipeline): | |||
| class AnimalRecognitionPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||
| use `model` to create a animal recognition pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| @@ -24,7 +24,7 @@ class CMDSSLVideoEmbeddingPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||
| use `model` to create a CMDSSL Video Embedding pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| @@ -0,0 +1,105 @@ | |||
| import os.path as osp | |||
| from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import LoadImage | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.face_detection, module_name=Pipelines.face_detection) | |||
| class FaceDetectionPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` to create a face detection pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model, **kwargs) | |||
| from mmcv import Config | |||
| from mmcv.parallel import MMDataParallel | |||
| from mmcv.runner import load_checkpoint | |||
| from mmdet.models import build_detector | |||
| from modelscope.models.cv.face_detection.mmdet_patch.datasets import RetinaFaceDataset | |||
| from modelscope.models.cv.face_detection.mmdet_patch.datasets.pipelines import RandomSquareCrop | |||
| from modelscope.models.cv.face_detection.mmdet_patch.models.backbones import ResNetV1e | |||
| from modelscope.models.cv.face_detection.mmdet_patch.models.dense_heads import SCRFDHead | |||
| from modelscope.models.cv.face_detection.mmdet_patch.models.detectors import SCRFD | |||
| cfg = Config.fromfile(osp.join(model, 'mmcv_scrfd_10g_bnkps.py')) | |||
| detector = build_detector( | |||
| cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) | |||
| ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_BIN_FILE) | |||
| logger.info(f'loading model from {ckpt_path}') | |||
| device = torch.device( | |||
| f'cuda:{0}' if torch.cuda.is_available() else 'cpu') | |||
| load_checkpoint(detector, ckpt_path, map_location=device) | |||
| detector = MMDataParallel(detector, device_ids=[0]) | |||
| detector.eval() | |||
| self.detector = detector | |||
| logger.info('load model done') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| img = LoadImage.convert_to_ndarray(input) | |||
| img = img.astype(np.float32) | |||
| pre_pipeline = [ | |||
| dict( | |||
| type='MultiScaleFlipAug', | |||
| img_scale=(640, 640), | |||
| flip=False, | |||
| transforms=[ | |||
| dict(type='Resize', keep_ratio=True), | |||
| dict(type='RandomFlip', flip_ratio=0.0), | |||
| dict( | |||
| type='Normalize', | |||
| mean=[127.5, 127.5, 127.5], | |||
| std=[128.0, 128.0, 128.0], | |||
| to_rgb=False), | |||
| dict(type='Pad', size=(640, 640), pad_val=0), | |||
| dict(type='ImageToTensor', keys=['img']), | |||
| dict(type='Collect', keys=['img']) | |||
| ]) | |||
| ] | |||
| from mmdet.datasets.pipelines import Compose | |||
| pipeline = Compose(pre_pipeline) | |||
| result = {} | |||
| result['filename'] = '' | |||
| result['ori_filename'] = '' | |||
| result['img'] = img | |||
| result['img_shape'] = img.shape | |||
| result['ori_shape'] = img.shape | |||
| result['img_fields'] = ['img'] | |||
| result = pipeline(result) | |||
| return result | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| result = self.detector( | |||
| return_loss=False, | |||
| rescale=True, | |||
| img=[input['img'][0].unsqueeze(0)], | |||
| img_metas=[[dict(input['img_metas'][0].data)]]) | |||
| assert result is not None | |||
| result = result[0][0] | |||
| bboxes = result[:, :4].tolist() | |||
| kpss = result[:, 5:].tolist() | |||
| scores = result[:, 4].tolist() | |||
| return { | |||
| OutputKeys.SCORES: scores, | |||
| OutputKeys.BOXES: bboxes, | |||
| OutputKeys.KEYPOINTS: kpss | |||
| } | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -24,7 +24,7 @@ class FaceImageGenerationPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` to create a kws pipeline for prediction | |||
| use `model` to create a face image generation pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| @@ -0,0 +1,130 @@ | |||
| import os.path as osp | |||
| from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.face_recognition.align_face import align_face | |||
| from modelscope.models.cv.face_recognition.torchkit.backbone import get_model | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import LoadImage | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.face_recognition, module_name=Pipelines.face_recognition) | |||
| class FaceRecognitionPipeline(Pipeline): | |||
| def __init__(self, model: str, face_detection: Pipeline, **kwargs): | |||
| """ | |||
| use `model` to create a face recognition pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| face_detecion: pipeline for face detection and face alignment before recognition | |||
| """ | |||
| # face recong model | |||
| super().__init__(model=model, **kwargs) | |||
| device = torch.device( | |||
| f'cuda:{0}' if torch.cuda.is_available() else 'cpu') | |||
| self.device = device | |||
| face_model = get_model('IR_101')([112, 112]) | |||
| face_model.load_state_dict( | |||
| torch.load( | |||
| osp.join(model, ModelFile.TORCH_MODEL_BIN_FILE), | |||
| map_location=device)) | |||
| face_model = face_model.to(device) | |||
| face_model.eval() | |||
| self.face_model = face_model | |||
| logger.info('face recognition model loaded!') | |||
| # face detect pipeline | |||
| self.face_detection = face_detection | |||
| def _choose_face(self, | |||
| det_result, | |||
| min_face=10, | |||
| top_face=1, | |||
| center_face=False): | |||
| ''' | |||
| choose face with maximum area | |||
| Args: | |||
| det_result: output of face detection pipeline | |||
| min_face: minimum size of valid face w/h | |||
| top_face: take faces with top max areas | |||
| center_face: choose the most centerd face from multi faces, only valid if top_face > 1 | |||
| ''' | |||
| bboxes = np.array(det_result[OutputKeys.BOXES]) | |||
| landmarks = np.array(det_result[OutputKeys.KEYPOINTS]) | |||
| # scores = np.array(det_result[OutputKeys.SCORES]) | |||
| if bboxes.shape[0] == 0: | |||
| logger.info('No face detected!') | |||
| return None | |||
| # face idx with enough size | |||
| face_idx = [] | |||
| for i in range(bboxes.shape[0]): | |||
| box = bboxes[i] | |||
| if (box[2] - box[0]) >= min_face and (box[3] - box[1]) >= min_face: | |||
| face_idx += [i] | |||
| if len(face_idx) == 0: | |||
| logger.info( | |||
| f'Face size not enough, less than {min_face}x{min_face}!') | |||
| return None | |||
| bboxes = bboxes[face_idx] | |||
| landmarks = landmarks[face_idx] | |||
| # find max faces | |||
| boxes = np.array(bboxes) | |||
| area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) | |||
| sort_idx = np.argsort(area)[-top_face:] | |||
| # find center face | |||
| if top_face > 1 and center_face and bboxes.shape[0] > 1: | |||
| img_center = [img.shape[1] // 2, img.shape[0] // 2] | |||
| min_dist = float('inf') | |||
| sel_idx = -1 | |||
| for _idx in sort_idx: | |||
| box = boxes[_idx] | |||
| dist = np.square( | |||
| np.abs((box[0] + box[2]) / 2 - img_center[0])) + np.square( | |||
| np.abs((box[1] + box[3]) / 2 - img_center[1])) | |||
| if dist < min_dist: | |||
| min_dist = dist | |||
| sel_idx = _idx | |||
| sort_idx = [sel_idx] | |||
| main_idx = sort_idx[-1] | |||
| return bboxes[main_idx], landmarks[main_idx] | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| img = LoadImage.convert_to_ndarray(input) | |||
| img = img[:, :, ::-1] | |||
| det_result = self.face_detection(img.copy()) | |||
| rtn = self._choose_face(det_result) | |||
| face_img = None | |||
| if rtn is not None: | |||
| _, face_lmks = rtn | |||
| face_lmks = face_lmks.reshape(5, 2) | |||
| align_img, _ = align_face(img, (112, 112), face_lmks) | |||
| face_img = align_img[:, :, ::-1] # to rgb | |||
| face_img = np.transpose(face_img, axes=(2, 0, 1)) | |||
| face_img = (face_img / 255. - 0.5) / 0.5 | |||
| face_img = face_img.astype(np.float32) | |||
| result = {} | |||
| result['img'] = face_img | |||
| return result | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| assert input['img'] is not None | |||
| img = input['img'].unsqueeze(0) | |||
| emb = self.face_model(img).detach().cpu().numpy() | |||
| emb /= np.sqrt(np.sum(emb**2, -1, keepdims=True)) # l2 norm | |||
| return {OutputKeys.IMG_EMBEDDING: emb} | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -30,7 +30,7 @@ class ImageCartoonPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||
| use `model` to create a image cartoon pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| @@ -27,7 +27,7 @@ class ImageColorEnhancePipeline(Pipeline): | |||
| ImageColorEnhanceFinetunePreprocessor] = None, | |||
| **kwargs): | |||
| """ | |||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||
| use `model` and `preprocessor` to create a image color enhance pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| @@ -25,7 +25,7 @@ class ImageColorizationPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` to create a kws pipeline for prediction | |||
| use `model` to create a image colorization pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| @@ -21,7 +21,7 @@ class ImageMattingPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||
| use `model` to create a image matting pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| @@ -23,7 +23,7 @@ class ImageSuperResolutionPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` to create a kws pipeline for prediction | |||
| use `model` to create a image super resolution pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| @@ -41,7 +41,7 @@ class OCRDetectionPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||
| use `model` to create a OCR detection pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| @@ -21,7 +21,7 @@ class StyleTransferPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||
| use `model` to create a style transfer pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| @@ -25,7 +25,7 @@ class VirtualTryonPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` to create a kws pipeline for prediction | |||
| use `model` to create a virtual tryon pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| @@ -28,6 +28,8 @@ class CVTasks(object): | |||
| ocr_detection = 'ocr-detection' | |||
| action_recognition = 'action-recognition' | |||
| video_embedding = 'video-embedding' | |||
| face_detection = 'face-detection' | |||
| face_recognition = 'face-recognition' | |||
| image_color_enhance = 'image-color-enhance' | |||
| virtual_tryon = 'virtual-tryon' | |||
| image_colorization = 'image-colorization' | |||
| @@ -0,0 +1,84 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| import tempfile | |||
| import unittest | |||
| import cv2 | |||
| import numpy as np | |||
| from modelscope.fileio import File | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class FaceDetectionTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' | |||
| def show_result(self, img_path, bboxes, kpss, scores): | |||
| bboxes = np.array(bboxes) | |||
| kpss = np.array(kpss) | |||
| scores = np.array(scores) | |||
| img = cv2.imread(img_path) | |||
| assert img is not None, f"Can't read img: {img_path}" | |||
| for i in range(len(scores)): | |||
| bbox = bboxes[i].astype(np.int32) | |||
| kps = kpss[i].reshape(-1, 2).astype(np.int32) | |||
| score = scores[i] | |||
| x1, y1, x2, y2 = bbox | |||
| cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2) | |||
| for kp in kps: | |||
| cv2.circle(img, tuple(kp), 1, (0, 0, 255), 1) | |||
| cv2.putText( | |||
| img, | |||
| f'{score:.2f}', (x1, y2), | |||
| 1, | |||
| 1.0, (0, 255, 0), | |||
| thickness=1, | |||
| lineType=8) | |||
| cv2.imwrite('result.png', img) | |||
| print( | |||
| f'Found {len(scores)} faces, output written to {osp.abspath("result.png")}' | |||
| ) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_dataset(self): | |||
| input_location = ['data/test/images/face_detection.png'] | |||
| # alternatively: | |||
| # input_location = '/dir/to/images' | |||
| dataset = MsDataset.load(input_location, target='image') | |||
| face_detection = pipeline(Tasks.face_detection, model=self.model_id) | |||
| # note that for dataset output, the inference-output is a Generator that can be iterated. | |||
| result = face_detection(dataset) | |||
| result = next(result) | |||
| self.show_result(input_location[0], result[OutputKeys.BOXES], | |||
| result[OutputKeys.KEYPOINTS], | |||
| result[OutputKeys.SCORES]) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_modelhub(self): | |||
| face_detection = pipeline(Tasks.face_detection, model=self.model_id) | |||
| img_path = 'data/test/images/face_detection.png' | |||
| result = face_detection(img_path) | |||
| self.show_result(img_path, result[OutputKeys.BOXES], | |||
| result[OutputKeys.KEYPOINTS], | |||
| result[OutputKeys.SCORES]) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_modelhub_default_model(self): | |||
| face_detection = pipeline(Tasks.face_detection) | |||
| img_path = 'data/test/images/face_detection.png' | |||
| result = face_detection(img_path) | |||
| self.show_result(img_path, result[OutputKeys.BOXES], | |||
| result[OutputKeys.KEYPOINTS], | |||
| result[OutputKeys.SCORES]) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,42 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| import tempfile | |||
| import unittest | |||
| import cv2 | |||
| import numpy as np | |||
| from modelscope.fileio import File | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class FaceRecognitionTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.recog_model_id = 'damo/cv_ir101_facerecognition_cfglint' | |||
| self.det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_face_compare(self): | |||
| img1 = 'data/test/images/face_recognition_1.png' | |||
| img2 = 'data/test/images/face_recognition_2.png' | |||
| face_detection = pipeline( | |||
| Tasks.face_detection, model=self.det_model_id) | |||
| face_recognition = pipeline( | |||
| Tasks.face_recognition, | |||
| face_detection=face_detection, | |||
| model=self.recog_model_id) | |||
| # note that for dataset output, the inference-output is a Generator that can be iterated. | |||
| emb1 = face_recognition(img1)[OutputKeys.IMG_EMBEDDING] | |||
| emb2 = face_recognition(img2)[OutputKeys.IMG_EMBEDDING] | |||
| sim = np.dot(emb1[0], emb2[0]) | |||
| print(f'Cos similarity={sim:.3f}, img1:{img1} img2:{img2}') | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||