Browse Source

[to #42322933]add cv-faceDetection and cv-faceRecognition

1. support FaceDetectionPipeline inference
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9470723
master
yuxiang.tyx yingda.chen 3 years ago
parent
commit
7e0af3dddc
51 changed files with 3168 additions and 34 deletions
  1. +1
    -0
      .gitignore
  2. +3
    -0
      data/test/images/face_detection.png
  3. +3
    -0
      data/test/images/face_recognition_1.png
  4. +3
    -0
      data/test/images/face_recognition_2.png
  5. +3
    -0
      modelscope/metainfo.py
  6. +4
    -3
      modelscope/models/cv/__init__.py
  7. +0
    -0
      modelscope/models/cv/face_detection/__init__.py
  8. +5
    -0
      modelscope/models/cv/face_detection/mmdet_patch/__init__.py
  9. +3
    -0
      modelscope/models/cv/face_detection/mmdet_patch/core/bbox/__init__.py
  10. +86
    -0
      modelscope/models/cv/face_detection/mmdet_patch/core/bbox/transforms.py
  11. +3
    -0
      modelscope/models/cv/face_detection/mmdet_patch/core/post_processing/__init__.py
  12. +85
    -0
      modelscope/models/cv/face_detection/mmdet_patch/core/post_processing/bbox_nms.py
  13. +3
    -0
      modelscope/models/cv/face_detection/mmdet_patch/datasets/__init__.py
  14. +3
    -0
      modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/__init__.py
  15. +188
    -0
      modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/transforms.py
  16. +151
    -0
      modelscope/models/cv/face_detection/mmdet_patch/datasets/retinaface.py
  17. +2
    -0
      modelscope/models/cv/face_detection/mmdet_patch/models/__init__.py
  18. +3
    -0
      modelscope/models/cv/face_detection/mmdet_patch/models/backbones/__init__.py
  19. +412
    -0
      modelscope/models/cv/face_detection/mmdet_patch/models/backbones/resnet.py
  20. +3
    -0
      modelscope/models/cv/face_detection/mmdet_patch/models/dense_heads/__init__.py
  21. +1068
    -0
      modelscope/models/cv/face_detection/mmdet_patch/models/dense_heads/scrfd_head.py
  22. +3
    -0
      modelscope/models/cv/face_detection/mmdet_patch/models/detectors/__init__.py
  23. +109
    -0
      modelscope/models/cv/face_detection/mmdet_patch/models/detectors/scrfd.py
  24. +0
    -0
      modelscope/models/cv/face_recognition/__init__.py
  25. +50
    -0
      modelscope/models/cv/face_recognition/align_face.py
  26. +0
    -0
      modelscope/models/cv/face_recognition/torchkit/__init__.py
  27. +31
    -0
      modelscope/models/cv/face_recognition/torchkit/backbone/__init__.py
  28. +68
    -0
      modelscope/models/cv/face_recognition/torchkit/backbone/common.py
  29. +279
    -0
      modelscope/models/cv/face_recognition/torchkit/backbone/model_irse.py
  30. +162
    -0
      modelscope/models/cv/face_recognition/torchkit/backbone/model_resnet.py
  31. +26
    -0
      modelscope/outputs.py
  32. +5
    -1
      modelscope/pipelines/base.py
  33. +4
    -0
      modelscope/pipelines/builder.py
  34. +23
    -17
      modelscope/pipelines/cv/__init__.py
  35. +1
    -1
      modelscope/pipelines/cv/action_recognition_pipeline.py
  36. +2
    -2
      modelscope/pipelines/cv/animal_recognition_pipeline.py
  37. +1
    -1
      modelscope/pipelines/cv/cmdssl_video_embedding_pipeline.py
  38. +105
    -0
      modelscope/pipelines/cv/face_detection_pipeline.py
  39. +1
    -1
      modelscope/pipelines/cv/face_image_generation_pipeline.py
  40. +130
    -0
      modelscope/pipelines/cv/face_recognition_pipeline.py
  41. +1
    -1
      modelscope/pipelines/cv/image_cartoon_pipeline.py
  42. +1
    -1
      modelscope/pipelines/cv/image_color_enhance_pipeline.py
  43. +1
    -1
      modelscope/pipelines/cv/image_colorization_pipeline.py
  44. +1
    -1
      modelscope/pipelines/cv/image_matting_pipeline.py
  45. +1
    -1
      modelscope/pipelines/cv/image_super_resolution_pipeline.py
  46. +1
    -1
      modelscope/pipelines/cv/ocr_detection_pipeline.py
  47. +1
    -1
      modelscope/pipelines/cv/style_transfer_pipeline.py
  48. +1
    -1
      modelscope/pipelines/cv/virtual_tryon_pipeline.py
  49. +2
    -0
      modelscope/utils/constant.py
  50. +84
    -0
      tests/pipelines/test_face_detection.py
  51. +42
    -0
      tests/pipelines/test_face_recognition.py

+ 1
- 0
.gitignore View File

@@ -121,6 +121,7 @@ source.sh
tensorboard.sh
.DS_Store
replace.sh
result.png

# Pytorch
*.pth


+ 3
- 0
data/test/images/face_detection.png View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:aa3963d1c54e6d3d46e9a59872a99ed955d4050092f5cfe5f591e03d740b7042
size 653006

+ 3
- 0
data/test/images/face_recognition_1.png View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:48e541daeb2692907efef47018e41abb5ae6bcd88eb5ff58290d7fe5dc8b2a13
size 462584

+ 3
- 0
data/test/images/face_recognition_2.png View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e9565b43d9f65361b9bad6553b327c2c6f02fd063a4c8dc0f461e88ea461989d
size 357166

+ 3
- 0
modelscope/metainfo.py View File

@@ -10,6 +10,7 @@ class Models(object):
Model name should only contain model info but not task info.
"""
# vision models
scrfd = 'scrfd'
classification_model = 'ClassificationModel'
nafnet = 'nafnet'
csrnet = 'csrnet'
@@ -67,6 +68,7 @@ class Pipelines(object):
action_recognition = 'TAdaConv_action-recognition'
animal_recognation = 'resnet101-animal_recog'
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding'
face_detection = 'resnet-face-detection-scrfd10gkps'
live_category = 'live-category'
general_image_classification = 'vit-base_image-classification_ImageNet-labels'
daily_image_classification = 'vit-base_image-classification_Dailylife-labels'
@@ -76,6 +78,7 @@ class Pipelines(object):
image_super_resolution = 'rrdb-image-super-resolution'
face_image_generation = 'gan-face-image-generation'
style_transfer = 'AAMS-style-transfer'
face_recognition = 'ir101-face-recognition-cfglint'
image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation'
image2image_translation = 'image-to-image-translation'
live_category = 'live-category'


+ 4
- 3
modelscope/models/cv/__init__.py View File

@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from . import (action_recognition, animal_recognition, cartoon,
cmdssl_video_embedding, face_generation, image_classification,
image_color_enhance, image_colorization, image_denoise,
image_instance_segmentation, super_resolution, virual_tryon)
cmdssl_video_embedding, face_detection, face_generation,
image_classification, image_color_enhance, image_colorization,
image_denoise, image_instance_segmentation,
image_to_image_translation, super_resolution, virual_tryon)

+ 0
- 0
modelscope/models/cv/face_detection/__init__.py View File


+ 5
- 0
modelscope/models/cv/face_detection/mmdet_patch/__init__.py View File

@@ -0,0 +1,5 @@
"""
mmdet_patch is based on
https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet,
all duplicate functions from official mmdetection are removed.
"""

+ 3
- 0
modelscope/models/cv/face_detection/mmdet_patch/core/bbox/__init__.py View File

@@ -0,0 +1,3 @@
from .transforms import bbox2result, distance2kps, kps2distance

__all__ = ['bbox2result', 'distance2kps', 'kps2distance']

+ 86
- 0
modelscope/models/cv/face_detection/mmdet_patch/core/bbox/transforms.py View File

@@ -0,0 +1,86 @@
"""
based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/core/bbox/transforms.py
"""
import numpy as np
import torch


def bbox2result(bboxes, labels, num_classes, kps=None):
"""Convert detection results to a list of numpy arrays.

Args:
bboxes (torch.Tensor | np.ndarray): shape (n, 5)
labels (torch.Tensor | np.ndarray): shape (n, )
num_classes (int): class number, including background class

Returns:
list(ndarray): bbox results of each class
"""
bbox_len = 5 if kps is None else 5 + 10 # if has kps, add 10 kps into bbox
if bboxes.shape[0] == 0:
return [
np.zeros((0, bbox_len), dtype=np.float32)
for i in range(num_classes)
]
else:
if isinstance(bboxes, torch.Tensor):
bboxes = bboxes.detach().cpu().numpy()
labels = labels.detach().cpu().numpy()
if kps is None:
return [bboxes[labels == i, :] for i in range(num_classes)]
else: # with kps
if isinstance(kps, torch.Tensor):
kps = kps.detach().cpu().numpy()
return [
np.hstack([bboxes[labels == i, :], kps[labels == i, :]])
for i in range(num_classes)
]


def distance2kps(points, distance, max_shape=None):
"""Decode distance prediction to bounding box.

Args:
points (Tensor): Shape (n, 2), [x, y].
distance (Tensor): Distance from the given point to 4
boundaries (left, top, right, bottom).
max_shape (tuple): Shape of the image.

Returns:
Tensor: Decoded kps.
"""
preds = []
for i in range(0, distance.shape[1], 2):
px = points[:, i % 2] + distance[:, i]
py = points[:, i % 2 + 1] + distance[:, i + 1]
if max_shape is not None:
px = px.clamp(min=0, max=max_shape[1])
py = py.clamp(min=0, max=max_shape[0])
preds.append(px)
preds.append(py)
return torch.stack(preds, -1)


def kps2distance(points, kps, max_dis=None, eps=0.1):
"""Decode bounding box based on distances.

Args:
points (Tensor): Shape (n, 2), [x, y].
kps (Tensor): Shape (n, K), "xyxy" format
max_dis (float): Upper bound of the distance.
eps (float): a small value to ensure target < max_dis, instead <=

Returns:
Tensor: Decoded distances.
"""

preds = []
for i in range(0, kps.shape[1], 2):
px = kps[:, i] - points[:, i % 2]
py = kps[:, i + 1] - points[:, i % 2 + 1]
if max_dis is not None:
px = px.clamp(min=0, max=max_dis - eps)
py = py.clamp(min=0, max=max_dis - eps)
preds.append(px)
preds.append(py)
return torch.stack(preds, -1)

+ 3
- 0
modelscope/models/cv/face_detection/mmdet_patch/core/post_processing/__init__.py View File

@@ -0,0 +1,3 @@
from .bbox_nms import multiclass_nms

__all__ = ['multiclass_nms']

+ 85
- 0
modelscope/models/cv/face_detection/mmdet_patch/core/post_processing/bbox_nms.py View File

@@ -0,0 +1,85 @@
"""
based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/core/post_processing/bbox_nms.py
"""
import torch


def multiclass_nms(multi_bboxes,
multi_scores,
score_thr,
nms_cfg,
max_num=-1,
score_factors=None,
return_inds=False,
multi_kps=None):
"""NMS for multi-class bboxes.

Args:
multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
multi_scores (Tensor): shape (n, #class), where the last column
contains scores of the background class, but this will be ignored.
score_thr (float): bbox threshold, bboxes with scores lower than it
will not be considered.
nms_thr (float): NMS IoU threshold
max_num (int, optional): if there are more than max_num bboxes after
NMS, only top max_num will be kept. Default to -1.
score_factors (Tensor, optional): The factors multiplied to scores
before applying NMS. Default to None.
return_inds (bool, optional): Whether return the indices of kept
bboxes. Default to False.

Returns:
tuple: (bboxes, labels, indices (optional)), tensors of shape (k, 5),
(k), and (k). Labels are 0-based.
"""
num_classes = multi_scores.size(1) - 1
# exclude background category
kps = None
if multi_bboxes.shape[1] > 4:
bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4)
if multi_kps is not None:
kps = multi_kps.view(multi_scores.size(0), -1, 10)
else:
bboxes = multi_bboxes[:, None].expand(
multi_scores.size(0), num_classes, 4)
if multi_kps is not None:
kps = multi_kps[:, None].expand(
multi_scores.size(0), num_classes, 10)

scores = multi_scores[:, :-1]
if score_factors is not None:
scores = scores * score_factors[:, None]

labels = torch.arange(num_classes, dtype=torch.long)
labels = labels.view(1, -1).expand_as(scores)

bboxes = bboxes.reshape(-1, 4)
if kps is not None:
kps = kps.reshape(-1, 10)
scores = scores.reshape(-1)
labels = labels.reshape(-1)

# remove low scoring boxes
valid_mask = scores > score_thr
inds = valid_mask.nonzero(as_tuple=False).squeeze(1)
bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds]
if kps is not None:
kps = kps[inds]
if inds.numel() == 0:
if torch.onnx.is_in_onnx_export():
raise RuntimeError('[ONNX Error] Can not record NMS '
'as it has not been executed this time')
return bboxes, labels, kps

# TODO: add size check before feed into batched_nms
from mmcv.ops.nms import batched_nms
dets, keep = batched_nms(bboxes, scores, labels, nms_cfg)

if max_num > 0:
dets = dets[:max_num]
keep = keep[:max_num]

if return_inds:
return dets, labels[keep], kps[keep], keep
else:
return dets, labels[keep], kps[keep]

+ 3
- 0
modelscope/models/cv/face_detection/mmdet_patch/datasets/__init__.py View File

@@ -0,0 +1,3 @@
from .retinaface import RetinaFaceDataset

__all__ = ['RetinaFaceDataset']

+ 3
- 0
modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/__init__.py View File

@@ -0,0 +1,3 @@
from .transforms import RandomSquareCrop

__all__ = ['RandomSquareCrop']

+ 188
- 0
modelscope/models/cv/face_detection/mmdet_patch/datasets/pipelines/transforms.py View File

@@ -0,0 +1,188 @@
"""
based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/transforms.py
"""
import numpy as np
from mmdet.datasets.builder import PIPELINES
from numpy import random


@PIPELINES.register_module()
class RandomSquareCrop(object):
"""Random crop the image & bboxes, the cropped patches have minimum IoU
requirement with original image & bboxes, the IoU threshold is randomly
selected from min_ious.

Args:
min_ious (tuple): minimum IoU threshold for all intersections with
bounding boxes
min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w,
where a >= min_crop_size).

Note:
The keys for bboxes, labels and masks should be paired. That is, \
`gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \
`gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`.
"""

def __init__(self,
crop_ratio_range=None,
crop_choice=None,
bbox_clip_border=True):

self.crop_ratio_range = crop_ratio_range
self.crop_choice = crop_choice
self.bbox_clip_border = bbox_clip_border

assert (self.crop_ratio_range is None) ^ (self.crop_choice is None)
if self.crop_ratio_range is not None:
self.crop_ratio_min, self.crop_ratio_max = self.crop_ratio_range

self.bbox2label = {
'gt_bboxes': 'gt_labels',
'gt_bboxes_ignore': 'gt_labels_ignore'
}
self.bbox2mask = {
'gt_bboxes': 'gt_masks',
'gt_bboxes_ignore': 'gt_masks_ignore'
}

def __call__(self, results):
"""Call function to crop images and bounding boxes with minimum IoU
constraint.

Args:
results (dict): Result dict from loading pipeline.

Returns:
dict: Result dict with images and bounding boxes cropped, \
'img_shape' key is updated.
"""

if 'img_fields' in results:
assert results['img_fields'] == ['img'], \
'Only single img_fields is allowed'
img = results['img']
assert 'bbox_fields' in results
assert 'gt_bboxes' in results
boxes = results['gt_bboxes']
h, w, c = img.shape
scale_retry = 0
if self.crop_ratio_range is not None:
max_scale = self.crop_ratio_max
else:
max_scale = np.amax(self.crop_choice)
while True:
scale_retry += 1

if scale_retry == 1 or max_scale > 1.0:
if self.crop_ratio_range is not None:
scale = np.random.uniform(self.crop_ratio_min,
self.crop_ratio_max)
elif self.crop_choice is not None:
scale = np.random.choice(self.crop_choice)
else:
scale = scale * 1.2

for i in range(250):
short_side = min(w, h)
cw = int(scale * short_side)
ch = cw

# TODO +1
if w == cw:
left = 0
elif w > cw:
left = random.randint(0, w - cw)
else:
left = random.randint(w - cw, 0)
if h == ch:
top = 0
elif h > ch:
top = random.randint(0, h - ch)
else:
top = random.randint(h - ch, 0)

patch = np.array(
(int(left), int(top), int(left + cw), int(top + ch)),
dtype=np.int)

# center of boxes should inside the crop img
# only adjust boxes and instance masks when the gt is not empty
# adjust boxes
def is_center_of_bboxes_in_patch(boxes, patch):
# TODO >=
center = (boxes[:, :2] + boxes[:, 2:]) / 2
mask = \
((center[:, 0] > patch[0])
* (center[:, 1] > patch[1])
* (center[:, 0] < patch[2])
* (center[:, 1] < patch[3]))
return mask

mask = is_center_of_bboxes_in_patch(boxes, patch)
if not mask.any():
continue
for key in results.get('bbox_fields', []):
boxes = results[key].copy()
mask = is_center_of_bboxes_in_patch(boxes, patch)
boxes = boxes[mask]
if self.bbox_clip_border:
boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:])
boxes[:, :2] = boxes[:, :2].clip(min=patch[:2])
boxes -= np.tile(patch[:2], 2)

results[key] = boxes
# labels
label_key = self.bbox2label.get(key)
if label_key in results:
results[label_key] = results[label_key][mask]

# keypoints field
if key == 'gt_bboxes':
for kps_key in results.get('keypoints_fields', []):
keypointss = results[kps_key].copy()
keypointss = keypointss[mask, :, :]
if self.bbox_clip_border:
keypointss[:, :, :
2] = keypointss[:, :, :2].clip(
max=patch[2:])
keypointss[:, :, :
2] = keypointss[:, :, :2].clip(
min=patch[:2])
keypointss[:, :, 0] -= patch[0]
keypointss[:, :, 1] -= patch[1]
results[kps_key] = keypointss

# mask fields
mask_key = self.bbox2mask.get(key)
if mask_key in results:
results[mask_key] = results[mask_key][mask.nonzero()
[0]].crop(patch)

# adjust the img no matter whether the gt is empty before crop
rimg = np.ones((ch, cw, 3), dtype=img.dtype) * 128
patch_from = patch.copy()
patch_from[0] = max(0, patch_from[0])
patch_from[1] = max(0, patch_from[1])
patch_from[2] = min(img.shape[1], patch_from[2])
patch_from[3] = min(img.shape[0], patch_from[3])
patch_to = patch.copy()
patch_to[0] = max(0, patch_to[0] * -1)
patch_to[1] = max(0, patch_to[1] * -1)
patch_to[2] = patch_to[0] + (patch_from[2] - patch_from[0])
patch_to[3] = patch_to[1] + (patch_from[3] - patch_from[1])
rimg[patch_to[1]:patch_to[3],
patch_to[0]:patch_to[2], :] = img[
patch_from[1]:patch_from[3],
patch_from[0]:patch_from[2], :]
img = rimg
results['img'] = img
results['img_shape'] = img.shape

return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(min_ious={self.min_iou}, '
repr_str += f'crop_size={self.crop_size})'
return repr_str

+ 151
- 0
modelscope/models/cv/face_detection/mmdet_patch/datasets/retinaface.py View File

@@ -0,0 +1,151 @@
"""
based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/retinaface.py
"""
import numpy as np
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.custom import CustomDataset


@DATASETS.register_module()
class RetinaFaceDataset(CustomDataset):

CLASSES = ('FG', )

def __init__(self, min_size=None, **kwargs):
self.NK = 5
self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)}
self.min_size = min_size
self.gt_path = kwargs.get('gt_path')
super(RetinaFaceDataset, self).__init__(**kwargs)

def _parse_ann_line(self, line):
values = [float(x) for x in line.strip().split()]
bbox = np.array(values[0:4], dtype=np.float32)
kps = np.zeros((self.NK, 3), dtype=np.float32)
ignore = False
if self.min_size is not None:
assert not self.test_mode
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
if w < self.min_size or h < self.min_size:
ignore = True
if len(values) > 4:
if len(values) > 5:
kps = np.array(
values[4:19], dtype=np.float32).reshape((self.NK, 3))
for li in range(kps.shape[0]):
if (kps[li, :] == -1).all():
kps[li][2] = 0.0 # weight = 0, ignore
else:
assert kps[li][2] >= 0
kps[li][2] = 1.0 # weight
else: # len(values)==5
if not ignore:
ignore = (values[4] == 1)
else:
assert self.test_mode

return dict(bbox=bbox, kps=kps, ignore=ignore, cat='FG')

def load_annotations(self, ann_file):
"""Load annotation from COCO style annotation file.

Args:
ann_file (str): Path of annotation file.
20220711@tyx: ann_file is list of img paths is supported

Returns:
list[dict]: Annotation info from COCO api.
"""
if isinstance(ann_file, list):
data_infos = []
for line in ann_file:
name = line
objs = [0, 0, 0, 0]
data_infos.append(
dict(filename=name, width=0, height=0, objs=objs))
else:
name = None
bbox_map = {}
for line in open(ann_file, 'r'):
line = line.strip()
if line.startswith('#'):
value = line[1:].strip().split()
name = value[0]
width = int(value[1])
height = int(value[2])

bbox_map[name] = dict(width=width, height=height, objs=[])
continue
assert name is not None
assert name in bbox_map
bbox_map[name]['objs'].append(line)
print('origin image size', len(bbox_map))
data_infos = []
for name in bbox_map:
item = bbox_map[name]
width = item['width']
height = item['height']
vals = item['objs']
objs = []
for line in vals:
data = self._parse_ann_line(line)
if data is None:
continue
objs.append(data) # data is (bbox, kps, cat)
if len(objs) == 0 and not self.test_mode:
continue
data_infos.append(
dict(filename=name, width=width, height=height, objs=objs))
return data_infos

def get_ann_info(self, idx):
"""Get COCO annotation by index.

Args:
idx (int): Index of data.

Returns:
dict: Annotation info of specified index.
"""
data_info = self.data_infos[idx]

bboxes = []
keypointss = []
labels = []
bboxes_ignore = []
labels_ignore = []
for obj in data_info['objs']:
label = self.cat2label[obj['cat']]
bbox = obj['bbox']
keypoints = obj['kps']
ignore = obj['ignore']
if ignore:
bboxes_ignore.append(bbox)
labels_ignore.append(label)
else:
bboxes.append(bbox)
labels.append(label)
keypointss.append(keypoints)
if not bboxes:
bboxes = np.zeros((0, 4))
labels = np.zeros((0, ))
keypointss = np.zeros((0, self.NK, 3))
else:
# bboxes = np.array(bboxes, ndmin=2) - 1
bboxes = np.array(bboxes, ndmin=2)
labels = np.array(labels)
keypointss = np.array(keypointss, ndmin=3)
if not bboxes_ignore:
bboxes_ignore = np.zeros((0, 4))
labels_ignore = np.zeros((0, ))
else:
bboxes_ignore = np.array(bboxes_ignore, ndmin=2)
labels_ignore = np.array(labels_ignore)
ann = dict(
bboxes=bboxes.astype(np.float32),
labels=labels.astype(np.int64),
keypointss=keypointss.astype(np.float32),
bboxes_ignore=bboxes_ignore.astype(np.float32),
labels_ignore=labels_ignore.astype(np.int64))
return ann

+ 2
- 0
modelscope/models/cv/face_detection/mmdet_patch/models/__init__.py View File

@@ -0,0 +1,2 @@
from .dense_heads import * # noqa: F401,F403
from .detectors import * # noqa: F401,F403

+ 3
- 0
modelscope/models/cv/face_detection/mmdet_patch/models/backbones/__init__.py View File

@@ -0,0 +1,3 @@
from .resnet import ResNetV1e

__all__ = ['ResNetV1e']

+ 412
- 0
modelscope/models/cv/face_detection/mmdet_patch/models/backbones/resnet.py View File

@@ -0,0 +1,412 @@
"""
based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/backbones/resnet.py
"""
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer,
constant_init, kaiming_init)
from mmcv.runner import load_checkpoint
from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
from mmdet.models.builder import BACKBONES
from mmdet.models.utils import ResLayer
from mmdet.utils import get_root_logger
from torch.nn.modules.batchnorm import _BatchNorm


class ResNet(nn.Module):
"""ResNet backbone.

Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
stem_channels (int | None): Number of stem channels. If not specified,
it will be the same as `base_channels`. Default: None.
base_channels (int): Number of base channels of res layer. Default: 64.
in_channels (int): Number of input image channels. Default: 3.
num_stages (int): Resnet stages. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
norm_cfg (dict): Dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
plugins (list[dict]): List of plugins for stages, each dict contains:

- cfg (dict, required): Cfg dict to build plugin.
- position (str, required): Position inside block to insert
plugin, options are 'after_conv1', 'after_conv2', 'after_conv3'.
- stages (tuple[bool], optional): Stages to apply plugin, length
should be same as 'num_stages'.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity.

Example:
>>> from mmdet.models import ResNet
>>> import torch
>>> self = ResNet(depth=18)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 64, 8, 8)
(1, 128, 4, 4)
(1, 256, 2, 2)
(1, 512, 1, 1)
"""

arch_settings = {
0: (BasicBlock, (2, 2, 2, 2)),
18: (BasicBlock, (2, 2, 2, 2)),
19: (BasicBlock, (2, 4, 4, 1)),
20: (BasicBlock, (2, 3, 2, 2)),
22: (BasicBlock, (2, 4, 3, 1)),
24: (BasicBlock, (2, 4, 4, 1)),
26: (BasicBlock, (2, 4, 4, 2)),
28: (BasicBlock, (2, 5, 4, 2)),
29: (BasicBlock, (2, 6, 3, 2)),
30: (BasicBlock, (2, 5, 5, 2)),
32: (BasicBlock, (2, 6, 5, 2)),
34: (BasicBlock, (3, 4, 6, 3)),
35: (BasicBlock, (3, 6, 4, 3)),
38: (BasicBlock, (3, 8, 4, 3)),
40: (BasicBlock, (3, 8, 5, 3)),
50: (Bottleneck, (3, 4, 6, 3)),
56: (Bottleneck, (3, 8, 4, 3)),
68: (Bottleneck, (3, 10, 6, 3)),
74: (Bottleneck, (3, 12, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}

def __init__(self,
depth,
in_channels=3,
stem_channels=None,
base_channels=64,
num_stages=4,
block_cfg=None,
strides=(1, 2, 2, 2),
dilations=(1, 1, 1, 1),
out_indices=(0, 1, 2, 3),
style='pytorch',
deep_stem=False,
avg_down=False,
no_pool33=False,
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
dcn=None,
stage_with_dcn=(False, False, False, False),
plugins=None,
with_cp=False,
zero_init_residual=True):
super(ResNet, self).__init__()
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')
self.depth = depth
if stem_channels is None:
stem_channels = base_channels
self.stem_channels = stem_channels
self.base_channels = base_channels
self.num_stages = num_stages
assert num_stages >= 1 and num_stages <= 4
self.strides = strides
self.dilations = dilations
assert len(strides) == len(dilations) == num_stages
self.out_indices = out_indices
assert max(out_indices) < num_stages
self.style = style
self.deep_stem = deep_stem
self.avg_down = avg_down
self.no_pool33 = no_pool33
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.with_cp = with_cp
self.norm_eval = norm_eval
self.dcn = dcn
self.stage_with_dcn = stage_with_dcn
if dcn is not None:
assert len(stage_with_dcn) == num_stages
self.plugins = plugins
self.zero_init_residual = zero_init_residual
if block_cfg is None:
self.block, stage_blocks = self.arch_settings[depth]
else:
self.block = BasicBlock if block_cfg[
'block'] == 'BasicBlock' else Bottleneck
stage_blocks = block_cfg['stage_blocks']
assert len(stage_blocks) >= num_stages
self.stage_blocks = stage_blocks[:num_stages]
self.inplanes = stem_channels

self._make_stem_layer(in_channels, stem_channels)
if block_cfg is not None and 'stage_planes' in block_cfg:
stage_planes = block_cfg['stage_planes']
else:
stage_planes = [base_channels * 2**i for i in range(num_stages)]

# print('resnet cfg:', stage_blocks, stage_planes)
self.res_layers = []
for i, num_blocks in enumerate(self.stage_blocks):
stride = strides[i]
dilation = dilations[i]
dcn = self.dcn if self.stage_with_dcn[i] else None
if plugins is not None:
stage_plugins = self.make_stage_plugins(plugins, i)
else:
stage_plugins = None
planes = stage_planes[i]
res_layer = self.make_res_layer(
block=self.block,
inplanes=self.inplanes,
planes=planes,
num_blocks=num_blocks,
stride=stride,
dilation=dilation,
style=self.style,
avg_down=self.avg_down,
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
dcn=dcn,
plugins=stage_plugins)
self.inplanes = planes * self.block.expansion
layer_name = f'layer{i + 1}'
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)

self._freeze_stages()

self.feat_dim = self.block.expansion * base_channels * 2**(
len(self.stage_blocks) - 1)

def make_stage_plugins(self, plugins, stage_idx):
"""Make plugins for ResNet ``stage_idx`` th stage.

Currently we support to insert ``context_block``,
``empirical_attention_block``, ``nonlocal_block`` into the backbone
like ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of
Bottleneck.

An example of plugins format could be:

Examples:
>>> plugins=[
... dict(cfg=dict(type='xxx', arg1='xxx'),
... stages=(False, True, True, True),
... position='after_conv2'),
... dict(cfg=dict(type='yyy'),
... stages=(True, True, True, True),
... position='after_conv3'),
... dict(cfg=dict(type='zzz', postfix='1'),
... stages=(True, True, True, True),
... position='after_conv3'),
... dict(cfg=dict(type='zzz', postfix='2'),
... stages=(True, True, True, True),
... position='after_conv3')
... ]
>>> self = ResNet(depth=18)
>>> stage_plugins = self.make_stage_plugins(plugins, 0)
>>> assert len(stage_plugins) == 3

Suppose ``stage_idx=0``, the structure of blocks in the stage would be:

.. code-block:: none

conv1-> conv2->conv3->yyy->zzz1->zzz2

Suppose 'stage_idx=1', the structure of blocks in the stage would be:

.. code-block:: none

conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2

If stages is missing, the plugin would be applied to all stages.

Args:
plugins (list[dict]): List of plugins cfg to build. The postfix is
required if multiple same type plugins are inserted.
stage_idx (int): Index of stage to build

Returns:
list[dict]: Plugins for current stage
"""
stage_plugins = []
for plugin in plugins:
plugin = plugin.copy()
stages = plugin.pop('stages', None)
assert stages is None or len(stages) == self.num_stages
# whether to insert plugin into current stage
if stages is None or stages[stage_idx]:
stage_plugins.append(plugin)

return stage_plugins

def make_res_layer(self, **kwargs):
"""Pack all blocks in a stage into a ``ResLayer``."""
return ResLayer(**kwargs)

@property
def norm1(self):
"""nn.Module: the normalization layer named "norm1" """
return getattr(self, self.norm1_name)

def _make_stem_layer(self, in_channels, stem_channels):
if self.deep_stem:
self.stem = nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels,
stem_channels // 2,
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
nn.ReLU(inplace=True),
build_conv_layer(
self.conv_cfg,
stem_channels // 2,
stem_channels // 2,
kernel_size=3,
stride=1,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
nn.ReLU(inplace=True),
build_conv_layer(
self.conv_cfg,
stem_channels // 2,
stem_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, stem_channels)[1],
nn.ReLU(inplace=True))
else:
self.conv1 = build_conv_layer(
self.conv_cfg,
in_channels,
stem_channels,
kernel_size=7,
stride=2,
padding=3,
bias=False)
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, stem_channels, postfix=1)
self.add_module(self.norm1_name, norm1)
self.relu = nn.ReLU(inplace=True)
if self.no_pool33:
assert self.deep_stem
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
else:
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

def _freeze_stages(self):
if self.frozen_stages >= 0:
if self.deep_stem:
self.stem.eval()
for param in self.stem.parameters():
param.requires_grad = False
else:
self.norm1.eval()
for m in [self.conv1, self.norm1]:
for param in m.parameters():
param.requires_grad = False

for i in range(1, self.frozen_stages + 1):
m = getattr(self, f'layer{i}')
m.eval()
for param in m.parameters():
param.requires_grad = False

def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.

Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)

if self.dcn is not None:
for m in self.modules():
if isinstance(m, Bottleneck) and hasattr(
m.conv2, 'conv_offset'):
constant_init(m.conv2.conv_offset, 0)

if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
else:
raise TypeError('pretrained must be a str or None')

def forward(self, x):
"""Forward function."""
if self.deep_stem:
x = self.stem(x)
else:
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.maxpool(x)
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)

def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer
freezed."""
super(ResNet, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()


@BACKBONES.register_module()
class ResNetV1e(ResNet):
r"""ResNetV1d variant described in `Bag of Tricks
<https://arxiv.org/pdf/1812.01187.pdf>`_.

Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
the input stem with three 3x3 convs. And in the downsampling block, a 2x2
avg_pool with stride 2 is added before conv, whose stride is changed to 1.

Compared with ResNetV1d, ResNetV1e change maxpooling from 3x3 to 2x2 pad=1
"""

def __init__(self, **kwargs):
super(ResNetV1e, self).__init__(
deep_stem=True, avg_down=True, no_pool33=True, **kwargs)

+ 3
- 0
modelscope/models/cv/face_detection/mmdet_patch/models/dense_heads/__init__.py View File

@@ -0,0 +1,3 @@
from .scrfd_head import SCRFDHead

__all__ = ['SCRFDHead']

+ 1068
- 0
modelscope/models/cv/face_detection/mmdet_patch/models/dense_heads/scrfd_head.py
File diff suppressed because it is too large
View File


+ 3
- 0
modelscope/models/cv/face_detection/mmdet_patch/models/detectors/__init__.py View File

@@ -0,0 +1,3 @@
from .scrfd import SCRFD

__all__ = ['SCRFD']

+ 109
- 0
modelscope/models/cv/face_detection/mmdet_patch/models/detectors/scrfd.py View File

@@ -0,0 +1,109 @@
"""
based on https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/detectors/scrfd.py
"""
import torch
from mmdet.models.builder import DETECTORS
from mmdet.models.detectors.single_stage import SingleStageDetector

from ....mmdet_patch.core.bbox import bbox2result


@DETECTORS.register_module()
class SCRFD(SingleStageDetector):

def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(SCRFD, self).__init__(backbone, neck, bbox_head, train_cfg,
test_cfg, pretrained)

def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_keypointss=None,
gt_bboxes_ignore=None):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
gt_bboxes (list[Tensor]): Each item are the truth boxes for each
image in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): Class indices corresponding to each box
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
boxes can be ignored when computing the loss.

Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
super(SingleStageDetector, self).forward_train(img, img_metas)
x = self.extract_feat(img)
losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
gt_labels, gt_keypointss,
gt_bboxes_ignore)
return losses

def simple_test(self, img, img_metas, rescale=False):
"""Test function without test time augmentation.

Args:
imgs (list[torch.Tensor]): List of multiple images
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.

Returns:
list[list[np.ndarray]]: BBox results of each image and classes.
The outer list corresponds to each image. The inner list
corresponds to each class.
"""
x = self.extract_feat(img)
outs = self.bbox_head(x)
if torch.onnx.is_in_onnx_export():
print('single_stage.py in-onnx-export')
print(outs.__class__)
cls_score, bbox_pred, kps_pred = outs
for c in cls_score:
print(c.shape)
for c in bbox_pred:
print(c.shape)
if self.bbox_head.use_kps:
for c in kps_pred:
print(c.shape)
return (cls_score, bbox_pred, kps_pred)
else:
return (cls_score, bbox_pred)
bbox_list = self.bbox_head.get_bboxes(
*outs, img_metas, rescale=rescale)

# return kps if use_kps
if len(bbox_list[0]) == 2:
bbox_results = [
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
for det_bboxes, det_labels in bbox_list
]
elif len(bbox_list[0]) == 3:
bbox_results = [
bbox2result(
det_bboxes,
det_labels,
self.bbox_head.num_classes,
kps=det_kps)
for det_bboxes, det_labels, det_kps in bbox_list
]
return bbox_results

def feature_test(self, img):
x = self.extract_feat(img)
outs = self.bbox_head(x)
return outs

+ 0
- 0
modelscope/models/cv/face_recognition/__init__.py View File


+ 50
- 0
modelscope/models/cv/face_recognition/align_face.py View File

@@ -0,0 +1,50 @@
import cv2
import numpy as np
from skimage import transform as trans


def align_face(image, size, lmks):
dst_w = size[1]
dst_h = size[0]
# landmark calculation of dst images
base_w = 96
base_h = 112
assert (dst_w >= base_w)
assert (dst_h >= base_h)
base_lmk = [
30.2946, 51.6963, 65.5318, 51.5014, 48.0252, 71.7366, 33.5493, 92.3655,
62.7299, 92.2041
]

dst_lmk = np.array(base_lmk).reshape((5, 2)).astype(np.float32)
if dst_w != base_w:
slide = (dst_w - base_w) / 2
dst_lmk[:, 0] += slide

if dst_h != base_h:
slide = (dst_h - base_h) / 2
dst_lmk[:, 1] += slide

src_lmk = lmks
# using skimage method
tform = trans.SimilarityTransform()
tform.estimate(src_lmk, dst_lmk)
t = tform.params[0:2, :]

assert (image.shape[2] == 3)

dst_image = cv2.warpAffine(image.copy(), t, (dst_w, dst_h))
dst_pts = GetAffinePoints(src_lmk, t)
return dst_image, dst_pts


def GetAffinePoints(pts_in, trans):
pts_out = pts_in.copy()
assert (pts_in.shape[1] == 2)

for k in range(pts_in.shape[0]):
pts_out[k, 0] = pts_in[k, 0] * trans[0, 0] + pts_in[k, 1] * trans[
0, 1] + trans[0, 2]
pts_out[k, 1] = pts_in[k, 0] * trans[1, 0] + pts_in[k, 1] * trans[
1, 1] + trans[1, 2]
return pts_out

+ 0
- 0
modelscope/models/cv/face_recognition/torchkit/__init__.py View File


+ 31
- 0
modelscope/models/cv/face_recognition/torchkit/backbone/__init__.py View File

@@ -0,0 +1,31 @@
from .model_irse import (IR_18, IR_34, IR_50, IR_101, IR_152, IR_200, IR_SE_50,
IR_SE_101, IR_SE_152, IR_SE_200)
from .model_resnet import ResNet_50, ResNet_101, ResNet_152

_model_dict = {
'ResNet_50': ResNet_50,
'ResNet_101': ResNet_101,
'ResNet_152': ResNet_152,
'IR_18': IR_18,
'IR_34': IR_34,
'IR_50': IR_50,
'IR_101': IR_101,
'IR_152': IR_152,
'IR_200': IR_200,
'IR_SE_50': IR_SE_50,
'IR_SE_101': IR_SE_101,
'IR_SE_152': IR_SE_152,
'IR_SE_200': IR_SE_200
}


def get_model(key):
""" Get different backbone network by key,
support ResNet50, ResNet_101, ResNet_152
IR_18, IR_34, IR_50, IR_101, IR_152, IR_200,
IR_SE_50, IR_SE_101, IR_SE_152, IR_SE_200.
"""
if key in _model_dict.keys():
return _model_dict[key]
else:
raise KeyError('not support model {}'.format(key))

+ 68
- 0
modelscope/models/cv/face_recognition/torchkit/backbone/common.py View File

@@ -0,0 +1,68 @@
import torch
import torch.nn as nn
from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Linear, Module, ReLU,
Sigmoid)


def initialize_weights(modules):
""" Weight initilize, conv2d and linear is initialized with kaiming_normal
"""
for m in modules:
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
m.bias.data.zero_()


class Flatten(Module):
""" Flat tensor
"""

def forward(self, input):
return input.view(input.size(0), -1)


class SEModule(Module):
""" SE block
"""

def __init__(self, channels, reduction):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = Conv2d(
channels,
channels // reduction,
kernel_size=1,
padding=0,
bias=False)

nn.init.xavier_uniform_(self.fc1.weight.data)

self.relu = ReLU(inplace=True)
self.fc2 = Conv2d(
channels // reduction,
channels,
kernel_size=1,
padding=0,
bias=False)

self.sigmoid = Sigmoid()

def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)

return module_input * x

+ 279
- 0
modelscope/models/cv/face_recognition/torchkit/backbone/model_irse.py View File

@@ -0,0 +1,279 @@
# based on:
# https://github.com/ZhaoJ9014/face.evoLVe.PyTorch/blob/master/backbone/model_irse.py
from collections import namedtuple

from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear,
MaxPool2d, Module, PReLU, Sequential)

from .common import Flatten, SEModule, initialize_weights


class BasicBlockIR(Module):
""" BasicBlock for IRNet
"""

def __init__(self, in_channel, depth, stride):
super(BasicBlockIR, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth))
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
BatchNorm2d(depth), PReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth))

def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)

return res + shortcut


class BottleneckIR(Module):
""" BasicBlock with bottleneck for IRNet
"""

def __init__(self, in_channel, depth, stride):
super(BottleneckIR, self).__init__()
reduction_channel = depth // 4
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth))
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(
in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False),
BatchNorm2d(reduction_channel), PReLU(reduction_channel),
Conv2d(
reduction_channel,
reduction_channel, (3, 3), (1, 1),
1,
bias=False), BatchNorm2d(reduction_channel),
PReLU(reduction_channel),
Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False),
BatchNorm2d(depth))

def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)

return res + shortcut


class BasicBlockIRSE(BasicBlockIR):

def __init__(self, in_channel, depth, stride):
super(BasicBlockIRSE, self).__init__(in_channel, depth, stride)
self.res_layer.add_module('se_block', SEModule(depth, 16))


class BottleneckIRSE(BottleneckIR):

def __init__(self, in_channel, depth, stride):
super(BottleneckIRSE, self).__init__(in_channel, depth, stride)
self.res_layer.add_module('se_block', SEModule(depth, 16))


class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
'''A named tuple describing a ResNet block.'''


def get_block(in_channel, depth, num_units, stride=2):

return [Bottleneck(in_channel, depth, stride)] +\
[Bottleneck(depth, depth, 1) for i in range(num_units - 1)]


def get_blocks(num_layers):
if num_layers == 18:
blocks = [
get_block(in_channel=64, depth=64, num_units=2),
get_block(in_channel=64, depth=128, num_units=2),
get_block(in_channel=128, depth=256, num_units=2),
get_block(in_channel=256, depth=512, num_units=2)
]
elif num_layers == 34:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=4),
get_block(in_channel=128, depth=256, num_units=6),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 50:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=4),
get_block(in_channel=128, depth=256, num_units=14),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 100:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=13),
get_block(in_channel=128, depth=256, num_units=30),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 152:
blocks = [
get_block(in_channel=64, depth=256, num_units=3),
get_block(in_channel=256, depth=512, num_units=8),
get_block(in_channel=512, depth=1024, num_units=36),
get_block(in_channel=1024, depth=2048, num_units=3)
]
elif num_layers == 200:
blocks = [
get_block(in_channel=64, depth=256, num_units=3),
get_block(in_channel=256, depth=512, num_units=24),
get_block(in_channel=512, depth=1024, num_units=36),
get_block(in_channel=1024, depth=2048, num_units=3)
]

return blocks


class Backbone(Module):

def __init__(self, input_size, num_layers, mode='ir'):
""" Args:
input_size: input_size of backbone
num_layers: num_layers of backbone
mode: support ir or irse
"""
super(Backbone, self).__init__()
assert input_size[0] in [112, 224], \
'input_size should be [112, 112] or [224, 224]'
assert num_layers in [18, 34, 50, 100, 152, 200], \
'num_layers should be 18, 34, 50, 100 or 152'
assert mode in ['ir', 'ir_se'], \
'mode should be ir or ir_se'
self.input_layer = Sequential(
Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64),
PReLU(64))
blocks = get_blocks(num_layers)
if num_layers <= 100:
if mode == 'ir':
unit_module = BasicBlockIR
elif mode == 'ir_se':
unit_module = BasicBlockIRSE
output_channel = 512
else:
if mode == 'ir':
unit_module = BottleneckIR
elif mode == 'ir_se':
unit_module = BottleneckIRSE
output_channel = 2048

if input_size[0] == 112:
self.output_layer = Sequential(
BatchNorm2d(output_channel), Dropout(0.4), Flatten(),
Linear(output_channel * 7 * 7, 512),
BatchNorm1d(512, affine=False))
else:
self.output_layer = Sequential(
BatchNorm2d(output_channel), Dropout(0.4), Flatten(),
Linear(output_channel * 14 * 14, 512),
BatchNorm1d(512, affine=False))

modules = []
for block in blocks:
for bottleneck in block:
modules.append(
unit_module(bottleneck.in_channel, bottleneck.depth,
bottleneck.stride))
self.body = Sequential(*modules)

initialize_weights(self.modules())

def forward(self, x):
x = self.input_layer(x)
x = self.body(x)
x = self.output_layer(x)
return x


def IR_18(input_size):
""" Constructs a ir-18 model.
"""
model = Backbone(input_size, 18, 'ir')

return model


def IR_34(input_size):
""" Constructs a ir-34 model.
"""
model = Backbone(input_size, 34, 'ir')

return model


def IR_50(input_size):
""" Constructs a ir-50 model.
"""
model = Backbone(input_size, 50, 'ir')

return model


def IR_101(input_size):
""" Constructs a ir-101 model.
"""
model = Backbone(input_size, 100, 'ir')

return model


def IR_152(input_size):
""" Constructs a ir-152 model.
"""
model = Backbone(input_size, 152, 'ir')

return model


def IR_200(input_size):
""" Constructs a ir-200 model.
"""
model = Backbone(input_size, 200, 'ir')

return model


def IR_SE_50(input_size):
""" Constructs a ir_se-50 model.
"""
model = Backbone(input_size, 50, 'ir_se')

return model


def IR_SE_101(input_size):
""" Constructs a ir_se-101 model.
"""
model = Backbone(input_size, 100, 'ir_se')

return model


def IR_SE_152(input_size):
""" Constructs a ir_se-152 model.
"""
model = Backbone(input_size, 152, 'ir_se')

return model


def IR_SE_200(input_size):
""" Constructs a ir_se-200 model.
"""
model = Backbone(input_size, 200, 'ir_se')

return model

+ 162
- 0
modelscope/models/cv/face_recognition/torchkit/backbone/model_resnet.py View File

@@ -0,0 +1,162 @@
# based on:
# https://github.com/ZhaoJ9014/face.evoLVe.PyTorch/blob/master/backbone/model_resnet.py
import torch.nn as nn
from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear,
MaxPool2d, Module, ReLU, Sequential)

from .common import initialize_weights


def conv3x3(in_planes, out_planes, stride=1):
""" 3x3 convolution with padding
"""
return Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False)


def conv1x1(in_planes, out_planes, stride=1):
""" 1x1 convolution
"""
return Conv2d(
in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class Bottleneck(Module):
expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = conv1x1(inplanes, planes)
self.bn1 = BatchNorm2d(planes)
self.conv2 = conv3x3(planes, planes, stride)
self.bn2 = BatchNorm2d(planes)
self.conv3 = conv1x1(planes, planes * self.expansion)
self.bn3 = BatchNorm2d(planes * self.expansion)
self.relu = ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu(out)

return out


class ResNet(Module):
""" ResNet backbone
"""

def __init__(self, input_size, block, layers, zero_init_residual=True):
""" Args:
input_size: input_size of backbone
block: block function
layers: layers in each block
"""
super(ResNet, self).__init__()
assert input_size[0] in [112, 224],\
'input_size should be [112, 112] or [224, 224]'
self.inplanes = 64
self.conv1 = Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = BatchNorm2d(64)
self.relu = ReLU(inplace=True)
self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

self.bn_o1 = BatchNorm2d(2048)
self.dropout = Dropout()
if input_size[0] == 112:
self.fc = Linear(2048 * 4 * 4, 512)
else:
self.fc = Linear(2048 * 7 * 7, 512)
self.bn_o2 = BatchNorm1d(512)

initialize_weights(self.modules)
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)

def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
BatchNorm2d(planes * block.expansion),
)

layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))

return Sequential(*layers)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

x = self.bn_o1(x)
x = self.dropout(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
x = self.bn_o2(x)

return x


def ResNet_50(input_size, **kwargs):
""" Constructs a ResNet-50 model.
"""
model = ResNet(input_size, Bottleneck, [3, 4, 6, 3], **kwargs)

return model


def ResNet_101(input_size, **kwargs):
""" Constructs a ResNet-101 model.
"""
model = ResNet(input_size, Bottleneck, [3, 4, 23, 3], **kwargs)

return model


def ResNet_152(input_size, **kwargs):
""" Constructs a ResNet-152 model.
"""
model = ResNet(input_size, Bottleneck, [3, 8, 36, 3], **kwargs)

return model

+ 26
- 0
modelscope/outputs.py View File

@@ -13,6 +13,7 @@ class OutputKeys(object):
POSES = 'poses'
CAPTION = 'caption'
BOXES = 'boxes'
KEYPOINTS = 'keypoints'
MASKS = 'masks'
TEXT = 'text'
POLYGONS = 'polygons'
@@ -55,6 +56,31 @@ TASK_OUTPUTS = {
Tasks.object_detection:
[OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.BOXES],

# face detection result for single sample
# {
# "scores": [0.9, 0.1, 0.05, 0.05]
# "boxes": [
# [x1, y1, x2, y2],
# [x1, y1, x2, y2],
# [x1, y1, x2, y2],
# [x1, y1, x2, y2],
# ],
# "keypoints": [
# [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5],
# [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5],
# [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5],
# [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5],
# ],
# }
Tasks.face_detection:
[OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS],

# face recognition result for single sample
# {
# "img_embedding": np.array with shape [1, D],
# }
Tasks.face_recognition: [OutputKeys.IMG_EMBEDDING],

# instance segmentation result for single sample
# {
# "scores": [0.9, 0.1, 0.05, 0.05],


+ 5
- 1
modelscope/pipelines/base.py View File

@@ -255,7 +255,11 @@ class Pipeline(ABC):
elif isinstance(data, InputFeatures):
return data
else:
raise ValueError(f'Unsupported data type {type(data)}')
import mmcv
if isinstance(data, mmcv.parallel.data_container.DataContainer):
return data
else:
raise ValueError(f'Unsupported data type {type(data)}')

def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]:
preprocess_params = kwargs.get('preprocess_params')


+ 4
- 0
modelscope/pipelines/builder.py View File

@@ -80,6 +80,10 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.text_to_image_synthesis:
(Pipelines.text_to_image_synthesis,
'damo/cv_imagen_text-to-image-synthesis_tiny'),
Tasks.face_detection: (Pipelines.face_detection,
'damo/cv_resnet_facedetection_scrfd10gkps'),
Tasks.face_recognition: (Pipelines.face_recognition,
'damo/cv_ir101_facerecognition_cfglint'),
Tasks.video_multi_modal_embedding:
(Pipelines.video_multi_modal_embedding,
'damo/multi_modal_clip_vtretrival_msrvtt_53'),


+ 23
- 17
modelscope/pipelines/cv/__init__.py View File

@@ -5,44 +5,50 @@ from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .action_recognition_pipeline import ActionRecognitionPipeline
from .animal_recog_pipeline import AnimalRecogPipeline
from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline
from .live_category_pipeline import LiveCategoryPipeline
from .image_classification_pipeline import GeneralImageClassificationPipeline
from .animal_recognition_pipeline import AnimalRecognitionPipeline
from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline
from .face_detection_pipeline import FaceDetectionPipeline
from .face_recognition_pipeline import FaceRecognitionPipeline
from .face_image_generation_pipeline import FaceImageGenerationPipeline
from .image_cartoon_pipeline import ImageCartoonPipeline
from .image_classification_pipeline import GeneralImageClassificationPipeline
from .image_denoise_pipeline import ImageDenoisePipeline
from .image_color_enhance_pipeline import ImageColorEnhancePipeline
from .image_colorization_pipeline import ImageColorizationPipeline
from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline
from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline
from .video_category_pipeline import VideoCategoryPipeline
from .image_matting_pipeline import ImageMattingPipeline
from .image_super_resolution_pipeline import ImageSuperResolutionPipeline
from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline
from .style_transfer_pipeline import StyleTransferPipeline
from .live_category_pipeline import LiveCategoryPipeline
from .ocr_detection_pipeline import OCRDetectionPipeline
from .video_category_pipeline import VideoCategoryPipeline
from .virtual_tryon_pipeline import VirtualTryonPipeline
else:
_import_structure = {
'action_recognition_pipeline': ['ActionRecognitionPipeline'],
'animal_recog_pipeline': ['AnimalRecogPipeline'],
'cmdssl_video_embedding_pipleline': ['CMDSSLVideoEmbeddingPipeline'],
'animal_recognition_pipeline': ['AnimalRecognitionPipeline'],
'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'],
'face_detection_pipeline': ['FaceDetectionPipeline'],
'face_image_generation_pipeline': ['FaceImageGenerationPipeline'],
'face_recognition_pipeline': ['FaceRecognitionPipeline'],
'image_classification_pipeline':
['GeneralImageClassificationPipeline'],
'image_cartoon_pipeline': ['ImageCartoonPipeline'],
'image_denoise_pipeline': ['ImageDenoisePipeline'],
'image_color_enhance_pipeline': ['ImageColorEnhancePipeline'],
'virtual_tryon_pipeline': ['VirtualTryonPipeline'],
'image_colorization_pipeline': ['ImageColorizationPipeline'],
'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'],
'image_denoise_pipeline': ['ImageDenoisePipeline'],
'face_image_generation_pipeline': ['FaceImageGenerationPipeline'],
'image_cartoon_pipeline': ['ImageCartoonPipeline'],
'image_matting_pipeline': ['ImageMattingPipeline'],
'style_transfer_pipeline': ['StyleTransferPipeline'],
'ocr_detection_pipeline': ['OCRDetectionPipeline'],
'image_instance_segmentation_pipeline':
['ImageInstanceSegmentationPipeline'],
'video_category_pipeline': ['VideoCategoryPipeline'],
'image_matting_pipeline': ['ImageMattingPipeline'],
'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'],
'image_to_image_translation_pipeline':
['Image2ImageTranslationPipeline'],
'live_category_pipeline': ['LiveCategoryPipeline'],
'ocr_detection_pipeline': ['OCRDetectionPipeline'],
'style_transfer_pipeline': ['StyleTransferPipeline'],
'video_category_pipeline': ['VideoCategoryPipeline'],
'virtual_tryon_pipeline': ['VirtualTryonPipeline'],
}

import sys


+ 1
- 1
modelscope/pipelines/cv/action_recognition_pipeline.py View File

@@ -23,7 +23,7 @@ class ActionRecognitionPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` and `preprocessor` to create a kws pipeline for prediction
use `model` to create a action recognition pipeline for prediction
Args:
model: model id on modelscope hub.
"""


modelscope/pipelines/cv/animal_recog_pipeline.py → modelscope/pipelines/cv/animal_recognition_pipeline.py View File

@@ -22,11 +22,11 @@ logger = get_logger()

@PIPELINES.register_module(
Tasks.image_classification, module_name=Pipelines.animal_recognation)
class AnimalRecogPipeline(Pipeline):
class AnimalRecognitionPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` and `preprocessor` to create a kws pipeline for prediction
use `model` to create a animal recognition pipeline for prediction
Args:
model: model id on modelscope hub.
"""

modelscope/pipelines/cv/cmdssl_video_embedding_pipleline.py → modelscope/pipelines/cv/cmdssl_video_embedding_pipeline.py View File

@@ -24,7 +24,7 @@ class CMDSSLVideoEmbeddingPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` and `preprocessor` to create a kws pipeline for prediction
use `model` to create a CMDSSL Video Embedding pipeline for prediction
Args:
model: model id on modelscope hub.
"""

+ 105
- 0
modelscope/pipelines/cv/face_detection_pipeline.py View File

@@ -0,0 +1,105 @@
import os.path as osp
from typing import Any, Dict

import cv2
import numpy as np
import PIL
import torch

from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.face_detection, module_name=Pipelines.face_detection)
class FaceDetectionPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create a face detection pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
from mmcv import Config
from mmcv.parallel import MMDataParallel
from mmcv.runner import load_checkpoint
from mmdet.models import build_detector
from modelscope.models.cv.face_detection.mmdet_patch.datasets import RetinaFaceDataset
from modelscope.models.cv.face_detection.mmdet_patch.datasets.pipelines import RandomSquareCrop
from modelscope.models.cv.face_detection.mmdet_patch.models.backbones import ResNetV1e
from modelscope.models.cv.face_detection.mmdet_patch.models.dense_heads import SCRFDHead
from modelscope.models.cv.face_detection.mmdet_patch.models.detectors import SCRFD
cfg = Config.fromfile(osp.join(model, 'mmcv_scrfd_10g_bnkps.py'))
detector = build_detector(
cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_BIN_FILE)
logger.info(f'loading model from {ckpt_path}')
device = torch.device(
f'cuda:{0}' if torch.cuda.is_available() else 'cpu')
load_checkpoint(detector, ckpt_path, map_location=device)
detector = MMDataParallel(detector, device_ids=[0])
detector.eval()
self.detector = detector
logger.info('load model done')

def preprocess(self, input: Input) -> Dict[str, Any]:
img = LoadImage.convert_to_ndarray(input)
img = img.astype(np.float32)
pre_pipeline = [
dict(
type='MultiScaleFlipAug',
img_scale=(640, 640),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.0),
dict(
type='Normalize',
mean=[127.5, 127.5, 127.5],
std=[128.0, 128.0, 128.0],
to_rgb=False),
dict(type='Pad', size=(640, 640), pad_val=0),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
from mmdet.datasets.pipelines import Compose
pipeline = Compose(pre_pipeline)
result = {}
result['filename'] = ''
result['ori_filename'] = ''
result['img'] = img
result['img_shape'] = img.shape
result['ori_shape'] = img.shape
result['img_fields'] = ['img']
result = pipeline(result)
return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:

result = self.detector(
return_loss=False,
rescale=True,
img=[input['img'][0].unsqueeze(0)],
img_metas=[[dict(input['img_metas'][0].data)]])
assert result is not None
result = result[0][0]
bboxes = result[:, :4].tolist()
kpss = result[:, 5:].tolist()
scores = result[:, 4].tolist()
return {
OutputKeys.SCORES: scores,
OutputKeys.BOXES: bboxes,
OutputKeys.KEYPOINTS: kpss
}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 1
- 1
modelscope/pipelines/cv/face_image_generation_pipeline.py View File

@@ -24,7 +24,7 @@ class FaceImageGenerationPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create a kws pipeline for prediction
use `model` to create a face image generation pipeline for prediction
Args:
model: model id on modelscope hub.
"""


+ 130
- 0
modelscope/pipelines/cv/face_recognition_pipeline.py View File

@@ -0,0 +1,130 @@
import os.path as osp
from typing import Any, Dict

import cv2
import numpy as np
import PIL
import torch

from modelscope.metainfo import Pipelines
from modelscope.models.cv.face_recognition.align_face import align_face
from modelscope.models.cv.face_recognition.torchkit.backbone import get_model
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.face_recognition, module_name=Pipelines.face_recognition)
class FaceRecognitionPipeline(Pipeline):

def __init__(self, model: str, face_detection: Pipeline, **kwargs):
"""
use `model` to create a face recognition pipeline for prediction
Args:
model: model id on modelscope hub.
face_detecion: pipeline for face detection and face alignment before recognition
"""

# face recong model
super().__init__(model=model, **kwargs)
device = torch.device(
f'cuda:{0}' if torch.cuda.is_available() else 'cpu')
self.device = device
face_model = get_model('IR_101')([112, 112])
face_model.load_state_dict(
torch.load(
osp.join(model, ModelFile.TORCH_MODEL_BIN_FILE),
map_location=device))
face_model = face_model.to(device)
face_model.eval()
self.face_model = face_model
logger.info('face recognition model loaded!')
# face detect pipeline
self.face_detection = face_detection

def _choose_face(self,
det_result,
min_face=10,
top_face=1,
center_face=False):
'''
choose face with maximum area
Args:
det_result: output of face detection pipeline
min_face: minimum size of valid face w/h
top_face: take faces with top max areas
center_face: choose the most centerd face from multi faces, only valid if top_face > 1
'''
bboxes = np.array(det_result[OutputKeys.BOXES])
landmarks = np.array(det_result[OutputKeys.KEYPOINTS])
# scores = np.array(det_result[OutputKeys.SCORES])
if bboxes.shape[0] == 0:
logger.info('No face detected!')
return None
# face idx with enough size
face_idx = []
for i in range(bboxes.shape[0]):
box = bboxes[i]
if (box[2] - box[0]) >= min_face and (box[3] - box[1]) >= min_face:
face_idx += [i]
if len(face_idx) == 0:
logger.info(
f'Face size not enough, less than {min_face}x{min_face}!')
return None
bboxes = bboxes[face_idx]
landmarks = landmarks[face_idx]
# find max faces
boxes = np.array(bboxes)
area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
sort_idx = np.argsort(area)[-top_face:]
# find center face
if top_face > 1 and center_face and bboxes.shape[0] > 1:
img_center = [img.shape[1] // 2, img.shape[0] // 2]
min_dist = float('inf')
sel_idx = -1
for _idx in sort_idx:
box = boxes[_idx]
dist = np.square(
np.abs((box[0] + box[2]) / 2 - img_center[0])) + np.square(
np.abs((box[1] + box[3]) / 2 - img_center[1]))
if dist < min_dist:
min_dist = dist
sel_idx = _idx
sort_idx = [sel_idx]
main_idx = sort_idx[-1]
return bboxes[main_idx], landmarks[main_idx]

def preprocess(self, input: Input) -> Dict[str, Any]:
img = LoadImage.convert_to_ndarray(input)
img = img[:, :, ::-1]
det_result = self.face_detection(img.copy())
rtn = self._choose_face(det_result)
face_img = None
if rtn is not None:
_, face_lmks = rtn
face_lmks = face_lmks.reshape(5, 2)
align_img, _ = align_face(img, (112, 112), face_lmks)
face_img = align_img[:, :, ::-1] # to rgb
face_img = np.transpose(face_img, axes=(2, 0, 1))
face_img = (face_img / 255. - 0.5) / 0.5
face_img = face_img.astype(np.float32)
result = {}
result['img'] = face_img
return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
assert input['img'] is not None
img = input['img'].unsqueeze(0)
emb = self.face_model(img).detach().cpu().numpy()
emb /= np.sqrt(np.sum(emb**2, -1, keepdims=True)) # l2 norm
return {OutputKeys.IMG_EMBEDDING: emb}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 1
- 1
modelscope/pipelines/cv/image_cartoon_pipeline.py View File

@@ -30,7 +30,7 @@ class ImageCartoonPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` and `preprocessor` to create a kws pipeline for prediction
use `model` to create a image cartoon pipeline for prediction
Args:
model: model id on modelscope hub.
"""


+ 1
- 1
modelscope/pipelines/cv/image_color_enhance_pipeline.py View File

@@ -27,7 +27,7 @@ class ImageColorEnhancePipeline(Pipeline):
ImageColorEnhanceFinetunePreprocessor] = None,
**kwargs):
"""
use `model` and `preprocessor` to create a kws pipeline for prediction
use `model` and `preprocessor` to create a image color enhance pipeline for prediction
Args:
model: model id on modelscope hub.
"""


+ 1
- 1
modelscope/pipelines/cv/image_colorization_pipeline.py View File

@@ -25,7 +25,7 @@ class ImageColorizationPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create a kws pipeline for prediction
use `model` to create a image colorization pipeline for prediction
Args:
model: model id on modelscope hub.
"""


+ 1
- 1
modelscope/pipelines/cv/image_matting_pipeline.py View File

@@ -21,7 +21,7 @@ class ImageMattingPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` and `preprocessor` to create a kws pipeline for prediction
use `model` to create a image matting pipeline for prediction
Args:
model: model id on modelscope hub.
"""


+ 1
- 1
modelscope/pipelines/cv/image_super_resolution_pipeline.py View File

@@ -23,7 +23,7 @@ class ImageSuperResolutionPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create a kws pipeline for prediction
use `model` to create a image super resolution pipeline for prediction
Args:
model: model id on modelscope hub.
"""


+ 1
- 1
modelscope/pipelines/cv/ocr_detection_pipeline.py View File

@@ -41,7 +41,7 @@ class OCRDetectionPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` and `preprocessor` to create a kws pipeline for prediction
use `model` to create a OCR detection pipeline for prediction
Args:
model: model id on modelscope hub.
"""


+ 1
- 1
modelscope/pipelines/cv/style_transfer_pipeline.py View File

@@ -21,7 +21,7 @@ class StyleTransferPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` and `preprocessor` to create a kws pipeline for prediction
use `model` to create a style transfer pipeline for prediction
Args:
model: model id on modelscope hub.
"""


+ 1
- 1
modelscope/pipelines/cv/virtual_tryon_pipeline.py View File

@@ -25,7 +25,7 @@ class VirtualTryonPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create a kws pipeline for prediction
use `model` to create a virtual tryon pipeline for prediction
Args:
model: model id on modelscope hub.
"""


+ 2
- 0
modelscope/utils/constant.py View File

@@ -28,6 +28,8 @@ class CVTasks(object):
ocr_detection = 'ocr-detection'
action_recognition = 'action-recognition'
video_embedding = 'video-embedding'
face_detection = 'face-detection'
face_recognition = 'face-recognition'
image_color_enhance = 'image-color-enhance'
virtual_tryon = 'virtual-tryon'
image_colorization = 'image-colorization'


+ 84
- 0
tests/pipelines/test_face_detection.py View File

@@ -0,0 +1,84 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import tempfile
import unittest

import cv2
import numpy as np

from modelscope.fileio import File
from modelscope.msdatasets import MsDataset
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.test_utils import test_level


class FaceDetectionTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/cv_resnet_facedetection_scrfd10gkps'

def show_result(self, img_path, bboxes, kpss, scores):
bboxes = np.array(bboxes)
kpss = np.array(kpss)
scores = np.array(scores)
img = cv2.imread(img_path)
assert img is not None, f"Can't read img: {img_path}"
for i in range(len(scores)):
bbox = bboxes[i].astype(np.int32)
kps = kpss[i].reshape(-1, 2).astype(np.int32)
score = scores[i]
x1, y1, x2, y2 = bbox
cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
for kp in kps:
cv2.circle(img, tuple(kp), 1, (0, 0, 255), 1)
cv2.putText(
img,
f'{score:.2f}', (x1, y2),
1,
1.0, (0, 255, 0),
thickness=1,
lineType=8)
cv2.imwrite('result.png', img)
print(
f'Found {len(scores)} faces, output written to {osp.abspath("result.png")}'
)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_dataset(self):
input_location = ['data/test/images/face_detection.png']
# alternatively:
# input_location = '/dir/to/images'

dataset = MsDataset.load(input_location, target='image')
face_detection = pipeline(Tasks.face_detection, model=self.model_id)
# note that for dataset output, the inference-output is a Generator that can be iterated.
result = face_detection(dataset)
result = next(result)
self.show_result(input_location[0], result[OutputKeys.BOXES],
result[OutputKeys.KEYPOINTS],
result[OutputKeys.SCORES])

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
face_detection = pipeline(Tasks.face_detection, model=self.model_id)
img_path = 'data/test/images/face_detection.png'

result = face_detection(img_path)
self.show_result(img_path, result[OutputKeys.BOXES],
result[OutputKeys.KEYPOINTS],
result[OutputKeys.SCORES])

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
face_detection = pipeline(Tasks.face_detection)
img_path = 'data/test/images/face_detection.png'
result = face_detection(img_path)
self.show_result(img_path, result[OutputKeys.BOXES],
result[OutputKeys.KEYPOINTS],
result[OutputKeys.SCORES])


if __name__ == '__main__':
unittest.main()

+ 42
- 0
tests/pipelines/test_face_recognition.py View File

@@ -0,0 +1,42 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import tempfile
import unittest

import cv2
import numpy as np

from modelscope.fileio import File
from modelscope.msdatasets import MsDataset
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.test_utils import test_level


class FaceRecognitionTest(unittest.TestCase):

def setUp(self) -> None:
self.recog_model_id = 'damo/cv_ir101_facerecognition_cfglint'
self.det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps'

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_face_compare(self):
img1 = 'data/test/images/face_recognition_1.png'
img2 = 'data/test/images/face_recognition_2.png'

face_detection = pipeline(
Tasks.face_detection, model=self.det_model_id)
face_recognition = pipeline(
Tasks.face_recognition,
face_detection=face_detection,
model=self.recog_model_id)
# note that for dataset output, the inference-output is a Generator that can be iterated.
emb1 = face_recognition(img1)[OutputKeys.IMG_EMBEDDING]
emb2 = face_recognition(img2)[OutputKeys.IMG_EMBEDDING]
sim = np.dot(emb1[0], emb2[0])
print(f'Cos similarity={sim:.3f}, img1:{img1} img2:{img2}')


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save