Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10780109master^2
@@ -48,6 +48,7 @@ class Models(object): | |||||
ulfd = 'ulfd' | ulfd = 'ulfd' | ||||
arcface = 'arcface' | arcface = 'arcface' | ||||
facemask = 'facemask' | facemask = 'facemask' | ||||
flc = 'flc' | |||||
tinymog = 'tinymog' | tinymog = 'tinymog' | ||||
video_inpainting = 'video-inpainting' | video_inpainting = 'video-inpainting' | ||||
human_wholebody_keypoint = 'human-wholebody-keypoint' | human_wholebody_keypoint = 'human-wholebody-keypoint' | ||||
@@ -186,6 +187,7 @@ class Pipelines(object): | |||||
ulfd_face_detection = 'manual-face-detection-ulfd' | ulfd_face_detection = 'manual-face-detection-ulfd' | ||||
tinymog_face_detection = 'manual-face-detection-tinymog' | tinymog_face_detection = 'manual-face-detection-tinymog' | ||||
facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' | facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' | ||||
facial_landmark_confidence = 'manual-facial-landmark-confidence-flcm' | |||||
face_attribute_recognition = 'resnet34-face-attribute-recognition-fairface' | face_attribute_recognition = 'resnet34-face-attribute-recognition-fairface' | ||||
retina_face_detection = 'resnet50-face-detection-retinaface' | retina_face_detection = 'resnet50-face-detection-retinaface' | ||||
mog_face_detection = 'resnet101-face-detection-cvpr22papermogface' | mog_face_detection = 'resnet101-face-detection-cvpr22papermogface' | ||||
@@ -204,6 +206,7 @@ class Pipelines(object): | |||||
realtime_object_detection = 'cspnet_realtime-object-detection_yolox' | realtime_object_detection = 'cspnet_realtime-object-detection_yolox' | ||||
realtime_video_object_detection = 'cspnet_realtime-video-object-detection_streamyolo' | realtime_video_object_detection = 'cspnet_realtime-video-object-detection_streamyolo' | ||||
face_recognition = 'ir101-face-recognition-cfglint' | face_recognition = 'ir101-face-recognition-cfglint' | ||||
arc_face_recognition = 'ir50-face-recognition-arcface' | |||||
mask_face_recognition = 'resnet-face-recognition-facemask' | mask_face_recognition = 'resnet-face-recognition-facemask' | ||||
image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' | image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' | ||||
image2image_translation = 'image-to-image-translation' | image2image_translation = 'image-to-image-translation' | ||||
@@ -0,0 +1,200 @@ | |||||
# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at | |||||
# https://github.com/deepinsight/insightface/blob/master/recognition/arcface_torch/backbones/iresnet.py | |||||
import torch | |||||
from torch import nn | |||||
from torch.utils.checkpoint import checkpoint | |||||
using_ckpt = False | |||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | |||||
"""3x3 convolution with padding""" | |||||
return nn.Conv2d( | |||||
in_planes, | |||||
out_planes, | |||||
kernel_size=3, | |||||
stride=stride, | |||||
padding=dilation, | |||||
groups=groups, | |||||
bias=False, | |||||
dilation=dilation) | |||||
def conv1x1(in_planes, out_planes, stride=1): | |||||
"""1x1 convolution""" | |||||
return nn.Conv2d( | |||||
in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |||||
class IBasicBlock(nn.Module): | |||||
expansion = 1 | |||||
def __init__(self, | |||||
inplanes, | |||||
planes, | |||||
stride=1, | |||||
downsample=None, | |||||
groups=1, | |||||
base_width=64, | |||||
dilation=1): | |||||
super(IBasicBlock, self).__init__() | |||||
if groups != 1 or base_width != 64: | |||||
raise ValueError( | |||||
'BasicBlock only supports groups=1 and base_width=64') | |||||
if dilation > 1: | |||||
raise NotImplementedError( | |||||
'Dilation > 1 not supported in BasicBlock') | |||||
self.bn1 = nn.BatchNorm2d( | |||||
inplanes, | |||||
eps=1e-05, | |||||
) | |||||
self.conv1 = conv3x3(inplanes, planes) | |||||
self.bn2 = nn.BatchNorm2d( | |||||
planes, | |||||
eps=1e-05, | |||||
) | |||||
self.prelu = nn.PReLU(planes) | |||||
self.conv2 = conv3x3(planes, planes, stride) | |||||
self.bn3 = nn.BatchNorm2d( | |||||
planes, | |||||
eps=1e-05, | |||||
) | |||||
self.downsample = downsample | |||||
self.stride = stride | |||||
def forward(self, x): | |||||
identity = x | |||||
out = self.bn1(x) | |||||
out = self.conv1(out) | |||||
out = self.bn2(out) | |||||
out = self.prelu(out) | |||||
out = self.conv2(out) | |||||
out = self.bn3(out) | |||||
if self.downsample is not None: | |||||
identity = self.downsample(x) | |||||
out += identity | |||||
return out | |||||
class IResNet(nn.Module): | |||||
fc_scale = 7 * 7 | |||||
def __init__(self, | |||||
block, | |||||
layers, | |||||
dropout=0, | |||||
num_features=512, | |||||
zero_init_residual=False, | |||||
groups=1, | |||||
width_per_group=64, | |||||
replace_stride_with_dilation=None, | |||||
fp16=False): | |||||
super(IResNet, self).__init__() | |||||
self.extra_gflops = 0.0 | |||||
self.fp16 = fp16 | |||||
self.inplanes = 64 | |||||
self.dilation = 1 | |||||
if replace_stride_with_dilation is None: | |||||
replace_stride_with_dilation = [False, False, False] | |||||
if len(replace_stride_with_dilation) != 3: | |||||
raise ValueError('replace_stride_with_dilation should be None ' | |||||
'or a 3-element tuple, got {}'.format( | |||||
replace_stride_with_dilation)) | |||||
self.groups = groups | |||||
self.base_width = width_per_group | |||||
self.conv1 = nn.Conv2d( | |||||
3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) | |||||
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) | |||||
self.prelu = nn.PReLU(self.inplanes) | |||||
self.layer1 = self._make_layer(block, 64, layers[0], stride=2) | |||||
self.layer2 = self._make_layer( | |||||
block, | |||||
128, | |||||
layers[1], | |||||
stride=2, | |||||
dilate=replace_stride_with_dilation[0]) | |||||
self.layer3 = self._make_layer( | |||||
block, | |||||
256, | |||||
layers[2], | |||||
stride=2, | |||||
dilate=replace_stride_with_dilation[1]) | |||||
self.layer4 = self._make_layer( | |||||
block, | |||||
512, | |||||
layers[3], | |||||
stride=2, | |||||
dilate=replace_stride_with_dilation[2]) | |||||
self.bn2 = nn.BatchNorm2d( | |||||
512 * block.expansion, | |||||
eps=1e-05, | |||||
) | |||||
self.dropout = nn.Dropout(p=dropout, inplace=True) | |||||
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, | |||||
num_features) | |||||
self.features = nn.BatchNorm1d(num_features, eps=1e-05) | |||||
nn.init.constant_(self.features.weight, 1.0) | |||||
self.features.weight.requires_grad = False | |||||
for m in self.modules(): | |||||
if isinstance(m, nn.Conv2d): | |||||
nn.init.normal_(m.weight, 0, 0.1) | |||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | |||||
nn.init.constant_(m.weight, 1) | |||||
nn.init.constant_(m.bias, 0) | |||||
if zero_init_residual: | |||||
for m in self.modules(): | |||||
if isinstance(m, IBasicBlock): | |||||
nn.init.constant_(m.bn2.weight, 0) | |||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False): | |||||
downsample = None | |||||
previous_dilation = self.dilation | |||||
if dilate: | |||||
self.dilation *= stride | |||||
stride = 1 | |||||
if stride != 1 or self.inplanes != planes * block.expansion: | |||||
downsample = nn.Sequential( | |||||
conv1x1(self.inplanes, planes * block.expansion, stride), | |||||
nn.BatchNorm2d( | |||||
planes * block.expansion, | |||||
eps=1e-05, | |||||
), | |||||
) | |||||
layers = [] | |||||
layers.append( | |||||
block(self.inplanes, planes, stride, downsample, self.groups, | |||||
self.base_width, previous_dilation)) | |||||
self.inplanes = planes * block.expansion | |||||
for _ in range(1, blocks): | |||||
layers.append( | |||||
block( | |||||
self.inplanes, | |||||
planes, | |||||
groups=self.groups, | |||||
base_width=self.base_width, | |||||
dilation=self.dilation)) | |||||
return nn.Sequential(*layers) | |||||
def forward(self, x): | |||||
with torch.cuda.amp.autocast(self.fp16): | |||||
x = self.conv1(x) | |||||
x = self.bn1(x) | |||||
x = self.prelu(x) | |||||
x = self.layer1(x) | |||||
x = self.layer2(x) | |||||
x = self.layer3(x) | |||||
x = self.layer4(x) | |||||
x = self.bn2(x) | |||||
x = torch.flatten(x, 1) | |||||
x = self.dropout(x) | |||||
x = self.fc(x.float() if self.fp16 else x) | |||||
x = self.features(x) | |||||
return x | |||||
def _iresnet(arch, layers): | |||||
model = IResNet(IBasicBlock, layers) | |||||
return model |
@@ -0,0 +1,20 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from typing import TYPE_CHECKING | |||||
from modelscope.utils.import_utils import LazyImportModule | |||||
if TYPE_CHECKING: | |||||
from .flc import FacialLandmarkConfidence | |||||
else: | |||||
_import_structure = {'flc': ['FacialLandmarkConfidence']} | |||||
import sys | |||||
sys.modules[__name__] = LazyImportModule( | |||||
__name__, | |||||
globals()['__file__'], | |||||
_import_structure, | |||||
module_spec=__spec__, | |||||
extra_objects={}, | |||||
) |
@@ -0,0 +1,2 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from .facial_landmark_confidence import FacialLandmarkConfidence |
@@ -0,0 +1,94 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os | |||||
import cv2 | |||||
import numpy as np | |||||
import torch | |||||
import torch.backends.cudnn as cudnn | |||||
import torch.nn.functional as F | |||||
from PIL import Image | |||||
from torch.autograd import Variable | |||||
from modelscope.metainfo import Models | |||||
from modelscope.models.base import Tensor, TorchModel | |||||
from modelscope.models.builder import MODELS | |||||
from modelscope.utils.constant import ModelFile, Tasks | |||||
from .manual_landmark_net import LandmarkConfidence | |||||
@MODELS.register_module( | |||||
Tasks.facial_landmark_confidence, module_name=Models.flc) | |||||
class FacialLandmarkConfidence(TorchModel): | |||||
def __init__(self, model_path, device='cuda'): | |||||
super().__init__(model_path) | |||||
cudnn.benchmark = True | |||||
self.model_path = model_path | |||||
self.device = device | |||||
self.cfg_path = model_path.replace(ModelFile.TORCH_MODEL_FILE, | |||||
ModelFile.CONFIGURATION) | |||||
self.landmark_count = 5 | |||||
self.net = LandmarkConfidence(landmark_count=self.landmark_count) | |||||
self.load_model() | |||||
self.net = self.net.to(device) | |||||
def load_model(self, load_to_cpu=False): | |||||
pretrained_dict = torch.load( | |||||
self.model_path, map_location=torch.device('cpu'))['state_dict'] | |||||
pretrained_dict['rp_net.binary_cls.weight'] = 32.0 * F.normalize( | |||||
pretrained_dict['rp_net.binary_cls.weight'], dim=1).t() | |||||
self.net.load_state_dict(pretrained_dict, strict=True) | |||||
self.net.eval() | |||||
def forward(self, input): | |||||
img_org = input['orig_img'] | |||||
bbox = input['bbox'] | |||||
img_org = img_org.cpu().numpy() | |||||
image_height = img_org.shape[0] | |||||
image_width = img_org.shape[1] | |||||
x1 = max(0, int(bbox[0])) | |||||
y1 = max(0, int(bbox[1])) | |||||
x2 = min(image_width, int(bbox[2])) | |||||
y2 = min(image_height, int(bbox[3])) | |||||
box_w = x2 - x1 + 1 | |||||
box_h = y2 - y1 + 1 | |||||
if box_h > box_w: | |||||
delta = box_h - box_w | |||||
dy = edy = 0 | |||||
dx = delta // 2 | |||||
edx = delta - dx | |||||
else: | |||||
dx = edx = 0 | |||||
delta = box_w - box_h | |||||
dy = delta // 2 | |||||
edy = delta - dy | |||||
cv_img = img_org[y1:y2, x1:x2] | |||||
if dx > 0 or dy > 0 or edx > 0 or edy > 0: | |||||
cv_img = cv2.copyMakeBorder(cv_img, dy, edy, dx, edx, | |||||
cv2.BORDER_CONSTANT, 0) | |||||
inter_x = cv_img.shape[1] | |||||
inter_y = cv_img.shape[0] | |||||
cv_img = cv2.resize(cv_img, (120, 120)) | |||||
cv_img = cv_img.transpose((2, 0, 1)) | |||||
input_blob = torch.from_numpy(cv_img[np.newaxis, :, :, :].astype( | |||||
np.float32)) | |||||
tmp_conf_lms, tmp_feat, tmp_conf_resp, tmp_nose = self.net( | |||||
input_blob.to(self.device)) | |||||
conf_lms = tmp_conf_lms.cpu().numpy().squeeze() | |||||
feat = tmp_feat.cpu().numpy().squeeze() | |||||
pts5pt = [] | |||||
for i in range(feat.shape[0]): | |||||
if i < self.landmark_count: | |||||
pts5pt.append(feat[i] * inter_x - dx + x1) | |||||
else: | |||||
pts5pt.append(feat[i] * inter_y - dy + y1) | |||||
lm5pt = np.array(pts5pt).reshape(2, 5).T | |||||
return lm5pt, conf_lms |
@@ -0,0 +1,152 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import math | |||||
import torch | |||||
import torch.nn.functional as F | |||||
from torch.nn import (AdaptiveAvgPool2d, BatchNorm2d, Conv2d, Linear, | |||||
MaxPool2d, Module, Parameter, ReLU, Sequential) | |||||
class LandmarkConfidence(Module): | |||||
def __init__(self, landmark_count=5): | |||||
super(LandmarkConfidence, self).__init__() | |||||
self.landmark_net = LandmarkNetD(landmark_count) | |||||
self.landmark_net.eval() | |||||
self.cls_net = ClassNet() | |||||
self.cls_net.eval() | |||||
self.rp_net = RespiratorNet() | |||||
def forward(self, x): | |||||
feat, nose_feat, lms = self.landmark_net(x) | |||||
cls_respirator, nose = self.rp_net(feat, nose_feat) | |||||
confidence = self.cls_net(feat) | |||||
return confidence, lms, cls_respirator, nose | |||||
class FC(Module): | |||||
def __init__(self, feat_dim=256, num_class=2): | |||||
super(FC, self).__init__() | |||||
self.weight = Parameter( | |||||
torch.zeros(num_class, feat_dim, dtype=torch.float32)) | |||||
def forward(self, x): | |||||
cos_theta = F.linear(x, self.weight) | |||||
return F.softmax(cos_theta, dim=1) | |||||
class Flatten(Module): | |||||
def forward(self, x): | |||||
return torch.flatten(x, 1) | |||||
class RespiratorNet(Module): | |||||
def __init__(self): | |||||
super(RespiratorNet, self).__init__() | |||||
self.conv1 = Sequential( | |||||
Conv2d(48, 48, 3, 2, 1), BatchNorm2d(48), ReLU(True)) | |||||
self.conv2 = AdaptiveAvgPool2d( | |||||
(1, 1) | |||||
) # Sequential(Conv2d(48, 48, 5, 1, 0), BatchNorm2d(48), ReLU(True)) | |||||
self.binary_cls = FC(feat_dim=48, num_class=2) | |||||
self.nose_layer = Sequential( | |||||
Conv2d(48, 64, 3, 1, 0), BatchNorm2d(64), ReLU(True), | |||||
Conv2d(64, 64, 3, 1, 0), BatchNorm2d(64), ReLU(True), Flatten(), | |||||
Linear(64, 96), ReLU(True), Linear(96, 6)) | |||||
def train(self, mode=True): | |||||
self.conv1.train(mode) | |||||
self.conv2.train(mode) | |||||
# self.nose_feat.train(mode) | |||||
self.nose_layer.train(mode) | |||||
self.binary_cls.train(mode) | |||||
def forward(self, x, y): | |||||
x = self.conv1(x) | |||||
x = self.conv2(x) | |||||
cls = self.binary_cls(torch.flatten(x, 1)) | |||||
# loc = self.nose_feat(y) | |||||
loc = self.nose_layer(y) | |||||
return cls, loc | |||||
class ClassNet(Module): | |||||
def __init__(self): | |||||
super(ClassNet, self).__init__() | |||||
self.conv1 = Sequential( | |||||
Conv2d(48, 48, 3, 1, 1), BatchNorm2d(48), ReLU(True)) | |||||
self.conv2 = Sequential( | |||||
Conv2d(48, 54, 3, 2, 1), BatchNorm2d(54), ReLU(True)) | |||||
self.conv3 = Sequential( | |||||
Conv2d(54, 54, 5, 1, 0), BatchNorm2d(54), ReLU(True)) | |||||
self.fc1 = Sequential(Flatten(), Linear(54, 54), ReLU(True)) | |||||
self.fc2 = Linear(54, 1) | |||||
def forward(self, x): | |||||
y = self.conv1(x) | |||||
y = self.conv2(y) | |||||
y = self.conv3(y) | |||||
y = self.fc1(y) | |||||
y = self.fc2(y) | |||||
return y | |||||
class LandmarkNetD(Module): | |||||
def __init__(self, landmark_count=5): | |||||
super(LandmarkNetD, self).__init__() | |||||
self.conv_pre = Sequential( | |||||
Conv2d(3, 16, 5, 2, 0), BatchNorm2d(16), ReLU(True)) | |||||
self.pool_pre = MaxPool2d(2, 2) # output is 29 | |||||
self.conv1 = Sequential( | |||||
Conv2d(16, 32, 3, 1, 1), BatchNorm2d(32), ReLU(True), | |||||
Conv2d(32, 32, 3, 1, 1), BatchNorm2d(32), ReLU(True)) | |||||
self.pool1 = MaxPool2d(2, 2) # 14 | |||||
self.conv2 = Sequential( | |||||
Conv2d(32, 48, 3, 1, 0), BatchNorm2d(48), ReLU(True), | |||||
Conv2d(48, 48, 3, 1, 0), BatchNorm2d(48), ReLU(True)) | |||||
self.pool2 = MaxPool2d(2, 2) # 5 | |||||
self.conv3 = Sequential( | |||||
Conv2d(48, 80, 3, 1, 0), BatchNorm2d(80), ReLU(True), | |||||
Conv2d(80, 80, 3, 1, 0), BatchNorm2d(80), ReLU(True)) | |||||
self.fc1 = Sequential(Linear(80, 128), ReLU(True)) | |||||
self.fc2 = Sequential(Linear(128, 128), ReLU(True)) | |||||
self.output = Linear(128, landmark_count * 2) | |||||
def _initialize_weights(self): | |||||
for m in self.modules(): | |||||
if isinstance(m, Conv2d): | |||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||||
m.weight.data.normal_(0, math.sqrt(2. / n)) | |||||
if m.bias is not None: | |||||
m.bias.data.zero_() | |||||
elif isinstance(m, BatchNorm2d): | |||||
m.weight.data.fill_(1) | |||||
m.bias.data.zero_() | |||||
elif isinstance(m, Linear): | |||||
n = m.weight.size(1) | |||||
m.weight.data.normal_(0, 0.01) | |||||
m.bias.data.zero_() | |||||
def forward(self, x): | |||||
y = self.conv_pre(x) | |||||
y = self.pool_pre(y) | |||||
y = self.conv1(y) | |||||
y = self.pool1(y[:, :, :28, :28]) | |||||
feat = self.conv2(y) | |||||
y2 = self.pool2(feat) | |||||
y = self.conv3(y2) | |||||
y = torch.flatten(y, 1) | |||||
y = self.fc1(y) | |||||
y = self.fc2(y) | |||||
y = self.output(y) | |||||
return feat, y2, y |
@@ -137,6 +137,26 @@ TASK_OUTPUTS = { | |||||
Tasks.facial_expression_recognition: | Tasks.facial_expression_recognition: | ||||
[OutputKeys.SCORES, OutputKeys.LABELS], | [OutputKeys.SCORES, OutputKeys.LABELS], | ||||
# face processing base result for single img | |||||
# { | |||||
# "scores": [0.85] | |||||
# "boxes": [x1, y1, x2, y2] | |||||
# "keypoints": [x1, y1, x2, y2, x3, y3, x4, y4] | |||||
# } | |||||
Tasks.face_processing_base: [ | |||||
OutputKeys.OUTPUT_IMG, OutputKeys.SCORES, OutputKeys.BOXES, | |||||
OutputKeys.KEYPOINTS | |||||
], | |||||
# facial landmark confidence result for single sample | |||||
# { | |||||
# "output_img": np.array with shape(h, w, 3) (output_img = aligned_img) | |||||
# "scores": [0.85] | |||||
# "keypoints": [x1, y1, x2, y2, x3, y3, x4, y4] | |||||
# "boxes": [x1, y1, x2, y2] | |||||
# } | |||||
Tasks.facial_landmark_confidence: | |||||
[OutputKeys.SCORES, OutputKeys.KEYPOINTS, OutputKeys.BOXES], | |||||
# face attribute recognition result for single sample | # face attribute recognition result for single sample | ||||
# { | # { | ||||
# "scores": [[0.9, 0.1], [0.92, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01] | # "scores": [[0.9, 0.1], [0.92, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01] | ||||
@@ -447,8 +467,9 @@ TASK_OUTPUTS = { | |||||
# "masks": [np.array # 3D array with shape [frame_num, height, width]] | # "masks": [np.array # 3D array with shape [frame_num, height, width]] | ||||
# "timestamps": ["hh:mm:ss", "hh:mm:ss", "hh:mm:ss"] | # "timestamps": ["hh:mm:ss", "hh:mm:ss", "hh:mm:ss"] | ||||
# } | # } | ||||
Tasks.referring_video_object_segmentation: | |||||
[OutputKeys.MASKS, OutputKeys.TIMESTAMPS], | |||||
Tasks.referring_video_object_segmentation: [ | |||||
OutputKeys.MASKS, OutputKeys.TIMESTAMPS | |||||
], | |||||
# video human matting result for a single video | # video human matting result for a single video | ||||
# { | # { | ||||
@@ -135,6 +135,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
Tasks.facial_expression_recognition: | Tasks.facial_expression_recognition: | ||||
(Pipelines.facial_expression_recognition, | (Pipelines.facial_expression_recognition, | ||||
'damo/cv_vgg19_facial-expression-recognition_fer'), | 'damo/cv_vgg19_facial-expression-recognition_fer'), | ||||
Tasks.facial_landmark_confidence: | |||||
(Pipelines.facial_landmark_confidence, | |||||
'damo/cv_manual_facial-landmark-confidence_flcm'), | |||||
Tasks.face_attribute_recognition: | Tasks.face_attribute_recognition: | ||||
(Pipelines.face_attribute_recognition, | (Pipelines.face_attribute_recognition, | ||||
'damo/cv_resnet34_face-attribute-recognition_fairface'), | 'damo/cv_resnet34_face-attribute-recognition_fairface'), | ||||
@@ -18,6 +18,7 @@ if TYPE_CHECKING: | |||||
from .face_detection_pipeline import FaceDetectionPipeline | from .face_detection_pipeline import FaceDetectionPipeline | ||||
from .face_image_generation_pipeline import FaceImageGenerationPipeline | from .face_image_generation_pipeline import FaceImageGenerationPipeline | ||||
from .face_recognition_pipeline import FaceRecognitionPipeline | from .face_recognition_pipeline import FaceRecognitionPipeline | ||||
from .arc_face_recognition_pipeline import ArcFaceRecognitionPipeline | |||||
from .mask_face_recognition_pipeline import MaskFaceRecognitionPipeline | from .mask_face_recognition_pipeline import MaskFaceRecognitionPipeline | ||||
from .general_recognition_pipeline import GeneralRecognitionPipeline | from .general_recognition_pipeline import GeneralRecognitionPipeline | ||||
from .image_cartoon_pipeline import ImageCartoonPipeline | from .image_cartoon_pipeline import ImageCartoonPipeline | ||||
@@ -59,6 +60,8 @@ if TYPE_CHECKING: | |||||
from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline | from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline | ||||
from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline | from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline | ||||
from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline | from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline | ||||
from .facial_landmark_confidence_pipeline import FacialLandmarkConfidencePipeline | |||||
from .face_processing_base_pipeline import FaceProcessingBasePipeline | |||||
from .face_attribute_recognition_pipeline import FaceAttributeRecognitionPipeline | from .face_attribute_recognition_pipeline import FaceAttributeRecognitionPipeline | ||||
from .mtcnn_face_detection_pipeline import MtcnnFaceDetectionPipelin | from .mtcnn_face_detection_pipeline import MtcnnFaceDetectionPipelin | ||||
from .hand_static_pipeline import HandStaticPipeline | from .hand_static_pipeline import HandStaticPipeline | ||||
@@ -81,6 +84,7 @@ else: | |||||
'face_detection_pipeline': ['FaceDetectionPipeline'], | 'face_detection_pipeline': ['FaceDetectionPipeline'], | ||||
'face_image_generation_pipeline': ['FaceImageGenerationPipeline'], | 'face_image_generation_pipeline': ['FaceImageGenerationPipeline'], | ||||
'face_recognition_pipeline': ['FaceRecognitionPipeline'], | 'face_recognition_pipeline': ['FaceRecognitionPipeline'], | ||||
'arc_face_recognition_pipeline': ['ArcFaceRecognitionPipeline'], | |||||
'mask_face_recognition_pipeline': ['MaskFaceRecognitionPipeline'], | 'mask_face_recognition_pipeline': ['MaskFaceRecognitionPipeline'], | ||||
'general_recognition_pipeline': ['GeneralRecognitionPipeline'], | 'general_recognition_pipeline': ['GeneralRecognitionPipeline'], | ||||
'image_classification_pipeline': | 'image_classification_pipeline': | ||||
@@ -135,6 +139,10 @@ else: | |||||
'retina_face_detection_pipeline': ['RetinaFaceDetectionPipeline'], | 'retina_face_detection_pipeline': ['RetinaFaceDetectionPipeline'], | ||||
'facial_expression_recognition_pipeline': | 'facial_expression_recognition_pipeline': | ||||
['FacialExpressionRecognitionPipeline'], | ['FacialExpressionRecognitionPipeline'], | ||||
'facial_landmark_confidence_pipeline': [ | |||||
'FacialLandmarkConfidencePipeline' | |||||
], | |||||
'face_processing_base_pipeline': ['FaceProcessingBasePipeline'], | |||||
'face_attribute_recognition_pipeline': [ | 'face_attribute_recognition_pipeline': [ | ||||
'FaceAttributeRecognitionPipeline' | 'FaceAttributeRecognitionPipeline' | ||||
], | ], | ||||
@@ -0,0 +1,66 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os.path as osp | |||||
from typing import Any, Dict | |||||
import cv2 | |||||
import numpy as np | |||||
import PIL | |||||
import torch | |||||
from modelscope.metainfo import Pipelines | |||||
from modelscope.models.cv.face_recognition.align_face import align_face | |||||
from modelscope.models.cv.face_recognition.torchkit.backbone.arcface_backbone import \ | |||||
_iresnet | |||||
from modelscope.outputs import OutputKeys | |||||
from modelscope.pipelines import pipeline | |||||
from modelscope.pipelines.base import Input, Pipeline | |||||
from modelscope.pipelines.builder import PIPELINES | |||||
from modelscope.preprocessors import LoadImage | |||||
from modelscope.utils.constant import ModelFile, Tasks | |||||
from modelscope.utils.logger import get_logger | |||||
from . import FaceProcessingBasePipeline | |||||
logger = get_logger() | |||||
@PIPELINES.register_module( | |||||
Tasks.face_recognition, module_name=Pipelines.arc_face_recognition) | |||||
class ArcFaceRecognitionPipeline(FaceProcessingBasePipeline): | |||||
def __init__(self, model: str, **kwargs): | |||||
""" | |||||
use `model` to create a face recognition pipeline for prediction | |||||
Args: | |||||
model: model id on modelscope hub. | |||||
""" | |||||
# face recong model | |||||
super().__init__(model=model, **kwargs) | |||||
face_model = _iresnet('arcface_i50', [3, 4, 14, 3]) | |||||
face_model.load_state_dict( | |||||
torch.load( | |||||
osp.join(model, ModelFile.TORCH_MODEL_FILE), | |||||
map_location=self.device)) | |||||
face_model = face_model.to(self.device) | |||||
face_model.eval() | |||||
self.face_model = face_model | |||||
logger.info('face recognition model loaded!') | |||||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
result = super(ArcFaceRecognitionPipeline, self).preprocess(input) | |||||
align_img = result['img'] | |||||
face_img = align_img[:, :, ::-1] # to rgb | |||||
face_img = np.transpose(face_img, axes=(2, 0, 1)) | |||||
face_img = (face_img / 255. - 0.5) / 0.5 | |||||
face_img = face_img.astype(np.float32) | |||||
result['img'] = face_img | |||||
return result | |||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
img = input['img'].unsqueeze(0) | |||||
emb = self.face_model(img).detach().cpu().numpy() | |||||
emb /= np.sqrt(np.sum(emb**2, -1, keepdims=True)) # l2 norm | |||||
return {OutputKeys.IMG_EMBEDDING: emb} | |||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
return inputs |
@@ -0,0 +1,119 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os.path as osp | |||||
from typing import Any, Dict | |||||
import cv2 | |||||
import numpy as np | |||||
import PIL | |||||
import torch | |||||
from modelscope.metainfo import Pipelines | |||||
from modelscope.models.cv.face_recognition.align_face import align_face | |||||
from modelscope.outputs import OutputKeys | |||||
from modelscope.pipelines import pipeline | |||||
from modelscope.pipelines.base import Input, Pipeline | |||||
from modelscope.pipelines.builder import PIPELINES | |||||
from modelscope.preprocessors import LoadImage | |||||
from modelscope.utils.constant import ModelFile, Tasks | |||||
from modelscope.utils.logger import get_logger | |||||
logger = get_logger() | |||||
class FaceProcessingBasePipeline(Pipeline): | |||||
def __init__(self, model: str, **kwargs): | |||||
""" | |||||
use `model` to create a face processing pipeline and output cropped img, scores, bbox and lmks. | |||||
Args: | |||||
model: model id on modelscope hub. | |||||
""" | |||||
super().__init__(model=model, **kwargs) | |||||
# face detect pipeline | |||||
det_model_id = 'damo/cv_resnet50_face-detection_retinaface' | |||||
self.face_detection = pipeline( | |||||
Tasks.face_detection, model=det_model_id) | |||||
def _choose_face(self, | |||||
det_result, | |||||
min_face=10, | |||||
top_face=1, | |||||
center_face=False): | |||||
''' | |||||
choose face with maximum area | |||||
Args: | |||||
det_result: output of face detection pipeline | |||||
min_face: minimum size of valid face w/h | |||||
top_face: take faces with top max areas | |||||
center_face: choose the most centerd face from multi faces, only valid if top_face > 1 | |||||
''' | |||||
bboxes = np.array(det_result[OutputKeys.BOXES]) | |||||
landmarks = np.array(det_result[OutputKeys.KEYPOINTS]) | |||||
scores = np.array(det_result[OutputKeys.SCORES]) | |||||
if bboxes.shape[0] == 0: | |||||
logger.info('Warning: No face detected!') | |||||
return None | |||||
# face idx with enough size | |||||
face_idx = [] | |||||
for i in range(bboxes.shape[0]): | |||||
box = bboxes[i] | |||||
if (box[2] - box[0]) >= min_face and (box[3] - box[1]) >= min_face: | |||||
face_idx += [i] | |||||
if len(face_idx) == 0: | |||||
logger.info( | |||||
f'Warning: Face size not enough, less than {min_face}x{min_face}!' | |||||
) | |||||
return None | |||||
bboxes = bboxes[face_idx] | |||||
landmarks = landmarks[face_idx] | |||||
scores = scores[face_idx] | |||||
# find max faces | |||||
boxes = np.array(bboxes) | |||||
area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) | |||||
sort_idx = np.argsort(area)[-top_face:] | |||||
# find center face | |||||
if top_face > 1 and center_face and bboxes.shape[0] > 1: | |||||
img_center = [img.shape[1] // 2, img.shape[0] // 2] | |||||
min_dist = float('inf') | |||||
sel_idx = -1 | |||||
for _idx in sort_idx: | |||||
box = boxes[_idx] | |||||
dist = np.square( | |||||
np.abs((box[0] + box[2]) / 2 - img_center[0])) + np.square( | |||||
np.abs((box[1] + box[3]) / 2 - img_center[1])) | |||||
if dist < min_dist: | |||||
min_dist = dist | |||||
sel_idx = _idx | |||||
sort_idx = [sel_idx] | |||||
main_idx = sort_idx[-1] | |||||
return scores[main_idx], bboxes[main_idx], landmarks[main_idx] | |||||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
img = LoadImage.convert_to_ndarray(input) | |||||
img = img[:, :, ::-1] | |||||
det_result = self.face_detection(img.copy()) | |||||
rtn = self._choose_face(det_result) | |||||
if rtn is not None: | |||||
scores, bboxes, face_lmks = rtn | |||||
face_lmks = face_lmks.reshape(5, 2) | |||||
align_img, _ = align_face(img, (112, 112), face_lmks) | |||||
result = {} | |||||
result['img'] = np.ascontiguousarray(align_img) | |||||
result['scores'] = [scores] | |||||
result['bbox'] = bboxes | |||||
result['lmks'] = face_lmks | |||||
return result | |||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
return { | |||||
OutputKeys.OUTPUT_IMG: input['img'].cpu().numpy(), | |||||
OutputKeys.SCORES: input['scores'].cpu().tolist(), | |||||
OutputKeys.BOXES: [input['bbox'].cpu().tolist()], | |||||
OutputKeys.KEYPOINTS: [input['lmks'].cpu().tolist()] | |||||
} | |||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
return inputs |
@@ -0,0 +1,67 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os.path as osp | |||||
from typing import Any, Dict | |||||
import cv2 | |||||
import numpy as np | |||||
import PIL | |||||
import torch | |||||
from modelscope.metainfo import Pipelines | |||||
from modelscope.models.cv.face_recognition.align_face import align_face | |||||
from modelscope.models.cv.facial_landmark_confidence import \ | |||||
FacialLandmarkConfidence | |||||
from modelscope.outputs import OutputKeys | |||||
from modelscope.pipelines import pipeline | |||||
from modelscope.pipelines.base import Input, Pipeline | |||||
from modelscope.pipelines.builder import PIPELINES | |||||
from modelscope.preprocessors import LoadImage | |||||
from modelscope.utils.constant import ModelFile, Tasks | |||||
from modelscope.utils.logger import get_logger | |||||
from . import FaceProcessingBasePipeline | |||||
logger = get_logger() | |||||
@PIPELINES.register_module( | |||||
Tasks.facial_landmark_confidence, | |||||
module_name=Pipelines.facial_landmark_confidence) | |||||
class FacialLandmarkConfidencePipeline(FaceProcessingBasePipeline): | |||||
def __init__(self, model: str, **kwargs): | |||||
""" | |||||
use `model` to create a facial landmrk confidence pipeline for prediction | |||||
Args: | |||||
model: model id on modelscope hub. | |||||
""" | |||||
super().__init__(model=model, **kwargs) | |||||
ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_FILE) | |||||
logger.info(f'loading model from {ckpt_path}') | |||||
flcm = FacialLandmarkConfidence( | |||||
model_path=ckpt_path, device=self.device) | |||||
self.flcm = flcm | |||||
logger.info('load model done') | |||||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
result = super(FacialLandmarkConfidencePipeline, | |||||
self).preprocess(input) | |||||
img = LoadImage.convert_to_ndarray(input) | |||||
img = img[:, :, ::-1] | |||||
result['orig_img'] = img.astype(np.float32) | |||||
return result | |||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
result = self.flcm(input) | |||||
assert result is not None | |||||
lms = result[0].reshape(-1, 10).tolist() | |||||
scores = [1 - result[1].tolist()] | |||||
boxes = input['bbox'].cpu().numpy()[np.newaxis, :].tolist() | |||||
return { | |||||
OutputKeys.SCORES: scores, | |||||
OutputKeys.KEYPOINTS: lms, | |||||
OutputKeys.BOXES: boxes | |||||
} | |||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
return inputs |
@@ -25,6 +25,8 @@ class CVTasks(object): | |||||
card_detection = 'card-detection' | card_detection = 'card-detection' | ||||
face_recognition = 'face-recognition' | face_recognition = 'face-recognition' | ||||
facial_expression_recognition = 'facial-expression-recognition' | facial_expression_recognition = 'facial-expression-recognition' | ||||
facial_landmark_confidence = 'facial-landmark-confidence' | |||||
face_processing_base = 'face-processing-base' | |||||
face_attribute_recognition = 'face-attribute-recognition' | face_attribute_recognition = 'face-attribute-recognition' | ||||
face_2d_keypoints = 'face-2d-keypoints' | face_2d_keypoints = 'face-2d-keypoints' | ||||
human_detection = 'human-detection' | human_detection = 'human-detection' | ||||
@@ -0,0 +1,37 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import unittest | |||||
import numpy as np | |||||
from modelscope.outputs import OutputKeys | |||||
from modelscope.pipelines import pipeline | |||||
from modelscope.utils.constant import Tasks | |||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||||
from modelscope.utils.test_utils import test_level | |||||
class FaceRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
def setUp(self) -> None: | |||||
self.task = Tasks.face_recognition | |||||
self.model_id = 'damo/cv_ir50_face-recognition_arcface' | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_face_compare(self): | |||||
img1 = 'data/test/images/face_recognition_1.png' | |||||
img2 = 'data/test/images/face_recognition_2.png' | |||||
face_recognition = pipeline( | |||||
Tasks.face_recognition, model=self.model_id) | |||||
emb1 = face_recognition(img1)[OutputKeys.IMG_EMBEDDING] | |||||
emb2 = face_recognition(img2)[OutputKeys.IMG_EMBEDDING] | |||||
sim = np.dot(emb1[0], emb2[0]) | |||||
print(f'Cos similarity={sim:.3f}, img1:{img1} img2:{img2}') | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_demo_compatibility(self): | |||||
self.compatibility_check() | |||||
if __name__ == '__main__': | |||||
unittest.main() |
@@ -0,0 +1,35 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os.path as osp | |||||
import unittest | |||||
import cv2 | |||||
import numpy as np | |||||
from modelscope.msdatasets import MsDataset | |||||
from modelscope.outputs import OutputKeys | |||||
from modelscope.pipelines import pipeline | |||||
from modelscope.utils.constant import Tasks | |||||
from modelscope.utils.cv.image_utils import draw_face_detection_result | |||||
from modelscope.utils.test_utils import test_level | |||||
class FacialLandmarkConfidenceTest(unittest.TestCase): | |||||
def setUp(self) -> None: | |||||
self.model_id = 'damo/cv_manual_facial-landmark-confidence_flcm' | |||||
def show_result(self, img_path, facial_expression_result): | |||||
img = draw_face_detection_result(img_path, facial_expression_result) | |||||
cv2.imwrite('result.png', img) | |||||
print(f'output written to {osp.abspath("result.png")}') | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_run_modelhub(self): | |||||
flcm = pipeline(Tasks.facial_landmark_confidence, model=self.model_id) | |||||
img_path = 'data/test/images/face_recognition_1.png' | |||||
result = flcm(img_path) | |||||
self.show_result(img_path, result) | |||||
if __name__ == '__main__': | |||||
unittest.main() |