1. 新增人脸检测RetinaFace模型;
2. 完成Maas-cv CR标准自查
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9945188
master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9 | |||||
| size 87228 | |||||
| @@ -32,6 +32,7 @@ class Models(object): | |||||
| vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | ||||
| text_driven_segmentation = 'text-driven-segmentation' | text_driven_segmentation = 'text-driven-segmentation' | ||||
| resnet50_bert = 'resnet50-bert' | resnet50_bert = 'resnet50-bert' | ||||
| retinaface = 'retinaface' | |||||
| shop_segmentation = 'shop-segmentation' | shop_segmentation = 'shop-segmentation' | ||||
| # EasyCV models | # EasyCV models | ||||
| @@ -118,6 +119,7 @@ class Pipelines(object): | |||||
| salient_detection = 'u2net-salient-detection' | salient_detection = 'u2net-salient-detection' | ||||
| image_classification = 'image-classification' | image_classification = 'image-classification' | ||||
| face_detection = 'resnet-face-detection-scrfd10gkps' | face_detection = 'resnet-face-detection-scrfd10gkps' | ||||
| retina_face_detection = 'resnet50-face-detection-retinaface' | |||||
| live_category = 'live-category' | live_category = 'live-category' | ||||
| general_image_classification = 'vit-base_image-classification_ImageNet-labels' | general_image_classification = 'vit-base_image-classification_ImageNet-labels' | ||||
| daily_image_classification = 'vit-base_image-classification_Dailylife-labels' | daily_image_classification = 'vit-base_image-classification_Dailylife-labels' | ||||
| @@ -0,0 +1,137 @@ | |||||
| # The implementation is based on resnet, available at https://github.com/biubug6/Pytorch_Retinaface | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.backends.cudnn as cudnn | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import Tensor, TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from .models.retinaface import RetinaFace | |||||
| from .utils import PriorBox, decode, decode_landm, py_cpu_nms | |||||
| @MODELS.register_module(Tasks.face_detection, module_name=Models.retinaface) | |||||
| class RetinaFaceDetection(TorchModel): | |||||
| def __init__(self, model_path, device='cuda'): | |||||
| super().__init__(model_path) | |||||
| torch.set_grad_enabled(False) | |||||
| cudnn.benchmark = True | |||||
| self.model_path = model_path | |||||
| self.cfg = Config.from_file( | |||||
| model_path.replace(ModelFile.TORCH_MODEL_FILE, | |||||
| ModelFile.CONFIGURATION))['models'] | |||||
| self.net = RetinaFace(cfg=self.cfg) | |||||
| self.load_model() | |||||
| self.device = device | |||||
| self.net = self.net.to(self.device) | |||||
| self.mean = torch.tensor([[[[104]], [[117]], [[123]]]]).to(device) | |||||
| def check_keys(self, pretrained_state_dict): | |||||
| ckpt_keys = set(pretrained_state_dict.keys()) | |||||
| model_keys = set(self.net.state_dict().keys()) | |||||
| used_pretrained_keys = model_keys & ckpt_keys | |||||
| assert len( | |||||
| used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' | |||||
| return True | |||||
| def remove_prefix(self, state_dict, prefix): | |||||
| new_state_dict = dict() | |||||
| for k, v in state_dict.items(): | |||||
| if k.startswith(prefix): | |||||
| new_state_dict[k[len(prefix):]] = v | |||||
| else: | |||||
| new_state_dict[k] = v | |||||
| return new_state_dict | |||||
| def load_model(self, load_to_cpu=False): | |||||
| pretrained_dict = torch.load( | |||||
| self.model_path, map_location=torch.device('cpu')) | |||||
| if 'state_dict' in pretrained_dict.keys(): | |||||
| pretrained_dict = self.remove_prefix(pretrained_dict['state_dict'], | |||||
| 'module.') | |||||
| else: | |||||
| pretrained_dict = self.remove_prefix(pretrained_dict, 'module.') | |||||
| self.check_keys(pretrained_dict) | |||||
| self.net.load_state_dict(pretrained_dict, strict=False) | |||||
| self.net.eval() | |||||
| def forward(self, input): | |||||
| img_raw = input['img'].cpu().numpy() | |||||
| img = np.float32(img_raw) | |||||
| im_height, im_width = img.shape[:2] | |||||
| ss = 1.0 | |||||
| # tricky | |||||
| if max(im_height, im_width) > 1500: | |||||
| ss = 1000.0 / max(im_height, im_width) | |||||
| img = cv2.resize(img, (0, 0), fx=ss, fy=ss) | |||||
| im_height, im_width = img.shape[:2] | |||||
| scale = torch.Tensor( | |||||
| [img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) | |||||
| img -= (104, 117, 123) | |||||
| img = img.transpose(2, 0, 1) | |||||
| img = torch.from_numpy(img).unsqueeze(0) | |||||
| img = img.to(self.device) | |||||
| scale = scale.to(self.device) | |||||
| loc, conf, landms = self.net(img) # forward pass | |||||
| del img | |||||
| confidence_threshold = 0.9 | |||||
| nms_threshold = 0.4 | |||||
| top_k = 5000 | |||||
| keep_top_k = 750 | |||||
| priorbox = PriorBox(self.cfg, image_size=(im_height, im_width)) | |||||
| priors = priorbox.forward() | |||||
| priors = priors.to(self.device) | |||||
| prior_data = priors.data | |||||
| boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance']) | |||||
| boxes = boxes * scale | |||||
| boxes = boxes.cpu().numpy() | |||||
| scores = conf.squeeze(0).data.cpu().numpy()[:, 1] | |||||
| landms = decode_landm( | |||||
| landms.data.squeeze(0), prior_data, self.cfg['variance']) | |||||
| scale1 = torch.Tensor([ | |||||
| im_width, im_height, im_width, im_height, im_width, im_height, | |||||
| im_width, im_height, im_width, im_height | |||||
| ]) | |||||
| scale1 = scale1.to(self.device) | |||||
| landms = landms * scale1 | |||||
| landms = landms.cpu().numpy() | |||||
| # ignore low scores | |||||
| inds = np.where(scores > confidence_threshold)[0] | |||||
| boxes = boxes[inds] | |||||
| landms = landms[inds] | |||||
| scores = scores[inds] | |||||
| # keep top-K before NMS | |||||
| order = scores.argsort()[::-1][:top_k] | |||||
| boxes = boxes[order] | |||||
| landms = landms[order] | |||||
| scores = scores[order] | |||||
| # do NMS | |||||
| dets = np.hstack((boxes, scores[:, np.newaxis])).astype( | |||||
| np.float32, copy=False) | |||||
| keep = py_cpu_nms(dets, nms_threshold) | |||||
| dets = dets[keep, :] | |||||
| landms = landms[keep] | |||||
| # keep top-K faster NMS | |||||
| dets = dets[:keep_top_k, :] | |||||
| landms = landms[:keep_top_k, :] | |||||
| landms = landms.reshape((-1, 5, 2)) | |||||
| landms = landms.reshape( | |||||
| -1, | |||||
| 10, | |||||
| ) | |||||
| return dets / ss, landms / ss | |||||
| @@ -0,0 +1,149 @@ | |||||
| # The implementation is based on resnet, available at https://github.com/biubug6/Pytorch_Retinaface | |||||
| import time | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| import torchvision.models as models | |||||
| import torchvision.models._utils as _utils | |||||
| from torch.autograd import Variable | |||||
| def conv_bn(inp, oup, stride=1, leaky=0): | |||||
| return nn.Sequential( | |||||
| nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), | |||||
| nn.LeakyReLU(negative_slope=leaky, inplace=True)) | |||||
| def conv_bn_no_relu(inp, oup, stride): | |||||
| return nn.Sequential( | |||||
| nn.Conv2d(inp, oup, 3, stride, 1, bias=False), | |||||
| nn.BatchNorm2d(oup), | |||||
| ) | |||||
| def conv_bn1X1(inp, oup, stride, leaky=0): | |||||
| return nn.Sequential( | |||||
| nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), | |||||
| nn.BatchNorm2d(oup), nn.LeakyReLU(negative_slope=leaky, inplace=True)) | |||||
| def conv_dw(inp, oup, stride, leaky=0.1): | |||||
| return nn.Sequential( | |||||
| nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), | |||||
| nn.BatchNorm2d(inp), | |||||
| nn.LeakyReLU(negative_slope=leaky, inplace=True), | |||||
| nn.Conv2d(inp, oup, 1, 1, 0, bias=False), | |||||
| nn.BatchNorm2d(oup), | |||||
| nn.LeakyReLU(negative_slope=leaky, inplace=True), | |||||
| ) | |||||
| class SSH(nn.Module): | |||||
| def __init__(self, in_channel, out_channel): | |||||
| super(SSH, self).__init__() | |||||
| assert out_channel % 4 == 0 | |||||
| leaky = 0 | |||||
| if (out_channel <= 64): | |||||
| leaky = 0.1 | |||||
| self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) | |||||
| self.conv5X5_1 = conv_bn( | |||||
| in_channel, out_channel // 4, stride=1, leaky=leaky) | |||||
| self.conv5X5_2 = conv_bn_no_relu( | |||||
| out_channel // 4, out_channel // 4, stride=1) | |||||
| self.conv7X7_2 = conv_bn( | |||||
| out_channel // 4, out_channel // 4, stride=1, leaky=leaky) | |||||
| self.conv7x7_3 = conv_bn_no_relu( | |||||
| out_channel // 4, out_channel // 4, stride=1) | |||||
| def forward(self, input): | |||||
| conv3X3 = self.conv3X3(input) | |||||
| conv5X5_1 = self.conv5X5_1(input) | |||||
| conv5X5 = self.conv5X5_2(conv5X5_1) | |||||
| conv7X7_2 = self.conv7X7_2(conv5X5_1) | |||||
| conv7X7 = self.conv7x7_3(conv7X7_2) | |||||
| out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) | |||||
| out = F.relu(out) | |||||
| return out | |||||
| class FPN(nn.Module): | |||||
| def __init__(self, in_channels_list, out_channels): | |||||
| super(FPN, self).__init__() | |||||
| leaky = 0 | |||||
| if (out_channels <= 64): | |||||
| leaky = 0.1 | |||||
| self.output1 = conv_bn1X1( | |||||
| in_channels_list[0], out_channels, stride=1, leaky=leaky) | |||||
| self.output2 = conv_bn1X1( | |||||
| in_channels_list[1], out_channels, stride=1, leaky=leaky) | |||||
| self.output3 = conv_bn1X1( | |||||
| in_channels_list[2], out_channels, stride=1, leaky=leaky) | |||||
| self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) | |||||
| self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) | |||||
| def forward(self, input): | |||||
| # names = list(input.keys()) | |||||
| input = list(input.values()) | |||||
| output1 = self.output1(input[0]) | |||||
| output2 = self.output2(input[1]) | |||||
| output3 = self.output3(input[2]) | |||||
| up3 = F.interpolate( | |||||
| output3, size=[output2.size(2), output2.size(3)], mode='nearest') | |||||
| output2 = output2 + up3 | |||||
| output2 = self.merge2(output2) | |||||
| up2 = F.interpolate( | |||||
| output2, size=[output1.size(2), output1.size(3)], mode='nearest') | |||||
| output1 = output1 + up2 | |||||
| output1 = self.merge1(output1) | |||||
| out = [output1, output2, output3] | |||||
| return out | |||||
| class MobileNetV1(nn.Module): | |||||
| def __init__(self): | |||||
| super(MobileNetV1, self).__init__() | |||||
| self.stage1 = nn.Sequential( | |||||
| conv_bn(3, 8, 2, leaky=0.1), # 3 | |||||
| conv_dw(8, 16, 1), # 7 | |||||
| conv_dw(16, 32, 2), # 11 | |||||
| conv_dw(32, 32, 1), # 19 | |||||
| conv_dw(32, 64, 2), # 27 | |||||
| conv_dw(64, 64, 1), # 43 | |||||
| ) | |||||
| self.stage2 = nn.Sequential( | |||||
| conv_dw(64, 128, 2), # 43 + 16 = 59 | |||||
| conv_dw(128, 128, 1), # 59 + 32 = 91 | |||||
| conv_dw(128, 128, 1), # 91 + 32 = 123 | |||||
| conv_dw(128, 128, 1), # 123 + 32 = 155 | |||||
| conv_dw(128, 128, 1), # 155 + 32 = 187 | |||||
| conv_dw(128, 128, 1), # 187 + 32 = 219 | |||||
| ) | |||||
| self.stage3 = nn.Sequential( | |||||
| conv_dw(128, 256, 2), # 219 +3 2 = 241 | |||||
| conv_dw(256, 256, 1), # 241 + 64 = 301 | |||||
| ) | |||||
| self.avg = nn.AdaptiveAvgPool2d((1, 1)) | |||||
| self.fc = nn.Linear(256, 1000) | |||||
| def forward(self, x): | |||||
| x = self.stage1(x) | |||||
| x = self.stage2(x) | |||||
| x = self.stage3(x) | |||||
| x = self.avg(x) | |||||
| x = x.view(-1, 256) | |||||
| x = self.fc(x) | |||||
| return x | |||||
| @@ -0,0 +1,145 @@ | |||||
| # The implementation is based on resnet, available at https://github.com/biubug6/Pytorch_Retinaface | |||||
| from collections import OrderedDict | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| import torchvision.models as models | |||||
| import torchvision.models._utils as _utils | |||||
| import torchvision.models.detection.backbone_utils as backbone_utils | |||||
| from .net import FPN, SSH, MobileNetV1 | |||||
| class ClassHead(nn.Module): | |||||
| def __init__(self, inchannels=512, num_anchors=3): | |||||
| super(ClassHead, self).__init__() | |||||
| self.num_anchors = num_anchors | |||||
| self.conv1x1 = nn.Conv2d( | |||||
| inchannels, | |||||
| self.num_anchors * 2, | |||||
| kernel_size=(1, 1), | |||||
| stride=1, | |||||
| padding=0) | |||||
| def forward(self, x): | |||||
| out = self.conv1x1(x) | |||||
| out = out.permute(0, 2, 3, 1).contiguous() | |||||
| return out.view(out.shape[0], -1, 2) | |||||
| class BboxHead(nn.Module): | |||||
| def __init__(self, inchannels=512, num_anchors=3): | |||||
| super(BboxHead, self).__init__() | |||||
| self.conv1x1 = nn.Conv2d( | |||||
| inchannels, | |||||
| num_anchors * 4, | |||||
| kernel_size=(1, 1), | |||||
| stride=1, | |||||
| padding=0) | |||||
| def forward(self, x): | |||||
| out = self.conv1x1(x) | |||||
| out = out.permute(0, 2, 3, 1).contiguous() | |||||
| return out.view(out.shape[0], -1, 4) | |||||
| class LandmarkHead(nn.Module): | |||||
| def __init__(self, inchannels=512, num_anchors=3): | |||||
| super(LandmarkHead, self).__init__() | |||||
| self.conv1x1 = nn.Conv2d( | |||||
| inchannels, | |||||
| num_anchors * 10, | |||||
| kernel_size=(1, 1), | |||||
| stride=1, | |||||
| padding=0) | |||||
| def forward(self, x): | |||||
| out = self.conv1x1(x) | |||||
| out = out.permute(0, 2, 3, 1).contiguous() | |||||
| return out.view(out.shape[0], -1, 10) | |||||
| class RetinaFace(nn.Module): | |||||
| def __init__(self, cfg=None): | |||||
| """ | |||||
| :param cfg: Network related settings. | |||||
| """ | |||||
| super(RetinaFace, self).__init__() | |||||
| backbone = None | |||||
| if cfg['name'] == 'Resnet50': | |||||
| backbone = models.resnet50(pretrained=cfg['pretrain']) | |||||
| else: | |||||
| raise Exception('Invalid name') | |||||
| self.body = _utils.IntermediateLayerGetter(backbone, | |||||
| cfg['return_layers']) | |||||
| in_channels_stage2 = cfg['in_channel'] | |||||
| in_channels_list = [ | |||||
| in_channels_stage2 * 2, | |||||
| in_channels_stage2 * 4, | |||||
| in_channels_stage2 * 8, | |||||
| ] | |||||
| out_channels = cfg['out_channel'] | |||||
| self.fpn = FPN(in_channels_list, out_channels) | |||||
| self.ssh1 = SSH(out_channels, out_channels) | |||||
| self.ssh2 = SSH(out_channels, out_channels) | |||||
| self.ssh3 = SSH(out_channels, out_channels) | |||||
| self.ClassHead = self._make_class_head( | |||||
| fpn_num=3, inchannels=cfg['out_channel']) | |||||
| self.BboxHead = self._make_bbox_head( | |||||
| fpn_num=3, inchannels=cfg['out_channel']) | |||||
| self.LandmarkHead = self._make_landmark_head( | |||||
| fpn_num=3, inchannels=cfg['out_channel']) | |||||
| def _make_class_head(self, fpn_num=3, inchannels=64, anchor_num=2): | |||||
| classhead = nn.ModuleList() | |||||
| for i in range(fpn_num): | |||||
| classhead.append(ClassHead(inchannels, anchor_num)) | |||||
| return classhead | |||||
| def _make_bbox_head(self, fpn_num=3, inchannels=64, anchor_num=2): | |||||
| bboxhead = nn.ModuleList() | |||||
| for i in range(fpn_num): | |||||
| bboxhead.append(BboxHead(inchannels, anchor_num)) | |||||
| return bboxhead | |||||
| def _make_landmark_head(self, fpn_num=3, inchannels=64, anchor_num=2): | |||||
| landmarkhead = nn.ModuleList() | |||||
| for i in range(fpn_num): | |||||
| landmarkhead.append(LandmarkHead(inchannels, anchor_num)) | |||||
| return landmarkhead | |||||
| def forward(self, inputs): | |||||
| out = self.body(inputs) | |||||
| # FPN | |||||
| fpn = self.fpn(out) | |||||
| # SSH | |||||
| feature1 = self.ssh1(fpn[0]) | |||||
| feature2 = self.ssh2(fpn[1]) | |||||
| feature3 = self.ssh3(fpn[2]) | |||||
| features = [feature1, feature2, feature3] | |||||
| bbox_regressions = torch.cat( | |||||
| [self.BboxHead[i](feature) for i, feature in enumerate(features)], | |||||
| dim=1) | |||||
| classifications = torch.cat( | |||||
| [self.ClassHead[i](feature) for i, feature in enumerate(features)], | |||||
| dim=1) | |||||
| ldm_regressions = torch.cat( | |||||
| [self.LandmarkHead[i](feat) for i, feat in enumerate(features)], | |||||
| dim=1) | |||||
| output = (bbox_regressions, F.softmax(classifications, | |||||
| dim=-1), ldm_regressions) | |||||
| return output | |||||
| @@ -0,0 +1,123 @@ | |||||
| # -------------------------------------------------------- | |||||
| # Modified from https://github.com/biubug6/Pytorch_Retinaface | |||||
| # -------------------------------------------------------- | |||||
| from itertools import product as product | |||||
| from math import ceil | |||||
| import numpy as np | |||||
| import torch | |||||
| class PriorBox(object): | |||||
| def __init__(self, cfg, image_size=None, phase='train'): | |||||
| super(PriorBox, self).__init__() | |||||
| self.min_sizes = cfg['min_sizes'] | |||||
| self.steps = cfg['steps'] | |||||
| self.clip = cfg['clip'] | |||||
| self.image_size = image_size | |||||
| self.feature_maps = [[ | |||||
| ceil(self.image_size[0] / step), | |||||
| ceil(self.image_size[1] / step) | |||||
| ] for step in self.steps] | |||||
| self.name = 's' | |||||
| def forward(self): | |||||
| anchors = [] | |||||
| for k, f in enumerate(self.feature_maps): | |||||
| min_sizes = self.min_sizes[k] | |||||
| for i, j in product(range(f[0]), range(f[1])): | |||||
| for min_size in min_sizes: | |||||
| s_kx = min_size / self.image_size[1] | |||||
| s_ky = min_size / self.image_size[0] | |||||
| dense_cx = [ | |||||
| x * self.steps[k] / self.image_size[1] | |||||
| for x in [j + 0.5] | |||||
| ] | |||||
| dense_cy = [ | |||||
| y * self.steps[k] / self.image_size[0] | |||||
| for y in [i + 0.5] | |||||
| ] | |||||
| for cy, cx in product(dense_cy, dense_cx): | |||||
| anchors += [cx, cy, s_kx, s_ky] | |||||
| # back to torch land | |||||
| output = torch.Tensor(anchors).view(-1, 4) | |||||
| if self.clip: | |||||
| output.clamp_(max=1, min=0) | |||||
| return output | |||||
| def py_cpu_nms(dets, thresh): | |||||
| """Pure Python NMS baseline.""" | |||||
| x1 = dets[:, 0] | |||||
| y1 = dets[:, 1] | |||||
| x2 = dets[:, 2] | |||||
| y2 = dets[:, 3] | |||||
| scores = dets[:, 4] | |||||
| areas = (x2 - x1 + 1) * (y2 - y1 + 1) | |||||
| order = scores.argsort()[::-1] | |||||
| keep = [] | |||||
| while order.size > 0: | |||||
| i = order[0] | |||||
| keep.append(i) | |||||
| xx1 = np.maximum(x1[i], x1[order[1:]]) | |||||
| yy1 = np.maximum(y1[i], y1[order[1:]]) | |||||
| xx2 = np.minimum(x2[i], x2[order[1:]]) | |||||
| yy2 = np.minimum(y2[i], y2[order[1:]]) | |||||
| w = np.maximum(0.0, xx2 - xx1 + 1) | |||||
| h = np.maximum(0.0, yy2 - yy1 + 1) | |||||
| inter = w * h | |||||
| ovr = inter / (areas[i] + areas[order[1:]] - inter) | |||||
| inds = np.where(ovr <= thresh)[0] | |||||
| order = order[inds + 1] | |||||
| return keep | |||||
| # Adapted from https://github.com/Hakuyume/chainer-ssd | |||||
| def decode(loc, priors, variances): | |||||
| """Decode locations from predictions using priors to undo | |||||
| the encoding we did for offset regression at train time. | |||||
| Args: | |||||
| loc (tensor): location predictions for loc layers, | |||||
| Shape: [num_priors,4] | |||||
| priors (tensor): Prior boxes in center-offset form. | |||||
| Shape: [num_priors,4]. | |||||
| variances: (list[float]) Variances of priorboxes | |||||
| Return: | |||||
| decoded bounding box predictions | |||||
| """ | |||||
| boxes = torch.cat( | |||||
| (priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], | |||||
| priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) | |||||
| boxes[:, :2] -= boxes[:, 2:] / 2 | |||||
| boxes[:, 2:] += boxes[:, :2] | |||||
| return boxes | |||||
| def decode_landm(pre, priors, variances): | |||||
| """Decode landm from predictions using priors to undo | |||||
| the encoding we did for offset regression at train time. | |||||
| Args: | |||||
| pre (tensor): landm predictions for loc layers, | |||||
| Shape: [num_priors,10] | |||||
| priors (tensor): Prior boxes in center-offset form. | |||||
| Shape: [num_priors,4]. | |||||
| variances: (list[float]) Variances of priorboxes | |||||
| Return: | |||||
| decoded landm predictions | |||||
| """ | |||||
| a = priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:] | |||||
| b = priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:] | |||||
| c = priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:] | |||||
| d = priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:] | |||||
| e = priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:] | |||||
| landms = torch.cat((a, b, c, d, e), dim=1) | |||||
| return landms | |||||
| @@ -2,7 +2,6 @@ | |||||
| import os.path as osp | import os.path as osp | ||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
| from contextlib import contextmanager | |||||
| from threading import Lock | from threading import Lock | ||||
| from typing import Any, Dict, Generator, List, Mapping, Union | from typing import Any, Dict, Generator, List, Mapping, Union | ||||
| @@ -0,0 +1,55 @@ | |||||
| import os.path as osp | |||||
| from typing import Any, Dict | |||||
| import numpy as np | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.cv.face_detection.retinaface import detection | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import LoadImage | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.face_detection, module_name=Pipelines.retina_face_detection) | |||||
| class RetinaFaceDetectionPipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| """ | |||||
| use `model` to create a face detection pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model, **kwargs) | |||||
| ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_FILE) | |||||
| logger.info(f'loading model from {ckpt_path}') | |||||
| detector = detection.RetinaFaceDetection( | |||||
| model_path=ckpt_path, device=self.device) | |||||
| self.detector = detector | |||||
| logger.info('load model done') | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| img = LoadImage.convert_to_ndarray(input) | |||||
| img = img.astype(np.float32) | |||||
| result = {'img': img} | |||||
| return result | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| result = self.detector(input) | |||||
| assert result is not None | |||||
| bboxes = result[0][:, :4].tolist() | |||||
| scores = result[0][:, 4].tolist() | |||||
| lms = result[1].tolist() | |||||
| return { | |||||
| OutputKeys.SCORES: scores, | |||||
| OutputKeys.BOXES: bboxes, | |||||
| OutputKeys.KEYPOINTS: lms, | |||||
| } | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| @@ -0,0 +1,33 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path as osp | |||||
| import unittest | |||||
| import cv2 | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.cv.image_utils import draw_face_detection_result | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class RetinaFaceDetectionTest(unittest.TestCase): | |||||
| def setUp(self) -> None: | |||||
| self.model_id = 'damo/cv_resnet50_face-detection_retinaface' | |||||
| def show_result(self, img_path, detection_result): | |||||
| img = draw_face_detection_result(img_path, detection_result) | |||||
| cv2.imwrite('result.png', img) | |||||
| print(f'output written to {osp.abspath("result.png")}') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_modelhub(self): | |||||
| face_detection = pipeline(Tasks.face_detection, model=self.model_id) | |||||
| img_path = 'data/test/images/retina_face_detection.jpg' | |||||
| result = face_detection(img_path) | |||||
| self.show_result(img_path, result) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||