Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9442509master
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:0fcd36e0ada8a506bb09d3e0f3594e2be978194ea4123e066331c0bcb7fc79bc | |||
size 683425 |
@@ -94,6 +94,7 @@ class Pipelines(object): | |||
video_category = 'video-category' | |||
image_portrait_enhancement = 'gpen-image-portrait-enhancement' | |||
image_to_image_generation = 'image-to-image-generation' | |||
skin_retouching = 'unet-skin-retouching' | |||
# nlp tasks | |||
sentence_similarity = 'sentence-similarity' | |||
@@ -0,0 +1,65 @@ | |||
import torch | |||
import torch.nn as nn | |||
class ConvBNActiv(nn.Module): | |||
def __init__(self, | |||
in_channels, | |||
out_channels, | |||
bn=True, | |||
sample='none-3', | |||
activ='relu', | |||
bias=False): | |||
super(ConvBNActiv, self).__init__() | |||
if sample == 'down-7': | |||
self.conv = nn.Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size=7, | |||
stride=2, | |||
padding=3, | |||
bias=bias) | |||
elif sample == 'down-5': | |||
self.conv = nn.Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size=5, | |||
stride=2, | |||
padding=2, | |||
bias=bias) | |||
elif sample == 'down-3': | |||
self.conv = nn.Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size=3, | |||
stride=2, | |||
padding=1, | |||
bias=bias) | |||
else: | |||
self.conv = nn.Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size=3, | |||
stride=1, | |||
padding=1, | |||
bias=bias) | |||
if bn: | |||
self.bn = nn.BatchNorm2d(out_channels) | |||
if activ == 'relu': | |||
self.activation = nn.ReLU() | |||
elif activ == 'leaky': | |||
self.activation = nn.LeakyReLU(negative_slope=0.2) | |||
def forward(self, images): | |||
outputs = self.conv(images) | |||
if hasattr(self, 'bn'): | |||
outputs = self.bn(outputs) | |||
if hasattr(self, 'activation'): | |||
outputs = self.activation(outputs) | |||
return outputs |
@@ -0,0 +1,66 @@ | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from ..weights_init import weights_init | |||
from .detection_module import ConvBNActiv | |||
class DetectionUNet(nn.Module): | |||
def __init__(self, | |||
n_channels, | |||
n_classes, | |||
up_sampling_node='nearest', | |||
init_weights=True): | |||
super(DetectionUNet, self).__init__() | |||
self.n_classes = n_classes | |||
self.up_sampling_node = up_sampling_node | |||
self.ec_images_1 = ConvBNActiv( | |||
n_channels, 64, bn=False, sample='down-3') | |||
self.ec_images_2 = ConvBNActiv(64, 128, sample='down-3') | |||
self.ec_images_3 = ConvBNActiv(128, 256, sample='down-3') | |||
self.ec_images_4 = ConvBNActiv(256, 512, sample='down-3') | |||
self.ec_images_5 = ConvBNActiv(512, 512, sample='down-3') | |||
self.ec_images_6 = ConvBNActiv(512, 512, sample='down-3') | |||
self.dc_images_6 = ConvBNActiv(512 + 512, 512, activ='leaky') | |||
self.dc_images_5 = ConvBNActiv(512 + 512, 512, activ='leaky') | |||
self.dc_images_4 = ConvBNActiv(512 + 256, 256, activ='leaky') | |||
self.dc_images_3 = ConvBNActiv(256 + 128, 128, activ='leaky') | |||
self.dc_images_2 = ConvBNActiv(128 + 64, 64, activ='leaky') | |||
self.dc_images_1 = nn.Conv2d(64 + n_channels, n_classes, kernel_size=1) | |||
if init_weights: | |||
self.apply(weights_init()) | |||
def forward(self, input_images): | |||
ec_images = {} | |||
ec_images['ec_images_0'] = input_images | |||
ec_images['ec_images_1'] = self.ec_images_1(input_images) | |||
ec_images['ec_images_2'] = self.ec_images_2(ec_images['ec_images_1']) | |||
ec_images['ec_images_3'] = self.ec_images_3(ec_images['ec_images_2']) | |||
ec_images['ec_images_4'] = self.ec_images_4(ec_images['ec_images_3']) | |||
ec_images['ec_images_5'] = self.ec_images_5(ec_images['ec_images_4']) | |||
ec_images['ec_images_6'] = self.ec_images_6(ec_images['ec_images_5']) | |||
# -------------- | |||
# images decoder | |||
# -------------- | |||
logits = ec_images['ec_images_6'] | |||
for _ in range(6, 0, -1): | |||
ec_images_skip = 'ec_images_{:d}'.format(_ - 1) | |||
dc_conv = 'dc_images_{:d}'.format(_) | |||
logits = F.interpolate( | |||
logits, scale_factor=2, mode=self.up_sampling_node) | |||
logits = torch.cat((logits, ec_images[ec_images_skip]), dim=1) | |||
logits = getattr(self, dc_conv)(logits) | |||
return logits |
@@ -0,0 +1,207 @@ | |||
import torch | |||
import torch.nn as nn | |||
class GatedConvBNActiv(nn.Module): | |||
def __init__(self, | |||
in_channels, | |||
out_channels, | |||
bn=True, | |||
sample='none-3', | |||
activ='relu', | |||
bias=False): | |||
super(GatedConvBNActiv, self).__init__() | |||
if sample == 'down-7': | |||
self.conv = nn.Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size=7, | |||
stride=2, | |||
padding=3, | |||
bias=bias) | |||
self.gate = nn.Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size=7, | |||
stride=2, | |||
padding=3, | |||
bias=bias) | |||
elif sample == 'down-5': | |||
self.conv = nn.Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size=5, | |||
stride=2, | |||
padding=2, | |||
bias=bias) | |||
self.gate = nn.Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size=5, | |||
stride=2, | |||
padding=2, | |||
bias=bias) | |||
elif sample == 'down-3': | |||
self.conv = nn.Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size=3, | |||
stride=2, | |||
padding=1, | |||
bias=bias) | |||
self.gate = nn.Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size=3, | |||
stride=2, | |||
padding=1, | |||
bias=bias) | |||
else: | |||
self.conv = nn.Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size=3, | |||
stride=1, | |||
padding=1, | |||
bias=bias) | |||
self.gate = nn.Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size=3, | |||
stride=1, | |||
padding=1, | |||
bias=bias) | |||
if bn: | |||
self.bn = nn.BatchNorm2d(out_channels) | |||
if activ == 'relu': | |||
self.activation = nn.ReLU() | |||
elif activ == 'leaky': | |||
self.activation = nn.LeakyReLU(negative_slope=0.2) | |||
self.sigmoid = nn.Sigmoid() | |||
def forward(self, x): | |||
images = self.conv(x) | |||
gates = self.sigmoid(self.gate(x)) | |||
if hasattr(self, 'bn'): | |||
images = self.bn(images) | |||
if hasattr(self, 'activation'): | |||
images = self.activation(images) | |||
images = images * gates | |||
return images | |||
class GatedConvBNActiv2(nn.Module): | |||
def __init__(self, | |||
in_channels, | |||
out_channels, | |||
bn=True, | |||
sample='none-3', | |||
activ='relu', | |||
bias=False): | |||
super(GatedConvBNActiv2, self).__init__() | |||
if sample == 'down-7': | |||
self.conv = nn.Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size=7, | |||
stride=2, | |||
padding=3, | |||
bias=bias) | |||
self.gate = nn.Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size=7, | |||
stride=2, | |||
padding=3, | |||
bias=bias) | |||
elif sample == 'down-5': | |||
self.conv = nn.Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size=5, | |||
stride=2, | |||
padding=2, | |||
bias=bias) | |||
self.gate = nn.Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size=5, | |||
stride=2, | |||
padding=2, | |||
bias=bias) | |||
elif sample == 'down-3': | |||
self.conv = nn.Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size=3, | |||
stride=2, | |||
padding=1, | |||
bias=bias) | |||
self.gate = nn.Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size=3, | |||
stride=2, | |||
padding=1, | |||
bias=bias) | |||
else: | |||
self.conv = nn.Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size=3, | |||
stride=1, | |||
padding=1, | |||
bias=bias) | |||
self.gate = nn.Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size=3, | |||
stride=1, | |||
padding=1, | |||
bias=bias) | |||
self.conv_skip = nn.Conv2d( | |||
out_channels, | |||
out_channels, | |||
kernel_size=3, | |||
stride=1, | |||
padding=1, | |||
bias=bias) | |||
if bn: | |||
self.bn = nn.BatchNorm2d(out_channels) | |||
if activ == 'relu': | |||
self.activation = nn.ReLU() | |||
elif activ == 'leaky': | |||
self.activation = nn.LeakyReLU(negative_slope=0.2) | |||
self.sigmoid = nn.Sigmoid() | |||
def forward(self, f_up, f_skip, mask): | |||
x = torch.cat((f_up, f_skip, mask), dim=1) | |||
images = self.conv(x) | |||
images_skip = self.conv_skip(f_skip) | |||
gates = self.sigmoid(self.gate(x)) | |||
if hasattr(self, 'bn'): | |||
images = self.bn(images) | |||
images_skip = self.bn(images_skip) | |||
if hasattr(self, 'activation'): | |||
images = self.activation(images) | |||
images_skip = self.activation(images_skip) | |||
images = images * gates + images_skip * (1 - gates) | |||
return images |
@@ -0,0 +1,88 @@ | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from modelscope.models.cv.skin_retouching.inpainting_model.gconv import \ | |||
GatedConvBNActiv | |||
from ..weights_init import weights_init | |||
class RetouchingNet(nn.Module): | |||
def __init__(self, | |||
in_channels=3, | |||
out_channels=3, | |||
up_sampling_node='nearest', | |||
init_weights=True): | |||
super(RetouchingNet, self).__init__() | |||
self.freeze_ec_bn = False | |||
self.up_sampling_node = up_sampling_node | |||
self.ec_images_1 = GatedConvBNActiv( | |||
in_channels, 64, bn=False, sample='down-3') | |||
self.ec_images_2 = GatedConvBNActiv(64, 128, sample='down-3') | |||
self.ec_images_3 = GatedConvBNActiv(128, 256, sample='down-3') | |||
self.ec_images_4 = GatedConvBNActiv(256, 512, sample='down-3') | |||
self.ec_images_5 = GatedConvBNActiv(512, 512, sample='down-3') | |||
self.ec_images_6 = GatedConvBNActiv(512, 512, sample='down-3') | |||
self.dc_images_6 = GatedConvBNActiv(512 + 512, 512, activ='leaky') | |||
self.dc_images_5 = GatedConvBNActiv(512 + 512, 512, activ='leaky') | |||
self.dc_images_4 = GatedConvBNActiv(512 + 256, 256, activ='leaky') | |||
self.dc_images_3 = GatedConvBNActiv(256 + 128, 128, activ='leaky') | |||
self.dc_images_2 = GatedConvBNActiv(128 + 64, 64, activ='leaky') | |||
self.dc_images_1 = GatedConvBNActiv( | |||
64 + in_channels, | |||
out_channels, | |||
bn=False, | |||
sample='none-3', | |||
activ=None, | |||
bias=True) | |||
self.tanh = nn.Tanh() | |||
if init_weights: | |||
self.apply(weights_init()) | |||
def forward(self, input_images, input_masks): | |||
ec_images = {} | |||
ec_images['ec_images_0'] = torch.cat((input_images, input_masks), | |||
dim=1) | |||
ec_images['ec_images_1'] = self.ec_images_1(ec_images['ec_images_0']) | |||
ec_images['ec_images_2'] = self.ec_images_2(ec_images['ec_images_1']) | |||
ec_images['ec_images_3'] = self.ec_images_3(ec_images['ec_images_2']) | |||
ec_images['ec_images_4'] = self.ec_images_4(ec_images['ec_images_3']) | |||
ec_images['ec_images_5'] = self.ec_images_5(ec_images['ec_images_4']) | |||
ec_images['ec_images_6'] = self.ec_images_6(ec_images['ec_images_5']) | |||
# -------------- | |||
# images decoder | |||
# -------------- | |||
dc_images = ec_images['ec_images_6'] | |||
for _ in range(6, 0, -1): | |||
ec_images_skip = 'ec_images_{:d}'.format(_ - 1) | |||
dc_conv = 'dc_images_{:d}'.format(_) | |||
dc_images = F.interpolate( | |||
dc_images, scale_factor=2, mode=self.up_sampling_node) | |||
dc_images = torch.cat((dc_images, ec_images[ec_images_skip]), | |||
dim=1) | |||
dc_images = getattr(self, dc_conv)(dc_images) | |||
outputs = self.tanh(dc_images) | |||
return outputs | |||
def train(self, mode=True): | |||
super().train(mode) | |||
if self.freeze_ec_bn: | |||
for name, module in self.named_modules(): | |||
if isinstance(module, nn.BatchNorm2d): | |||
module.eval() |
@@ -0,0 +1,271 @@ | |||
# Implementation in this file is modifed from source code avaiable via https://github.com/ternaus/retinaface | |||
from typing import List, Tuple, Union | |||
import numpy as np | |||
import torch | |||
def point_form(boxes: torch.Tensor) -> torch.Tensor: | |||
"""Convert prior_boxes to (x_min, y_min, x_max, y_max) representation for comparison to point form ground truth data. | |||
Args: | |||
boxes: center-size default boxes from priorbox layers. | |||
Return: | |||
boxes: Converted x_min, y_min, x_max, y_max form of boxes. | |||
""" | |||
return torch.cat( | |||
(boxes[:, :2] - boxes[:, 2:] / 2, boxes[:, :2] + boxes[:, 2:] / 2), | |||
dim=1) | |||
def center_size(boxes: torch.Tensor) -> torch.Tensor: | |||
"""Convert prior_boxes to (cx, cy, w, h) representation for comparison to center-size form ground truth data. | |||
Args: | |||
boxes: point_form boxes | |||
Return: | |||
boxes: Converted x_min, y_min, x_max, y_max form of boxes. | |||
""" | |||
return torch.cat( | |||
((boxes[:, 2:] + boxes[:, :2]) / 2, boxes[:, 2:] - boxes[:, :2]), | |||
dim=1) | |||
def intersect(box_a: torch.Tensor, box_b: torch.Tensor) -> torch.Tensor: | |||
""" We resize both tensors to [A,B,2] without new malloc: | |||
[A, 2] -> [A, 1, 2] -> [A, B, 2] | |||
[B, 2] -> [1, B, 2] -> [A, B, 2] | |||
Then we compute the area of intersect between box_a and box_b. | |||
Args: | |||
box_a: bounding boxes, Shape: [A, 4]. | |||
box_b: bounding boxes, Shape: [B, 4]. | |||
Return: | |||
intersection area, Shape: [A, B]. | |||
""" | |||
A = box_a.size(0) | |||
B = box_b.size(0) | |||
max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), | |||
box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) | |||
min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), | |||
box_b[:, :2].unsqueeze(0).expand(A, B, 2)) | |||
inter = torch.clamp((max_xy - min_xy), min=0) | |||
return inter[:, :, 0] * inter[:, :, 1] | |||
def jaccard(box_a: torch.Tensor, box_b: torch.Tensor) -> torch.Tensor: | |||
"""Compute the jaccard overlap of two sets of boxes. The jaccard overlap is simply the intersection over | |||
union of two boxes. Here we operate on ground truth boxes and default boxes. | |||
E.g.: | |||
A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) | |||
Args: | |||
box_a: Ground truth bounding boxes, Shape: [num_objects,4] | |||
box_b: Prior boxes from priorbox layers, Shape: [num_priors,4] | |||
Return: | |||
jaccard overlap: Shape: [box_a.size(0), box_b.size(0)] | |||
""" | |||
inter = intersect(box_a, box_b) | |||
area_a = (box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1]) | |||
area_a = area_a.unsqueeze(1).expand_as(inter) # [A,B] | |||
area_b = (box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1]) | |||
area_b = area_b.unsqueeze(0).expand_as(inter) # [A,B] | |||
union = area_a + area_b - inter | |||
return inter / union | |||
def matrix_iof(a: np.ndarray, b: np.ndarray) -> np.ndarray: | |||
""" | |||
return iof of a and b, numpy version for data augmentation | |||
""" | |||
lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) | |||
rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) | |||
area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) | |||
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) | |||
return area_i / np.maximum(area_a[:, np.newaxis], 1) | |||
def match( | |||
threshold: float, | |||
box_gt: torch.Tensor, | |||
priors: torch.Tensor, | |||
variances: List[float], | |||
labels_gt: torch.Tensor, | |||
landmarks_gt: torch.Tensor, | |||
box_t: torch.Tensor, | |||
label_t: torch.Tensor, | |||
landmarks_t: torch.Tensor, | |||
batch_id: int, | |||
) -> None: | |||
"""Match each prior box with the ground truth box of the highest jaccard overlap, encode the bounding | |||
boxes, then return the matched indices corresponding to both confidence and location preds. | |||
Args: | |||
threshold: The overlap threshold used when matching boxes. | |||
box_gt: Ground truth boxes, Shape: [num_obj, 4]. | |||
priors: Prior boxes from priorbox layers, Shape: [n_priors, 4]. | |||
variances: Variances corresponding to each prior coord, Shape: [num_priors, 4]. | |||
labels_gt: All the class labels for the image, Shape: [num_obj, 2]. | |||
landmarks_gt: Ground truth landms, Shape [num_obj, 10]. | |||
box_t: Tensor to be filled w/ endcoded location targets. | |||
label_t: Tensor to be filled w/ matched indices for labels predictions. | |||
landmarks_t: Tensor to be filled w/ endcoded landmarks targets. | |||
batch_id: current batch index | |||
Return: | |||
The matched indices corresponding to 1)location 2)confidence 3)landmarks preds. | |||
""" | |||
# Compute iou between gt and priors | |||
overlaps = jaccard(box_gt, point_form(priors)) | |||
# (Bipartite Matching) | |||
# [1, num_objects] best prior for each ground truth | |||
best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) | |||
# ignore hard gt | |||
valid_gt_idx = best_prior_overlap[:, 0] >= 0.2 | |||
best_prior_idx_filter = best_prior_idx[valid_gt_idx, :] | |||
if best_prior_idx_filter.shape[0] <= 0: | |||
box_t[batch_id] = 0 | |||
label_t[batch_id] = 0 | |||
return | |||
# [1, num_priors] best ground truth for each prior | |||
best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) | |||
best_truth_idx.squeeze_(0) | |||
best_truth_overlap.squeeze_(0) | |||
best_prior_idx.squeeze_(1) | |||
best_prior_idx_filter.squeeze_(1) | |||
best_prior_overlap.squeeze_(1) | |||
best_truth_overlap.index_fill_(0, best_prior_idx_filter, | |||
2) # ensure best prior | |||
# TODO refactor: index best_prior_idx with long tensor | |||
# ensure every gt matches with its prior of max overlap | |||
for j in range(best_prior_idx.size(0)): | |||
best_truth_idx[best_prior_idx[j]] = j | |||
matches = box_gt[best_truth_idx] # Shape: [num_priors, 4] | |||
labels = labels_gt[best_truth_idx] # Shape: [num_priors] | |||
# label as background | |||
labels[best_truth_overlap < threshold] = 0 | |||
loc = encode(matches, priors, variances) | |||
matches_landm = landmarks_gt[best_truth_idx] | |||
landmarks_gt = encode_landm(matches_landm, priors, variances) | |||
box_t[batch_id] = loc # [num_priors, 4] encoded offsets to learn | |||
label_t[batch_id] = labels # [num_priors] top class label for each prior | |||
landmarks_t[batch_id] = landmarks_gt | |||
def encode(matched, priors, variances): | |||
"""Encode the variances from the priorbox layers into the ground truth boxes | |||
we have matched (based on jaccard overlap) with the prior boxes. | |||
Args: | |||
matched: (tensor) Coords of ground truth for each prior in point-form | |||
Shape: [num_priors, 4]. | |||
priors: (tensor) Prior boxes in center-offset form | |||
Shape: [num_priors,4]. | |||
variances: (list[float]) Variances of priorboxes | |||
Return: | |||
encoded boxes (tensor), Shape: [num_priors, 4] | |||
""" | |||
# dist b/t match center and prior's center | |||
g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] | |||
# encode variance | |||
g_cxcy /= variances[0] * priors[:, 2:] | |||
# match wh / prior wh | |||
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] | |||
g_wh = torch.log(g_wh) / variances[1] | |||
# return target for smooth_l1_loss | |||
return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] | |||
def encode_landm( | |||
matched: torch.Tensor, priors: torch.Tensor, | |||
variances: Union[List[float], Tuple[float, float]]) -> torch.Tensor: | |||
"""Encode the variances from the priorbox layers into the ground truth boxes we have matched | |||
(based on jaccard overlap) with the prior boxes. | |||
Args: | |||
matched: Coords of ground truth for each prior in point-form | |||
Shape: [num_priors, 10]. | |||
priors: Prior boxes in center-offset form | |||
Shape: [num_priors,4]. | |||
variances: Variances of priorboxes | |||
Return: | |||
encoded landmarks, Shape: [num_priors, 10] | |||
""" | |||
# dist b/t match center and prior's center | |||
matched = torch.reshape(matched, (matched.size(0), 5, 2)) | |||
priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), | |||
5).unsqueeze(2) | |||
priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), | |||
5).unsqueeze(2) | |||
priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), | |||
5).unsqueeze(2) | |||
priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), | |||
5).unsqueeze(2) | |||
priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2) | |||
g_cxcy = matched[:, :, :2] - priors[:, :, :2] | |||
# encode variance | |||
g_cxcy = g_cxcy // variances[0] * priors[:, :, 2:] | |||
# return target for smooth_l1_loss | |||
return g_cxcy.reshape(g_cxcy.size(0), -1) | |||
# Adapted from https://github.com/Hakuyume/chainer-ssd | |||
def decode(loc: torch.Tensor, priors: torch.Tensor, | |||
variances: Union[List[float], Tuple[float, float]]) -> torch.Tensor: | |||
"""Decode locations from predictions using priors to undo the encoding we did for offset regression at train time. | |||
Args: | |||
loc: location predictions for loc layers, | |||
Shape: [num_priors, 4] | |||
priors: Prior boxes in center-offset form. | |||
Shape: [num_priors, 4]. | |||
variances: Variances of priorboxes | |||
Return: | |||
decoded bounding box predictions | |||
""" | |||
boxes = torch.cat( | |||
( | |||
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], | |||
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1]), | |||
), | |||
1, | |||
) | |||
boxes[:, :2] -= boxes[:, 2:] / 2 | |||
boxes[:, 2:] += boxes[:, :2] | |||
return boxes | |||
def decode_landm( | |||
pre: torch.Tensor, priors: torch.Tensor, | |||
variances: Union[List[float], Tuple[float, float]]) -> torch.Tensor: | |||
"""Decode landmarks from predictions using priors to undo the encoding we did for offset regression at train time. | |||
Args: | |||
pre: landmark predictions for loc layers, | |||
Shape: [num_priors, 10] | |||
priors: Prior boxes in center-offset form. | |||
Shape: [num_priors, 4]. | |||
variances: Variances of priorboxes | |||
Return: | |||
decoded landmark predictions | |||
""" | |||
return torch.cat( | |||
( | |||
priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:], | |||
priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:], | |||
priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:], | |||
priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:], | |||
priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:], | |||
), | |||
dim=1, | |||
) | |||
def log_sum_exp(x: torch.Tensor) -> torch.Tensor: | |||
"""Utility function for computing log_sum_exp while determining This will be used to determine unaveraged | |||
confidence loss across all examples in a batch. | |||
Args: | |||
x: conf_preds from conf layers | |||
""" | |||
x_max = x.data.max() | |||
return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max |
@@ -0,0 +1,124 @@ | |||
# Implementation in this file is modifed from source code avaiable via https://github.com/ternaus/retinaface | |||
from typing import Dict, List | |||
import torch | |||
import torch.nn.functional as F | |||
from torch import nn | |||
def conv_bn(inp: int, | |||
oup: int, | |||
stride: int = 1, | |||
leaky: float = 0) -> nn.Sequential: | |||
return nn.Sequential( | |||
nn.Conv2d(inp, oup, 3, stride, 1, bias=False), | |||
nn.BatchNorm2d(oup), | |||
nn.LeakyReLU(negative_slope=leaky, inplace=True), | |||
) | |||
def conv_bn_no_relu(inp: int, oup: int, stride: int) -> nn.Sequential: | |||
return nn.Sequential( | |||
nn.Conv2d(inp, oup, 3, stride, 1, bias=False), | |||
nn.BatchNorm2d(oup), | |||
) | |||
def conv_bn1X1(inp: int, | |||
oup: int, | |||
stride: int, | |||
leaky: float = 0) -> nn.Sequential: | |||
return nn.Sequential( | |||
nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), | |||
nn.BatchNorm2d(oup), | |||
nn.LeakyReLU(negative_slope=leaky, inplace=True), | |||
) | |||
def conv_dw(inp: int, | |||
oup: int, | |||
stride: int, | |||
leaky: float = 0.1) -> nn.Sequential: | |||
return nn.Sequential( | |||
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), | |||
nn.BatchNorm2d(inp), | |||
nn.LeakyReLU(negative_slope=leaky, inplace=True), | |||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False), | |||
nn.BatchNorm2d(oup), | |||
nn.LeakyReLU(negative_slope=leaky, inplace=True), | |||
) | |||
class SSH(nn.Module): | |||
def __init__(self, in_channel: int, out_channel: int) -> None: | |||
super().__init__() | |||
if out_channel % 4 != 0: | |||
raise ValueError( | |||
f'Expect out channel % 4 == 0, but we got {out_channel % 4}') | |||
leaky: float = 0 | |||
if out_channel <= 64: | |||
leaky = 0.1 | |||
self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) | |||
self.conv5X5_1 = conv_bn( | |||
in_channel, out_channel // 4, stride=1, leaky=leaky) | |||
self.conv5X5_2 = conv_bn_no_relu( | |||
out_channel // 4, out_channel // 4, stride=1) | |||
self.conv7X7_2 = conv_bn( | |||
out_channel // 4, out_channel // 4, stride=1, leaky=leaky) | |||
self.conv7x7_3 = conv_bn_no_relu( | |||
out_channel // 4, out_channel // 4, stride=1) | |||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
conv3X3 = self.conv3X3(x) | |||
conv5X5_1 = self.conv5X5_1(x) | |||
conv5X5 = self.conv5X5_2(conv5X5_1) | |||
conv7X7_2 = self.conv7X7_2(conv5X5_1) | |||
conv7X7 = self.conv7x7_3(conv7X7_2) | |||
out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) | |||
return F.relu(out) | |||
class FPN(nn.Module): | |||
def __init__(self, in_channels_list: List[int], out_channels: int) -> None: | |||
super().__init__() | |||
leaky = 0.0 | |||
if out_channels <= 64: | |||
leaky = 0.1 | |||
self.output1 = conv_bn1X1( | |||
in_channels_list[0], out_channels, stride=1, leaky=leaky) | |||
self.output2 = conv_bn1X1( | |||
in_channels_list[1], out_channels, stride=1, leaky=leaky) | |||
self.output3 = conv_bn1X1( | |||
in_channels_list[2], out_channels, stride=1, leaky=leaky) | |||
self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) | |||
self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) | |||
def forward(self, x: Dict[str, torch.Tensor]) -> List[torch.Tensor]: | |||
y = list(x.values()) | |||
output1 = self.output1(y[0]) | |||
output2 = self.output2(y[1]) | |||
output3 = self.output3(y[2]) | |||
up3 = F.interpolate( | |||
output3, size=[output2.size(2), output2.size(3)], mode='nearest') | |||
output2 = output2 + up3 | |||
output2 = self.merge2(output2) | |||
up2 = F.interpolate( | |||
output2, size=[output1.size(2), output1.size(3)], mode='nearest') | |||
output1 = output1 + up2 | |||
output1 = self.merge1(output1) | |||
return [output1, output2, output3] |
@@ -0,0 +1,146 @@ | |||
# Implementation in this file is modifed from source code avaiable via https://github.com/ternaus/retinaface | |||
from typing import Dict, Tuple | |||
import torch | |||
from torch import nn | |||
from torchvision import models | |||
from torchvision.models import _utils | |||
from .net import FPN, SSH | |||
class ClassHead(nn.Module): | |||
def __init__(self, in_channels: int = 512, num_anchors: int = 3) -> None: | |||
super().__init__() | |||
self.conv1x1 = nn.Conv2d( | |||
in_channels, | |||
num_anchors * 2, | |||
kernel_size=(1, 1), | |||
stride=1, | |||
padding=0) | |||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
out = self.conv1x1(x) | |||
out = out.permute(0, 2, 3, 1).contiguous() | |||
return out.view(out.shape[0], -1, 2) | |||
class BboxHead(nn.Module): | |||
def __init__(self, in_channels: int = 512, num_anchors: int = 3): | |||
super().__init__() | |||
self.conv1x1 = nn.Conv2d( | |||
in_channels, | |||
num_anchors * 4, | |||
kernel_size=(1, 1), | |||
stride=1, | |||
padding=0) | |||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
out = self.conv1x1(x) | |||
out = out.permute(0, 2, 3, 1).contiguous() | |||
return out.view(out.shape[0], -1, 4) | |||
class LandmarkHead(nn.Module): | |||
def __init__(self, in_channels: int = 512, num_anchors: int = 3): | |||
super().__init__() | |||
self.conv1x1 = nn.Conv2d( | |||
in_channels, | |||
num_anchors * 10, | |||
kernel_size=(1, 1), | |||
stride=1, | |||
padding=0) | |||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
out = self.conv1x1(x) | |||
out = out.permute(0, 2, 3, 1).contiguous() | |||
return out.view(out.shape[0], -1, 10) | |||
class RetinaFace(nn.Module): | |||
def __init__(self, name: str, pretrained: bool, in_channels: int, | |||
return_layers: Dict[str, int], out_channels: int) -> None: | |||
super().__init__() | |||
if name == 'Resnet50': | |||
backbone = models.resnet50(pretrained=pretrained) | |||
else: | |||
raise NotImplementedError( | |||
f'Only Resnet50 backbone is supported but got {name}') | |||
self.body = _utils.IntermediateLayerGetter(backbone, return_layers) | |||
in_channels_stage2 = in_channels | |||
in_channels_list = [ | |||
in_channels_stage2 * 2, | |||
in_channels_stage2 * 4, | |||
in_channels_stage2 * 8, | |||
] | |||
self.fpn = FPN(in_channels_list, out_channels) | |||
self.ssh1 = SSH(out_channels, out_channels) | |||
self.ssh2 = SSH(out_channels, out_channels) | |||
self.ssh3 = SSH(out_channels, out_channels) | |||
self.ClassHead = self._make_class_head( | |||
fpn_num=3, in_channels=out_channels) | |||
self.BboxHead = self._make_bbox_head( | |||
fpn_num=3, in_channels=out_channels) | |||
self.LandmarkHead = self._make_landmark_head( | |||
fpn_num=3, in_channels=out_channels) | |||
@staticmethod | |||
def _make_class_head(fpn_num: int = 3, | |||
in_channels: int = 64, | |||
anchor_num: int = 2) -> nn.ModuleList: | |||
classhead = nn.ModuleList() | |||
for _ in range(fpn_num): | |||
classhead.append(ClassHead(in_channels, anchor_num)) | |||
return classhead | |||
@staticmethod | |||
def _make_bbox_head(fpn_num: int = 3, | |||
in_channels: int = 64, | |||
anchor_num: int = 2) -> nn.ModuleList: | |||
bboxhead = nn.ModuleList() | |||
for _ in range(fpn_num): | |||
bboxhead.append(BboxHead(in_channels, anchor_num)) | |||
return bboxhead | |||
@staticmethod | |||
def _make_landmark_head(fpn_num: int = 3, | |||
in_channels: int = 64, | |||
anchor_num: int = 2) -> nn.ModuleList: | |||
landmarkhead = nn.ModuleList() | |||
for _ in range(fpn_num): | |||
landmarkhead.append(LandmarkHead(in_channels, anchor_num)) | |||
return landmarkhead | |||
def forward( | |||
self, inputs: torch.Tensor | |||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |||
out = self.body(inputs) | |||
# FPN | |||
fpn = self.fpn(out) | |||
# SSH | |||
feature1 = self.ssh1(fpn[0]) | |||
feature2 = self.ssh2(fpn[1]) | |||
feature3 = self.ssh3(fpn[2]) | |||
features = [feature1, feature2, feature3] | |||
bbox_regressions = torch.cat( | |||
[self.BboxHead[i](feature) for i, feature in enumerate(features)], | |||
dim=1) | |||
classifications = torch.cat( | |||
[self.ClassHead[i](feature) for i, feature in enumerate(features)], | |||
dim=1) | |||
ldm_regressions = [ | |||
self.LandmarkHead[i](feature) for i, feature in enumerate(features) | |||
] | |||
ldm_regressions = torch.cat(ldm_regressions, dim=1) | |||
return bbox_regressions, classifications, ldm_regressions |
@@ -0,0 +1,152 @@ | |||
# Implementation in this file is modifed from source code avaiable via https://github.com/ternaus/retinaface | |||
""" | |||
There is a lot of post processing of the predictions. | |||
""" | |||
from typing import Dict, List, Union | |||
import albumentations as A | |||
import numpy as np | |||
import torch | |||
from torch.nn import functional as F | |||
from torchvision.ops import nms | |||
from ..utils import pad_to_size, unpad_from_size | |||
from .box_utils import decode, decode_landm | |||
from .network import RetinaFace | |||
from .prior_box import priorbox | |||
from .utils import tensor_from_rgb_image | |||
class Model: | |||
def __init__(self, max_size: int = 960, device: str = 'cpu') -> None: | |||
self.model = RetinaFace( | |||
name='Resnet50', | |||
pretrained=False, | |||
return_layers={ | |||
'layer2': 1, | |||
'layer3': 2, | |||
'layer4': 3 | |||
}, | |||
in_channels=256, | |||
out_channels=256, | |||
).to(device) | |||
self.device = device | |||
self.transform = A.Compose( | |||
[A.LongestMaxSize(max_size=max_size, p=1), | |||
A.Normalize(p=1)]) | |||
self.max_size = max_size | |||
self.prior_box = priorbox( | |||
min_sizes=[[16, 32], [64, 128], [256, 512]], | |||
steps=[8, 16, 32], | |||
clip=False, | |||
image_size=(self.max_size, self.max_size), | |||
).to(device) | |||
self.variance = [0.1, 0.2] | |||
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: | |||
self.model.load_state_dict(state_dict) | |||
def eval(self): | |||
self.model.eval() | |||
def predict_jsons( | |||
self, | |||
image: np.array, | |||
confidence_threshold: float = 0.7, | |||
nms_threshold: float = 0.4) -> List[Dict[str, Union[List, float]]]: | |||
with torch.no_grad(): | |||
original_height, original_width = image.shape[:2] | |||
scale_landmarks = torch.from_numpy( | |||
np.tile([self.max_size, self.max_size], | |||
5)).to(self.device).float() | |||
scale_bboxes = torch.from_numpy( | |||
np.tile([self.max_size, self.max_size], | |||
2)).to(self.device).float() | |||
transformed_image = self.transform(image=image)['image'] | |||
paded = pad_to_size( | |||
target_size=(self.max_size, self.max_size), | |||
image=transformed_image) | |||
pads = paded['pads'] | |||
torched_image = tensor_from_rgb_image(paded['image']).to( | |||
self.device) | |||
loc, conf, land = self.model(torched_image.unsqueeze(0)) | |||
conf = F.softmax(conf, dim=-1) | |||
annotations: List[Dict[str, Union[List, float]]] = [] | |||
boxes = decode(loc.data[0], self.prior_box, self.variance) | |||
boxes *= scale_bboxes | |||
scores = conf[0][:, 1] | |||
landmarks = decode_landm(land.data[0], self.prior_box, | |||
self.variance) | |||
landmarks *= scale_landmarks | |||
# ignore low scores | |||
valid_index = scores > confidence_threshold | |||
boxes = boxes[valid_index] | |||
landmarks = landmarks[valid_index] | |||
scores = scores[valid_index] | |||
# Sort from high to low | |||
order = scores.argsort(descending=True) | |||
boxes = boxes[order] | |||
landmarks = landmarks[order] | |||
scores = scores[order] | |||
# do NMS | |||
keep = nms(boxes, scores, nms_threshold) | |||
boxes = boxes[keep, :].int() | |||
if boxes.shape[0] == 0: | |||
return [{'bbox': [], 'score': -1, 'landmarks': []}] | |||
landmarks = landmarks[keep] | |||
scores = scores[keep].cpu().numpy().astype(np.float64) | |||
boxes = boxes.cpu().numpy() | |||
landmarks = landmarks.cpu().numpy() | |||
landmarks = landmarks.reshape([-1, 2]) | |||
unpadded = unpad_from_size(pads, bboxes=boxes, keypoints=landmarks) | |||
resize_coeff = max(original_height, original_width) / self.max_size | |||
boxes = (unpadded['bboxes'] * resize_coeff).astype(int) | |||
landmarks = (unpadded['keypoints'].reshape(-1, 10) | |||
* resize_coeff).astype(int) | |||
for box_id, bbox in enumerate(boxes): | |||
x_min, y_min, x_max, y_max = bbox | |||
x_min = np.clip(x_min, 0, original_width - 1) | |||
x_max = np.clip(x_max, x_min + 1, original_width - 1) | |||
if x_min >= x_max: | |||
continue | |||
y_min = np.clip(y_min, 0, original_height - 1) | |||
y_max = np.clip(y_max, y_min + 1, original_height - 1) | |||
if y_min >= y_max: | |||
continue | |||
annotations += [{ | |||
'bbox': | |||
bbox.tolist(), | |||
'score': | |||
scores[box_id], | |||
'landmarks': | |||
landmarks[box_id].reshape(-1, 2).tolist(), | |||
}] | |||
return annotations |
@@ -0,0 +1,28 @@ | |||
# Implementation in this file is modifed from source code avaiable via https://github.com/ternaus/retinaface | |||
from itertools import product | |||
from math import ceil | |||
import torch | |||
def priorbox(min_sizes, steps, clip, image_size): | |||
feature_maps = [[ceil(image_size[0] / step), | |||
ceil(image_size[1] / step)] for step in steps] | |||
anchors = [] | |||
for k, f in enumerate(feature_maps): | |||
t_min_sizes = min_sizes[k] | |||
for i, j in product(range(f[0]), range(f[1])): | |||
for min_size in t_min_sizes: | |||
s_kx = min_size / image_size[1] | |||
s_ky = min_size / image_size[0] | |||
dense_cx = [x * steps[k] / image_size[1] for x in [j + 0.5]] | |||
dense_cy = [y * steps[k] / image_size[0] for y in [i + 0.5]] | |||
for cy, cx in product(dense_cy, dense_cx): | |||
anchors += [cx, cy, s_kx, s_ky] | |||
# back to torch land | |||
output = torch.Tensor(anchors).view(-1, 4) | |||
if clip: | |||
output.clamp_(max=1, min=0) | |||
return output |
@@ -0,0 +1,70 @@ | |||
# Implementation in this file is modifed from source code avaiable via https://github.com/ternaus/retinaface | |||
import re | |||
from pathlib import Path | |||
from typing import Any, Dict, List, Optional, Union | |||
import cv2 | |||
import numpy as np | |||
import torch | |||
def load_checkpoint(file_path: Union[Path, str], | |||
rename_in_layers: Optional[dict] = None) -> Dict[str, Any]: | |||
"""Loads PyTorch checkpoint, optionally renaming layer names. | |||
Args: | |||
file_path: path to the torch checkpoint. | |||
rename_in_layers: {from_name: to_name} | |||
ex: {"model.0.": "", | |||
"model.": ""} | |||
Returns: | |||
""" | |||
checkpoint = torch.load( | |||
file_path, map_location=lambda storage, loc: storage) | |||
if rename_in_layers is not None: | |||
model_state_dict = checkpoint['state_dict'] | |||
result = {} | |||
for key, value in model_state_dict.items(): | |||
for key_r, value_r in rename_in_layers.items(): | |||
key = re.sub(key_r, value_r, key) | |||
result[key] = value | |||
checkpoint['state_dict'] = result | |||
return checkpoint | |||
def tensor_from_rgb_image(image: np.ndarray) -> torch.Tensor: | |||
image = np.transpose(image, (2, 0, 1)) | |||
return torch.from_numpy(image) | |||
def vis_annotations(image: np.ndarray, | |||
annotations: List[Dict[str, Any]]) -> np.ndarray: | |||
vis_image = image.copy() | |||
for annotation in annotations: | |||
landmarks = annotation['landmarks'] | |||
colors = [(255, 0, 0), (128, 255, 0), (255, 178, 102), (102, 128, 255), | |||
(0, 255, 255)] | |||
for landmark_id, (x, y) in enumerate(landmarks): | |||
vis_image = cv2.circle( | |||
vis_image, (x, y), | |||
radius=3, | |||
color=colors[landmark_id], | |||
thickness=3) | |||
x_min, y_min, x_max, y_max = annotation['bbox'] | |||
x_min = np.clip(x_min, 0, x_max - 1) | |||
y_min = np.clip(y_min, 0, y_max - 1) | |||
vis_image = cv2.rectangle( | |||
vis_image, (x_min, y_min), (x_max, y_max), | |||
color=(0, 255, 0), | |||
thickness=2) | |||
return vis_image |
@@ -0,0 +1,143 @@ | |||
import warnings | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from .weights_init import weights_init | |||
warnings.filterwarnings(action='ignore') | |||
class double_conv(nn.Module): | |||
'''(conv => BN => ReLU) * 2''' | |||
def __init__(self, in_ch, out_ch): | |||
super(double_conv, self).__init__() | |||
self.conv = nn.Sequential( | |||
nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), | |||
nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), | |||
nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True)) | |||
def forward(self, x): | |||
x = self.conv(x) | |||
return x | |||
class inconv(nn.Module): | |||
def __init__(self, in_ch, out_ch): | |||
super(inconv, self).__init__() | |||
self.conv = double_conv(in_ch, out_ch) | |||
def forward(self, x): | |||
x = self.conv(x) | |||
return x | |||
class down(nn.Module): | |||
def __init__(self, in_ch, out_ch): | |||
super(down, self).__init__() | |||
self.mpconv = nn.Sequential( | |||
nn.MaxPool2d(2), double_conv(in_ch, out_ch)) | |||
def forward(self, x): | |||
x = self.mpconv(x) | |||
return x | |||
class up(nn.Module): | |||
def __init__(self, in_ch, out_ch, bilinear=True): | |||
super(up, self).__init__() | |||
if bilinear: | |||
self.up = nn.Upsample( | |||
scale_factor=2, mode='bilinear', align_corners=True) | |||
else: | |||
self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2) | |||
self.conv = double_conv(in_ch, out_ch) | |||
def forward(self, x1, x2): | |||
x1 = self.up(x1) | |||
diffY = x2.size()[2] - x1.size()[2] | |||
diffX = x2.size()[3] - x1.size()[3] | |||
x1 = F.pad( | |||
x1, | |||
(diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)) | |||
x = torch.cat([x2, x1], dim=1) | |||
x = self.conv(x) | |||
return x | |||
class outconv(nn.Module): | |||
def __init__(self, in_ch, out_ch): | |||
super(outconv, self).__init__() | |||
self.conv = nn.Conv2d(in_ch, out_ch, 1) | |||
def forward(self, x): | |||
x = self.conv(x) | |||
return x | |||
class UNet(nn.Module): | |||
def __init__(self, | |||
n_channels, | |||
n_classes, | |||
deep_supervision=False, | |||
init_weights=True): | |||
super(UNet, self).__init__() | |||
self.deep_supervision = deep_supervision | |||
self.inc = inconv(n_channels, 64) | |||
self.down1 = down(64, 128) | |||
self.down2 = down(128, 256) | |||
self.down3 = down(256, 512) | |||
self.down4 = down(512, 512) | |||
self.up1 = up(1024, 256) | |||
self.up2 = up(512, 128) | |||
self.up3 = up(256, 64) | |||
self.up4 = up(128, 64) | |||
self.outc = outconv(64, n_classes) | |||
self.dsoutc4 = outconv(256, n_classes) | |||
self.dsoutc3 = outconv(128, n_classes) | |||
self.dsoutc2 = outconv(64, n_classes) | |||
self.dsoutc1 = outconv(64, n_classes) | |||
self.sigmoid = nn.Sigmoid() | |||
if init_weights: | |||
self.apply(weights_init()) | |||
def forward(self, x): | |||
x1 = self.inc(x) | |||
x2 = self.down1(x1) | |||
x3 = self.down2(x2) | |||
x4 = self.down3(x3) | |||
x5 = self.down4(x4) | |||
x44 = self.up1(x5, x4) | |||
x33 = self.up2(x44, x3) | |||
x22 = self.up3(x33, x2) | |||
x11 = self.up4(x22, x1) | |||
x0 = self.outc(x11) | |||
x0 = self.sigmoid(x0) | |||
if self.deep_supervision: | |||
x11 = F.interpolate( | |||
self.dsoutc1(x11), x0.shape[2:], mode='bilinear') | |||
x22 = F.interpolate( | |||
self.dsoutc2(x22), x0.shape[2:], mode='bilinear') | |||
x33 = F.interpolate( | |||
self.dsoutc3(x33), x0.shape[2:], mode='bilinear') | |||
x44 = F.interpolate( | |||
self.dsoutc4(x44), x0.shape[2:], mode='bilinear') | |||
return x0, x11, x22, x33, x44 | |||
else: | |||
return x0 |
@@ -0,0 +1,327 @@ | |||
import time | |||
from typing import Dict, List, Optional, Tuple, Union | |||
import cv2 | |||
import numpy as np | |||
import torch | |||
import torch.nn.functional as F | |||
from einops import rearrange | |||
__all__ = [ | |||
'gen_diffuse_mask', 'get_crop_bbox', 'get_roi_without_padding', | |||
'patch_aggregation_overlap', 'patch_partition_overlap', 'preprocess_roi', | |||
'resize_on_long_side', 'roi_to_tensor', 'smooth_border_mg', 'whiten_img' | |||
] | |||
def resize_on_long_side(img, long_side=800): | |||
src_height = img.shape[0] | |||
src_width = img.shape[1] | |||
if src_height > src_width: | |||
scale = long_side * 1.0 / src_height | |||
_img = cv2.resize( | |||
img, (int(src_width * scale), long_side), | |||
interpolation=cv2.INTER_LINEAR) | |||
else: | |||
scale = long_side * 1.0 / src_width | |||
_img = cv2.resize( | |||
img, (long_side, int(src_height * scale)), | |||
interpolation=cv2.INTER_LINEAR) | |||
return _img, scale | |||
def get_crop_bbox(detecting_results): | |||
boxes = [] | |||
for anno in detecting_results: | |||
if anno['score'] == -1: | |||
break | |||
boxes.append({ | |||
'x1': anno['bbox'][0], | |||
'y1': anno['bbox'][1], | |||
'x2': anno['bbox'][2], | |||
'y2': anno['bbox'][3] | |||
}) | |||
face_count = len(boxes) | |||
suitable_bboxes = [] | |||
for i in range(face_count): | |||
face_bbox = boxes[i] | |||
face_bbox_width = abs(face_bbox['x2'] - face_bbox['x1']) | |||
face_bbox_height = abs(face_bbox['y2'] - face_bbox['y1']) | |||
face_bbox_center = ((face_bbox['x1'] + face_bbox['x2']) / 2, | |||
(face_bbox['y1'] + face_bbox['y2']) / 2) | |||
square_bbox_length = face_bbox_height if face_bbox_height > face_bbox_width else face_bbox_width | |||
enlarge_ratio = 1.5 | |||
square_bbox_length = int(enlarge_ratio * square_bbox_length) | |||
sideScale = 1 | |||
square_bbox = { | |||
'x1': | |||
int(face_bbox_center[0] - sideScale * square_bbox_length / 2), | |||
'x2': | |||
int(face_bbox_center[0] + sideScale * square_bbox_length / 2), | |||
'y1': | |||
int(face_bbox_center[1] - sideScale * square_bbox_length / 2), | |||
'y2': int(face_bbox_center[1] + sideScale * square_bbox_length / 2) | |||
} | |||
suitable_bboxes.append(square_bbox) | |||
return suitable_bboxes | |||
def get_roi_without_padding(img, bbox): | |||
crop_t = max(bbox['y1'], 0) | |||
crop_b = min(bbox['y2'], img.shape[0]) | |||
crop_l = max(bbox['x1'], 0) | |||
crop_r = min(bbox['x2'], img.shape[1]) | |||
roi = img[crop_t:crop_b, crop_l:crop_r] | |||
return roi, 0, [crop_t, crop_b, crop_l, crop_r] | |||
def roi_to_tensor(img): | |||
img = torch.from_numpy(img.transpose((2, 0, 1)))[None, ...] | |||
return img | |||
def preprocess_roi(img): | |||
img = img.float() / 255.0 | |||
img = (img - 0.5) * 2 | |||
return img | |||
def patch_partition_overlap(image, p1, p2, padding=32): | |||
B, C, H, W = image.size() | |||
h, w = H // p1, W // p2 | |||
image = F.pad( | |||
image, | |||
pad=(padding, padding, padding, padding, 0, 0), | |||
mode='constant', | |||
value=0) | |||
patch_list = [] | |||
for i in range(h): | |||
for j in range(w): | |||
patch = image[:, :, p1 * i:p1 * (i + 1) + padding * 2, | |||
p2 * j:p2 * (j + 1) + padding * 2] | |||
patch_list.append(patch) | |||
output = torch.cat( | |||
patch_list, dim=0) # (b h w) c (p1 + 2 * padding) (p2 + 2 * padding) | |||
return output | |||
def patch_aggregation_overlap(image, h, w, padding=32): | |||
image = image[:, :, padding:-padding, padding:-padding] | |||
output = rearrange(image, '(b h w) c p1 p2 -> b c (h p1) (w p2)', h=h, w=w) | |||
return output | |||
def smooth_border_mg(diffuse_mask, mg): | |||
mg = mg - 0.5 | |||
diffuse_mask = F.interpolate( | |||
diffuse_mask, mg.shape[:2], mode='bilinear')[0].permute(1, 2, 0) | |||
mg = mg * diffuse_mask | |||
mg = mg + 0.5 | |||
return mg | |||
def whiten_img(image, skin_mask, whitening_degree, flag_bigKernal=False): | |||
""" | |||
image: rgb | |||
""" | |||
dilate_kernalsize = 30 | |||
if flag_bigKernal: | |||
dilate_kernalsize = 80 | |||
new_kernel1 = cv2.getStructuringElement( | |||
cv2.MORPH_ELLIPSE, (dilate_kernalsize, dilate_kernalsize)) | |||
new_kernel2 = cv2.getStructuringElement( | |||
cv2.MORPH_ELLIPSE, (dilate_kernalsize, dilate_kernalsize)) | |||
if len(skin_mask.shape) == 3: | |||
skin_mask = skin_mask[:, :, -1] | |||
skin_mask = cv2.dilate(skin_mask, new_kernel1, 1) | |||
skin_mask = cv2.erode(skin_mask, new_kernel2, 1) | |||
skin_mask = cv2.blur(skin_mask, (20, 20)) / 255.0 | |||
skin_mask = skin_mask.squeeze() | |||
skin_mask = torch.from_numpy(skin_mask).to(image.device) | |||
skin_mask = torch.stack([skin_mask, skin_mask, skin_mask], dim=0)[None, | |||
...] | |||
skin_mask[:, 1:, :, :] *= 0.75 | |||
whiten_mg = skin_mask * 0.2 * whitening_degree + 0.5 | |||
assert len(whiten_mg.shape) == 4 | |||
whiten_mg = F.interpolate( | |||
whiten_mg, image.shape[:2], mode='bilinear')[0].permute(1, 2, | |||
0).half() | |||
output_pred = image.half() | |||
output_pred = output_pred / 255.0 | |||
output_pred = ( | |||
-2 * whiten_mg + 1 | |||
) * output_pred * output_pred + 2 * whiten_mg * output_pred # value: 0~1 | |||
output_pred = output_pred * 255.0 | |||
output_pred = output_pred.byte() | |||
output_pred = output_pred.cpu().numpy() | |||
return output_pred | |||
def gen_diffuse_mask(out_channels=3): | |||
mask_size = 500 | |||
diffuse_with = 20 | |||
a = np.ones(shape=(mask_size, mask_size), dtype=np.float32) | |||
for i in range(mask_size): | |||
for j in range(mask_size): | |||
if i >= diffuse_with and i <= ( | |||
mask_size - diffuse_with) and j >= diffuse_with and j <= ( | |||
mask_size - diffuse_with): | |||
a[i, j] = 1.0 | |||
elif i <= diffuse_with: | |||
a[i, j] = i * 1.0 / diffuse_with | |||
elif i > (mask_size - diffuse_with): | |||
a[i, j] = (mask_size - i) * 1.0 / diffuse_with | |||
for i in range(mask_size): | |||
for j in range(mask_size): | |||
if j <= diffuse_with: | |||
a[i, j] = min(a[i, j], j * 1.0 / diffuse_with) | |||
elif j > (mask_size - diffuse_with): | |||
a[i, j] = min(a[i, j], (mask_size - j) * 1.0 / diffuse_with) | |||
a = np.dstack([a] * out_channels) | |||
return a | |||
def pad_to_size( | |||
target_size: Tuple[int, int], | |||
image: np.array, | |||
bboxes: Optional[np.ndarray] = None, | |||
keypoints: Optional[np.ndarray] = None, | |||
) -> Dict[str, Union[np.ndarray, Tuple[int, int, int, int]]]: | |||
"""Pads the image on the sides to the target_size | |||
Args: | |||
target_size: (target_height, target_width) | |||
image: | |||
bboxes: np.array with shape (num_boxes, 4). Each row: [x_min, y_min, x_max, y_max] | |||
keypoints: np.array with shape (num_keypoints, 2), each row: [x, y] | |||
Returns: | |||
{ | |||
"image": padded_image, | |||
"pads": (x_min_pad, y_min_pad, x_max_pad, y_max_pad), | |||
"bboxes": shifted_boxes, | |||
"keypoints": shifted_keypoints | |||
} | |||
""" | |||
target_height, target_width = target_size | |||
image_height, image_width = image.shape[:2] | |||
if target_width < image_width: | |||
raise ValueError(f'Target width should bigger than image_width' | |||
f'We got {target_width} {image_width}') | |||
if target_height < image_height: | |||
raise ValueError(f'Target height should bigger than image_height' | |||
f'We got {target_height} {image_height}') | |||
if image_height == target_height: | |||
y_min_pad = 0 | |||
y_max_pad = 0 | |||
else: | |||
y_pad = target_height - image_height | |||
y_min_pad = y_pad // 2 | |||
y_max_pad = y_pad - y_min_pad | |||
if image_width == target_width: | |||
x_min_pad = 0 | |||
x_max_pad = 0 | |||
else: | |||
x_pad = target_width - image_width | |||
x_min_pad = x_pad // 2 | |||
x_max_pad = x_pad - x_min_pad | |||
result = { | |||
'pads': (x_min_pad, y_min_pad, x_max_pad, y_max_pad), | |||
'image': | |||
cv2.copyMakeBorder(image, y_min_pad, y_max_pad, x_min_pad, x_max_pad, | |||
cv2.BORDER_CONSTANT), | |||
} | |||
if bboxes is not None: | |||
bboxes[:, 0] += x_min_pad | |||
bboxes[:, 1] += y_min_pad | |||
bboxes[:, 2] += x_min_pad | |||
bboxes[:, 3] += y_min_pad | |||
result['bboxes'] = bboxes | |||
if keypoints is not None: | |||
keypoints[:, 0] += x_min_pad | |||
keypoints[:, 1] += y_min_pad | |||
result['keypoints'] = keypoints | |||
return result | |||
def unpad_from_size( | |||
pads: Tuple[int, int, int, int], | |||
image: Optional[np.array] = None, | |||
bboxes: Optional[np.ndarray] = None, | |||
keypoints: Optional[np.ndarray] = None, | |||
) -> Dict[str, np.ndarray]: | |||
"""Crops patch from the center so that sides are equal to pads. | |||
Args: | |||
image: | |||
pads: (x_min_pad, y_min_pad, x_max_pad, y_max_pad) | |||
bboxes: np.array with shape (num_boxes, 4). Each row: [x_min, y_min, x_max, y_max] | |||
keypoints: np.array with shape (num_keypoints, 2), each row: [x, y] | |||
Returns: cropped image | |||
{ | |||
"image": cropped_image, | |||
"bboxes": shifted_boxes, | |||
"keypoints": shifted_keypoints | |||
} | |||
""" | |||
x_min_pad, y_min_pad, x_max_pad, y_max_pad = pads | |||
result = {} | |||
if image is not None: | |||
height, width = image.shape[:2] | |||
result['image'] = image[y_min_pad:height - y_max_pad, | |||
x_min_pad:width - x_max_pad] | |||
if bboxes is not None: | |||
bboxes[:, 0] -= x_min_pad | |||
bboxes[:, 1] -= y_min_pad | |||
bboxes[:, 2] -= x_min_pad | |||
bboxes[:, 3] -= y_min_pad | |||
result['bboxes'] = bboxes | |||
if keypoints is not None: | |||
keypoints[:, 0] -= x_min_pad | |||
keypoints[:, 1] -= y_min_pad | |||
result['keypoints'] = keypoints | |||
return result |
@@ -0,0 +1,36 @@ | |||
import torch | |||
import torch.nn as nn | |||
def weights_init(init_type='kaiming', gain=0.02): | |||
def init_func(m): | |||
classname = m.__class__.__name__ | |||
if hasattr(m, 'weight') and (classname.find('Conv') != -1 | |||
or classname.find('Linear') != -1): | |||
if init_type == 'normal': | |||
nn.init.normal_(m.weight.data, 0.0, gain) | |||
elif init_type == 'xavier': | |||
nn.init.xavier_normal_(m.weight.data, gain=gain) | |||
elif init_type == 'kaiming': | |||
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') | |||
elif init_type == 'orthogonal': | |||
nn.init.orthogonal_(m.weight.data, gain=gain) | |||
if hasattr(m, 'bias') and m.bias is not None: | |||
nn.init.constant_(m.bias.data, 0.0) | |||
elif classname.find('BatchNorm2d') != -1: | |||
nn.init.normal_(m.weight.data, 1.0, gain) | |||
nn.init.constant_(m.bias.data, 0.0) | |||
return init_func | |||
def spectral_norm(module, mode=True): | |||
if mode: | |||
return nn.utils.spectral_norm(module) | |||
return module |
@@ -122,6 +122,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
Tasks.image_classification: | |||
(Pipelines.daily_image_classification, | |||
'damo/cv_vit-base_image-classification_Dailylife-labels'), | |||
Tasks.skin_retouching: (Pipelines.skin_retouching, | |||
'damo/cv_unet_skin-retouching'), | |||
} | |||
@@ -27,6 +27,7 @@ if TYPE_CHECKING: | |||
from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline | |||
from .live_category_pipeline import LiveCategoryPipeline | |||
from .ocr_detection_pipeline import OCRDetectionPipeline | |||
from .skin_retouching_pipeline import SkinRetouchingPipeline | |||
from .video_category_pipeline import VideoCategoryPipeline | |||
from .virtual_try_on_pipeline import VirtualTryonPipeline | |||
else: | |||
@@ -59,6 +60,7 @@ else: | |||
'image_to_image_generation_pipeline': | |||
['Image2ImageGenerationePipeline'], | |||
'ocr_detection_pipeline': ['OCRDetectionPipeline'], | |||
'skin_retouching_pipeline': ['SkinRetouchingPipeline'], | |||
'video_category_pipeline': ['VideoCategoryPipeline'], | |||
'virtual_try_on_pipeline': ['VirtualTryonPipeline'], | |||
} | |||
@@ -0,0 +1,302 @@ | |||
import os | |||
from typing import Any, Dict | |||
import cv2 | |||
import numpy as np | |||
import PIL | |||
import tensorflow as tf | |||
import torch | |||
import torch.nn.functional as F | |||
import torchvision.transforms as transforms | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models.cv.skin_retouching.detection_model.detection_unet_in import \ | |||
DetectionUNet | |||
from modelscope.models.cv.skin_retouching.inpainting_model.inpainting_unet import \ | |||
RetouchingNet | |||
from modelscope.models.cv.skin_retouching.retinaface.predict_single import \ | |||
Model | |||
from modelscope.models.cv.skin_retouching.unet_deploy import UNet | |||
from modelscope.models.cv.skin_retouching.utils import * # noqa F403 | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import LoadImage | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
if tf.__version__ >= '2.0': | |||
tf = tf.compat.v1 | |||
tf.disable_eager_execution() | |||
logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.skin_retouching, module_name=Pipelines.skin_retouching) | |||
class SkinRetouchingPipeline(Pipeline): | |||
def __init__(self, model: str, device: str): | |||
""" | |||
use `model` to create a skin retouching pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
super().__init__(model=model) | |||
if device == 'gpu': | |||
device = 'cuda' | |||
model_path = os.path.join(self.model, ModelFile.TORCH_MODEL_FILE) | |||
detector_model_path = os.path.join( | |||
self.model, 'retinaface_resnet50_2020-07-20_old_torch.pth') | |||
local_model_path = os.path.join(self.model, 'joint_20210926.pth') | |||
skin_model_path = os.path.join(self.model, ModelFile.TF_GRAPH_FILE) | |||
self.generator = UNet(3, 3).to(device) | |||
self.generator.load_state_dict( | |||
torch.load(model_path, map_location='cpu')['generator']) | |||
self.generator.eval() | |||
self.detector = Model(max_size=512, device=device) | |||
state_dict = torch.load(detector_model_path, map_location='cpu') | |||
self.detector.load_state_dict(state_dict) | |||
self.detector.eval() | |||
self.local_model_path = local_model_path | |||
ckpt_dict_load = torch.load(self.local_model_path, map_location='cpu') | |||
self.inpainting_net = RetouchingNet( | |||
in_channels=4, out_channels=3).to(device) | |||
self.detection_net = DetectionUNet( | |||
n_channels=3, n_classes=1).to(device) | |||
self.inpainting_net.load_state_dict(ckpt_dict_load['inpainting_net']) | |||
self.detection_net.load_state_dict(ckpt_dict_load['detection_net']) | |||
self.inpainting_net.eval() | |||
self.detection_net.eval() | |||
self.patch_size = 512 | |||
self.skin_model_path = skin_model_path | |||
if self.skin_model_path is not None: | |||
config = tf.ConfigProto(allow_soft_placement=True) | |||
config.gpu_options.per_process_gpu_memory_fraction = 0.3 | |||
config.gpu_options.allow_growth = True | |||
self.sess = tf.Session(config=config) | |||
with tf.gfile.FastGFile(self.skin_model_path, 'rb') as f: | |||
graph_def = tf.GraphDef() | |||
graph_def.ParseFromString(f.read()) | |||
self.sess.graph.as_default() | |||
tf.import_graph_def(graph_def, name='') | |||
self.sess.run(tf.global_variables_initializer()) | |||
self.image_files_transforms = transforms.Compose([ | |||
transforms.ToTensor(), | |||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |||
]) | |||
self.diffuse_mask = gen_diffuse_mask() | |||
self.diffuse_mask = torch.from_numpy( | |||
self.diffuse_mask).to(device).float() | |||
self.diffuse_mask = self.diffuse_mask.permute(2, 0, 1)[None, ...] | |||
self.input_size = 512 | |||
self.device = device | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
img = LoadImage.convert_to_ndarray(input) | |||
if len(img.shape) == 2: | |||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||
img = img.astype(np.float) | |||
result = {'img': img} | |||
return result | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
rgb_image = input['img'].astype(np.uint8) | |||
retouch_local = True | |||
whitening = True | |||
degree = 1.0 | |||
whitening_degree = 0.8 | |||
return_mg = False | |||
with torch.no_grad(): | |||
if whitening and whitening_degree > 0 and self.skin_model_path is not None: | |||
rgb_image_small, resize_scale = resize_on_long_side( | |||
rgb_image, 800) | |||
skin_mask = self.sess.run( | |||
self.sess.graph.get_tensor_by_name('output_png:0'), | |||
feed_dict={'input_image:0': rgb_image_small}) | |||
output_pred = torch.from_numpy(rgb_image).to(self.device) | |||
if return_mg: | |||
output_mg = np.ones( | |||
(rgb_image.shape[0], rgb_image.shape[1], 3), | |||
dtype=np.float32) * 0.5 | |||
results = self.detector.predict_jsons( | |||
rgb_image | |||
) # list, [{'bbox':, [x1, y1, x2, y2], 'score'...}, ...] | |||
crop_bboxes = get_crop_bbox(results) | |||
face_num = len(crop_bboxes) | |||
if face_num == 0: | |||
output = { | |||
'pred': output_pred.cpu().numpy()[:, :, ::-1], | |||
'face_num': face_num | |||
} | |||
return output | |||
flag_bigKernal = False | |||
for bbox in crop_bboxes: | |||
roi, expand, crop_tblr = get_roi_without_padding( | |||
rgb_image, bbox) | |||
roi = roi_to_tensor(roi) # bgr -> rgb | |||
if roi.shape[2] > 0.4 * rgb_image.shape[0]: | |||
flag_bigKernal = True | |||
roi = roi.to(self.device) | |||
roi = preprocess_roi(roi) | |||
if retouch_local and self.local_model_path is not None: | |||
roi = self.retouch_local(roi) | |||
roi_output = self.predict_roi( | |||
roi, | |||
degree=degree, | |||
smooth_border=True, | |||
return_mg=return_mg) | |||
roi_pred = roi_output['pred'] | |||
output_pred[crop_tblr[0]:crop_tblr[1], | |||
crop_tblr[2]:crop_tblr[3]] = roi_pred | |||
if return_mg: | |||
roi_mg = roi_output['pred_mg'] | |||
output_mg[crop_tblr[0]:crop_tblr[1], | |||
crop_tblr[2]:crop_tblr[3]] = roi_mg | |||
if whitening and whitening_degree > 0 and self.skin_model_path is not None: | |||
output_pred = whiten_img( | |||
output_pred, | |||
skin_mask, | |||
whitening_degree, | |||
flag_bigKernal=flag_bigKernal) | |||
if not isinstance(output_pred, np.ndarray): | |||
output_pred = output_pred.cpu().numpy() | |||
output_pred = output_pred[:, :, ::-1] | |||
return {OutputKeys.OUTPUT_IMG: output_pred} | |||
def retouch_local(self, image): | |||
""" | |||
image: rgb | |||
""" | |||
with torch.no_grad(): | |||
sub_H, sub_W = image.shape[2:] | |||
sub_image_standard = F.interpolate( | |||
image, size=(768, 768), mode='bilinear', align_corners=True) | |||
sub_mask_pred = torch.sigmoid( | |||
self.detection_net(sub_image_standard)) | |||
sub_mask_pred = F.interpolate( | |||
sub_mask_pred, size=(sub_H, sub_W), mode='nearest') | |||
sub_mask_pred_hard_low = (sub_mask_pred >= 0.35).float() | |||
sub_mask_pred_hard_high = (sub_mask_pred >= 0.5).float() | |||
sub_mask_pred = sub_mask_pred * ( | |||
1 - sub_mask_pred_hard_high) + sub_mask_pred_hard_high | |||
sub_mask_pred = sub_mask_pred * sub_mask_pred_hard_low | |||
sub_mask_pred = 1 - sub_mask_pred | |||
sub_H_standard = sub_H if sub_H % self.patch_size == 0 else ( | |||
sub_H // self.patch_size + 1) * self.patch_size | |||
sub_W_standard = sub_W if sub_W % self.patch_size == 0 else ( | |||
sub_W // self.patch_size + 1) * self.patch_size | |||
sub_image_padding = F.pad( | |||
image, | |||
pad=(0, sub_W_standard - sub_W, 0, sub_H_standard - sub_H, 0, | |||
0), | |||
mode='constant', | |||
value=0) | |||
sub_mask_pred_padding = F.pad( | |||
sub_mask_pred, | |||
pad=(0, sub_W_standard - sub_W, 0, sub_H_standard - sub_H, 0, | |||
0), | |||
mode='constant', | |||
value=0) | |||
sub_image_padding = patch_partition_overlap( | |||
sub_image_padding, p1=self.patch_size, p2=self.patch_size) | |||
sub_mask_pred_padding = patch_partition_overlap( | |||
sub_mask_pred_padding, p1=self.patch_size, p2=self.patch_size) | |||
B_padding, C_padding, _, _ = sub_image_padding.size() | |||
sub_comp_padding_list = [] | |||
for window_item in range(B_padding): | |||
sub_image_padding_window = sub_image_padding[ | |||
window_item:window_item + 1] | |||
sub_mask_pred_padding_window = sub_mask_pred_padding[ | |||
window_item:window_item + 1] | |||
sub_input_image_padding_window = sub_image_padding_window * sub_mask_pred_padding_window | |||
sub_output_padding_window = self.inpainting_net( | |||
sub_input_image_padding_window, | |||
sub_mask_pred_padding_window) | |||
sub_comp_padding_window = sub_input_image_padding_window + ( | |||
1 | |||
- sub_mask_pred_padding_window) * sub_output_padding_window | |||
sub_comp_padding_list.append(sub_comp_padding_window) | |||
sub_comp_padding = torch.cat(sub_comp_padding_list, dim=0) | |||
sub_comp = patch_aggregation_overlap( | |||
sub_comp_padding, | |||
h=int(round(sub_H_standard / self.patch_size)), | |||
w=int(round(sub_W_standard | |||
/ self.patch_size)))[:, :, :sub_H, :sub_W] | |||
return sub_comp | |||
def predict_roi(self, | |||
roi, | |||
degree=1.0, | |||
smooth_border=False, | |||
return_mg=False): | |||
with torch.no_grad(): | |||
image = F.interpolate( | |||
roi, (self.input_size, self.input_size), mode='bilinear') | |||
pred_mg = self.generator(image) # value: 0~1 | |||
pred_mg = (pred_mg - 0.5) * degree + 0.5 | |||
pred_mg = pred_mg.clamp(0.0, 1.0) | |||
pred_mg = F.interpolate(pred_mg, roi.shape[2:], mode='bilinear') | |||
pred_mg = pred_mg[0].permute( | |||
1, 2, 0) # ndarray, (h, w, 1) or (h0, w0, 3) | |||
if len(pred_mg.shape) == 2: | |||
pred_mg = pred_mg[..., None] | |||
if smooth_border: | |||
pred_mg = smooth_border_mg(self.diffuse_mask, pred_mg) | |||
image = (roi[0].permute(1, 2, 0) + 1.0) / 2 | |||
pred = (1 - 2 * pred_mg | |||
) * image * image + 2 * pred_mg * image # value: 0~1 | |||
pred = (pred * 255.0).byte() # ndarray, (h, w, 3), rgb | |||
output = {'pred': pred} | |||
if return_mg: | |||
output['pred_mg'] = pred_mg.cpu().numpy() | |||
return output | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
return inputs |
@@ -0,0 +1,46 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import os.path as osp | |||
import unittest | |||
import cv2 | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import pipeline | |||
from modelscope.pipelines.base import Pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.test_utils import test_level | |||
class SkinRetouchingTest(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.model_id = 'damo/cv_unet_skin-retouching' | |||
self.test_image = 'data/test/images/skin_retouching.png' | |||
def pipeline_inference(self, pipeline: Pipeline, input_location: str): | |||
result = pipeline(input_location) | |||
cv2.imwrite('result_skinretouching.png', result[OutputKeys.OUTPUT_IMG]) | |||
print(f'Output written to {osp.abspath("result_skinretouching.png")}') | |||
@unittest.skip('deprecated, download model from model hub instead') | |||
def test_run_by_direct_model_download(self): | |||
model_dir = snapshot_download(self.model_id) | |||
skin_retouching = pipeline(Tasks.skin_retouching, model=model_dir) | |||
self.pipeline_inference(skin_retouching, self.test_image) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_modelhub(self): | |||
skin_retouching = pipeline(Tasks.skin_retouching, model=self.model_id) | |||
self.pipeline_inference(skin_retouching, self.test_image) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_modelhub_default_model(self): | |||
skin_retouching = pipeline(Tasks.skin_retouching) | |||
self.pipeline_inference(skin_retouching, self.test_image) | |||
if __name__ == '__main__': | |||
unittest.main() |