ly261666 yingda.chen 3 years ago
parent
commit
8f05fa8cf1
13 changed files with 898 additions and 2 deletions
  1. +3
    -0
      data/test/images/mog_face_detection.jpg
  2. +2
    -0
      modelscope/metainfo.py
  3. +3
    -2
      modelscope/models/cv/face_detection/__init__.py
  4. +1
    -0
      modelscope/models/cv/face_detection/mogface/__init__.py
  5. +0
    -0
      modelscope/models/cv/face_detection/mogface/models/__init__.py
  6. +96
    -0
      modelscope/models/cv/face_detection/mogface/models/detectors.py
  7. +135
    -0
      modelscope/models/cv/face_detection/mogface/models/mogface.py
  8. +164
    -0
      modelscope/models/cv/face_detection/mogface/models/mogprednet.py
  9. +193
    -0
      modelscope/models/cv/face_detection/mogface/models/resnet.py
  10. +212
    -0
      modelscope/models/cv/face_detection/mogface/models/utils.py
  11. +2
    -0
      modelscope/pipelines/cv/__init__.py
  12. +54
    -0
      modelscope/pipelines/cv/mog_face_detection_pipeline.py
  13. +33
    -0
      tests/pipelines/test_mog_face_detection.py

+ 3
- 0
data/test/images/mog_face_detection.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9
size 87228

+ 2
- 0
modelscope/metainfo.py View File

@@ -35,6 +35,7 @@ class Models(object):
fer = 'fer'
retinaface = 'retinaface'
shop_segmentation = 'shop-segmentation'
mogface = 'mogface'
mtcnn = 'mtcnn'
ulfd = 'ulfd'

@@ -128,6 +129,7 @@ class Pipelines(object):
ulfd_face_detection = 'manual-face-detection-ulfd'
facial_expression_recognition = 'vgg19-facial-expression-recognition-fer'
retina_face_detection = 'resnet50-face-detection-retinaface'
mog_face_detection = 'resnet101-face-detection-cvpr22papermogface'
mtcnn_face_detection = 'manual-face-detection-mtcnn'
live_category = 'live-category'
general_image_classification = 'vit-base_image-classification_ImageNet-labels'


+ 3
- 2
modelscope/models/cv/face_detection/__init__.py View File

@@ -4,15 +4,16 @@ from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .mogface import MogFaceDetector
from .mtcnn import MtcnnFaceDetector
from .retinaface import RetinaFaceDetection
from .ulfd_slim import UlfdFaceDetector

else:
_import_structure = {
'ulfd_slim': ['UlfdFaceDetector'],
'retinaface': ['RetinaFaceDetection'],
'mtcnn': ['MtcnnFaceDetector']
'mtcnn': ['MtcnnFaceDetector'],
'mogface': ['MogFaceDetector']
}

import sys


+ 1
- 0
modelscope/models/cv/face_detection/mogface/__init__.py View File

@@ -0,0 +1 @@
from .models.detectors import MogFaceDetector

+ 0
- 0
modelscope/models/cv/face_detection/mogface/models/__init__.py View File


+ 96
- 0
modelscope/models/cv/face_detection/mogface/models/detectors.py View File

@@ -0,0 +1,96 @@
import os

import cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnn

from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks
from .mogface import MogFace
from .utils import MogPriorBox, mogdecode, py_cpu_nms


@MODELS.register_module(Tasks.face_detection, module_name=Models.mogface)
class MogFaceDetector(TorchModel):

def __init__(self, model_path, device='cuda'):
super().__init__(model_path)
torch.set_grad_enabled(False)
cudnn.benchmark = True
self.model_path = model_path
self.device = device
self.net = MogFace()
self.load_model()
self.net = self.net.to(device)

self.mean = np.array([[104, 117, 123]])

def load_model(self, load_to_cpu=False):
pretrained_dict = torch.load(
self.model_path, map_location=torch.device('cpu'))
self.net.load_state_dict(pretrained_dict, strict=False)
self.net.eval()

def forward(self, input):
img_raw = input['img']
img = np.array(img_raw.cpu().detach())
img = img[:, :, ::-1]

im_height, im_width = img.shape[:2]
ss = 1.0
# tricky
if max(im_height, im_width) > 1500:
ss = 1000.0 / max(im_height, im_width)
img = cv2.resize(img, (0, 0), fx=ss, fy=ss)
im_height, im_width = img.shape[:2]

scale = torch.Tensor(
[img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
img -= np.array([[103.53, 116.28, 123.675]])
img /= np.array([[57.375, 57.120003, 58.395]])
img /= 255
img = img[:, :, ::-1].copy()
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).unsqueeze(0)
img = img.to(self.device)
scale = scale.to(self.device)

conf, loc = self.net(img) # forward pass

confidence_threshold = 0.82
nms_threshold = 0.4
top_k = 5000
keep_top_k = 750

priorbox = MogPriorBox(scale_list=[0.68])
priors = priorbox(im_height, im_width)
priors = torch.tensor(priors).to(self.device)
prior_data = priors.data

boxes = mogdecode(loc.data.squeeze(0), prior_data)
boxes = boxes.cpu().numpy()
scores = conf.squeeze(0).data.cpu().numpy()[:, 0]

# ignore low scores
inds = np.where(scores > confidence_threshold)[0]
boxes = boxes[inds]
scores = scores[inds]

# keep top-K before NMS
order = scores.argsort()[::-1][:top_k]
boxes = boxes[order]
scores = scores[order]

# do NMS
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(
np.float32, copy=False)
keep = py_cpu_nms(dets, nms_threshold)
dets = dets[keep, :]

# keep top-K faster NMS
dets = dets[:keep_top_k, :]

return dets / ss

+ 135
- 0
modelscope/models/cv/face_detection/mogface/models/mogface.py View File

@@ -0,0 +1,135 @@
# --------------------------------------------------------
# The implementation is also open-sourced by the authors as Yang Liu, and is available publicly on
# https://github.com/damo-cv/MogFace
# --------------------------------------------------------
import torch.nn as nn
import torch.nn.functional as F

from .mogprednet import MogPredNet
from .resnet import ResNet


class MogFace(nn.Module):

def __init__(self):
super(MogFace, self).__init__()
self.backbone = ResNet(depth=101)
self.fpn = LFPN()
self.pred_net = MogPredNet()

def forward(self, x):
feature_list = self.backbone(x)
fpn_list = self.fpn(feature_list)
pyramid_feature_list = fpn_list[0]
conf, loc = self.pred_net(pyramid_feature_list)
return conf, loc


class FeatureFusion(nn.Module):

def __init__(self, lat_ch=256, **channels):
super(FeatureFusion, self).__init__()
self.main_conv = nn.Conv2d(channels['main'], lat_ch, kernel_size=1)

def forward(self, up, main):
main = self.main_conv(main)
_, _, H, W = main.size()
res = F.upsample(up, scale_factor=2, mode='bilinear')
if res.size(2) != main.size(2) or res.size(3) != main.size(3):
res = res[:, :, 0:H, 0:W]
res = res + main
return res


class LFPN(nn.Module):

def __init__(self,
c2_out_ch=256,
c3_out_ch=512,
c4_out_ch=1024,
c5_out_ch=2048,
c6_mid_ch=512,
c6_out_ch=512,
c7_mid_ch=128,
c7_out_ch=256,
out_dsfd_ft=True):
super(LFPN, self).__init__()
self.out_dsfd_ft = out_dsfd_ft
if self.out_dsfd_ft:
dsfd_module = []
dsfd_module.append(nn.Conv2d(256, 256, kernel_size=3, padding=1))
dsfd_module.append(nn.Conv2d(512, 256, kernel_size=3, padding=1))
dsfd_module.append(nn.Conv2d(1024, 256, kernel_size=3, padding=1))
dsfd_module.append(nn.Conv2d(2048, 256, kernel_size=3, padding=1))
dsfd_module.append(nn.Conv2d(256, 256, kernel_size=3, padding=1))
dsfd_module.append(nn.Conv2d(256, 256, kernel_size=3, padding=1))
self.dsfd_modules = nn.ModuleList(dsfd_module)

c6_input_ch = c5_out_ch
self.c6 = nn.Sequential(*[
nn.Conv2d(
c6_input_ch,
c6_mid_ch,
kernel_size=1,
),
nn.BatchNorm2d(c6_mid_ch),
nn.ReLU(inplace=True),
nn.Conv2d(
c6_mid_ch, c6_out_ch, kernel_size=3, padding=1, stride=2),
nn.BatchNorm2d(c6_out_ch),
nn.ReLU(inplace=True)
])
self.c7 = nn.Sequential(*[
nn.Conv2d(
c6_out_ch,
c7_mid_ch,
kernel_size=1,
),
nn.BatchNorm2d(c7_mid_ch),
nn.ReLU(inplace=True),
nn.Conv2d(
c7_mid_ch, c7_out_ch, kernel_size=3, padding=1, stride=2),
nn.BatchNorm2d(c7_out_ch),
nn.ReLU(inplace=True)
])

self.p2_lat = nn.Conv2d(256, 256, kernel_size=3, padding=1)
self.p3_lat = nn.Conv2d(256, 256, kernel_size=3, padding=1)
self.p4_lat = nn.Conv2d(256, 256, kernel_size=3, padding=1)

self.c5_lat = nn.Conv2d(c6_input_ch, 256, kernel_size=3, padding=1)
self.c6_lat = nn.Conv2d(c6_out_ch, 256, kernel_size=3, padding=1)
self.c7_lat = nn.Conv2d(c7_out_ch, 256, kernel_size=3, padding=1)

self.ff_c5_c4 = FeatureFusion(main=c4_out_ch)
self.ff_c4_c3 = FeatureFusion(main=c3_out_ch)
self.ff_c3_c2 = FeatureFusion(main=c2_out_ch)

def forward(self, feature_list):
c2, c3, c4, c5 = feature_list
c6 = self.c6(c5)
c7 = self.c7(c6)

c5 = self.c5_lat(c5)
c6 = self.c6_lat(c6)
c7 = self.c7_lat(c7)

if self.out_dsfd_ft:
dsfd_fts = []
dsfd_fts.append(self.dsfd_modules[0](c2))
dsfd_fts.append(self.dsfd_modules[1](c3))
dsfd_fts.append(self.dsfd_modules[2](c4))
dsfd_fts.append(self.dsfd_modules[3](feature_list[-1]))
dsfd_fts.append(self.dsfd_modules[4](c6))
dsfd_fts.append(self.dsfd_modules[5](c7))

p4 = self.ff_c5_c4(c5, c4)
p3 = self.ff_c4_c3(p4, c3)
p2 = self.ff_c3_c2(p3, c2)

p2 = self.p2_lat(p2)
p3 = self.p3_lat(p3)
p4 = self.p4_lat(p4)

if self.out_dsfd_ft:
return ([p2, p3, p4, c5, c6, c7], dsfd_fts)

+ 164
- 0
modelscope/models/cv/face_detection/mogface/models/mogprednet.py View File

@@ -0,0 +1,164 @@
# --------------------------------------------------------
# The implementation is also open-sourced by the authors as Yang Liu, and is available publicly on
# https://github.com/damo-cv/MogFace
# --------------------------------------------------------
import math

import torch
import torch.nn as nn
import torch.nn.functional as F


class conv_bn(nn.Module):
"""docstring for conv"""

def __init__(self, in_plane, out_plane, kernel_size, stride, padding):
super(conv_bn, self).__init__()
self.conv1 = nn.Conv2d(
in_plane,
out_plane,
kernel_size=kernel_size,
stride=stride,
padding=padding)
self.bn1 = nn.BatchNorm2d(out_plane)

def forward(self, x):
x = self.conv1(x)
return self.bn1(x)


class SSHContext(nn.Module):

def __init__(self, channels, Xchannels=256):
super(SSHContext, self).__init__()

self.conv1 = nn.Conv2d(
channels, Xchannels, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(
channels,
Xchannels // 2,
kernel_size=3,
dilation=2,
stride=1,
padding=2)
self.conv2_1 = nn.Conv2d(
Xchannels // 2, Xchannels // 2, kernel_size=3, stride=1, padding=1)
self.conv2_2 = nn.Conv2d(
Xchannels // 2,
Xchannels // 2,
kernel_size=3,
dilation=2,
stride=1,
padding=2)
self.conv2_2_1 = nn.Conv2d(
Xchannels // 2, Xchannels // 2, kernel_size=3, stride=1, padding=1)

def forward(self, x):
x1 = F.relu(self.conv1(x), inplace=True)
x2 = F.relu(self.conv2(x), inplace=True)
x2_1 = F.relu(self.conv2_1(x2), inplace=True)
x2_2 = F.relu(self.conv2_2(x2), inplace=True)
x2_2 = F.relu(self.conv2_2_1(x2_2), inplace=True)

return torch.cat([x1, x2_1, x2_2], 1)


class DeepHead(nn.Module):

def __init__(self,
in_channel=256,
out_channel=256,
use_gn=False,
num_conv=4):
super(DeepHead, self).__init__()
self.use_gn = use_gn
self.num_conv = num_conv
self.conv1 = nn.Conv2d(in_channel, out_channel, 3, 1, 1)
self.conv2 = nn.Conv2d(out_channel, out_channel, 3, 1, 1)
self.conv3 = nn.Conv2d(out_channel, out_channel, 3, 1, 1)
self.conv4 = nn.Conv2d(out_channel, out_channel, 3, 1, 1)
if self.use_gn:
self.gn1 = nn.GroupNorm(16, out_channel)
self.gn2 = nn.GroupNorm(16, out_channel)
self.gn3 = nn.GroupNorm(16, out_channel)
self.gn4 = nn.GroupNorm(16, out_channel)

def forward(self, x):
if self.use_gn:
x1 = F.relu(self.gn1(self.conv1(x)), inplace=True)
x2 = F.relu(self.gn2(self.conv1(x1)), inplace=True)
x3 = F.relu(self.gn3(self.conv1(x2)), inplace=True)
x4 = F.relu(self.gn4(self.conv1(x3)), inplace=True)
else:
x1 = F.relu(self.conv1(x), inplace=True)
x2 = F.relu(self.conv1(x1), inplace=True)
if self.num_conv == 2:
return x2
x3 = F.relu(self.conv1(x2), inplace=True)
x4 = F.relu(self.conv1(x3), inplace=True)

return x4


class MogPredNet(nn.Module):

def __init__(self,
num_anchor_per_pixel=1,
num_classes=1,
input_ch_list=[256, 256, 256, 256, 256, 256],
use_deep_head=True,
deep_head_with_gn=True,
use_ssh=True,
deep_head_ch=512):
super(MogPredNet, self).__init__()
self.num_classes = num_classes
self.use_deep_head = use_deep_head
self.deep_head_with_gn = deep_head_with_gn

self.use_ssh = use_ssh

self.deep_head_ch = deep_head_ch

if self.use_ssh:
self.conv_SSH = SSHContext(input_ch_list[0],
self.deep_head_ch // 2)

if self.use_deep_head:
if self.deep_head_with_gn:
self.deep_loc_head = DeepHead(
self.deep_head_ch, self.deep_head_ch, use_gn=True)
self.deep_cls_head = DeepHead(
self.deep_head_ch, self.deep_head_ch, use_gn=True)

self.pred_cls = nn.Conv2d(self.deep_head_ch,
1 * num_anchor_per_pixel, 3, 1, 1)
self.pred_loc = nn.Conv2d(self.deep_head_ch,
4 * num_anchor_per_pixel, 3, 1, 1)

self.sigmoid = nn.Sigmoid()

def forward(self, pyramid_feature_list, dsfd_ft_list=None):
loc = []
conf = []

if self.use_deep_head:
for x in pyramid_feature_list:
if self.use_ssh:
x = self.conv_SSH(x)
x_cls = self.deep_cls_head(x)
x_loc = self.deep_loc_head(x)

conf.append(
self.pred_cls(x_cls).permute(0, 2, 3, 1).contiguous())
loc.append(
self.pred_loc(x_loc).permute(0, 2, 3, 1).contiguous())

loc = torch.cat([o.view(o.size(0), -1, 4) for o in loc], 1)
conf = torch.cat(
[o.view(o.size(0), -1, self.num_classes) for o in conf], 1)
output = (
self.sigmoid(conf.view(conf.size(0), -1, self.num_classes)),
loc.view(loc.size(0), -1, 4),
)

return output

+ 193
- 0
modelscope/models/cv/face_detection/mogface/models/resnet.py View File

@@ -0,0 +1,193 @@
# The implementation is modified from original resent implementaiton, which is
# also open-sourced by the authors as Yang Liu,
# and is available publicly on https://github.com/damo-cv/MogFace

import torch.nn as nn


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(
in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class Bottleneck(nn.Module):
expansion = 4

def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu(out)

return out


class ResNet(nn.Module):

def __init__(self,
depth=50,
groups=1,
width_per_group=64,
replace_stride_with_dilation=None,
norm_layer=None,
inplanes=64,
shrink_ch_ratio=1):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer

if depth == 50:
block = Bottleneck
layers = [3, 4, 6, 3]
elif depth == 101:
block = Bottleneck
layers = [3, 4, 23, 3]
elif depth == 152:
block = Bottleneck
layers = [3, 4, 36, 3]
elif depth == 18:
block = BasicBlock
layers = [2, 2, 2, 2]
else:
raise ValueError('only support depth in [18, 50, 101, 152]')

shrink_input_ch = int(inplanes * shrink_ch_ratio)
self.inplanes = int(inplanes * shrink_ch_ratio)
if shrink_ch_ratio == 0.125:
layers = [2, 3, 3, 3]

self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError('replace_stride_with_dilation should be None '
'or a 3-element tuple, got {}'.format(
replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(
3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, shrink_input_ch, layers[0])
self.layer2 = self._make_layer(
block,
shrink_input_ch * 2,
layers[1],
stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(
block,
shrink_input_ch * 4,
layers[2],
stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(
block,
shrink_input_ch * 8,
layers[3],
stride=2,
dilate=replace_stride_with_dilation[2])

def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)

layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm_layer=norm_layer))

return nn.Sequential(*layers)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
four_conv_layer = []
x = self.layer1(x)
four_conv_layer.append(x)
x = self.layer2(x)
four_conv_layer.append(x)
x = self.layer3(x)
four_conv_layer.append(x)
x = self.layer4(x)
four_conv_layer.append(x)

return four_conv_layer

+ 212
- 0
modelscope/models/cv/face_detection/mogface/models/utils.py View File

@@ -0,0 +1,212 @@
# Modified from https://github.com/biubug6/Pytorch_Retinaface

import math
from itertools import product as product
from math import ceil

import numpy as np
import torch


def transform_anchor(anchors):
"""
from [x0, x1, y0, y1] to [c_x, cy, w, h]
x1 = x0 + w - 1
c_x = (x0 + x1) / 2 = (2x0 + w - 1) / 2 = x0 + (w - 1) / 2
"""
return np.concatenate(((anchors[:, :2] + anchors[:, 2:]) / 2,
anchors[:, 2:] - anchors[:, :2] + 1),
axis=1)


def normalize_anchor(anchors):
"""
from [c_x, cy, w, h] to [x0, x1, y0, y1]
"""
item_1 = anchors[:, :2] - (anchors[:, 2:] - 1) / 2
item_2 = anchors[:, :2] + (anchors[:, 2:] - 1) / 2
return np.concatenate((item_1, item_2), axis=1)


class MogPriorBox(object):
"""
both for fpn and single layer, single layer need to test
return (np.array) [num_anchros, 4] [x0, y0, x1, y1]
"""

def __init__(self,
scale_list=[1.],
aspect_ratio_list=[1.0],
stride_list=[4, 8, 16, 32, 64, 128],
anchor_size_list=[16, 32, 64, 128, 256, 512]):
self.scale_list = scale_list
self.aspect_ratio_list = aspect_ratio_list
self.stride_list = stride_list
self.anchor_size_list = anchor_size_list

def __call__(self, img_height, img_width):
final_anchor_list = []

for idx, stride in enumerate(self.stride_list):
anchor_list = []
cur_img_height = img_height
cur_img_width = img_width
tmp_stride = stride

while tmp_stride != 1:
tmp_stride = tmp_stride // 2
cur_img_height = (cur_img_height + 1) // 2
cur_img_width = (cur_img_width + 1) // 2

for i in range(cur_img_height):
for j in range(cur_img_width):
for scale in self.scale_list:
cx = (j + 0.5) * stride
cy = (i + 0.5) * stride
side_x = self.anchor_size_list[idx] * scale
side_y = self.anchor_size_list[idx] * scale
for ratio in self.aspect_ratio_list:
anchor_list.append([
cx, cy, side_x / math.sqrt(ratio),
side_y * math.sqrt(ratio)
])

final_anchor_list.append(anchor_list)
final_anchor_arr = np.concatenate(final_anchor_list, axis=0)
normalized_anchor_arr = normalize_anchor(final_anchor_arr).astype(
'float32')
transformed_anchor = transform_anchor(normalized_anchor_arr)

return transformed_anchor


class PriorBox(object):

def __init__(self, cfg, image_size=None, phase='train'):
super(PriorBox, self).__init__()
self.min_sizes = cfg['min_sizes']
self.steps = cfg['steps']
self.clip = cfg['clip']
self.image_size = image_size
self.feature_maps = [[
ceil(self.image_size[0] / step),
ceil(self.image_size[1] / step)
] for step in self.steps]
self.name = 's'

def forward(self):
anchors = []
for k, f in enumerate(self.feature_maps):
min_sizes = self.min_sizes[k]
for i, j in product(range(f[0]), range(f[1])):
for min_size in min_sizes:
s_kx = min_size / self.image_size[1]
s_ky = min_size / self.image_size[0]
dense_cx = [
x * self.steps[k] / self.image_size[1]
for x in [j + 0.5]
]
dense_cy = [
y * self.steps[k] / self.image_size[0]
for y in [i + 0.5]
]
for cy, cx in product(dense_cy, dense_cx):
anchors += [cx, cy, s_kx, s_ky]

# back to torch land
output = torch.Tensor(anchors).view(-1, 4)
if self.clip:
output.clamp_(max=1, min=0)
return output


def py_cpu_nms(dets, thresh):
"""Pure Python NMS baseline."""
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4]

areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]

keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])

w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)

inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]

return keep


def mogdecode(loc, anchors):
"""
loc: torch.Tensor
anchors: 2-d, torch.Tensor (cx, cy, w, h)
boxes: 2-d, torch.Tensor (x0, y0, x1, y1)
"""

boxes = torch.cat((anchors[:, :2] + loc[:, :2] * anchors[:, 2:],
anchors[:, 2:] * torch.exp(loc[:, 2:])), 1)

boxes[:, 0] -= (boxes[:, 2] - 1) / 2
boxes[:, 1] -= (boxes[:, 3] - 1) / 2
boxes[:, 2] += boxes[:, 0] - 1
boxes[:, 3] += boxes[:, 1] - 1

return boxes


# Adapted from https://github.com/Hakuyume/chainer-ssd
def decode(loc, priors, variances):
"""Decode locations from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
loc (tensor): location predictions for loc layers,
Shape: [num_priors,4]
priors (tensor): Prior boxes in center-offset form.
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded bounding box predictions
"""

boxes = torch.cat(
(priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
boxes[:, :2] -= boxes[:, 2:] / 2
boxes[:, 2:] += boxes[:, :2]
return boxes


def decode_landm(pre, priors, variances):
"""Decode landm from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
pre (tensor): landm predictions for loc layers,
Shape: [num_priors,10]
priors (tensor): Prior boxes in center-offset form.
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded landm predictions
"""
a = priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:]
b = priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:]
c = priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:]
d = priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:]
e = priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:]
landms = torch.cat((a, b, c, d, e), dim=1)
return landms

+ 2
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -48,6 +48,7 @@ if TYPE_CHECKING:
from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline, Face2DKeypointsPipeline
from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipeline
from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline
from .mog_face_detection_pipeline import MogFaceDetectionPipeline
from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline
from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline
from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline
@@ -112,6 +113,7 @@ else:
['TextDrivenSegmentationPipeline'],
'movie_scene_segmentation_pipeline':
['MovieSceneSegmentationPipeline'],
'mog_face_detection_pipeline': ['MogFaceDetectionPipeline'],
'ulfd_face_detection_pipeline': ['UlfdFaceDetectionPipeline'],
'retina_face_detection_pipeline': ['RetinaFaceDetectionPipeline'],
'facial_expression_recognition_pipelin':


+ 54
- 0
modelscope/pipelines/cv/mog_face_detection_pipeline.py View File

@@ -0,0 +1,54 @@
import os.path as osp
from typing import Any, Dict

import numpy as np

from modelscope.metainfo import Pipelines
from modelscope.models.cv.face_detection import MogFaceDetector
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.face_detection, module_name=Pipelines.mog_face_detection)
class MogFaceDetectionPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create a face detection pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_FILE)
logger.info(f'loading model from {ckpt_path}')
detector = MogFaceDetector(model_path=ckpt_path, device=self.device)
self.detector = detector
logger.info('load model done')

def preprocess(self, input: Input) -> Dict[str, Any]:
img = LoadImage.convert_to_ndarray(input)
img = img.astype(np.float32)
result = {'img': img}
return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:

result = self.detector(input)
assert result is not None
bboxes = result[:, :4].tolist()
scores = result[:, 4].tolist()
return {
OutputKeys.SCORES: scores,
OutputKeys.BOXES: bboxes,
OutputKeys.KEYPOINTS: None,
}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 33
- 0
tests/pipelines/test_mog_face_detection.py View File

@@ -0,0 +1,33 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import unittest

import cv2

from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.cv.image_utils import draw_face_detection_no_lm_result
from modelscope.utils.test_utils import test_level


class MogFaceDetectionTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/cv_resnet101_face-detection_cvpr22papermogface'

def show_result(self, img_path, detection_result):
img = draw_face_detection_no_lm_result(img_path, detection_result)
cv2.imwrite('result.png', img)
print(f'output written to {osp.abspath("result.png")}')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
face_detection = pipeline(Tasks.face_detection, model=self.model_id)
img_path = 'data/test/images/mog_face_detection.jpg'

result = face_detection(img_path)
self.show_result(img_path, result)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save