diff --git a/data/test/images/skin_retouching.png b/data/test/images/skin_retouching.png new file mode 100644 index 00000000..a0b8df2a --- /dev/null +++ b/data/test/images/skin_retouching.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0fcd36e0ada8a506bb09d3e0f3594e2be978194ea4123e066331c0bcb7fc79bc +size 683425 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 451c0bec..215233fe 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -94,6 +94,7 @@ class Pipelines(object): video_category = 'video-category' image_portrait_enhancement = 'gpen-image-portrait-enhancement' image_to_image_generation = 'image-to-image-generation' + skin_retouching = 'unet-skin-retouching' # nlp tasks sentence_similarity = 'sentence-similarity' diff --git a/modelscope/models/cv/skin_retouching/__init__.py b/modelscope/models/cv/skin_retouching/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/skin_retouching/detection_model/__init__.py b/modelscope/models/cv/skin_retouching/detection_model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/skin_retouching/detection_model/detection_module.py b/modelscope/models/cv/skin_retouching/detection_model/detection_module.py new file mode 100644 index 00000000..f89ce37b --- /dev/null +++ b/modelscope/models/cv/skin_retouching/detection_model/detection_module.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn + + +class ConvBNActiv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + bn=True, + sample='none-3', + activ='relu', + bias=False): + super(ConvBNActiv, self).__init__() + + if sample == 'down-7': + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=7, + stride=2, + padding=3, + bias=bias) + elif sample == 'down-5': + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=5, + stride=2, + padding=2, + bias=bias) + elif sample == 'down-3': + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=bias) + else: + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=bias) + + if bn: + self.bn = nn.BatchNorm2d(out_channels) + + if activ == 'relu': + self.activation = nn.ReLU() + elif activ == 'leaky': + self.activation = nn.LeakyReLU(negative_slope=0.2) + + def forward(self, images): + + outputs = self.conv(images) + if hasattr(self, 'bn'): + outputs = self.bn(outputs) + if hasattr(self, 'activation'): + outputs = self.activation(outputs) + + return outputs diff --git a/modelscope/models/cv/skin_retouching/detection_model/detection_unet_in.py b/modelscope/models/cv/skin_retouching/detection_model/detection_unet_in.py new file mode 100644 index 00000000..b48f6e5f --- /dev/null +++ b/modelscope/models/cv/skin_retouching/detection_model/detection_unet_in.py @@ -0,0 +1,66 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..weights_init import weights_init +from .detection_module import ConvBNActiv + + +class DetectionUNet(nn.Module): + + def __init__(self, + n_channels, + n_classes, + up_sampling_node='nearest', + init_weights=True): + super(DetectionUNet, self).__init__() + + self.n_classes = n_classes + self.up_sampling_node = up_sampling_node + + self.ec_images_1 = ConvBNActiv( + n_channels, 64, bn=False, sample='down-3') + self.ec_images_2 = ConvBNActiv(64, 128, sample='down-3') + self.ec_images_3 = ConvBNActiv(128, 256, sample='down-3') + self.ec_images_4 = ConvBNActiv(256, 512, sample='down-3') + self.ec_images_5 = ConvBNActiv(512, 512, sample='down-3') + self.ec_images_6 = ConvBNActiv(512, 512, sample='down-3') + + self.dc_images_6 = ConvBNActiv(512 + 512, 512, activ='leaky') + self.dc_images_5 = ConvBNActiv(512 + 512, 512, activ='leaky') + self.dc_images_4 = ConvBNActiv(512 + 256, 256, activ='leaky') + self.dc_images_3 = ConvBNActiv(256 + 128, 128, activ='leaky') + self.dc_images_2 = ConvBNActiv(128 + 64, 64, activ='leaky') + self.dc_images_1 = nn.Conv2d(64 + n_channels, n_classes, kernel_size=1) + + if init_weights: + self.apply(weights_init()) + + def forward(self, input_images): + + ec_images = {} + + ec_images['ec_images_0'] = input_images + ec_images['ec_images_1'] = self.ec_images_1(input_images) + ec_images['ec_images_2'] = self.ec_images_2(ec_images['ec_images_1']) + ec_images['ec_images_3'] = self.ec_images_3(ec_images['ec_images_2']) + ec_images['ec_images_4'] = self.ec_images_4(ec_images['ec_images_3']) + ec_images['ec_images_5'] = self.ec_images_5(ec_images['ec_images_4']) + ec_images['ec_images_6'] = self.ec_images_6(ec_images['ec_images_5']) + # -------------- + # images decoder + # -------------- + logits = ec_images['ec_images_6'] + + for _ in range(6, 0, -1): + + ec_images_skip = 'ec_images_{:d}'.format(_ - 1) + dc_conv = 'dc_images_{:d}'.format(_) + + logits = F.interpolate( + logits, scale_factor=2, mode=self.up_sampling_node) + logits = torch.cat((logits, ec_images[ec_images_skip]), dim=1) + + logits = getattr(self, dc_conv)(logits) + + return logits diff --git a/modelscope/models/cv/skin_retouching/inpainting_model/__init__.py b/modelscope/models/cv/skin_retouching/inpainting_model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/skin_retouching/inpainting_model/gconv.py b/modelscope/models/cv/skin_retouching/inpainting_model/gconv.py new file mode 100644 index 00000000..e0910d2c --- /dev/null +++ b/modelscope/models/cv/skin_retouching/inpainting_model/gconv.py @@ -0,0 +1,207 @@ +import torch +import torch.nn as nn + + +class GatedConvBNActiv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + bn=True, + sample='none-3', + activ='relu', + bias=False): + super(GatedConvBNActiv, self).__init__() + + if sample == 'down-7': + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=7, + stride=2, + padding=3, + bias=bias) + self.gate = nn.Conv2d( + in_channels, + out_channels, + kernel_size=7, + stride=2, + padding=3, + bias=bias) + elif sample == 'down-5': + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=5, + stride=2, + padding=2, + bias=bias) + self.gate = nn.Conv2d( + in_channels, + out_channels, + kernel_size=5, + stride=2, + padding=2, + bias=bias) + elif sample == 'down-3': + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=bias) + self.gate = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=bias) + else: + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=bias) + self.gate = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=bias) + + if bn: + self.bn = nn.BatchNorm2d(out_channels) + + if activ == 'relu': + self.activation = nn.ReLU() + elif activ == 'leaky': + self.activation = nn.LeakyReLU(negative_slope=0.2) + + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + + images = self.conv(x) + gates = self.sigmoid(self.gate(x)) + + if hasattr(self, 'bn'): + images = self.bn(images) + if hasattr(self, 'activation'): + images = self.activation(images) + + images = images * gates + + return images + + +class GatedConvBNActiv2(nn.Module): + + def __init__(self, + in_channels, + out_channels, + bn=True, + sample='none-3', + activ='relu', + bias=False): + super(GatedConvBNActiv2, self).__init__() + + if sample == 'down-7': + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=7, + stride=2, + padding=3, + bias=bias) + self.gate = nn.Conv2d( + in_channels, + out_channels, + kernel_size=7, + stride=2, + padding=3, + bias=bias) + elif sample == 'down-5': + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=5, + stride=2, + padding=2, + bias=bias) + self.gate = nn.Conv2d( + in_channels, + out_channels, + kernel_size=5, + stride=2, + padding=2, + bias=bias) + elif sample == 'down-3': + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=bias) + self.gate = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=bias) + else: + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=bias) + self.gate = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=bias) + + self.conv_skip = nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=bias) + + if bn: + self.bn = nn.BatchNorm2d(out_channels) + + if activ == 'relu': + self.activation = nn.ReLU() + elif activ == 'leaky': + self.activation = nn.LeakyReLU(negative_slope=0.2) + + self.sigmoid = nn.Sigmoid() + + def forward(self, f_up, f_skip, mask): + x = torch.cat((f_up, f_skip, mask), dim=1) + images = self.conv(x) + images_skip = self.conv_skip(f_skip) + gates = self.sigmoid(self.gate(x)) + + if hasattr(self, 'bn'): + images = self.bn(images) + images_skip = self.bn(images_skip) + if hasattr(self, 'activation'): + images = self.activation(images) + images_skip = self.activation(images_skip) + + images = images * gates + images_skip * (1 - gates) + + return images diff --git a/modelscope/models/cv/skin_retouching/inpainting_model/inpainting_unet.py b/modelscope/models/cv/skin_retouching/inpainting_model/inpainting_unet.py new file mode 100644 index 00000000..09cea1fc --- /dev/null +++ b/modelscope/models/cv/skin_retouching/inpainting_model/inpainting_unet.py @@ -0,0 +1,88 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.models.cv.skin_retouching.inpainting_model.gconv import \ + GatedConvBNActiv +from ..weights_init import weights_init + + +class RetouchingNet(nn.Module): + + def __init__(self, + in_channels=3, + out_channels=3, + up_sampling_node='nearest', + init_weights=True): + super(RetouchingNet, self).__init__() + + self.freeze_ec_bn = False + self.up_sampling_node = up_sampling_node + + self.ec_images_1 = GatedConvBNActiv( + in_channels, 64, bn=False, sample='down-3') + self.ec_images_2 = GatedConvBNActiv(64, 128, sample='down-3') + self.ec_images_3 = GatedConvBNActiv(128, 256, sample='down-3') + self.ec_images_4 = GatedConvBNActiv(256, 512, sample='down-3') + self.ec_images_5 = GatedConvBNActiv(512, 512, sample='down-3') + self.ec_images_6 = GatedConvBNActiv(512, 512, sample='down-3') + + self.dc_images_6 = GatedConvBNActiv(512 + 512, 512, activ='leaky') + self.dc_images_5 = GatedConvBNActiv(512 + 512, 512, activ='leaky') + self.dc_images_4 = GatedConvBNActiv(512 + 256, 256, activ='leaky') + self.dc_images_3 = GatedConvBNActiv(256 + 128, 128, activ='leaky') + self.dc_images_2 = GatedConvBNActiv(128 + 64, 64, activ='leaky') + self.dc_images_1 = GatedConvBNActiv( + 64 + in_channels, + out_channels, + bn=False, + sample='none-3', + activ=None, + bias=True) + + self.tanh = nn.Tanh() + + if init_weights: + self.apply(weights_init()) + + def forward(self, input_images, input_masks): + + ec_images = {} + + ec_images['ec_images_0'] = torch.cat((input_images, input_masks), + dim=1) + ec_images['ec_images_1'] = self.ec_images_1(ec_images['ec_images_0']) + ec_images['ec_images_2'] = self.ec_images_2(ec_images['ec_images_1']) + ec_images['ec_images_3'] = self.ec_images_3(ec_images['ec_images_2']) + + ec_images['ec_images_4'] = self.ec_images_4(ec_images['ec_images_3']) + ec_images['ec_images_5'] = self.ec_images_5(ec_images['ec_images_4']) + ec_images['ec_images_6'] = self.ec_images_6(ec_images['ec_images_5']) + + # -------------- + # images decoder + # -------------- + dc_images = ec_images['ec_images_6'] + for _ in range(6, 0, -1): + ec_images_skip = 'ec_images_{:d}'.format(_ - 1) + dc_conv = 'dc_images_{:d}'.format(_) + + dc_images = F.interpolate( + dc_images, scale_factor=2, mode=self.up_sampling_node) + dc_images = torch.cat((dc_images, ec_images[ec_images_skip]), + dim=1) + + dc_images = getattr(self, dc_conv)(dc_images) + + outputs = self.tanh(dc_images) + + return outputs + + def train(self, mode=True): + + super().train(mode) + + if self.freeze_ec_bn: + for name, module in self.named_modules(): + if isinstance(module, nn.BatchNorm2d): + module.eval() diff --git a/modelscope/models/cv/skin_retouching/retinaface/__init__.py b/modelscope/models/cv/skin_retouching/retinaface/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/skin_retouching/retinaface/box_utils.py b/modelscope/models/cv/skin_retouching/retinaface/box_utils.py new file mode 100644 index 00000000..89cf8bf6 --- /dev/null +++ b/modelscope/models/cv/skin_retouching/retinaface/box_utils.py @@ -0,0 +1,271 @@ +# Implementation in this file is modifed from source code avaiable via https://github.com/ternaus/retinaface +from typing import List, Tuple, Union + +import numpy as np +import torch + + +def point_form(boxes: torch.Tensor) -> torch.Tensor: + """Convert prior_boxes to (x_min, y_min, x_max, y_max) representation for comparison to point form ground truth data. + + Args: + boxes: center-size default boxes from priorbox layers. + Return: + boxes: Converted x_min, y_min, x_max, y_max form of boxes. + """ + return torch.cat( + (boxes[:, :2] - boxes[:, 2:] / 2, boxes[:, :2] + boxes[:, 2:] / 2), + dim=1) + + +def center_size(boxes: torch.Tensor) -> torch.Tensor: + """Convert prior_boxes to (cx, cy, w, h) representation for comparison to center-size form ground truth data. + Args: + boxes: point_form boxes + Return: + boxes: Converted x_min, y_min, x_max, y_max form of boxes. + """ + return torch.cat( + ((boxes[:, 2:] + boxes[:, :2]) / 2, boxes[:, 2:] - boxes[:, :2]), + dim=1) + + +def intersect(box_a: torch.Tensor, box_b: torch.Tensor) -> torch.Tensor: + """ We resize both tensors to [A,B,2] without new malloc: + [A, 2] -> [A, 1, 2] -> [A, B, 2] + [B, 2] -> [1, B, 2] -> [A, B, 2] + Then we compute the area of intersect between box_a and box_b. + Args: + box_a: bounding boxes, Shape: [A, 4]. + box_b: bounding boxes, Shape: [B, 4]. + Return: + intersection area, Shape: [A, B]. + """ + A = box_a.size(0) + B = box_b.size(0) + max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), + box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) + min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), + box_b[:, :2].unsqueeze(0).expand(A, B, 2)) + inter = torch.clamp((max_xy - min_xy), min=0) + return inter[:, :, 0] * inter[:, :, 1] + + +def jaccard(box_a: torch.Tensor, box_b: torch.Tensor) -> torch.Tensor: + """Compute the jaccard overlap of two sets of boxes. The jaccard overlap is simply the intersection over + union of two boxes. Here we operate on ground truth boxes and default boxes. + E.g.: + A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) + Args: + box_a: Ground truth bounding boxes, Shape: [num_objects,4] + box_b: Prior boxes from priorbox layers, Shape: [num_priors,4] + Return: + jaccard overlap: Shape: [box_a.size(0), box_b.size(0)] + """ + inter = intersect(box_a, box_b) + area_a = (box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1]) + area_a = area_a.unsqueeze(1).expand_as(inter) # [A,B] + area_b = (box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1]) + area_b = area_b.unsqueeze(0).expand_as(inter) # [A,B] + union = area_a + area_b - inter + return inter / union + + +def matrix_iof(a: np.ndarray, b: np.ndarray) -> np.ndarray: + """ + return iof of a and b, numpy version for data augmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + return area_i / np.maximum(area_a[:, np.newaxis], 1) + + +def match( + threshold: float, + box_gt: torch.Tensor, + priors: torch.Tensor, + variances: List[float], + labels_gt: torch.Tensor, + landmarks_gt: torch.Tensor, + box_t: torch.Tensor, + label_t: torch.Tensor, + landmarks_t: torch.Tensor, + batch_id: int, +) -> None: + """Match each prior box with the ground truth box of the highest jaccard overlap, encode the bounding + boxes, then return the matched indices corresponding to both confidence and location preds. + + Args: + threshold: The overlap threshold used when matching boxes. + box_gt: Ground truth boxes, Shape: [num_obj, 4]. + priors: Prior boxes from priorbox layers, Shape: [n_priors, 4]. + variances: Variances corresponding to each prior coord, Shape: [num_priors, 4]. + labels_gt: All the class labels for the image, Shape: [num_obj, 2]. + landmarks_gt: Ground truth landms, Shape [num_obj, 10]. + box_t: Tensor to be filled w/ endcoded location targets. + label_t: Tensor to be filled w/ matched indices for labels predictions. + landmarks_t: Tensor to be filled w/ endcoded landmarks targets. + batch_id: current batch index + Return: + The matched indices corresponding to 1)location 2)confidence 3)landmarks preds. + """ + # Compute iou between gt and priors + overlaps = jaccard(box_gt, point_form(priors)) + # (Bipartite Matching) + # [1, num_objects] best prior for each ground truth + best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) + + # ignore hard gt + valid_gt_idx = best_prior_overlap[:, 0] >= 0.2 + best_prior_idx_filter = best_prior_idx[valid_gt_idx, :] + if best_prior_idx_filter.shape[0] <= 0: + box_t[batch_id] = 0 + label_t[batch_id] = 0 + return + + # [1, num_priors] best ground truth for each prior + best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) + best_truth_idx.squeeze_(0) + best_truth_overlap.squeeze_(0) + best_prior_idx.squeeze_(1) + best_prior_idx_filter.squeeze_(1) + best_prior_overlap.squeeze_(1) + best_truth_overlap.index_fill_(0, best_prior_idx_filter, + 2) # ensure best prior + # TODO refactor: index best_prior_idx with long tensor + # ensure every gt matches with its prior of max overlap + for j in range(best_prior_idx.size(0)): + best_truth_idx[best_prior_idx[j]] = j + + matches = box_gt[best_truth_idx] # Shape: [num_priors, 4] + labels = labels_gt[best_truth_idx] # Shape: [num_priors] + # label as background + labels[best_truth_overlap < threshold] = 0 + loc = encode(matches, priors, variances) + + matches_landm = landmarks_gt[best_truth_idx] + landmarks_gt = encode_landm(matches_landm, priors, variances) + box_t[batch_id] = loc # [num_priors, 4] encoded offsets to learn + label_t[batch_id] = labels # [num_priors] top class label for each prior + landmarks_t[batch_id] = landmarks_gt + + +def encode(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 4]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded boxes (tensor), Shape: [num_priors, 4] + """ + + # dist b/t match center and prior's center + g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] + # encode variance + g_cxcy /= variances[0] * priors[:, 2:] + # match wh / prior wh + g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] + g_wh = torch.log(g_wh) / variances[1] + # return target for smooth_l1_loss + return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] + + +def encode_landm( + matched: torch.Tensor, priors: torch.Tensor, + variances: Union[List[float], Tuple[float, float]]) -> torch.Tensor: + """Encode the variances from the priorbox layers into the ground truth boxes we have matched + (based on jaccard overlap) with the prior boxes. + Args: + matched: Coords of ground truth for each prior in point-form + Shape: [num_priors, 10]. + priors: Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: Variances of priorboxes + Return: + encoded landmarks, Shape: [num_priors, 10] + """ + + # dist b/t match center and prior's center + matched = torch.reshape(matched, (matched.size(0), 5, 2)) + priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), + 5).unsqueeze(2) + priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), + 5).unsqueeze(2) + priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), + 5).unsqueeze(2) + priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), + 5).unsqueeze(2) + priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2) + g_cxcy = matched[:, :, :2] - priors[:, :, :2] + # encode variance + g_cxcy = g_cxcy // variances[0] * priors[:, :, 2:] + # return target for smooth_l1_loss + return g_cxcy.reshape(g_cxcy.size(0), -1) + + +# Adapted from https://github.com/Hakuyume/chainer-ssd +def decode(loc: torch.Tensor, priors: torch.Tensor, + variances: Union[List[float], Tuple[float, float]]) -> torch.Tensor: + """Decode locations from predictions using priors to undo the encoding we did for offset regression at train time. + Args: + loc: location predictions for loc layers, + Shape: [num_priors, 4] + priors: Prior boxes in center-offset form. + Shape: [num_priors, 4]. + variances: Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat( + ( + priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1]), + ), + 1, + ) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + +def decode_landm( + pre: torch.Tensor, priors: torch.Tensor, + variances: Union[List[float], Tuple[float, float]]) -> torch.Tensor: + """Decode landmarks from predictions using priors to undo the encoding we did for offset regression at train time. + Args: + pre: landmark predictions for loc layers, + Shape: [num_priors, 10] + priors: Prior boxes in center-offset form. + Shape: [num_priors, 4]. + variances: Variances of priorboxes + Return: + decoded landmark predictions + """ + return torch.cat( + ( + priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:], + ), + dim=1, + ) + + +def log_sum_exp(x: torch.Tensor) -> torch.Tensor: + """Utility function for computing log_sum_exp while determining This will be used to determine unaveraged + confidence loss across all examples in a batch. + Args: + x: conf_preds from conf layers + """ + x_max = x.data.max() + return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max diff --git a/modelscope/models/cv/skin_retouching/retinaface/net.py b/modelscope/models/cv/skin_retouching/retinaface/net.py new file mode 100644 index 00000000..e9b0297b --- /dev/null +++ b/modelscope/models/cv/skin_retouching/retinaface/net.py @@ -0,0 +1,124 @@ +# Implementation in this file is modifed from source code avaiable via https://github.com/ternaus/retinaface +from typing import Dict, List + +import torch +import torch.nn.functional as F +from torch import nn + + +def conv_bn(inp: int, + oup: int, + stride: int = 1, + leaky: float = 0) -> nn.Sequential: + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + ) + + +def conv_bn_no_relu(inp: int, oup: int, stride: int) -> nn.Sequential: + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + ) + + +def conv_bn1X1(inp: int, + oup: int, + stride: int, + leaky: float = 0) -> nn.Sequential: + return nn.Sequential( + nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + ) + + +def conv_dw(inp: int, + oup: int, + stride: int, + leaky: float = 0.1) -> nn.Sequential: + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + ) + + +class SSH(nn.Module): + + def __init__(self, in_channel: int, out_channel: int) -> None: + super().__init__() + if out_channel % 4 != 0: + raise ValueError( + f'Expect out channel % 4 == 0, but we got {out_channel % 4}') + + leaky: float = 0 + if out_channel <= 64: + leaky = 0.1 + self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) + + self.conv5X5_1 = conv_bn( + in_channel, out_channel // 4, stride=1, leaky=leaky) + self.conv5X5_2 = conv_bn_no_relu( + out_channel // 4, out_channel // 4, stride=1) + + self.conv7X7_2 = conv_bn( + out_channel // 4, out_channel // 4, stride=1, leaky=leaky) + self.conv7x7_3 = conv_bn_no_relu( + out_channel // 4, out_channel // 4, stride=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + conv3X3 = self.conv3X3(x) + + conv5X5_1 = self.conv5X5_1(x) + conv5X5 = self.conv5X5_2(conv5X5_1) + + conv7X7_2 = self.conv7X7_2(conv5X5_1) + conv7X7 = self.conv7x7_3(conv7X7_2) + + out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) + + return F.relu(out) + + +class FPN(nn.Module): + + def __init__(self, in_channels_list: List[int], out_channels: int) -> None: + super().__init__() + leaky = 0.0 + if out_channels <= 64: + leaky = 0.1 + + self.output1 = conv_bn1X1( + in_channels_list[0], out_channels, stride=1, leaky=leaky) + self.output2 = conv_bn1X1( + in_channels_list[1], out_channels, stride=1, leaky=leaky) + self.output3 = conv_bn1X1( + in_channels_list[2], out_channels, stride=1, leaky=leaky) + + self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) + self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) + + def forward(self, x: Dict[str, torch.Tensor]) -> List[torch.Tensor]: + y = list(x.values()) + + output1 = self.output1(y[0]) + output2 = self.output2(y[1]) + output3 = self.output3(y[2]) + + up3 = F.interpolate( + output3, size=[output2.size(2), output2.size(3)], mode='nearest') + output2 = output2 + up3 + output2 = self.merge2(output2) + + up2 = F.interpolate( + output2, size=[output1.size(2), output1.size(3)], mode='nearest') + output1 = output1 + up2 + output1 = self.merge1(output1) + + return [output1, output2, output3] diff --git a/modelscope/models/cv/skin_retouching/retinaface/network.py b/modelscope/models/cv/skin_retouching/retinaface/network.py new file mode 100644 index 00000000..3b197ca9 --- /dev/null +++ b/modelscope/models/cv/skin_retouching/retinaface/network.py @@ -0,0 +1,146 @@ +# Implementation in this file is modifed from source code avaiable via https://github.com/ternaus/retinaface +from typing import Dict, Tuple + +import torch +from torch import nn +from torchvision import models +from torchvision.models import _utils + +from .net import FPN, SSH + + +class ClassHead(nn.Module): + + def __init__(self, in_channels: int = 512, num_anchors: int = 3) -> None: + super().__init__() + self.conv1x1 = nn.Conv2d( + in_channels, + num_anchors * 2, + kernel_size=(1, 1), + stride=1, + padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + return out.view(out.shape[0], -1, 2) + + +class BboxHead(nn.Module): + + def __init__(self, in_channels: int = 512, num_anchors: int = 3): + super().__init__() + self.conv1x1 = nn.Conv2d( + in_channels, + num_anchors * 4, + kernel_size=(1, 1), + stride=1, + padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + return out.view(out.shape[0], -1, 4) + + +class LandmarkHead(nn.Module): + + def __init__(self, in_channels: int = 512, num_anchors: int = 3): + super().__init__() + self.conv1x1 = nn.Conv2d( + in_channels, + num_anchors * 10, + kernel_size=(1, 1), + stride=1, + padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + return out.view(out.shape[0], -1, 10) + + +class RetinaFace(nn.Module): + + def __init__(self, name: str, pretrained: bool, in_channels: int, + return_layers: Dict[str, int], out_channels: int) -> None: + super().__init__() + + if name == 'Resnet50': + backbone = models.resnet50(pretrained=pretrained) + else: + raise NotImplementedError( + f'Only Resnet50 backbone is supported but got {name}') + + self.body = _utils.IntermediateLayerGetter(backbone, return_layers) + in_channels_stage2 = in_channels + in_channels_list = [ + in_channels_stage2 * 2, + in_channels_stage2 * 4, + in_channels_stage2 * 8, + ] + self.fpn = FPN(in_channels_list, out_channels) + self.ssh1 = SSH(out_channels, out_channels) + self.ssh2 = SSH(out_channels, out_channels) + self.ssh3 = SSH(out_channels, out_channels) + + self.ClassHead = self._make_class_head( + fpn_num=3, in_channels=out_channels) + self.BboxHead = self._make_bbox_head( + fpn_num=3, in_channels=out_channels) + self.LandmarkHead = self._make_landmark_head( + fpn_num=3, in_channels=out_channels) + + @staticmethod + def _make_class_head(fpn_num: int = 3, + in_channels: int = 64, + anchor_num: int = 2) -> nn.ModuleList: + classhead = nn.ModuleList() + for _ in range(fpn_num): + classhead.append(ClassHead(in_channels, anchor_num)) + return classhead + + @staticmethod + def _make_bbox_head(fpn_num: int = 3, + in_channels: int = 64, + anchor_num: int = 2) -> nn.ModuleList: + bboxhead = nn.ModuleList() + for _ in range(fpn_num): + bboxhead.append(BboxHead(in_channels, anchor_num)) + return bboxhead + + @staticmethod + def _make_landmark_head(fpn_num: int = 3, + in_channels: int = 64, + anchor_num: int = 2) -> nn.ModuleList: + landmarkhead = nn.ModuleList() + for _ in range(fpn_num): + landmarkhead.append(LandmarkHead(in_channels, anchor_num)) + return landmarkhead + + def forward( + self, inputs: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + out = self.body(inputs) + + # FPN + fpn = self.fpn(out) + + # SSH + feature1 = self.ssh1(fpn[0]) + feature2 = self.ssh2(fpn[1]) + feature3 = self.ssh3(fpn[2]) + features = [feature1, feature2, feature3] + + bbox_regressions = torch.cat( + [self.BboxHead[i](feature) for i, feature in enumerate(features)], + dim=1) + classifications = torch.cat( + [self.ClassHead[i](feature) for i, feature in enumerate(features)], + dim=1) + ldm_regressions = [ + self.LandmarkHead[i](feature) for i, feature in enumerate(features) + ] + ldm_regressions = torch.cat(ldm_regressions, dim=1) + + return bbox_regressions, classifications, ldm_regressions diff --git a/modelscope/models/cv/skin_retouching/retinaface/predict_single.py b/modelscope/models/cv/skin_retouching/retinaface/predict_single.py new file mode 100644 index 00000000..659a1134 --- /dev/null +++ b/modelscope/models/cv/skin_retouching/retinaface/predict_single.py @@ -0,0 +1,152 @@ +# Implementation in this file is modifed from source code avaiable via https://github.com/ternaus/retinaface +""" +There is a lot of post processing of the predictions. +""" +from typing import Dict, List, Union + +import albumentations as A +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.ops import nms + +from ..utils import pad_to_size, unpad_from_size +from .box_utils import decode, decode_landm +from .network import RetinaFace +from .prior_box import priorbox +from .utils import tensor_from_rgb_image + + +class Model: + + def __init__(self, max_size: int = 960, device: str = 'cpu') -> None: + self.model = RetinaFace( + name='Resnet50', + pretrained=False, + return_layers={ + 'layer2': 1, + 'layer3': 2, + 'layer4': 3 + }, + in_channels=256, + out_channels=256, + ).to(device) + self.device = device + self.transform = A.Compose( + [A.LongestMaxSize(max_size=max_size, p=1), + A.Normalize(p=1)]) + self.max_size = max_size + self.prior_box = priorbox( + min_sizes=[[16, 32], [64, 128], [256, 512]], + steps=[8, 16, 32], + clip=False, + image_size=(self.max_size, self.max_size), + ).to(device) + self.variance = [0.1, 0.2] + + def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: + self.model.load_state_dict(state_dict) + + def eval(self): + self.model.eval() + + def predict_jsons( + self, + image: np.array, + confidence_threshold: float = 0.7, + nms_threshold: float = 0.4) -> List[Dict[str, Union[List, float]]]: + with torch.no_grad(): + original_height, original_width = image.shape[:2] + + scale_landmarks = torch.from_numpy( + np.tile([self.max_size, self.max_size], + 5)).to(self.device).float() + scale_bboxes = torch.from_numpy( + np.tile([self.max_size, self.max_size], + 2)).to(self.device).float() + + transformed_image = self.transform(image=image)['image'] + + paded = pad_to_size( + target_size=(self.max_size, self.max_size), + image=transformed_image) + + pads = paded['pads'] + + torched_image = tensor_from_rgb_image(paded['image']).to( + self.device) + + loc, conf, land = self.model(torched_image.unsqueeze(0)) + + conf = F.softmax(conf, dim=-1) + + annotations: List[Dict[str, Union[List, float]]] = [] + + boxes = decode(loc.data[0], self.prior_box, self.variance) + + boxes *= scale_bboxes + scores = conf[0][:, 1] + + landmarks = decode_landm(land.data[0], self.prior_box, + self.variance) + landmarks *= scale_landmarks + + # ignore low scores + valid_index = scores > confidence_threshold + boxes = boxes[valid_index] + landmarks = landmarks[valid_index] + scores = scores[valid_index] + + # Sort from high to low + order = scores.argsort(descending=True) + boxes = boxes[order] + landmarks = landmarks[order] + scores = scores[order] + + # do NMS + keep = nms(boxes, scores, nms_threshold) + boxes = boxes[keep, :].int() + + if boxes.shape[0] == 0: + return [{'bbox': [], 'score': -1, 'landmarks': []}] + + landmarks = landmarks[keep] + + scores = scores[keep].cpu().numpy().astype(np.float64) + boxes = boxes.cpu().numpy() + landmarks = landmarks.cpu().numpy() + landmarks = landmarks.reshape([-1, 2]) + + unpadded = unpad_from_size(pads, bboxes=boxes, keypoints=landmarks) + + resize_coeff = max(original_height, original_width) / self.max_size + + boxes = (unpadded['bboxes'] * resize_coeff).astype(int) + landmarks = (unpadded['keypoints'].reshape(-1, 10) + * resize_coeff).astype(int) + + for box_id, bbox in enumerate(boxes): + x_min, y_min, x_max, y_max = bbox + + x_min = np.clip(x_min, 0, original_width - 1) + x_max = np.clip(x_max, x_min + 1, original_width - 1) + + if x_min >= x_max: + continue + + y_min = np.clip(y_min, 0, original_height - 1) + y_max = np.clip(y_max, y_min + 1, original_height - 1) + + if y_min >= y_max: + continue + + annotations += [{ + 'bbox': + bbox.tolist(), + 'score': + scores[box_id], + 'landmarks': + landmarks[box_id].reshape(-1, 2).tolist(), + }] + + return annotations diff --git a/modelscope/models/cv/skin_retouching/retinaface/prior_box.py b/modelscope/models/cv/skin_retouching/retinaface/prior_box.py new file mode 100644 index 00000000..863a676c --- /dev/null +++ b/modelscope/models/cv/skin_retouching/retinaface/prior_box.py @@ -0,0 +1,28 @@ +# Implementation in this file is modifed from source code avaiable via https://github.com/ternaus/retinaface +from itertools import product +from math import ceil + +import torch + + +def priorbox(min_sizes, steps, clip, image_size): + feature_maps = [[ceil(image_size[0] / step), + ceil(image_size[1] / step)] for step in steps] + + anchors = [] + for k, f in enumerate(feature_maps): + t_min_sizes = min_sizes[k] + for i, j in product(range(f[0]), range(f[1])): + for min_size in t_min_sizes: + s_kx = min_size / image_size[1] + s_ky = min_size / image_size[0] + dense_cx = [x * steps[k] / image_size[1] for x in [j + 0.5]] + dense_cy = [y * steps[k] / image_size[0] for y in [i + 0.5]] + for cy, cx in product(dense_cy, dense_cx): + anchors += [cx, cy, s_kx, s_ky] + + # back to torch land + output = torch.Tensor(anchors).view(-1, 4) + if clip: + output.clamp_(max=1, min=0) + return output diff --git a/modelscope/models/cv/skin_retouching/retinaface/utils.py b/modelscope/models/cv/skin_retouching/retinaface/utils.py new file mode 100644 index 00000000..c6b97484 --- /dev/null +++ b/modelscope/models/cv/skin_retouching/retinaface/utils.py @@ -0,0 +1,70 @@ +# Implementation in this file is modifed from source code avaiable via https://github.com/ternaus/retinaface +import re +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import cv2 +import numpy as np +import torch + + +def load_checkpoint(file_path: Union[Path, str], + rename_in_layers: Optional[dict] = None) -> Dict[str, Any]: + """Loads PyTorch checkpoint, optionally renaming layer names. + Args: + file_path: path to the torch checkpoint. + rename_in_layers: {from_name: to_name} + ex: {"model.0.": "", + "model.": ""} + Returns: + """ + checkpoint = torch.load( + file_path, map_location=lambda storage, loc: storage) + + if rename_in_layers is not None: + model_state_dict = checkpoint['state_dict'] + + result = {} + for key, value in model_state_dict.items(): + for key_r, value_r in rename_in_layers.items(): + key = re.sub(key_r, value_r, key) + + result[key] = value + + checkpoint['state_dict'] = result + + return checkpoint + + +def tensor_from_rgb_image(image: np.ndarray) -> torch.Tensor: + image = np.transpose(image, (2, 0, 1)) + return torch.from_numpy(image) + + +def vis_annotations(image: np.ndarray, + annotations: List[Dict[str, Any]]) -> np.ndarray: + vis_image = image.copy() + + for annotation in annotations: + landmarks = annotation['landmarks'] + + colors = [(255, 0, 0), (128, 255, 0), (255, 178, 102), (102, 128, 255), + (0, 255, 255)] + + for landmark_id, (x, y) in enumerate(landmarks): + vis_image = cv2.circle( + vis_image, (x, y), + radius=3, + color=colors[landmark_id], + thickness=3) + + x_min, y_min, x_max, y_max = annotation['bbox'] + + x_min = np.clip(x_min, 0, x_max - 1) + y_min = np.clip(y_min, 0, y_max - 1) + + vis_image = cv2.rectangle( + vis_image, (x_min, y_min), (x_max, y_max), + color=(0, 255, 0), + thickness=2) + return vis_image diff --git a/modelscope/models/cv/skin_retouching/unet_deploy.py b/modelscope/models/cv/skin_retouching/unet_deploy.py new file mode 100755 index 00000000..cb37b04c --- /dev/null +++ b/modelscope/models/cv/skin_retouching/unet_deploy.py @@ -0,0 +1,143 @@ +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .weights_init import weights_init + +warnings.filterwarnings(action='ignore') + + +class double_conv(nn.Module): + '''(conv => BN => ReLU) * 2''' + + def __init__(self, in_ch, out_ch): + super(double_conv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), + nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True)) + + def forward(self, x): + x = self.conv(x) + return x + + +class inconv(nn.Module): + + def __init__(self, in_ch, out_ch): + super(inconv, self).__init__() + self.conv = double_conv(in_ch, out_ch) + + def forward(self, x): + x = self.conv(x) + return x + + +class down(nn.Module): + + def __init__(self, in_ch, out_ch): + super(down, self).__init__() + self.mpconv = nn.Sequential( + nn.MaxPool2d(2), double_conv(in_ch, out_ch)) + + def forward(self, x): + x = self.mpconv(x) + return x + + +class up(nn.Module): + + def __init__(self, in_ch, out_ch, bilinear=True): + super(up, self).__init__() + + if bilinear: + self.up = nn.Upsample( + scale_factor=2, mode='bilinear', align_corners=True) + else: + self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2) + + self.conv = double_conv(in_ch, out_ch) + + def forward(self, x1, x2): + x1 = self.up(x1) + + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + + x1 = F.pad( + x1, + (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)) + + x = torch.cat([x2, x1], dim=1) + x = self.conv(x) + return x + + +class outconv(nn.Module): + + def __init__(self, in_ch, out_ch): + super(outconv, self).__init__() + self.conv = nn.Conv2d(in_ch, out_ch, 1) + + def forward(self, x): + x = self.conv(x) + return x + + +class UNet(nn.Module): + + def __init__(self, + n_channels, + n_classes, + deep_supervision=False, + init_weights=True): + super(UNet, self).__init__() + self.deep_supervision = deep_supervision + self.inc = inconv(n_channels, 64) + self.down1 = down(64, 128) + self.down2 = down(128, 256) + self.down3 = down(256, 512) + self.down4 = down(512, 512) + self.up1 = up(1024, 256) + self.up2 = up(512, 128) + self.up3 = up(256, 64) + self.up4 = up(128, 64) + self.outc = outconv(64, n_classes) + + self.dsoutc4 = outconv(256, n_classes) + self.dsoutc3 = outconv(128, n_classes) + self.dsoutc2 = outconv(64, n_classes) + self.dsoutc1 = outconv(64, n_classes) + + self.sigmoid = nn.Sigmoid() + + if init_weights: + self.apply(weights_init()) + + def forward(self, x): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x44 = self.up1(x5, x4) + x33 = self.up2(x44, x3) + x22 = self.up3(x33, x2) + x11 = self.up4(x22, x1) + x0 = self.outc(x11) + x0 = self.sigmoid(x0) + if self.deep_supervision: + x11 = F.interpolate( + self.dsoutc1(x11), x0.shape[2:], mode='bilinear') + x22 = F.interpolate( + self.dsoutc2(x22), x0.shape[2:], mode='bilinear') + x33 = F.interpolate( + self.dsoutc3(x33), x0.shape[2:], mode='bilinear') + x44 = F.interpolate( + self.dsoutc4(x44), x0.shape[2:], mode='bilinear') + + return x0, x11, x22, x33, x44 + else: + return x0 diff --git a/modelscope/models/cv/skin_retouching/utils.py b/modelscope/models/cv/skin_retouching/utils.py new file mode 100644 index 00000000..12653f41 --- /dev/null +++ b/modelscope/models/cv/skin_retouching/utils.py @@ -0,0 +1,327 @@ +import time +from typing import Dict, List, Optional, Tuple, Union + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from einops import rearrange + +__all__ = [ + 'gen_diffuse_mask', 'get_crop_bbox', 'get_roi_without_padding', + 'patch_aggregation_overlap', 'patch_partition_overlap', 'preprocess_roi', + 'resize_on_long_side', 'roi_to_tensor', 'smooth_border_mg', 'whiten_img' +] + + +def resize_on_long_side(img, long_side=800): + src_height = img.shape[0] + src_width = img.shape[1] + + if src_height > src_width: + scale = long_side * 1.0 / src_height + _img = cv2.resize( + img, (int(src_width * scale), long_side), + interpolation=cv2.INTER_LINEAR) + else: + scale = long_side * 1.0 / src_width + _img = cv2.resize( + img, (long_side, int(src_height * scale)), + interpolation=cv2.INTER_LINEAR) + + return _img, scale + + +def get_crop_bbox(detecting_results): + boxes = [] + for anno in detecting_results: + if anno['score'] == -1: + break + boxes.append({ + 'x1': anno['bbox'][0], + 'y1': anno['bbox'][1], + 'x2': anno['bbox'][2], + 'y2': anno['bbox'][3] + }) + face_count = len(boxes) + + suitable_bboxes = [] + for i in range(face_count): + face_bbox = boxes[i] + + face_bbox_width = abs(face_bbox['x2'] - face_bbox['x1']) + face_bbox_height = abs(face_bbox['y2'] - face_bbox['y1']) + + face_bbox_center = ((face_bbox['x1'] + face_bbox['x2']) / 2, + (face_bbox['y1'] + face_bbox['y2']) / 2) + + square_bbox_length = face_bbox_height if face_bbox_height > face_bbox_width else face_bbox_width + enlarge_ratio = 1.5 + square_bbox_length = int(enlarge_ratio * square_bbox_length) + + sideScale = 1 + + square_bbox = { + 'x1': + int(face_bbox_center[0] - sideScale * square_bbox_length / 2), + 'x2': + int(face_bbox_center[0] + sideScale * square_bbox_length / 2), + 'y1': + int(face_bbox_center[1] - sideScale * square_bbox_length / 2), + 'y2': int(face_bbox_center[1] + sideScale * square_bbox_length / 2) + } + + suitable_bboxes.append(square_bbox) + + return suitable_bboxes + + +def get_roi_without_padding(img, bbox): + crop_t = max(bbox['y1'], 0) + crop_b = min(bbox['y2'], img.shape[0]) + crop_l = max(bbox['x1'], 0) + crop_r = min(bbox['x2'], img.shape[1]) + roi = img[crop_t:crop_b, crop_l:crop_r] + return roi, 0, [crop_t, crop_b, crop_l, crop_r] + + +def roi_to_tensor(img): + img = torch.from_numpy(img.transpose((2, 0, 1)))[None, ...] + + return img + + +def preprocess_roi(img): + img = img.float() / 255.0 + img = (img - 0.5) * 2 + + return img + + +def patch_partition_overlap(image, p1, p2, padding=32): + + B, C, H, W = image.size() + h, w = H // p1, W // p2 + image = F.pad( + image, + pad=(padding, padding, padding, padding, 0, 0), + mode='constant', + value=0) + + patch_list = [] + for i in range(h): + for j in range(w): + patch = image[:, :, p1 * i:p1 * (i + 1) + padding * 2, + p2 * j:p2 * (j + 1) + padding * 2] + patch_list.append(patch) + + output = torch.cat( + patch_list, dim=0) # (b h w) c (p1 + 2 * padding) (p2 + 2 * padding) + return output + + +def patch_aggregation_overlap(image, h, w, padding=32): + + image = image[:, :, padding:-padding, padding:-padding] + + output = rearrange(image, '(b h w) c p1 p2 -> b c (h p1) (w p2)', h=h, w=w) + + return output + + +def smooth_border_mg(diffuse_mask, mg): + mg = mg - 0.5 + diffuse_mask = F.interpolate( + diffuse_mask, mg.shape[:2], mode='bilinear')[0].permute(1, 2, 0) + mg = mg * diffuse_mask + mg = mg + 0.5 + return mg + + +def whiten_img(image, skin_mask, whitening_degree, flag_bigKernal=False): + """ + image: rgb + """ + dilate_kernalsize = 30 + if flag_bigKernal: + dilate_kernalsize = 80 + new_kernel1 = cv2.getStructuringElement( + cv2.MORPH_ELLIPSE, (dilate_kernalsize, dilate_kernalsize)) + new_kernel2 = cv2.getStructuringElement( + cv2.MORPH_ELLIPSE, (dilate_kernalsize, dilate_kernalsize)) + if len(skin_mask.shape) == 3: + skin_mask = skin_mask[:, :, -1] + skin_mask = cv2.dilate(skin_mask, new_kernel1, 1) + skin_mask = cv2.erode(skin_mask, new_kernel2, 1) + skin_mask = cv2.blur(skin_mask, (20, 20)) / 255.0 + skin_mask = skin_mask.squeeze() + skin_mask = torch.from_numpy(skin_mask).to(image.device) + skin_mask = torch.stack([skin_mask, skin_mask, skin_mask], dim=0)[None, + ...] + skin_mask[:, 1:, :, :] *= 0.75 + + whiten_mg = skin_mask * 0.2 * whitening_degree + 0.5 + assert len(whiten_mg.shape) == 4 + whiten_mg = F.interpolate( + whiten_mg, image.shape[:2], mode='bilinear')[0].permute(1, 2, + 0).half() + output_pred = image.half() + output_pred = output_pred / 255.0 + output_pred = ( + -2 * whiten_mg + 1 + ) * output_pred * output_pred + 2 * whiten_mg * output_pred # value: 0~1 + output_pred = output_pred * 255.0 + output_pred = output_pred.byte() + + output_pred = output_pred.cpu().numpy() + return output_pred + + +def gen_diffuse_mask(out_channels=3): + mask_size = 500 + diffuse_with = 20 + a = np.ones(shape=(mask_size, mask_size), dtype=np.float32) + + for i in range(mask_size): + for j in range(mask_size): + if i >= diffuse_with and i <= ( + mask_size - diffuse_with) and j >= diffuse_with and j <= ( + mask_size - diffuse_with): + a[i, j] = 1.0 + elif i <= diffuse_with: + a[i, j] = i * 1.0 / diffuse_with + elif i > (mask_size - diffuse_with): + a[i, j] = (mask_size - i) * 1.0 / diffuse_with + + for i in range(mask_size): + for j in range(mask_size): + if j <= diffuse_with: + a[i, j] = min(a[i, j], j * 1.0 / diffuse_with) + elif j > (mask_size - diffuse_with): + a[i, j] = min(a[i, j], (mask_size - j) * 1.0 / diffuse_with) + a = np.dstack([a] * out_channels) + return a + + +def pad_to_size( + target_size: Tuple[int, int], + image: np.array, + bboxes: Optional[np.ndarray] = None, + keypoints: Optional[np.ndarray] = None, +) -> Dict[str, Union[np.ndarray, Tuple[int, int, int, int]]]: + """Pads the image on the sides to the target_size + + Args: + target_size: (target_height, target_width) + image: + bboxes: np.array with shape (num_boxes, 4). Each row: [x_min, y_min, x_max, y_max] + keypoints: np.array with shape (num_keypoints, 2), each row: [x, y] + + Returns: + { + "image": padded_image, + "pads": (x_min_pad, y_min_pad, x_max_pad, y_max_pad), + "bboxes": shifted_boxes, + "keypoints": shifted_keypoints + } + + """ + target_height, target_width = target_size + + image_height, image_width = image.shape[:2] + + if target_width < image_width: + raise ValueError(f'Target width should bigger than image_width' + f'We got {target_width} {image_width}') + + if target_height < image_height: + raise ValueError(f'Target height should bigger than image_height' + f'We got {target_height} {image_height}') + + if image_height == target_height: + y_min_pad = 0 + y_max_pad = 0 + else: + y_pad = target_height - image_height + y_min_pad = y_pad // 2 + y_max_pad = y_pad - y_min_pad + + if image_width == target_width: + x_min_pad = 0 + x_max_pad = 0 + else: + x_pad = target_width - image_width + x_min_pad = x_pad // 2 + x_max_pad = x_pad - x_min_pad + + result = { + 'pads': (x_min_pad, y_min_pad, x_max_pad, y_max_pad), + 'image': + cv2.copyMakeBorder(image, y_min_pad, y_max_pad, x_min_pad, x_max_pad, + cv2.BORDER_CONSTANT), + } + + if bboxes is not None: + bboxes[:, 0] += x_min_pad + bboxes[:, 1] += y_min_pad + bboxes[:, 2] += x_min_pad + bboxes[:, 3] += y_min_pad + + result['bboxes'] = bboxes + + if keypoints is not None: + keypoints[:, 0] += x_min_pad + keypoints[:, 1] += y_min_pad + + result['keypoints'] = keypoints + + return result + + +def unpad_from_size( + pads: Tuple[int, int, int, int], + image: Optional[np.array] = None, + bboxes: Optional[np.ndarray] = None, + keypoints: Optional[np.ndarray] = None, +) -> Dict[str, np.ndarray]: + """Crops patch from the center so that sides are equal to pads. + + Args: + image: + pads: (x_min_pad, y_min_pad, x_max_pad, y_max_pad) + bboxes: np.array with shape (num_boxes, 4). Each row: [x_min, y_min, x_max, y_max] + keypoints: np.array with shape (num_keypoints, 2), each row: [x, y] + + Returns: cropped image + + { + "image": cropped_image, + "bboxes": shifted_boxes, + "keypoints": shifted_keypoints + } + + """ + x_min_pad, y_min_pad, x_max_pad, y_max_pad = pads + + result = {} + + if image is not None: + height, width = image.shape[:2] + result['image'] = image[y_min_pad:height - y_max_pad, + x_min_pad:width - x_max_pad] + + if bboxes is not None: + bboxes[:, 0] -= x_min_pad + bboxes[:, 1] -= y_min_pad + bboxes[:, 2] -= x_min_pad + bboxes[:, 3] -= y_min_pad + + result['bboxes'] = bboxes + + if keypoints is not None: + keypoints[:, 0] -= x_min_pad + keypoints[:, 1] -= y_min_pad + + result['keypoints'] = keypoints + + return result diff --git a/modelscope/models/cv/skin_retouching/weights_init.py b/modelscope/models/cv/skin_retouching/weights_init.py new file mode 100644 index 00000000..efd24843 --- /dev/null +++ b/modelscope/models/cv/skin_retouching/weights_init.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn + + +def weights_init(init_type='kaiming', gain=0.02): + + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 + or classname.find('Linear') != -1): + + if init_type == 'normal': + nn.init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + nn.init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'kaiming': + nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + nn.init.orthogonal_(m.weight.data, gain=gain) + + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + + elif classname.find('BatchNorm2d') != -1: + nn.init.normal_(m.weight.data, 1.0, gain) + nn.init.constant_(m.bias.data, 0.0) + + return init_func + + +def spectral_norm(module, mode=True): + + if mode: + return nn.utils.spectral_norm(module) + + return module diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 28c03e73..6ef21752 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -122,6 +122,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.image_classification: (Pipelines.daily_image_classification, 'damo/cv_vit-base_image-classification_Dailylife-labels'), + Tasks.skin_retouching: (Pipelines.skin_retouching, + 'damo/cv_unet_skin-retouching'), } diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 3c7f6092..d8b09c63 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -27,6 +27,7 @@ if TYPE_CHECKING: from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline from .live_category_pipeline import LiveCategoryPipeline from .ocr_detection_pipeline import OCRDetectionPipeline + from .skin_retouching_pipeline import SkinRetouchingPipeline from .video_category_pipeline import VideoCategoryPipeline from .virtual_try_on_pipeline import VirtualTryonPipeline else: @@ -59,6 +60,7 @@ else: 'image_to_image_generation_pipeline': ['Image2ImageGenerationePipeline'], 'ocr_detection_pipeline': ['OCRDetectionPipeline'], + 'skin_retouching_pipeline': ['SkinRetouchingPipeline'], 'video_category_pipeline': ['VideoCategoryPipeline'], 'virtual_try_on_pipeline': ['VirtualTryonPipeline'], } diff --git a/modelscope/pipelines/cv/skin_retouching_pipeline.py b/modelscope/pipelines/cv/skin_retouching_pipeline.py new file mode 100644 index 00000000..056409df --- /dev/null +++ b/modelscope/pipelines/cv/skin_retouching_pipeline.py @@ -0,0 +1,302 @@ +import os +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import tensorflow as tf +import torch +import torch.nn.functional as F +import torchvision.transforms as transforms + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.skin_retouching.detection_model.detection_unet_in import \ + DetectionUNet +from modelscope.models.cv.skin_retouching.inpainting_model.inpainting_unet import \ + RetouchingNet +from modelscope.models.cv.skin_retouching.retinaface.predict_single import \ + Model +from modelscope.models.cv.skin_retouching.unet_deploy import UNet +from modelscope.models.cv.skin_retouching.utils import * # noqa F403 +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + tf.disable_eager_execution() + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.skin_retouching, module_name=Pipelines.skin_retouching) +class SkinRetouchingPipeline(Pipeline): + + def __init__(self, model: str, device: str): + """ + use `model` to create a skin retouching pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + + if device == 'gpu': + device = 'cuda' + model_path = os.path.join(self.model, ModelFile.TORCH_MODEL_FILE) + detector_model_path = os.path.join( + self.model, 'retinaface_resnet50_2020-07-20_old_torch.pth') + local_model_path = os.path.join(self.model, 'joint_20210926.pth') + skin_model_path = os.path.join(self.model, ModelFile.TF_GRAPH_FILE) + + self.generator = UNet(3, 3).to(device) + self.generator.load_state_dict( + torch.load(model_path, map_location='cpu')['generator']) + self.generator.eval() + + self.detector = Model(max_size=512, device=device) + state_dict = torch.load(detector_model_path, map_location='cpu') + self.detector.load_state_dict(state_dict) + self.detector.eval() + + self.local_model_path = local_model_path + ckpt_dict_load = torch.load(self.local_model_path, map_location='cpu') + self.inpainting_net = RetouchingNet( + in_channels=4, out_channels=3).to(device) + self.detection_net = DetectionUNet( + n_channels=3, n_classes=1).to(device) + + self.inpainting_net.load_state_dict(ckpt_dict_load['inpainting_net']) + self.detection_net.load_state_dict(ckpt_dict_load['detection_net']) + + self.inpainting_net.eval() + self.detection_net.eval() + + self.patch_size = 512 + + self.skin_model_path = skin_model_path + if self.skin_model_path is not None: + config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.per_process_gpu_memory_fraction = 0.3 + config.gpu_options.allow_growth = True + self.sess = tf.Session(config=config) + with tf.gfile.FastGFile(self.skin_model_path, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + self.sess.graph.as_default() + tf.import_graph_def(graph_def, name='') + self.sess.run(tf.global_variables_initializer()) + + self.image_files_transforms = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + + self.diffuse_mask = gen_diffuse_mask() + self.diffuse_mask = torch.from_numpy( + self.diffuse_mask).to(device).float() + self.diffuse_mask = self.diffuse_mask.permute(2, 0, 1)[None, ...] + + self.input_size = 512 + self.device = device + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + if len(img.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + img = img.astype(np.float) + result = {'img': img} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + rgb_image = input['img'].astype(np.uint8) + + retouch_local = True + whitening = True + degree = 1.0 + whitening_degree = 0.8 + return_mg = False + + with torch.no_grad(): + if whitening and whitening_degree > 0 and self.skin_model_path is not None: + rgb_image_small, resize_scale = resize_on_long_side( + rgb_image, 800) + skin_mask = self.sess.run( + self.sess.graph.get_tensor_by_name('output_png:0'), + feed_dict={'input_image:0': rgb_image_small}) + + output_pred = torch.from_numpy(rgb_image).to(self.device) + if return_mg: + output_mg = np.ones( + (rgb_image.shape[0], rgb_image.shape[1], 3), + dtype=np.float32) * 0.5 + + results = self.detector.predict_jsons( + rgb_image + ) # list, [{'bbox':, [x1, y1, x2, y2], 'score'...}, ...] + + crop_bboxes = get_crop_bbox(results) + + face_num = len(crop_bboxes) + if face_num == 0: + output = { + 'pred': output_pred.cpu().numpy()[:, :, ::-1], + 'face_num': face_num + } + return output + + flag_bigKernal = False + for bbox in crop_bboxes: + roi, expand, crop_tblr = get_roi_without_padding( + rgb_image, bbox) + roi = roi_to_tensor(roi) # bgr -> rgb + + if roi.shape[2] > 0.4 * rgb_image.shape[0]: + flag_bigKernal = True + + roi = roi.to(self.device) + + roi = preprocess_roi(roi) + + if retouch_local and self.local_model_path is not None: + roi = self.retouch_local(roi) + + roi_output = self.predict_roi( + roi, + degree=degree, + smooth_border=True, + return_mg=return_mg) + + roi_pred = roi_output['pred'] + output_pred[crop_tblr[0]:crop_tblr[1], + crop_tblr[2]:crop_tblr[3]] = roi_pred + + if return_mg: + roi_mg = roi_output['pred_mg'] + output_mg[crop_tblr[0]:crop_tblr[1], + crop_tblr[2]:crop_tblr[3]] = roi_mg + + if whitening and whitening_degree > 0 and self.skin_model_path is not None: + output_pred = whiten_img( + output_pred, + skin_mask, + whitening_degree, + flag_bigKernal=flag_bigKernal) + + if not isinstance(output_pred, np.ndarray): + output_pred = output_pred.cpu().numpy() + + output_pred = output_pred[:, :, ::-1] + + return {OutputKeys.OUTPUT_IMG: output_pred} + + def retouch_local(self, image): + """ + image: rgb + """ + with torch.no_grad(): + sub_H, sub_W = image.shape[2:] + + sub_image_standard = F.interpolate( + image, size=(768, 768), mode='bilinear', align_corners=True) + sub_mask_pred = torch.sigmoid( + self.detection_net(sub_image_standard)) + sub_mask_pred = F.interpolate( + sub_mask_pred, size=(sub_H, sub_W), mode='nearest') + + sub_mask_pred_hard_low = (sub_mask_pred >= 0.35).float() + sub_mask_pred_hard_high = (sub_mask_pred >= 0.5).float() + sub_mask_pred = sub_mask_pred * ( + 1 - sub_mask_pred_hard_high) + sub_mask_pred_hard_high + sub_mask_pred = sub_mask_pred * sub_mask_pred_hard_low + sub_mask_pred = 1 - sub_mask_pred + + sub_H_standard = sub_H if sub_H % self.patch_size == 0 else ( + sub_H // self.patch_size + 1) * self.patch_size + sub_W_standard = sub_W if sub_W % self.patch_size == 0 else ( + sub_W // self.patch_size + 1) * self.patch_size + + sub_image_padding = F.pad( + image, + pad=(0, sub_W_standard - sub_W, 0, sub_H_standard - sub_H, 0, + 0), + mode='constant', + value=0) + sub_mask_pred_padding = F.pad( + sub_mask_pred, + pad=(0, sub_W_standard - sub_W, 0, sub_H_standard - sub_H, 0, + 0), + mode='constant', + value=0) + + sub_image_padding = patch_partition_overlap( + sub_image_padding, p1=self.patch_size, p2=self.patch_size) + sub_mask_pred_padding = patch_partition_overlap( + sub_mask_pred_padding, p1=self.patch_size, p2=self.patch_size) + B_padding, C_padding, _, _ = sub_image_padding.size() + + sub_comp_padding_list = [] + for window_item in range(B_padding): + sub_image_padding_window = sub_image_padding[ + window_item:window_item + 1] + sub_mask_pred_padding_window = sub_mask_pred_padding[ + window_item:window_item + 1] + + sub_input_image_padding_window = sub_image_padding_window * sub_mask_pred_padding_window + + sub_output_padding_window = self.inpainting_net( + sub_input_image_padding_window, + sub_mask_pred_padding_window) + sub_comp_padding_window = sub_input_image_padding_window + ( + 1 + - sub_mask_pred_padding_window) * sub_output_padding_window + + sub_comp_padding_list.append(sub_comp_padding_window) + + sub_comp_padding = torch.cat(sub_comp_padding_list, dim=0) + sub_comp = patch_aggregation_overlap( + sub_comp_padding, + h=int(round(sub_H_standard / self.patch_size)), + w=int(round(sub_W_standard + / self.patch_size)))[:, :, :sub_H, :sub_W] + + return sub_comp + + def predict_roi(self, + roi, + degree=1.0, + smooth_border=False, + return_mg=False): + with torch.no_grad(): + image = F.interpolate( + roi, (self.input_size, self.input_size), mode='bilinear') + + pred_mg = self.generator(image) # value: 0~1 + pred_mg = (pred_mg - 0.5) * degree + 0.5 + pred_mg = pred_mg.clamp(0.0, 1.0) + pred_mg = F.interpolate(pred_mg, roi.shape[2:], mode='bilinear') + pred_mg = pred_mg[0].permute( + 1, 2, 0) # ndarray, (h, w, 1) or (h0, w0, 3) + if len(pred_mg.shape) == 2: + pred_mg = pred_mg[..., None] + + if smooth_border: + pred_mg = smooth_border_mg(self.diffuse_mask, pred_mg) + + image = (roi[0].permute(1, 2, 0) + 1.0) / 2 + + pred = (1 - 2 * pred_mg + ) * image * image + 2 * pred_mg * image # value: 0~1 + + pred = (pred * 255.0).byte() # ndarray, (h, w, 3), rgb + + output = {'pred': pred} + if return_mg: + output['pred_mg'] = pred_mg.cpu().numpy() + return output + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/tests/pipelines/test_skin_retouching.py b/tests/pipelines/test_skin_retouching.py new file mode 100644 index 00000000..54cdaa73 --- /dev/null +++ b/tests/pipelines/test_skin_retouching.py @@ -0,0 +1,46 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import os.path as osp +import unittest + +import cv2 + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class SkinRetouchingTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_unet_skin-retouching' + self.test_image = 'data/test/images/skin_retouching.png' + + def pipeline_inference(self, pipeline: Pipeline, input_location: str): + result = pipeline(input_location) + cv2.imwrite('result_skinretouching.png', result[OutputKeys.OUTPUT_IMG]) + print(f'Output written to {osp.abspath("result_skinretouching.png")}') + + @unittest.skip('deprecated, download model from model hub instead') + def test_run_by_direct_model_download(self): + model_dir = snapshot_download(self.model_id) + + skin_retouching = pipeline(Tasks.skin_retouching, model=model_dir) + self.pipeline_inference(skin_retouching, self.test_image) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_modelhub(self): + skin_retouching = pipeline(Tasks.skin_retouching, model=self.model_id) + self.pipeline_inference(skin_retouching, self.test_image) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + skin_retouching = pipeline(Tasks.skin_retouching) + self.pipeline_inference(skin_retouching, self.test_image) + + +if __name__ == '__main__': + unittest.main()