1. support FaceDetectionPipeline inference Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9470723master
@@ -121,6 +121,7 @@ source.sh | |||
tensorboard.sh | |||
.DS_Store | |||
replace.sh | |||
result.png | |||
# Pytorch | |||
*.pth | |||
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:aa3963d1c54e6d3d46e9a59872a99ed955d4050092f5cfe5f591e03d740b7042 | |||
size 653006 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:48e541daeb2692907efef47018e41abb5ae6bcd88eb5ff58290d7fe5dc8b2a13 | |||
size 462584 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:e9565b43d9f65361b9bad6553b327c2c6f02fd063a4c8dc0f461e88ea461989d | |||
size 357166 |
@@ -10,6 +10,7 @@ class Models(object): | |||
Model name should only contain model info but not task info. | |||
""" | |||
# vision models | |||
scrfd = 'scrfd' | |||
classification_model = 'ClassificationModel' | |||
nafnet = 'nafnet' | |||
csrnet = 'csrnet' | |||
@@ -67,6 +68,7 @@ class Pipelines(object): | |||
action_recognition = 'TAdaConv_action-recognition' | |||
animal_recognation = 'resnet101-animal_recog' | |||
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | |||
face_detection = 'resnet-face-detection-scrfd10gkps' | |||
live_category = 'live-category' | |||
general_image_classification = 'vit-base_image-classification_ImageNet-labels' | |||
daily_image_classification = 'vit-base_image-classification_Dailylife-labels' | |||
@@ -76,6 +78,7 @@ class Pipelines(object): | |||
image_super_resolution = 'rrdb-image-super-resolution' | |||
face_image_generation = 'gan-face-image-generation' | |||
style_transfer = 'AAMS-style-transfer' | |||
face_recognition = 'ir101-face-recognition-cfglint' | |||
image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' | |||
image2image_translation = 'image-to-image-translation' | |||
live_category = 'live-category' | |||
@@ -1,5 +1,6 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from . import (action_recognition, animal_recognition, cartoon, | |||
cmdssl_video_embedding, face_generation, image_classification, | |||
image_color_enhance, image_colorization, image_denoise, | |||
image_instance_segmentation, super_resolution, virual_tryon) | |||
cmdssl_video_embedding, face_detection, face_generation, | |||
image_classification, image_color_enhance, image_colorization, | |||
image_denoise, image_instance_segmentation, | |||
image_to_image_translation, super_resolution, virual_tryon) |
@@ -0,0 +1,5 @@ | |||
""" | |||
mmdet_patch is based on | |||
https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet, | |||
all duplicate functions from official mmdetection are removed. | |||
""" |
@@ -0,0 +1,3 @@ | |||
from .transforms import bbox2result, distance2kps, kps2distance | |||
__all__ = ['bbox2result', 'distance2kps', 'kps2distance'] |
@@ -0,0 +1,86 @@ | |||
""" | |||
based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/core/bbox/transforms.py | |||
""" | |||
import numpy as np | |||
import torch | |||
def bbox2result(bboxes, labels, num_classes, kps=None): | |||
"""Convert detection results to a list of numpy arrays. | |||
Args: | |||
bboxes (torch.Tensor | np.ndarray): shape (n, 5) | |||
labels (torch.Tensor | np.ndarray): shape (n, ) | |||
num_classes (int): class number, including background class | |||
Returns: | |||
list(ndarray): bbox results of each class | |||
""" | |||
bbox_len = 5 if kps is None else 5 + 10 # if has kps, add 10 kps into bbox | |||
if bboxes.shape[0] == 0: | |||
return [ | |||
np.zeros((0, bbox_len), dtype=np.float32) | |||
for i in range(num_classes) | |||
] | |||
else: | |||
if isinstance(bboxes, torch.Tensor): | |||
bboxes = bboxes.detach().cpu().numpy() | |||
labels = labels.detach().cpu().numpy() | |||
if kps is None: | |||
return [bboxes[labels == i, :] for i in range(num_classes)] | |||
else: # with kps | |||
if isinstance(kps, torch.Tensor): | |||
kps = kps.detach().cpu().numpy() | |||
return [ | |||
np.hstack([bboxes[labels == i, :], kps[labels == i, :]]) | |||
for i in range(num_classes) | |||
] | |||
def distance2kps(points, distance, max_shape=None): | |||
"""Decode distance prediction to bounding box. | |||
Args: | |||
points (Tensor): Shape (n, 2), [x, y]. | |||
distance (Tensor): Distance from the given point to 4 | |||
boundaries (left, top, right, bottom). | |||
max_shape (tuple): Shape of the image. | |||
Returns: | |||
Tensor: Decoded kps. | |||
""" | |||
preds = [] | |||
for i in range(0, distance.shape[1], 2): | |||
px = points[:, i % 2] + distance[:, i] | |||
py = points[:, i % 2 + 1] + distance[:, i + 1] | |||
if max_shape is not None: | |||
px = px.clamp(min=0, max=max_shape[1]) | |||
py = py.clamp(min=0, max=max_shape[0]) | |||
preds.append(px) | |||
preds.append(py) | |||
return torch.stack(preds, -1) | |||
def kps2distance(points, kps, max_dis=None, eps=0.1): | |||
"""Decode bounding box based on distances. | |||
Args: | |||
points (Tensor): Shape (n, 2), [x, y]. | |||
kps (Tensor): Shape (n, K), "xyxy" format | |||
max_dis (float): Upper bound of the distance. | |||
eps (float): a small value to ensure target < max_dis, instead <= | |||
Returns: | |||
Tensor: Decoded distances. | |||
""" | |||
preds = [] | |||
for i in range(0, kps.shape[1], 2): | |||
px = kps[:, i] - points[:, i % 2] | |||
py = kps[:, i + 1] - points[:, i % 2 + 1] | |||
if max_dis is not None: | |||
px = px.clamp(min=0, max=max_dis - eps) | |||
py = py.clamp(min=0, max=max_dis - eps) | |||
preds.append(px) | |||
preds.append(py) | |||
return torch.stack(preds, -1) |
@@ -0,0 +1,3 @@ | |||
from .bbox_nms import multiclass_nms | |||
__all__ = ['multiclass_nms'] |
@@ -0,0 +1,85 @@ | |||
""" | |||
based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/core/post_processing/bbox_nms.py | |||
""" | |||
import torch | |||
def multiclass_nms(multi_bboxes, | |||
multi_scores, | |||
score_thr, | |||
nms_cfg, | |||
max_num=-1, | |||
score_factors=None, | |||
return_inds=False, | |||
multi_kps=None): | |||
"""NMS for multi-class bboxes. | |||
Args: | |||
multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) | |||
multi_scores (Tensor): shape (n, #class), where the last column | |||
contains scores of the background class, but this will be ignored. | |||
score_thr (float): bbox threshold, bboxes with scores lower than it | |||
will not be considered. | |||
nms_thr (float): NMS IoU threshold | |||
max_num (int, optional): if there are more than max_num bboxes after | |||
NMS, only top max_num will be kept. Default to -1. | |||
score_factors (Tensor, optional): The factors multiplied to scores | |||
before applying NMS. Default to None. | |||
return_inds (bool, optional): Whether return the indices of kept | |||
bboxes. Default to False. | |||
Returns: | |||
tuple: (bboxes, labels, indices (optional)), tensors of shape (k, 5), | |||
(k), and (k). Labels are 0-based. | |||
""" | |||
num_classes = multi_scores.size(1) - 1 | |||
# exclude background category | |||
kps = None | |||
if multi_bboxes.shape[1] > 4: | |||
bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) | |||
if multi_kps is not None: | |||
kps = multi_kps.view(multi_scores.size(0), -1, 10) | |||
else: | |||
bboxes = multi_bboxes[:, None].expand( | |||
multi_scores.size(0), num_classes, 4) | |||
if multi_kps is not None: | |||
kps = multi_kps[:, None].expand( | |||
multi_scores.size(0), num_classes, 10) | |||
scores = multi_scores[:, :-1] | |||
if score_factors is not None: | |||
scores = scores * score_factors[:, None] | |||
labels = torch.arange(num_classes, dtype=torch.long) | |||
labels = labels.view(1, -1).expand_as(scores) | |||
bboxes = bboxes.reshape(-1, 4) | |||
if kps is not None: | |||
kps = kps.reshape(-1, 10) | |||
scores = scores.reshape(-1) | |||
labels = labels.reshape(-1) | |||
# remove low scoring boxes | |||
valid_mask = scores > score_thr | |||
inds = valid_mask.nonzero(as_tuple=False).squeeze(1) | |||
bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds] | |||
if kps is not None: | |||
kps = kps[inds] | |||
if inds.numel() == 0: | |||
if torch.onnx.is_in_onnx_export(): | |||
raise RuntimeError('[ONNX Error] Can not record NMS ' | |||
'as it has not been executed this time') | |||
return bboxes, labels, kps | |||
# TODO: add size check before feed into batched_nms | |||
from mmcv.ops.nms import batched_nms | |||
dets, keep = batched_nms(bboxes, scores, labels, nms_cfg) | |||
if max_num > 0: | |||
dets = dets[:max_num] | |||
keep = keep[:max_num] | |||
if return_inds: | |||
return dets, labels[keep], kps[keep], keep | |||
else: | |||
return dets, labels[keep], kps[keep] |
@@ -0,0 +1,3 @@ | |||
from .retinaface import RetinaFaceDataset | |||
__all__ = ['RetinaFaceDataset'] |
@@ -0,0 +1,3 @@ | |||
from .transforms import RandomSquareCrop | |||
__all__ = ['RandomSquareCrop'] |
@@ -0,0 +1,188 @@ | |||
""" | |||
based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/transforms.py | |||
""" | |||
import numpy as np | |||
from mmdet.datasets.builder import PIPELINES | |||
from numpy import random | |||
@PIPELINES.register_module() | |||
class RandomSquareCrop(object): | |||
"""Random crop the image & bboxes, the cropped patches have minimum IoU | |||
requirement with original image & bboxes, the IoU threshold is randomly | |||
selected from min_ious. | |||
Args: | |||
min_ious (tuple): minimum IoU threshold for all intersections with | |||
bounding boxes | |||
min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w, | |||
where a >= min_crop_size). | |||
Note: | |||
The keys for bboxes, labels and masks should be paired. That is, \ | |||
`gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \ | |||
`gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`. | |||
""" | |||
def __init__(self, | |||
crop_ratio_range=None, | |||
crop_choice=None, | |||
bbox_clip_border=True): | |||
self.crop_ratio_range = crop_ratio_range | |||
self.crop_choice = crop_choice | |||
self.bbox_clip_border = bbox_clip_border | |||
assert (self.crop_ratio_range is None) ^ (self.crop_choice is None) | |||
if self.crop_ratio_range is not None: | |||
self.crop_ratio_min, self.crop_ratio_max = self.crop_ratio_range | |||
self.bbox2label = { | |||
'gt_bboxes': 'gt_labels', | |||
'gt_bboxes_ignore': 'gt_labels_ignore' | |||
} | |||
self.bbox2mask = { | |||
'gt_bboxes': 'gt_masks', | |||
'gt_bboxes_ignore': 'gt_masks_ignore' | |||
} | |||
def __call__(self, results): | |||
"""Call function to crop images and bounding boxes with minimum IoU | |||
constraint. | |||
Args: | |||
results (dict): Result dict from loading pipeline. | |||
Returns: | |||
dict: Result dict with images and bounding boxes cropped, \ | |||
'img_shape' key is updated. | |||
""" | |||
if 'img_fields' in results: | |||
assert results['img_fields'] == ['img'], \ | |||
'Only single img_fields is allowed' | |||
img = results['img'] | |||
assert 'bbox_fields' in results | |||
assert 'gt_bboxes' in results | |||
boxes = results['gt_bboxes'] | |||
h, w, c = img.shape | |||
scale_retry = 0 | |||
if self.crop_ratio_range is not None: | |||
max_scale = self.crop_ratio_max | |||
else: | |||
max_scale = np.amax(self.crop_choice) | |||
while True: | |||
scale_retry += 1 | |||
if scale_retry == 1 or max_scale > 1.0: | |||
if self.crop_ratio_range is not None: | |||
scale = np.random.uniform(self.crop_ratio_min, | |||
self.crop_ratio_max) | |||
elif self.crop_choice is not None: | |||
scale = np.random.choice(self.crop_choice) | |||
else: | |||
scale = scale * 1.2 | |||
for i in range(250): | |||
short_side = min(w, h) | |||
cw = int(scale * short_side) | |||
ch = cw | |||
# TODO +1 | |||
if w == cw: | |||
left = 0 | |||
elif w > cw: | |||
left = random.randint(0, w - cw) | |||
else: | |||
left = random.randint(w - cw, 0) | |||
if h == ch: | |||
top = 0 | |||
elif h > ch: | |||
top = random.randint(0, h - ch) | |||
else: | |||
top = random.randint(h - ch, 0) | |||
patch = np.array( | |||
(int(left), int(top), int(left + cw), int(top + ch)), | |||
dtype=np.int) | |||
# center of boxes should inside the crop img | |||
# only adjust boxes and instance masks when the gt is not empty | |||
# adjust boxes | |||
def is_center_of_bboxes_in_patch(boxes, patch): | |||
# TODO >= | |||
center = (boxes[:, :2] + boxes[:, 2:]) / 2 | |||
mask = \ | |||
((center[:, 0] > patch[0]) | |||
* (center[:, 1] > patch[1]) | |||
* (center[:, 0] < patch[2]) | |||
* (center[:, 1] < patch[3])) | |||
return mask | |||
mask = is_center_of_bboxes_in_patch(boxes, patch) | |||
if not mask.any(): | |||
continue | |||
for key in results.get('bbox_fields', []): | |||
boxes = results[key].copy() | |||
mask = is_center_of_bboxes_in_patch(boxes, patch) | |||
boxes = boxes[mask] | |||
if self.bbox_clip_border: | |||
boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) | |||
boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) | |||
boxes -= np.tile(patch[:2], 2) | |||
results[key] = boxes | |||
# labels | |||
label_key = self.bbox2label.get(key) | |||
if label_key in results: | |||
results[label_key] = results[label_key][mask] | |||
# keypoints field | |||
if key == 'gt_bboxes': | |||
for kps_key in results.get('keypoints_fields', []): | |||
keypointss = results[kps_key].copy() | |||
keypointss = keypointss[mask, :, :] | |||
if self.bbox_clip_border: | |||
keypointss[:, :, : | |||
2] = keypointss[:, :, :2].clip( | |||
max=patch[2:]) | |||
keypointss[:, :, : | |||
2] = keypointss[:, :, :2].clip( | |||
min=patch[:2]) | |||
keypointss[:, :, 0] -= patch[0] | |||
keypointss[:, :, 1] -= patch[1] | |||
results[kps_key] = keypointss | |||
# mask fields | |||
mask_key = self.bbox2mask.get(key) | |||
if mask_key in results: | |||
results[mask_key] = results[mask_key][mask.nonzero() | |||
[0]].crop(patch) | |||
# adjust the img no matter whether the gt is empty before crop | |||
rimg = np.ones((ch, cw, 3), dtype=img.dtype) * 128 | |||
patch_from = patch.copy() | |||
patch_from[0] = max(0, patch_from[0]) | |||
patch_from[1] = max(0, patch_from[1]) | |||
patch_from[2] = min(img.shape[1], patch_from[2]) | |||
patch_from[3] = min(img.shape[0], patch_from[3]) | |||
patch_to = patch.copy() | |||
patch_to[0] = max(0, patch_to[0] * -1) | |||
patch_to[1] = max(0, patch_to[1] * -1) | |||
patch_to[2] = patch_to[0] + (patch_from[2] - patch_from[0]) | |||
patch_to[3] = patch_to[1] + (patch_from[3] - patch_from[1]) | |||
rimg[patch_to[1]:patch_to[3], | |||
patch_to[0]:patch_to[2], :] = img[ | |||
patch_from[1]:patch_from[3], | |||
patch_from[0]:patch_from[2], :] | |||
img = rimg | |||
results['img'] = img | |||
results['img_shape'] = img.shape | |||
return results | |||
def __repr__(self): | |||
repr_str = self.__class__.__name__ | |||
repr_str += f'(min_ious={self.min_iou}, ' | |||
repr_str += f'crop_size={self.crop_size})' | |||
return repr_str |
@@ -0,0 +1,151 @@ | |||
""" | |||
based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/retinaface.py | |||
""" | |||
import numpy as np | |||
from mmdet.datasets.builder import DATASETS | |||
from mmdet.datasets.custom import CustomDataset | |||
@DATASETS.register_module() | |||
class RetinaFaceDataset(CustomDataset): | |||
CLASSES = ('FG', ) | |||
def __init__(self, min_size=None, **kwargs): | |||
self.NK = 5 | |||
self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)} | |||
self.min_size = min_size | |||
self.gt_path = kwargs.get('gt_path') | |||
super(RetinaFaceDataset, self).__init__(**kwargs) | |||
def _parse_ann_line(self, line): | |||
values = [float(x) for x in line.strip().split()] | |||
bbox = np.array(values[0:4], dtype=np.float32) | |||
kps = np.zeros((self.NK, 3), dtype=np.float32) | |||
ignore = False | |||
if self.min_size is not None: | |||
assert not self.test_mode | |||
w = bbox[2] - bbox[0] | |||
h = bbox[3] - bbox[1] | |||
if w < self.min_size or h < self.min_size: | |||
ignore = True | |||
if len(values) > 4: | |||
if len(values) > 5: | |||
kps = np.array( | |||
values[4:19], dtype=np.float32).reshape((self.NK, 3)) | |||
for li in range(kps.shape[0]): | |||
if (kps[li, :] == -1).all(): | |||
kps[li][2] = 0.0 # weight = 0, ignore | |||
else: | |||
assert kps[li][2] >= 0 | |||
kps[li][2] = 1.0 # weight | |||
else: # len(values)==5 | |||
if not ignore: | |||
ignore = (values[4] == 1) | |||
else: | |||
assert self.test_mode | |||
return dict(bbox=bbox, kps=kps, ignore=ignore, cat='FG') | |||
def load_annotations(self, ann_file): | |||
"""Load annotation from COCO style annotation file. | |||
Args: | |||
ann_file (str): Path of annotation file. | |||
20220711@tyx: ann_file is list of img paths is supported | |||
Returns: | |||
list[dict]: Annotation info from COCO api. | |||
""" | |||
if isinstance(ann_file, list): | |||
data_infos = [] | |||
for line in ann_file: | |||
name = line | |||
objs = [0, 0, 0, 0] | |||
data_infos.append( | |||
dict(filename=name, width=0, height=0, objs=objs)) | |||
else: | |||
name = None | |||
bbox_map = {} | |||
for line in open(ann_file, 'r'): | |||
line = line.strip() | |||
if line.startswith('#'): | |||
value = line[1:].strip().split() | |||
name = value[0] | |||
width = int(value[1]) | |||
height = int(value[2]) | |||
bbox_map[name] = dict(width=width, height=height, objs=[]) | |||
continue | |||
assert name is not None | |||
assert name in bbox_map | |||
bbox_map[name]['objs'].append(line) | |||
print('origin image size', len(bbox_map)) | |||
data_infos = [] | |||
for name in bbox_map: | |||
item = bbox_map[name] | |||
width = item['width'] | |||
height = item['height'] | |||
vals = item['objs'] | |||
objs = [] | |||
for line in vals: | |||
data = self._parse_ann_line(line) | |||
if data is None: | |||
continue | |||
objs.append(data) # data is (bbox, kps, cat) | |||
if len(objs) == 0 and not self.test_mode: | |||
continue | |||
data_infos.append( | |||
dict(filename=name, width=width, height=height, objs=objs)) | |||
return data_infos | |||
def get_ann_info(self, idx): | |||
"""Get COCO annotation by index. | |||
Args: | |||
idx (int): Index of data. | |||
Returns: | |||
dict: Annotation info of specified index. | |||
""" | |||
data_info = self.data_infos[idx] | |||
bboxes = [] | |||
keypointss = [] | |||
labels = [] | |||
bboxes_ignore = [] | |||
labels_ignore = [] | |||
for obj in data_info['objs']: | |||
label = self.cat2label[obj['cat']] | |||
bbox = obj['bbox'] | |||
keypoints = obj['kps'] | |||
ignore = obj['ignore'] | |||
if ignore: | |||
bboxes_ignore.append(bbox) | |||
labels_ignore.append(label) | |||
else: | |||
bboxes.append(bbox) | |||
labels.append(label) | |||
keypointss.append(keypoints) | |||
if not bboxes: | |||
bboxes = np.zeros((0, 4)) | |||
labels = np.zeros((0, )) | |||
keypointss = np.zeros((0, self.NK, 3)) | |||
else: | |||
# bboxes = np.array(bboxes, ndmin=2) - 1 | |||
bboxes = np.array(bboxes, ndmin=2) | |||
labels = np.array(labels) | |||
keypointss = np.array(keypointss, ndmin=3) | |||
if not bboxes_ignore: | |||
bboxes_ignore = np.zeros((0, 4)) | |||
labels_ignore = np.zeros((0, )) | |||
else: | |||
bboxes_ignore = np.array(bboxes_ignore, ndmin=2) | |||
labels_ignore = np.array(labels_ignore) | |||
ann = dict( | |||
bboxes=bboxes.astype(np.float32), | |||
labels=labels.astype(np.int64), | |||
keypointss=keypointss.astype(np.float32), | |||
bboxes_ignore=bboxes_ignore.astype(np.float32), | |||
labels_ignore=labels_ignore.astype(np.int64)) | |||
return ann |
@@ -0,0 +1,2 @@ | |||
from .dense_heads import * # noqa: F401,F403 | |||
from .detectors import * # noqa: F401,F403 |
@@ -0,0 +1,3 @@ | |||
from .resnet import ResNetV1e | |||
__all__ = ['ResNetV1e'] |
@@ -0,0 +1,412 @@ | |||
""" | |||
based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/backbones/resnet.py | |||
""" | |||
import torch.nn as nn | |||
import torch.utils.checkpoint as cp | |||
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer, | |||
constant_init, kaiming_init) | |||
from mmcv.runner import load_checkpoint | |||
from mmdet.models.backbones.resnet import BasicBlock, Bottleneck | |||
from mmdet.models.builder import BACKBONES | |||
from mmdet.models.utils import ResLayer | |||
from mmdet.utils import get_root_logger | |||
from torch.nn.modules.batchnorm import _BatchNorm | |||
class ResNet(nn.Module): | |||
"""ResNet backbone. | |||
Args: | |||
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. | |||
stem_channels (int | None): Number of stem channels. If not specified, | |||
it will be the same as `base_channels`. Default: None. | |||
base_channels (int): Number of base channels of res layer. Default: 64. | |||
in_channels (int): Number of input image channels. Default: 3. | |||
num_stages (int): Resnet stages. Default: 4. | |||
strides (Sequence[int]): Strides of the first block of each stage. | |||
dilations (Sequence[int]): Dilation of each stage. | |||
out_indices (Sequence[int]): Output from which stages. | |||
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two | |||
layer is the 3x3 conv layer, otherwise the stride-two layer is | |||
the first 1x1 conv layer. | |||
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv | |||
avg_down (bool): Use AvgPool instead of stride conv when | |||
downsampling in the bottleneck. | |||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode). | |||
-1 means not freezing any parameters. | |||
norm_cfg (dict): Dictionary to construct and config norm layer. | |||
norm_eval (bool): Whether to set norm layers to eval mode, namely, | |||
freeze running stats (mean and var). Note: Effect on Batch Norm | |||
and its variants only. | |||
plugins (list[dict]): List of plugins for stages, each dict contains: | |||
- cfg (dict, required): Cfg dict to build plugin. | |||
- position (str, required): Position inside block to insert | |||
plugin, options are 'after_conv1', 'after_conv2', 'after_conv3'. | |||
- stages (tuple[bool], optional): Stages to apply plugin, length | |||
should be same as 'num_stages'. | |||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some | |||
memory while slowing down the training speed. | |||
zero_init_residual (bool): Whether to use zero init for last norm layer | |||
in resblocks to let them behave as identity. | |||
Example: | |||
>>> from mmdet.models import ResNet | |||
>>> import torch | |||
>>> self = ResNet(depth=18) | |||
>>> self.eval() | |||
>>> inputs = torch.rand(1, 3, 32, 32) | |||
>>> level_outputs = self.forward(inputs) | |||
>>> for level_out in level_outputs: | |||
... print(tuple(level_out.shape)) | |||
(1, 64, 8, 8) | |||
(1, 128, 4, 4) | |||
(1, 256, 2, 2) | |||
(1, 512, 1, 1) | |||
""" | |||
arch_settings = { | |||
0: (BasicBlock, (2, 2, 2, 2)), | |||
18: (BasicBlock, (2, 2, 2, 2)), | |||
19: (BasicBlock, (2, 4, 4, 1)), | |||
20: (BasicBlock, (2, 3, 2, 2)), | |||
22: (BasicBlock, (2, 4, 3, 1)), | |||
24: (BasicBlock, (2, 4, 4, 1)), | |||
26: (BasicBlock, (2, 4, 4, 2)), | |||
28: (BasicBlock, (2, 5, 4, 2)), | |||
29: (BasicBlock, (2, 6, 3, 2)), | |||
30: (BasicBlock, (2, 5, 5, 2)), | |||
32: (BasicBlock, (2, 6, 5, 2)), | |||
34: (BasicBlock, (3, 4, 6, 3)), | |||
35: (BasicBlock, (3, 6, 4, 3)), | |||
38: (BasicBlock, (3, 8, 4, 3)), | |||
40: (BasicBlock, (3, 8, 5, 3)), | |||
50: (Bottleneck, (3, 4, 6, 3)), | |||
56: (Bottleneck, (3, 8, 4, 3)), | |||
68: (Bottleneck, (3, 10, 6, 3)), | |||
74: (Bottleneck, (3, 12, 6, 3)), | |||
101: (Bottleneck, (3, 4, 23, 3)), | |||
152: (Bottleneck, (3, 8, 36, 3)) | |||
} | |||
def __init__(self, | |||
depth, | |||
in_channels=3, | |||
stem_channels=None, | |||
base_channels=64, | |||
num_stages=4, | |||
block_cfg=None, | |||
strides=(1, 2, 2, 2), | |||
dilations=(1, 1, 1, 1), | |||
out_indices=(0, 1, 2, 3), | |||
style='pytorch', | |||
deep_stem=False, | |||
avg_down=False, | |||
no_pool33=False, | |||
frozen_stages=-1, | |||
conv_cfg=None, | |||
norm_cfg=dict(type='BN', requires_grad=True), | |||
norm_eval=True, | |||
dcn=None, | |||
stage_with_dcn=(False, False, False, False), | |||
plugins=None, | |||
with_cp=False, | |||
zero_init_residual=True): | |||
super(ResNet, self).__init__() | |||
if depth not in self.arch_settings: | |||
raise KeyError(f'invalid depth {depth} for resnet') | |||
self.depth = depth | |||
if stem_channels is None: | |||
stem_channels = base_channels | |||
self.stem_channels = stem_channels | |||
self.base_channels = base_channels | |||
self.num_stages = num_stages | |||
assert num_stages >= 1 and num_stages <= 4 | |||
self.strides = strides | |||
self.dilations = dilations | |||
assert len(strides) == len(dilations) == num_stages | |||
self.out_indices = out_indices | |||
assert max(out_indices) < num_stages | |||
self.style = style | |||
self.deep_stem = deep_stem | |||
self.avg_down = avg_down | |||
self.no_pool33 = no_pool33 | |||
self.frozen_stages = frozen_stages | |||
self.conv_cfg = conv_cfg | |||
self.norm_cfg = norm_cfg | |||
self.with_cp = with_cp | |||
self.norm_eval = norm_eval | |||
self.dcn = dcn | |||
self.stage_with_dcn = stage_with_dcn | |||
if dcn is not None: | |||
assert len(stage_with_dcn) == num_stages | |||
self.plugins = plugins | |||
self.zero_init_residual = zero_init_residual | |||
if block_cfg is None: | |||
self.block, stage_blocks = self.arch_settings[depth] | |||
else: | |||
self.block = BasicBlock if block_cfg[ | |||
'block'] == 'BasicBlock' else Bottleneck | |||
stage_blocks = block_cfg['stage_blocks'] | |||
assert len(stage_blocks) >= num_stages | |||
self.stage_blocks = stage_blocks[:num_stages] | |||
self.inplanes = stem_channels | |||
self._make_stem_layer(in_channels, stem_channels) | |||
if block_cfg is not None and 'stage_planes' in block_cfg: | |||
stage_planes = block_cfg['stage_planes'] | |||
else: | |||
stage_planes = [base_channels * 2**i for i in range(num_stages)] | |||
# print('resnet cfg:', stage_blocks, stage_planes) | |||
self.res_layers = [] | |||
for i, num_blocks in enumerate(self.stage_blocks): | |||
stride = strides[i] | |||
dilation = dilations[i] | |||
dcn = self.dcn if self.stage_with_dcn[i] else None | |||
if plugins is not None: | |||
stage_plugins = self.make_stage_plugins(plugins, i) | |||
else: | |||
stage_plugins = None | |||
planes = stage_planes[i] | |||
res_layer = self.make_res_layer( | |||
block=self.block, | |||
inplanes=self.inplanes, | |||
planes=planes, | |||
num_blocks=num_blocks, | |||
stride=stride, | |||
dilation=dilation, | |||
style=self.style, | |||
avg_down=self.avg_down, | |||
with_cp=with_cp, | |||
conv_cfg=conv_cfg, | |||
norm_cfg=norm_cfg, | |||
dcn=dcn, | |||
plugins=stage_plugins) | |||
self.inplanes = planes * self.block.expansion | |||
layer_name = f'layer{i + 1}' | |||
self.add_module(layer_name, res_layer) | |||
self.res_layers.append(layer_name) | |||
self._freeze_stages() | |||
self.feat_dim = self.block.expansion * base_channels * 2**( | |||
len(self.stage_blocks) - 1) | |||
def make_stage_plugins(self, plugins, stage_idx): | |||
"""Make plugins for ResNet ``stage_idx`` th stage. | |||
Currently we support to insert ``context_block``, | |||
``empirical_attention_block``, ``nonlocal_block`` into the backbone | |||
like ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of | |||
Bottleneck. | |||
An example of plugins format could be: | |||
Examples: | |||
>>> plugins=[ | |||
... dict(cfg=dict(type='xxx', arg1='xxx'), | |||
... stages=(False, True, True, True), | |||
... position='after_conv2'), | |||
... dict(cfg=dict(type='yyy'), | |||
... stages=(True, True, True, True), | |||
... position='after_conv3'), | |||
... dict(cfg=dict(type='zzz', postfix='1'), | |||
... stages=(True, True, True, True), | |||
... position='after_conv3'), | |||
... dict(cfg=dict(type='zzz', postfix='2'), | |||
... stages=(True, True, True, True), | |||
... position='after_conv3') | |||
... ] | |||
>>> self = ResNet(depth=18) | |||
>>> stage_plugins = self.make_stage_plugins(plugins, 0) | |||
>>> assert len(stage_plugins) == 3 | |||
Suppose ``stage_idx=0``, the structure of blocks in the stage would be: | |||
.. code-block:: none | |||
conv1-> conv2->conv3->yyy->zzz1->zzz2 | |||
Suppose 'stage_idx=1', the structure of blocks in the stage would be: | |||
.. code-block:: none | |||
conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2 | |||
If stages is missing, the plugin would be applied to all stages. | |||
Args: | |||
plugins (list[dict]): List of plugins cfg to build. The postfix is | |||
required if multiple same type plugins are inserted. | |||
stage_idx (int): Index of stage to build | |||
Returns: | |||
list[dict]: Plugins for current stage | |||
""" | |||
stage_plugins = [] | |||
for plugin in plugins: | |||
plugin = plugin.copy() | |||
stages = plugin.pop('stages', None) | |||
assert stages is None or len(stages) == self.num_stages | |||
# whether to insert plugin into current stage | |||
if stages is None or stages[stage_idx]: | |||
stage_plugins.append(plugin) | |||
return stage_plugins | |||
def make_res_layer(self, **kwargs): | |||
"""Pack all blocks in a stage into a ``ResLayer``.""" | |||
return ResLayer(**kwargs) | |||
@property | |||
def norm1(self): | |||
"""nn.Module: the normalization layer named "norm1" """ | |||
return getattr(self, self.norm1_name) | |||
def _make_stem_layer(self, in_channels, stem_channels): | |||
if self.deep_stem: | |||
self.stem = nn.Sequential( | |||
build_conv_layer( | |||
self.conv_cfg, | |||
in_channels, | |||
stem_channels // 2, | |||
kernel_size=3, | |||
stride=2, | |||
padding=1, | |||
bias=False), | |||
build_norm_layer(self.norm_cfg, stem_channels // 2)[1], | |||
nn.ReLU(inplace=True), | |||
build_conv_layer( | |||
self.conv_cfg, | |||
stem_channels // 2, | |||
stem_channels // 2, | |||
kernel_size=3, | |||
stride=1, | |||
padding=1, | |||
bias=False), | |||
build_norm_layer(self.norm_cfg, stem_channels // 2)[1], | |||
nn.ReLU(inplace=True), | |||
build_conv_layer( | |||
self.conv_cfg, | |||
stem_channels // 2, | |||
stem_channels, | |||
kernel_size=3, | |||
stride=1, | |||
padding=1, | |||
bias=False), | |||
build_norm_layer(self.norm_cfg, stem_channels)[1], | |||
nn.ReLU(inplace=True)) | |||
else: | |||
self.conv1 = build_conv_layer( | |||
self.conv_cfg, | |||
in_channels, | |||
stem_channels, | |||
kernel_size=7, | |||
stride=2, | |||
padding=3, | |||
bias=False) | |||
self.norm1_name, norm1 = build_norm_layer( | |||
self.norm_cfg, stem_channels, postfix=1) | |||
self.add_module(self.norm1_name, norm1) | |||
self.relu = nn.ReLU(inplace=True) | |||
if self.no_pool33: | |||
assert self.deep_stem | |||
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) | |||
else: | |||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||
def _freeze_stages(self): | |||
if self.frozen_stages >= 0: | |||
if self.deep_stem: | |||
self.stem.eval() | |||
for param in self.stem.parameters(): | |||
param.requires_grad = False | |||
else: | |||
self.norm1.eval() | |||
for m in [self.conv1, self.norm1]: | |||
for param in m.parameters(): | |||
param.requires_grad = False | |||
for i in range(1, self.frozen_stages + 1): | |||
m = getattr(self, f'layer{i}') | |||
m.eval() | |||
for param in m.parameters(): | |||
param.requires_grad = False | |||
def init_weights(self, pretrained=None): | |||
"""Initialize the weights in backbone. | |||
Args: | |||
pretrained (str, optional): Path to pre-trained weights. | |||
Defaults to None. | |||
""" | |||
if isinstance(pretrained, str): | |||
logger = get_root_logger() | |||
load_checkpoint(self, pretrained, strict=False, logger=logger) | |||
elif pretrained is None: | |||
for m in self.modules(): | |||
if isinstance(m, nn.Conv2d): | |||
kaiming_init(m) | |||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)): | |||
constant_init(m, 1) | |||
if self.dcn is not None: | |||
for m in self.modules(): | |||
if isinstance(m, Bottleneck) and hasattr( | |||
m.conv2, 'conv_offset'): | |||
constant_init(m.conv2.conv_offset, 0) | |||
if self.zero_init_residual: | |||
for m in self.modules(): | |||
if isinstance(m, Bottleneck): | |||
constant_init(m.norm3, 0) | |||
elif isinstance(m, BasicBlock): | |||
constant_init(m.norm2, 0) | |||
else: | |||
raise TypeError('pretrained must be a str or None') | |||
def forward(self, x): | |||
"""Forward function.""" | |||
if self.deep_stem: | |||
x = self.stem(x) | |||
else: | |||
x = self.conv1(x) | |||
x = self.norm1(x) | |||
x = self.relu(x) | |||
x = self.maxpool(x) | |||
outs = [] | |||
for i, layer_name in enumerate(self.res_layers): | |||
res_layer = getattr(self, layer_name) | |||
x = res_layer(x) | |||
if i in self.out_indices: | |||
outs.append(x) | |||
return tuple(outs) | |||
def train(self, mode=True): | |||
"""Convert the model into training mode while keep normalization layer | |||
freezed.""" | |||
super(ResNet, self).train(mode) | |||
self._freeze_stages() | |||
if mode and self.norm_eval: | |||
for m in self.modules(): | |||
# trick: eval have effect on BatchNorm only | |||
if isinstance(m, _BatchNorm): | |||
m.eval() | |||
@BACKBONES.register_module() | |||
class ResNetV1e(ResNet): | |||
r"""ResNetV1d variant described in `Bag of Tricks | |||
<https://arxiv.org/pdf/1812.01187.pdf>`_. | |||
Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in | |||
the input stem with three 3x3 convs. And in the downsampling block, a 2x2 | |||
avg_pool with stride 2 is added before conv, whose stride is changed to 1. | |||
Compared with ResNetV1d, ResNetV1e change maxpooling from 3x3 to 2x2 pad=1 | |||
""" | |||
def __init__(self, **kwargs): | |||
super(ResNetV1e, self).__init__( | |||
deep_stem=True, avg_down=True, no_pool33=True, **kwargs) |
@@ -0,0 +1,3 @@ | |||
from .scrfd_head import SCRFDHead | |||
__all__ = ['SCRFDHead'] |
@@ -0,0 +1,3 @@ | |||
from .scrfd import SCRFD | |||
__all__ = ['SCRFD'] |
@@ -0,0 +1,109 @@ | |||
""" | |||
based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/detectors/scrfd.py | |||
""" | |||
import torch | |||
from mmdet.models.builder import DETECTORS | |||
from mmdet.models.detectors.single_stage import SingleStageDetector | |||
from ....mmdet_patch.core.bbox import bbox2result | |||
@DETECTORS.register_module() | |||
class SCRFD(SingleStageDetector): | |||
def __init__(self, | |||
backbone, | |||
neck, | |||
bbox_head, | |||
train_cfg=None, | |||
test_cfg=None, | |||
pretrained=None): | |||
super(SCRFD, self).__init__(backbone, neck, bbox_head, train_cfg, | |||
test_cfg, pretrained) | |||
def forward_train(self, | |||
img, | |||
img_metas, | |||
gt_bboxes, | |||
gt_labels, | |||
gt_keypointss=None, | |||
gt_bboxes_ignore=None): | |||
""" | |||
Args: | |||
img (Tensor): Input images of shape (N, C, H, W). | |||
Typically these should be mean centered and std scaled. | |||
img_metas (list[dict]): A List of image info dict where each dict | |||
has: 'img_shape', 'scale_factor', 'flip', and may also contain | |||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |||
For details on the values of these keys see | |||
:class:`mmdet.datasets.pipelines.Collect`. | |||
gt_bboxes (list[Tensor]): Each item are the truth boxes for each | |||
image in [tl_x, tl_y, br_x, br_y] format. | |||
gt_labels (list[Tensor]): Class indices corresponding to each box | |||
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding | |||
boxes can be ignored when computing the loss. | |||
Returns: | |||
dict[str, Tensor]: A dictionary of loss components. | |||
""" | |||
super(SingleStageDetector, self).forward_train(img, img_metas) | |||
x = self.extract_feat(img) | |||
losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes, | |||
gt_labels, gt_keypointss, | |||
gt_bboxes_ignore) | |||
return losses | |||
def simple_test(self, img, img_metas, rescale=False): | |||
"""Test function without test time augmentation. | |||
Args: | |||
imgs (list[torch.Tensor]): List of multiple images | |||
img_metas (list[dict]): List of image information. | |||
rescale (bool, optional): Whether to rescale the results. | |||
Defaults to False. | |||
Returns: | |||
list[list[np.ndarray]]: BBox results of each image and classes. | |||
The outer list corresponds to each image. The inner list | |||
corresponds to each class. | |||
""" | |||
x = self.extract_feat(img) | |||
outs = self.bbox_head(x) | |||
if torch.onnx.is_in_onnx_export(): | |||
print('single_stage.py in-onnx-export') | |||
print(outs.__class__) | |||
cls_score, bbox_pred, kps_pred = outs | |||
for c in cls_score: | |||
print(c.shape) | |||
for c in bbox_pred: | |||
print(c.shape) | |||
if self.bbox_head.use_kps: | |||
for c in kps_pred: | |||
print(c.shape) | |||
return (cls_score, bbox_pred, kps_pred) | |||
else: | |||
return (cls_score, bbox_pred) | |||
bbox_list = self.bbox_head.get_bboxes( | |||
*outs, img_metas, rescale=rescale) | |||
# return kps if use_kps | |||
if len(bbox_list[0]) == 2: | |||
bbox_results = [ | |||
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) | |||
for det_bboxes, det_labels in bbox_list | |||
] | |||
elif len(bbox_list[0]) == 3: | |||
bbox_results = [ | |||
bbox2result( | |||
det_bboxes, | |||
det_labels, | |||
self.bbox_head.num_classes, | |||
kps=det_kps) | |||
for det_bboxes, det_labels, det_kps in bbox_list | |||
] | |||
return bbox_results | |||
def feature_test(self, img): | |||
x = self.extract_feat(img) | |||
outs = self.bbox_head(x) | |||
return outs |
@@ -0,0 +1,50 @@ | |||
import cv2 | |||
import numpy as np | |||
from skimage import transform as trans | |||
def align_face(image, size, lmks): | |||
dst_w = size[1] | |||
dst_h = size[0] | |||
# landmark calculation of dst images | |||
base_w = 96 | |||
base_h = 112 | |||
assert (dst_w >= base_w) | |||
assert (dst_h >= base_h) | |||
base_lmk = [ | |||
30.2946, 51.6963, 65.5318, 51.5014, 48.0252, 71.7366, 33.5493, 92.3655, | |||
62.7299, 92.2041 | |||
] | |||
dst_lmk = np.array(base_lmk).reshape((5, 2)).astype(np.float32) | |||
if dst_w != base_w: | |||
slide = (dst_w - base_w) / 2 | |||
dst_lmk[:, 0] += slide | |||
if dst_h != base_h: | |||
slide = (dst_h - base_h) / 2 | |||
dst_lmk[:, 1] += slide | |||
src_lmk = lmks | |||
# using skimage method | |||
tform = trans.SimilarityTransform() | |||
tform.estimate(src_lmk, dst_lmk) | |||
t = tform.params[0:2, :] | |||
assert (image.shape[2] == 3) | |||
dst_image = cv2.warpAffine(image.copy(), t, (dst_w, dst_h)) | |||
dst_pts = GetAffinePoints(src_lmk, t) | |||
return dst_image, dst_pts | |||
def GetAffinePoints(pts_in, trans): | |||
pts_out = pts_in.copy() | |||
assert (pts_in.shape[1] == 2) | |||
for k in range(pts_in.shape[0]): | |||
pts_out[k, 0] = pts_in[k, 0] * trans[0, 0] + pts_in[k, 1] * trans[ | |||
0, 1] + trans[0, 2] | |||
pts_out[k, 1] = pts_in[k, 0] * trans[1, 0] + pts_in[k, 1] * trans[ | |||
1, 1] + trans[1, 2] | |||
return pts_out |
@@ -0,0 +1,31 @@ | |||
from .model_irse import (IR_18, IR_34, IR_50, IR_101, IR_152, IR_200, IR_SE_50, | |||
IR_SE_101, IR_SE_152, IR_SE_200) | |||
from .model_resnet import ResNet_50, ResNet_101, ResNet_152 | |||
_model_dict = { | |||
'ResNet_50': ResNet_50, | |||
'ResNet_101': ResNet_101, | |||
'ResNet_152': ResNet_152, | |||
'IR_18': IR_18, | |||
'IR_34': IR_34, | |||
'IR_50': IR_50, | |||
'IR_101': IR_101, | |||
'IR_152': IR_152, | |||
'IR_200': IR_200, | |||
'IR_SE_50': IR_SE_50, | |||
'IR_SE_101': IR_SE_101, | |||
'IR_SE_152': IR_SE_152, | |||
'IR_SE_200': IR_SE_200 | |||
} | |||
def get_model(key): | |||
""" Get different backbone network by key, | |||
support ResNet50, ResNet_101, ResNet_152 | |||
IR_18, IR_34, IR_50, IR_101, IR_152, IR_200, | |||
IR_SE_50, IR_SE_101, IR_SE_152, IR_SE_200. | |||
""" | |||
if key in _model_dict.keys(): | |||
return _model_dict[key] | |||
else: | |||
raise KeyError('not support model {}'.format(key)) |
@@ -0,0 +1,68 @@ | |||
import torch | |||
import torch.nn as nn | |||
from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Linear, Module, ReLU, | |||
Sigmoid) | |||
def initialize_weights(modules): | |||
""" Weight initilize, conv2d and linear is initialized with kaiming_normal | |||
""" | |||
for m in modules: | |||
if isinstance(m, nn.Conv2d): | |||
nn.init.kaiming_normal_( | |||
m.weight, mode='fan_out', nonlinearity='relu') | |||
if m.bias is not None: | |||
m.bias.data.zero_() | |||
elif isinstance(m, nn.BatchNorm2d): | |||
m.weight.data.fill_(1) | |||
m.bias.data.zero_() | |||
elif isinstance(m, nn.Linear): | |||
nn.init.kaiming_normal_( | |||
m.weight, mode='fan_out', nonlinearity='relu') | |||
if m.bias is not None: | |||
m.bias.data.zero_() | |||
class Flatten(Module): | |||
""" Flat tensor | |||
""" | |||
def forward(self, input): | |||
return input.view(input.size(0), -1) | |||
class SEModule(Module): | |||
""" SE block | |||
""" | |||
def __init__(self, channels, reduction): | |||
super(SEModule, self).__init__() | |||
self.avg_pool = nn.AdaptiveAvgPool2d(1) | |||
self.fc1 = Conv2d( | |||
channels, | |||
channels // reduction, | |||
kernel_size=1, | |||
padding=0, | |||
bias=False) | |||
nn.init.xavier_uniform_(self.fc1.weight.data) | |||
self.relu = ReLU(inplace=True) | |||
self.fc2 = Conv2d( | |||
channels // reduction, | |||
channels, | |||
kernel_size=1, | |||
padding=0, | |||
bias=False) | |||
self.sigmoid = Sigmoid() | |||
def forward(self, x): | |||
module_input = x | |||
x = self.avg_pool(x) | |||
x = self.fc1(x) | |||
x = self.relu(x) | |||
x = self.fc2(x) | |||
x = self.sigmoid(x) | |||
return module_input * x |
@@ -0,0 +1,279 @@ | |||
# based on: | |||
# https://github.com/ZhaoJ9014/face.evoLVe.PyTorch/blob/master/backbone/model_irse.py | |||
from collections import namedtuple | |||
from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, | |||
MaxPool2d, Module, PReLU, Sequential) | |||
from .common import Flatten, SEModule, initialize_weights | |||
class BasicBlockIR(Module): | |||
""" BasicBlock for IRNet | |||
""" | |||
def __init__(self, in_channel, depth, stride): | |||
super(BasicBlockIR, self).__init__() | |||
if in_channel == depth: | |||
self.shortcut_layer = MaxPool2d(1, stride) | |||
else: | |||
self.shortcut_layer = Sequential( | |||
Conv2d(in_channel, depth, (1, 1), stride, bias=False), | |||
BatchNorm2d(depth)) | |||
self.res_layer = Sequential( | |||
BatchNorm2d(in_channel), | |||
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), | |||
BatchNorm2d(depth), PReLU(depth), | |||
Conv2d(depth, depth, (3, 3), stride, 1, bias=False), | |||
BatchNorm2d(depth)) | |||
def forward(self, x): | |||
shortcut = self.shortcut_layer(x) | |||
res = self.res_layer(x) | |||
return res + shortcut | |||
class BottleneckIR(Module): | |||
""" BasicBlock with bottleneck for IRNet | |||
""" | |||
def __init__(self, in_channel, depth, stride): | |||
super(BottleneckIR, self).__init__() | |||
reduction_channel = depth // 4 | |||
if in_channel == depth: | |||
self.shortcut_layer = MaxPool2d(1, stride) | |||
else: | |||
self.shortcut_layer = Sequential( | |||
Conv2d(in_channel, depth, (1, 1), stride, bias=False), | |||
BatchNorm2d(depth)) | |||
self.res_layer = Sequential( | |||
BatchNorm2d(in_channel), | |||
Conv2d( | |||
in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False), | |||
BatchNorm2d(reduction_channel), PReLU(reduction_channel), | |||
Conv2d( | |||
reduction_channel, | |||
reduction_channel, (3, 3), (1, 1), | |||
1, | |||
bias=False), BatchNorm2d(reduction_channel), | |||
PReLU(reduction_channel), | |||
Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False), | |||
BatchNorm2d(depth)) | |||
def forward(self, x): | |||
shortcut = self.shortcut_layer(x) | |||
res = self.res_layer(x) | |||
return res + shortcut | |||
class BasicBlockIRSE(BasicBlockIR): | |||
def __init__(self, in_channel, depth, stride): | |||
super(BasicBlockIRSE, self).__init__(in_channel, depth, stride) | |||
self.res_layer.add_module('se_block', SEModule(depth, 16)) | |||
class BottleneckIRSE(BottleneckIR): | |||
def __init__(self, in_channel, depth, stride): | |||
super(BottleneckIRSE, self).__init__(in_channel, depth, stride) | |||
self.res_layer.add_module('se_block', SEModule(depth, 16)) | |||
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): | |||
'''A named tuple describing a ResNet block.''' | |||
def get_block(in_channel, depth, num_units, stride=2): | |||
return [Bottleneck(in_channel, depth, stride)] +\ | |||
[Bottleneck(depth, depth, 1) for i in range(num_units - 1)] | |||
def get_blocks(num_layers): | |||
if num_layers == 18: | |||
blocks = [ | |||
get_block(in_channel=64, depth=64, num_units=2), | |||
get_block(in_channel=64, depth=128, num_units=2), | |||
get_block(in_channel=128, depth=256, num_units=2), | |||
get_block(in_channel=256, depth=512, num_units=2) | |||
] | |||
elif num_layers == 34: | |||
blocks = [ | |||
get_block(in_channel=64, depth=64, num_units=3), | |||
get_block(in_channel=64, depth=128, num_units=4), | |||
get_block(in_channel=128, depth=256, num_units=6), | |||
get_block(in_channel=256, depth=512, num_units=3) | |||
] | |||
elif num_layers == 50: | |||
blocks = [ | |||
get_block(in_channel=64, depth=64, num_units=3), | |||
get_block(in_channel=64, depth=128, num_units=4), | |||
get_block(in_channel=128, depth=256, num_units=14), | |||
get_block(in_channel=256, depth=512, num_units=3) | |||
] | |||
elif num_layers == 100: | |||
blocks = [ | |||
get_block(in_channel=64, depth=64, num_units=3), | |||
get_block(in_channel=64, depth=128, num_units=13), | |||
get_block(in_channel=128, depth=256, num_units=30), | |||
get_block(in_channel=256, depth=512, num_units=3) | |||
] | |||
elif num_layers == 152: | |||
blocks = [ | |||
get_block(in_channel=64, depth=256, num_units=3), | |||
get_block(in_channel=256, depth=512, num_units=8), | |||
get_block(in_channel=512, depth=1024, num_units=36), | |||
get_block(in_channel=1024, depth=2048, num_units=3) | |||
] | |||
elif num_layers == 200: | |||
blocks = [ | |||
get_block(in_channel=64, depth=256, num_units=3), | |||
get_block(in_channel=256, depth=512, num_units=24), | |||
get_block(in_channel=512, depth=1024, num_units=36), | |||
get_block(in_channel=1024, depth=2048, num_units=3) | |||
] | |||
return blocks | |||
class Backbone(Module): | |||
def __init__(self, input_size, num_layers, mode='ir'): | |||
""" Args: | |||
input_size: input_size of backbone | |||
num_layers: num_layers of backbone | |||
mode: support ir or irse | |||
""" | |||
super(Backbone, self).__init__() | |||
assert input_size[0] in [112, 224], \ | |||
'input_size should be [112, 112] or [224, 224]' | |||
assert num_layers in [18, 34, 50, 100, 152, 200], \ | |||
'num_layers should be 18, 34, 50, 100 or 152' | |||
assert mode in ['ir', 'ir_se'], \ | |||
'mode should be ir or ir_se' | |||
self.input_layer = Sequential( | |||
Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), | |||
PReLU(64)) | |||
blocks = get_blocks(num_layers) | |||
if num_layers <= 100: | |||
if mode == 'ir': | |||
unit_module = BasicBlockIR | |||
elif mode == 'ir_se': | |||
unit_module = BasicBlockIRSE | |||
output_channel = 512 | |||
else: | |||
if mode == 'ir': | |||
unit_module = BottleneckIR | |||
elif mode == 'ir_se': | |||
unit_module = BottleneckIRSE | |||
output_channel = 2048 | |||
if input_size[0] == 112: | |||
self.output_layer = Sequential( | |||
BatchNorm2d(output_channel), Dropout(0.4), Flatten(), | |||
Linear(output_channel * 7 * 7, 512), | |||
BatchNorm1d(512, affine=False)) | |||
else: | |||
self.output_layer = Sequential( | |||
BatchNorm2d(output_channel), Dropout(0.4), Flatten(), | |||
Linear(output_channel * 14 * 14, 512), | |||
BatchNorm1d(512, affine=False)) | |||
modules = [] | |||
for block in blocks: | |||
for bottleneck in block: | |||
modules.append( | |||
unit_module(bottleneck.in_channel, bottleneck.depth, | |||
bottleneck.stride)) | |||
self.body = Sequential(*modules) | |||
initialize_weights(self.modules()) | |||
def forward(self, x): | |||
x = self.input_layer(x) | |||
x = self.body(x) | |||
x = self.output_layer(x) | |||
return x | |||
def IR_18(input_size): | |||
""" Constructs a ir-18 model. | |||
""" | |||
model = Backbone(input_size, 18, 'ir') | |||
return model | |||
def IR_34(input_size): | |||
""" Constructs a ir-34 model. | |||
""" | |||
model = Backbone(input_size, 34, 'ir') | |||
return model | |||
def IR_50(input_size): | |||
""" Constructs a ir-50 model. | |||
""" | |||
model = Backbone(input_size, 50, 'ir') | |||
return model | |||
def IR_101(input_size): | |||
""" Constructs a ir-101 model. | |||
""" | |||
model = Backbone(input_size, 100, 'ir') | |||
return model | |||
def IR_152(input_size): | |||
""" Constructs a ir-152 model. | |||
""" | |||
model = Backbone(input_size, 152, 'ir') | |||
return model | |||
def IR_200(input_size): | |||
""" Constructs a ir-200 model. | |||
""" | |||
model = Backbone(input_size, 200, 'ir') | |||
return model | |||
def IR_SE_50(input_size): | |||
""" Constructs a ir_se-50 model. | |||
""" | |||
model = Backbone(input_size, 50, 'ir_se') | |||
return model | |||
def IR_SE_101(input_size): | |||
""" Constructs a ir_se-101 model. | |||
""" | |||
model = Backbone(input_size, 100, 'ir_se') | |||
return model | |||
def IR_SE_152(input_size): | |||
""" Constructs a ir_se-152 model. | |||
""" | |||
model = Backbone(input_size, 152, 'ir_se') | |||
return model | |||
def IR_SE_200(input_size): | |||
""" Constructs a ir_se-200 model. | |||
""" | |||
model = Backbone(input_size, 200, 'ir_se') | |||
return model |
@@ -0,0 +1,162 @@ | |||
# based on: | |||
# https://github.com/ZhaoJ9014/face.evoLVe.PyTorch/blob/master/backbone/model_resnet.py | |||
import torch.nn as nn | |||
from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, | |||
MaxPool2d, Module, ReLU, Sequential) | |||
from .common import initialize_weights | |||
def conv3x3(in_planes, out_planes, stride=1): | |||
""" 3x3 convolution with padding | |||
""" | |||
return Conv2d( | |||
in_planes, | |||
out_planes, | |||
kernel_size=3, | |||
stride=stride, | |||
padding=1, | |||
bias=False) | |||
def conv1x1(in_planes, out_planes, stride=1): | |||
""" 1x1 convolution | |||
""" | |||
return Conv2d( | |||
in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |||
class Bottleneck(Module): | |||
expansion = 4 | |||
def __init__(self, inplanes, planes, stride=1, downsample=None): | |||
super(Bottleneck, self).__init__() | |||
self.conv1 = conv1x1(inplanes, planes) | |||
self.bn1 = BatchNorm2d(planes) | |||
self.conv2 = conv3x3(planes, planes, stride) | |||
self.bn2 = BatchNorm2d(planes) | |||
self.conv3 = conv1x1(planes, planes * self.expansion) | |||
self.bn3 = BatchNorm2d(planes * self.expansion) | |||
self.relu = ReLU(inplace=True) | |||
self.downsample = downsample | |||
self.stride = stride | |||
def forward(self, x): | |||
identity = x | |||
out = self.conv1(x) | |||
out = self.bn1(out) | |||
out = self.relu(out) | |||
out = self.conv2(out) | |||
out = self.bn2(out) | |||
out = self.relu(out) | |||
out = self.conv3(out) | |||
out = self.bn3(out) | |||
if self.downsample is not None: | |||
identity = self.downsample(x) | |||
out += identity | |||
out = self.relu(out) | |||
return out | |||
class ResNet(Module): | |||
""" ResNet backbone | |||
""" | |||
def __init__(self, input_size, block, layers, zero_init_residual=True): | |||
""" Args: | |||
input_size: input_size of backbone | |||
block: block function | |||
layers: layers in each block | |||
""" | |||
super(ResNet, self).__init__() | |||
assert input_size[0] in [112, 224],\ | |||
'input_size should be [112, 112] or [224, 224]' | |||
self.inplanes = 64 | |||
self.conv1 = Conv2d( | |||
3, 64, kernel_size=7, stride=2, padding=3, bias=False) | |||
self.bn1 = BatchNorm2d(64) | |||
self.relu = ReLU(inplace=True) | |||
self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1) | |||
self.layer1 = self._make_layer(block, 64, layers[0]) | |||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | |||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | |||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2) | |||
self.bn_o1 = BatchNorm2d(2048) | |||
self.dropout = Dropout() | |||
if input_size[0] == 112: | |||
self.fc = Linear(2048 * 4 * 4, 512) | |||
else: | |||
self.fc = Linear(2048 * 7 * 7, 512) | |||
self.bn_o2 = BatchNorm1d(512) | |||
initialize_weights(self.modules) | |||
if zero_init_residual: | |||
for m in self.modules(): | |||
if isinstance(m, Bottleneck): | |||
nn.init.constant_(m.bn3.weight, 0) | |||
def _make_layer(self, block, planes, blocks, stride=1): | |||
downsample = None | |||
if stride != 1 or self.inplanes != planes * block.expansion: | |||
downsample = Sequential( | |||
conv1x1(self.inplanes, planes * block.expansion, stride), | |||
BatchNorm2d(planes * block.expansion), | |||
) | |||
layers = [] | |||
layers.append(block(self.inplanes, planes, stride, downsample)) | |||
self.inplanes = planes * block.expansion | |||
for _ in range(1, blocks): | |||
layers.append(block(self.inplanes, planes)) | |||
return Sequential(*layers) | |||
def forward(self, x): | |||
x = self.conv1(x) | |||
x = self.bn1(x) | |||
x = self.relu(x) | |||
x = self.maxpool(x) | |||
x = self.layer1(x) | |||
x = self.layer2(x) | |||
x = self.layer3(x) | |||
x = self.layer4(x) | |||
x = self.bn_o1(x) | |||
x = self.dropout(x) | |||
x = x.view(x.size(0), -1) | |||
x = self.fc(x) | |||
x = self.bn_o2(x) | |||
return x | |||
def ResNet_50(input_size, **kwargs): | |||
""" Constructs a ResNet-50 model. | |||
""" | |||
model = ResNet(input_size, Bottleneck, [3, 4, 6, 3], **kwargs) | |||
return model | |||
def ResNet_101(input_size, **kwargs): | |||
""" Constructs a ResNet-101 model. | |||
""" | |||
model = ResNet(input_size, Bottleneck, [3, 4, 23, 3], **kwargs) | |||
return model | |||
def ResNet_152(input_size, **kwargs): | |||
""" Constructs a ResNet-152 model. | |||
""" | |||
model = ResNet(input_size, Bottleneck, [3, 8, 36, 3], **kwargs) | |||
return model |
@@ -13,6 +13,7 @@ class OutputKeys(object): | |||
POSES = 'poses' | |||
CAPTION = 'caption' | |||
BOXES = 'boxes' | |||
KEYPOINTS = 'keypoints' | |||
MASKS = 'masks' | |||
TEXT = 'text' | |||
POLYGONS = 'polygons' | |||
@@ -55,6 +56,31 @@ TASK_OUTPUTS = { | |||
Tasks.object_detection: | |||
[OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.BOXES], | |||
# face detection result for single sample | |||
# { | |||
# "scores": [0.9, 0.1, 0.05, 0.05] | |||
# "boxes": [ | |||
# [x1, y1, x2, y2], | |||
# [x1, y1, x2, y2], | |||
# [x1, y1, x2, y2], | |||
# [x1, y1, x2, y2], | |||
# ], | |||
# "keypoints": [ | |||
# [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5], | |||
# [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5], | |||
# [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5], | |||
# [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5], | |||
# ], | |||
# } | |||
Tasks.face_detection: | |||
[OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS], | |||
# face recognition result for single sample | |||
# { | |||
# "img_embedding": np.array with shape [1, D], | |||
# } | |||
Tasks.face_recognition: [OutputKeys.IMG_EMBEDDING], | |||
# instance segmentation result for single sample | |||
# { | |||
# "scores": [0.9, 0.1, 0.05, 0.05], | |||
@@ -255,7 +255,11 @@ class Pipeline(ABC): | |||
elif isinstance(data, InputFeatures): | |||
return data | |||
else: | |||
raise ValueError(f'Unsupported data type {type(data)}') | |||
import mmcv | |||
if isinstance(data, mmcv.parallel.data_container.DataContainer): | |||
return data | |||
else: | |||
raise ValueError(f'Unsupported data type {type(data)}') | |||
def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]: | |||
preprocess_params = kwargs.get('preprocess_params') | |||
@@ -80,6 +80,10 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
Tasks.text_to_image_synthesis: | |||
(Pipelines.text_to_image_synthesis, | |||
'damo/cv_imagen_text-to-image-synthesis_tiny'), | |||
Tasks.face_detection: (Pipelines.face_detection, | |||
'damo/cv_resnet_facedetection_scrfd10gkps'), | |||
Tasks.face_recognition: (Pipelines.face_recognition, | |||
'damo/cv_ir101_facerecognition_cfglint'), | |||
Tasks.video_multi_modal_embedding: | |||
(Pipelines.video_multi_modal_embedding, | |||
'damo/multi_modal_clip_vtretrival_msrvtt_53'), | |||
@@ -5,44 +5,50 @@ from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .action_recognition_pipeline import ActionRecognitionPipeline | |||
from .animal_recog_pipeline import AnimalRecogPipeline | |||
from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline | |||
from .live_category_pipeline import LiveCategoryPipeline | |||
from .image_classification_pipeline import GeneralImageClassificationPipeline | |||
from .animal_recognition_pipeline import AnimalRecognitionPipeline | |||
from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline | |||
from .face_detection_pipeline import FaceDetectionPipeline | |||
from .face_recognition_pipeline import FaceRecognitionPipeline | |||
from .face_image_generation_pipeline import FaceImageGenerationPipeline | |||
from .image_cartoon_pipeline import ImageCartoonPipeline | |||
from .image_classification_pipeline import GeneralImageClassificationPipeline | |||
from .image_denoise_pipeline import ImageDenoisePipeline | |||
from .image_color_enhance_pipeline import ImageColorEnhancePipeline | |||
from .image_colorization_pipeline import ImageColorizationPipeline | |||
from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline | |||
from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline | |||
from .video_category_pipeline import VideoCategoryPipeline | |||
from .image_matting_pipeline import ImageMattingPipeline | |||
from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | |||
from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline | |||
from .style_transfer_pipeline import StyleTransferPipeline | |||
from .live_category_pipeline import LiveCategoryPipeline | |||
from .ocr_detection_pipeline import OCRDetectionPipeline | |||
from .video_category_pipeline import VideoCategoryPipeline | |||
from .virtual_tryon_pipeline import VirtualTryonPipeline | |||
else: | |||
_import_structure = { | |||
'action_recognition_pipeline': ['ActionRecognitionPipeline'], | |||
'animal_recog_pipeline': ['AnimalRecogPipeline'], | |||
'cmdssl_video_embedding_pipleline': ['CMDSSLVideoEmbeddingPipeline'], | |||
'animal_recognition_pipeline': ['AnimalRecognitionPipeline'], | |||
'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'], | |||
'face_detection_pipeline': ['FaceDetectionPipeline'], | |||
'face_image_generation_pipeline': ['FaceImageGenerationPipeline'], | |||
'face_recognition_pipeline': ['FaceRecognitionPipeline'], | |||
'image_classification_pipeline': | |||
['GeneralImageClassificationPipeline'], | |||
'image_cartoon_pipeline': ['ImageCartoonPipeline'], | |||
'image_denoise_pipeline': ['ImageDenoisePipeline'], | |||
'image_color_enhance_pipeline': ['ImageColorEnhancePipeline'], | |||
'virtual_tryon_pipeline': ['VirtualTryonPipeline'], | |||
'image_colorization_pipeline': ['ImageColorizationPipeline'], | |||
'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'], | |||
'image_denoise_pipeline': ['ImageDenoisePipeline'], | |||
'face_image_generation_pipeline': ['FaceImageGenerationPipeline'], | |||
'image_cartoon_pipeline': ['ImageCartoonPipeline'], | |||
'image_matting_pipeline': ['ImageMattingPipeline'], | |||
'style_transfer_pipeline': ['StyleTransferPipeline'], | |||
'ocr_detection_pipeline': ['OCRDetectionPipeline'], | |||
'image_instance_segmentation_pipeline': | |||
['ImageInstanceSegmentationPipeline'], | |||
'video_category_pipeline': ['VideoCategoryPipeline'], | |||
'image_matting_pipeline': ['ImageMattingPipeline'], | |||
'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'], | |||
'image_to_image_translation_pipeline': | |||
['Image2ImageTranslationPipeline'], | |||
'live_category_pipeline': ['LiveCategoryPipeline'], | |||
'ocr_detection_pipeline': ['OCRDetectionPipeline'], | |||
'style_transfer_pipeline': ['StyleTransferPipeline'], | |||
'video_category_pipeline': ['VideoCategoryPipeline'], | |||
'virtual_tryon_pipeline': ['VirtualTryonPipeline'], | |||
} | |||
import sys | |||
@@ -23,7 +23,7 @@ class ActionRecognitionPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
use `model` and `preprocessor` to create a kws pipeline for prediction | |||
use `model` to create a action recognition pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
@@ -22,11 +22,11 @@ logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.image_classification, module_name=Pipelines.animal_recognation) | |||
class AnimalRecogPipeline(Pipeline): | |||
class AnimalRecognitionPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
use `model` and `preprocessor` to create a kws pipeline for prediction | |||
use `model` to create a animal recognition pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" |
@@ -24,7 +24,7 @@ class CMDSSLVideoEmbeddingPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
use `model` and `preprocessor` to create a kws pipeline for prediction | |||
use `model` to create a CMDSSL Video Embedding pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" |
@@ -0,0 +1,105 @@ | |||
import os.path as osp | |||
from typing import Any, Dict | |||
import cv2 | |||
import numpy as np | |||
import PIL | |||
import torch | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import LoadImage | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.face_detection, module_name=Pipelines.face_detection) | |||
class FaceDetectionPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
use `model` to create a face detection pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
super().__init__(model=model, **kwargs) | |||
from mmcv import Config | |||
from mmcv.parallel import MMDataParallel | |||
from mmcv.runner import load_checkpoint | |||
from mmdet.models import build_detector | |||
from modelscope.models.cv.face_detection.mmdet_patch.datasets import RetinaFaceDataset | |||
from modelscope.models.cv.face_detection.mmdet_patch.datasets.pipelines import RandomSquareCrop | |||
from modelscope.models.cv.face_detection.mmdet_patch.models.backbones import ResNetV1e | |||
from modelscope.models.cv.face_detection.mmdet_patch.models.dense_heads import SCRFDHead | |||
from modelscope.models.cv.face_detection.mmdet_patch.models.detectors import SCRFD | |||
cfg = Config.fromfile(osp.join(model, 'mmcv_scrfd_10g_bnkps.py')) | |||
detector = build_detector( | |||
cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) | |||
ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_BIN_FILE) | |||
logger.info(f'loading model from {ckpt_path}') | |||
device = torch.device( | |||
f'cuda:{0}' if torch.cuda.is_available() else 'cpu') | |||
load_checkpoint(detector, ckpt_path, map_location=device) | |||
detector = MMDataParallel(detector, device_ids=[0]) | |||
detector.eval() | |||
self.detector = detector | |||
logger.info('load model done') | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
img = LoadImage.convert_to_ndarray(input) | |||
img = img.astype(np.float32) | |||
pre_pipeline = [ | |||
dict( | |||
type='MultiScaleFlipAug', | |||
img_scale=(640, 640), | |||
flip=False, | |||
transforms=[ | |||
dict(type='Resize', keep_ratio=True), | |||
dict(type='RandomFlip', flip_ratio=0.0), | |||
dict( | |||
type='Normalize', | |||
mean=[127.5, 127.5, 127.5], | |||
std=[128.0, 128.0, 128.0], | |||
to_rgb=False), | |||
dict(type='Pad', size=(640, 640), pad_val=0), | |||
dict(type='ImageToTensor', keys=['img']), | |||
dict(type='Collect', keys=['img']) | |||
]) | |||
] | |||
from mmdet.datasets.pipelines import Compose | |||
pipeline = Compose(pre_pipeline) | |||
result = {} | |||
result['filename'] = '' | |||
result['ori_filename'] = '' | |||
result['img'] = img | |||
result['img_shape'] = img.shape | |||
result['ori_shape'] = img.shape | |||
result['img_fields'] = ['img'] | |||
result = pipeline(result) | |||
return result | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
result = self.detector( | |||
return_loss=False, | |||
rescale=True, | |||
img=[input['img'][0].unsqueeze(0)], | |||
img_metas=[[dict(input['img_metas'][0].data)]]) | |||
assert result is not None | |||
result = result[0][0] | |||
bboxes = result[:, :4].tolist() | |||
kpss = result[:, 5:].tolist() | |||
scores = result[:, 4].tolist() | |||
return { | |||
OutputKeys.SCORES: scores, | |||
OutputKeys.BOXES: bboxes, | |||
OutputKeys.KEYPOINTS: kpss | |||
} | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
return inputs |
@@ -24,7 +24,7 @@ class FaceImageGenerationPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
use `model` to create a kws pipeline for prediction | |||
use `model` to create a face image generation pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
@@ -0,0 +1,130 @@ | |||
import os.path as osp | |||
from typing import Any, Dict | |||
import cv2 | |||
import numpy as np | |||
import PIL | |||
import torch | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models.cv.face_recognition.align_face import align_face | |||
from modelscope.models.cv.face_recognition.torchkit.backbone import get_model | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import pipeline | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import LoadImage | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.face_recognition, module_name=Pipelines.face_recognition) | |||
class FaceRecognitionPipeline(Pipeline): | |||
def __init__(self, model: str, face_detection: Pipeline, **kwargs): | |||
""" | |||
use `model` to create a face recognition pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
face_detecion: pipeline for face detection and face alignment before recognition | |||
""" | |||
# face recong model | |||
super().__init__(model=model, **kwargs) | |||
device = torch.device( | |||
f'cuda:{0}' if torch.cuda.is_available() else 'cpu') | |||
self.device = device | |||
face_model = get_model('IR_101')([112, 112]) | |||
face_model.load_state_dict( | |||
torch.load( | |||
osp.join(model, ModelFile.TORCH_MODEL_BIN_FILE), | |||
map_location=device)) | |||
face_model = face_model.to(device) | |||
face_model.eval() | |||
self.face_model = face_model | |||
logger.info('face recognition model loaded!') | |||
# face detect pipeline | |||
self.face_detection = face_detection | |||
def _choose_face(self, | |||
det_result, | |||
min_face=10, | |||
top_face=1, | |||
center_face=False): | |||
''' | |||
choose face with maximum area | |||
Args: | |||
det_result: output of face detection pipeline | |||
min_face: minimum size of valid face w/h | |||
top_face: take faces with top max areas | |||
center_face: choose the most centerd face from multi faces, only valid if top_face > 1 | |||
''' | |||
bboxes = np.array(det_result[OutputKeys.BOXES]) | |||
landmarks = np.array(det_result[OutputKeys.KEYPOINTS]) | |||
# scores = np.array(det_result[OutputKeys.SCORES]) | |||
if bboxes.shape[0] == 0: | |||
logger.info('No face detected!') | |||
return None | |||
# face idx with enough size | |||
face_idx = [] | |||
for i in range(bboxes.shape[0]): | |||
box = bboxes[i] | |||
if (box[2] - box[0]) >= min_face and (box[3] - box[1]) >= min_face: | |||
face_idx += [i] | |||
if len(face_idx) == 0: | |||
logger.info( | |||
f'Face size not enough, less than {min_face}x{min_face}!') | |||
return None | |||
bboxes = bboxes[face_idx] | |||
landmarks = landmarks[face_idx] | |||
# find max faces | |||
boxes = np.array(bboxes) | |||
area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) | |||
sort_idx = np.argsort(area)[-top_face:] | |||
# find center face | |||
if top_face > 1 and center_face and bboxes.shape[0] > 1: | |||
img_center = [img.shape[1] // 2, img.shape[0] // 2] | |||
min_dist = float('inf') | |||
sel_idx = -1 | |||
for _idx in sort_idx: | |||
box = boxes[_idx] | |||
dist = np.square( | |||
np.abs((box[0] + box[2]) / 2 - img_center[0])) + np.square( | |||
np.abs((box[1] + box[3]) / 2 - img_center[1])) | |||
if dist < min_dist: | |||
min_dist = dist | |||
sel_idx = _idx | |||
sort_idx = [sel_idx] | |||
main_idx = sort_idx[-1] | |||
return bboxes[main_idx], landmarks[main_idx] | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
img = LoadImage.convert_to_ndarray(input) | |||
img = img[:, :, ::-1] | |||
det_result = self.face_detection(img.copy()) | |||
rtn = self._choose_face(det_result) | |||
face_img = None | |||
if rtn is not None: | |||
_, face_lmks = rtn | |||
face_lmks = face_lmks.reshape(5, 2) | |||
align_img, _ = align_face(img, (112, 112), face_lmks) | |||
face_img = align_img[:, :, ::-1] # to rgb | |||
face_img = np.transpose(face_img, axes=(2, 0, 1)) | |||
face_img = (face_img / 255. - 0.5) / 0.5 | |||
face_img = face_img.astype(np.float32) | |||
result = {} | |||
result['img'] = face_img | |||
return result | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
assert input['img'] is not None | |||
img = input['img'].unsqueeze(0) | |||
emb = self.face_model(img).detach().cpu().numpy() | |||
emb /= np.sqrt(np.sum(emb**2, -1, keepdims=True)) # l2 norm | |||
return {OutputKeys.IMG_EMBEDDING: emb} | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
return inputs |
@@ -30,7 +30,7 @@ class ImageCartoonPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
use `model` and `preprocessor` to create a kws pipeline for prediction | |||
use `model` to create a image cartoon pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
@@ -27,7 +27,7 @@ class ImageColorEnhancePipeline(Pipeline): | |||
ImageColorEnhanceFinetunePreprocessor] = None, | |||
**kwargs): | |||
""" | |||
use `model` and `preprocessor` to create a kws pipeline for prediction | |||
use `model` and `preprocessor` to create a image color enhance pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
@@ -25,7 +25,7 @@ class ImageColorizationPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
use `model` to create a kws pipeline for prediction | |||
use `model` to create a image colorization pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
@@ -21,7 +21,7 @@ class ImageMattingPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
use `model` and `preprocessor` to create a kws pipeline for prediction | |||
use `model` to create a image matting pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
@@ -23,7 +23,7 @@ class ImageSuperResolutionPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
use `model` to create a kws pipeline for prediction | |||
use `model` to create a image super resolution pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
@@ -41,7 +41,7 @@ class OCRDetectionPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
use `model` and `preprocessor` to create a kws pipeline for prediction | |||
use `model` to create a OCR detection pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
@@ -21,7 +21,7 @@ class StyleTransferPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
use `model` and `preprocessor` to create a kws pipeline for prediction | |||
use `model` to create a style transfer pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
@@ -25,7 +25,7 @@ class VirtualTryonPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
use `model` to create a kws pipeline for prediction | |||
use `model` to create a virtual tryon pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
@@ -28,6 +28,8 @@ class CVTasks(object): | |||
ocr_detection = 'ocr-detection' | |||
action_recognition = 'action-recognition' | |||
video_embedding = 'video-embedding' | |||
face_detection = 'face-detection' | |||
face_recognition = 'face-recognition' | |||
image_color_enhance = 'image-color-enhance' | |||
virtual_tryon = 'virtual-tryon' | |||
image_colorization = 'image-colorization' | |||
@@ -0,0 +1,84 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os.path as osp | |||
import tempfile | |||
import unittest | |||
import cv2 | |||
import numpy as np | |||
from modelscope.fileio import File | |||
from modelscope.msdatasets import MsDataset | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.test_utils import test_level | |||
class FaceDetectionTest(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' | |||
def show_result(self, img_path, bboxes, kpss, scores): | |||
bboxes = np.array(bboxes) | |||
kpss = np.array(kpss) | |||
scores = np.array(scores) | |||
img = cv2.imread(img_path) | |||
assert img is not None, f"Can't read img: {img_path}" | |||
for i in range(len(scores)): | |||
bbox = bboxes[i].astype(np.int32) | |||
kps = kpss[i].reshape(-1, 2).astype(np.int32) | |||
score = scores[i] | |||
x1, y1, x2, y2 = bbox | |||
cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2) | |||
for kp in kps: | |||
cv2.circle(img, tuple(kp), 1, (0, 0, 255), 1) | |||
cv2.putText( | |||
img, | |||
f'{score:.2f}', (x1, y2), | |||
1, | |||
1.0, (0, 255, 0), | |||
thickness=1, | |||
lineType=8) | |||
cv2.imwrite('result.png', img) | |||
print( | |||
f'Found {len(scores)} faces, output written to {osp.abspath("result.png")}' | |||
) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_with_dataset(self): | |||
input_location = ['data/test/images/face_detection.png'] | |||
# alternatively: | |||
# input_location = '/dir/to/images' | |||
dataset = MsDataset.load(input_location, target='image') | |||
face_detection = pipeline(Tasks.face_detection, model=self.model_id) | |||
# note that for dataset output, the inference-output is a Generator that can be iterated. | |||
result = face_detection(dataset) | |||
result = next(result) | |||
self.show_result(input_location[0], result[OutputKeys.BOXES], | |||
result[OutputKeys.KEYPOINTS], | |||
result[OutputKeys.SCORES]) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_modelhub(self): | |||
face_detection = pipeline(Tasks.face_detection, model=self.model_id) | |||
img_path = 'data/test/images/face_detection.png' | |||
result = face_detection(img_path) | |||
self.show_result(img_path, result[OutputKeys.BOXES], | |||
result[OutputKeys.KEYPOINTS], | |||
result[OutputKeys.SCORES]) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_modelhub_default_model(self): | |||
face_detection = pipeline(Tasks.face_detection) | |||
img_path = 'data/test/images/face_detection.png' | |||
result = face_detection(img_path) | |||
self.show_result(img_path, result[OutputKeys.BOXES], | |||
result[OutputKeys.KEYPOINTS], | |||
result[OutputKeys.SCORES]) | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -0,0 +1,42 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os.path as osp | |||
import tempfile | |||
import unittest | |||
import cv2 | |||
import numpy as np | |||
from modelscope.fileio import File | |||
from modelscope.msdatasets import MsDataset | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.test_utils import test_level | |||
class FaceRecognitionTest(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.recog_model_id = 'damo/cv_ir101_facerecognition_cfglint' | |||
self.det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_face_compare(self): | |||
img1 = 'data/test/images/face_recognition_1.png' | |||
img2 = 'data/test/images/face_recognition_2.png' | |||
face_detection = pipeline( | |||
Tasks.face_detection, model=self.det_model_id) | |||
face_recognition = pipeline( | |||
Tasks.face_recognition, | |||
face_detection=face_detection, | |||
model=self.recog_model_id) | |||
# note that for dataset output, the inference-output is a Generator that can be iterated. | |||
emb1 = face_recognition(img1)[OutputKeys.IMG_EMBEDDING] | |||
emb2 = face_recognition(img2)[OutputKeys.IMG_EMBEDDING] | |||
sim = np.dot(emb1[0], emb2[0]) | |||
print(f'Cos similarity={sim:.3f}, img1:{img1} img2:{img2}') | |||
if __name__ == '__main__': | |||
unittest.main() |