1. 调整face_detection的文件层级(scrfd与其余新增face_detection方法平级); 2. 增加极大脸/旋转脸的检测方法,更新了新模型; 3. 支持读入数据集并finetune和eval; 4. 新增card_detection模型,支持读入datasethub数据集并finetune Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10244540master
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:ecbc9d0827cfb92e93e7d75868b1724142685dc20d3b32023c3c657a7b688a9c | |||
size 254845 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:d510ab26ddc58ffea882c8ef850c1f9bd4444772f2bce7ebea3e76944536c3ae | |||
size 48909 |
@@ -148,6 +148,7 @@ class Pipelines(object): | |||
salient_detection = 'u2net-salient-detection' | |||
image_classification = 'image-classification' | |||
face_detection = 'resnet-face-detection-scrfd10gkps' | |||
card_detection = 'resnet-card-detection-scrfd34gkps' | |||
ulfd_face_detection = 'manual-face-detection-ulfd' | |||
facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' | |||
retina_face_detection = 'resnet50-face-detection-retinaface' | |||
@@ -270,6 +271,8 @@ class Trainers(object): | |||
image_portrait_enhancement = 'image-portrait-enhancement' | |||
video_summarization = 'video-summarization' | |||
movie_scene_segmentation = 'movie-scene-segmentation' | |||
face_detection_scrfd = 'face-detection-scrfd' | |||
card_detection_scrfd = 'card-detection-scrfd' | |||
image_inpainting = 'image-inpainting' | |||
# nlp trainers | |||
@@ -8,12 +8,14 @@ if TYPE_CHECKING: | |||
from .mtcnn import MtcnnFaceDetector | |||
from .retinaface import RetinaFaceDetection | |||
from .ulfd_slim import UlfdFaceDetector | |||
from .scrfd import ScrfdDetect | |||
else: | |||
_import_structure = { | |||
'ulfd_slim': ['UlfdFaceDetector'], | |||
'retinaface': ['RetinaFaceDetection'], | |||
'mtcnn': ['MtcnnFaceDetector'], | |||
'mogface': ['MogFaceDetector'] | |||
'mogface': ['MogFaceDetector'], | |||
'scrfd': ['ScrfdDetect'] | |||
} | |||
import sys | |||
@@ -1,189 +0,0 @@ | |||
""" | |||
The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||
https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/transforms.py | |||
""" | |||
import numpy as np | |||
from mmdet.datasets.builder import PIPELINES | |||
from numpy import random | |||
@PIPELINES.register_module() | |||
class RandomSquareCrop(object): | |||
"""Random crop the image & bboxes, the cropped patches have minimum IoU | |||
requirement with original image & bboxes, the IoU threshold is randomly | |||
selected from min_ious. | |||
Args: | |||
min_ious (tuple): minimum IoU threshold for all intersections with | |||
bounding boxes | |||
min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w, | |||
where a >= min_crop_size). | |||
Note: | |||
The keys for bboxes, labels and masks should be paired. That is, \ | |||
`gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \ | |||
`gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`. | |||
""" | |||
def __init__(self, | |||
crop_ratio_range=None, | |||
crop_choice=None, | |||
bbox_clip_border=True): | |||
self.crop_ratio_range = crop_ratio_range | |||
self.crop_choice = crop_choice | |||
self.bbox_clip_border = bbox_clip_border | |||
assert (self.crop_ratio_range is None) ^ (self.crop_choice is None) | |||
if self.crop_ratio_range is not None: | |||
self.crop_ratio_min, self.crop_ratio_max = self.crop_ratio_range | |||
self.bbox2label = { | |||
'gt_bboxes': 'gt_labels', | |||
'gt_bboxes_ignore': 'gt_labels_ignore' | |||
} | |||
self.bbox2mask = { | |||
'gt_bboxes': 'gt_masks', | |||
'gt_bboxes_ignore': 'gt_masks_ignore' | |||
} | |||
def __call__(self, results): | |||
"""Call function to crop images and bounding boxes with minimum IoU | |||
constraint. | |||
Args: | |||
results (dict): Result dict from loading pipeline. | |||
Returns: | |||
dict: Result dict with images and bounding boxes cropped, \ | |||
'img_shape' key is updated. | |||
""" | |||
if 'img_fields' in results: | |||
assert results['img_fields'] == ['img'], \ | |||
'Only single img_fields is allowed' | |||
img = results['img'] | |||
assert 'bbox_fields' in results | |||
assert 'gt_bboxes' in results | |||
boxes = results['gt_bboxes'] | |||
h, w, c = img.shape | |||
scale_retry = 0 | |||
if self.crop_ratio_range is not None: | |||
max_scale = self.crop_ratio_max | |||
else: | |||
max_scale = np.amax(self.crop_choice) | |||
while True: | |||
scale_retry += 1 | |||
if scale_retry == 1 or max_scale > 1.0: | |||
if self.crop_ratio_range is not None: | |||
scale = np.random.uniform(self.crop_ratio_min, | |||
self.crop_ratio_max) | |||
elif self.crop_choice is not None: | |||
scale = np.random.choice(self.crop_choice) | |||
else: | |||
scale = scale * 1.2 | |||
for i in range(250): | |||
short_side = min(w, h) | |||
cw = int(scale * short_side) | |||
ch = cw | |||
# TODO +1 | |||
if w == cw: | |||
left = 0 | |||
elif w > cw: | |||
left = random.randint(0, w - cw) | |||
else: | |||
left = random.randint(w - cw, 0) | |||
if h == ch: | |||
top = 0 | |||
elif h > ch: | |||
top = random.randint(0, h - ch) | |||
else: | |||
top = random.randint(h - ch, 0) | |||
patch = np.array( | |||
(int(left), int(top), int(left + cw), int(top + ch)), | |||
dtype=np.int) | |||
# center of boxes should inside the crop img | |||
# only adjust boxes and instance masks when the gt is not empty | |||
# adjust boxes | |||
def is_center_of_bboxes_in_patch(boxes, patch): | |||
# TODO >= | |||
center = (boxes[:, :2] + boxes[:, 2:]) / 2 | |||
mask = \ | |||
((center[:, 0] > patch[0]) | |||
* (center[:, 1] > patch[1]) | |||
* (center[:, 0] < patch[2]) | |||
* (center[:, 1] < patch[3])) | |||
return mask | |||
mask = is_center_of_bboxes_in_patch(boxes, patch) | |||
if not mask.any(): | |||
continue | |||
for key in results.get('bbox_fields', []): | |||
boxes = results[key].copy() | |||
mask = is_center_of_bboxes_in_patch(boxes, patch) | |||
boxes = boxes[mask] | |||
if self.bbox_clip_border: | |||
boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) | |||
boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) | |||
boxes -= np.tile(patch[:2], 2) | |||
results[key] = boxes | |||
# labels | |||
label_key = self.bbox2label.get(key) | |||
if label_key in results: | |||
results[label_key] = results[label_key][mask] | |||
# keypoints field | |||
if key == 'gt_bboxes': | |||
for kps_key in results.get('keypoints_fields', []): | |||
keypointss = results[kps_key].copy() | |||
keypointss = keypointss[mask, :, :] | |||
if self.bbox_clip_border: | |||
keypointss[:, :, : | |||
2] = keypointss[:, :, :2].clip( | |||
max=patch[2:]) | |||
keypointss[:, :, : | |||
2] = keypointss[:, :, :2].clip( | |||
min=patch[:2]) | |||
keypointss[:, :, 0] -= patch[0] | |||
keypointss[:, :, 1] -= patch[1] | |||
results[kps_key] = keypointss | |||
# mask fields | |||
mask_key = self.bbox2mask.get(key) | |||
if mask_key in results: | |||
results[mask_key] = results[mask_key][mask.nonzero() | |||
[0]].crop(patch) | |||
# adjust the img no matter whether the gt is empty before crop | |||
rimg = np.ones((ch, cw, 3), dtype=img.dtype) * 128 | |||
patch_from = patch.copy() | |||
patch_from[0] = max(0, patch_from[0]) | |||
patch_from[1] = max(0, patch_from[1]) | |||
patch_from[2] = min(img.shape[1], patch_from[2]) | |||
patch_from[3] = min(img.shape[0], patch_from[3]) | |||
patch_to = patch.copy() | |||
patch_to[0] = max(0, patch_to[0] * -1) | |||
patch_to[1] = max(0, patch_to[1] * -1) | |||
patch_to[2] = patch_to[0] + (patch_from[2] - patch_from[0]) | |||
patch_to[3] = patch_to[1] + (patch_from[3] - patch_from[1]) | |||
rimg[patch_to[1]:patch_to[3], | |||
patch_to[0]:patch_to[2], :] = img[ | |||
patch_from[1]:patch_from[3], | |||
patch_from[0]:patch_from[2], :] | |||
img = rimg | |||
results['img'] = img | |||
results['img_shape'] = img.shape | |||
return results | |||
def __repr__(self): | |||
repr_str = self.__class__.__name__ | |||
repr_str += f'(min_ious={self.min_iou}, ' | |||
repr_str += f'crop_size={self.crop_size})' | |||
return repr_str |
@@ -0,0 +1,2 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from .scrfd_detect import ScrfdDetect |
@@ -6,7 +6,7 @@ import numpy as np | |||
import torch | |||
def bbox2result(bboxes, labels, num_classes, kps=None): | |||
def bbox2result(bboxes, labels, num_classes, kps=None, num_kps=5): | |||
"""Convert detection results to a list of numpy arrays. | |||
Args: | |||
@@ -17,7 +17,7 @@ def bbox2result(bboxes, labels, num_classes, kps=None): | |||
Returns: | |||
list(ndarray): bbox results of each class | |||
""" | |||
bbox_len = 5 if kps is None else 5 + 10 # if has kps, add 10 kps into bbox | |||
bbox_len = 5 if kps is None else 5 + num_kps * 2 # if has kps, add num_kps*2 into bbox | |||
if bboxes.shape[0] == 0: | |||
return [ | |||
np.zeros((0, bbox_len), dtype=np.float32) |
@@ -17,6 +17,7 @@ def multiclass_nms(multi_bboxes, | |||
Args: | |||
multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) | |||
multi_kps (Tensor): shape (n, #class*num_kps*2) or (n, num_kps*2) | |||
multi_scores (Tensor): shape (n, #class), where the last column | |||
contains scores of the background class, but this will be ignored. | |||
score_thr (float): bbox threshold, bboxes with scores lower than it | |||
@@ -36,16 +37,18 @@ def multiclass_nms(multi_bboxes, | |||
num_classes = multi_scores.size(1) - 1 | |||
# exclude background category | |||
kps = None | |||
if multi_kps is not None: | |||
num_kps = int((multi_kps.shape[1] / num_classes) / 2) | |||
if multi_bboxes.shape[1] > 4: | |||
bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) | |||
if multi_kps is not None: | |||
kps = multi_kps.view(multi_scores.size(0), -1, 10) | |||
kps = multi_kps.view(multi_scores.size(0), -1, num_kps * 2) | |||
else: | |||
bboxes = multi_bboxes[:, None].expand( | |||
multi_scores.size(0), num_classes, 4) | |||
if multi_kps is not None: | |||
kps = multi_kps[:, None].expand( | |||
multi_scores.size(0), num_classes, 10) | |||
multi_scores.size(0), num_classes, num_kps * 2) | |||
scores = multi_scores[:, :-1] | |||
if score_factors is not None: | |||
@@ -56,7 +59,7 @@ def multiclass_nms(multi_bboxes, | |||
bboxes = bboxes.reshape(-1, 4) | |||
if kps is not None: | |||
kps = kps.reshape(-1, 10) | |||
kps = kps.reshape(-1, num_kps * 2) | |||
scores = scores.reshape(-1) | |||
labels = labels.reshape(-1) | |||
@@ -2,6 +2,12 @@ | |||
The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||
https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines | |||
""" | |||
from .auto_augment import RotateV2 | |||
from .formating import DefaultFormatBundleV2 | |||
from .loading import LoadAnnotationsV2 | |||
from .transforms import RandomSquareCrop | |||
__all__ = ['RandomSquareCrop'] | |||
__all__ = [ | |||
'RandomSquareCrop', 'LoadAnnotationsV2', 'RotateV2', | |||
'DefaultFormatBundleV2' | |||
] |
@@ -0,0 +1,271 @@ | |||
""" | |||
The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||
https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/auto_augment.py | |||
""" | |||
import copy | |||
import cv2 | |||
import mmcv | |||
import numpy as np | |||
from mmdet.datasets.builder import PIPELINES | |||
_MAX_LEVEL = 10 | |||
def level_to_value(level, max_value): | |||
"""Map from level to values based on max_value.""" | |||
return (level / _MAX_LEVEL) * max_value | |||
def random_negative(value, random_negative_prob): | |||
"""Randomly negate value based on random_negative_prob.""" | |||
return -value if np.random.rand() < random_negative_prob else value | |||
def bbox2fields(): | |||
"""The key correspondence from bboxes to labels, masks and | |||
segmentations.""" | |||
bbox2label = { | |||
'gt_bboxes': 'gt_labels', | |||
'gt_bboxes_ignore': 'gt_labels_ignore' | |||
} | |||
bbox2mask = { | |||
'gt_bboxes': 'gt_masks', | |||
'gt_bboxes_ignore': 'gt_masks_ignore' | |||
} | |||
bbox2seg = { | |||
'gt_bboxes': 'gt_semantic_seg', | |||
} | |||
return bbox2label, bbox2mask, bbox2seg | |||
@PIPELINES.register_module() | |||
class RotateV2(object): | |||
"""Apply Rotate Transformation to image (and its corresponding bbox, mask, | |||
segmentation). | |||
Args: | |||
level (int | float): The level should be in range (0,_MAX_LEVEL]. | |||
scale (int | float): Isotropic scale factor. Same in | |||
``mmcv.imrotate``. | |||
center (int | float | tuple[float]): Center point (w, h) of the | |||
rotation in the source image. If None, the center of the | |||
image will be used. Same in ``mmcv.imrotate``. | |||
img_fill_val (int | float | tuple): The fill value for image border. | |||
If float, the same value will be used for all the three | |||
channels of image. If tuple, the should be 3 elements (e.g. | |||
equals the number of channels for image). | |||
seg_ignore_label (int): The fill value used for segmentation map. | |||
Note this value must equals ``ignore_label`` in ``semantic_head`` | |||
of the corresponding config. Default 255. | |||
prob (float): The probability for perform transformation and | |||
should be in range 0 to 1. | |||
max_rotate_angle (int | float): The maximum angles for rotate | |||
transformation. | |||
random_negative_prob (float): The probability that turns the | |||
offset negative. | |||
""" | |||
def __init__(self, | |||
level, | |||
scale=1, | |||
center=None, | |||
img_fill_val=128, | |||
seg_ignore_label=255, | |||
prob=0.5, | |||
max_rotate_angle=30, | |||
random_negative_prob=0.5): | |||
assert isinstance(level, (int, float)), \ | |||
f'The level must be type int or float. got {type(level)}.' | |||
assert 0 <= level <= _MAX_LEVEL, \ | |||
f'The level should be in range (0,{_MAX_LEVEL}]. got {level}.' | |||
assert isinstance(scale, (int, float)), \ | |||
f'The scale must be type int or float. got type {type(scale)}.' | |||
if isinstance(center, (int, float)): | |||
center = (center, center) | |||
elif isinstance(center, tuple): | |||
assert len(center) == 2, 'center with type tuple must have '\ | |||
f'2 elements. got {len(center)} elements.' | |||
else: | |||
assert center is None, 'center must be None or type int, '\ | |||
f'float or tuple, got type {type(center)}.' | |||
if isinstance(img_fill_val, (float, int)): | |||
img_fill_val = tuple([float(img_fill_val)] * 3) | |||
elif isinstance(img_fill_val, tuple): | |||
assert len(img_fill_val) == 3, 'img_fill_val as tuple must '\ | |||
f'have 3 elements. got {len(img_fill_val)}.' | |||
img_fill_val = tuple([float(val) for val in img_fill_val]) | |||
else: | |||
raise ValueError( | |||
'img_fill_val must be float or tuple with 3 elements.') | |||
assert np.all([0 <= val <= 255 for val in img_fill_val]), \ | |||
'all elements of img_fill_val should between range [0,255]. '\ | |||
f'got {img_fill_val}.' | |||
assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. '\ | |||
f'got {prob}.' | |||
assert isinstance(max_rotate_angle, (int, float)), 'max_rotate_angle '\ | |||
f'should be type int or float. got type {type(max_rotate_angle)}.' | |||
self.level = level | |||
self.scale = scale | |||
# Rotation angle in degrees. Positive values mean | |||
# clockwise rotation. | |||
self.angle = level_to_value(level, max_rotate_angle) | |||
self.center = center | |||
self.img_fill_val = img_fill_val | |||
self.seg_ignore_label = seg_ignore_label | |||
self.prob = prob | |||
self.max_rotate_angle = max_rotate_angle | |||
self.random_negative_prob = random_negative_prob | |||
def _rotate_img(self, results, angle, center=None, scale=1.0): | |||
"""Rotate the image. | |||
Args: | |||
results (dict): Result dict from loading pipeline. | |||
angle (float): Rotation angle in degrees, positive values | |||
mean clockwise rotation. Same in ``mmcv.imrotate``. | |||
center (tuple[float], optional): Center point (w, h) of the | |||
rotation. Same in ``mmcv.imrotate``. | |||
scale (int | float): Isotropic scale factor. Same in | |||
``mmcv.imrotate``. | |||
""" | |||
for key in results.get('img_fields', ['img']): | |||
img = results[key].copy() | |||
img_rotated = mmcv.imrotate( | |||
img, angle, center, scale, border_value=self.img_fill_val) | |||
results[key] = img_rotated.astype(img.dtype) | |||
results['img_shape'] = results[key].shape | |||
def _rotate_bboxes(self, results, rotate_matrix): | |||
"""Rotate the bboxes.""" | |||
h, w, c = results['img_shape'] | |||
for key in results.get('bbox_fields', []): | |||
min_x, min_y, max_x, max_y = np.split( | |||
results[key], results[key].shape[-1], axis=-1) | |||
coordinates = np.stack([[min_x, min_y], [max_x, min_y], | |||
[min_x, max_y], | |||
[max_x, max_y]]) # [4, 2, nb_bbox, 1] | |||
# pad 1 to convert from format [x, y] to homogeneous | |||
# coordinates format [x, y, 1] | |||
coordinates = np.concatenate( | |||
(coordinates, | |||
np.ones((4, 1, coordinates.shape[2], 1), coordinates.dtype)), | |||
axis=1) # [4, 3, nb_bbox, 1] | |||
coordinates = coordinates.transpose( | |||
(2, 0, 1, 3)) # [nb_bbox, 4, 3, 1] | |||
rotated_coords = np.matmul(rotate_matrix, | |||
coordinates) # [nb_bbox, 4, 2, 1] | |||
rotated_coords = rotated_coords[..., 0] # [nb_bbox, 4, 2] | |||
min_x, min_y = np.min( | |||
rotated_coords[:, :, 0], axis=1), np.min( | |||
rotated_coords[:, :, 1], axis=1) | |||
max_x, max_y = np.max( | |||
rotated_coords[:, :, 0], axis=1), np.max( | |||
rotated_coords[:, :, 1], axis=1) | |||
results[key] = np.stack([min_x, min_y, max_x, max_y], | |||
axis=-1).astype(results[key].dtype) | |||
def _rotate_keypoints90(self, results, angle): | |||
"""Rotate the keypoints, only valid when angle in [-90,90,-180,180]""" | |||
if angle not in [-90, 90, 180, -180 | |||
] or self.scale != 1 or self.center is not None: | |||
return | |||
for key in results.get('keypoints_fields', []): | |||
k = results[key] | |||
if angle == 90: | |||
w, h, c = results['img'].shape | |||
new = np.stack([h - k[..., 1], k[..., 0], k[..., 2]], axis=-1) | |||
elif angle == -90: | |||
w, h, c = results['img'].shape | |||
new = np.stack([k[..., 1], w - k[..., 0], k[..., 2]], axis=-1) | |||
else: | |||
h, w, c = results['img'].shape | |||
new = np.stack([w - k[..., 0], h - k[..., 1], k[..., 2]], | |||
axis=-1) | |||
# a kps is invalid if thrid value is -1 | |||
kps_invalid = new[..., -1][:, -1] == -1 | |||
new[kps_invalid] = np.zeros(new.shape[1:]) - 1 | |||
results[key] = new | |||
def _rotate_masks(self, | |||
results, | |||
angle, | |||
center=None, | |||
scale=1.0, | |||
fill_val=0): | |||
"""Rotate the masks.""" | |||
h, w, c = results['img_shape'] | |||
for key in results.get('mask_fields', []): | |||
masks = results[key] | |||
results[key] = masks.rotate((h, w), angle, center, scale, fill_val) | |||
def _rotate_seg(self, | |||
results, | |||
angle, | |||
center=None, | |||
scale=1.0, | |||
fill_val=255): | |||
"""Rotate the segmentation map.""" | |||
for key in results.get('seg_fields', []): | |||
seg = results[key].copy() | |||
results[key] = mmcv.imrotate( | |||
seg, angle, center, scale, | |||
border_value=fill_val).astype(seg.dtype) | |||
def _filter_invalid(self, results, min_bbox_size=0): | |||
"""Filter bboxes and corresponding masks too small after rotate | |||
augmentation.""" | |||
bbox2label, bbox2mask, _ = bbox2fields() | |||
for key in results.get('bbox_fields', []): | |||
bbox_w = results[key][:, 2] - results[key][:, 0] | |||
bbox_h = results[key][:, 3] - results[key][:, 1] | |||
valid_inds = (bbox_w > min_bbox_size) & (bbox_h > min_bbox_size) | |||
valid_inds = np.nonzero(valid_inds)[0] | |||
results[key] = results[key][valid_inds] | |||
# label fields. e.g. gt_labels and gt_labels_ignore | |||
label_key = bbox2label.get(key) | |||
if label_key in results: | |||
results[label_key] = results[label_key][valid_inds] | |||
# mask fields, e.g. gt_masks and gt_masks_ignore | |||
mask_key = bbox2mask.get(key) | |||
if mask_key in results: | |||
results[mask_key] = results[mask_key][valid_inds] | |||
def __call__(self, results): | |||
"""Call function to rotate images, bounding boxes, masks and semantic | |||
segmentation maps. | |||
Args: | |||
results (dict): Result dict from loading pipeline. | |||
Returns: | |||
dict: Rotated results. | |||
""" | |||
if np.random.rand() > self.prob: | |||
return results | |||
h, w = results['img'].shape[:2] | |||
center = self.center | |||
if center is None: | |||
center = ((w - 1) * 0.5, (h - 1) * 0.5) | |||
angle = random_negative(self.angle, self.random_negative_prob) | |||
self._rotate_img(results, angle, center, self.scale) | |||
rotate_matrix = cv2.getRotationMatrix2D(center, -angle, self.scale) | |||
self._rotate_bboxes(results, rotate_matrix) | |||
self._rotate_keypoints90(results, angle) | |||
self._rotate_masks(results, angle, center, self.scale, fill_val=0) | |||
self._rotate_seg( | |||
results, angle, center, self.scale, fill_val=self.seg_ignore_label) | |||
self._filter_invalid(results) | |||
return results | |||
def __repr__(self): | |||
repr_str = self.__class__.__name__ | |||
repr_str += f'(level={self.level}, ' | |||
repr_str += f'scale={self.scale}, ' | |||
repr_str += f'center={self.center}, ' | |||
repr_str += f'img_fill_val={self.img_fill_val}, ' | |||
repr_str += f'seg_ignore_label={self.seg_ignore_label}, ' | |||
repr_str += f'prob={self.prob}, ' | |||
repr_str += f'max_rotate_angle={self.max_rotate_angle}, ' | |||
repr_str += f'random_negative_prob={self.random_negative_prob})' | |||
return repr_str |
@@ -0,0 +1,113 @@ | |||
""" | |||
The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||
https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/formating.py | |||
""" | |||
import numpy as np | |||
import torch | |||
from mmcv.parallel import DataContainer as DC | |||
from mmdet.datasets.builder import PIPELINES | |||
def to_tensor(data): | |||
"""Convert objects of various python types to :obj:`torch.Tensor`. | |||
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, | |||
:class:`Sequence`, :class:`int` and :class:`float`. | |||
Args: | |||
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to | |||
be converted. | |||
""" | |||
if isinstance(data, torch.Tensor): | |||
return data | |||
elif isinstance(data, np.ndarray): | |||
return torch.from_numpy(data) | |||
elif isinstance(data, Sequence) and not mmcv.is_str(data): | |||
return torch.tensor(data) | |||
elif isinstance(data, int): | |||
return torch.LongTensor([data]) | |||
elif isinstance(data, float): | |||
return torch.FloatTensor([data]) | |||
else: | |||
raise TypeError(f'type {type(data)} cannot be converted to tensor.') | |||
@PIPELINES.register_module() | |||
class DefaultFormatBundleV2(object): | |||
"""Default formatting bundle. | |||
It simplifies the pipeline of formatting common fields, including "img", | |||
"proposals", "gt_bboxes", "gt_labels", "gt_masks" and "gt_semantic_seg". | |||
These fields are formatted as follows. | |||
- img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True) | |||
- proposals: (1)to tensor, (2)to DataContainer | |||
- gt_bboxes: (1)to tensor, (2)to DataContainer | |||
- gt_bboxes_ignore: (1)to tensor, (2)to DataContainer | |||
- gt_labels: (1)to tensor, (2)to DataContainer | |||
- gt_masks: (1)to tensor, (2)to DataContainer (cpu_only=True) | |||
- gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, \ | |||
(3)to DataContainer (stack=True) | |||
""" | |||
def __call__(self, results): | |||
"""Call function to transform and format common fields in results. | |||
Args: | |||
results (dict): Result dict contains the data to convert. | |||
Returns: | |||
dict: The result dict contains the data that is formatted with \ | |||
default bundle. | |||
""" | |||
if 'img' in results: | |||
img = results['img'] | |||
# add default meta keys | |||
results = self._add_default_meta_keys(results) | |||
if len(img.shape) < 3: | |||
img = np.expand_dims(img, -1) | |||
img = np.ascontiguousarray(img.transpose(2, 0, 1)) | |||
results['img'] = DC(to_tensor(img), stack=True) | |||
for key in [ | |||
'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_keypointss', | |||
'gt_labels' | |||
]: | |||
if key not in results: | |||
continue | |||
results[key] = DC(to_tensor(results[key])) | |||
if 'gt_masks' in results: | |||
results['gt_masks'] = DC(results['gt_masks'], cpu_only=True) | |||
if 'gt_semantic_seg' in results: | |||
results['gt_semantic_seg'] = DC( | |||
to_tensor(results['gt_semantic_seg'][None, ...]), stack=True) | |||
return results | |||
def _add_default_meta_keys(self, results): | |||
"""Add default meta keys. | |||
We set default meta keys including `pad_shape`, `scale_factor` and | |||
`img_norm_cfg` to avoid the case where no `Resize`, `Normalize` and | |||
`Pad` are implemented during the whole pipeline. | |||
Args: | |||
results (dict): Result dict contains the data to convert. | |||
Returns: | |||
results (dict): Updated result dict contains the data to convert. | |||
""" | |||
img = results['img'] | |||
results.setdefault('pad_shape', img.shape) | |||
results.setdefault('scale_factor', 1.0) | |||
num_channels = 1 if len(img.shape) < 3 else img.shape[2] | |||
results.setdefault( | |||
'img_norm_cfg', | |||
dict( | |||
mean=np.zeros(num_channels, dtype=np.float32), | |||
std=np.ones(num_channels, dtype=np.float32), | |||
to_rgb=False)) | |||
return results | |||
def __repr__(self): | |||
return self.__class__.__name__ |
@@ -0,0 +1,225 @@ | |||
""" | |||
The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||
https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/loading.py | |||
""" | |||
import os.path as osp | |||
import numpy as np | |||
import pycocotools.mask as maskUtils | |||
from mmdet.core import BitmapMasks, PolygonMasks | |||
from mmdet.datasets.builder import PIPELINES | |||
@PIPELINES.register_module() | |||
class LoadAnnotationsV2(object): | |||
"""Load mutiple types of annotations. | |||
Args: | |||
with_bbox (bool): Whether to parse and load the bbox annotation. | |||
Default: True. | |||
with_label (bool): Whether to parse and load the label annotation. | |||
Default: True. | |||
with_keypoints (bool): Whether to parse and load the keypoints annotation. | |||
Default: False. | |||
with_mask (bool): Whether to parse and load the mask annotation. | |||
Default: False. | |||
with_seg (bool): Whether to parse and load the semantic segmentation | |||
annotation. Default: False. | |||
poly2mask (bool): Whether to convert the instance masks from polygons | |||
to bitmaps. Default: True. | |||
file_client_args (dict): Arguments to instantiate a FileClient. | |||
See :class:`mmcv.fileio.FileClient` for details. | |||
Defaults to ``dict(backend='disk')``. | |||
""" | |||
def __init__(self, | |||
with_bbox=True, | |||
with_label=True, | |||
with_keypoints=False, | |||
with_mask=False, | |||
with_seg=False, | |||
poly2mask=True, | |||
file_client_args=dict(backend='disk')): | |||
self.with_bbox = with_bbox | |||
self.with_label = with_label | |||
self.with_keypoints = with_keypoints | |||
self.with_mask = with_mask | |||
self.with_seg = with_seg | |||
self.poly2mask = poly2mask | |||
self.file_client_args = file_client_args.copy() | |||
self.file_client = None | |||
def _load_bboxes(self, results): | |||
"""Private function to load bounding box annotations. | |||
Args: | |||
results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||
Returns: | |||
dict: The dict contains loaded bounding box annotations. | |||
""" | |||
ann_info = results['ann_info'] | |||
results['gt_bboxes'] = ann_info['bboxes'].copy() | |||
gt_bboxes_ignore = ann_info.get('bboxes_ignore', None) | |||
if gt_bboxes_ignore is not None: | |||
results['gt_bboxes_ignore'] = gt_bboxes_ignore.copy() | |||
results['bbox_fields'].append('gt_bboxes_ignore') | |||
results['bbox_fields'].append('gt_bboxes') | |||
return results | |||
def _load_keypoints(self, results): | |||
"""Private function to load bounding box annotations. | |||
Args: | |||
results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||
Returns: | |||
dict: The dict contains loaded bounding box annotations. | |||
""" | |||
ann_info = results['ann_info'] | |||
results['gt_keypointss'] = ann_info['keypointss'].copy() | |||
results['keypoints_fields'] = ['gt_keypointss'] | |||
return results | |||
def _load_labels(self, results): | |||
"""Private function to load label annotations. | |||
Args: | |||
results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||
Returns: | |||
dict: The dict contains loaded label annotations. | |||
""" | |||
results['gt_labels'] = results['ann_info']['labels'].copy() | |||
return results | |||
def _poly2mask(self, mask_ann, img_h, img_w): | |||
"""Private function to convert masks represented with polygon to | |||
bitmaps. | |||
Args: | |||
mask_ann (list | dict): Polygon mask annotation input. | |||
img_h (int): The height of output mask. | |||
img_w (int): The width of output mask. | |||
Returns: | |||
numpy.ndarray: The decode bitmap mask of shape (img_h, img_w). | |||
""" | |||
if isinstance(mask_ann, list): | |||
# polygon -- a single object might consist of multiple parts | |||
# we merge all parts into one mask rle code | |||
rles = maskUtils.frPyObjects(mask_ann, img_h, img_w) | |||
rle = maskUtils.merge(rles) | |||
elif isinstance(mask_ann['counts'], list): | |||
# uncompressed RLE | |||
rle = maskUtils.frPyObjects(mask_ann, img_h, img_w) | |||
else: | |||
# rle | |||
rle = mask_ann | |||
mask = maskUtils.decode(rle) | |||
return mask | |||
def process_polygons(self, polygons): | |||
"""Convert polygons to list of ndarray and filter invalid polygons. | |||
Args: | |||
polygons (list[list]): Polygons of one instance. | |||
Returns: | |||
list[numpy.ndarray]: Processed polygons. | |||
""" | |||
polygons = [np.array(p) for p in polygons] | |||
valid_polygons = [] | |||
for polygon in polygons: | |||
if len(polygon) % 2 == 0 and len(polygon) >= 6: | |||
valid_polygons.append(polygon) | |||
return valid_polygons | |||
def _load_masks(self, results): | |||
"""Private function to load mask annotations. | |||
Args: | |||
results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||
Returns: | |||
dict: The dict contains loaded mask annotations. | |||
If ``self.poly2mask`` is set ``True``, `gt_mask` will contain | |||
:obj:`PolygonMasks`. Otherwise, :obj:`BitmapMasks` is used. | |||
""" | |||
h, w = results['img_info']['height'], results['img_info']['width'] | |||
gt_masks = results['ann_info']['masks'] | |||
if self.poly2mask: | |||
gt_masks = BitmapMasks( | |||
[self._poly2mask(mask, h, w) for mask in gt_masks], h, w) | |||
else: | |||
gt_masks = PolygonMasks( | |||
[self.process_polygons(polygons) for polygons in gt_masks], h, | |||
w) | |||
results['gt_masks'] = gt_masks | |||
results['mask_fields'].append('gt_masks') | |||
return results | |||
def _load_semantic_seg(self, results): | |||
"""Private function to load semantic segmentation annotations. | |||
Args: | |||
results (dict): Result dict from :obj:`dataset`. | |||
Returns: | |||
dict: The dict contains loaded semantic segmentation annotations. | |||
""" | |||
import mmcv | |||
if self.file_client is None: | |||
self.file_client = mmcv.FileClient(**self.file_client_args) | |||
filename = osp.join(results['seg_prefix'], | |||
results['ann_info']['seg_map']) | |||
img_bytes = self.file_client.get(filename) | |||
results['gt_semantic_seg'] = mmcv.imfrombytes( | |||
img_bytes, flag='unchanged').squeeze() | |||
results['seg_fields'].append('gt_semantic_seg') | |||
return results | |||
def __call__(self, results): | |||
"""Call function to load multiple types annotations. | |||
Args: | |||
results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||
Returns: | |||
dict: The dict contains loaded bounding box, label, mask and | |||
semantic segmentation annotations. | |||
""" | |||
if self.with_bbox: | |||
results = self._load_bboxes(results) | |||
if results is None: | |||
return None | |||
if self.with_label: | |||
results = self._load_labels(results) | |||
if self.with_keypoints: | |||
results = self._load_keypoints(results) | |||
if self.with_mask: | |||
results = self._load_masks(results) | |||
if self.with_seg: | |||
results = self._load_semantic_seg(results) | |||
return results | |||
def __repr__(self): | |||
repr_str = self.__class__.__name__ | |||
repr_str += f'(with_bbox={self.with_bbox}, ' | |||
repr_str += f'with_label={self.with_label}, ' | |||
repr_str += f'with_keypoints={self.with_keypoints}, ' | |||
repr_str += f'with_mask={self.with_mask}, ' | |||
repr_str += f'with_seg={self.with_seg})' | |||
repr_str += f'poly2mask={self.poly2mask})' | |||
repr_str += f'poly2mask={self.file_client_args})' | |||
return repr_str |
@@ -0,0 +1,737 @@ | |||
""" | |||
The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||
https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/transforms.py | |||
""" | |||
import mmcv | |||
import numpy as np | |||
from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps | |||
from mmdet.datasets.builder import PIPELINES | |||
from numpy import random | |||
@PIPELINES.register_module() | |||
class ResizeV2(object): | |||
"""Resize images & bbox & mask &kps. | |||
This transform resizes the input image to some scale. Bboxes and masks are | |||
then resized with the same scale factor. If the input dict contains the key | |||
"scale", then the scale in the input dict is used, otherwise the specified | |||
scale in the init method is used. If the input dict contains the key | |||
"scale_factor" (if MultiScaleFlipAug does not give img_scale but | |||
scale_factor), the actual scale will be computed by image shape and | |||
scale_factor. | |||
`img_scale` can either be a tuple (single-scale) or a list of tuple | |||
(multi-scale). There are 3 multiscale modes: | |||
- ``ratio_range is not None``: randomly sample a ratio from the ratio \ | |||
range and multiply it with the image scale. | |||
- ``ratio_range is None`` and ``multiscale_mode == "range"``: randomly \ | |||
sample a scale from the multiscale range. | |||
- ``ratio_range is None`` and ``multiscale_mode == "value"``: randomly \ | |||
sample a scale from multiple scales. | |||
Args: | |||
img_scale (tuple or list[tuple]): Images scales for resizing. | |||
multiscale_mode (str): Either "range" or "value". | |||
ratio_range (tuple[float]): (min_ratio, max_ratio) | |||
keep_ratio (bool): Whether to keep the aspect ratio when resizing the | |||
image. | |||
bbox_clip_border (bool, optional): Whether clip the objects outside | |||
the border of the image. Defaults to True. | |||
backend (str): Image resize backend, choices are 'cv2' and 'pillow'. | |||
These two backends generates slightly different results. Defaults | |||
to 'cv2'. | |||
override (bool, optional): Whether to override `scale` and | |||
`scale_factor` so as to call resize twice. Default False. If True, | |||
after the first resizing, the existed `scale` and `scale_factor` | |||
will be ignored so the second resizing can be allowed. | |||
This option is a work-around for multiple times of resize in DETR. | |||
Defaults to False. | |||
""" | |||
def __init__(self, | |||
img_scale=None, | |||
multiscale_mode='range', | |||
ratio_range=None, | |||
keep_ratio=True, | |||
bbox_clip_border=True, | |||
backend='cv2', | |||
override=False): | |||
if img_scale is None: | |||
self.img_scale = None | |||
else: | |||
if isinstance(img_scale, list): | |||
self.img_scale = img_scale | |||
else: | |||
self.img_scale = [img_scale] | |||
assert mmcv.is_list_of(self.img_scale, tuple) | |||
if ratio_range is not None: | |||
# mode 1: given a scale and a range of image ratio | |||
assert len(self.img_scale) == 1 | |||
else: | |||
# mode 2: given multiple scales or a range of scales | |||
assert multiscale_mode in ['value', 'range'] | |||
self.backend = backend | |||
self.multiscale_mode = multiscale_mode | |||
self.ratio_range = ratio_range | |||
self.keep_ratio = keep_ratio | |||
# TODO: refactor the override option in Resize | |||
self.override = override | |||
self.bbox_clip_border = bbox_clip_border | |||
@staticmethod | |||
def random_select(img_scales): | |||
"""Randomly select an img_scale from given candidates. | |||
Args: | |||
img_scales (list[tuple]): Images scales for selection. | |||
Returns: | |||
(tuple, int): Returns a tuple ``(img_scale, scale_dix)``, \ | |||
where ``img_scale`` is the selected image scale and \ | |||
``scale_idx`` is the selected index in the given candidates. | |||
""" | |||
assert mmcv.is_list_of(img_scales, tuple) | |||
scale_idx = np.random.randint(len(img_scales)) | |||
img_scale = img_scales[scale_idx] | |||
return img_scale, scale_idx | |||
@staticmethod | |||
def random_sample(img_scales): | |||
"""Randomly sample an img_scale when ``multiscale_mode=='range'``. | |||
Args: | |||
img_scales (list[tuple]): Images scale range for sampling. | |||
There must be two tuples in img_scales, which specify the lower | |||
and uper bound of image scales. | |||
Returns: | |||
(tuple, None): Returns a tuple ``(img_scale, None)``, where \ | |||
``img_scale`` is sampled scale and None is just a placeholder \ | |||
to be consistent with :func:`random_select`. | |||
""" | |||
assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2 | |||
img_scale_long = [max(s) for s in img_scales] | |||
img_scale_short = [min(s) for s in img_scales] | |||
long_edge = np.random.randint( | |||
min(img_scale_long), | |||
max(img_scale_long) + 1) | |||
short_edge = np.random.randint( | |||
min(img_scale_short), | |||
max(img_scale_short) + 1) | |||
img_scale = (long_edge, short_edge) | |||
return img_scale, None | |||
@staticmethod | |||
def random_sample_ratio(img_scale, ratio_range): | |||
"""Randomly sample an img_scale when ``ratio_range`` is specified. | |||
A ratio will be randomly sampled from the range specified by | |||
``ratio_range``. Then it would be multiplied with ``img_scale`` to | |||
generate sampled scale. | |||
Args: | |||
img_scale (tuple): Images scale base to multiply with ratio. | |||
ratio_range (tuple[float]): The minimum and maximum ratio to scale | |||
the ``img_scale``. | |||
Returns: | |||
(tuple, None): Returns a tuple ``(scale, None)``, where \ | |||
``scale`` is sampled ratio multiplied with ``img_scale`` and \ | |||
None is just a placeholder to be consistent with \ | |||
:func:`random_select`. | |||
""" | |||
assert isinstance(img_scale, tuple) and len(img_scale) == 2 | |||
min_ratio, max_ratio = ratio_range | |||
assert min_ratio <= max_ratio | |||
ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio | |||
scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio) | |||
return scale, None | |||
def _random_scale(self, results): | |||
"""Randomly sample an img_scale according to ``ratio_range`` and | |||
``multiscale_mode``. | |||
If ``ratio_range`` is specified, a ratio will be sampled and be | |||
multiplied with ``img_scale``. | |||
If multiple scales are specified by ``img_scale``, a scale will be | |||
sampled according to ``multiscale_mode``. | |||
Otherwise, single scale will be used. | |||
Args: | |||
results (dict): Result dict from :obj:`dataset`. | |||
Returns: | |||
dict: Two new keys 'scale` and 'scale_idx` are added into \ | |||
``results``, which would be used by subsequent pipelines. | |||
""" | |||
if self.ratio_range is not None: | |||
scale, scale_idx = self.random_sample_ratio( | |||
self.img_scale[0], self.ratio_range) | |||
elif len(self.img_scale) == 1: | |||
scale, scale_idx = self.img_scale[0], 0 | |||
elif self.multiscale_mode == 'range': | |||
scale, scale_idx = self.random_sample(self.img_scale) | |||
elif self.multiscale_mode == 'value': | |||
scale, scale_idx = self.random_select(self.img_scale) | |||
else: | |||
raise NotImplementedError | |||
results['scale'] = scale | |||
results['scale_idx'] = scale_idx | |||
def _resize_img(self, results): | |||
"""Resize images with ``results['scale']``.""" | |||
for key in results.get('img_fields', ['img']): | |||
if self.keep_ratio: | |||
img, scale_factor = mmcv.imrescale( | |||
results[key], | |||
results['scale'], | |||
return_scale=True, | |||
backend=self.backend) | |||
# the w_scale and h_scale has minor difference | |||
# a real fix should be done in the mmcv.imrescale in the future | |||
new_h, new_w = img.shape[:2] | |||
h, w = results[key].shape[:2] | |||
w_scale = new_w / w | |||
h_scale = new_h / h | |||
else: | |||
img, w_scale, h_scale = mmcv.imresize( | |||
results[key], | |||
results['scale'], | |||
return_scale=True, | |||
backend=self.backend) | |||
results[key] = img | |||
scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], | |||
dtype=np.float32) | |||
results['img_shape'] = img.shape | |||
# in case that there is no padding | |||
results['pad_shape'] = img.shape | |||
results['scale_factor'] = scale_factor | |||
results['keep_ratio'] = self.keep_ratio | |||
def _resize_bboxes(self, results): | |||
"""Resize bounding boxes with ``results['scale_factor']``.""" | |||
for key in results.get('bbox_fields', []): | |||
bboxes = results[key] * results['scale_factor'] | |||
if self.bbox_clip_border: | |||
img_shape = results['img_shape'] | |||
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1]) | |||
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0]) | |||
results[key] = bboxes | |||
def _resize_keypoints(self, results): | |||
"""Resize keypoints with ``results['scale_factor']``.""" | |||
for key in results.get('keypoints_fields', []): | |||
keypointss = results[key].copy() | |||
factors = results['scale_factor'] | |||
assert factors[0] == factors[2] | |||
assert factors[1] == factors[3] | |||
keypointss[:, :, 0] *= factors[0] | |||
keypointss[:, :, 1] *= factors[1] | |||
if self.bbox_clip_border: | |||
img_shape = results['img_shape'] | |||
keypointss[:, :, 0] = np.clip(keypointss[:, :, 0], 0, | |||
img_shape[1]) | |||
keypointss[:, :, 1] = np.clip(keypointss[:, :, 1], 0, | |||
img_shape[0]) | |||
results[key] = keypointss | |||
def _resize_masks(self, results): | |||
"""Resize masks with ``results['scale']``""" | |||
for key in results.get('mask_fields', []): | |||
if results[key] is None: | |||
continue | |||
if self.keep_ratio: | |||
results[key] = results[key].rescale(results['scale']) | |||
else: | |||
results[key] = results[key].resize(results['img_shape'][:2]) | |||
def _resize_seg(self, results): | |||
"""Resize semantic segmentation map with ``results['scale']``.""" | |||
for key in results.get('seg_fields', []): | |||
if self.keep_ratio: | |||
gt_seg = mmcv.imrescale( | |||
results[key], | |||
results['scale'], | |||
interpolation='nearest', | |||
backend=self.backend) | |||
else: | |||
gt_seg = mmcv.imresize( | |||
results[key], | |||
results['scale'], | |||
interpolation='nearest', | |||
backend=self.backend) | |||
results['gt_semantic_seg'] = gt_seg | |||
def __call__(self, results): | |||
"""Call function to resize images, bounding boxes, masks, semantic | |||
segmentation map. | |||
Args: | |||
results (dict): Result dict from loading pipeline. | |||
Returns: | |||
dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', \ | |||
'keep_ratio' keys are added into result dict. | |||
""" | |||
if 'scale' not in results: | |||
if 'scale_factor' in results: | |||
img_shape = results['img'].shape[:2] | |||
scale_factor = results['scale_factor'] | |||
assert isinstance(scale_factor, float) | |||
results['scale'] = tuple( | |||
[int(x * scale_factor) for x in img_shape][::-1]) | |||
else: | |||
self._random_scale(results) | |||
else: | |||
if not self.override: | |||
assert 'scale_factor' not in results, ( | |||
'scale and scale_factor cannot be both set.') | |||
else: | |||
results.pop('scale') | |||
if 'scale_factor' in results: | |||
results.pop('scale_factor') | |||
self._random_scale(results) | |||
self._resize_img(results) | |||
self._resize_bboxes(results) | |||
self._resize_keypoints(results) | |||
self._resize_masks(results) | |||
self._resize_seg(results) | |||
return results | |||
def __repr__(self): | |||
repr_str = self.__class__.__name__ | |||
repr_str += f'(img_scale={self.img_scale}, ' | |||
repr_str += f'multiscale_mode={self.multiscale_mode}, ' | |||
repr_str += f'ratio_range={self.ratio_range}, ' | |||
repr_str += f'keep_ratio={self.keep_ratio})' | |||
repr_str += f'bbox_clip_border={self.bbox_clip_border})' | |||
return repr_str | |||
@PIPELINES.register_module() | |||
class RandomFlipV2(object): | |||
"""Flip the image & bbox & mask & kps. | |||
If the input dict contains the key "flip", then the flag will be used, | |||
otherwise it will be randomly decided by a ratio specified in the init | |||
method. | |||
When random flip is enabled, ``flip_ratio``/``direction`` can either be a | |||
float/string or tuple of float/string. There are 3 flip modes: | |||
- ``flip_ratio`` is float, ``direction`` is string: the image will be | |||
``direction``ly flipped with probability of ``flip_ratio`` . | |||
E.g., ``flip_ratio=0.5``, ``direction='horizontal'``, | |||
then image will be horizontally flipped with probability of 0.5. | |||
- ``flip_ratio`` is float, ``direction`` is list of string: the image wil | |||
be ``direction[i]``ly flipped with probability of | |||
``flip_ratio/len(direction)``. | |||
E.g., ``flip_ratio=0.5``, ``direction=['horizontal', 'vertical']``, | |||
then image will be horizontally flipped with probability of 0.25, | |||
vertically with probability of 0.25. | |||
- ``flip_ratio`` is list of float, ``direction`` is list of string: | |||
given ``len(flip_ratio) == len(direction)``, the image wil | |||
be ``direction[i]``ly flipped with probability of ``flip_ratio[i]``. | |||
E.g., ``flip_ratio=[0.3, 0.5]``, ``direction=['horizontal', | |||
'vertical']``, then image will be horizontally flipped with probability | |||
of 0.3, vertically with probability of 0.5 | |||
Args: | |||
flip_ratio (float | list[float], optional): The flipping probability. | |||
Default: None. | |||
direction(str | list[str], optional): The flipping direction. Options | |||
are 'horizontal', 'vertical', 'diagonal'. Default: 'horizontal'. | |||
If input is a list, the length must equal ``flip_ratio``. Each | |||
element in ``flip_ratio`` indicates the flip probability of | |||
corresponding direction. | |||
""" | |||
def __init__(self, flip_ratio=None, direction='horizontal'): | |||
if isinstance(flip_ratio, list): | |||
assert mmcv.is_list_of(flip_ratio, float) | |||
assert 0 <= sum(flip_ratio) <= 1 | |||
elif isinstance(flip_ratio, float): | |||
assert 0 <= flip_ratio <= 1 | |||
elif flip_ratio is None: | |||
pass | |||
else: | |||
raise ValueError('flip_ratios must be None, float, ' | |||
'or list of float') | |||
self.flip_ratio = flip_ratio | |||
valid_directions = ['horizontal', 'vertical', 'diagonal'] | |||
if isinstance(direction, str): | |||
assert direction in valid_directions | |||
elif isinstance(direction, list): | |||
assert mmcv.is_list_of(direction, str) | |||
assert set(direction).issubset(set(valid_directions)) | |||
else: | |||
raise ValueError('direction must be either str or list of str') | |||
self.direction = direction | |||
if isinstance(flip_ratio, list): | |||
assert len(self.flip_ratio) == len(self.direction) | |||
self.count = 0 | |||
def bbox_flip(self, bboxes, img_shape, direction): | |||
"""Flip bboxes horizontally. | |||
Args: | |||
bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k) | |||
img_shape (tuple[int]): Image shape (height, width) | |||
direction (str): Flip direction. Options are 'horizontal', | |||
'vertical'. | |||
Returns: | |||
numpy.ndarray: Flipped bounding boxes. | |||
""" | |||
assert bboxes.shape[-1] % 4 == 0 | |||
flipped = bboxes.copy() | |||
if direction == 'horizontal': | |||
w = img_shape[1] | |||
flipped[..., 0::4] = w - bboxes[..., 2::4] | |||
flipped[..., 2::4] = w - bboxes[..., 0::4] | |||
elif direction == 'vertical': | |||
h = img_shape[0] | |||
flipped[..., 1::4] = h - bboxes[..., 3::4] | |||
flipped[..., 3::4] = h - bboxes[..., 1::4] | |||
elif direction == 'diagonal': | |||
w = img_shape[1] | |||
h = img_shape[0] | |||
flipped[..., 0::4] = w - bboxes[..., 2::4] | |||
flipped[..., 1::4] = h - bboxes[..., 3::4] | |||
flipped[..., 2::4] = w - bboxes[..., 0::4] | |||
flipped[..., 3::4] = h - bboxes[..., 1::4] | |||
else: | |||
raise ValueError(f"Invalid flipping direction '{direction}'") | |||
return flipped | |||
def keypoints_flip(self, keypointss, img_shape, direction): | |||
"""Flip keypoints horizontally.""" | |||
assert direction == 'horizontal' | |||
assert keypointss.shape[-1] == 3 | |||
num_kps = keypointss.shape[1] | |||
assert num_kps in [4, 5], f'Only Support num_kps=4 or 5, got:{num_kps}' | |||
assert keypointss.ndim == 3 | |||
flipped = keypointss.copy() | |||
if num_kps == 5: | |||
flip_order = [1, 0, 2, 4, 3] | |||
elif num_kps == 4: | |||
flip_order = [3, 2, 1, 0] | |||
for idx, a in enumerate(flip_order): | |||
flipped[:, idx, :] = keypointss[:, a, :] | |||
w = img_shape[1] | |||
flipped[..., 0] = w - flipped[..., 0] | |||
return flipped | |||
def __call__(self, results): | |||
"""Call function to flip bounding boxes, masks, semantic segmentation | |||
maps. | |||
Args: | |||
results (dict): Result dict from loading pipeline. | |||
Returns: | |||
dict: Flipped results, 'flip', 'flip_direction' keys are added \ | |||
into result dict. | |||
""" | |||
if 'flip' not in results: | |||
if isinstance(self.direction, list): | |||
# None means non-flip | |||
direction_list = self.direction + [None] | |||
else: | |||
# None means non-flip | |||
direction_list = [self.direction, None] | |||
if isinstance(self.flip_ratio, list): | |||
non_flip_ratio = 1 - sum(self.flip_ratio) | |||
flip_ratio_list = self.flip_ratio + [non_flip_ratio] | |||
else: | |||
non_flip_ratio = 1 - self.flip_ratio | |||
# exclude non-flip | |||
single_ratio = self.flip_ratio / (len(direction_list) - 1) | |||
flip_ratio_list = [single_ratio] * (len(direction_list) | |||
- 1) + [non_flip_ratio] | |||
cur_dir = np.random.choice(direction_list, p=flip_ratio_list) | |||
results['flip'] = cur_dir is not None | |||
if 'flip_direction' not in results: | |||
results['flip_direction'] = cur_dir | |||
if results['flip']: | |||
# flip image | |||
for key in results.get('img_fields', ['img']): | |||
results[key] = mmcv.imflip( | |||
results[key], direction=results['flip_direction']) | |||
# flip bboxes | |||
for key in results.get('bbox_fields', []): | |||
results[key] = self.bbox_flip(results[key], | |||
results['img_shape'], | |||
results['flip_direction']) | |||
# flip kps | |||
for key in results.get('keypoints_fields', []): | |||
results[key] = self.keypoints_flip(results[key], | |||
results['img_shape'], | |||
results['flip_direction']) | |||
# flip masks | |||
for key in results.get('mask_fields', []): | |||
results[key] = results[key].flip(results['flip_direction']) | |||
# flip segs | |||
for key in results.get('seg_fields', []): | |||
results[key] = mmcv.imflip( | |||
results[key], direction=results['flip_direction']) | |||
return results | |||
def __repr__(self): | |||
return self.__class__.__name__ + f'(flip_ratio={self.flip_ratio})' | |||
@PIPELINES.register_module() | |||
class RandomSquareCrop(object): | |||
"""Random crop the image & bboxes, the cropped patches have minimum IoU | |||
requirement with original image & bboxes, the IoU threshold is randomly | |||
selected from min_ious. | |||
Args: | |||
min_ious (tuple): minimum IoU threshold for all intersections with | |||
bounding boxes | |||
min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w, | |||
where a >= min_crop_size). | |||
Note: | |||
The keys for bboxes, labels and masks should be paired. That is, \ | |||
`gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \ | |||
`gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`. | |||
""" | |||
def __init__(self, | |||
crop_ratio_range=None, | |||
crop_choice=None, | |||
bbox_clip_border=True, | |||
big_face_ratio=0, | |||
big_face_crop_choice=None): | |||
self.crop_ratio_range = crop_ratio_range | |||
self.crop_choice = crop_choice | |||
self.big_face_crop_choice = big_face_crop_choice | |||
self.bbox_clip_border = bbox_clip_border | |||
assert (self.crop_ratio_range is None) ^ (self.crop_choice is None) | |||
if self.crop_ratio_range is not None: | |||
self.crop_ratio_min, self.crop_ratio_max = self.crop_ratio_range | |||
self.bbox2label = { | |||
'gt_bboxes': 'gt_labels', | |||
'gt_bboxes_ignore': 'gt_labels_ignore' | |||
} | |||
self.bbox2mask = { | |||
'gt_bboxes': 'gt_masks', | |||
'gt_bboxes_ignore': 'gt_masks_ignore' | |||
} | |||
assert big_face_ratio >= 0 and big_face_ratio <= 1.0 | |||
self.big_face_ratio = big_face_ratio | |||
def __call__(self, results): | |||
"""Call function to crop images and bounding boxes with minimum IoU | |||
constraint. | |||
Args: | |||
results (dict): Result dict from loading pipeline. | |||
Returns: | |||
dict: Result dict with images and bounding boxes cropped, \ | |||
'img_shape' key is updated. | |||
""" | |||
if 'img_fields' in results: | |||
assert results['img_fields'] == ['img'], \ | |||
'Only single img_fields is allowed' | |||
img = results['img'] | |||
assert 'bbox_fields' in results | |||
assert 'gt_bboxes' in results | |||
# try augment big face images | |||
find_bigface = False | |||
if np.random.random() < self.big_face_ratio: | |||
min_size = 100 # h and w | |||
expand_ratio = 0.3 # expand ratio of croped face alongwith both w and h | |||
bbox = results['gt_bboxes'].copy() | |||
lmks = results['gt_keypointss'].copy() | |||
label = results['gt_labels'].copy() | |||
# filter small faces | |||
size_mask = ((bbox[:, 2] - bbox[:, 0]) > min_size) * ( | |||
(bbox[:, 3] - bbox[:, 1]) > min_size) | |||
bbox = bbox[size_mask] | |||
lmks = lmks[size_mask] | |||
label = label[size_mask] | |||
# randomly choose a face that has no overlap with others | |||
if len(bbox) > 0: | |||
overlaps = bbox_overlaps(bbox, bbox) | |||
overlaps -= np.eye(overlaps.shape[0]) | |||
iou_mask = np.sum(overlaps, axis=1) == 0 | |||
bbox = bbox[iou_mask] | |||
lmks = lmks[iou_mask] | |||
label = label[iou_mask] | |||
if len(bbox) > 0: | |||
choice = np.random.randint(len(bbox)) | |||
bbox = bbox[choice] | |||
lmks = lmks[choice] | |||
label = [label[choice]] | |||
w = bbox[2] - bbox[0] | |||
h = bbox[3] - bbox[1] | |||
x1 = bbox[0] - w * expand_ratio | |||
x2 = bbox[2] + w * expand_ratio | |||
y1 = bbox[1] - h * expand_ratio | |||
y2 = bbox[3] + h * expand_ratio | |||
x1, x2 = np.clip([x1, x2], 0, img.shape[1]) | |||
y1, y2 = np.clip([y1, y2], 0, img.shape[0]) | |||
bbox -= np.tile([x1, y1], 2) | |||
lmks -= (x1, y1, 0) | |||
find_bigface = True | |||
img = img[int(y1):int(y2), int(x1):int(x2), :] | |||
results['gt_bboxes'] = np.expand_dims(bbox, axis=0) | |||
results['gt_keypointss'] = np.expand_dims(lmks, axis=0) | |||
results['gt_labels'] = np.array(label) | |||
results['img'] = img | |||
boxes = results['gt_bboxes'] | |||
h, w, c = img.shape | |||
if self.crop_ratio_range is not None: | |||
max_scale = self.crop_ratio_max | |||
else: | |||
max_scale = np.amax(self.crop_choice) | |||
scale_retry = 0 | |||
while True: | |||
scale_retry += 1 | |||
if scale_retry == 1 or max_scale > 1.0: | |||
if self.crop_ratio_range is not None: | |||
scale = np.random.uniform(self.crop_ratio_min, | |||
self.crop_ratio_max) | |||
elif self.crop_choice is not None: | |||
scale = np.random.choice(self.crop_choice) | |||
else: | |||
scale = scale * 1.2 | |||
if find_bigface: | |||
# select a scale from big_face_crop_choice if in big_face mode | |||
scale = np.random.choice(self.big_face_crop_choice) | |||
for i in range(250): | |||
long_side = max(w, h) | |||
cw = int(scale * long_side) | |||
ch = cw | |||
# TODO +1 | |||
if w == cw: | |||
left = 0 | |||
elif w > cw: | |||
left = random.randint(0, w - cw) | |||
else: | |||
left = random.randint(w - cw, 0) | |||
if h == ch: | |||
top = 0 | |||
elif h > ch: | |||
top = random.randint(0, h - ch) | |||
else: | |||
top = random.randint(h - ch, 0) | |||
patch = np.array( | |||
(int(left), int(top), int(left + cw), int(top + ch)), | |||
dtype=np.int32) | |||
# center of boxes should inside the crop img | |||
# only adjust boxes and instance masks when the gt is not empty | |||
# adjust boxes | |||
def is_center_of_bboxes_in_patch(boxes, patch): | |||
# TODO >= | |||
center = (boxes[:, :2] + boxes[:, 2:]) / 2 | |||
mask = \ | |||
((center[:, 0] > patch[0]) | |||
* (center[:, 1] > patch[1]) | |||
* (center[:, 0] < patch[2]) | |||
* (center[:, 1] < patch[3])) | |||
return mask | |||
mask = is_center_of_bboxes_in_patch(boxes, patch) | |||
if not mask.any(): | |||
continue | |||
for key in results.get('bbox_fields', []): | |||
boxes = results[key].copy() | |||
mask = is_center_of_bboxes_in_patch(boxes, patch) | |||
boxes = boxes[mask] | |||
if self.bbox_clip_border: | |||
boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) | |||
boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) | |||
boxes -= np.tile(patch[:2], 2) | |||
results[key] = boxes | |||
# labels | |||
label_key = self.bbox2label.get(key) | |||
if label_key in results: | |||
results[label_key] = results[label_key][mask] | |||
# keypoints field | |||
if key == 'gt_bboxes': | |||
for kps_key in results.get('keypoints_fields', []): | |||
keypointss = results[kps_key].copy() | |||
keypointss = keypointss[mask, :, :] | |||
if self.bbox_clip_border: | |||
keypointss[:, :, : | |||
2] = keypointss[:, :, :2].clip( | |||
max=patch[2:]) | |||
keypointss[:, :, : | |||
2] = keypointss[:, :, :2].clip( | |||
min=patch[:2]) | |||
keypointss[:, :, 0] -= patch[0] | |||
keypointss[:, :, 1] -= patch[1] | |||
results[kps_key] = keypointss | |||
# mask fields | |||
mask_key = self.bbox2mask.get(key) | |||
if mask_key in results: | |||
results[mask_key] = results[mask_key][mask.nonzero() | |||
[0]].crop(patch) | |||
# adjust the img no matter whether the gt is empty before crop | |||
rimg = np.ones((ch, cw, 3), dtype=img.dtype) * 128 | |||
patch_from = patch.copy() | |||
patch_from[0] = max(0, patch_from[0]) | |||
patch_from[1] = max(0, patch_from[1]) | |||
patch_from[2] = min(img.shape[1], patch_from[2]) | |||
patch_from[3] = min(img.shape[0], patch_from[3]) | |||
patch_to = patch.copy() | |||
patch_to[0] = max(0, patch_to[0] * -1) | |||
patch_to[1] = max(0, patch_to[1] * -1) | |||
patch_to[2] = patch_to[0] + (patch_from[2] - patch_from[0]) | |||
patch_to[3] = patch_to[1] + (patch_from[3] - patch_from[1]) | |||
rimg[patch_to[1]:patch_to[3], | |||
patch_to[0]:patch_to[2], :] = img[ | |||
patch_from[1]:patch_from[3], | |||
patch_from[0]:patch_from[2], :] | |||
img = rimg | |||
results['img'] = img | |||
results['img_shape'] = img.shape | |||
return results | |||
def __repr__(self): | |||
repr_str = self.__class__.__name__ | |||
repr_str += f'(min_ious={self.min_iou}, ' | |||
repr_str += f'crop_size={self.crop_size})' | |||
return repr_str |
@@ -13,7 +13,7 @@ class RetinaFaceDataset(CustomDataset): | |||
CLASSES = ('FG', ) | |||
def __init__(self, min_size=None, **kwargs): | |||
self.NK = 5 | |||
self.NK = kwargs.pop('num_kps', 5) | |||
self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)} | |||
self.min_size = min_size | |||
self.gt_path = kwargs.get('gt_path') | |||
@@ -33,7 +33,8 @@ class RetinaFaceDataset(CustomDataset): | |||
if len(values) > 4: | |||
if len(values) > 5: | |||
kps = np.array( | |||
values[4:19], dtype=np.float32).reshape((self.NK, 3)) | |||
values[4:4 + self.NK * 3], dtype=np.float32).reshape( | |||
(self.NK, 3)) | |||
for li in range(kps.shape[0]): | |||
if (kps[li, :] == -1).all(): | |||
kps[li][2] = 0.0 # weight = 0, ignore |
@@ -103,6 +103,7 @@ class SCRFDHead(AnchorHead): | |||
scale_mode=1, | |||
dw_conv=False, | |||
use_kps=False, | |||
num_kps=5, | |||
loss_kps=dict( | |||
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.1), | |||
**kwargs): | |||
@@ -116,7 +117,7 @@ class SCRFDHead(AnchorHead): | |||
self.scale_mode = scale_mode | |||
self.use_dfl = True | |||
self.dw_conv = dw_conv | |||
self.NK = 5 | |||
self.NK = num_kps | |||
self.extra_flops = 0.0 | |||
if loss_dfl is None or not loss_dfl: | |||
self.use_dfl = False | |||
@@ -323,8 +324,8 @@ class SCRFDHead(AnchorHead): | |||
batch_size, -1, self.cls_out_channels).sigmoid() | |||
bbox_pred = bbox_pred.permute(0, 2, 3, | |||
1).reshape(batch_size, -1, 4) | |||
kps_pred = kps_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 10) | |||
kps_pred = kps_pred.permute(0, 2, 3, | |||
1).reshape(batch_size, -1, self.NK * 2) | |||
return cls_score, bbox_pred, kps_pred | |||
def forward_train(self, | |||
@@ -788,7 +789,7 @@ class SCRFDHead(AnchorHead): | |||
if self.use_dfl: | |||
kps_pred = self.integral(kps_pred) * stride[0] | |||
else: | |||
kps_pred = kps_pred.reshape((-1, 10)) * stride[0] | |||
kps_pred = kps_pred.reshape((-1, self.NK * 2)) * stride[0] | |||
nms_pre = cfg.get('nms_pre', -1) | |||
if nms_pre > 0 and scores.shape[0] > nms_pre: | |||
@@ -815,7 +816,7 @@ class SCRFDHead(AnchorHead): | |||
mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) | |||
if mlvl_kps is not None: | |||
scale_factor2 = torch.tensor( | |||
[scale_factor[0], scale_factor[1]] * 5) | |||
[scale_factor[0], scale_factor[1]] * self.NK) | |||
mlvl_kps /= scale_factor2.to(mlvl_kps.device) | |||
mlvl_scores = torch.cat(mlvl_scores) |
@@ -54,7 +54,13 @@ class SCRFD(SingleStageDetector): | |||
gt_bboxes_ignore) | |||
return losses | |||
def simple_test(self, img, img_metas, rescale=False): | |||
def simple_test(self, | |||
img, | |||
img_metas, | |||
rescale=False, | |||
repeat_head=1, | |||
output_kps_var=0, | |||
output_results=1): | |||
"""Test function without test time augmentation. | |||
Args: | |||
@@ -62,6 +68,9 @@ class SCRFD(SingleStageDetector): | |||
img_metas (list[dict]): List of image information. | |||
rescale (bool, optional): Whether to rescale the results. | |||
Defaults to False. | |||
repeat_head (int): repeat inference times in head | |||
output_kps_var (int): whether output kps var to calculate quality | |||
output_results (int): 0: nothing 1: bbox 2: both bbox and kps | |||
Returns: | |||
list[list[np.ndarray]]: BBox results of each image and classes. | |||
@@ -69,40 +78,71 @@ class SCRFD(SingleStageDetector): | |||
corresponds to each class. | |||
""" | |||
x = self.extract_feat(img) | |||
outs = self.bbox_head(x) | |||
if torch.onnx.is_in_onnx_export(): | |||
print('single_stage.py in-onnx-export') | |||
print(outs.__class__) | |||
cls_score, bbox_pred, kps_pred = outs | |||
for c in cls_score: | |||
print(c.shape) | |||
for c in bbox_pred: | |||
print(c.shape) | |||
if self.bbox_head.use_kps: | |||
for c in kps_pred: | |||
assert repeat_head >= 1 | |||
kps_out0 = [] | |||
kps_out1 = [] | |||
kps_out2 = [] | |||
for i in range(repeat_head): | |||
outs = self.bbox_head(x) | |||
kps_out0 += [outs[2][0].detach().cpu().numpy()] | |||
kps_out1 += [outs[2][1].detach().cpu().numpy()] | |||
kps_out2 += [outs[2][2].detach().cpu().numpy()] | |||
if output_kps_var: | |||
var0 = np.var(np.vstack(kps_out0), axis=0).mean() | |||
var1 = np.var(np.vstack(kps_out1), axis=0).mean() | |||
var2 = np.var(np.vstack(kps_out2), axis=0).mean() | |||
var = np.mean([var0, var1, var2]) | |||
else: | |||
var = None | |||
if output_results > 0: | |||
if torch.onnx.is_in_onnx_export(): | |||
print('single_stage.py in-onnx-export') | |||
print(outs.__class__) | |||
cls_score, bbox_pred, kps_pred = outs | |||
for c in cls_score: | |||
print(c.shape) | |||
for c in bbox_pred: | |||
print(c.shape) | |||
return (cls_score, bbox_pred, kps_pred) | |||
else: | |||
return (cls_score, bbox_pred) | |||
bbox_list = self.bbox_head.get_bboxes( | |||
*outs, img_metas, rescale=rescale) | |||
if self.bbox_head.use_kps: | |||
for c in kps_pred: | |||
print(c.shape) | |||
return (cls_score, bbox_pred, kps_pred) | |||
else: | |||
return (cls_score, bbox_pred) | |||
bbox_list = self.bbox_head.get_bboxes( | |||
*outs, img_metas, rescale=rescale) | |||
# return kps if use_kps | |||
if len(bbox_list[0]) == 2: | |||
bbox_results = [ | |||
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) | |||
for det_bboxes, det_labels in bbox_list | |||
] | |||
elif len(bbox_list[0]) == 3: | |||
bbox_results = [ | |||
bbox2result( | |||
det_bboxes, | |||
det_labels, | |||
self.bbox_head.num_classes, | |||
kps=det_kps) | |||
for det_bboxes, det_labels, det_kps in bbox_list | |||
] | |||
return bbox_results | |||
# return kps if use_kps | |||
if len(bbox_list[0]) == 2: | |||
bbox_results = [ | |||
bbox2result(det_bboxes, det_labels, | |||
self.bbox_head.num_classes) | |||
for det_bboxes, det_labels in bbox_list | |||
] | |||
elif len(bbox_list[0]) == 3: | |||
if output_results == 2: | |||
bbox_results = [ | |||
bbox2result( | |||
det_bboxes, | |||
det_labels, | |||
self.bbox_head.num_classes, | |||
kps=det_kps, | |||
num_kps=self.bbox_head.NK) | |||
for det_bboxes, det_labels, det_kps in bbox_list | |||
] | |||
elif output_results == 1: | |||
bbox_results = [ | |||
bbox2result(det_bboxes, det_labels, | |||
self.bbox_head.num_classes) | |||
for det_bboxes, det_labels, _ in bbox_list | |||
] | |||
else: | |||
bbox_results = None | |||
if var is not None: | |||
return bbox_results, var | |||
else: | |||
return bbox_results | |||
def feature_test(self, img): | |||
x = self.extract_feat(img) |
@@ -0,0 +1,71 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os.path as osp | |||
from copy import deepcopy | |||
from typing import Any, Dict | |||
import torch | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base import TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
__all__ = ['ScrfdDetect'] | |||
@MODELS.register_module(Tasks.face_detection, module_name=Models.scrfd) | |||
class ScrfdDetect(TorchModel): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""initialize the face detection model from the `model_dir` path. | |||
Args: | |||
model_dir (str): the model path. | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
from mmcv import Config | |||
from mmcv.parallel import MMDataParallel | |||
from mmcv.runner import load_checkpoint | |||
from mmdet.models import build_detector | |||
from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets import RetinaFaceDataset | |||
from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import RandomSquareCrop | |||
from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.backbones import ResNetV1e | |||
from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.dense_heads import SCRFDHead | |||
from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.detectors import SCRFD | |||
cfg = Config.fromfile(osp.join(model_dir, 'mmcv_scrfd.py')) | |||
ckpt_path = osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) | |||
cfg.model.test_cfg.score_thr = kwargs.get('score_thr', 0.3) | |||
detector = build_detector(cfg.model) | |||
logger.info(f'loading model from {ckpt_path}') | |||
device = torch.device( | |||
f'cuda:{0}' if torch.cuda.is_available() else 'cpu') | |||
load_checkpoint(detector, ckpt_path, map_location=device) | |||
detector = MMDataParallel(detector, device_ids=[0]) | |||
detector.eval() | |||
self.detector = detector | |||
logger.info('load model done') | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
result = self.detector( | |||
return_loss=False, | |||
rescale=True, | |||
img=[input['img'][0].unsqueeze(0)], | |||
img_metas=[[dict(input['img_metas'][0].data)]], | |||
output_results=2) | |||
assert result is not None | |||
result = result[0][0] | |||
bboxes = result[:, :4].tolist() | |||
kpss = result[:, 5:].tolist() | |||
scores = result[:, 4].tolist() | |||
return { | |||
OutputKeys.SCORES: scores, | |||
OutputKeys.BOXES: bboxes, | |||
OutputKeys.KEYPOINTS: kpss | |||
} | |||
def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: | |||
return input |
@@ -90,6 +90,25 @@ TASK_OUTPUTS = { | |||
Tasks.face_detection: | |||
[OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS], | |||
# card detection result for single sample | |||
# { | |||
# "scores": [0.9, 0.1, 0.05, 0.05] | |||
# "boxes": [ | |||
# [x1, y1, x2, y2], | |||
# [x1, y1, x2, y2], | |||
# [x1, y1, x2, y2], | |||
# [x1, y1, x2, y2], | |||
# ], | |||
# "keypoints": [ | |||
# [x1, y1, x2, y2, x3, y3, x4, y4], | |||
# [x1, y1, x2, y2, x3, y3, x4, y4], | |||
# [x1, y1, x2, y2, x3, y3, x4, y4], | |||
# [x1, y1, x2, y2, x3, y3, x4, y4], | |||
# ], | |||
# } | |||
Tasks.card_detection: | |||
[OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS], | |||
# facial expression recognition result for single sample | |||
# { | |||
# "scores": [0.9, 0.1, 0.02, 0.02, 0.02, 0.02, 0.02], | |||
@@ -116,6 +116,10 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
Tasks.hand_2d_keypoints: | |||
(Pipelines.hand_2d_keypoints, | |||
'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody'), | |||
Tasks.face_detection: (Pipelines.face_detection, | |||
'damo/cv_resnet_facedetection_scrfd10gkps'), | |||
Tasks.card_detection: (Pipelines.card_detection, | |||
'damo/cv_resnet_carddetection_scrfd34gkps'), | |||
Tasks.face_detection: | |||
(Pipelines.face_detection, | |||
'damo/cv_resnet101_face-detection_cvpr22papermogface'), | |||
@@ -0,0 +1,23 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.pipelines.cv.face_detection_pipeline import \ | |||
FaceDetectionPipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.card_detection, module_name=Pipelines.card_detection) | |||
class CardDetectionPipeline(FaceDetectionPipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
use `model` to create a card detection pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
thr = 0.45 # card/face detect use different threshold | |||
super().__init__(model=model, score_thr=thr, **kwargs) |
@@ -8,6 +8,7 @@ import PIL | |||
import torch | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models.cv.face_detection import ScrfdDetect | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
@@ -29,27 +30,8 @@ class FaceDetectionPipeline(Pipeline): | |||
model: model id on modelscope hub. | |||
""" | |||
super().__init__(model=model, **kwargs) | |||
from mmcv import Config | |||
from mmcv.parallel import MMDataParallel | |||
from mmcv.runner import load_checkpoint | |||
from mmdet.models import build_detector | |||
from modelscope.models.cv.face_detection.mmdet_patch.datasets import RetinaFaceDataset | |||
from modelscope.models.cv.face_detection.mmdet_patch.datasets.pipelines import RandomSquareCrop | |||
from modelscope.models.cv.face_detection.mmdet_patch.models.backbones import ResNetV1e | |||
from modelscope.models.cv.face_detection.mmdet_patch.models.dense_heads import SCRFDHead | |||
from modelscope.models.cv.face_detection.mmdet_patch.models.detectors import SCRFD | |||
cfg = Config.fromfile(osp.join(model, 'mmcv_scrfd_10g_bnkps.py')) | |||
detector = build_detector( | |||
cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) | |||
ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_BIN_FILE) | |||
logger.info(f'loading model from {ckpt_path}') | |||
device = torch.device( | |||
f'cuda:{0}' if torch.cuda.is_available() else 'cpu') | |||
load_checkpoint(detector, ckpt_path, map_location=device) | |||
detector = MMDataParallel(detector, device_ids=[0]) | |||
detector.eval() | |||
detector = ScrfdDetect(model_dir=model, **kwargs) | |||
self.detector = detector | |||
logger.info('load model done') | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
img = LoadImage.convert_to_ndarray(input) | |||
@@ -85,22 +67,7 @@ class FaceDetectionPipeline(Pipeline): | |||
return result | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
result = self.detector( | |||
return_loss=False, | |||
rescale=True, | |||
img=[input['img'][0].unsqueeze(0)], | |||
img_metas=[[dict(input['img_metas'][0].data)]]) | |||
assert result is not None | |||
result = result[0][0] | |||
bboxes = result[:, :4].tolist() | |||
kpss = result[:, 5:].tolist() | |||
scores = result[:, 4].tolist() | |||
return { | |||
OutputKeys.SCORES: scores, | |||
OutputKeys.BOXES: bboxes, | |||
OutputKeys.KEYPOINTS: kpss | |||
} | |||
return self.detector(input) | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
return inputs |
@@ -49,7 +49,7 @@ class FaceRecognitionPipeline(Pipeline): | |||
# face detect pipeline | |||
det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' | |||
self.face_detection = pipeline( | |||
Tasks.face_detection, model=det_model_id) | |||
Tasks.face_detection, model=det_model_id, model_revision='v2') | |||
def _choose_face(self, | |||
det_result, | |||
@@ -0,0 +1,18 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from modelscope.metainfo import Trainers | |||
from modelscope.trainers.builder import TRAINERS | |||
from modelscope.trainers.cv.face_detection_scrfd_trainer import \ | |||
FaceDetectionScrfdTrainer | |||
@TRAINERS.register_module(module_name=Trainers.card_detection_scrfd) | |||
class CardDetectionScrfdTrainer(FaceDetectionScrfdTrainer): | |||
def __init__(self, cfg_file: str, *args, **kwargs): | |||
""" High-level finetune api for SCRFD. | |||
Args: | |||
cfg_file: Path to configuration file. | |||
""" | |||
# card/face dataset use different img folder names | |||
super().__init__(cfg_file, imgdir_name='', **kwargs) |
@@ -0,0 +1,154 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import copy | |||
import os | |||
import os.path as osp | |||
import time | |||
from typing import Callable, Dict, Optional | |||
from modelscope.metainfo import Trainers | |||
from modelscope.trainers.base import BaseTrainer | |||
from modelscope.trainers.builder import TRAINERS | |||
@TRAINERS.register_module(module_name=Trainers.face_detection_scrfd) | |||
class FaceDetectionScrfdTrainer(BaseTrainer): | |||
def __init__(self, | |||
cfg_file: str, | |||
cfg_modify_fn: Optional[Callable] = None, | |||
*args, | |||
**kwargs): | |||
""" High-level finetune api for SCRFD. | |||
Args: | |||
cfg_file: Path to configuration file. | |||
cfg_modify_fn: An input fn which is used to modify the cfg read out of the file. | |||
""" | |||
import mmcv | |||
from mmcv.runner import get_dist_info, init_dist | |||
from mmcv.utils import get_git_hash | |||
from mmdet.utils import collect_env, get_root_logger | |||
from mmdet.apis import set_random_seed | |||
from mmdet.models import build_detector | |||
from mmdet.datasets import build_dataset | |||
from mmdet import __version__ | |||
from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets import RetinaFaceDataset | |||
from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import DefaultFormatBundleV2 | |||
from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import LoadAnnotationsV2 | |||
from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import RotateV2 | |||
from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import RandomSquareCrop | |||
from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.backbones import ResNetV1e | |||
from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.dense_heads import SCRFDHead | |||
from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.detectors import SCRFD | |||
super().__init__(cfg_file) | |||
cfg = self.cfg | |||
if 'work_dir' in kwargs: | |||
cfg.work_dir = kwargs['work_dir'] | |||
else: | |||
# use config filename as default work_dir if work_dir is None | |||
cfg.work_dir = osp.join('./work_dirs', | |||
osp.splitext(osp.basename(cfg_file))[0]) | |||
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) | |||
if 'resume_from' in kwargs: # pretrain model for finetune | |||
cfg.resume_from = kwargs['resume_from'] | |||
cfg.device = 'cuda' | |||
if 'gpu_ids' in kwargs: | |||
cfg.gpu_ids = kwargs['gpu_ids'] | |||
else: | |||
cfg.gpu_ids = range(1) | |||
labelfile_name = kwargs.pop('labelfile_name', 'labelv2.txt') | |||
imgdir_name = kwargs.pop('imgdir_name', 'images/') | |||
if 'train_root' in kwargs: | |||
cfg.data.train.ann_file = kwargs['train_root'] + labelfile_name | |||
cfg.data.train.img_prefix = kwargs['train_root'] + imgdir_name | |||
if 'val_root' in kwargs: | |||
cfg.data.val.ann_file = kwargs['val_root'] + labelfile_name | |||
cfg.data.val.img_prefix = kwargs['val_root'] + imgdir_name | |||
if 'total_epochs' in kwargs: | |||
cfg.total_epochs = kwargs['total_epochs'] | |||
if cfg_modify_fn is not None: | |||
cfg = cfg_modify_fn(cfg) | |||
if 'launcher' in kwargs: | |||
distributed = True | |||
init_dist(kwargs['launcher'], **cfg.dist_params) | |||
# re-set gpu_ids with distributed training mode | |||
_, world_size = get_dist_info() | |||
cfg.gpu_ids = range(world_size) | |||
else: | |||
distributed = False | |||
# no_validate=True will not evaluate checkpoint during training | |||
cfg.no_validate = kwargs.get('no_validate', False) | |||
# init the logger before other steps | |||
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) | |||
log_file = osp.join(cfg.work_dir, f'{timestamp}.log') | |||
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) | |||
# init the meta dict to record some important information such as | |||
# environment info and seed, which will be logged | |||
meta = dict() | |||
# log env info | |||
env_info_dict = collect_env() | |||
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) | |||
dash_line = '-' * 60 + '\n' | |||
logger.info('Environment info:\n' + dash_line + env_info + '\n' | |||
+ dash_line) | |||
meta['env_info'] = env_info | |||
meta['config'] = cfg.pretty_text | |||
# log some basic info | |||
logger.info(f'Distributed training: {distributed}') | |||
logger.info(f'Config:\n{cfg.pretty_text}') | |||
# set random seeds | |||
if 'seed' in kwargs: | |||
cfg.seed = kwargs['seed'] | |||
_deterministic = kwargs.get('deterministic', False) | |||
logger.info(f'Set random seed to {kwargs["seed"]}, ' | |||
f'deterministic: {_deterministic}') | |||
set_random_seed(kwargs['seed'], deterministic=_deterministic) | |||
else: | |||
cfg.seed = None | |||
meta['seed'] = cfg.seed | |||
meta['exp_name'] = osp.basename(cfg_file) | |||
model = build_detector(cfg.model) | |||
model.init_weights() | |||
datasets = [build_dataset(cfg.data.train)] | |||
if len(cfg.workflow) == 2: | |||
val_dataset = copy.deepcopy(cfg.data.val) | |||
val_dataset.pipeline = cfg.data.train.pipeline | |||
datasets.append(build_dataset(val_dataset)) | |||
if cfg.checkpoint_config is not None: | |||
# save mmdet version, config file content and class names in | |||
# checkpoints as meta data | |||
cfg.checkpoint_config.meta = dict( | |||
mmdet_version=__version__ + get_git_hash()[:7], | |||
CLASSES=datasets[0].CLASSES) | |||
# add an attribute for visualization convenience | |||
model.CLASSES = datasets[0].CLASSES | |||
self.cfg = cfg | |||
self.datasets = datasets | |||
self.model = model | |||
self.distributed = distributed | |||
self.timestamp = timestamp | |||
self.meta = meta | |||
self.logger = logger | |||
def train(self, *args, **kwargs): | |||
from mmdet.apis import train_detector | |||
train_detector( | |||
self.model, | |||
self.datasets, | |||
self.cfg, | |||
distributed=self.distributed, | |||
validate=(not self.cfg.no_validate), | |||
timestamp=self.timestamp, | |||
meta=self.meta) | |||
def evaluate(self, | |||
checkpoint_path: str = None, | |||
*args, | |||
**kwargs) -> Dict[str, float]: | |||
cfg = self.cfg.evaluation | |||
logger.info(f'eval cfg {cfg}') | |||
logger.info(f'checkpoint_path {checkpoint_path}') |
@@ -19,6 +19,7 @@ class CVTasks(object): | |||
# human face body related | |||
animal_recognition = 'animal-recognition' | |||
face_detection = 'face-detection' | |||
card_detection = 'card-detection' | |||
face_recognition = 'face-recognition' | |||
facial_expression_recognition = 'facial-expression-recognition' | |||
face_2d_keypoints = 'face-2d-keypoints' | |||
@@ -154,6 +154,54 @@ def draw_face_detection_result(img_path, detection_result): | |||
return img | |||
def draw_card_detection_result(img_path, detection_result): | |||
def warp_img(src_img, kps, ratio): | |||
short_size = 500 | |||
if ratio > 1: | |||
obj_h = short_size | |||
obj_w = int(obj_h * ratio) | |||
else: | |||
obj_w = short_size | |||
obj_h = int(obj_w / ratio) | |||
input_pts = np.float32([kps[0], kps[1], kps[2], kps[3]]) | |||
output_pts = np.float32([[0, obj_h - 1], [0, 0], [obj_w - 1, 0], | |||
[obj_w - 1, obj_h - 1]]) | |||
M = cv2.getPerspectiveTransform(input_pts, output_pts) | |||
obj_img = cv2.warpPerspective(src_img, M, (obj_w, obj_h)) | |||
return obj_img | |||
bboxes = np.array(detection_result[OutputKeys.BOXES]) | |||
kpss = np.array(detection_result[OutputKeys.KEYPOINTS]) | |||
scores = np.array(detection_result[OutputKeys.SCORES]) | |||
img_list = [] | |||
ver_col = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (0, 255, 255)] | |||
img = cv2.imread(img_path) | |||
img_list += [img] | |||
assert img is not None, f"Can't read img: {img_path}" | |||
for i in range(len(scores)): | |||
bbox = bboxes[i].astype(np.int32) | |||
kps = kpss[i].reshape(-1, 2).astype(np.int32) | |||
_w = (kps[0][0] - kps[3][0])**2 + (kps[0][1] - kps[3][1])**2 | |||
_h = (kps[0][0] - kps[1][0])**2 + (kps[0][1] - kps[1][1])**2 | |||
ratio = 1.59 if _w >= _h else 1 / 1.59 | |||
card_img = warp_img(img, kps, ratio) | |||
img_list += [card_img] | |||
score = scores[i] | |||
x1, y1, x2, y2 = bbox | |||
cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 4) | |||
for k, kp in enumerate(kps): | |||
cv2.circle(img, tuple(kp), 1, color=ver_col[k], thickness=10) | |||
cv2.putText( | |||
img, | |||
f'{score:.2f}', (x1, y2), | |||
1, | |||
1.0, (0, 255, 0), | |||
thickness=1, | |||
lineType=8) | |||
return img_list | |||
def created_boxed_image(image_in, box): | |||
image = load_image(image_in) | |||
img = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) | |||
@@ -0,0 +1,66 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os.path as osp | |||
import unittest | |||
import cv2 | |||
from modelscope.msdatasets import MsDataset | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.cv.image_utils import draw_card_detection_result | |||
from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
from modelscope.utils.test_utils import test_level | |||
class CardDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
def setUp(self) -> None: | |||
self.task = Tasks.card_detection | |||
self.model_id = 'damo/cv_resnet_carddetection_scrfd34gkps' | |||
def show_result(self, img_path, detection_result): | |||
img_list = draw_card_detection_result(img_path, detection_result) | |||
for i, img in enumerate(img_list): | |||
if i == 0: | |||
cv2.imwrite('result.jpg', img_list[0]) | |||
print( | |||
f'Found {len(img_list)-1} cards, output written to {osp.abspath("result.jpg")}' | |||
) | |||
else: | |||
cv2.imwrite(f'card_{i}.jpg', img_list[i]) | |||
save_path = osp.abspath(f'card_{i}.jpg') | |||
print(f'detect card_{i}: {save_path}') | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_with_dataset(self): | |||
input_location = ['data/test/images/card_detection.jpg'] | |||
dataset = MsDataset.load(input_location, target='image') | |||
card_detection = pipeline(Tasks.card_detection, model=self.model_id) | |||
# note that for dataset output, the inference-output is a Generator that can be iterated. | |||
result = card_detection(dataset) | |||
result = next(result) | |||
self.show_result(input_location[0], result) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_modelhub(self): | |||
card_detection = pipeline(Tasks.card_detection, model=self.model_id) | |||
img_path = 'data/test/images/card_detection.jpg' | |||
result = card_detection(img_path) | |||
self.show_result(img_path, result) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_modelhub_default_model(self): | |||
card_detection = pipeline(Tasks.card_detection) | |||
img_path = 'data/test/images/card_detection.jpg' | |||
result = card_detection(img_path) | |||
self.show_result(img_path, result) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_demo_compatibility(self): | |||
self.compatibility_check() | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -25,10 +25,11 @@ class FaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_with_dataset(self): | |||
input_location = ['data/test/images/face_detection.png'] | |||
input_location = ['data/test/images/face_detection2.jpeg'] | |||
dataset = MsDataset.load(input_location, target='image') | |||
face_detection = pipeline(Tasks.face_detection, model=self.model_id) | |||
face_detection = pipeline( | |||
Tasks.face_detection, model=self.model_id, model_revision='v2') | |||
# note that for dataset output, the inference-output is a Generator that can be iterated. | |||
result = face_detection(dataset) | |||
result = next(result) | |||
@@ -36,8 +37,9 @@ class FaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_modelhub(self): | |||
face_detection = pipeline(Tasks.face_detection, model=self.model_id) | |||
img_path = 'data/test/images/face_detection.png' | |||
face_detection = pipeline( | |||
Tasks.face_detection, model=self.model_id, model_revision='v2') | |||
img_path = 'data/test/images/face_detection2.jpeg' | |||
result = face_detection(img_path) | |||
self.show_result(img_path, result) | |||
@@ -45,7 +47,7 @@ class FaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_modelhub_default_model(self): | |||
face_detection = pipeline(Tasks.face_detection) | |||
img_path = 'data/test/images/face_detection.png' | |||
img_path = 'data/test/images/face_detection2.jpeg' | |||
result = face_detection(img_path) | |||
self.show_result(img_path, result) | |||
@@ -0,0 +1,151 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import glob | |||
import os | |||
import shutil | |||
import tempfile | |||
import unittest | |||
import torch | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.metainfo import Trainers | |||
from modelscope.msdatasets import MsDataset | |||
from modelscope.trainers import build_trainer | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.test_utils import DistributedTestCase, test_level | |||
def _setup(): | |||
model_id = 'damo/cv_resnet_carddetection_scrfd34gkps' | |||
# mini dataset only for unit test, remove '_mini' for full dataset. | |||
ms_ds_syncards = MsDataset.load( | |||
'SyntheticCards_mini', namespace='shaoxuan') | |||
data_path = ms_ds_syncards.config_kwargs['split_config'] | |||
train_dir = data_path['train'] | |||
val_dir = data_path['validation'] | |||
train_root = train_dir + '/' + os.listdir(train_dir)[0] + '/' | |||
val_root = val_dir + '/' + os.listdir(val_dir)[0] + '/' | |||
max_epochs = 1 # run epochs in unit test | |||
cache_path = snapshot_download(model_id) | |||
tmp_dir = tempfile.TemporaryDirectory().name | |||
if not os.path.exists(tmp_dir): | |||
os.makedirs(tmp_dir) | |||
return train_root, val_root, max_epochs, cache_path, tmp_dir | |||
def train_func(**kwargs): | |||
trainer = build_trainer( | |||
name=Trainers.card_detection_scrfd, default_args=kwargs) | |||
trainer.train() | |||
class TestCardDetectionScrfdTrainerSingleGPU(unittest.TestCase): | |||
def setUp(self): | |||
print(('SingleGPU Testing %s.%s' % | |||
(type(self).__name__, self._testMethodName))) | |||
self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup( | |||
) | |||
def tearDown(self): | |||
shutil.rmtree(self.tmp_dir) | |||
super().tearDown() | |||
def _cfg_modify_fn(self, cfg): | |||
cfg.checkpoint_config.interval = 1 | |||
cfg.log_config.interval = 10 | |||
cfg.evaluation.interval = 1 | |||
cfg.data.workers_per_gpu = 3 | |||
cfg.data.samples_per_gpu = 4 # batch size | |||
return cfg | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_trainer_from_scratch(self): | |||
kwargs = dict( | |||
cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), | |||
work_dir=self.tmp_dir, | |||
train_root=self.train_root, | |||
val_root=self.val_root, | |||
total_epochs=self.max_epochs, | |||
cfg_modify_fn=self._cfg_modify_fn) | |||
trainer = build_trainer( | |||
name=Trainers.card_detection_scrfd, default_args=kwargs) | |||
trainer.train() | |||
results_files = os.listdir(self.tmp_dir) | |||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
for i in range(self.max_epochs): | |||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_trainer_finetune(self): | |||
pretrain_epoch = 640 | |||
self.max_epochs += pretrain_epoch | |||
kwargs = dict( | |||
cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), | |||
work_dir=self.tmp_dir, | |||
train_root=self.train_root, | |||
val_root=self.val_root, | |||
total_epochs=self.max_epochs, | |||
resume_from=os.path.join(self.cache_path, | |||
ModelFile.TORCH_MODEL_BIN_FILE), | |||
cfg_modify_fn=self._cfg_modify_fn) | |||
trainer = build_trainer( | |||
name=Trainers.card_detection_scrfd, default_args=kwargs) | |||
trainer.train() | |||
results_files = os.listdir(self.tmp_dir) | |||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
for i in range(pretrain_epoch, self.max_epochs): | |||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
@unittest.skipIf(not torch.cuda.is_available() | |||
or torch.cuda.device_count() <= 1, 'distributed unittest') | |||
class TestCardDetectionScrfdTrainerMultiGpus(DistributedTestCase): | |||
def setUp(self): | |||
print(('MultiGPUs Testing %s.%s' % | |||
(type(self).__name__, self._testMethodName))) | |||
self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup( | |||
) | |||
cfg_file_path = os.path.join(self.cache_path, 'mmcv_scrfd.py') | |||
cfg = Config.from_file(cfg_file_path) | |||
cfg.checkpoint_config.interval = 1 | |||
cfg.log_config.interval = 10 | |||
cfg.evaluation.interval = 1 | |||
cfg.data.workers_per_gpu = 3 | |||
cfg.data.samples_per_gpu = 4 | |||
cfg.dump(cfg_file_path) | |||
def tearDown(self): | |||
shutil.rmtree(self.tmp_dir) | |||
super().tearDown() | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_multi_gpus_finetune(self): | |||
pretrain_epoch = 640 | |||
self.max_epochs += pretrain_epoch | |||
kwargs = dict( | |||
cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), | |||
work_dir=self.tmp_dir, | |||
train_root=self.train_root, | |||
val_root=self.val_root, | |||
total_epochs=self.max_epochs, | |||
resume_from=os.path.join(self.cache_path, | |||
ModelFile.TORCH_MODEL_BIN_FILE), | |||
launcher='pytorch') | |||
self.start(train_func, num_gpus=2, **kwargs) | |||
results_files = os.listdir(self.tmp_dir) | |||
json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) | |||
self.assertEqual(len(json_files), 1) | |||
for i in range(pretrain_epoch, self.max_epochs): | |||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -0,0 +1,150 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import glob | |||
import os | |||
import shutil | |||
import tempfile | |||
import unittest | |||
import torch | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.metainfo import Trainers | |||
from modelscope.msdatasets import MsDataset | |||
from modelscope.trainers import build_trainer | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.test_utils import DistributedTestCase, test_level | |||
def _setup(): | |||
model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' | |||
# mini dataset only for unit test, remove '_mini' for full dataset. | |||
ms_ds_widerface = MsDataset.load('WIDER_FACE_mini', namespace='shaoxuan') | |||
data_path = ms_ds_widerface.config_kwargs['split_config'] | |||
train_dir = data_path['train'] | |||
val_dir = data_path['validation'] | |||
train_root = train_dir + '/' + os.listdir(train_dir)[0] + '/' | |||
val_root = val_dir + '/' + os.listdir(val_dir)[0] + '/' | |||
max_epochs = 1 # run epochs in unit test | |||
cache_path = snapshot_download(model_id, revision='v2') | |||
tmp_dir = tempfile.TemporaryDirectory().name | |||
if not os.path.exists(tmp_dir): | |||
os.makedirs(tmp_dir) | |||
return train_root, val_root, max_epochs, cache_path, tmp_dir | |||
def train_func(**kwargs): | |||
trainer = build_trainer( | |||
name=Trainers.face_detection_scrfd, default_args=kwargs) | |||
trainer.train() | |||
class TestFaceDetectionScrfdTrainerSingleGPU(unittest.TestCase): | |||
def setUp(self): | |||
print(('SingleGPU Testing %s.%s' % | |||
(type(self).__name__, self._testMethodName))) | |||
self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup( | |||
) | |||
def tearDown(self): | |||
shutil.rmtree(self.tmp_dir) | |||
super().tearDown() | |||
def _cfg_modify_fn(self, cfg): | |||
cfg.checkpoint_config.interval = 1 | |||
cfg.log_config.interval = 10 | |||
cfg.evaluation.interval = 1 | |||
cfg.data.workers_per_gpu = 3 | |||
cfg.data.samples_per_gpu = 4 # batch size | |||
return cfg | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_trainer_from_scratch(self): | |||
kwargs = dict( | |||
cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), | |||
work_dir=self.tmp_dir, | |||
train_root=self.train_root, | |||
val_root=self.val_root, | |||
total_epochs=self.max_epochs, | |||
cfg_modify_fn=self._cfg_modify_fn) | |||
trainer = build_trainer( | |||
name=Trainers.face_detection_scrfd, default_args=kwargs) | |||
trainer.train() | |||
results_files = os.listdir(self.tmp_dir) | |||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
for i in range(self.max_epochs): | |||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_trainer_finetune(self): | |||
pretrain_epoch = 640 | |||
self.max_epochs += pretrain_epoch | |||
kwargs = dict( | |||
cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), | |||
work_dir=self.tmp_dir, | |||
train_root=self.train_root, | |||
val_root=self.val_root, | |||
total_epochs=self.max_epochs, | |||
resume_from=os.path.join(self.cache_path, | |||
ModelFile.TORCH_MODEL_BIN_FILE), | |||
cfg_modify_fn=self._cfg_modify_fn) | |||
trainer = build_trainer( | |||
name=Trainers.face_detection_scrfd, default_args=kwargs) | |||
trainer.train() | |||
results_files = os.listdir(self.tmp_dir) | |||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
for i in range(pretrain_epoch, self.max_epochs): | |||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
@unittest.skipIf(not torch.cuda.is_available() | |||
or torch.cuda.device_count() <= 1, 'distributed unittest') | |||
class TestFaceDetectionScrfdTrainerMultiGpus(DistributedTestCase): | |||
def setUp(self): | |||
print(('MultiGPUs Testing %s.%s' % | |||
(type(self).__name__, self._testMethodName))) | |||
self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup( | |||
) | |||
cfg_file_path = os.path.join(self.cache_path, 'mmcv_scrfd.py') | |||
cfg = Config.from_file(cfg_file_path) | |||
cfg.checkpoint_config.interval = 1 | |||
cfg.log_config.interval = 10 | |||
cfg.evaluation.interval = 1 | |||
cfg.data.workers_per_gpu = 3 | |||
cfg.data.samples_per_gpu = 4 | |||
cfg.dump(cfg_file_path) | |||
def tearDown(self): | |||
shutil.rmtree(self.tmp_dir) | |||
super().tearDown() | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_multi_gpus_finetune(self): | |||
pretrain_epoch = 640 | |||
self.max_epochs += pretrain_epoch | |||
kwargs = dict( | |||
cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), | |||
work_dir=self.tmp_dir, | |||
train_root=self.train_root, | |||
val_root=self.val_root, | |||
total_epochs=self.max_epochs, | |||
resume_from=os.path.join(self.cache_path, | |||
ModelFile.TORCH_MODEL_BIN_FILE), | |||
launcher='pytorch') | |||
self.start(train_func, num_gpus=2, **kwargs) | |||
results_files = os.listdir(self.tmp_dir) | |||
json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) | |||
self.assertEqual(len(json_files), 1) | |||
for i in range(pretrain_epoch, self.max_epochs): | |||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
if __name__ == '__main__': | |||
unittest.main() |