Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9921926master
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9 | |||
size 87228 |
@@ -35,6 +35,7 @@ class Models(object): | |||
fer = 'fer' | |||
retinaface = 'retinaface' | |||
shop_segmentation = 'shop-segmentation' | |||
mogface = 'mogface' | |||
mtcnn = 'mtcnn' | |||
ulfd = 'ulfd' | |||
@@ -128,6 +129,7 @@ class Pipelines(object): | |||
ulfd_face_detection = 'manual-face-detection-ulfd' | |||
facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' | |||
retina_face_detection = 'resnet50-face-detection-retinaface' | |||
mog_face_detection = 'resnet101-face-detection-cvpr22papermogface' | |||
mtcnn_face_detection = 'manual-face-detection-mtcnn' | |||
live_category = 'live-category' | |||
general_image_classification = 'vit-base_image-classification_ImageNet-labels' | |||
@@ -4,15 +4,16 @@ from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .mogface import MogFaceDetector | |||
from .mtcnn import MtcnnFaceDetector | |||
from .retinaface import RetinaFaceDetection | |||
from .ulfd_slim import UlfdFaceDetector | |||
else: | |||
_import_structure = { | |||
'ulfd_slim': ['UlfdFaceDetector'], | |||
'retinaface': ['RetinaFaceDetection'], | |||
'mtcnn': ['MtcnnFaceDetector'] | |||
'mtcnn': ['MtcnnFaceDetector'], | |||
'mogface': ['MogFaceDetector'] | |||
} | |||
import sys | |||
@@ -0,0 +1 @@ | |||
from .models.detectors import MogFaceDetector |
@@ -0,0 +1,96 @@ | |||
import os | |||
import cv2 | |||
import numpy as np | |||
import torch | |||
import torch.backends.cudnn as cudnn | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base import TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.utils.constant import Tasks | |||
from .mogface import MogFace | |||
from .utils import MogPriorBox, mogdecode, py_cpu_nms | |||
@MODELS.register_module(Tasks.face_detection, module_name=Models.mogface) | |||
class MogFaceDetector(TorchModel): | |||
def __init__(self, model_path, device='cuda'): | |||
super().__init__(model_path) | |||
torch.set_grad_enabled(False) | |||
cudnn.benchmark = True | |||
self.model_path = model_path | |||
self.device = device | |||
self.net = MogFace() | |||
self.load_model() | |||
self.net = self.net.to(device) | |||
self.mean = np.array([[104, 117, 123]]) | |||
def load_model(self, load_to_cpu=False): | |||
pretrained_dict = torch.load( | |||
self.model_path, map_location=torch.device('cpu')) | |||
self.net.load_state_dict(pretrained_dict, strict=False) | |||
self.net.eval() | |||
def forward(self, input): | |||
img_raw = input['img'] | |||
img = np.array(img_raw.cpu().detach()) | |||
img = img[:, :, ::-1] | |||
im_height, im_width = img.shape[:2] | |||
ss = 1.0 | |||
# tricky | |||
if max(im_height, im_width) > 1500: | |||
ss = 1000.0 / max(im_height, im_width) | |||
img = cv2.resize(img, (0, 0), fx=ss, fy=ss) | |||
im_height, im_width = img.shape[:2] | |||
scale = torch.Tensor( | |||
[img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) | |||
img -= np.array([[103.53, 116.28, 123.675]]) | |||
img /= np.array([[57.375, 57.120003, 58.395]]) | |||
img /= 255 | |||
img = img[:, :, ::-1].copy() | |||
img = img.transpose(2, 0, 1) | |||
img = torch.from_numpy(img).unsqueeze(0) | |||
img = img.to(self.device) | |||
scale = scale.to(self.device) | |||
conf, loc = self.net(img) # forward pass | |||
confidence_threshold = 0.82 | |||
nms_threshold = 0.4 | |||
top_k = 5000 | |||
keep_top_k = 750 | |||
priorbox = MogPriorBox(scale_list=[0.68]) | |||
priors = priorbox(im_height, im_width) | |||
priors = torch.tensor(priors).to(self.device) | |||
prior_data = priors.data | |||
boxes = mogdecode(loc.data.squeeze(0), prior_data) | |||
boxes = boxes.cpu().numpy() | |||
scores = conf.squeeze(0).data.cpu().numpy()[:, 0] | |||
# ignore low scores | |||
inds = np.where(scores > confidence_threshold)[0] | |||
boxes = boxes[inds] | |||
scores = scores[inds] | |||
# keep top-K before NMS | |||
order = scores.argsort()[::-1][:top_k] | |||
boxes = boxes[order] | |||
scores = scores[order] | |||
# do NMS | |||
dets = np.hstack((boxes, scores[:, np.newaxis])).astype( | |||
np.float32, copy=False) | |||
keep = py_cpu_nms(dets, nms_threshold) | |||
dets = dets[keep, :] | |||
# keep top-K faster NMS | |||
dets = dets[:keep_top_k, :] | |||
return dets / ss |
@@ -0,0 +1,135 @@ | |||
# -------------------------------------------------------- | |||
# The implementation is also open-sourced by the authors as Yang Liu, and is available publicly on | |||
# https://github.com/damo-cv/MogFace | |||
# -------------------------------------------------------- | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from .mogprednet import MogPredNet | |||
from .resnet import ResNet | |||
class MogFace(nn.Module): | |||
def __init__(self): | |||
super(MogFace, self).__init__() | |||
self.backbone = ResNet(depth=101) | |||
self.fpn = LFPN() | |||
self.pred_net = MogPredNet() | |||
def forward(self, x): | |||
feature_list = self.backbone(x) | |||
fpn_list = self.fpn(feature_list) | |||
pyramid_feature_list = fpn_list[0] | |||
conf, loc = self.pred_net(pyramid_feature_list) | |||
return conf, loc | |||
class FeatureFusion(nn.Module): | |||
def __init__(self, lat_ch=256, **channels): | |||
super(FeatureFusion, self).__init__() | |||
self.main_conv = nn.Conv2d(channels['main'], lat_ch, kernel_size=1) | |||
def forward(self, up, main): | |||
main = self.main_conv(main) | |||
_, _, H, W = main.size() | |||
res = F.upsample(up, scale_factor=2, mode='bilinear') | |||
if res.size(2) != main.size(2) or res.size(3) != main.size(3): | |||
res = res[:, :, 0:H, 0:W] | |||
res = res + main | |||
return res | |||
class LFPN(nn.Module): | |||
def __init__(self, | |||
c2_out_ch=256, | |||
c3_out_ch=512, | |||
c4_out_ch=1024, | |||
c5_out_ch=2048, | |||
c6_mid_ch=512, | |||
c6_out_ch=512, | |||
c7_mid_ch=128, | |||
c7_out_ch=256, | |||
out_dsfd_ft=True): | |||
super(LFPN, self).__init__() | |||
self.out_dsfd_ft = out_dsfd_ft | |||
if self.out_dsfd_ft: | |||
dsfd_module = [] | |||
dsfd_module.append(nn.Conv2d(256, 256, kernel_size=3, padding=1)) | |||
dsfd_module.append(nn.Conv2d(512, 256, kernel_size=3, padding=1)) | |||
dsfd_module.append(nn.Conv2d(1024, 256, kernel_size=3, padding=1)) | |||
dsfd_module.append(nn.Conv2d(2048, 256, kernel_size=3, padding=1)) | |||
dsfd_module.append(nn.Conv2d(256, 256, kernel_size=3, padding=1)) | |||
dsfd_module.append(nn.Conv2d(256, 256, kernel_size=3, padding=1)) | |||
self.dsfd_modules = nn.ModuleList(dsfd_module) | |||
c6_input_ch = c5_out_ch | |||
self.c6 = nn.Sequential(*[ | |||
nn.Conv2d( | |||
c6_input_ch, | |||
c6_mid_ch, | |||
kernel_size=1, | |||
), | |||
nn.BatchNorm2d(c6_mid_ch), | |||
nn.ReLU(inplace=True), | |||
nn.Conv2d( | |||
c6_mid_ch, c6_out_ch, kernel_size=3, padding=1, stride=2), | |||
nn.BatchNorm2d(c6_out_ch), | |||
nn.ReLU(inplace=True) | |||
]) | |||
self.c7 = nn.Sequential(*[ | |||
nn.Conv2d( | |||
c6_out_ch, | |||
c7_mid_ch, | |||
kernel_size=1, | |||
), | |||
nn.BatchNorm2d(c7_mid_ch), | |||
nn.ReLU(inplace=True), | |||
nn.Conv2d( | |||
c7_mid_ch, c7_out_ch, kernel_size=3, padding=1, stride=2), | |||
nn.BatchNorm2d(c7_out_ch), | |||
nn.ReLU(inplace=True) | |||
]) | |||
self.p2_lat = nn.Conv2d(256, 256, kernel_size=3, padding=1) | |||
self.p3_lat = nn.Conv2d(256, 256, kernel_size=3, padding=1) | |||
self.p4_lat = nn.Conv2d(256, 256, kernel_size=3, padding=1) | |||
self.c5_lat = nn.Conv2d(c6_input_ch, 256, kernel_size=3, padding=1) | |||
self.c6_lat = nn.Conv2d(c6_out_ch, 256, kernel_size=3, padding=1) | |||
self.c7_lat = nn.Conv2d(c7_out_ch, 256, kernel_size=3, padding=1) | |||
self.ff_c5_c4 = FeatureFusion(main=c4_out_ch) | |||
self.ff_c4_c3 = FeatureFusion(main=c3_out_ch) | |||
self.ff_c3_c2 = FeatureFusion(main=c2_out_ch) | |||
def forward(self, feature_list): | |||
c2, c3, c4, c5 = feature_list | |||
c6 = self.c6(c5) | |||
c7 = self.c7(c6) | |||
c5 = self.c5_lat(c5) | |||
c6 = self.c6_lat(c6) | |||
c7 = self.c7_lat(c7) | |||
if self.out_dsfd_ft: | |||
dsfd_fts = [] | |||
dsfd_fts.append(self.dsfd_modules[0](c2)) | |||
dsfd_fts.append(self.dsfd_modules[1](c3)) | |||
dsfd_fts.append(self.dsfd_modules[2](c4)) | |||
dsfd_fts.append(self.dsfd_modules[3](feature_list[-1])) | |||
dsfd_fts.append(self.dsfd_modules[4](c6)) | |||
dsfd_fts.append(self.dsfd_modules[5](c7)) | |||
p4 = self.ff_c5_c4(c5, c4) | |||
p3 = self.ff_c4_c3(p4, c3) | |||
p2 = self.ff_c3_c2(p3, c2) | |||
p2 = self.p2_lat(p2) | |||
p3 = self.p3_lat(p3) | |||
p4 = self.p4_lat(p4) | |||
if self.out_dsfd_ft: | |||
return ([p2, p3, p4, c5, c6, c7], dsfd_fts) |
@@ -0,0 +1,164 @@ | |||
# -------------------------------------------------------- | |||
# The implementation is also open-sourced by the authors as Yang Liu, and is available publicly on | |||
# https://github.com/damo-cv/MogFace | |||
# -------------------------------------------------------- | |||
import math | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
class conv_bn(nn.Module): | |||
"""docstring for conv""" | |||
def __init__(self, in_plane, out_plane, kernel_size, stride, padding): | |||
super(conv_bn, self).__init__() | |||
self.conv1 = nn.Conv2d( | |||
in_plane, | |||
out_plane, | |||
kernel_size=kernel_size, | |||
stride=stride, | |||
padding=padding) | |||
self.bn1 = nn.BatchNorm2d(out_plane) | |||
def forward(self, x): | |||
x = self.conv1(x) | |||
return self.bn1(x) | |||
class SSHContext(nn.Module): | |||
def __init__(self, channels, Xchannels=256): | |||
super(SSHContext, self).__init__() | |||
self.conv1 = nn.Conv2d( | |||
channels, Xchannels, kernel_size=3, stride=1, padding=1) | |||
self.conv2 = nn.Conv2d( | |||
channels, | |||
Xchannels // 2, | |||
kernel_size=3, | |||
dilation=2, | |||
stride=1, | |||
padding=2) | |||
self.conv2_1 = nn.Conv2d( | |||
Xchannels // 2, Xchannels // 2, kernel_size=3, stride=1, padding=1) | |||
self.conv2_2 = nn.Conv2d( | |||
Xchannels // 2, | |||
Xchannels // 2, | |||
kernel_size=3, | |||
dilation=2, | |||
stride=1, | |||
padding=2) | |||
self.conv2_2_1 = nn.Conv2d( | |||
Xchannels // 2, Xchannels // 2, kernel_size=3, stride=1, padding=1) | |||
def forward(self, x): | |||
x1 = F.relu(self.conv1(x), inplace=True) | |||
x2 = F.relu(self.conv2(x), inplace=True) | |||
x2_1 = F.relu(self.conv2_1(x2), inplace=True) | |||
x2_2 = F.relu(self.conv2_2(x2), inplace=True) | |||
x2_2 = F.relu(self.conv2_2_1(x2_2), inplace=True) | |||
return torch.cat([x1, x2_1, x2_2], 1) | |||
class DeepHead(nn.Module): | |||
def __init__(self, | |||
in_channel=256, | |||
out_channel=256, | |||
use_gn=False, | |||
num_conv=4): | |||
super(DeepHead, self).__init__() | |||
self.use_gn = use_gn | |||
self.num_conv = num_conv | |||
self.conv1 = nn.Conv2d(in_channel, out_channel, 3, 1, 1) | |||
self.conv2 = nn.Conv2d(out_channel, out_channel, 3, 1, 1) | |||
self.conv3 = nn.Conv2d(out_channel, out_channel, 3, 1, 1) | |||
self.conv4 = nn.Conv2d(out_channel, out_channel, 3, 1, 1) | |||
if self.use_gn: | |||
self.gn1 = nn.GroupNorm(16, out_channel) | |||
self.gn2 = nn.GroupNorm(16, out_channel) | |||
self.gn3 = nn.GroupNorm(16, out_channel) | |||
self.gn4 = nn.GroupNorm(16, out_channel) | |||
def forward(self, x): | |||
if self.use_gn: | |||
x1 = F.relu(self.gn1(self.conv1(x)), inplace=True) | |||
x2 = F.relu(self.gn2(self.conv1(x1)), inplace=True) | |||
x3 = F.relu(self.gn3(self.conv1(x2)), inplace=True) | |||
x4 = F.relu(self.gn4(self.conv1(x3)), inplace=True) | |||
else: | |||
x1 = F.relu(self.conv1(x), inplace=True) | |||
x2 = F.relu(self.conv1(x1), inplace=True) | |||
if self.num_conv == 2: | |||
return x2 | |||
x3 = F.relu(self.conv1(x2), inplace=True) | |||
x4 = F.relu(self.conv1(x3), inplace=True) | |||
return x4 | |||
class MogPredNet(nn.Module): | |||
def __init__(self, | |||
num_anchor_per_pixel=1, | |||
num_classes=1, | |||
input_ch_list=[256, 256, 256, 256, 256, 256], | |||
use_deep_head=True, | |||
deep_head_with_gn=True, | |||
use_ssh=True, | |||
deep_head_ch=512): | |||
super(MogPredNet, self).__init__() | |||
self.num_classes = num_classes | |||
self.use_deep_head = use_deep_head | |||
self.deep_head_with_gn = deep_head_with_gn | |||
self.use_ssh = use_ssh | |||
self.deep_head_ch = deep_head_ch | |||
if self.use_ssh: | |||
self.conv_SSH = SSHContext(input_ch_list[0], | |||
self.deep_head_ch // 2) | |||
if self.use_deep_head: | |||
if self.deep_head_with_gn: | |||
self.deep_loc_head = DeepHead( | |||
self.deep_head_ch, self.deep_head_ch, use_gn=True) | |||
self.deep_cls_head = DeepHead( | |||
self.deep_head_ch, self.deep_head_ch, use_gn=True) | |||
self.pred_cls = nn.Conv2d(self.deep_head_ch, | |||
1 * num_anchor_per_pixel, 3, 1, 1) | |||
self.pred_loc = nn.Conv2d(self.deep_head_ch, | |||
4 * num_anchor_per_pixel, 3, 1, 1) | |||
self.sigmoid = nn.Sigmoid() | |||
def forward(self, pyramid_feature_list, dsfd_ft_list=None): | |||
loc = [] | |||
conf = [] | |||
if self.use_deep_head: | |||
for x in pyramid_feature_list: | |||
if self.use_ssh: | |||
x = self.conv_SSH(x) | |||
x_cls = self.deep_cls_head(x) | |||
x_loc = self.deep_loc_head(x) | |||
conf.append( | |||
self.pred_cls(x_cls).permute(0, 2, 3, 1).contiguous()) | |||
loc.append( | |||
self.pred_loc(x_loc).permute(0, 2, 3, 1).contiguous()) | |||
loc = torch.cat([o.view(o.size(0), -1, 4) for o in loc], 1) | |||
conf = torch.cat( | |||
[o.view(o.size(0), -1, self.num_classes) for o in conf], 1) | |||
output = ( | |||
self.sigmoid(conf.view(conf.size(0), -1, self.num_classes)), | |||
loc.view(loc.size(0), -1, 4), | |||
) | |||
return output |
@@ -0,0 +1,193 @@ | |||
# The implementation is modified from original resent implementaiton, which is | |||
# also open-sourced by the authors as Yang Liu, | |||
# and is available publicly on https://github.com/damo-cv/MogFace | |||
import torch.nn as nn | |||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | |||
"""3x3 convolution with padding""" | |||
return nn.Conv2d( | |||
in_planes, | |||
out_planes, | |||
kernel_size=3, | |||
stride=stride, | |||
padding=dilation, | |||
groups=groups, | |||
bias=False, | |||
dilation=dilation) | |||
def conv1x1(in_planes, out_planes, stride=1): | |||
"""1x1 convolution""" | |||
return nn.Conv2d( | |||
in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |||
class Bottleneck(nn.Module): | |||
expansion = 4 | |||
def __init__(self, | |||
inplanes, | |||
planes, | |||
stride=1, | |||
downsample=None, | |||
groups=1, | |||
base_width=64, | |||
dilation=1, | |||
norm_layer=None): | |||
super(Bottleneck, self).__init__() | |||
if norm_layer is None: | |||
norm_layer = nn.BatchNorm2d | |||
width = int(planes * (base_width / 64.)) * groups | |||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1 | |||
self.conv1 = conv1x1(inplanes, width) | |||
self.bn1 = norm_layer(width) | |||
self.conv2 = conv3x3(width, width, stride, groups, dilation) | |||
self.bn2 = norm_layer(width) | |||
self.conv3 = conv1x1(width, planes * self.expansion) | |||
self.bn3 = norm_layer(planes * self.expansion) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.downsample = downsample | |||
self.stride = stride | |||
def forward(self, x): | |||
identity = x | |||
out = self.conv1(x) | |||
out = self.bn1(out) | |||
out = self.relu(out) | |||
out = self.conv2(out) | |||
out = self.bn2(out) | |||
out = self.relu(out) | |||
out = self.conv3(out) | |||
out = self.bn3(out) | |||
if self.downsample is not None: | |||
identity = self.downsample(x) | |||
out += identity | |||
out = self.relu(out) | |||
return out | |||
class ResNet(nn.Module): | |||
def __init__(self, | |||
depth=50, | |||
groups=1, | |||
width_per_group=64, | |||
replace_stride_with_dilation=None, | |||
norm_layer=None, | |||
inplanes=64, | |||
shrink_ch_ratio=1): | |||
super(ResNet, self).__init__() | |||
if norm_layer is None: | |||
norm_layer = nn.BatchNorm2d | |||
self._norm_layer = norm_layer | |||
if depth == 50: | |||
block = Bottleneck | |||
layers = [3, 4, 6, 3] | |||
elif depth == 101: | |||
block = Bottleneck | |||
layers = [3, 4, 23, 3] | |||
elif depth == 152: | |||
block = Bottleneck | |||
layers = [3, 4, 36, 3] | |||
elif depth == 18: | |||
block = BasicBlock | |||
layers = [2, 2, 2, 2] | |||
else: | |||
raise ValueError('only support depth in [18, 50, 101, 152]') | |||
shrink_input_ch = int(inplanes * shrink_ch_ratio) | |||
self.inplanes = int(inplanes * shrink_ch_ratio) | |||
if shrink_ch_ratio == 0.125: | |||
layers = [2, 3, 3, 3] | |||
self.dilation = 1 | |||
if replace_stride_with_dilation is None: | |||
# each element in the tuple indicates if we should replace | |||
# the 2x2 stride with a dilated convolution instead | |||
replace_stride_with_dilation = [False, False, False] | |||
if len(replace_stride_with_dilation) != 3: | |||
raise ValueError('replace_stride_with_dilation should be None ' | |||
'or a 3-element tuple, got {}'.format( | |||
replace_stride_with_dilation)) | |||
self.groups = groups | |||
self.base_width = width_per_group | |||
self.conv1 = nn.Conv2d( | |||
3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) | |||
self.bn1 = norm_layer(self.inplanes) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||
self.layer1 = self._make_layer(block, shrink_input_ch, layers[0]) | |||
self.layer2 = self._make_layer( | |||
block, | |||
shrink_input_ch * 2, | |||
layers[1], | |||
stride=2, | |||
dilate=replace_stride_with_dilation[0]) | |||
self.layer3 = self._make_layer( | |||
block, | |||
shrink_input_ch * 4, | |||
layers[2], | |||
stride=2, | |||
dilate=replace_stride_with_dilation[1]) | |||
self.layer4 = self._make_layer( | |||
block, | |||
shrink_input_ch * 8, | |||
layers[3], | |||
stride=2, | |||
dilate=replace_stride_with_dilation[2]) | |||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False): | |||
norm_layer = self._norm_layer | |||
downsample = None | |||
previous_dilation = self.dilation | |||
if dilate: | |||
self.dilation *= stride | |||
stride = 1 | |||
if stride != 1 or self.inplanes != planes * block.expansion: | |||
downsample = nn.Sequential( | |||
conv1x1(self.inplanes, planes * block.expansion, stride), | |||
norm_layer(planes * block.expansion), | |||
) | |||
layers = [] | |||
layers.append( | |||
block(self.inplanes, planes, stride, downsample, self.groups, | |||
self.base_width, previous_dilation, norm_layer)) | |||
self.inplanes = planes * block.expansion | |||
for _ in range(1, blocks): | |||
layers.append( | |||
block( | |||
self.inplanes, | |||
planes, | |||
groups=self.groups, | |||
base_width=self.base_width, | |||
dilation=self.dilation, | |||
norm_layer=norm_layer)) | |||
return nn.Sequential(*layers) | |||
def forward(self, x): | |||
x = self.conv1(x) | |||
x = self.bn1(x) | |||
x = self.relu(x) | |||
x = self.maxpool(x) | |||
four_conv_layer = [] | |||
x = self.layer1(x) | |||
four_conv_layer.append(x) | |||
x = self.layer2(x) | |||
four_conv_layer.append(x) | |||
x = self.layer3(x) | |||
four_conv_layer.append(x) | |||
x = self.layer4(x) | |||
four_conv_layer.append(x) | |||
return four_conv_layer |
@@ -0,0 +1,212 @@ | |||
# Modified from https://github.com/biubug6/Pytorch_Retinaface | |||
import math | |||
from itertools import product as product | |||
from math import ceil | |||
import numpy as np | |||
import torch | |||
def transform_anchor(anchors): | |||
""" | |||
from [x0, x1, y0, y1] to [c_x, cy, w, h] | |||
x1 = x0 + w - 1 | |||
c_x = (x0 + x1) / 2 = (2x0 + w - 1) / 2 = x0 + (w - 1) / 2 | |||
""" | |||
return np.concatenate(((anchors[:, :2] + anchors[:, 2:]) / 2, | |||
anchors[:, 2:] - anchors[:, :2] + 1), | |||
axis=1) | |||
def normalize_anchor(anchors): | |||
""" | |||
from [c_x, cy, w, h] to [x0, x1, y0, y1] | |||
""" | |||
item_1 = anchors[:, :2] - (anchors[:, 2:] - 1) / 2 | |||
item_2 = anchors[:, :2] + (anchors[:, 2:] - 1) / 2 | |||
return np.concatenate((item_1, item_2), axis=1) | |||
class MogPriorBox(object): | |||
""" | |||
both for fpn and single layer, single layer need to test | |||
return (np.array) [num_anchros, 4] [x0, y0, x1, y1] | |||
""" | |||
def __init__(self, | |||
scale_list=[1.], | |||
aspect_ratio_list=[1.0], | |||
stride_list=[4, 8, 16, 32, 64, 128], | |||
anchor_size_list=[16, 32, 64, 128, 256, 512]): | |||
self.scale_list = scale_list | |||
self.aspect_ratio_list = aspect_ratio_list | |||
self.stride_list = stride_list | |||
self.anchor_size_list = anchor_size_list | |||
def __call__(self, img_height, img_width): | |||
final_anchor_list = [] | |||
for idx, stride in enumerate(self.stride_list): | |||
anchor_list = [] | |||
cur_img_height = img_height | |||
cur_img_width = img_width | |||
tmp_stride = stride | |||
while tmp_stride != 1: | |||
tmp_stride = tmp_stride // 2 | |||
cur_img_height = (cur_img_height + 1) // 2 | |||
cur_img_width = (cur_img_width + 1) // 2 | |||
for i in range(cur_img_height): | |||
for j in range(cur_img_width): | |||
for scale in self.scale_list: | |||
cx = (j + 0.5) * stride | |||
cy = (i + 0.5) * stride | |||
side_x = self.anchor_size_list[idx] * scale | |||
side_y = self.anchor_size_list[idx] * scale | |||
for ratio in self.aspect_ratio_list: | |||
anchor_list.append([ | |||
cx, cy, side_x / math.sqrt(ratio), | |||
side_y * math.sqrt(ratio) | |||
]) | |||
final_anchor_list.append(anchor_list) | |||
final_anchor_arr = np.concatenate(final_anchor_list, axis=0) | |||
normalized_anchor_arr = normalize_anchor(final_anchor_arr).astype( | |||
'float32') | |||
transformed_anchor = transform_anchor(normalized_anchor_arr) | |||
return transformed_anchor | |||
class PriorBox(object): | |||
def __init__(self, cfg, image_size=None, phase='train'): | |||
super(PriorBox, self).__init__() | |||
self.min_sizes = cfg['min_sizes'] | |||
self.steps = cfg['steps'] | |||
self.clip = cfg['clip'] | |||
self.image_size = image_size | |||
self.feature_maps = [[ | |||
ceil(self.image_size[0] / step), | |||
ceil(self.image_size[1] / step) | |||
] for step in self.steps] | |||
self.name = 's' | |||
def forward(self): | |||
anchors = [] | |||
for k, f in enumerate(self.feature_maps): | |||
min_sizes = self.min_sizes[k] | |||
for i, j in product(range(f[0]), range(f[1])): | |||
for min_size in min_sizes: | |||
s_kx = min_size / self.image_size[1] | |||
s_ky = min_size / self.image_size[0] | |||
dense_cx = [ | |||
x * self.steps[k] / self.image_size[1] | |||
for x in [j + 0.5] | |||
] | |||
dense_cy = [ | |||
y * self.steps[k] / self.image_size[0] | |||
for y in [i + 0.5] | |||
] | |||
for cy, cx in product(dense_cy, dense_cx): | |||
anchors += [cx, cy, s_kx, s_ky] | |||
# back to torch land | |||
output = torch.Tensor(anchors).view(-1, 4) | |||
if self.clip: | |||
output.clamp_(max=1, min=0) | |||
return output | |||
def py_cpu_nms(dets, thresh): | |||
"""Pure Python NMS baseline.""" | |||
x1 = dets[:, 0] | |||
y1 = dets[:, 1] | |||
x2 = dets[:, 2] | |||
y2 = dets[:, 3] | |||
scores = dets[:, 4] | |||
areas = (x2 - x1 + 1) * (y2 - y1 + 1) | |||
order = scores.argsort()[::-1] | |||
keep = [] | |||
while order.size > 0: | |||
i = order[0] | |||
keep.append(i) | |||
xx1 = np.maximum(x1[i], x1[order[1:]]) | |||
yy1 = np.maximum(y1[i], y1[order[1:]]) | |||
xx2 = np.minimum(x2[i], x2[order[1:]]) | |||
yy2 = np.minimum(y2[i], y2[order[1:]]) | |||
w = np.maximum(0.0, xx2 - xx1 + 1) | |||
h = np.maximum(0.0, yy2 - yy1 + 1) | |||
inter = w * h | |||
ovr = inter / (areas[i] + areas[order[1:]] - inter) | |||
inds = np.where(ovr <= thresh)[0] | |||
order = order[inds + 1] | |||
return keep | |||
def mogdecode(loc, anchors): | |||
""" | |||
loc: torch.Tensor | |||
anchors: 2-d, torch.Tensor (cx, cy, w, h) | |||
boxes: 2-d, torch.Tensor (x0, y0, x1, y1) | |||
""" | |||
boxes = torch.cat((anchors[:, :2] + loc[:, :2] * anchors[:, 2:], | |||
anchors[:, 2:] * torch.exp(loc[:, 2:])), 1) | |||
boxes[:, 0] -= (boxes[:, 2] - 1) / 2 | |||
boxes[:, 1] -= (boxes[:, 3] - 1) / 2 | |||
boxes[:, 2] += boxes[:, 0] - 1 | |||
boxes[:, 3] += boxes[:, 1] - 1 | |||
return boxes | |||
# Adapted from https://github.com/Hakuyume/chainer-ssd | |||
def decode(loc, priors, variances): | |||
"""Decode locations from predictions using priors to undo | |||
the encoding we did for offset regression at train time. | |||
Args: | |||
loc (tensor): location predictions for loc layers, | |||
Shape: [num_priors,4] | |||
priors (tensor): Prior boxes in center-offset form. | |||
Shape: [num_priors,4]. | |||
variances: (list[float]) Variances of priorboxes | |||
Return: | |||
decoded bounding box predictions | |||
""" | |||
boxes = torch.cat( | |||
(priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], | |||
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) | |||
boxes[:, :2] -= boxes[:, 2:] / 2 | |||
boxes[:, 2:] += boxes[:, :2] | |||
return boxes | |||
def decode_landm(pre, priors, variances): | |||
"""Decode landm from predictions using priors to undo | |||
the encoding we did for offset regression at train time. | |||
Args: | |||
pre (tensor): landm predictions for loc layers, | |||
Shape: [num_priors,10] | |||
priors (tensor): Prior boxes in center-offset form. | |||
Shape: [num_priors,4]. | |||
variances: (list[float]) Variances of priorboxes | |||
Return: | |||
decoded landm predictions | |||
""" | |||
a = priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:] | |||
b = priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:] | |||
c = priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:] | |||
d = priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:] | |||
e = priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:] | |||
landms = torch.cat((a, b, c, d, e), dim=1) | |||
return landms |
@@ -48,6 +48,7 @@ if TYPE_CHECKING: | |||
from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline, Face2DKeypointsPipeline | |||
from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipeline | |||
from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline | |||
from .mog_face_detection_pipeline import MogFaceDetectionPipeline | |||
from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline | |||
from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline | |||
from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline | |||
@@ -112,6 +113,7 @@ else: | |||
['TextDrivenSegmentationPipeline'], | |||
'movie_scene_segmentation_pipeline': | |||
['MovieSceneSegmentationPipeline'], | |||
'mog_face_detection_pipeline': ['MogFaceDetectionPipeline'], | |||
'ulfd_face_detection_pipeline': ['UlfdFaceDetectionPipeline'], | |||
'retina_face_detection_pipeline': ['RetinaFaceDetectionPipeline'], | |||
'facial_expression_recognition_pipelin': | |||
@@ -0,0 +1,54 @@ | |||
import os.path as osp | |||
from typing import Any, Dict | |||
import numpy as np | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models.cv.face_detection import MogFaceDetector | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import LoadImage | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.face_detection, module_name=Pipelines.mog_face_detection) | |||
class MogFaceDetectionPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
use `model` to create a face detection pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
super().__init__(model=model, **kwargs) | |||
ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_FILE) | |||
logger.info(f'loading model from {ckpt_path}') | |||
detector = MogFaceDetector(model_path=ckpt_path, device=self.device) | |||
self.detector = detector | |||
logger.info('load model done') | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
img = LoadImage.convert_to_ndarray(input) | |||
img = img.astype(np.float32) | |||
result = {'img': img} | |||
return result | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
result = self.detector(input) | |||
assert result is not None | |||
bboxes = result[:, :4].tolist() | |||
scores = result[:, 4].tolist() | |||
return { | |||
OutputKeys.SCORES: scores, | |||
OutputKeys.BOXES: bboxes, | |||
OutputKeys.KEYPOINTS: None, | |||
} | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
return inputs |
@@ -0,0 +1,33 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os.path as osp | |||
import unittest | |||
import cv2 | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.cv.image_utils import draw_face_detection_no_lm_result | |||
from modelscope.utils.test_utils import test_level | |||
class MogFaceDetectionTest(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.model_id = 'damo/cv_resnet101_face-detection_cvpr22papermogface' | |||
def show_result(self, img_path, detection_result): | |||
img = draw_face_detection_no_lm_result(img_path, detection_result) | |||
cv2.imwrite('result.png', img) | |||
print(f'output written to {osp.abspath("result.png")}') | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_modelhub(self): | |||
face_detection = pipeline(Tasks.face_detection, model=self.model_id) | |||
img_path = 'data/test/images/mog_face_detection.jpg' | |||
result = face_detection(img_path) | |||
self.show_result(img_path, result) | |||
if __name__ == '__main__': | |||
unittest.main() |