Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9590794master
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:fa8ab905e8374a0f94b4bfbfc81da14e762c71eaf64bae85bdd03b07cdf884c2 | |||
size 859206 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:8cd14710143ba1a912e3ef574d0bf71c7e40bf9897522cba07ecae2567343064 | |||
size 850603 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:d7f166ecb3a6913dbd05a1eb271399cbaa731d1074ac03184c13ae245ca66819 | |||
size 800380 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:e95d11661485fc0e6f326398f953459dcb3e65b7f4a6c892611266067cf8fe3a | |||
size 245773 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:03972400b20b3e6f1d056b359d9c9f12952653a67a73b36018504ce9ee9edf9d | |||
size 254261 |
@@ -16,6 +16,7 @@ class Models(object): | |||
nafnet = 'nafnet' | |||
csrnet = 'csrnet' | |||
cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin' | |||
gpen = 'gpen' | |||
product_retrieval_embedding = 'product-retrieval-embedding' | |||
# nlp models | |||
@@ -91,6 +92,7 @@ class Pipelines(object): | |||
image2image_translation = 'image-to-image-translation' | |||
live_category = 'live-category' | |||
video_category = 'video-category' | |||
image_portrait_enhancement = 'gpen-image-portrait-enhancement' | |||
image_to_image_generation = 'image-to-image-generation' | |||
# nlp tasks | |||
@@ -160,6 +162,7 @@ class Preprocessors(object): | |||
image_denoie_preprocessor = 'image-denoise-preprocessor' | |||
image_color_enhance_preprocessor = 'image-color-enhance-preprocessor' | |||
image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor' | |||
image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor' | |||
# nlp preprocessor | |||
sen_sim_tokenizer = 'sen-sim-tokenizer' | |||
@@ -207,3 +210,5 @@ class Metrics(object): | |||
text_gen_metric = 'text-gen-metric' | |||
# metrics for image-color-enhance task | |||
image_color_enhance_metric = 'image-color-enhance-metric' | |||
# metrics for image-portrait-enhancement task | |||
image_portrait_enhancement_metric = 'image-portrait-enhancement-metric' |
@@ -10,6 +10,7 @@ if TYPE_CHECKING: | |||
from .image_denoise_metric import ImageDenoiseMetric | |||
from .image_instance_segmentation_metric import \ | |||
ImageInstanceSegmentationCOCOMetric | |||
from .image_portrait_enhancement_metric import ImagePortraitEnhancementMetric | |||
from .sequence_classification_metric import SequenceClassificationMetric | |||
from .text_generation_metric import TextGenerationMetric | |||
@@ -21,6 +22,8 @@ else: | |||
'image_denoise_metric': ['ImageDenoiseMetric'], | |||
'image_instance_segmentation_metric': | |||
['ImageInstanceSegmentationCOCOMetric'], | |||
'image_portrait_enhancement_metric': | |||
['ImagePortraitEnhancementMetric'], | |||
'sequence_classification_metric': ['SequenceClassificationMetric'], | |||
'text_generation_metric': ['TextGenerationMetric'], | |||
} | |||
@@ -23,7 +23,9 @@ task_default_metrics = { | |||
Tasks.sentiment_classification: [Metrics.seq_cls_metric], | |||
Tasks.text_generation: [Metrics.text_gen_metric], | |||
Tasks.image_denoising: [Metrics.image_denoise_metric], | |||
Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric] | |||
Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric], | |||
Tasks.image_portrait_enhancement: | |||
[Metrics.image_portrait_enhancement_metric], | |||
} | |||
@@ -0,0 +1,47 @@ | |||
from typing import Dict | |||
import numpy as np | |||
from modelscope.metainfo import Metrics | |||
from modelscope.utils.registry import default_group | |||
from .base import Metric | |||
from .builder import METRICS, MetricKeys | |||
def calculate_psnr(img, img2): | |||
assert img.shape == img2.shape, ( | |||
f'Image shapes are different: {img.shape}, {img2.shape}.') | |||
img = img.astype(np.float64) | |||
img2 = img2.astype(np.float64) | |||
mse = np.mean((img - img2)**2) | |||
if mse == 0: | |||
return float('inf') | |||
return 10. * np.log10(255. * 255. / mse) | |||
@METRICS.register_module( | |||
group_key=default_group, | |||
module_name=Metrics.image_portrait_enhancement_metric) | |||
class ImagePortraitEnhancementMetric(Metric): | |||
"""The metric for image-portrait-enhancement task. | |||
""" | |||
def __init__(self): | |||
self.preds = [] | |||
self.targets = [] | |||
def add(self, outputs: Dict, inputs: Dict): | |||
ground_truths = outputs['target'] | |||
eval_results = outputs['pred'] | |||
self.preds.extend(eval_results) | |||
self.targets.extend(ground_truths) | |||
def evaluate(self): | |||
psnrs = [ | |||
calculate_psnr(pred, target) | |||
for pred, target in zip(self.preds, self.targets) | |||
] | |||
return {MetricKeys.PSNR: sum(psnrs) / len(psnrs)} |
@@ -3,6 +3,6 @@ from . import (action_recognition, animal_recognition, cartoon, | |||
cmdssl_video_embedding, face_detection, face_generation, | |||
image_classification, image_color_enhance, image_colorization, | |||
image_denoise, image_instance_segmentation, | |||
image_to_image_generation, image_to_image_translation, | |||
object_detection, product_retrieval_embedding, super_resolution, | |||
virual_tryon) | |||
image_portrait_enhancement, image_to_image_generation, | |||
image_to_image_translation, object_detection, | |||
product_retrieval_embedding, super_resolution, virual_tryon) |
@@ -0,0 +1,22 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .image_portrait_enhancement import ImagePortraitEnhancement | |||
else: | |||
_import_structure = { | |||
'image_portrait_enhancement': ['ImagePortraitEnhancement'] | |||
} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1,252 @@ | |||
import cv2 | |||
import numpy as np | |||
from skimage import transform as trans | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
# reference facial points, a list of coordinates (x,y) | |||
REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], | |||
[65.53179932, 51.50139999], | |||
[48.02519989, | |||
71.73660278], [33.54930115, 92.3655014], | |||
[62.72990036, 92.20410156]] | |||
DEFAULT_CROP_SIZE = (96, 112) | |||
def _umeyama(src, dst, estimate_scale=True, scale=1.0): | |||
"""Estimate N-D similarity transformation with or without scaling. | |||
Parameters | |||
---------- | |||
src : (M, N) array | |||
Source coordinates. | |||
dst : (M, N) array | |||
Destination coordinates. | |||
estimate_scale : bool | |||
Whether to estimate scaling factor. | |||
Returns | |||
------- | |||
T : (N + 1, N + 1) | |||
The homogeneous similarity transformation matrix. The matrix contains | |||
NaN values only if the problem is not well-conditioned. | |||
References | |||
---------- | |||
.. [1] "Least-squares estimation of transformation parameters between two | |||
point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573` | |||
""" | |||
num = src.shape[0] | |||
dim = src.shape[1] | |||
# Compute mean of src and dst. | |||
src_mean = src.mean(axis=0) | |||
dst_mean = dst.mean(axis=0) | |||
# Subtract mean from src and dst. | |||
src_demean = src - src_mean | |||
dst_demean = dst - dst_mean | |||
# Eq. (38). | |||
A = dst_demean.T @ src_demean / num | |||
# Eq. (39). | |||
d = np.ones((dim, ), dtype=np.double) | |||
if np.linalg.det(A) < 0: | |||
d[dim - 1] = -1 | |||
T = np.eye(dim + 1, dtype=np.double) | |||
U, S, V = np.linalg.svd(A) | |||
# Eq. (40) and (43). | |||
rank = np.linalg.matrix_rank(A) | |||
if rank == 0: | |||
return np.nan * T | |||
elif rank == dim - 1: | |||
if np.linalg.det(U) * np.linalg.det(V) > 0: | |||
T[:dim, :dim] = U @ V | |||
else: | |||
s = d[dim - 1] | |||
d[dim - 1] = -1 | |||
T[:dim, :dim] = U @ np.diag(d) @ V | |||
d[dim - 1] = s | |||
else: | |||
T[:dim, :dim] = U @ np.diag(d) @ V | |||
if estimate_scale: | |||
# Eq. (41) and (42). | |||
scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d) | |||
else: | |||
scale = scale | |||
T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T) | |||
T[:dim, :dim] *= scale | |||
return T, scale | |||
class FaceWarpException(Exception): | |||
def __str__(self): | |||
return 'In File {}:{}'.format(__file__, super.__str__(self)) | |||
def get_reference_facial_points(output_size=None, | |||
inner_padding_factor=0.0, | |||
outer_padding=(0, 0), | |||
default_square=False): | |||
ref_5pts = np.array(REFERENCE_FACIAL_POINTS) | |||
ref_crop_size = np.array(DEFAULT_CROP_SIZE) | |||
# 0) make the inner region a square | |||
if default_square: | |||
size_diff = max(ref_crop_size) - ref_crop_size | |||
ref_5pts += size_diff / 2 | |||
ref_crop_size += size_diff | |||
if (output_size and output_size[0] == ref_crop_size[0] | |||
and output_size[1] == ref_crop_size[1]): | |||
return ref_5pts | |||
if (inner_padding_factor == 0 and outer_padding == (0, 0)): | |||
if output_size is None: | |||
logger.info('No paddings to do: return default reference points') | |||
return ref_5pts | |||
else: | |||
raise FaceWarpException( | |||
'No paddings to do, output_size must be None or {}'.format( | |||
ref_crop_size)) | |||
# check output size | |||
if not (0 <= inner_padding_factor <= 1.0): | |||
raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') | |||
if ((inner_padding_factor > 0 or outer_padding[0] > 0 | |||
or outer_padding[1] > 0) and output_size is None): | |||
output_size = ref_crop_size * (1 + inner_padding_factor * 2).astype( | |||
np.int32) | |||
output_size += np.array(outer_padding) | |||
logger.info('deduced from paddings, output_size = ', output_size) | |||
if not (outer_padding[0] < output_size[0] | |||
and outer_padding[1] < output_size[1]): | |||
raise FaceWarpException('Not (outer_padding[0] < output_size[0]' | |||
'and outer_padding[1] < output_size[1])') | |||
# 1) pad the inner region according inner_padding_factor | |||
if inner_padding_factor > 0: | |||
size_diff = ref_crop_size * inner_padding_factor * 2 | |||
ref_5pts += size_diff / 2 | |||
ref_crop_size += np.round(size_diff).astype(np.int32) | |||
# 2) resize the padded inner region | |||
size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 | |||
if size_bf_outer_pad[0] * ref_crop_size[1] != size_bf_outer_pad[ | |||
1] * ref_crop_size[0]: | |||
raise FaceWarpException( | |||
'Must have (output_size - outer_padding)' | |||
'= some_scale * (crop_size * (1.0 + inner_padding_factor)') | |||
scale_factor = size_bf_outer_pad[0].astype(np.float32) / ref_crop_size[0] | |||
ref_5pts = ref_5pts * scale_factor | |||
ref_crop_size = size_bf_outer_pad | |||
# 3) add outer_padding to make output_size | |||
reference_5point = ref_5pts + np.array(outer_padding) | |||
ref_crop_size = output_size | |||
return reference_5point | |||
def get_affine_transform_matrix(src_pts, dst_pts): | |||
tfm = np.float32([[1, 0, 0], [0, 1, 0]]) | |||
n_pts = src_pts.shape[0] | |||
ones = np.ones((n_pts, 1), src_pts.dtype) | |||
src_pts_ = np.hstack([src_pts, ones]) | |||
dst_pts_ = np.hstack([dst_pts, ones]) | |||
A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) | |||
if rank == 3: | |||
tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], | |||
[A[0, 1], A[1, 1], A[2, 1]]]) | |||
elif rank == 2: | |||
tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]]) | |||
return tfm | |||
def get_params(reference_pts, facial_pts, align_type): | |||
ref_pts = np.float32(reference_pts) | |||
ref_pts_shp = ref_pts.shape | |||
if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: | |||
raise FaceWarpException( | |||
'reference_pts.shape must be (K,2) or (2,K) and K>2') | |||
if ref_pts_shp[0] == 2: | |||
ref_pts = ref_pts.T | |||
src_pts = np.float32(facial_pts) | |||
src_pts_shp = src_pts.shape | |||
if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: | |||
raise FaceWarpException( | |||
'facial_pts.shape must be (K,2) or (2,K) and K>2') | |||
if src_pts_shp[0] == 2: | |||
src_pts = src_pts.T | |||
if src_pts.shape != ref_pts.shape: | |||
raise FaceWarpException( | |||
'facial_pts and reference_pts must have the same shape') | |||
if align_type == 'cv2_affine': | |||
tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) | |||
tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3]) | |||
elif align_type == 'affine': | |||
tfm = get_affine_transform_matrix(src_pts, ref_pts) | |||
tfm_inv = get_affine_transform_matrix(ref_pts, src_pts) | |||
else: | |||
params, scale = _umeyama(src_pts, ref_pts) | |||
tfm = params[:2, :] | |||
params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0 / scale) | |||
tfm_inv = params[:2, :] | |||
return tfm, tfm_inv | |||
def warp_and_crop_face(src_img, | |||
facial_pts, | |||
reference_pts=None, | |||
crop_size=(96, 112), | |||
align_type='smilarity'): # smilarity cv2_affine affine | |||
reference_pts_112 = get_reference_facial_points((112, 112), 0.25, (0, 0), | |||
True) | |||
if reference_pts is None: | |||
if crop_size[0] == 96 and crop_size[1] == 112: | |||
reference_pts = REFERENCE_FACIAL_POINTS | |||
else: | |||
default_square = True # False | |||
inner_padding_factor = 0.25 # 0 | |||
outer_padding = (0, 0) | |||
output_size = crop_size | |||
reference_pts = get_reference_facial_points( | |||
output_size, inner_padding_factor, outer_padding, | |||
default_square) | |||
tfm, tfm_inv = get_params(reference_pts, facial_pts, align_type) | |||
tfm_112, tfm_inv_112 = get_params(reference_pts_112, facial_pts, | |||
align_type) | |||
if src_img is not None: | |||
face_img = cv2.warpAffine( | |||
src_img, tfm, (crop_size[0], crop_size[1]), flags=3) | |||
face_img_112 = cv2.warpAffine(src_img, tfm_112, (112, 112), flags=3) | |||
return face_img, face_img_112, tfm_inv | |||
else: | |||
return tfm, tfm_inv |
@@ -0,0 +1,57 @@ | |||
import os | |||
import cv2 | |||
import numpy as np | |||
import torch | |||
from .model_resnet import FaceQuality, ResNet | |||
class FQA(object): | |||
def __init__(self, backbone_path, quality_path, device='cuda', size=112): | |||
self.BACKBONE = ResNet(num_layers=100, feature_dim=512) | |||
self.QUALITY = FaceQuality(512 * 7 * 7) | |||
self.size = size | |||
self.device = device | |||
self.load_model(backbone_path, quality_path) | |||
def load_model(self, backbone_path, quality_path): | |||
checkpoint = torch.load(backbone_path, map_location='cpu') | |||
self.load_state_dict(self.BACKBONE, checkpoint) | |||
checkpoint = torch.load(quality_path, map_location='cpu') | |||
self.load_state_dict(self.QUALITY, checkpoint) | |||
self.BACKBONE.to(self.device) | |||
self.QUALITY.to(self.device) | |||
self.BACKBONE.eval() | |||
self.QUALITY.eval() | |||
def load_state_dict(self, model, state_dict): | |||
all_keys = {k for k in state_dict.keys()} | |||
for k in all_keys: | |||
if k.startswith('module.'): | |||
state_dict[k[7:]] = state_dict.pop(k) | |||
model_dict = model.state_dict() | |||
pretrained_dict = { | |||
k: v | |||
for k, v in state_dict.items() | |||
if k in model_dict and v.size() == model_dict[k].size() | |||
} | |||
model_dict.update(pretrained_dict) | |||
model.load_state_dict(model_dict) | |||
def get_face_quality(self, img): | |||
img = torch.from_numpy(img).permute(2, 0, | |||
1).unsqueeze(0).flip(1).cuda() | |||
img = (img - 127.5) / 128.0 | |||
# extract features & predict quality | |||
with torch.no_grad(): | |||
feature, fc = self.BACKBONE(img.to(self.device), True) | |||
s = self.QUALITY(fc)[0] | |||
return s.cpu().numpy()[0], feature.cpu().numpy()[0] |
@@ -0,0 +1,130 @@ | |||
import torch | |||
from torch import nn | |||
class BottleNeck_IR(nn.Module): | |||
def __init__(self, in_channel, out_channel, stride, dim_match): | |||
super(BottleNeck_IR, self).__init__() | |||
self.res_layer = nn.Sequential( | |||
nn.BatchNorm2d(in_channel), | |||
nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False), | |||
nn.BatchNorm2d(out_channel), nn.PReLU(out_channel), | |||
nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False), | |||
nn.BatchNorm2d(out_channel)) | |||
if dim_match: | |||
self.shortcut_layer = None | |||
else: | |||
self.shortcut_layer = nn.Sequential( | |||
nn.Conv2d( | |||
in_channel, | |||
out_channel, | |||
kernel_size=(1, 1), | |||
stride=stride, | |||
bias=False), nn.BatchNorm2d(out_channel)) | |||
def forward(self, x): | |||
shortcut = x | |||
res = self.res_layer(x) | |||
if self.shortcut_layer is not None: | |||
shortcut = self.shortcut_layer(x) | |||
return shortcut + res | |||
channel_list = [64, 64, 128, 256, 512] | |||
def get_layers(num_layers): | |||
if num_layers == 34: | |||
return [3, 4, 6, 3] | |||
if num_layers == 50: | |||
return [3, 4, 14, 3] | |||
elif num_layers == 100: | |||
return [3, 13, 30, 3] | |||
elif num_layers == 152: | |||
return [3, 8, 36, 3] | |||
class ResNet(nn.Module): | |||
def __init__(self, | |||
num_layers=100, | |||
feature_dim=512, | |||
drop_ratio=0.4, | |||
channel_list=channel_list): | |||
super(ResNet, self).__init__() | |||
assert num_layers in [34, 50, 100, 152] | |||
layers = get_layers(num_layers) | |||
block = BottleNeck_IR | |||
self.input_layer = nn.Sequential( | |||
nn.Conv2d( | |||
3, channel_list[0], (3, 3), stride=1, padding=1, bias=False), | |||
nn.BatchNorm2d(channel_list[0]), nn.PReLU(channel_list[0])) | |||
self.layer1 = self._make_layer( | |||
block, channel_list[0], channel_list[1], layers[0], stride=2) | |||
self.layer2 = self._make_layer( | |||
block, channel_list[1], channel_list[2], layers[1], stride=2) | |||
self.layer3 = self._make_layer( | |||
block, channel_list[2], channel_list[3], layers[2], stride=2) | |||
self.layer4 = self._make_layer( | |||
block, channel_list[3], channel_list[4], layers[3], stride=2) | |||
self.output_layer = nn.Sequential( | |||
nn.BatchNorm2d(512), nn.Dropout(drop_ratio), nn.Flatten()) | |||
self.feature_layer = nn.Sequential( | |||
nn.Linear(512 * 7 * 7, feature_dim), nn.BatchNorm1d(feature_dim)) | |||
for m in self.modules(): | |||
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): | |||
nn.init.xavier_uniform_(m.weight) | |||
if m.bias is not None: | |||
nn.init.constant_(m.bias, 0.0) | |||
elif isinstance(m, nn.BatchNorm2d) or isinstance( | |||
m, nn.BatchNorm1d): | |||
nn.init.constant_(m.weight, 1) | |||
nn.init.constant_(m.bias, 0) | |||
def _make_layer(self, block, in_channel, out_channel, blocks, stride): | |||
layers = [] | |||
layers.append(block(in_channel, out_channel, stride, False)) | |||
for i in range(1, blocks): | |||
layers.append(block(out_channel, out_channel, 1, True)) | |||
return nn.Sequential(*layers) | |||
def forward(self, x, fc=False): | |||
x = self.input_layer(x) | |||
x = self.layer1(x) | |||
x = self.layer2(x) | |||
x = self.layer3(x) | |||
x = self.layer4(x) | |||
x = self.output_layer(x) | |||
feature = self.feature_layer(x) | |||
if fc: | |||
return feature, x | |||
return feature | |||
class FaceQuality(nn.Module): | |||
def __init__(self, feature_dim): | |||
super(FaceQuality, self).__init__() | |||
self.qualtiy = nn.Sequential( | |||
nn.Linear(feature_dim, 512, bias=False), nn.BatchNorm1d(512), | |||
nn.ReLU(inplace=True), nn.Linear(512, 2, bias=False), | |||
nn.Softmax(dim=1)) | |||
for m in self.modules(): | |||
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): | |||
nn.init.xavier_uniform_(m.weight) | |||
if m.bias is not None: | |||
nn.init.constant_(m.bias, 0.0) | |||
elif isinstance(m, nn.BatchNorm2d) or isinstance( | |||
m, nn.BatchNorm1d): | |||
nn.init.constant_(m.weight, 1) | |||
nn.init.constant_(m.bias, 0) | |||
def forward(self, x): | |||
x = self.qualtiy(x) | |||
return x[:, 0:1] |
@@ -0,0 +1,813 @@ | |||
import functools | |||
import itertools | |||
import math | |||
import operator | |||
import random | |||
import torch | |||
from torch import nn | |||
from torch.autograd import Function | |||
from torch.nn import functional as F | |||
from modelscope.models.cv.face_generation.op import (FusedLeakyReLU, | |||
fused_leaky_relu, | |||
upfirdn2d) | |||
class PixelNorm(nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
def forward(self, input): | |||
return input * torch.rsqrt( | |||
torch.mean(input**2, dim=1, keepdim=True) + 1e-8) | |||
def make_kernel(k): | |||
k = torch.tensor(k, dtype=torch.float32) | |||
if k.ndim == 1: | |||
k = k[None, :] * k[:, None] | |||
k /= k.sum() | |||
return k | |||
class Upsample(nn.Module): | |||
def __init__(self, kernel, factor=2): | |||
super().__init__() | |||
self.factor = factor | |||
kernel = make_kernel(kernel) * (factor**2) | |||
self.register_buffer('kernel', kernel) | |||
p = kernel.shape[0] - factor | |||
pad0 = (p + 1) // 2 + factor - 1 | |||
pad1 = p // 2 | |||
self.pad = (pad0, pad1) | |||
def forward(self, input): | |||
out = upfirdn2d( | |||
input, self.kernel, up=self.factor, down=1, pad=self.pad) | |||
return out | |||
class Downsample(nn.Module): | |||
def __init__(self, kernel, factor=2): | |||
super().__init__() | |||
self.factor = factor | |||
kernel = make_kernel(kernel) | |||
self.register_buffer('kernel', kernel) | |||
p = kernel.shape[0] - factor | |||
pad0 = (p + 1) // 2 | |||
pad1 = p // 2 | |||
self.pad = (pad0, pad1) | |||
def forward(self, input): | |||
out = upfirdn2d( | |||
input, self.kernel, up=1, down=self.factor, pad=self.pad) | |||
return out | |||
class Blur(nn.Module): | |||
def __init__(self, kernel, pad, upsample_factor=1): | |||
super().__init__() | |||
kernel = make_kernel(kernel) | |||
if upsample_factor > 1: | |||
kernel = kernel * (upsample_factor**2) | |||
self.register_buffer('kernel', kernel) | |||
self.pad = pad | |||
def forward(self, input): | |||
out = upfirdn2d(input, self.kernel, pad=self.pad) | |||
return out | |||
class EqualConv2d(nn.Module): | |||
def __init__(self, | |||
in_channel, | |||
out_channel, | |||
kernel_size, | |||
stride=1, | |||
padding=0, | |||
bias=True): | |||
super().__init__() | |||
self.weight = nn.Parameter( | |||
torch.randn(out_channel, in_channel, kernel_size, kernel_size)) | |||
self.scale = 1 / math.sqrt(in_channel * kernel_size**2) | |||
self.stride = stride | |||
self.padding = padding | |||
if bias: | |||
self.bias = nn.Parameter(torch.zeros(out_channel)) | |||
else: | |||
self.bias = None | |||
def forward(self, input): | |||
out = F.conv2d( | |||
input, | |||
self.weight * self.scale, | |||
bias=self.bias, | |||
stride=self.stride, | |||
padding=self.padding, | |||
) | |||
return out | |||
def __repr__(self): | |||
return ( | |||
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' | |||
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' | |||
) | |||
class EqualLinear(nn.Module): | |||
def __init__(self, | |||
in_dim, | |||
out_dim, | |||
bias=True, | |||
bias_init=0, | |||
lr_mul=1, | |||
activation=None): | |||
super().__init__() | |||
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) | |||
if bias: | |||
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) | |||
else: | |||
self.bias = None | |||
self.activation = activation | |||
self.scale = (1 / math.sqrt(in_dim)) * lr_mul | |||
self.lr_mul = lr_mul | |||
def forward(self, input): | |||
if self.activation: | |||
out = F.linear(input, self.weight * self.scale) | |||
out = fused_leaky_relu(out, self.bias * self.lr_mul) | |||
else: | |||
out = F.linear( | |||
input, self.weight * self.scale, bias=self.bias * self.lr_mul) | |||
return out | |||
def __repr__(self): | |||
return ( | |||
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' | |||
) | |||
class ScaledLeakyReLU(nn.Module): | |||
def __init__(self, negative_slope=0.2): | |||
super().__init__() | |||
self.negative_slope = negative_slope | |||
def forward(self, input): | |||
out = F.leaky_relu(input, negative_slope=self.negative_slope) | |||
return out * math.sqrt(2) | |||
class ModulatedConv2d(nn.Module): | |||
def __init__( | |||
self, | |||
in_channel, | |||
out_channel, | |||
kernel_size, | |||
style_dim, | |||
demodulate=True, | |||
upsample=False, | |||
downsample=False, | |||
blur_kernel=[1, 3, 3, 1], | |||
): | |||
super().__init__() | |||
self.eps = 1e-8 | |||
self.kernel_size = kernel_size | |||
self.in_channel = in_channel | |||
self.out_channel = out_channel | |||
self.upsample = upsample | |||
self.downsample = downsample | |||
if upsample: | |||
factor = 2 | |||
p = (len(blur_kernel) - factor) - (kernel_size - 1) | |||
pad0 = (p + 1) // 2 + factor - 1 | |||
pad1 = p // 2 + 1 | |||
self.blur = Blur( | |||
blur_kernel, pad=(pad0, pad1), upsample_factor=factor) | |||
if downsample: | |||
factor = 2 | |||
p = (len(blur_kernel) - factor) + (kernel_size - 1) | |||
pad0 = (p + 1) // 2 | |||
pad1 = p // 2 | |||
self.blur = Blur(blur_kernel, pad=(pad0, pad1)) | |||
fan_in = in_channel * kernel_size**2 | |||
self.scale = 1 / math.sqrt(fan_in) | |||
self.padding = kernel_size // 2 | |||
self.weight = nn.Parameter( | |||
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)) | |||
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) | |||
self.demodulate = demodulate | |||
def __repr__(self): | |||
return ( | |||
f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' | |||
f'upsample={self.upsample}, downsample={self.downsample})') | |||
def forward(self, input, style): | |||
batch, in_channel, height, width = input.shape | |||
style = self.modulation(style).view(batch, 1, in_channel, 1, 1) | |||
weight = self.scale * self.weight * style | |||
if self.demodulate: | |||
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) | |||
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) | |||
weight = weight.view(batch * self.out_channel, in_channel, | |||
self.kernel_size, self.kernel_size) | |||
if self.upsample: | |||
input = input.view(1, batch * in_channel, height, width) | |||
weight = weight.view(batch, self.out_channel, in_channel, | |||
self.kernel_size, self.kernel_size) | |||
weight = weight.transpose(1, 2).reshape(batch * in_channel, | |||
self.out_channel, | |||
self.kernel_size, | |||
self.kernel_size) | |||
out = F.conv_transpose2d( | |||
input, weight, padding=0, stride=2, groups=batch) | |||
_, _, height, width = out.shape | |||
out = out.view(batch, self.out_channel, height, width) | |||
out = self.blur(out) | |||
elif self.downsample: | |||
input = self.blur(input) | |||
_, _, height, width = input.shape | |||
input = input.view(1, batch * in_channel, height, width) | |||
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) | |||
_, _, height, width = out.shape | |||
out = out.view(batch, self.out_channel, height, width) | |||
else: | |||
input = input.view(1, batch * in_channel, height, width) | |||
out = F.conv2d(input, weight, padding=self.padding, groups=batch) | |||
_, _, height, width = out.shape | |||
out = out.view(batch, self.out_channel, height, width) | |||
return out | |||
class NoiseInjection(nn.Module): | |||
def __init__(self, isconcat=True): | |||
super().__init__() | |||
self.isconcat = isconcat | |||
self.weight = nn.Parameter(torch.zeros(1)) | |||
def forward(self, image, noise=None): | |||
if noise is None: | |||
batch, channel, height, width = image.shape | |||
noise = image.new_empty(batch, channel, height, width).normal_() | |||
if self.isconcat: | |||
return torch.cat((image, self.weight * noise), dim=1) | |||
else: | |||
return image + self.weight * noise | |||
class ConstantInput(nn.Module): | |||
def __init__(self, channel, size=4): | |||
super().__init__() | |||
self.input = nn.Parameter(torch.randn(1, channel, size, size)) | |||
def forward(self, input): | |||
batch = input.shape[0] | |||
out = self.input.repeat(batch, 1, 1, 1) | |||
return out | |||
class StyledConv(nn.Module): | |||
def __init__( | |||
self, | |||
in_channel, | |||
out_channel, | |||
kernel_size, | |||
style_dim, | |||
upsample=False, | |||
blur_kernel=[1, 3, 3, 1], | |||
demodulate=True, | |||
isconcat=True, | |||
): | |||
super().__init__() | |||
self.conv = ModulatedConv2d( | |||
in_channel, | |||
out_channel, | |||
kernel_size, | |||
style_dim, | |||
upsample=upsample, | |||
blur_kernel=blur_kernel, | |||
demodulate=demodulate, | |||
) | |||
self.noise = NoiseInjection(isconcat) | |||
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) | |||
# self.activate = ScaledLeakyReLU(0.2) | |||
feat_multiplier = 2 if isconcat else 1 | |||
self.activate = FusedLeakyReLU(out_channel * feat_multiplier) | |||
def forward(self, input, style, noise=None): | |||
out = self.conv(input, style) | |||
out = self.noise(out, noise=noise) | |||
# out = out + self.bias | |||
out = self.activate(out) | |||
return out | |||
class ToRGB(nn.Module): | |||
def __init__(self, | |||
in_channel, | |||
style_dim, | |||
upsample=True, | |||
blur_kernel=[1, 3, 3, 1]): | |||
super().__init__() | |||
if upsample: | |||
self.upsample = Upsample(blur_kernel) | |||
self.conv = ModulatedConv2d( | |||
in_channel, 3, 1, style_dim, demodulate=False) | |||
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) | |||
def forward(self, input, style, skip=None): | |||
out = self.conv(input, style) | |||
out = out + self.bias | |||
if skip is not None: | |||
skip = self.upsample(skip) | |||
out = out + skip | |||
return out | |||
class Generator(nn.Module): | |||
def __init__( | |||
self, | |||
size, | |||
style_dim, | |||
n_mlp, | |||
channel_multiplier=2, | |||
blur_kernel=[1, 3, 3, 1], | |||
lr_mlp=0.01, | |||
isconcat=True, | |||
narrow=1, | |||
): | |||
super().__init__() | |||
self.size = size | |||
self.n_mlp = n_mlp | |||
self.style_dim = style_dim | |||
self.feat_multiplier = 2 if isconcat else 1 | |||
layers = [PixelNorm()] | |||
for i in range(n_mlp): | |||
layers.append( | |||
EqualLinear( | |||
style_dim, | |||
style_dim, | |||
lr_mul=lr_mlp, | |||
activation='fused_lrelu')) | |||
self.style = nn.Sequential(*layers) | |||
self.channels = { | |||
4: int(512 * narrow), | |||
8: int(512 * narrow), | |||
16: int(512 * narrow), | |||
32: int(512 * narrow), | |||
64: int(256 * channel_multiplier * narrow), | |||
128: int(128 * channel_multiplier * narrow), | |||
256: int(64 * channel_multiplier * narrow), | |||
512: int(32 * channel_multiplier * narrow), | |||
1024: int(16 * channel_multiplier * narrow), | |||
2048: int(8 * channel_multiplier * narrow) | |||
} | |||
self.input = ConstantInput(self.channels[4]) | |||
self.conv1 = StyledConv( | |||
self.channels[4], | |||
self.channels[4], | |||
3, | |||
style_dim, | |||
blur_kernel=blur_kernel, | |||
isconcat=isconcat) | |||
self.to_rgb1 = ToRGB( | |||
self.channels[4] * self.feat_multiplier, style_dim, upsample=False) | |||
self.log_size = int(math.log(size, 2)) | |||
self.convs = nn.ModuleList() | |||
self.upsamples = nn.ModuleList() | |||
self.to_rgbs = nn.ModuleList() | |||
in_channel = self.channels[4] | |||
for i in range(3, self.log_size + 1): | |||
out_channel = self.channels[2**i] | |||
self.convs.append( | |||
StyledConv( | |||
in_channel * self.feat_multiplier, | |||
out_channel, | |||
3, | |||
style_dim, | |||
upsample=True, | |||
blur_kernel=blur_kernel, | |||
isconcat=isconcat, | |||
)) | |||
self.convs.append( | |||
StyledConv( | |||
out_channel * self.feat_multiplier, | |||
out_channel, | |||
3, | |||
style_dim, | |||
blur_kernel=blur_kernel, | |||
isconcat=isconcat)) | |||
self.to_rgbs.append( | |||
ToRGB(out_channel * self.feat_multiplier, style_dim)) | |||
in_channel = out_channel | |||
self.n_latent = self.log_size * 2 - 2 | |||
def make_noise(self): | |||
device = self.input.input.device | |||
noises = [torch.randn(1, 1, 2**2, 2**2, device=device)] | |||
for i in range(3, self.log_size + 1): | |||
for _ in range(2): | |||
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) | |||
return noises | |||
def mean_latent(self, n_latent): | |||
latent_in = torch.randn( | |||
n_latent, self.style_dim, device=self.input.input.device) | |||
latent = self.style(latent_in).mean(0, keepdim=True) | |||
return latent | |||
def get_latent(self, input): | |||
return self.style(input) | |||
def forward( | |||
self, | |||
styles, | |||
return_latents=False, | |||
inject_index=None, | |||
truncation=1, | |||
truncation_latent=None, | |||
input_is_latent=False, | |||
noise=None, | |||
): | |||
if not input_is_latent: | |||
styles = [self.style(s) for s in styles] | |||
if noise is None: | |||
''' | |||
noise = [None] * (2 * (self.log_size - 2) + 1) | |||
''' | |||
noise = [] | |||
batch = styles[0].shape[0] | |||
for i in range(self.n_mlp + 1): | |||
size = 2**(i + 2) | |||
noise.append( | |||
torch.randn( | |||
batch, | |||
self.channels[size], | |||
size, | |||
size, | |||
device=styles[0].device)) | |||
if truncation < 1: | |||
style_t = [] | |||
for style in styles: | |||
style_t.append(truncation_latent | |||
+ truncation * (style - truncation_latent)) | |||
styles = style_t | |||
if len(styles) < 2: | |||
inject_index = self.n_latent | |||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) | |||
else: | |||
if inject_index is None: | |||
inject_index = random.randint(1, self.n_latent - 1) | |||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) | |||
latent2 = styles[1].unsqueeze(1).repeat( | |||
1, self.n_latent - inject_index, 1) | |||
latent = torch.cat([latent, latent2], 1) | |||
out = self.input(latent) | |||
out = self.conv1(out, latent[:, 0], noise=noise[0]) | |||
skip = self.to_rgb1(out, latent[:, 1]) | |||
i = 1 | |||
for conv1, conv2, noise1, noise2, to_rgb in zip( | |||
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], | |||
self.to_rgbs): | |||
out = conv1(out, latent[:, i], noise=noise1) | |||
out = conv2(out, latent[:, i + 1], noise=noise2) | |||
skip = to_rgb(out, latent[:, i + 2], skip) | |||
i += 2 | |||
image = skip | |||
if return_latents: | |||
return image, latent | |||
else: | |||
return image, None | |||
class ConvLayer(nn.Sequential): | |||
def __init__( | |||
self, | |||
in_channel, | |||
out_channel, | |||
kernel_size, | |||
downsample=False, | |||
blur_kernel=[1, 3, 3, 1], | |||
bias=True, | |||
activate=True, | |||
): | |||
layers = [] | |||
if downsample: | |||
factor = 2 | |||
p = (len(blur_kernel) - factor) + (kernel_size - 1) | |||
pad0 = (p + 1) // 2 | |||
pad1 = p // 2 | |||
layers.append(Blur(blur_kernel, pad=(pad0, pad1))) | |||
stride = 2 | |||
self.padding = 0 | |||
else: | |||
stride = 1 | |||
self.padding = kernel_size // 2 | |||
layers.append( | |||
EqualConv2d( | |||
in_channel, | |||
out_channel, | |||
kernel_size, | |||
padding=self.padding, | |||
stride=stride, | |||
bias=bias and not activate, | |||
)) | |||
if activate: | |||
if bias: | |||
layers.append(FusedLeakyReLU(out_channel)) | |||
else: | |||
layers.append(ScaledLeakyReLU(0.2)) | |||
super().__init__(*layers) | |||
class ResBlock(nn.Module): | |||
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): | |||
super().__init__() | |||
self.conv1 = ConvLayer(in_channel, in_channel, 3) | |||
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) | |||
self.skip = ConvLayer( | |||
in_channel, | |||
out_channel, | |||
1, | |||
downsample=True, | |||
activate=False, | |||
bias=False) | |||
def forward(self, input): | |||
out = self.conv1(input) | |||
out = self.conv2(out) | |||
skip = self.skip(input) | |||
out = (out + skip) / math.sqrt(2) | |||
return out | |||
class FullGenerator(nn.Module): | |||
def __init__( | |||
self, | |||
size, | |||
style_dim, | |||
n_mlp, | |||
channel_multiplier=2, | |||
blur_kernel=[1, 3, 3, 1], | |||
lr_mlp=0.01, | |||
isconcat=True, | |||
narrow=1, | |||
): | |||
super().__init__() | |||
channels = { | |||
4: int(512 * narrow), | |||
8: int(512 * narrow), | |||
16: int(512 * narrow), | |||
32: int(512 * narrow), | |||
64: int(256 * channel_multiplier * narrow), | |||
128: int(128 * channel_multiplier * narrow), | |||
256: int(64 * channel_multiplier * narrow), | |||
512: int(32 * channel_multiplier * narrow), | |||
1024: int(16 * channel_multiplier * narrow), | |||
2048: int(8 * channel_multiplier * narrow) | |||
} | |||
self.log_size = int(math.log(size, 2)) | |||
self.generator = Generator( | |||
size, | |||
style_dim, | |||
n_mlp, | |||
channel_multiplier=channel_multiplier, | |||
blur_kernel=blur_kernel, | |||
lr_mlp=lr_mlp, | |||
isconcat=isconcat, | |||
narrow=narrow) | |||
conv = [ConvLayer(3, channels[size], 1)] | |||
self.ecd0 = nn.Sequential(*conv) | |||
in_channel = channels[size] | |||
self.names = ['ecd%d' % i for i in range(self.log_size - 1)] | |||
for i in range(self.log_size, 2, -1): | |||
out_channel = channels[2**(i - 1)] | |||
# conv = [ResBlock(in_channel, out_channel, blur_kernel)] | |||
conv = [ConvLayer(in_channel, out_channel, 3, downsample=True)] | |||
setattr(self, self.names[self.log_size - i + 1], | |||
nn.Sequential(*conv)) | |||
in_channel = out_channel | |||
self.final_linear = nn.Sequential( | |||
EqualLinear( | |||
channels[4] * 4 * 4, style_dim, activation='fused_lrelu')) | |||
def forward( | |||
self, | |||
inputs, | |||
return_latents=False, | |||
inject_index=None, | |||
truncation=1, | |||
truncation_latent=None, | |||
input_is_latent=False, | |||
): | |||
noise = [] | |||
for i in range(self.log_size - 1): | |||
ecd = getattr(self, self.names[i]) | |||
inputs = ecd(inputs) | |||
noise.append(inputs) | |||
inputs = inputs.view(inputs.shape[0], -1) | |||
outs = self.final_linear(inputs) | |||
noise = list( | |||
itertools.chain.from_iterable( | |||
itertools.repeat(x, 2) for x in noise))[::-1] | |||
outs = self.generator([outs], | |||
return_latents, | |||
inject_index, | |||
truncation, | |||
truncation_latent, | |||
input_is_latent, | |||
noise=noise[1:]) | |||
return outs | |||
class Discriminator(nn.Module): | |||
def __init__(self, | |||
size, | |||
channel_multiplier=2, | |||
blur_kernel=[1, 3, 3, 1], | |||
narrow=1): | |||
super().__init__() | |||
channels = { | |||
4: int(512 * narrow), | |||
8: int(512 * narrow), | |||
16: int(512 * narrow), | |||
32: int(512 * narrow), | |||
64: int(256 * channel_multiplier * narrow), | |||
128: int(128 * channel_multiplier * narrow), | |||
256: int(64 * channel_multiplier * narrow), | |||
512: int(32 * channel_multiplier * narrow), | |||
1024: int(16 * channel_multiplier * narrow), | |||
2048: int(8 * channel_multiplier * narrow) | |||
} | |||
convs = [ConvLayer(3, channels[size], 1)] | |||
log_size = int(math.log(size, 2)) | |||
in_channel = channels[size] | |||
for i in range(log_size, 2, -1): | |||
out_channel = channels[2**(i - 1)] | |||
convs.append(ResBlock(in_channel, out_channel, blur_kernel)) | |||
in_channel = out_channel | |||
self.convs = nn.Sequential(*convs) | |||
self.stddev_group = 4 | |||
self.stddev_feat = 1 | |||
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) | |||
self.final_linear = nn.Sequential( | |||
EqualLinear( | |||
channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), | |||
EqualLinear(channels[4], 1), | |||
) | |||
def forward(self, input): | |||
out = self.convs(input) | |||
batch, channel, height, width = out.shape | |||
group = min(batch, self.stddev_group) | |||
stddev = out.view(group, -1, self.stddev_feat, | |||
channel // self.stddev_feat, height, width) | |||
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) | |||
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) | |||
stddev = stddev.repeat(group, 1, height, width) | |||
out = torch.cat([out, stddev], 1) | |||
out = self.final_conv(out) | |||
out = out.view(batch, -1) | |||
out = self.final_linear(out) | |||
return out |
@@ -0,0 +1,205 @@ | |||
import math | |||
import os.path as osp | |||
from copy import deepcopy | |||
from typing import Any, Dict, List, Union | |||
import torch | |||
import torch.nn.functional as F | |||
from torch import autograd, nn | |||
from torch.nn.parallel import DataParallel, DistributedDataParallel | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base import Tensor, TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
from .gpen import Discriminator, FullGenerator | |||
from .losses.losses import IDLoss, L1Loss | |||
logger = get_logger() | |||
__all__ = ['ImagePortraitEnhancement'] | |||
@MODELS.register_module( | |||
Tasks.image_portrait_enhancement, module_name=Models.gpen) | |||
class ImagePortraitEnhancement(TorchModel): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""initialize the face enhancement model from the `model_dir` path. | |||
Args: | |||
model_dir (str): the model path. | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
self.size = 512 | |||
self.style_dim = 512 | |||
self.n_mlp = 8 | |||
self.mean_path_length = 0 | |||
self.accum = 0.5**(32 / (10 * 1000)) | |||
if torch.cuda.is_available(): | |||
self._device = torch.device('cuda') | |||
else: | |||
self._device = torch.device('cpu') | |||
self.l1_loss = L1Loss() | |||
self.id_loss = IDLoss(f'{model_dir}/arcface/model_ir_se50.pth', | |||
self._device) | |||
self.generator = FullGenerator( | |||
self.size, self.style_dim, self.n_mlp, | |||
isconcat=True).to(self._device) | |||
self.g_ema = FullGenerator( | |||
self.size, self.style_dim, self.n_mlp, | |||
isconcat=True).to(self._device) | |||
self.discriminator = Discriminator(self.size).to(self._device) | |||
if self.size == 512: | |||
self.load_pretrained(model_dir) | |||
def load_pretrained(self, model_dir): | |||
g_path = f'{model_dir}/{ModelFile.TORCH_MODEL_FILE}' | |||
g_dict = torch.load(g_path, map_location=torch.device('cpu')) | |||
self.generator.load_state_dict(g_dict) | |||
self.g_ema.load_state_dict(g_dict) | |||
d_path = f'{model_dir}/net_d.pt' | |||
d_dict = torch.load(d_path, map_location=torch.device('cpu')) | |||
self.discriminator.load_state_dict(d_dict) | |||
logger.info('load model done.') | |||
def accumulate(self): | |||
par1 = dict(self.g_ema.named_parameters()) | |||
par2 = dict(self.generator.named_parameters()) | |||
for k in par1.keys(): | |||
par1[k].data.mul_(self.accum).add_(1 - self.accum, par2[k].data) | |||
def requires_grad(self, model, flag=True): | |||
for p in model.parameters(): | |||
p.requires_grad = flag | |||
def d_logistic_loss(self, real_pred, fake_pred): | |||
real_loss = F.softplus(-real_pred) | |||
fake_loss = F.softplus(fake_pred) | |||
return real_loss.mean() + fake_loss.mean() | |||
def d_r1_loss(self, real_pred, real_img): | |||
grad_real, = autograd.grad( | |||
outputs=real_pred.sum(), inputs=real_img, create_graph=True) | |||
grad_penalty = grad_real.pow(2).view(grad_real.shape[0], | |||
-1).sum(1).mean() | |||
return grad_penalty | |||
def g_nonsaturating_loss(self, | |||
fake_pred, | |||
fake_img=None, | |||
real_img=None, | |||
input_img=None): | |||
loss = F.softplus(-fake_pred).mean() | |||
loss_l1 = self.l1_loss(fake_img, real_img) | |||
loss_id, __, __ = self.id_loss(fake_img, real_img, input_img) | |||
loss_id = 0 | |||
loss += 1.0 * loss_l1 + 1.0 * loss_id | |||
return loss | |||
def g_path_regularize(self, | |||
fake_img, | |||
latents, | |||
mean_path_length, | |||
decay=0.01): | |||
noise = torch.randn_like(fake_img) / math.sqrt( | |||
fake_img.shape[2] * fake_img.shape[3]) | |||
grad, = autograd.grad( | |||
outputs=(fake_img * noise).sum(), | |||
inputs=latents, | |||
create_graph=True) | |||
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) | |||
path_mean = mean_path_length + decay * ( | |||
path_lengths.mean() - mean_path_length) | |||
path_penalty = (path_lengths - path_mean).pow(2).mean() | |||
return path_penalty, path_mean.detach(), path_lengths | |||
@torch.no_grad() | |||
def _evaluate_postprocess(self, src: Tensor, | |||
target: Tensor) -> Dict[str, list]: | |||
preds, _ = self.generator(src) | |||
preds = list(torch.split(preds, 1, 0)) | |||
targets = list(torch.split(target, 1, 0)) | |||
preds = [((pred.data * 0.5 + 0.5) * 255.).squeeze(0).type( | |||
torch.uint8).permute(1, 2, 0).cpu().numpy() for pred in preds] | |||
targets = [((target.data * 0.5 + 0.5) * 255.).squeeze(0).type( | |||
torch.uint8).permute(1, 2, 0).cpu().numpy() for target in targets] | |||
return {'pred': preds, 'target': targets} | |||
def _train_forward_d(self, src: Tensor, target: Tensor) -> Tensor: | |||
self.requires_grad(self.generator, False) | |||
self.requires_grad(self.discriminator, True) | |||
preds, _ = self.generator(src) | |||
fake_pred = self.discriminator(preds) | |||
real_pred = self.discriminator(target) | |||
d_loss = self.d_logistic_loss(real_pred, fake_pred) | |||
return d_loss | |||
def _train_forward_d_r1(self, src: Tensor, target: Tensor) -> Tensor: | |||
src.requires_grad = True | |||
target.requires_grad = True | |||
real_pred = self.discriminator(target) | |||
r1_loss = self.d_r1_loss(real_pred, target) | |||
return r1_loss | |||
def _train_forward_g(self, src: Tensor, target: Tensor) -> Tensor: | |||
self.requires_grad(self.generator, True) | |||
self.requires_grad(self.discriminator, False) | |||
preds, _ = self.generator(src) | |||
fake_pred = self.discriminator(preds) | |||
g_loss = self.g_nonsaturating_loss(fake_pred, preds, target, src) | |||
return g_loss | |||
def _train_forward_g_path(self, src: Tensor, target: Tensor) -> Tensor: | |||
fake_img, latents = self.generator(src, return_latents=True) | |||
path_loss, self.mean_path_length, path_lengths = self.g_path_regularize( | |||
fake_img, latents, self.mean_path_length) | |||
return path_loss | |||
@torch.no_grad() | |||
def _inference_forward(self, src: Tensor) -> Dict[str, Tensor]: | |||
return {'outputs': (self.generator(src)[0] * 0.5 + 0.5).clamp(0, 1)} | |||
def forward(self, input: Dict[str, | |||
Tensor]) -> Dict[str, Union[list, Tensor]]: | |||
"""return the result by the model | |||
Args: | |||
input (Dict[str, Tensor]): the preprocessed data | |||
Returns: | |||
Dict[str, Union[list, Tensor]]: results | |||
""" | |||
for key, value in input.items(): | |||
input[key] = input[key].to(self._device) | |||
if 'target' in input: | |||
return self._evaluate_postprocess(**input) | |||
else: | |||
return self._inference_forward(**input) |
@@ -0,0 +1,129 @@ | |||
from collections import namedtuple | |||
import torch | |||
from torch.nn import (AdaptiveAvgPool2d, BatchNorm2d, Conv2d, MaxPool2d, | |||
Module, PReLU, ReLU, Sequential, Sigmoid) | |||
class Flatten(Module): | |||
def forward(self, input): | |||
return input.view(input.size(0), -1) | |||
def l2_norm(input, axis=1): | |||
norm = torch.norm(input, 2, axis, True) | |||
output = torch.div(input, norm) | |||
return output | |||
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): | |||
""" A named tuple describing a ResNet block. """ | |||
def get_block(in_channel, depth, num_units, stride=2): | |||
return [Bottleneck(in_channel, depth, stride) | |||
] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] | |||
def get_blocks(num_layers): | |||
if num_layers == 50: | |||
blocks = [ | |||
get_block(in_channel=64, depth=64, num_units=3), | |||
get_block(in_channel=64, depth=128, num_units=4), | |||
get_block(in_channel=128, depth=256, num_units=14), | |||
get_block(in_channel=256, depth=512, num_units=3) | |||
] | |||
elif num_layers == 100: | |||
blocks = [ | |||
get_block(in_channel=64, depth=64, num_units=3), | |||
get_block(in_channel=64, depth=128, num_units=13), | |||
get_block(in_channel=128, depth=256, num_units=30), | |||
get_block(in_channel=256, depth=512, num_units=3) | |||
] | |||
elif num_layers == 152: | |||
blocks = [ | |||
get_block(in_channel=64, depth=64, num_units=3), | |||
get_block(in_channel=64, depth=128, num_units=8), | |||
get_block(in_channel=128, depth=256, num_units=36), | |||
get_block(in_channel=256, depth=512, num_units=3) | |||
] | |||
else: | |||
raise ValueError( | |||
'Invalid number of layers: {}. Must be one of [50, 100, 152]'. | |||
format(num_layers)) | |||
return blocks | |||
class SEModule(Module): | |||
def __init__(self, channels, reduction): | |||
super(SEModule, self).__init__() | |||
self.avg_pool = AdaptiveAvgPool2d(1) | |||
self.fc1 = Conv2d( | |||
channels, | |||
channels // reduction, | |||
kernel_size=1, | |||
padding=0, | |||
bias=False) | |||
self.relu = ReLU(inplace=True) | |||
self.fc2 = Conv2d( | |||
channels // reduction, | |||
channels, | |||
kernel_size=1, | |||
padding=0, | |||
bias=False) | |||
self.sigmoid = Sigmoid() | |||
def forward(self, x): | |||
module_input = x | |||
x = self.avg_pool(x) | |||
x = self.fc1(x) | |||
x = self.relu(x) | |||
x = self.fc2(x) | |||
x = self.sigmoid(x) | |||
return module_input * x | |||
class bottleneck_IR(Module): | |||
def __init__(self, in_channel, depth, stride): | |||
super(bottleneck_IR, self).__init__() | |||
if in_channel == depth: | |||
self.shortcut_layer = MaxPool2d(1, stride) | |||
else: | |||
self.shortcut_layer = Sequential( | |||
Conv2d(in_channel, depth, (1, 1), stride, bias=False), | |||
BatchNorm2d(depth)) | |||
self.res_layer = Sequential( | |||
BatchNorm2d(in_channel), | |||
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), | |||
PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), | |||
BatchNorm2d(depth)) | |||
def forward(self, x): | |||
shortcut = self.shortcut_layer(x) | |||
res = self.res_layer(x) | |||
return res + shortcut | |||
class bottleneck_IR_SE(Module): | |||
def __init__(self, in_channel, depth, stride): | |||
super(bottleneck_IR_SE, self).__init__() | |||
if in_channel == depth: | |||
self.shortcut_layer = MaxPool2d(1, stride) | |||
else: | |||
self.shortcut_layer = Sequential( | |||
Conv2d(in_channel, depth, (1, 1), stride, bias=False), | |||
BatchNorm2d(depth)) | |||
self.res_layer = Sequential( | |||
BatchNorm2d(in_channel), | |||
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), | |||
PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), | |||
BatchNorm2d(depth), SEModule(depth, 16)) | |||
def forward(self, x): | |||
shortcut = self.shortcut_layer(x) | |||
res = self.res_layer(x) | |||
return res + shortcut |
@@ -0,0 +1,90 @@ | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from .model_irse import Backbone | |||
class L1Loss(nn.Module): | |||
"""L1 (mean absolute error, MAE) loss. | |||
Args: | |||
loss_weight (float): Loss weight for L1 loss. Default: 1.0. | |||
reduction (str): Specifies the reduction to apply to the output. | |||
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. | |||
""" | |||
def __init__(self, loss_weight=1.0, reduction='mean'): | |||
super(L1Loss, self).__init__() | |||
if reduction not in ['none', 'mean', 'sum']: | |||
raise ValueError( | |||
f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}' | |||
) | |||
self.loss_weight = loss_weight | |||
self.reduction = reduction | |||
def forward(self, pred, target, weight=None, **kwargs): | |||
""" | |||
Args: | |||
pred (Tensor): of shape (N, C, H, W). Predicted tensor. | |||
target (Tensor): of shape (N, C, H, W). Ground truth tensor. | |||
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. | |||
""" | |||
return self.loss_weight * F.l1_loss( | |||
pred, target, reduction=self.reduction) | |||
class IDLoss(nn.Module): | |||
def __init__(self, model_path, device='cuda', ckpt_dict=None): | |||
super(IDLoss, self).__init__() | |||
print('Loading ResNet ArcFace') | |||
self.facenet = Backbone( | |||
input_size=112, num_layers=50, drop_ratio=0.6, | |||
mode='ir_se').to(device) | |||
if ckpt_dict is None: | |||
self.facenet.load_state_dict( | |||
torch.load(model_path, map_location=torch.device('cpu'))) | |||
else: | |||
self.facenet.load_state_dict(ckpt_dict) | |||
self.pool = torch.nn.AdaptiveAvgPool2d((256, 256)) | |||
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) | |||
self.facenet.eval() | |||
def extract_feats(self, x): | |||
_, _, h, w = x.shape | |||
assert h == w | |||
if h != 256: | |||
x = self.pool(x) | |||
x = x[:, :, 35:-33, 32:-36] # crop roi | |||
x = self.face_pool(x) | |||
x_feats = self.facenet(x) | |||
return x_feats | |||
@torch.no_grad() | |||
def forward(self, y_hat, y, x): | |||
n_samples = x.shape[0] | |||
x_feats = self.extract_feats(x) | |||
y_feats = self.extract_feats(y) # Otherwise use the feature from there | |||
y_hat_feats = self.extract_feats(y_hat) | |||
y_feats = y_feats.detach() | |||
loss = 0 | |||
sim_improvement = 0 | |||
id_logs = [] | |||
count = 0 | |||
for i in range(n_samples): | |||
diff_target = y_hat_feats[i].dot(y_feats[i]) | |||
diff_input = y_hat_feats[i].dot(x_feats[i]) | |||
diff_views = y_feats[i].dot(x_feats[i]) | |||
id_logs.append({ | |||
'diff_target': float(diff_target), | |||
'diff_input': float(diff_input), | |||
'diff_views': float(diff_views) | |||
}) | |||
loss += 1 - diff_target | |||
id_diff = float(diff_target) - float(diff_views) | |||
sim_improvement += id_diff | |||
count += 1 | |||
return loss / count, sim_improvement / count, id_logs |
@@ -0,0 +1,92 @@ | |||
from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, | |||
Module, PReLU, Sequential) | |||
from .helpers import (Flatten, bottleneck_IR, bottleneck_IR_SE, get_blocks, | |||
l2_norm) | |||
class Backbone(Module): | |||
def __init__(self, | |||
input_size, | |||
num_layers, | |||
mode='ir', | |||
drop_ratio=0.4, | |||
affine=True): | |||
super(Backbone, self).__init__() | |||
assert input_size in [112, 224], 'input_size should be 112 or 224' | |||
assert num_layers in [50, 100, | |||
152], 'num_layers should be 50, 100 or 152' | |||
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' | |||
blocks = get_blocks(num_layers) | |||
if mode == 'ir': | |||
unit_module = bottleneck_IR | |||
elif mode == 'ir_se': | |||
unit_module = bottleneck_IR_SE | |||
self.input_layer = Sequential( | |||
Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), | |||
PReLU(64)) | |||
if input_size == 112: | |||
self.output_layer = Sequential( | |||
BatchNorm2d(512), Dropout(drop_ratio), Flatten(), | |||
Linear(512 * 7 * 7, 512), BatchNorm1d(512, affine=affine)) | |||
else: | |||
self.output_layer = Sequential( | |||
BatchNorm2d(512), Dropout(drop_ratio), Flatten(), | |||
Linear(512 * 14 * 14, 512), BatchNorm1d(512, affine=affine)) | |||
modules = [] | |||
for block in blocks: | |||
for bottleneck in block: | |||
modules.append( | |||
unit_module(bottleneck.in_channel, bottleneck.depth, | |||
bottleneck.stride)) | |||
self.body = Sequential(*modules) | |||
def forward(self, x): | |||
x = self.input_layer(x) | |||
x = self.body(x) | |||
x = self.output_layer(x) | |||
return l2_norm(x) | |||
def IR_50(input_size): | |||
"""Constructs a ir-50 model.""" | |||
model = Backbone( | |||
input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) | |||
return model | |||
def IR_101(input_size): | |||
"""Constructs a ir-101 model.""" | |||
model = Backbone( | |||
input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) | |||
return model | |||
def IR_152(input_size): | |||
"""Constructs a ir-152 model.""" | |||
model = Backbone( | |||
input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) | |||
return model | |||
def IR_SE_50(input_size): | |||
"""Constructs a ir_se-50 model.""" | |||
model = Backbone( | |||
input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) | |||
return model | |||
def IR_SE_101(input_size): | |||
"""Constructs a ir_se-101 model.""" | |||
model = Backbone( | |||
input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) | |||
return model | |||
def IR_SE_152(input_size): | |||
"""Constructs a ir_se-152 model.""" | |||
model = Backbone( | |||
input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) | |||
return model |
@@ -0,0 +1,217 @@ | |||
import os | |||
import cv2 | |||
import numpy as np | |||
import torch | |||
import torch.backends.cudnn as cudnn | |||
import torch.nn.functional as F | |||
from .models.retinaface import RetinaFace | |||
from .utils import PriorBox, decode, decode_landm, py_cpu_nms | |||
cfg_re50 = { | |||
'name': 'Resnet50', | |||
'min_sizes': [[16, 32], [64, 128], [256, 512]], | |||
'steps': [8, 16, 32], | |||
'variance': [0.1, 0.2], | |||
'clip': False, | |||
'pretrain': False, | |||
'return_layers': { | |||
'layer2': 1, | |||
'layer3': 2, | |||
'layer4': 3 | |||
}, | |||
'in_channel': 256, | |||
'out_channel': 256 | |||
} | |||
class RetinaFaceDetection(object): | |||
def __init__(self, model_path, device='cuda'): | |||
torch.set_grad_enabled(False) | |||
cudnn.benchmark = True | |||
self.model_path = model_path | |||
self.device = device | |||
self.cfg = cfg_re50 | |||
self.net = RetinaFace(cfg=self.cfg) | |||
self.load_model() | |||
self.net = self.net.to(device) | |||
self.mean = torch.tensor([[[[104]], [[117]], [[123]]]]).to(device) | |||
def check_keys(self, pretrained_state_dict): | |||
ckpt_keys = set(pretrained_state_dict.keys()) | |||
model_keys = set(self.net.state_dict().keys()) | |||
used_pretrained_keys = model_keys & ckpt_keys | |||
assert len( | |||
used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' | |||
return True | |||
def remove_prefix(self, state_dict, prefix): | |||
new_state_dict = dict() | |||
# remove unnecessary 'module.' | |||
for k, v in state_dict.items(): | |||
if k.startswith(prefix): | |||
new_state_dict[k[len(prefix):]] = v | |||
else: | |||
new_state_dict[k] = v | |||
return new_state_dict | |||
def load_model(self, load_to_cpu=False): | |||
pretrained_dict = torch.load( | |||
self.model_path, map_location=torch.device('cpu')) | |||
if 'state_dict' in pretrained_dict.keys(): | |||
pretrained_dict = self.remove_prefix(pretrained_dict['state_dict'], | |||
'module.') | |||
else: | |||
pretrained_dict = self.remove_prefix(pretrained_dict, 'module.') | |||
self.check_keys(pretrained_dict) | |||
self.net.load_state_dict(pretrained_dict, strict=False) | |||
self.net.eval() | |||
def detect(self, | |||
img_raw, | |||
resize=1, | |||
confidence_threshold=0.9, | |||
nms_threshold=0.4, | |||
top_k=5000, | |||
keep_top_k=750, | |||
save_image=False): | |||
img = np.float32(img_raw) | |||
im_height, im_width = img.shape[:2] | |||
ss = 1.0 | |||
# tricky | |||
if max(im_height, im_width) > 1500: | |||
ss = 1000.0 / max(im_height, im_width) | |||
img = cv2.resize(img, (0, 0), fx=ss, fy=ss) | |||
im_height, im_width = img.shape[:2] | |||
scale = torch.Tensor( | |||
[img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) | |||
img -= (104, 117, 123) | |||
img = img.transpose(2, 0, 1) | |||
img = torch.from_numpy(img).unsqueeze(0) | |||
img = img.to(self.device) | |||
scale = scale.to(self.device) | |||
loc, conf, landms = self.net(img) # forward pass | |||
del img | |||
priorbox = PriorBox(self.cfg, image_size=(im_height, im_width)) | |||
priors = priorbox.forward() | |||
priors = priors.to(self.device) | |||
prior_data = priors.data | |||
boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance']) | |||
boxes = boxes * scale / resize | |||
boxes = boxes.cpu().numpy() | |||
scores = conf.squeeze(0).data.cpu().numpy()[:, 1] | |||
landms = decode_landm( | |||
landms.data.squeeze(0), prior_data, self.cfg['variance']) | |||
scale1 = torch.Tensor([ | |||
im_width, im_height, im_width, im_height, im_width, im_height, | |||
im_width, im_height, im_width, im_height | |||
]) | |||
scale1 = scale1.to(self.device) | |||
landms = landms * scale1 / resize | |||
landms = landms.cpu().numpy() | |||
# ignore low scores | |||
inds = np.where(scores > confidence_threshold)[0] | |||
boxes = boxes[inds] | |||
landms = landms[inds] | |||
scores = scores[inds] | |||
# keep top-K before NMS | |||
order = scores.argsort()[::-1][:top_k] | |||
boxes = boxes[order] | |||
landms = landms[order] | |||
scores = scores[order] | |||
# do NMS | |||
dets = np.hstack((boxes, scores[:, np.newaxis])).astype( | |||
np.float32, copy=False) | |||
keep = py_cpu_nms(dets, nms_threshold) | |||
dets = dets[keep, :] | |||
landms = landms[keep] | |||
# keep top-K faster NMS | |||
dets = dets[:keep_top_k, :] | |||
landms = landms[:keep_top_k, :] | |||
landms = landms.reshape((-1, 5, 2)) | |||
landms = landms.transpose((0, 2, 1)) | |||
landms = landms.reshape( | |||
-1, | |||
10, | |||
) | |||
return dets / ss, landms / ss | |||
def detect_tensor(self, | |||
img, | |||
resize=1, | |||
confidence_threshold=0.9, | |||
nms_threshold=0.4, | |||
top_k=5000, | |||
keep_top_k=750, | |||
save_image=False): | |||
im_height, im_width = img.shape[-2:] | |||
ss = 1000 / max(im_height, im_width) | |||
img = F.interpolate(img, scale_factor=ss) | |||
im_height, im_width = img.shape[-2:] | |||
scale = torch.Tensor([im_width, im_height, im_width, | |||
im_height]).to(self.device) | |||
img -= self.mean | |||
loc, conf, landms = self.net(img) # forward pass | |||
priorbox = PriorBox(self.cfg, image_size=(im_height, im_width)) | |||
priors = priorbox.forward() | |||
priors = priors.to(self.device) | |||
prior_data = priors.data | |||
boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance']) | |||
boxes = boxes * scale / resize | |||
boxes = boxes.cpu().numpy() | |||
scores = conf.squeeze(0).data.cpu().numpy()[:, 1] | |||
landms = decode_landm( | |||
landms.data.squeeze(0), prior_data, self.cfg['variance']) | |||
scale1 = torch.Tensor([ | |||
img.shape[3], img.shape[2], img.shape[3], img.shape[2], | |||
img.shape[3], img.shape[2], img.shape[3], img.shape[2], | |||
img.shape[3], img.shape[2] | |||
]) | |||
scale1 = scale1.to(self.device) | |||
landms = landms * scale1 / resize | |||
landms = landms.cpu().numpy() | |||
# ignore low scores | |||
inds = np.where(scores > confidence_threshold)[0] | |||
boxes = boxes[inds] | |||
landms = landms[inds] | |||
scores = scores[inds] | |||
# keep top-K before NMS | |||
order = scores.argsort()[::-1][:top_k] | |||
boxes = boxes[order] | |||
landms = landms[order] | |||
scores = scores[order] | |||
# do NMS | |||
dets = np.hstack((boxes, scores[:, np.newaxis])).astype( | |||
np.float32, copy=False) | |||
keep = py_cpu_nms(dets, nms_threshold) | |||
dets = dets[keep, :] | |||
landms = landms[keep] | |||
# keep top-K faster NMS | |||
dets = dets[:keep_top_k, :] | |||
landms = landms[:keep_top_k, :] | |||
landms = landms.reshape((-1, 5, 2)) | |||
landms = landms.transpose((0, 2, 1)) | |||
landms = landms.reshape( | |||
-1, | |||
10, | |||
) | |||
return dets / ss, landms / ss |
@@ -0,0 +1,148 @@ | |||
import time | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
import torchvision.models as models | |||
import torchvision.models._utils as _utils | |||
from torch.autograd import Variable | |||
def conv_bn(inp, oup, stride=1, leaky=0): | |||
return nn.Sequential( | |||
nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), | |||
nn.LeakyReLU(negative_slope=leaky, inplace=True)) | |||
def conv_bn_no_relu(inp, oup, stride): | |||
return nn.Sequential( | |||
nn.Conv2d(inp, oup, 3, stride, 1, bias=False), | |||
nn.BatchNorm2d(oup), | |||
) | |||
def conv_bn1X1(inp, oup, stride, leaky=0): | |||
return nn.Sequential( | |||
nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), | |||
nn.BatchNorm2d(oup), nn.LeakyReLU(negative_slope=leaky, inplace=True)) | |||
def conv_dw(inp, oup, stride, leaky=0.1): | |||
return nn.Sequential( | |||
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), | |||
nn.BatchNorm2d(inp), | |||
nn.LeakyReLU(negative_slope=leaky, inplace=True), | |||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False), | |||
nn.BatchNorm2d(oup), | |||
nn.LeakyReLU(negative_slope=leaky, inplace=True), | |||
) | |||
class SSH(nn.Module): | |||
def __init__(self, in_channel, out_channel): | |||
super(SSH, self).__init__() | |||
assert out_channel % 4 == 0 | |||
leaky = 0 | |||
if (out_channel <= 64): | |||
leaky = 0.1 | |||
self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) | |||
self.conv5X5_1 = conv_bn( | |||
in_channel, out_channel // 4, stride=1, leaky=leaky) | |||
self.conv5X5_2 = conv_bn_no_relu( | |||
out_channel // 4, out_channel // 4, stride=1) | |||
self.conv7X7_2 = conv_bn( | |||
out_channel // 4, out_channel // 4, stride=1, leaky=leaky) | |||
self.conv7x7_3 = conv_bn_no_relu( | |||
out_channel // 4, out_channel // 4, stride=1) | |||
def forward(self, input): | |||
conv3X3 = self.conv3X3(input) | |||
conv5X5_1 = self.conv5X5_1(input) | |||
conv5X5 = self.conv5X5_2(conv5X5_1) | |||
conv7X7_2 = self.conv7X7_2(conv5X5_1) | |||
conv7X7 = self.conv7x7_3(conv7X7_2) | |||
out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) | |||
out = F.relu(out) | |||
return out | |||
class FPN(nn.Module): | |||
def __init__(self, in_channels_list, out_channels): | |||
super(FPN, self).__init__() | |||
leaky = 0 | |||
if (out_channels <= 64): | |||
leaky = 0.1 | |||
self.output1 = conv_bn1X1( | |||
in_channels_list[0], out_channels, stride=1, leaky=leaky) | |||
self.output2 = conv_bn1X1( | |||
in_channels_list[1], out_channels, stride=1, leaky=leaky) | |||
self.output3 = conv_bn1X1( | |||
in_channels_list[2], out_channels, stride=1, leaky=leaky) | |||
self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) | |||
self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) | |||
def forward(self, input): | |||
# names = list(input.keys()) | |||
input = list(input.values()) | |||
output1 = self.output1(input[0]) | |||
output2 = self.output2(input[1]) | |||
output3 = self.output3(input[2]) | |||
up3 = F.interpolate( | |||
output3, size=[output2.size(2), output2.size(3)], mode='nearest') | |||
output2 = output2 + up3 | |||
output2 = self.merge2(output2) | |||
up2 = F.interpolate( | |||
output2, size=[output1.size(2), output1.size(3)], mode='nearest') | |||
output1 = output1 + up2 | |||
output1 = self.merge1(output1) | |||
out = [output1, output2, output3] | |||
return out | |||
class MobileNetV1(nn.Module): | |||
def __init__(self): | |||
super(MobileNetV1, self).__init__() | |||
self.stage1 = nn.Sequential( | |||
conv_bn(3, 8, 2, leaky=0.1), # 3 | |||
conv_dw(8, 16, 1), # 7 | |||
conv_dw(16, 32, 2), # 11 | |||
conv_dw(32, 32, 1), # 19 | |||
conv_dw(32, 64, 2), # 27 | |||
conv_dw(64, 64, 1), # 43 | |||
) | |||
self.stage2 = nn.Sequential( | |||
conv_dw(64, 128, 2), # 43 + 16 = 59 | |||
conv_dw(128, 128, 1), # 59 + 32 = 91 | |||
conv_dw(128, 128, 1), # 91 + 32 = 123 | |||
conv_dw(128, 128, 1), # 123 + 32 = 155 | |||
conv_dw(128, 128, 1), # 155 + 32 = 187 | |||
conv_dw(128, 128, 1), # 187 + 32 = 219 | |||
) | |||
self.stage3 = nn.Sequential( | |||
conv_dw(128, 256, 2), # 219 +3 2 = 241 | |||
conv_dw(256, 256, 1), # 241 + 64 = 301 | |||
) | |||
self.avg = nn.AdaptiveAvgPool2d((1, 1)) | |||
self.fc = nn.Linear(256, 1000) | |||
def forward(self, x): | |||
x = self.stage1(x) | |||
x = self.stage2(x) | |||
x = self.stage3(x) | |||
x = self.avg(x) | |||
x = x.view(-1, 256) | |||
x = self.fc(x) | |||
return x |
@@ -0,0 +1,144 @@ | |||
from collections import OrderedDict | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
import torchvision.models as models | |||
import torchvision.models._utils as _utils | |||
import torchvision.models.detection.backbone_utils as backbone_utils | |||
from .net import FPN, SSH, MobileNetV1 | |||
class ClassHead(nn.Module): | |||
def __init__(self, inchannels=512, num_anchors=3): | |||
super(ClassHead, self).__init__() | |||
self.num_anchors = num_anchors | |||
self.conv1x1 = nn.Conv2d( | |||
inchannels, | |||
self.num_anchors * 2, | |||
kernel_size=(1, 1), | |||
stride=1, | |||
padding=0) | |||
def forward(self, x): | |||
out = self.conv1x1(x) | |||
out = out.permute(0, 2, 3, 1).contiguous() | |||
return out.view(out.shape[0], -1, 2) | |||
class BboxHead(nn.Module): | |||
def __init__(self, inchannels=512, num_anchors=3): | |||
super(BboxHead, self).__init__() | |||
self.conv1x1 = nn.Conv2d( | |||
inchannels, | |||
num_anchors * 4, | |||
kernel_size=(1, 1), | |||
stride=1, | |||
padding=0) | |||
def forward(self, x): | |||
out = self.conv1x1(x) | |||
out = out.permute(0, 2, 3, 1).contiguous() | |||
return out.view(out.shape[0], -1, 4) | |||
class LandmarkHead(nn.Module): | |||
def __init__(self, inchannels=512, num_anchors=3): | |||
super(LandmarkHead, self).__init__() | |||
self.conv1x1 = nn.Conv2d( | |||
inchannels, | |||
num_anchors * 10, | |||
kernel_size=(1, 1), | |||
stride=1, | |||
padding=0) | |||
def forward(self, x): | |||
out = self.conv1x1(x) | |||
out = out.permute(0, 2, 3, 1).contiguous() | |||
return out.view(out.shape[0], -1, 10) | |||
class RetinaFace(nn.Module): | |||
def __init__(self, cfg=None): | |||
""" | |||
:param cfg: Network related settings. | |||
""" | |||
super(RetinaFace, self).__init__() | |||
backbone = None | |||
if cfg['name'] == 'Resnet50': | |||
backbone = models.resnet50(pretrained=cfg['pretrain']) | |||
else: | |||
raise Exception('Invalid name') | |||
self.body = _utils.IntermediateLayerGetter(backbone, | |||
cfg['return_layers']) | |||
in_channels_stage2 = cfg['in_channel'] | |||
in_channels_list = [ | |||
in_channels_stage2 * 2, | |||
in_channels_stage2 * 4, | |||
in_channels_stage2 * 8, | |||
] | |||
out_channels = cfg['out_channel'] | |||
self.fpn = FPN(in_channels_list, out_channels) | |||
self.ssh1 = SSH(out_channels, out_channels) | |||
self.ssh2 = SSH(out_channels, out_channels) | |||
self.ssh3 = SSH(out_channels, out_channels) | |||
self.ClassHead = self._make_class_head( | |||
fpn_num=3, inchannels=cfg['out_channel']) | |||
self.BboxHead = self._make_bbox_head( | |||
fpn_num=3, inchannels=cfg['out_channel']) | |||
self.LandmarkHead = self._make_landmark_head( | |||
fpn_num=3, inchannels=cfg['out_channel']) | |||
def _make_class_head(self, fpn_num=3, inchannels=64, anchor_num=2): | |||
classhead = nn.ModuleList() | |||
for i in range(fpn_num): | |||
classhead.append(ClassHead(inchannels, anchor_num)) | |||
return classhead | |||
def _make_bbox_head(self, fpn_num=3, inchannels=64, anchor_num=2): | |||
bboxhead = nn.ModuleList() | |||
for i in range(fpn_num): | |||
bboxhead.append(BboxHead(inchannels, anchor_num)) | |||
return bboxhead | |||
def _make_landmark_head(self, fpn_num=3, inchannels=64, anchor_num=2): | |||
landmarkhead = nn.ModuleList() | |||
for i in range(fpn_num): | |||
landmarkhead.append(LandmarkHead(inchannels, anchor_num)) | |||
return landmarkhead | |||
def forward(self, inputs): | |||
out = self.body(inputs) | |||
# FPN | |||
fpn = self.fpn(out) | |||
# SSH | |||
feature1 = self.ssh1(fpn[0]) | |||
feature2 = self.ssh2(fpn[1]) | |||
feature3 = self.ssh3(fpn[2]) | |||
features = [feature1, feature2, feature3] | |||
bbox_regressions = torch.cat( | |||
[self.BboxHead[i](feature) for i, feature in enumerate(features)], | |||
dim=1) | |||
classifications = torch.cat( | |||
[self.ClassHead[i](feature) for i, feature in enumerate(features)], | |||
dim=1) | |||
ldm_regressions = torch.cat( | |||
[self.LandmarkHead[i](feat) for i, feat in enumerate(features)], | |||
dim=1) | |||
output = (bbox_regressions, F.softmax(classifications, | |||
dim=-1), ldm_regressions) | |||
return output |
@@ -0,0 +1,123 @@ | |||
# -------------------------------------------------------- | |||
# Modified from https://github.com/biubug6/Pytorch_Retinaface | |||
# -------------------------------------------------------- | |||
from itertools import product as product | |||
from math import ceil | |||
import numpy as np | |||
import torch | |||
class PriorBox(object): | |||
def __init__(self, cfg, image_size=None, phase='train'): | |||
super(PriorBox, self).__init__() | |||
self.min_sizes = cfg['min_sizes'] | |||
self.steps = cfg['steps'] | |||
self.clip = cfg['clip'] | |||
self.image_size = image_size | |||
self.feature_maps = [[ | |||
ceil(self.image_size[0] / step), | |||
ceil(self.image_size[1] / step) | |||
] for step in self.steps] | |||
self.name = 's' | |||
def forward(self): | |||
anchors = [] | |||
for k, f in enumerate(self.feature_maps): | |||
min_sizes = self.min_sizes[k] | |||
for i, j in product(range(f[0]), range(f[1])): | |||
for min_size in min_sizes: | |||
s_kx = min_size / self.image_size[1] | |||
s_ky = min_size / self.image_size[0] | |||
dense_cx = [ | |||
x * self.steps[k] / self.image_size[1] | |||
for x in [j + 0.5] | |||
] | |||
dense_cy = [ | |||
y * self.steps[k] / self.image_size[0] | |||
for y in [i + 0.5] | |||
] | |||
for cy, cx in product(dense_cy, dense_cx): | |||
anchors += [cx, cy, s_kx, s_ky] | |||
# back to torch land | |||
output = torch.Tensor(anchors).view(-1, 4) | |||
if self.clip: | |||
output.clamp_(max=1, min=0) | |||
return output | |||
def py_cpu_nms(dets, thresh): | |||
"""Pure Python NMS baseline.""" | |||
x1 = dets[:, 0] | |||
y1 = dets[:, 1] | |||
x2 = dets[:, 2] | |||
y2 = dets[:, 3] | |||
scores = dets[:, 4] | |||
areas = (x2 - x1 + 1) * (y2 - y1 + 1) | |||
order = scores.argsort()[::-1] | |||
keep = [] | |||
while order.size > 0: | |||
i = order[0] | |||
keep.append(i) | |||
xx1 = np.maximum(x1[i], x1[order[1:]]) | |||
yy1 = np.maximum(y1[i], y1[order[1:]]) | |||
xx2 = np.minimum(x2[i], x2[order[1:]]) | |||
yy2 = np.minimum(y2[i], y2[order[1:]]) | |||
w = np.maximum(0.0, xx2 - xx1 + 1) | |||
h = np.maximum(0.0, yy2 - yy1 + 1) | |||
inter = w * h | |||
ovr = inter / (areas[i] + areas[order[1:]] - inter) | |||
inds = np.where(ovr <= thresh)[0] | |||
order = order[inds + 1] | |||
return keep | |||
# Adapted from https://github.com/Hakuyume/chainer-ssd | |||
def decode(loc, priors, variances): | |||
"""Decode locations from predictions using priors to undo | |||
the encoding we did for offset regression at train time. | |||
Args: | |||
loc (tensor): location predictions for loc layers, | |||
Shape: [num_priors,4] | |||
priors (tensor): Prior boxes in center-offset form. | |||
Shape: [num_priors,4]. | |||
variances: (list[float]) Variances of priorboxes | |||
Return: | |||
decoded bounding box predictions | |||
""" | |||
boxes = torch.cat( | |||
(priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], | |||
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) | |||
boxes[:, :2] -= boxes[:, 2:] / 2 | |||
boxes[:, 2:] += boxes[:, :2] | |||
return boxes | |||
def decode_landm(pre, priors, variances): | |||
"""Decode landm from predictions using priors to undo | |||
the encoding we did for offset regression at train time. | |||
Args: | |||
pre (tensor): landm predictions for loc layers, | |||
Shape: [num_priors,10] | |||
priors (tensor): Prior boxes in center-offset form. | |||
Shape: [num_priors,4]. | |||
variances: (list[float]) Variances of priorboxes | |||
Return: | |||
decoded landm predictions | |||
""" | |||
a = priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:] | |||
b = priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:] | |||
c = priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:] | |||
d = priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:] | |||
e = priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:] | |||
landms = torch.cat((a, b, c, d, e), dim=1) | |||
return landms |
@@ -137,6 +137,7 @@ TASK_OUTPUTS = { | |||
Tasks.image_colorization: [OutputKeys.OUTPUT_IMG], | |||
Tasks.image_color_enhancement: [OutputKeys.OUTPUT_IMG], | |||
Tasks.image_denoising: [OutputKeys.OUTPUT_IMG], | |||
Tasks.image_portrait_enhancement: [OutputKeys.OUTPUT_IMG], | |||
# image generation task result for a single image | |||
# {"output_img": np.array with shape (h, w, 3)} | |||
@@ -110,6 +110,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
'damo/cv_gan_face-image-generation'), | |||
Tasks.image_super_resolution: (Pipelines.image_super_resolution, | |||
'damo/cv_rrdb_image-super-resolution'), | |||
Tasks.image_portrait_enhancement: | |||
(Pipelines.image_portrait_enhancement, | |||
'damo/cv_gpen_image-portrait-enhancement'), | |||
Tasks.product_retrieval_embedding: | |||
(Pipelines.product_retrieval_embedding, | |||
'damo/cv_resnet50_product-bag-embedding-models'), | |||
@@ -11,14 +11,15 @@ if TYPE_CHECKING: | |||
from .face_detection_pipeline import FaceDetectionPipeline | |||
from .face_recognition_pipeline import FaceRecognitionPipeline | |||
from .face_image_generation_pipeline import FaceImageGenerationPipeline | |||
from .image_classification_pipeline import ImageClassificationPipeline | |||
from .image_cartoon_pipeline import ImageCartoonPipeline | |||
from .image_classification_pipeline import GeneralImageClassificationPipeline | |||
from .image_denoise_pipeline import ImageDenoisePipeline | |||
from .image_color_enhance_pipeline import ImageColorEnhancePipeline | |||
from .image_colorization_pipeline import ImageColorizationPipeline | |||
from .image_classification_pipeline import ImageClassificationPipeline | |||
from .image_denoise_pipeline import ImageDenoisePipeline | |||
from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline | |||
from .image_matting_pipeline import ImageMattingPipeline | |||
from .image_portrait_enhancement_pipeline import ImagePortraitEnhancementPipeline | |||
from .image_style_transfer_pipeline import ImageStyleTransferPipeline | |||
from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | |||
from .image_to_image_generate_pipeline import Image2ImageGenerationePipeline | |||
@@ -46,6 +47,8 @@ else: | |||
'image_instance_segmentation_pipeline': | |||
['ImageInstanceSegmentationPipeline'], | |||
'image_matting_pipeline': ['ImageMattingPipeline'], | |||
'image_portrait_enhancement_pipeline': | |||
['ImagePortraitEnhancementPipeline'], | |||
'image_style_transfer_pipeline': ['ImageStyleTransferPipeline'], | |||
'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'], | |||
'image_to_image_translation_pipeline': | |||
@@ -0,0 +1,216 @@ | |||
import math | |||
from typing import Any, Dict | |||
import cv2 | |||
import numpy as np | |||
import PIL | |||
import torch | |||
from scipy.ndimage import gaussian_filter | |||
from scipy.spatial.distance import pdist, squareform | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models.cv.image_portrait_enhancement import gpen | |||
from modelscope.models.cv.image_portrait_enhancement.align_faces import ( | |||
get_reference_facial_points, warp_and_crop_face) | |||
from modelscope.models.cv.image_portrait_enhancement.eqface import fqa | |||
from modelscope.models.cv.image_portrait_enhancement.retinaface import \ | |||
detection | |||
from modelscope.models.cv.super_resolution import rrdbnet_arch | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import LoadImage, load_image | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.image_portrait_enhancement, | |||
module_name=Pipelines.image_portrait_enhancement) | |||
class ImagePortraitEnhancementPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
use `model` to create a kws pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
super().__init__(model=model, **kwargs) | |||
if torch.cuda.is_available(): | |||
self.device = torch.device('cuda') | |||
else: | |||
self.device = torch.device('cpu') | |||
self.use_sr = True | |||
self.size = 512 | |||
self.n_mlp = 8 | |||
self.channel_multiplier = 2 | |||
self.narrow = 1 | |||
self.face_enhancer = gpen.FullGenerator( | |||
self.size, | |||
512, | |||
self.n_mlp, | |||
self.channel_multiplier, | |||
narrow=self.narrow).to(self.device) | |||
gpen_model_path = f'{model}/{ModelFile.TORCH_MODEL_FILE}' | |||
self.face_enhancer.load_state_dict( | |||
torch.load(gpen_model_path), strict=True) | |||
logger.info('load face enhancer model done') | |||
self.threshold = 0.9 | |||
detector_model_path = f'{model}/face_detection/RetinaFace-R50.pth' | |||
self.face_detector = detection.RetinaFaceDetection( | |||
detector_model_path, self.device) | |||
logger.info('load face detector model done') | |||
self.num_feat = 32 | |||
self.num_block = 23 | |||
self.scale = 2 | |||
self.sr_model = rrdbnet_arch.RRDBNet( | |||
num_in_ch=3, | |||
num_out_ch=3, | |||
num_feat=self.num_feat, | |||
num_block=self.num_block, | |||
num_grow_ch=32, | |||
scale=self.scale).to(self.device) | |||
sr_model_path = f'{model}/super_resolution/realesrnet_x{self.scale}.pth' | |||
self.sr_model.load_state_dict( | |||
torch.load(sr_model_path)['params_ema'], strict=True) | |||
logger.info('load sr model done') | |||
self.fqa_thres = 0.1 | |||
self.id_thres = 0.15 | |||
self.alpha = 1.0 | |||
backbone_model_path = f'{model}/face_quality/eqface_backbone.pth' | |||
fqa_model_path = f'{model}/face_quality/eqface_quality.pth' | |||
self.eqface = fqa.FQA(backbone_model_path, fqa_model_path, self.device) | |||
logger.info('load fqa model done') | |||
# the mask for pasting restored faces back | |||
self.mask = np.zeros((512, 512, 3), np.float32) | |||
cv2.rectangle(self.mask, (26, 26), (486, 486), (1, 1, 1), -1, | |||
cv2.LINE_AA) | |||
self.mask = cv2.GaussianBlur(self.mask, (101, 101), 4) | |||
self.mask = cv2.GaussianBlur(self.mask, (101, 101), 4) | |||
def enhance_face(self, img): | |||
img = cv2.resize(img, (self.size, self.size)) | |||
img_t = self.img2tensor(img) | |||
self.face_enhancer.eval() | |||
with torch.no_grad(): | |||
out, __ = self.face_enhancer(img_t) | |||
del img_t | |||
out = self.tensor2img(out) | |||
return out | |||
def img2tensor(self, img, is_norm=True): | |||
img_t = torch.from_numpy(img).to(self.device) / 255. | |||
if is_norm: | |||
img_t = (img_t - 0.5) / 0.5 | |||
img_t = img_t.permute(2, 0, 1).unsqueeze(0).flip(1) # BGR->RGB | |||
return img_t | |||
def tensor2img(self, img_t, pmax=255.0, is_denorm=True, imtype=np.uint8): | |||
if is_denorm: | |||
img_t = img_t * 0.5 + 0.5 | |||
img_t = img_t.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR | |||
img_np = np.clip(img_t.float().cpu().numpy(), 0, 1) * pmax | |||
return img_np.astype(imtype) | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
img = LoadImage.convert_to_ndarray(input) | |||
img_sr = None | |||
if self.use_sr: | |||
self.sr_model.eval() | |||
with torch.no_grad(): | |||
img_t = self.img2tensor(img, is_norm=False) | |||
img_out = self.sr_model(img_t) | |||
img_sr = img_out.squeeze(0).permute(1, 2, 0).flip(2).cpu().clamp_( | |||
0, 1).numpy() | |||
img_sr = (img_sr * 255.0).round().astype(np.uint8) | |||
img = cv2.resize(img, img_sr.shape[:2][::-1]) | |||
result = {'img': img, 'img_sr': img_sr} | |||
return result | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
img, img_sr = input['img'], input['img_sr'] | |||
img, img_sr = img.cpu().numpy(), img_sr.cpu().numpy() | |||
facebs, landms = self.face_detector.detect(img) | |||
height, width = img.shape[:2] | |||
full_mask = np.zeros(img.shape, dtype=np.float32) | |||
full_img = np.zeros(img.shape, dtype=np.uint8) | |||
for i, (faceb, facial5points) in enumerate(zip(facebs, landms)): | |||
if faceb[4] < self.threshold: | |||
continue | |||
# fh, fw = (faceb[3] - faceb[1]), (faceb[2] - faceb[0]) | |||
facial5points = np.reshape(facial5points, (2, 5)) | |||
of, of_112, tfm_inv = warp_and_crop_face( | |||
img, facial5points, crop_size=(self.size, self.size)) | |||
# detect orig face quality | |||
fq_o, fea_o = self.eqface.get_face_quality(of_112) | |||
if fq_o < self.fqa_thres: | |||
continue | |||
# enhance the face | |||
ef = self.enhance_face(of) | |||
# detect enhanced face quality | |||
ss = self.size // 256 | |||
ef_112 = cv2.resize(ef[35 * ss:-33 * ss, 32 * ss:-36 * ss], | |||
(112, 112)) # crop roi | |||
fq_e, fea_e = self.eqface.get_face_quality(ef_112) | |||
dist = squareform(pdist([fea_o, fea_e], 'cosine')).mean() | |||
if dist > self.id_thres: | |||
continue | |||
# blending parameter | |||
fq = max(1., (fq_o - self.fqa_thres)) | |||
fq = (1 - 2 * dist) * (1.0 / (1 + math.exp(-(2 * fq - 1)))) | |||
# blend face | |||
ef = cv2.addWeighted(ef, fq * self.alpha, of, 1 - fq * self.alpha, | |||
0.0) | |||
tmp_mask = self.mask | |||
tmp_mask = cv2.resize(tmp_mask, ef.shape[:2]) | |||
tmp_mask = cv2.warpAffine( | |||
tmp_mask, tfm_inv, (width, height), flags=3) | |||
tmp_img = cv2.warpAffine(ef, tfm_inv, (width, height), flags=3) | |||
mask = np.clip(tmp_mask - full_mask, 0, 1) | |||
full_mask[np.where(mask > 0)] = tmp_mask[np.where(mask > 0)] | |||
full_img[np.where(mask > 0)] = tmp_img[np.where(mask > 0)] | |||
if self.use_sr and img_sr is not None: | |||
out_img = cv2.convertScaleAbs(img_sr * (1 - full_mask) | |||
+ full_img * full_mask) | |||
else: | |||
out_img = cv2.convertScaleAbs(img * (1 - full_mask) | |||
+ full_img * full_mask) | |||
return {OutputKeys.OUTPUT_IMG: out_img.astype(np.uint8)} | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
return inputs |
@@ -163,6 +163,32 @@ class ImageDenoisePreprocessor(Preprocessor): | |||
return data | |||
@PREPROCESSORS.register_module( | |||
Fields.cv, | |||
module_name=Preprocessors.image_portrait_enhancement_preprocessor) | |||
class ImagePortraitEnhancementPreprocessor(Preprocessor): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
""" | |||
Args: | |||
model_dir (str): model path | |||
""" | |||
super().__init__(*args, **kwargs) | |||
self.model_dir: str = model_dir | |||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
"""process the raw input data | |||
Args: | |||
data Dict[str, Any] | |||
Returns: | |||
Dict[str, Any]: the preprocessed data | |||
""" | |||
return data | |||
@PREPROCESSORS.register_module( | |||
Fields.cv, | |||
module_name=Preprocessors.image_instance_segmentation_preprocessor) | |||
@@ -1,6 +1,7 @@ | |||
from .base import DummyTrainer | |||
from .builder import build_trainer | |||
from .cv import ImageInstanceSegmentationTrainer | |||
from .cv import (ImageInstanceSegmentationTrainer, | |||
ImagePortraitEnhancementTrainer) | |||
from .multi_modal import CLIPTrainer | |||
from .nlp import SequenceClassificationTrainer | |||
from .trainer import EpochBasedTrainer |
@@ -1,2 +1,3 @@ | |||
from .image_instance_segmentation_trainer import \ | |||
ImageInstanceSegmentationTrainer | |||
from .image_portrait_enhancement_trainer import ImagePortraitEnhancementTrainer |
@@ -0,0 +1,148 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from collections.abc import Mapping | |||
import torch | |||
from torch import distributed as dist | |||
from modelscope.trainers.builder import TRAINERS | |||
from modelscope.trainers.optimizer.builder import build_optimizer | |||
from modelscope.trainers.trainer import EpochBasedTrainer | |||
from modelscope.utils.constant import ModeKeys | |||
from modelscope.utils.logger import get_logger | |||
@TRAINERS.register_module(module_name='gpen') | |||
class ImagePortraitEnhancementTrainer(EpochBasedTrainer): | |||
def train_step(self, model, inputs): | |||
""" Perform a training step on a batch of inputs. | |||
Subclass and override to inject custom behavior. | |||
Args: | |||
model (`TorchModel`): The model to train. | |||
inputs (`Dict[str, Union[torch.Tensor, Any]]`): | |||
The inputs and targets of the model. | |||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the | |||
argument `labels`. Check your model's documentation for all accepted arguments. | |||
Return: | |||
`torch.Tensor`: The tensor with training loss on this batch. | |||
""" | |||
# EvaluationHook will do evaluate and change mode to val, return to train mode | |||
# TODO: find more pretty way to change mode | |||
self.d_reg_every = self.cfg.train.get('d_reg_every', 16) | |||
self.g_reg_every = self.cfg.train.get('g_reg_every', 4) | |||
self.path_regularize = self.cfg.train.get('path_regularize', 2) | |||
self.r1 = self.cfg.train.get('r1', 10) | |||
train_outputs = dict() | |||
self._mode = ModeKeys.TRAIN | |||
inputs = self.collate_fn(inputs) | |||
# call model forward but not __call__ to skip postprocess | |||
if isinstance(inputs, Mapping): | |||
d_loss = model._train_forward_d(**inputs) | |||
else: | |||
d_loss = model._train_forward_d(inputs) | |||
train_outputs['d_loss'] = d_loss | |||
model.discriminator.zero_grad() | |||
d_loss.backward() | |||
self.optimizer_d.step() | |||
if self._iter % self.d_reg_every == 0: | |||
if isinstance(inputs, Mapping): | |||
r1_loss = model._train_forward_d_r1(**inputs) | |||
else: | |||
r1_loss = model._train_forward_d_r1(inputs) | |||
train_outputs['r1_loss'] = r1_loss | |||
model.discriminator.zero_grad() | |||
(self.r1 / 2 * r1_loss * self.d_reg_every).backward() | |||
self.optimizer_d.step() | |||
if isinstance(inputs, Mapping): | |||
g_loss = model._train_forward_g(**inputs) | |||
else: | |||
g_loss = model._train_forward_g(inputs) | |||
train_outputs['g_loss'] = g_loss | |||
model.generator.zero_grad() | |||
g_loss.backward() | |||
self.optimizer.step() | |||
path_loss = 0 | |||
if self._iter % self.g_reg_every == 0: | |||
if isinstance(inputs, Mapping): | |||
path_loss = model._train_forward_g_path(**inputs) | |||
else: | |||
path_loss = model._train_forward_g_path(inputs) | |||
train_outputs['path_loss'] = path_loss | |||
model.generator.zero_grad() | |||
weighted_path_loss = self.path_regularize * self.g_reg_every * path_loss | |||
weighted_path_loss.backward() | |||
self.optimizer.step() | |||
model.accumulate() | |||
if not isinstance(train_outputs, dict): | |||
raise TypeError('"model.forward()" must return a dict') | |||
# add model output info to log | |||
if 'log_vars' not in train_outputs: | |||
default_keys_pattern = ['loss'] | |||
match_keys = set([]) | |||
for key_p in default_keys_pattern: | |||
match_keys.update( | |||
[key for key in train_outputs.keys() if key_p in key]) | |||
log_vars = {} | |||
for key in match_keys: | |||
value = train_outputs.get(key, None) | |||
if value is not None: | |||
if dist.is_available() and dist.is_initialized(): | |||
value = value.data.clone() | |||
dist.all_reduce(value.div_(dist.get_world_size())) | |||
log_vars.update({key: value.item()}) | |||
self.log_buffer.update(log_vars) | |||
else: | |||
self.log_buffer.update(train_outputs['log_vars']) | |||
self.train_outputs = train_outputs | |||
def create_optimizer_and_scheduler(self): | |||
""" Create optimizer and lr scheduler | |||
We provide a default implementation, if you want to customize your own optimizer | |||
and lr scheduler, you can either pass a tuple through trainer init function or | |||
subclass this class and override this method. | |||
""" | |||
optimizer, lr_scheduler = self.optimizers | |||
if optimizer is None: | |||
optimizer_cfg = self.cfg.train.get('optimizer', None) | |||
else: | |||
optimizer_cfg = None | |||
optimizer_d_cfg = self.cfg.train.get('optimizer_d', None) | |||
optim_options = {} | |||
if optimizer_cfg is not None: | |||
optim_options = optimizer_cfg.pop('options', {}) | |||
optimizer = build_optimizer( | |||
self.model.generator, cfg=optimizer_cfg) | |||
if optimizer_d_cfg is not None: | |||
optimizer_d = build_optimizer( | |||
self.model.discriminator, cfg=optimizer_d_cfg) | |||
lr_options = {} | |||
self.optimizer = optimizer | |||
self.lr_scheduler = lr_scheduler | |||
self.optimizer_d = optimizer_d | |||
return self.optimizer, self.lr_scheduler, optim_options, lr_options |
@@ -14,5 +14,5 @@ __all__ = [ | |||
'Hook', 'HOOKS', 'CheckpointHook', 'EvaluationHook', 'LrSchedulerHook', | |||
'OptimizerHook', 'Priority', 'build_hook', 'TextLoggerHook', | |||
'IterTimerHook', 'TorchAMPOptimizerHook', 'ApexAMPOptimizerHook', | |||
'BestCkptSaverHook' | |||
'BestCkptSaverHook', 'NoneOptimizerHook', 'NoneLrSchedulerHook' | |||
] |
@@ -115,3 +115,18 @@ class PlateauLrSchedulerHook(LrSchedulerHook): | |||
self.warmup_lr_scheduler.step(metrics=metrics) | |||
else: | |||
trainer.lr_scheduler.step(metrics=metrics) | |||
@HOOKS.register_module() | |||
class NoneLrSchedulerHook(LrSchedulerHook): | |||
PRIORITY = Priority.LOW # should be after EvaluationHook | |||
def __init__(self, by_epoch=True, warmup=None) -> None: | |||
super().__init__(by_epoch=by_epoch, warmup=warmup) | |||
def before_run(self, trainer): | |||
return | |||
def after_train_epoch(self, trainer): | |||
return |
@@ -200,3 +200,19 @@ class ApexAMPOptimizerHook(OptimizerHook): | |||
trainer.optimizer.step() | |||
trainer.optimizer.zero_grad() | |||
@HOOKS.register_module() | |||
class NoneOptimizerHook(OptimizerHook): | |||
def __init__(self, cumulative_iters=1, grad_clip=None, loss_keys='loss'): | |||
super(NoneOptimizerHook, self).__init__( | |||
grad_clip=grad_clip, loss_keys=loss_keys) | |||
self.cumulative_iters = cumulative_iters | |||
def before_run(self, trainer): | |||
return | |||
def after_train_iter(self, trainer): | |||
return |
@@ -43,6 +43,7 @@ class CVTasks(object): | |||
image_colorization = 'image-colorization' | |||
image_color_enhancement = 'image-color-enhancement' | |||
image_denoising = 'image-denoising' | |||
image_portrait_enhancement = 'image-portrait-enhancement' | |||
# image generation | |||
image_to_image_translation = 'image-to-image-translation' | |||
@@ -0,0 +1,43 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import os.path as osp | |||
import unittest | |||
import cv2 | |||
from modelscope.msdatasets import MsDataset | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import pipeline | |||
from modelscope.pipelines.base import Pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.test_utils import test_level | |||
class ImagePortraitEnhancementTest(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.model_id = 'damo/cv_gpen_image-portrait-enhancement' | |||
self.test_image = 'data/test/images/Solvay_conference_1927.png' | |||
def pipeline_inference(self, pipeline: Pipeline, test_image: str): | |||
result = pipeline(test_image) | |||
if result is not None: | |||
cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) | |||
print(f'Output written to {osp.abspath("result.png")}') | |||
else: | |||
raise Exception('Testing failed: invalid output') | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_modelhub(self): | |||
face_enhancement = pipeline( | |||
Tasks.image_portrait_enhancement, model=self.model_id) | |||
self.pipeline_inference(face_enhancement, self.test_image) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_modelhub_default_model(self): | |||
face_enhancement = pipeline(Tasks.image_portrait_enhancement) | |||
self.pipeline_inference(face_enhancement, self.test_image) | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -0,0 +1,119 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import os.path as osp | |||
import shutil | |||
import tempfile | |||
import unittest | |||
from typing import Callable, List, Optional, Tuple, Union | |||
import cv2 | |||
import torch | |||
from torch.utils import data as data | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.models.cv.image_portrait_enhancement import \ | |||
ImagePortraitEnhancement | |||
from modelscope.trainers import build_trainer | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.test_utils import test_level | |||
class TestImagePortraitEnhancementTrainer(unittest.TestCase): | |||
def setUp(self): | |||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||
self.tmp_dir = tempfile.TemporaryDirectory().name | |||
if not os.path.exists(self.tmp_dir): | |||
os.makedirs(self.tmp_dir) | |||
self.model_id = 'damo/cv_gpen_image-portrait-enhancement' | |||
class PairedImageDataset(data.Dataset): | |||
def __init__(self, root, size=512): | |||
super(PairedImageDataset, self).__init__() | |||
self.size = size | |||
gt_dir = osp.join(root, 'gt') | |||
lq_dir = osp.join(root, 'lq') | |||
self.gt_filelist = os.listdir(gt_dir) | |||
self.gt_filelist = sorted( | |||
self.gt_filelist, key=lambda x: int(x[:-4])) | |||
self.gt_filelist = [ | |||
osp.join(gt_dir, f) for f in self.gt_filelist | |||
] | |||
self.lq_filelist = os.listdir(lq_dir) | |||
self.lq_filelist = sorted( | |||
self.lq_filelist, key=lambda x: int(x[:-4])) | |||
self.lq_filelist = [ | |||
osp.join(lq_dir, f) for f in self.lq_filelist | |||
] | |||
def _img_to_tensor(self, img): | |||
img = torch.from_numpy(img[:, :, [2, 1, 0]]).permute( | |||
2, 0, 1).type(torch.float32) / 255. | |||
return (img - 0.5) / 0.5 | |||
def __getitem__(self, index): | |||
lq = cv2.imread(self.lq_filelist[index]) | |||
gt = cv2.imread(self.gt_filelist[index]) | |||
lq = cv2.resize( | |||
lq, (self.size, self.size), interpolation=cv2.INTER_CUBIC) | |||
gt = cv2.resize( | |||
gt, (self.size, self.size), interpolation=cv2.INTER_CUBIC) | |||
return \ | |||
{'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)} | |||
def __len__(self): | |||
return len(self.gt_filelist) | |||
def to_torch_dataset(self, | |||
columns: Union[str, List[str]] = None, | |||
preprocessors: Union[Callable, | |||
List[Callable]] = None, | |||
**format_kwargs): | |||
# self.preprocessor = preprocessors | |||
return self | |||
self.dataset = PairedImageDataset( | |||
'./data/test/images/face_enhancement/') | |||
def tearDown(self): | |||
shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||
super().tearDown() | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_trainer(self): | |||
kwargs = dict( | |||
model=self.model_id, | |||
train_dataset=self.dataset, | |||
eval_dataset=self.dataset, | |||
device='gpu', | |||
work_dir=self.tmp_dir) | |||
trainer = build_trainer(name='gpen', default_args=kwargs) | |||
trainer.train() | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_trainer_with_model_and_args(self): | |||
tmp_dir = tempfile.TemporaryDirectory().name | |||
if not os.path.exists(tmp_dir): | |||
os.makedirs(tmp_dir) | |||
cache_path = snapshot_download(self.model_id) | |||
model = ImagePortraitEnhancement.from_pretrained(cache_path) | |||
kwargs = dict( | |||
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||
model=model, | |||
train_dataset=self.dataset, | |||
eval_dataset=self.dataset, | |||
device='gpu', | |||
max_epochs=2, | |||
work_dir=self.tmp_dir) | |||
trainer = build_trainer(name='gpen', default_args=kwargs) | |||
trainer.train() | |||
if __name__ == '__main__': | |||
unittest.main() |