Browse Source

[to #42322933] 新增RetinaFace人脸检测器

1. 新增人脸检测RetinaFace模型;
2. 完成Maas-cv CR标准自查
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9945188
master
ly261666 yingda.chen 3 years ago
parent
commit
f508be8918
11 changed files with 647 additions and 1 deletions
  1. +3
    -0
      data/test/images/retina_face_detection.jpg
  2. +2
    -0
      modelscope/metainfo.py
  3. +0
    -0
      modelscope/models/cv/face_detection/retinaface/__init__.py
  4. +137
    -0
      modelscope/models/cv/face_detection/retinaface/detection.py
  5. +0
    -0
      modelscope/models/cv/face_detection/retinaface/models/__init__.py
  6. +149
    -0
      modelscope/models/cv/face_detection/retinaface/models/net.py
  7. +145
    -0
      modelscope/models/cv/face_detection/retinaface/models/retinaface.py
  8. +123
    -0
      modelscope/models/cv/face_detection/retinaface/utils.py
  9. +0
    -1
      modelscope/pipelines/base.py
  10. +55
    -0
      modelscope/pipelines/cv/retina_face_detection_pipeline.py
  11. +33
    -0
      tests/pipelines/test_retina_face_detection.py

+ 3
- 0
data/test/images/retina_face_detection.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9
size 87228

+ 2
- 0
modelscope/metainfo.py View File

@@ -32,6 +32,7 @@ class Models(object):
vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation'
text_driven_segmentation = 'text-driven-segmentation' text_driven_segmentation = 'text-driven-segmentation'
resnet50_bert = 'resnet50-bert' resnet50_bert = 'resnet50-bert'
retinaface = 'retinaface'
shop_segmentation = 'shop-segmentation' shop_segmentation = 'shop-segmentation'


# EasyCV models # EasyCV models
@@ -118,6 +119,7 @@ class Pipelines(object):
salient_detection = 'u2net-salient-detection' salient_detection = 'u2net-salient-detection'
image_classification = 'image-classification' image_classification = 'image-classification'
face_detection = 'resnet-face-detection-scrfd10gkps' face_detection = 'resnet-face-detection-scrfd10gkps'
retina_face_detection = 'resnet50-face-detection-retinaface'
live_category = 'live-category' live_category = 'live-category'
general_image_classification = 'vit-base_image-classification_ImageNet-labels' general_image_classification = 'vit-base_image-classification_ImageNet-labels'
daily_image_classification = 'vit-base_image-classification_Dailylife-labels' daily_image_classification = 'vit-base_image-classification_Dailylife-labels'


+ 0
- 0
modelscope/models/cv/face_detection/retinaface/__init__.py View File


+ 137
- 0
modelscope/models/cv/face_detection/retinaface/detection.py View File

@@ -0,0 +1,137 @@
# The implementation is based on resnet, available at https://github.com/biubug6/Pytorch_Retinaface
import cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnn

from modelscope.metainfo import Models
from modelscope.models.base import Tensor, TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from .models.retinaface import RetinaFace
from .utils import PriorBox, decode, decode_landm, py_cpu_nms


@MODELS.register_module(Tasks.face_detection, module_name=Models.retinaface)
class RetinaFaceDetection(TorchModel):

def __init__(self, model_path, device='cuda'):
super().__init__(model_path)
torch.set_grad_enabled(False)
cudnn.benchmark = True
self.model_path = model_path
self.cfg = Config.from_file(
model_path.replace(ModelFile.TORCH_MODEL_FILE,
ModelFile.CONFIGURATION))['models']
self.net = RetinaFace(cfg=self.cfg)
self.load_model()
self.device = device
self.net = self.net.to(self.device)

self.mean = torch.tensor([[[[104]], [[117]], [[123]]]]).to(device)

def check_keys(self, pretrained_state_dict):
ckpt_keys = set(pretrained_state_dict.keys())
model_keys = set(self.net.state_dict().keys())
used_pretrained_keys = model_keys & ckpt_keys
assert len(
used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
return True

def remove_prefix(self, state_dict, prefix):
new_state_dict = dict()
for k, v in state_dict.items():
if k.startswith(prefix):
new_state_dict[k[len(prefix):]] = v
else:
new_state_dict[k] = v
return new_state_dict

def load_model(self, load_to_cpu=False):
pretrained_dict = torch.load(
self.model_path, map_location=torch.device('cpu'))
if 'state_dict' in pretrained_dict.keys():
pretrained_dict = self.remove_prefix(pretrained_dict['state_dict'],
'module.')
else:
pretrained_dict = self.remove_prefix(pretrained_dict, 'module.')
self.check_keys(pretrained_dict)
self.net.load_state_dict(pretrained_dict, strict=False)
self.net.eval()

def forward(self, input):
img_raw = input['img'].cpu().numpy()
img = np.float32(img_raw)

im_height, im_width = img.shape[:2]
ss = 1.0
# tricky
if max(im_height, im_width) > 1500:
ss = 1000.0 / max(im_height, im_width)
img = cv2.resize(img, (0, 0), fx=ss, fy=ss)
im_height, im_width = img.shape[:2]

scale = torch.Tensor(
[img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
img -= (104, 117, 123)
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).unsqueeze(0)
img = img.to(self.device)
scale = scale.to(self.device)

loc, conf, landms = self.net(img) # forward pass
del img

confidence_threshold = 0.9
nms_threshold = 0.4
top_k = 5000
keep_top_k = 750

priorbox = PriorBox(self.cfg, image_size=(im_height, im_width))
priors = priorbox.forward()
priors = priors.to(self.device)
prior_data = priors.data
boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance'])
boxes = boxes * scale
boxes = boxes.cpu().numpy()
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
landms = decode_landm(
landms.data.squeeze(0), prior_data, self.cfg['variance'])
scale1 = torch.Tensor([
im_width, im_height, im_width, im_height, im_width, im_height,
im_width, im_height, im_width, im_height
])
scale1 = scale1.to(self.device)
landms = landms * scale1
landms = landms.cpu().numpy()

# ignore low scores
inds = np.where(scores > confidence_threshold)[0]
boxes = boxes[inds]
landms = landms[inds]
scores = scores[inds]

# keep top-K before NMS
order = scores.argsort()[::-1][:top_k]
boxes = boxes[order]
landms = landms[order]
scores = scores[order]

# do NMS
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(
np.float32, copy=False)
keep = py_cpu_nms(dets, nms_threshold)
dets = dets[keep, :]
landms = landms[keep]

# keep top-K faster NMS
dets = dets[:keep_top_k, :]
landms = landms[:keep_top_k, :]

landms = landms.reshape((-1, 5, 2))
landms = landms.reshape(
-1,
10,
)
return dets / ss, landms / ss

+ 0
- 0
modelscope/models/cv/face_detection/retinaface/models/__init__.py View File


+ 149
- 0
modelscope/models/cv/face_detection/retinaface/models/net.py View File

@@ -0,0 +1,149 @@
# The implementation is based on resnet, available at https://github.com/biubug6/Pytorch_Retinaface
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.models._utils as _utils
from torch.autograd import Variable


def conv_bn(inp, oup, stride=1, leaky=0):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup),
nn.LeakyReLU(negative_slope=leaky, inplace=True))


def conv_bn_no_relu(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
)


def conv_bn1X1(inp, oup, stride, leaky=0):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False),
nn.BatchNorm2d(oup), nn.LeakyReLU(negative_slope=leaky, inplace=True))


def conv_dw(inp, oup, stride, leaky=0.1):
return nn.Sequential(
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.LeakyReLU(negative_slope=leaky, inplace=True),
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.LeakyReLU(negative_slope=leaky, inplace=True),
)


class SSH(nn.Module):

def __init__(self, in_channel, out_channel):
super(SSH, self).__init__()
assert out_channel % 4 == 0
leaky = 0
if (out_channel <= 64):
leaky = 0.1
self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1)

self.conv5X5_1 = conv_bn(
in_channel, out_channel // 4, stride=1, leaky=leaky)
self.conv5X5_2 = conv_bn_no_relu(
out_channel // 4, out_channel // 4, stride=1)

self.conv7X7_2 = conv_bn(
out_channel // 4, out_channel // 4, stride=1, leaky=leaky)
self.conv7x7_3 = conv_bn_no_relu(
out_channel // 4, out_channel // 4, stride=1)

def forward(self, input):
conv3X3 = self.conv3X3(input)

conv5X5_1 = self.conv5X5_1(input)
conv5X5 = self.conv5X5_2(conv5X5_1)

conv7X7_2 = self.conv7X7_2(conv5X5_1)
conv7X7 = self.conv7x7_3(conv7X7_2)

out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
out = F.relu(out)
return out


class FPN(nn.Module):

def __init__(self, in_channels_list, out_channels):
super(FPN, self).__init__()
leaky = 0
if (out_channels <= 64):
leaky = 0.1
self.output1 = conv_bn1X1(
in_channels_list[0], out_channels, stride=1, leaky=leaky)
self.output2 = conv_bn1X1(
in_channels_list[1], out_channels, stride=1, leaky=leaky)
self.output3 = conv_bn1X1(
in_channels_list[2], out_channels, stride=1, leaky=leaky)

self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky)
self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky)

def forward(self, input):
# names = list(input.keys())
input = list(input.values())

output1 = self.output1(input[0])
output2 = self.output2(input[1])
output3 = self.output3(input[2])

up3 = F.interpolate(
output3, size=[output2.size(2), output2.size(3)], mode='nearest')
output2 = output2 + up3
output2 = self.merge2(output2)

up2 = F.interpolate(
output2, size=[output1.size(2), output1.size(3)], mode='nearest')
output1 = output1 + up2
output1 = self.merge1(output1)

out = [output1, output2, output3]
return out


class MobileNetV1(nn.Module):

def __init__(self):
super(MobileNetV1, self).__init__()
self.stage1 = nn.Sequential(
conv_bn(3, 8, 2, leaky=0.1), # 3
conv_dw(8, 16, 1), # 7
conv_dw(16, 32, 2), # 11
conv_dw(32, 32, 1), # 19
conv_dw(32, 64, 2), # 27
conv_dw(64, 64, 1), # 43
)
self.stage2 = nn.Sequential(
conv_dw(64, 128, 2), # 43 + 16 = 59
conv_dw(128, 128, 1), # 59 + 32 = 91
conv_dw(128, 128, 1), # 91 + 32 = 123
conv_dw(128, 128, 1), # 123 + 32 = 155
conv_dw(128, 128, 1), # 155 + 32 = 187
conv_dw(128, 128, 1), # 187 + 32 = 219
)
self.stage3 = nn.Sequential(
conv_dw(128, 256, 2), # 219 +3 2 = 241
conv_dw(256, 256, 1), # 241 + 64 = 301
)
self.avg = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(256, 1000)

def forward(self, x):
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.avg(x)
x = x.view(-1, 256)
x = self.fc(x)
return x

+ 145
- 0
modelscope/models/cv/face_detection/retinaface/models/retinaface.py View File

@@ -0,0 +1,145 @@
# The implementation is based on resnet, available at https://github.com/biubug6/Pytorch_Retinaface
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.models._utils as _utils
import torchvision.models.detection.backbone_utils as backbone_utils

from .net import FPN, SSH, MobileNetV1


class ClassHead(nn.Module):

def __init__(self, inchannels=512, num_anchors=3):
super(ClassHead, self).__init__()
self.num_anchors = num_anchors
self.conv1x1 = nn.Conv2d(
inchannels,
self.num_anchors * 2,
kernel_size=(1, 1),
stride=1,
padding=0)

def forward(self, x):
out = self.conv1x1(x)
out = out.permute(0, 2, 3, 1).contiguous()

return out.view(out.shape[0], -1, 2)


class BboxHead(nn.Module):

def __init__(self, inchannels=512, num_anchors=3):
super(BboxHead, self).__init__()
self.conv1x1 = nn.Conv2d(
inchannels,
num_anchors * 4,
kernel_size=(1, 1),
stride=1,
padding=0)

def forward(self, x):
out = self.conv1x1(x)
out = out.permute(0, 2, 3, 1).contiguous()

return out.view(out.shape[0], -1, 4)


class LandmarkHead(nn.Module):

def __init__(self, inchannels=512, num_anchors=3):
super(LandmarkHead, self).__init__()
self.conv1x1 = nn.Conv2d(
inchannels,
num_anchors * 10,
kernel_size=(1, 1),
stride=1,
padding=0)

def forward(self, x):
out = self.conv1x1(x)
out = out.permute(0, 2, 3, 1).contiguous()

return out.view(out.shape[0], -1, 10)


class RetinaFace(nn.Module):

def __init__(self, cfg=None):
"""
:param cfg: Network related settings.
"""
super(RetinaFace, self).__init__()
backbone = None
if cfg['name'] == 'Resnet50':
backbone = models.resnet50(pretrained=cfg['pretrain'])
else:
raise Exception('Invalid name')

self.body = _utils.IntermediateLayerGetter(backbone,
cfg['return_layers'])
in_channels_stage2 = cfg['in_channel']
in_channels_list = [
in_channels_stage2 * 2,
in_channels_stage2 * 4,
in_channels_stage2 * 8,
]
out_channels = cfg['out_channel']
self.fpn = FPN(in_channels_list, out_channels)
self.ssh1 = SSH(out_channels, out_channels)
self.ssh2 = SSH(out_channels, out_channels)
self.ssh3 = SSH(out_channels, out_channels)

self.ClassHead = self._make_class_head(
fpn_num=3, inchannels=cfg['out_channel'])
self.BboxHead = self._make_bbox_head(
fpn_num=3, inchannels=cfg['out_channel'])
self.LandmarkHead = self._make_landmark_head(
fpn_num=3, inchannels=cfg['out_channel'])

def _make_class_head(self, fpn_num=3, inchannels=64, anchor_num=2):
classhead = nn.ModuleList()
for i in range(fpn_num):
classhead.append(ClassHead(inchannels, anchor_num))
return classhead

def _make_bbox_head(self, fpn_num=3, inchannels=64, anchor_num=2):
bboxhead = nn.ModuleList()
for i in range(fpn_num):
bboxhead.append(BboxHead(inchannels, anchor_num))
return bboxhead

def _make_landmark_head(self, fpn_num=3, inchannels=64, anchor_num=2):
landmarkhead = nn.ModuleList()
for i in range(fpn_num):
landmarkhead.append(LandmarkHead(inchannels, anchor_num))
return landmarkhead

def forward(self, inputs):
out = self.body(inputs)

# FPN
fpn = self.fpn(out)

# SSH
feature1 = self.ssh1(fpn[0])
feature2 = self.ssh2(fpn[1])
feature3 = self.ssh3(fpn[2])
features = [feature1, feature2, feature3]

bbox_regressions = torch.cat(
[self.BboxHead[i](feature) for i, feature in enumerate(features)],
dim=1)
classifications = torch.cat(
[self.ClassHead[i](feature) for i, feature in enumerate(features)],
dim=1)
ldm_regressions = torch.cat(
[self.LandmarkHead[i](feat) for i, feat in enumerate(features)],
dim=1)

output = (bbox_regressions, F.softmax(classifications,
dim=-1), ldm_regressions)
return output

+ 123
- 0
modelscope/models/cv/face_detection/retinaface/utils.py View File

@@ -0,0 +1,123 @@
# --------------------------------------------------------
# Modified from https://github.com/biubug6/Pytorch_Retinaface
# --------------------------------------------------------

from itertools import product as product
from math import ceil

import numpy as np
import torch


class PriorBox(object):

def __init__(self, cfg, image_size=None, phase='train'):
super(PriorBox, self).__init__()
self.min_sizes = cfg['min_sizes']
self.steps = cfg['steps']
self.clip = cfg['clip']
self.image_size = image_size
self.feature_maps = [[
ceil(self.image_size[0] / step),
ceil(self.image_size[1] / step)
] for step in self.steps]
self.name = 's'

def forward(self):
anchors = []
for k, f in enumerate(self.feature_maps):
min_sizes = self.min_sizes[k]
for i, j in product(range(f[0]), range(f[1])):
for min_size in min_sizes:
s_kx = min_size / self.image_size[1]
s_ky = min_size / self.image_size[0]
dense_cx = [
x * self.steps[k] / self.image_size[1]
for x in [j + 0.5]
]
dense_cy = [
y * self.steps[k] / self.image_size[0]
for y in [i + 0.5]
]
for cy, cx in product(dense_cy, dense_cx):
anchors += [cx, cy, s_kx, s_ky]

# back to torch land
output = torch.Tensor(anchors).view(-1, 4)
if self.clip:
output.clamp_(max=1, min=0)
return output


def py_cpu_nms(dets, thresh):
"""Pure Python NMS baseline."""
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4]

areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]

keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])

w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)

inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]

return keep


# Adapted from https://github.com/Hakuyume/chainer-ssd
def decode(loc, priors, variances):
"""Decode locations from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
loc (tensor): location predictions for loc layers,
Shape: [num_priors,4]
priors (tensor): Prior boxes in center-offset form.
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded bounding box predictions
"""

boxes = torch.cat(
(priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
boxes[:, :2] -= boxes[:, 2:] / 2
boxes[:, 2:] += boxes[:, :2]
return boxes


def decode_landm(pre, priors, variances):
"""Decode landm from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
pre (tensor): landm predictions for loc layers,
Shape: [num_priors,10]
priors (tensor): Prior boxes in center-offset form.
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded landm predictions
"""
a = priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:]
b = priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:]
c = priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:]
d = priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:]
e = priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:]
landms = torch.cat((a, b, c, d, e), dim=1)
return landms

+ 0
- 1
modelscope/pipelines/base.py View File

@@ -2,7 +2,6 @@


import os.path as osp import os.path as osp
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager
from threading import Lock from threading import Lock
from typing import Any, Dict, Generator, List, Mapping, Union from typing import Any, Dict, Generator, List, Mapping, Union




+ 55
- 0
modelscope/pipelines/cv/retina_face_detection_pipeline.py View File

@@ -0,0 +1,55 @@
import os.path as osp
from typing import Any, Dict

import numpy as np

from modelscope.metainfo import Pipelines
from modelscope.models.cv.face_detection.retinaface import detection
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.face_detection, module_name=Pipelines.retina_face_detection)
class RetinaFaceDetectionPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create a face detection pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_FILE)
logger.info(f'loading model from {ckpt_path}')
detector = detection.RetinaFaceDetection(
model_path=ckpt_path, device=self.device)
self.detector = detector
logger.info('load model done')

def preprocess(self, input: Input) -> Dict[str, Any]:
img = LoadImage.convert_to_ndarray(input)
img = img.astype(np.float32)
result = {'img': img}
return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
result = self.detector(input)
assert result is not None
bboxes = result[0][:, :4].tolist()
scores = result[0][:, 4].tolist()
lms = result[1].tolist()
return {
OutputKeys.SCORES: scores,
OutputKeys.BOXES: bboxes,
OutputKeys.KEYPOINTS: lms,
}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 33
- 0
tests/pipelines/test_retina_face_detection.py View File

@@ -0,0 +1,33 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import unittest

import cv2

from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.cv.image_utils import draw_face_detection_result
from modelscope.utils.test_utils import test_level


class RetinaFaceDetectionTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/cv_resnet50_face-detection_retinaface'

def show_result(self, img_path, detection_result):
img = draw_face_detection_result(img_path, detection_result)
cv2.imwrite('result.png', img)
print(f'output written to {osp.abspath("result.png")}')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
face_detection = pipeline(Tasks.face_detection, model=self.model_id)
img_path = 'data/test/images/retina_face_detection.jpg'

result = face_detection(img_path)
self.show_result(img_path, result)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save