Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10111615master
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:46db348eae61448f1668ce282caec21375e96c3268d53da44aa67ec32cbf4fa5 | |||
size 2747938 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:709c1828ed2d56badf2f19a40194da9a5e5e6db2fb73ef55d047407f49bc7a15 | |||
size 27616 |
@@ -27,6 +27,7 @@ class Models(object): | |||
face_2d_keypoints = 'face-2d-keypoints' | |||
panoptic_segmentation = 'swinL-panoptic-segmentation' | |||
image_reid_person = 'passvitb' | |||
image_inpainting = 'FFTInpainting' | |||
video_summarization = 'pgl-video-summarization' | |||
swinL_semantic_segmentation = 'swinL-semantic-segmentation' | |||
vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | |||
@@ -179,6 +180,7 @@ class Pipelines(object): | |||
video_summarization = 'googlenet_pgl_video_summarization' | |||
image_semantic_segmentation = 'image-semantic-segmentation' | |||
image_reid_person = 'passvitb-image-reid-person' | |||
image_inpainting = 'fft-inpainting' | |||
text_driven_segmentation = 'text-driven-segmentation' | |||
movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' | |||
shop_segmentation = 'shop-segmentation' | |||
@@ -264,6 +266,7 @@ class Trainers(object): | |||
image_portrait_enhancement = 'image-portrait-enhancement' | |||
video_summarization = 'video-summarization' | |||
movie_scene_segmentation = 'movie-scene-segmentation' | |||
image_inpainting = 'image-inpainting' | |||
# nlp trainers | |||
bert_sentiment_analysis = 'bert-sentiment-analysis' | |||
@@ -363,6 +366,8 @@ class Metrics(object): | |||
video_summarization_metric = 'video-summarization-metric' | |||
# metric for movie-scene-segmentation task | |||
movie_scene_segmentation_metric = 'movie-scene-segmentation-metric' | |||
# metric for inpainting task | |||
image_inpainting_metric = 'image-inpainting-metric' | |||
class Optimizers(object): | |||
@@ -17,6 +17,7 @@ if TYPE_CHECKING: | |||
from .token_classification_metric import TokenClassificationMetric | |||
from .video_summarization_metric import VideoSummarizationMetric | |||
from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric | |||
from .image_inpainting_metric import ImageInpaintingMetric | |||
else: | |||
_import_structure = { | |||
@@ -34,6 +35,7 @@ else: | |||
'token_classification_metric': ['TokenClassificationMetric'], | |||
'video_summarization_metric': ['VideoSummarizationMetric'], | |||
'movie_scene_segmentation_metric': ['MovieSceneSegmentationMetric'], | |||
'image_inpainting_metric': ['ImageInpaintingMetric'], | |||
} | |||
import sys | |||
@@ -18,6 +18,7 @@ class MetricKeys(object): | |||
SSIM = 'ssim' | |||
AVERAGE_LOSS = 'avg_loss' | |||
FScore = 'fscore' | |||
FID = 'fid' | |||
BLEU_1 = 'bleu-1' | |||
BLEU_4 = 'bleu-4' | |||
ROUGE_1 = 'rouge-1' | |||
@@ -39,6 +40,7 @@ task_default_metrics = { | |||
Tasks.image_captioning: [Metrics.text_gen_metric], | |||
Tasks.visual_question_answering: [Metrics.text_gen_metric], | |||
Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric], | |||
Tasks.image_inpainting: [Metrics.image_inpainting_metric], | |||
} | |||
@@ -0,0 +1,210 @@ | |||
""" | |||
Part of the implementation is borrowed and modified from LaMa, publicly available at | |||
https://github.com/saic-mdal/lama | |||
""" | |||
from typing import Dict | |||
import numpy as np | |||
import torch | |||
import torch.nn.functional as F | |||
from scipy import linalg | |||
from modelscope.metainfo import Metrics | |||
from modelscope.models.cv.image_inpainting.modules.inception import InceptionV3 | |||
from modelscope.utils.registry import default_group | |||
from modelscope.utils.tensor_utils import (torch_nested_detach, | |||
torch_nested_numpify) | |||
from .base import Metric | |||
from .builder import METRICS, MetricKeys | |||
def fid_calculate_activation_statistics(act): | |||
mu = np.mean(act, axis=0) | |||
sigma = np.cov(act, rowvar=False) | |||
return mu, sigma | |||
def calculate_frechet_distance(activations_pred, activations_target, eps=1e-6): | |||
mu1, sigma1 = fid_calculate_activation_statistics(activations_pred) | |||
mu2, sigma2 = fid_calculate_activation_statistics(activations_target) | |||
diff = mu1 - mu2 | |||
# Product might be almost singular | |||
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) | |||
if not np.isfinite(covmean).all(): | |||
offset = np.eye(sigma1.shape[0]) * eps | |||
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) | |||
# Numerical error might give slight imaginary component | |||
if np.iscomplexobj(covmean): | |||
# if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): | |||
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-2): | |||
m = np.max(np.abs(covmean.imag)) | |||
raise ValueError('Imaginary component {}'.format(m)) | |||
covmean = covmean.real | |||
tr_covmean = np.trace(covmean) | |||
return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) | |||
- 2 * tr_covmean) | |||
class FIDScore(torch.nn.Module): | |||
def __init__(self, dims=2048, eps=1e-6): | |||
super().__init__() | |||
if getattr(FIDScore, '_MODEL', None) is None: | |||
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] | |||
FIDScore._MODEL = InceptionV3([block_idx]).eval() | |||
self.model = FIDScore._MODEL | |||
self.eps = eps | |||
self.reset() | |||
def forward(self, pred_batch, target_batch, mask=None): | |||
activations_pred = self._get_activations(pred_batch) | |||
activations_target = self._get_activations(target_batch) | |||
self.activations_pred.append(activations_pred.detach().cpu()) | |||
self.activations_target.append(activations_target.detach().cpu()) | |||
def get_value(self): | |||
activations_pred, activations_target = (self.activations_pred, | |||
self.activations_target) | |||
activations_pred = torch.cat(activations_pred).cpu().numpy() | |||
activations_target = torch.cat(activations_target).cpu().numpy() | |||
total_distance = calculate_frechet_distance( | |||
activations_pred, activations_target, eps=self.eps) | |||
self.reset() | |||
return total_distance | |||
def reset(self): | |||
self.activations_pred = [] | |||
self.activations_target = [] | |||
def _get_activations(self, batch): | |||
activations = self.model(batch)[0] | |||
if activations.shape[2] != 1 or activations.shape[3] != 1: | |||
assert False, \ | |||
'We should not have got here, because Inception always scales inputs to 299x299' | |||
activations = activations.squeeze(-1).squeeze(-1) | |||
return activations | |||
class SSIM(torch.nn.Module): | |||
"""SSIM. Modified from: | |||
https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py | |||
""" | |||
def __init__(self, window_size=11, size_average=True): | |||
super().__init__() | |||
self.window_size = window_size | |||
self.size_average = size_average | |||
self.channel = 1 | |||
self.register_buffer('window', | |||
self._create_window(window_size, self.channel)) | |||
def forward(self, img1, img2): | |||
assert len(img1.shape) == 4 | |||
channel = img1.size()[1] | |||
if channel == self.channel and self.window.data.type( | |||
) == img1.data.type(): | |||
window = self.window | |||
else: | |||
window = self._create_window(self.window_size, channel) | |||
window = window.type_as(img1) | |||
self.window = window | |||
self.channel = channel | |||
return self._ssim(img1, img2, window, self.window_size, channel, | |||
self.size_average) | |||
def _gaussian(self, window_size, sigma): | |||
gauss = torch.Tensor([ | |||
np.exp(-(x - (window_size // 2))**2 / float(2 * sigma**2)) | |||
for x in range(window_size) | |||
]) | |||
return gauss / gauss.sum() | |||
def _create_window(self, window_size, channel): | |||
_1D_window = self._gaussian(window_size, 1.5).unsqueeze(1) | |||
_2D_window = _1D_window.mm( | |||
_1D_window.t()).float().unsqueeze(0).unsqueeze(0) | |||
return _2D_window.expand(channel, 1, window_size, | |||
window_size).contiguous() | |||
def _ssim(self, | |||
img1, | |||
img2, | |||
window, | |||
window_size, | |||
channel, | |||
size_average=True): | |||
mu1 = F.conv2d( | |||
img1, window, padding=(window_size // 2), groups=channel) | |||
mu2 = F.conv2d( | |||
img2, window, padding=(window_size // 2), groups=channel) | |||
mu1_sq = mu1.pow(2) | |||
mu2_sq = mu2.pow(2) | |||
mu1_mu2 = mu1 * mu2 | |||
sigma1_sq = F.conv2d( | |||
img1 * img1, window, padding=(window_size // 2), | |||
groups=channel) - mu1_sq | |||
sigma2_sq = F.conv2d( | |||
img2 * img2, window, padding=(window_size // 2), | |||
groups=channel) - mu2_sq | |||
sigma12 = F.conv2d( | |||
img1 * img2, window, padding=(window_size // 2), | |||
groups=channel) - mu1_mu2 | |||
C1 = 0.01**2 | |||
C2 = 0.03**2 | |||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ | |||
((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) | |||
if size_average: | |||
return ssim_map.mean() | |||
return ssim_map.mean(1).mean(1).mean(1) | |||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | |||
missing_keys, unexpected_keys, error_msgs): | |||
return | |||
@METRICS.register_module( | |||
group_key=default_group, module_name=Metrics.image_inpainting_metric) | |||
class ImageInpaintingMetric(Metric): | |||
"""The metric computation class for image inpainting classes. | |||
""" | |||
def __init__(self): | |||
self.preds = [] | |||
self.targets = [] | |||
self.SSIM = SSIM(window_size=11, size_average=False).eval() | |||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||
self.FID = FIDScore().to(device) | |||
def add(self, outputs: Dict, inputs: Dict): | |||
pred = outputs['inpainted'] | |||
target = inputs['image'] | |||
self.preds.append(torch_nested_detach(pred)) | |||
self.targets.append(torch_nested_detach(target)) | |||
def evaluate(self): | |||
ssim_list = [] | |||
for (pred, target) in zip(self.preds, self.targets): | |||
ssim_list.append(self.SSIM(pred, target)) | |||
self.FID(pred, target) | |||
ssim_list = torch_nested_numpify(ssim_list) | |||
fid = self.FID.get_value() | |||
return {MetricKeys.SSIM: np.mean(ssim_list), MetricKeys.FID: fid} |
@@ -5,13 +5,14 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, | |||
body_3d_keypoints, cartoon, cmdssl_video_embedding, | |||
crowd_counting, face_2d_keypoints, face_detection, | |||
face_generation, image_classification, image_color_enhance, | |||
image_colorization, image_denoise, image_instance_segmentation, | |||
image_panoptic_segmentation, image_portrait_enhancement, | |||
image_reid_person, image_semantic_segmentation, | |||
image_to_image_generation, image_to_image_translation, | |||
movie_scene_segmentation, object_detection, | |||
product_retrieval_embedding, realtime_object_detection, | |||
salient_detection, shop_segmentation, super_resolution, | |||
video_single_object_tracking, video_summarization, virual_tryon) | |||
image_colorization, image_denoise, image_inpainting, | |||
image_instance_segmentation, image_panoptic_segmentation, | |||
image_portrait_enhancement, image_reid_person, | |||
image_semantic_segmentation, image_to_image_generation, | |||
image_to_image_translation, movie_scene_segmentation, | |||
object_detection, product_retrieval_embedding, | |||
realtime_object_detection, salient_detection, shop_segmentation, | |||
super_resolution, video_single_object_tracking, | |||
video_summarization, virual_tryon) | |||
# yapf: enable |
@@ -1,3 +1,5 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
from typing import Any, Dict, Optional, Union | |||
@@ -1,10 +1,10 @@ | |||
# ------------------------------------------------------------------------------ | |||
# Copyright (c) Microsoft | |||
# Licensed under the MIT License. | |||
# Written by Bin Xiao (Bin.Xiao@microsoft.com) | |||
# Modified by Ke Sun (sunk@mail.ustc.edu.cn) | |||
# https://github.com/HRNet/HRNet-Image-Classification/blob/master/lib/models/cls_hrnet.py | |||
# ------------------------------------------------------------------------------ | |||
""" | |||
Copyright (c) Microsoft | |||
Licensed under the MIT License. | |||
Written by Bin Xiao (Bin.Xiao@microsoft.com) | |||
Modified by Ke Sun (sunk@mail.ustc.edu.cn) | |||
https://github.com/HRNet/HRNet-Image-Classification/blob/master/lib/models/cls_hrnet.py | |||
""" | |||
import functools | |||
import logging | |||
@@ -0,0 +1,22 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .model import FFTInpainting | |||
else: | |||
_import_structure = { | |||
'model': ['FFTInpainting'], | |||
} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1,75 @@ | |||
""" | |||
Part of the implementation is borrowed and modified from LaMa, publicly available at | |||
https://github.com/saic-mdal/lama | |||
""" | |||
from typing import Dict, Tuple | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from modelscope.utils.logger import get_logger | |||
from .modules.adversarial import NonSaturatingWithR1 | |||
from .modules.ffc import FFCResNetGenerator | |||
from .modules.perceptual import ResNetPL | |||
from .modules.pix2pixhd import NLayerDiscriminator | |||
LOGGER = get_logger() | |||
class BaseInpaintingTrainingModule(nn.Module): | |||
def __init__(self, | |||
model_dir='', | |||
use_ddp=True, | |||
predict_only=False, | |||
visualize_each_iters=100, | |||
average_generator=False, | |||
generator_avg_beta=0.999, | |||
average_generator_start_step=30000, | |||
average_generator_period=10, | |||
store_discr_outputs_for_vis=False, | |||
**kwargs): | |||
super().__init__() | |||
LOGGER.info( | |||
f'BaseInpaintingTrainingModule init called, predict_only is {predict_only}' | |||
) | |||
self.generator = FFCResNetGenerator() | |||
self.use_ddp = use_ddp | |||
if not predict_only: | |||
self.discriminator = NLayerDiscriminator() | |||
self.adversarial_loss = NonSaturatingWithR1( | |||
weight=10, | |||
gp_coef=0.001, | |||
mask_as_fake_target=True, | |||
allow_scale_mask=True) | |||
self.average_generator = average_generator | |||
self.generator_avg_beta = generator_avg_beta | |||
self.average_generator_start_step = average_generator_start_step | |||
self.average_generator_period = average_generator_period | |||
self.generator_average = None | |||
self.last_generator_averaging_step = -1 | |||
self.store_discr_outputs_for_vis = store_discr_outputs_for_vis | |||
self.loss_l1 = nn.L1Loss(reduction='none') | |||
self.loss_resnet_pl = ResNetPL(weight=30, weights_path=model_dir) | |||
self.visualize_each_iters = visualize_each_iters | |||
LOGGER.info('BaseInpaintingTrainingModule init done') | |||
def forward(self, batch: Dict[str, | |||
torch.Tensor]) -> Dict[str, torch.Tensor]: | |||
"""Pass data through generator and obtain at leas 'predicted_image' and 'inpainted' keys""" | |||
raise NotImplementedError() | |||
def generator_loss(self, | |||
batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |||
raise NotImplementedError() | |||
def discriminator_loss( | |||
self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |||
raise NotImplementedError() |
@@ -0,0 +1,210 @@ | |||
""" | |||
Part of the implementation is borrowed and modified from LaMa, publicly available at | |||
https://github.com/saic-mdal/lama | |||
""" | |||
import bisect | |||
import torch | |||
import torch.nn.functional as F | |||
from modelscope.utils.logger import get_logger | |||
from .base import BaseInpaintingTrainingModule | |||
from .modules.feature_matching import feature_matching_loss, masked_l1_loss | |||
LOGGER = get_logger() | |||
def set_requires_grad(module, value): | |||
for param in module.parameters(): | |||
param.requires_grad = value | |||
def add_prefix_to_keys(dct, prefix): | |||
return {prefix + k: v for k, v in dct.items()} | |||
class LinearRamp: | |||
def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0): | |||
self.start_value = start_value | |||
self.end_value = end_value | |||
self.start_iter = start_iter | |||
self.end_iter = end_iter | |||
def __call__(self, i): | |||
if i < self.start_iter: | |||
return self.start_value | |||
if i >= self.end_iter: | |||
return self.end_value | |||
part = (i - self.start_iter) / (self.end_iter - self.start_iter) | |||
return self.start_value * (1 - part) + self.end_value * part | |||
class LadderRamp: | |||
def __init__(self, start_iters, values): | |||
self.start_iters = start_iters | |||
self.values = values | |||
assert len(values) == len(start_iters) + 1, (len(values), | |||
len(start_iters)) | |||
def __call__(self, i): | |||
segment_i = bisect.bisect_right(self.start_iters, i) | |||
return self.values[segment_i] | |||
def get_ramp(kind='ladder', **kwargs): | |||
if kind == 'linear': | |||
return LinearRamp(**kwargs) | |||
if kind == 'ladder': | |||
return LadderRamp(**kwargs) | |||
raise ValueError(f'Unexpected ramp kind: {kind}') | |||
class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule): | |||
def __init__(self, | |||
model_dir='', | |||
predict_only=False, | |||
concat_mask=True, | |||
rescale_scheduler_kwargs=None, | |||
image_to_discriminator='predicted_image', | |||
add_noise_kwargs=None, | |||
noise_fill_hole=False, | |||
const_area_crop_kwargs=None, | |||
distance_weighter_kwargs=None, | |||
distance_weighted_mask_for_discr=False, | |||
fake_fakes_proba=0, | |||
fake_fakes_generator_kwargs=None, | |||
**kwargs): | |||
super().__init__(model_dir=model_dir, predict_only=predict_only) | |||
self.concat_mask = concat_mask | |||
self.rescale_size_getter = get_ramp( | |||
**rescale_scheduler_kwargs | |||
) if rescale_scheduler_kwargs is not None else None | |||
self.image_to_discriminator = image_to_discriminator | |||
self.add_noise_kwargs = add_noise_kwargs | |||
self.noise_fill_hole = noise_fill_hole | |||
self.const_area_crop_kwargs = const_area_crop_kwargs | |||
self.refine_mask_for_losses = None | |||
self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr | |||
self.feature_matching_weight = 100 | |||
self.losses_l1_weight_known = 10 | |||
self.losses_l1_weight_missing = 0 | |||
self.fake_fakes_proba = fake_fakes_proba | |||
def forward(self, batch): | |||
img = batch['image'] | |||
mask = batch['mask'] | |||
masked_img = img * (1 - mask) | |||
if self.concat_mask: | |||
masked_img = torch.cat([masked_img, mask], dim=1) | |||
batch['predicted_image'] = self.generator(masked_img) | |||
batch['inpainted'] = mask * batch['predicted_image'] + ( | |||
1 - mask) * batch['image'] | |||
batch['mask_for_losses'] = mask | |||
return batch | |||
def generator_loss(self, batch): | |||
img = batch['image'] | |||
predicted_img = batch[self.image_to_discriminator] | |||
original_mask = batch['mask'] | |||
supervised_mask = batch['mask_for_losses'] | |||
# L1 | |||
l1_value = masked_l1_loss(predicted_img, img, supervised_mask, | |||
self.losses_l1_weight_known, | |||
self.losses_l1_weight_missing) | |||
total_loss = l1_value | |||
metrics = dict(gen_l1=l1_value) | |||
# discriminator | |||
# adversarial_loss calls backward by itself | |||
mask_for_discr = supervised_mask if self.distance_weighted_mask_for_discr else original_mask | |||
self.adversarial_loss.pre_generator_step( | |||
real_batch=img, | |||
fake_batch=predicted_img, | |||
generator=self.generator, | |||
discriminator=self.discriminator) | |||
discr_real_pred, discr_real_features = self.discriminator(img) | |||
discr_fake_pred, discr_fake_features = self.discriminator( | |||
predicted_img) | |||
adv_gen_loss, adv_metrics = self.adversarial_loss.generator_loss( | |||
real_batch=img, | |||
fake_batch=predicted_img, | |||
discr_real_pred=discr_real_pred, | |||
discr_fake_pred=discr_fake_pred, | |||
mask=mask_for_discr) | |||
total_loss = total_loss + adv_gen_loss | |||
metrics['gen_adv'] = adv_gen_loss | |||
metrics.update(add_prefix_to_keys(adv_metrics, 'adv_')) | |||
# feature matching | |||
if self.feature_matching_weight > 0: | |||
need_mask_in_fm = False | |||
mask_for_fm = supervised_mask if need_mask_in_fm else None | |||
fm_value = feature_matching_loss( | |||
discr_fake_features, discr_real_features, | |||
mask=mask_for_fm) * self.feature_matching_weight | |||
total_loss = total_loss + fm_value | |||
metrics['gen_fm'] = fm_value | |||
if self.loss_resnet_pl is not None: | |||
resnet_pl_value = self.loss_resnet_pl(predicted_img, img) | |||
total_loss = total_loss + resnet_pl_value | |||
metrics['gen_resnet_pl'] = resnet_pl_value | |||
return total_loss, metrics | |||
def discriminator_loss(self, batch): | |||
total_loss = 0 | |||
metrics = {} | |||
predicted_img = batch[self.image_to_discriminator].detach() | |||
self.adversarial_loss.pre_discriminator_step( | |||
real_batch=batch['image'], | |||
fake_batch=predicted_img, | |||
generator=self.generator, | |||
discriminator=self.discriminator) | |||
discr_real_pred, discr_real_features = self.discriminator( | |||
batch['image']) | |||
discr_fake_pred, discr_fake_features = self.discriminator( | |||
predicted_img) | |||
adv_discr_loss, adv_metrics = self.adversarial_loss.discriminator_loss( | |||
real_batch=batch['image'], | |||
fake_batch=predicted_img, | |||
discr_real_pred=discr_real_pred, | |||
discr_fake_pred=discr_fake_pred, | |||
mask=batch['mask']) | |||
total_loss = (total_loss + adv_discr_loss) * 0.1 | |||
metrics['discr_adv'] = adv_discr_loss | |||
metrics.update(add_prefix_to_keys(adv_metrics, 'adv_')) | |||
return total_loss, metrics | |||
def _do_step(self, batch, optimizer_idx=None): | |||
if optimizer_idx == 0: # step for generator | |||
set_requires_grad(self.generator, True) | |||
set_requires_grad(self.discriminator, False) | |||
elif optimizer_idx == 1: # step for discriminator | |||
set_requires_grad(self.generator, False) | |||
set_requires_grad(self.discriminator, True) | |||
batch = self(batch) | |||
total_loss = 0 | |||
if optimizer_idx is None or optimizer_idx == 0: # step for generator | |||
total_loss, metrics = self.generator_loss(batch) | |||
elif optimizer_idx is None or optimizer_idx == 1: # step for discriminator | |||
total_loss, metrics = self.discriminator_loss(batch) | |||
result = dict(loss=total_loss) | |||
return result |
@@ -0,0 +1,36 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
from typing import Any, Dict, Optional, Union | |||
import torch | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base.base_torch_model import TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
LOGGER = get_logger() | |||
@MODELS.register_module( | |||
Tasks.image_inpainting, module_name=Models.image_inpainting) | |||
class FFTInpainting(TorchModel): | |||
def __init__(self, model_dir: str, **kwargs): | |||
super().__init__(model_dir, **kwargs) | |||
from .default import DefaultInpaintingTrainingModule | |||
pretrained = kwargs.get('pretrained', True) | |||
predict_only = kwargs.get('predict_only', False) | |||
net = DefaultInpaintingTrainingModule( | |||
model_dir=model_dir, predict_only=predict_only) | |||
if pretrained: | |||
path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||
LOGGER.info(f'loading pretrained model from {path}') | |||
state = torch.load(path, map_location='cpu') | |||
net.load_state_dict(state, strict=False) | |||
self.model = net | |||
def forward(self, inputs): | |||
return self.model(inputs) |
@@ -0,0 +1,2 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from .base import ModelBuilder |
@@ -0,0 +1,380 @@ | |||
""" | |||
Part of the implementation is borrowed and modified from LaMa, publicly available at | |||
https://github.com/saic-mdal/lama | |||
""" | |||
import os | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from torch.nn.modules import BatchNorm2d | |||
from . import resnet | |||
NUM_CLASS = 150 | |||
# Model Builder | |||
class ModelBuilder: | |||
# custom weights initialization | |||
@staticmethod | |||
def weights_init(m): | |||
classname = m.__class__.__name__ | |||
if classname.find('Conv') != -1: | |||
nn.init.kaiming_normal_(m.weight.data) | |||
elif classname.find('BatchNorm') != -1: | |||
m.weight.data.fill_(1.) | |||
m.bias.data.fill_(1e-4) | |||
@staticmethod | |||
def build_encoder(arch='resnet50dilated', | |||
fc_dim=512, | |||
weights='', | |||
model_dir=''): | |||
pretrained = True if len(weights) == 0 else False | |||
arch = arch.lower() | |||
if arch == 'resnet50dilated': | |||
orig_resnet = resnet.__dict__['resnet50']( | |||
pretrained=pretrained, model_dir=model_dir) | |||
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) | |||
elif arch == 'resnet50': | |||
orig_resnet = resnet.__dict__['resnet50']( | |||
pretrained=pretrained, model_dir=model_dir) | |||
net_encoder = Resnet(orig_resnet) | |||
else: | |||
raise Exception('Architecture undefined!') | |||
# encoders are usually pretrained | |||
# net_encoder.apply(ModelBuilder.weights_init) | |||
if len(weights) > 0: | |||
print('Loading weights for net_encoder') | |||
net_encoder.load_state_dict( | |||
torch.load(weights, map_location=lambda storage, loc: storage), | |||
strict=False) | |||
return net_encoder | |||
@staticmethod | |||
def build_decoder(arch='ppm_deepsup', | |||
fc_dim=512, | |||
num_class=NUM_CLASS, | |||
weights='', | |||
use_softmax=False, | |||
drop_last_conv=False): | |||
arch = arch.lower() | |||
if arch == 'ppm_deepsup': | |||
net_decoder = PPMDeepsup( | |||
num_class=num_class, | |||
fc_dim=fc_dim, | |||
use_softmax=use_softmax, | |||
drop_last_conv=drop_last_conv) | |||
elif arch == 'c1_deepsup': | |||
net_decoder = C1DeepSup( | |||
num_class=num_class, | |||
fc_dim=fc_dim, | |||
use_softmax=use_softmax, | |||
drop_last_conv=drop_last_conv) | |||
else: | |||
raise Exception('Architecture undefined!') | |||
net_decoder.apply(ModelBuilder.weights_init) | |||
if len(weights) > 0: | |||
print('Loading weights for net_decoder') | |||
net_decoder.load_state_dict( | |||
torch.load(weights, map_location=lambda storage, loc: storage), | |||
strict=False) | |||
return net_decoder | |||
@staticmethod | |||
def get_decoder(weights_path, arch_encoder, arch_decoder, fc_dim, | |||
drop_last_conv, *arts, **kwargs): | |||
path = os.path.join( | |||
weights_path, 'ade20k', | |||
f'ade20k-{arch_encoder}-{arch_decoder}/decoder_epoch_20.pth') | |||
return ModelBuilder.build_decoder( | |||
arch=arch_decoder, | |||
fc_dim=fc_dim, | |||
weights=path, | |||
use_softmax=True, | |||
drop_last_conv=drop_last_conv) | |||
@staticmethod | |||
def get_encoder(weights_path, arch_encoder, arch_decoder, fc_dim, | |||
segmentation, *arts, **kwargs): | |||
if segmentation: | |||
path = os.path.join( | |||
weights_path, 'ade20k', | |||
f'ade20k-{arch_encoder}-{arch_decoder}/encoder_epoch_20.pth') | |||
else: | |||
path = '' | |||
return ModelBuilder.build_encoder( | |||
arch=arch_encoder, | |||
fc_dim=fc_dim, | |||
weights=path, | |||
model_dir=weights_path) | |||
def conv3x3_bn_relu(in_planes, out_planes, stride=1): | |||
return nn.Sequential( | |||
nn.Conv2d( | |||
in_planes, | |||
out_planes, | |||
kernel_size=3, | |||
stride=stride, | |||
padding=1, | |||
bias=False), | |||
BatchNorm2d(out_planes), | |||
nn.ReLU(inplace=True), | |||
) | |||
# pyramid pooling, deep supervision | |||
class PPMDeepsup(nn.Module): | |||
def __init__(self, | |||
num_class=NUM_CLASS, | |||
fc_dim=4096, | |||
use_softmax=False, | |||
pool_scales=(1, 2, 3, 6), | |||
drop_last_conv=False): | |||
super().__init__() | |||
self.use_softmax = use_softmax | |||
self.drop_last_conv = drop_last_conv | |||
self.ppm = [] | |||
for scale in pool_scales: | |||
self.ppm.append( | |||
nn.Sequential( | |||
nn.AdaptiveAvgPool2d(scale), | |||
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), | |||
BatchNorm2d(512), nn.ReLU(inplace=True))) | |||
self.ppm = nn.ModuleList(self.ppm) | |||
self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) | |||
self.conv_last = nn.Sequential( | |||
nn.Conv2d( | |||
fc_dim + len(pool_scales) * 512, | |||
512, | |||
kernel_size=3, | |||
padding=1, | |||
bias=False), BatchNorm2d(512), nn.ReLU(inplace=True), | |||
nn.Dropout2d(0.1), nn.Conv2d(512, num_class, kernel_size=1)) | |||
self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) | |||
self.dropout_deepsup = nn.Dropout2d(0.1) | |||
def forward(self, conv_out, segSize=None): | |||
conv5 = conv_out[-1] | |||
input_size = conv5.size() | |||
ppm_out = [conv5] | |||
for pool_scale in self.ppm: | |||
ppm_out.append( | |||
nn.functional.interpolate( | |||
pool_scale(conv5), (input_size[2], input_size[3]), | |||
mode='bilinear', | |||
align_corners=False)) | |||
ppm_out = torch.cat(ppm_out, 1) | |||
if self.drop_last_conv: | |||
return ppm_out | |||
else: | |||
x = self.conv_last(ppm_out) | |||
if self.use_softmax: # is True during inference | |||
x = nn.functional.interpolate( | |||
x, size=segSize, mode='bilinear', align_corners=False) | |||
x = nn.functional.softmax(x, dim=1) | |||
return x | |||
# deep sup | |||
conv4 = conv_out[-2] | |||
_ = self.cbr_deepsup(conv4) | |||
_ = self.dropout_deepsup(_) | |||
_ = self.conv_last_deepsup(_) | |||
x = nn.functional.log_softmax(x, dim=1) | |||
_ = nn.functional.log_softmax(_, dim=1) | |||
return (x, _) | |||
class Resnet(nn.Module): | |||
def __init__(self, orig_resnet): | |||
super(Resnet, self).__init__() | |||
# take pretrained resnet, except AvgPool and FC | |||
self.conv1 = orig_resnet.conv1 | |||
self.bn1 = orig_resnet.bn1 | |||
self.relu1 = orig_resnet.relu1 | |||
self.conv2 = orig_resnet.conv2 | |||
self.bn2 = orig_resnet.bn2 | |||
self.relu2 = orig_resnet.relu2 | |||
self.conv3 = orig_resnet.conv3 | |||
self.bn3 = orig_resnet.bn3 | |||
self.relu3 = orig_resnet.relu3 | |||
self.maxpool = orig_resnet.maxpool | |||
self.layer1 = orig_resnet.layer1 | |||
self.layer2 = orig_resnet.layer2 | |||
self.layer3 = orig_resnet.layer3 | |||
self.layer4 = orig_resnet.layer4 | |||
def forward(self, x, return_feature_maps=False): | |||
conv_out = [] | |||
x = self.relu1(self.bn1(self.conv1(x))) | |||
x = self.relu2(self.bn2(self.conv2(x))) | |||
x = self.relu3(self.bn3(self.conv3(x))) | |||
x = self.maxpool(x) | |||
x = self.layer1(x) | |||
conv_out.append(x) | |||
x = self.layer2(x) | |||
conv_out.append(x) | |||
x = self.layer3(x) | |||
conv_out.append(x) | |||
x = self.layer4(x) | |||
conv_out.append(x) | |||
if return_feature_maps: | |||
return conv_out | |||
return [x] | |||
# Resnet Dilated | |||
class ResnetDilated(nn.Module): | |||
def __init__(self, orig_resnet, dilate_scale=8): | |||
super().__init__() | |||
from functools import partial | |||
if dilate_scale == 8: | |||
orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2)) | |||
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4)) | |||
elif dilate_scale == 16: | |||
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2)) | |||
# take pretrained resnet, except AvgPool and FC | |||
self.conv1 = orig_resnet.conv1 | |||
self.bn1 = orig_resnet.bn1 | |||
self.relu1 = orig_resnet.relu1 | |||
self.conv2 = orig_resnet.conv2 | |||
self.bn2 = orig_resnet.bn2 | |||
self.relu2 = orig_resnet.relu2 | |||
self.conv3 = orig_resnet.conv3 | |||
self.bn3 = orig_resnet.bn3 | |||
self.relu3 = orig_resnet.relu3 | |||
self.maxpool = orig_resnet.maxpool | |||
self.layer1 = orig_resnet.layer1 | |||
self.layer2 = orig_resnet.layer2 | |||
self.layer3 = orig_resnet.layer3 | |||
self.layer4 = orig_resnet.layer4 | |||
def _nostride_dilate(self, m, dilate): | |||
classname = m.__class__.__name__ | |||
if classname.find('Conv') != -1: | |||
# the convolution with stride | |||
if m.stride == (2, 2): | |||
m.stride = (1, 1) | |||
if m.kernel_size == (3, 3): | |||
m.dilation = (dilate // 2, dilate // 2) | |||
m.padding = (dilate // 2, dilate // 2) | |||
# other convoluions | |||
else: | |||
if m.kernel_size == (3, 3): | |||
m.dilation = (dilate, dilate) | |||
m.padding = (dilate, dilate) | |||
def forward(self, x, return_feature_maps=False): | |||
conv_out = [] | |||
x = self.relu1(self.bn1(self.conv1(x))) | |||
x = self.relu2(self.bn2(self.conv2(x))) | |||
x = self.relu3(self.bn3(self.conv3(x))) | |||
x = self.maxpool(x) | |||
x = self.layer1(x) | |||
conv_out.append(x) | |||
x = self.layer2(x) | |||
conv_out.append(x) | |||
x = self.layer3(x) | |||
conv_out.append(x) | |||
x = self.layer4(x) | |||
conv_out.append(x) | |||
if return_feature_maps: | |||
return conv_out | |||
return [x] | |||
# last conv, deep supervision | |||
class C1DeepSup(nn.Module): | |||
def __init__(self, | |||
num_class=150, | |||
fc_dim=2048, | |||
use_softmax=False, | |||
drop_last_conv=False): | |||
super(C1DeepSup, self).__init__() | |||
self.use_softmax = use_softmax | |||
self.drop_last_conv = drop_last_conv | |||
self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) | |||
self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) | |||
# last conv | |||
self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) | |||
self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) | |||
def forward(self, conv_out, segSize=None): | |||
conv5 = conv_out[-1] | |||
x = self.cbr(conv5) | |||
if self.drop_last_conv: | |||
return x | |||
else: | |||
x = self.conv_last(x) | |||
if self.use_softmax: # is True during inference | |||
x = nn.functional.interpolate( | |||
x, size=segSize, mode='bilinear', align_corners=False) | |||
x = nn.functional.softmax(x, dim=1) | |||
return x | |||
# deep sup | |||
conv4 = conv_out[-2] | |||
_ = self.cbr_deepsup(conv4) | |||
_ = self.conv_last_deepsup(_) | |||
x = nn.functional.log_softmax(x, dim=1) | |||
_ = nn.functional.log_softmax(_, dim=1) | |||
return (x, _) | |||
# last conv | |||
class C1(nn.Module): | |||
def __init__(self, num_class=150, fc_dim=2048, use_softmax=False): | |||
super(C1, self).__init__() | |||
self.use_softmax = use_softmax | |||
self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) | |||
# last conv | |||
self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) | |||
def forward(self, conv_out, segSize=None): | |||
conv5 = conv_out[-1] | |||
x = self.cbr(conv5) | |||
x = self.conv_last(x) | |||
if self.use_softmax: # is True during inference | |||
x = nn.functional.interpolate( | |||
x, size=segSize, mode='bilinear', align_corners=False) | |||
x = nn.functional.softmax(x, dim=1) | |||
else: | |||
x = nn.functional.log_softmax(x, dim=1) | |||
return x |
@@ -0,0 +1,183 @@ | |||
""" | |||
Part of the implementation is borrowed and modified from LaMa, publicly available at | |||
https://github.com/saic-mdal/lama | |||
""" | |||
import math | |||
import os | |||
import torch | |||
import torch.nn as nn | |||
from torch.nn import BatchNorm2d | |||
__all__ = ['ResNet', 'resnet50'] | |||
def conv3x3(in_planes, out_planes, stride=1): | |||
'3x3 convolution with padding' | |||
return nn.Conv2d( | |||
in_planes, | |||
out_planes, | |||
kernel_size=3, | |||
stride=stride, | |||
padding=1, | |||
bias=False) | |||
class BasicBlock(nn.Module): | |||
expansion = 1 | |||
def __init__(self, inplanes, planes, stride=1, downsample=None): | |||
super(BasicBlock, self).__init__() | |||
self.conv1 = conv3x3(inplanes, planes, stride) | |||
self.bn1 = BatchNorm2d(planes) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.conv2 = conv3x3(planes, planes) | |||
self.bn2 = BatchNorm2d(planes) | |||
self.downsample = downsample | |||
self.stride = stride | |||
def forward(self, x): | |||
residual = x | |||
out = self.conv1(x) | |||
out = self.bn1(out) | |||
out = self.relu(out) | |||
out = self.conv2(out) | |||
out = self.bn2(out) | |||
if self.downsample is not None: | |||
residual = self.downsample(x) | |||
out += residual | |||
out = self.relu(out) | |||
return out | |||
class Bottleneck(nn.Module): | |||
expansion = 4 | |||
def __init__(self, inplanes, planes, stride=1, downsample=None): | |||
super(Bottleneck, self).__init__() | |||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | |||
self.bn1 = BatchNorm2d(planes) | |||
self.conv2 = nn.Conv2d( | |||
planes, | |||
planes, | |||
kernel_size=3, | |||
stride=stride, | |||
padding=1, | |||
bias=False) | |||
self.bn2 = BatchNorm2d(planes) | |||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) | |||
self.bn3 = BatchNorm2d(planes * 4) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.downsample = downsample | |||
self.stride = stride | |||
def forward(self, x): | |||
residual = x | |||
out = self.conv1(x) | |||
out = self.bn1(out) | |||
out = self.relu(out) | |||
out = self.conv2(out) | |||
out = self.bn2(out) | |||
out = self.relu(out) | |||
out = self.conv3(out) | |||
out = self.bn3(out) | |||
if self.downsample is not None: | |||
residual = self.downsample(x) | |||
out += residual | |||
out = self.relu(out) | |||
return out | |||
class ResNet(nn.Module): | |||
def __init__(self, block, layers, num_classes=1000): | |||
self.inplanes = 128 | |||
super(ResNet, self).__init__() | |||
self.conv1 = conv3x3(3, 64, stride=2) | |||
self.bn1 = BatchNorm2d(64) | |||
self.relu1 = nn.ReLU(inplace=True) | |||
self.conv2 = conv3x3(64, 64) | |||
self.bn2 = BatchNorm2d(64) | |||
self.relu2 = nn.ReLU(inplace=True) | |||
self.conv3 = conv3x3(64, 128) | |||
self.bn3 = BatchNorm2d(128) | |||
self.relu3 = nn.ReLU(inplace=True) | |||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||
self.layer1 = self._make_layer(block, 64, layers[0]) | |||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | |||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | |||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2) | |||
self.avgpool = nn.AvgPool2d(7, stride=1) | |||
self.fc = nn.Linear(512 * block.expansion, num_classes) | |||
for m in self.modules(): | |||
if isinstance(m, nn.Conv2d): | |||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||
m.weight.data.normal_(0, math.sqrt(2. / n)) | |||
elif isinstance(m, BatchNorm2d): | |||
m.weight.data.fill_(1) | |||
m.bias.data.zero_() | |||
def _make_layer(self, block, planes, blocks, stride=1): | |||
downsample = None | |||
if stride != 1 or self.inplanes != planes * block.expansion: | |||
downsample = nn.Sequential( | |||
nn.Conv2d( | |||
self.inplanes, | |||
planes * block.expansion, | |||
kernel_size=1, | |||
stride=stride, | |||
bias=False), | |||
BatchNorm2d(planes * block.expansion), | |||
) | |||
layers = [] | |||
layers.append(block(self.inplanes, planes, stride, downsample)) | |||
self.inplanes = planes * block.expansion | |||
for i in range(1, blocks): | |||
layers.append(block(self.inplanes, planes)) | |||
return nn.Sequential(*layers) | |||
def forward(self, x): | |||
x = self.relu1(self.bn1(self.conv1(x))) | |||
x = self.relu2(self.bn2(self.conv2(x))) | |||
x = self.relu3(self.bn3(self.conv3(x))) | |||
x = self.maxpool(x) | |||
x = self.layer1(x) | |||
x = self.layer2(x) | |||
x = self.layer3(x) | |||
x = self.layer4(x) | |||
x = self.avgpool(x) | |||
x = x.view(x.size(0), -1) | |||
x = self.fc(x) | |||
return x | |||
def resnet50(pretrained=False, model_dir='', **kwargs): | |||
"""Constructs a ResNet-50 model. | |||
Args: | |||
pretrained (bool): If True, returns a model pre-trained on ImageNet | |||
""" | |||
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) | |||
if pretrained: | |||
cached_file = os.path.join(model_dir, 'resnet50-imagenet.pth') | |||
model.load_state_dict( | |||
torch.load(cached_file, map_location='cpu'), strict=False) | |||
return model |
@@ -0,0 +1,167 @@ | |||
""" | |||
Part of the implementation is borrowed and modified from LaMa, publicly available at | |||
https://github.com/saic-mdal/lama | |||
""" | |||
from typing import Dict, Optional, Tuple | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
class BaseAdversarialLoss: | |||
def pre_generator_step(self, real_batch: torch.Tensor, | |||
fake_batch: torch.Tensor, generator: nn.Module, | |||
discriminator: nn.Module): | |||
""" | |||
Prepare for generator step | |||
:param real_batch: Tensor, a batch of real samples | |||
:param fake_batch: Tensor, a batch of samples produced by generator | |||
:param generator: | |||
:param discriminator: | |||
:return: None | |||
""" | |||
def pre_discriminator_step(self, real_batch: torch.Tensor, | |||
fake_batch: torch.Tensor, generator: nn.Module, | |||
discriminator: nn.Module): | |||
""" | |||
Prepare for discriminator step | |||
:param real_batch: Tensor, a batch of real samples | |||
:param fake_batch: Tensor, a batch of samples produced by generator | |||
:param generator: | |||
:param discriminator: | |||
:return: None | |||
""" | |||
def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, | |||
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, | |||
mask: Optional[torch.Tensor] = None) \ | |||
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |||
""" | |||
Calculate generator loss | |||
:param real_batch: Tensor, a batch of real samples | |||
:param fake_batch: Tensor, a batch of samples produced by generator | |||
:param discr_real_pred: Tensor, discriminator output for real_batch | |||
:param discr_fake_pred: Tensor, discriminator output for fake_batch | |||
:param mask: Tensor, actual mask, which was at input of generator when making fake_batch | |||
:return: total generator loss along with some values that might be interesting to log | |||
""" | |||
raise NotImplementedError | |||
def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, | |||
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, | |||
mask: Optional[torch.Tensor] = None) \ | |||
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |||
""" | |||
Calculate discriminator loss and call .backward() on it | |||
:param real_batch: Tensor, a batch of real samples | |||
:param fake_batch: Tensor, a batch of samples produced by generator | |||
:param discr_real_pred: Tensor, discriminator output for real_batch | |||
:param discr_fake_pred: Tensor, discriminator output for fake_batch | |||
:param mask: Tensor, actual mask, which was at input of generator when making fake_batch | |||
:return: total discriminator loss along with some values that might be interesting to log | |||
""" | |||
raise NotImplementedError | |||
def interpolate_mask(self, mask, shape): | |||
assert mask is not None | |||
assert self.allow_scale_mask or shape == mask.shape[-2:] | |||
if shape != mask.shape[-2:] and self.allow_scale_mask: | |||
if self.mask_scale_mode == 'maxpool': | |||
mask = F.adaptive_max_pool2d(mask, shape) | |||
else: | |||
mask = F.interpolate( | |||
mask, size=shape, mode=self.mask_scale_mode) | |||
return mask | |||
def make_r1_gp(discr_real_pred, real_batch): | |||
if torch.is_grad_enabled(): | |||
grad_real = torch.autograd.grad( | |||
outputs=discr_real_pred.sum(), | |||
inputs=real_batch, | |||
create_graph=True)[0] | |||
grad_penalty = (grad_real.view(grad_real.shape[0], | |||
-1).norm(2, dim=1)**2).mean() | |||
else: | |||
grad_penalty = 0 | |||
real_batch.requires_grad = False | |||
return grad_penalty | |||
class NonSaturatingWithR1(BaseAdversarialLoss): | |||
def __init__(self, | |||
gp_coef=5, | |||
weight=1, | |||
mask_as_fake_target=False, | |||
allow_scale_mask=False, | |||
mask_scale_mode='nearest', | |||
extra_mask_weight_for_gen=0, | |||
use_unmasked_for_gen=True, | |||
use_unmasked_for_discr=True): | |||
self.gp_coef = gp_coef | |||
self.weight = weight | |||
# use for discr => use for gen; | |||
# otherwise we teach only the discr to pay attention to very small difference | |||
assert use_unmasked_for_gen or (not use_unmasked_for_discr) | |||
# mask as target => use unmasked for discr: | |||
# if we don't care about unmasked regions at all | |||
# then it doesn't matter if the value of mask_as_fake_target is true or false | |||
assert use_unmasked_for_discr or (not mask_as_fake_target) | |||
self.use_unmasked_for_gen = use_unmasked_for_gen | |||
self.use_unmasked_for_discr = use_unmasked_for_discr | |||
self.mask_as_fake_target = mask_as_fake_target | |||
self.allow_scale_mask = allow_scale_mask | |||
self.mask_scale_mode = mask_scale_mode | |||
self.extra_mask_weight_for_gen = extra_mask_weight_for_gen | |||
def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, | |||
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, | |||
mask=None) \ | |||
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |||
fake_loss = F.softplus(-discr_fake_pred) | |||
if (self.mask_as_fake_target and self.extra_mask_weight_for_gen > 0) or \ | |||
not self.use_unmasked_for_gen: # == if masked region should be treated differently | |||
mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:]) | |||
if not self.use_unmasked_for_gen: | |||
fake_loss = fake_loss * mask | |||
else: | |||
pixel_weights = 1 + mask * self.extra_mask_weight_for_gen | |||
fake_loss = fake_loss * pixel_weights | |||
return fake_loss.mean() * self.weight, dict() | |||
def pre_discriminator_step(self, real_batch: torch.Tensor, | |||
fake_batch: torch.Tensor, generator: nn.Module, | |||
discriminator: nn.Module): | |||
real_batch.requires_grad = True | |||
def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, | |||
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, | |||
mask=None) \ | |||
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |||
real_loss = F.softplus(-discr_real_pred) | |||
grad_penalty = make_r1_gp(discr_real_pred, real_batch) * self.gp_coef | |||
fake_loss = F.softplus(discr_fake_pred) | |||
if not self.use_unmasked_for_discr or self.mask_as_fake_target: | |||
# == if masked region should be treated differently | |||
mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:]) | |||
# use_unmasked_for_discr=False only makes sense for fakes; | |||
# for reals there is no difference beetween two regions | |||
fake_loss = fake_loss * mask | |||
if self.mask_as_fake_target: | |||
fake_loss = fake_loss + (1 | |||
- mask) * F.softplus(-discr_fake_pred) | |||
sum_discr_loss = real_loss + grad_penalty + fake_loss | |||
metrics = dict( | |||
discr_real_out=discr_real_pred.mean(), | |||
discr_fake_out=discr_fake_pred.mean(), | |||
discr_real_gp=grad_penalty) | |||
return sum_discr_loss.mean(), metrics |
@@ -0,0 +1,45 @@ | |||
""" | |||
Part of the implementation is borrowed and modified from LaMa, publicly available at | |||
https://github.com/saic-mdal/lama | |||
""" | |||
from typing import List | |||
import torch | |||
import torch.nn.functional as F | |||
def masked_l2_loss(pred, target, mask, weight_known, weight_missing): | |||
per_pixel_l2 = F.mse_loss(pred, target, reduction='none') | |||
pixel_weights = mask * weight_missing + (1 - mask) * weight_known | |||
return (pixel_weights * per_pixel_l2).mean() | |||
def masked_l1_loss(pred, target, mask, weight_known, weight_missing): | |||
per_pixel_l1 = F.l1_loss(pred, target, reduction='none') | |||
pixel_weights = mask * weight_missing + (1 - mask) * weight_known | |||
return (pixel_weights * per_pixel_l1).mean() | |||
def feature_matching_loss(fake_features: List[torch.Tensor], | |||
target_features: List[torch.Tensor], | |||
mask=None): | |||
if mask is None: | |||
res = torch.stack([ | |||
F.mse_loss(fake_feat, target_feat) | |||
for fake_feat, target_feat in zip(fake_features, target_features) | |||
]).mean() | |||
else: | |||
res = 0 | |||
norm = 0 | |||
for fake_feat, target_feat in zip(fake_features, target_features): | |||
cur_mask = F.interpolate( | |||
mask, | |||
size=fake_feat.shape[-2:], | |||
mode='bilinear', | |||
align_corners=False) | |||
error_weights = 1 - cur_mask | |||
cur_val = ((fake_feat - target_feat).pow(2) * error_weights).mean() | |||
res = res + cur_val | |||
norm += 1 | |||
res = res / norm | |||
return res |
@@ -0,0 +1,588 @@ | |||
""" | |||
Part of the implementation is borrowed and modified from LaMa, publicly available at | |||
https://github.com/saic-mdal/lama | |||
""" | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from kornia.geometry.transform import rotate | |||
def get_activation(kind='tanh'): | |||
if kind == 'tanh': | |||
return nn.Tanh() | |||
if kind == 'sigmoid': | |||
return nn.Sigmoid() | |||
if kind is False: | |||
return nn.Identity() | |||
raise ValueError(f'Unknown activation kind {kind}') | |||
class SELayer(nn.Module): | |||
def __init__(self, channel, reduction=16): | |||
super(SELayer, self).__init__() | |||
self.avg_pool = nn.AdaptiveAvgPool2d(1) | |||
self.fc = nn.Sequential( | |||
nn.Linear(channel, channel // reduction, bias=False), | |||
nn.ReLU(inplace=True), | |||
nn.Linear(channel // reduction, channel, bias=False), nn.Sigmoid()) | |||
def forward(self, x): | |||
b, c, _, _ = x.size() | |||
y = self.avg_pool(x).view(b, c) | |||
y = self.fc(y).view(b, c, 1, 1) | |||
res = x * y.expand_as(x) | |||
return res | |||
class FourierUnit(nn.Module): | |||
def __init__(self, | |||
in_channels, | |||
out_channels, | |||
groups=1, | |||
spatial_scale_factor=None, | |||
spatial_scale_mode='bilinear', | |||
spectral_pos_encoding=False, | |||
use_se=False, | |||
se_kwargs=None, | |||
ffc3d=False, | |||
fft_norm='ortho'): | |||
# bn_layer not used | |||
super(FourierUnit, self).__init__() | |||
self.groups = groups | |||
self.conv_layer = torch.nn.Conv2d( | |||
in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0), | |||
out_channels=out_channels * 2, | |||
kernel_size=1, | |||
stride=1, | |||
padding=0, | |||
groups=self.groups, | |||
bias=False) | |||
self.bn = torch.nn.BatchNorm2d(out_channels * 2) | |||
self.relu = torch.nn.ReLU(inplace=True) | |||
# squeeze and excitation block | |||
self.use_se = use_se | |||
if use_se: | |||
if se_kwargs is None: | |||
se_kwargs = {} | |||
self.se = SELayer(self.conv_layer.in_channels, **se_kwargs) | |||
self.spatial_scale_factor = spatial_scale_factor | |||
self.spatial_scale_mode = spatial_scale_mode | |||
self.spectral_pos_encoding = spectral_pos_encoding | |||
self.ffc3d = ffc3d | |||
self.fft_norm = fft_norm | |||
def forward(self, x): | |||
batch = x.shape[0] | |||
if self.spatial_scale_factor is not None: | |||
orig_size = x.shape[-2:] | |||
x = F.interpolate( | |||
x, | |||
scale_factor=self.spatial_scale_factor, | |||
mode=self.spatial_scale_mode, | |||
align_corners=False) | |||
# (batch, c, h, w/2+1, 2) | |||
fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1) | |||
ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) | |||
ffted = torch.stack((ffted.real, ffted.imag), dim=-1) | |||
ffted = ffted.permute(0, 1, 4, 2, | |||
3).contiguous() # (batch, c, 2, h, w/2+1) | |||
ffted = ffted.view(( | |||
batch, | |||
-1, | |||
) + ffted.size()[3:]) | |||
if self.spectral_pos_encoding: | |||
height, width = ffted.shape[-2:] | |||
coords_vert = torch.linspace(0, 1, | |||
height)[None, None, :, None].expand( | |||
batch, 1, height, width).to(ffted) | |||
coords_hor = torch.linspace(0, 1, | |||
width)[None, None, None, :].expand( | |||
batch, 1, height, width).to(ffted) | |||
ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1) | |||
if self.use_se: | |||
ffted = self.se(ffted) | |||
ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1) | |||
ffted = self.relu(self.bn(ffted)) | |||
ffted = ffted.view(( | |||
batch, | |||
-1, | |||
2, | |||
) + ffted.size()[2:]).permute( | |||
0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2) | |||
ffted = torch.complex(ffted[..., 0], ffted[..., 1]) | |||
ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:] | |||
output = torch.fft.irfftn( | |||
ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) | |||
if self.spatial_scale_factor is not None: | |||
output = F.interpolate( | |||
output, | |||
size=orig_size, | |||
mode=self.spatial_scale_mode, | |||
align_corners=False) | |||
return output | |||
class SpectralTransform(nn.Module): | |||
def __init__(self, | |||
in_channels, | |||
out_channels, | |||
stride=1, | |||
groups=1, | |||
enable_lfu=True, | |||
**fu_kwargs): | |||
# bn_layer not used | |||
super(SpectralTransform, self).__init__() | |||
self.enable_lfu = enable_lfu | |||
if stride == 2: | |||
self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2) | |||
else: | |||
self.downsample = nn.Identity() | |||
self.stride = stride | |||
self.conv1 = nn.Sequential( | |||
nn.Conv2d( | |||
in_channels, | |||
out_channels // 2, | |||
kernel_size=1, | |||
groups=groups, | |||
bias=False), nn.BatchNorm2d(out_channels // 2), | |||
nn.ReLU(inplace=True)) | |||
self.fu = FourierUnit(out_channels // 2, out_channels // 2, groups, | |||
**fu_kwargs) | |||
if self.enable_lfu: | |||
self.lfu = FourierUnit(out_channels // 2, out_channels // 2, | |||
groups) | |||
self.conv2 = torch.nn.Conv2d( | |||
out_channels // 2, | |||
out_channels, | |||
kernel_size=1, | |||
groups=groups, | |||
bias=False) | |||
def forward(self, x): | |||
x = self.downsample(x) | |||
x = self.conv1(x) | |||
output = self.fu(x) | |||
if self.enable_lfu: | |||
n, c, h, w = x.shape | |||
split_no = 2 | |||
split_s = h // split_no | |||
xs = torch.cat( | |||
torch.split(x[:, :c // 4], split_s, dim=-2), | |||
dim=1).contiguous() | |||
xs = torch.cat( | |||
torch.split(xs, split_s, dim=-1), dim=1).contiguous() | |||
xs = self.lfu(xs) | |||
xs = xs.repeat(1, 1, split_no, split_no).contiguous() | |||
else: | |||
xs = 0 | |||
output = self.conv2(x + output + xs) | |||
return output | |||
class LearnableSpatialTransformWrapper(nn.Module): | |||
def __init__(self, | |||
impl, | |||
pad_coef=0.5, | |||
angle_init_range=80, | |||
train_angle=True): | |||
super().__init__() | |||
self.impl = impl | |||
self.angle = torch.rand(1) * angle_init_range | |||
if train_angle: | |||
self.angle = nn.Parameter(self.angle, requires_grad=True) | |||
self.pad_coef = pad_coef | |||
def forward(self, x): | |||
if torch.is_tensor(x): | |||
return self.inverse_transform(self.impl(self.transform(x)), x) | |||
elif isinstance(x, tuple): | |||
x_trans = tuple(self.transform(elem) for elem in x) | |||
y_trans = self.impl(x_trans) | |||
return tuple( | |||
self.inverse_transform(elem, orig_x) | |||
for elem, orig_x in zip(y_trans, x)) | |||
else: | |||
raise ValueError(f'Unexpected input type {type(x)}') | |||
def transform(self, x): | |||
height, width = x.shape[2:] | |||
pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) | |||
x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode='reflect') | |||
x_padded_rotated = rotate(x_padded, angle=self.angle.to(x_padded)) | |||
return x_padded_rotated | |||
def inverse_transform(self, y_padded_rotated, orig_x): | |||
height, width = orig_x.shape[2:] | |||
pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) | |||
y_padded = rotate( | |||
y_padded_rotated, angle=-self.angle.to(y_padded_rotated)) | |||
y_height, y_width = y_padded.shape[2:] | |||
y = y_padded[:, :, pad_h:y_height - pad_h, pad_w:y_width - pad_w] | |||
return y | |||
class FFC(nn.Module): | |||
def __init__(self, | |||
in_channels, | |||
out_channels, | |||
kernel_size, | |||
ratio_gin, | |||
ratio_gout, | |||
stride=1, | |||
padding=0, | |||
dilation=1, | |||
groups=1, | |||
bias=False, | |||
enable_lfu=True, | |||
padding_type='reflect', | |||
gated=False, | |||
**spectral_kwargs): | |||
super(FFC, self).__init__() | |||
assert stride == 1 or stride == 2, 'Stride should be 1 or 2.' | |||
self.stride = stride | |||
in_cg = int(in_channels * ratio_gin) | |||
in_cl = in_channels - in_cg | |||
out_cg = int(out_channels * ratio_gout) | |||
out_cl = out_channels - out_cg | |||
self.ratio_gin = ratio_gin | |||
self.ratio_gout = ratio_gout | |||
self.global_in_num = in_cg | |||
module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d | |||
self.convl2l = module( | |||
in_cl, | |||
out_cl, | |||
kernel_size, | |||
stride, | |||
padding, | |||
dilation, | |||
groups, | |||
bias, | |||
padding_mode=padding_type) | |||
module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d | |||
self.convl2g = module( | |||
in_cl, | |||
out_cg, | |||
kernel_size, | |||
stride, | |||
padding, | |||
dilation, | |||
groups, | |||
bias, | |||
padding_mode=padding_type) | |||
module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d | |||
self.convg2l = module( | |||
in_cg, | |||
out_cl, | |||
kernel_size, | |||
stride, | |||
padding, | |||
dilation, | |||
groups, | |||
bias, | |||
padding_mode=padding_type) | |||
module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform | |||
self.convg2g = module(in_cg, out_cg, stride, | |||
1 if groups == 1 else groups // 2, enable_lfu, | |||
**spectral_kwargs) | |||
self.gated = gated | |||
module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d | |||
self.gate = module(in_channels, 2, 1) | |||
def forward(self, x): | |||
x_l, x_g = x if type(x) is tuple else (x, 0) | |||
out_xl, out_xg = 0, 0 | |||
if self.gated: | |||
total_input_parts = [x_l] | |||
if torch.is_tensor(x_g): | |||
total_input_parts.append(x_g) | |||
total_input = torch.cat(total_input_parts, dim=1) | |||
gates = torch.sigmoid(self.gate(total_input)) | |||
g2l_gate, l2g_gate = gates.chunk(2, dim=1) | |||
else: | |||
g2l_gate, l2g_gate = 1, 1 | |||
if self.ratio_gout != 1: | |||
out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate | |||
if self.ratio_gout != 0: | |||
out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g) | |||
return out_xl, out_xg | |||
class FFC_BN_ACT(nn.Module): | |||
def __init__(self, | |||
in_channels, | |||
out_channels, | |||
kernel_size, | |||
ratio_gin, | |||
ratio_gout, | |||
stride=1, | |||
padding=0, | |||
dilation=1, | |||
groups=1, | |||
bias=False, | |||
norm_layer=nn.BatchNorm2d, | |||
activation_layer=nn.Identity, | |||
padding_type='reflect', | |||
enable_lfu=True, | |||
**kwargs): | |||
super(FFC_BN_ACT, self).__init__() | |||
self.ffc = FFC( | |||
in_channels, | |||
out_channels, | |||
kernel_size, | |||
ratio_gin, | |||
ratio_gout, | |||
stride, | |||
padding, | |||
dilation, | |||
groups, | |||
bias, | |||
enable_lfu, | |||
padding_type=padding_type, | |||
**kwargs) | |||
lnorm = nn.Identity if ratio_gout == 1 else norm_layer | |||
gnorm = nn.Identity if ratio_gout == 0 else norm_layer | |||
global_channels = int(out_channels * ratio_gout) | |||
self.bn_l = lnorm(out_channels - global_channels) | |||
self.bn_g = gnorm(global_channels) | |||
lact = nn.Identity if ratio_gout == 1 else activation_layer | |||
gact = nn.Identity if ratio_gout == 0 else activation_layer | |||
self.act_l = lact(inplace=True) | |||
self.act_g = gact(inplace=True) | |||
def forward(self, x): | |||
x_l, x_g = self.ffc(x) | |||
x_l = self.act_l(self.bn_l(x_l)) | |||
x_g = self.act_g(self.bn_g(x_g)) | |||
return x_l, x_g | |||
class FFCResnetBlock(nn.Module): | |||
def __init__(self, | |||
dim, | |||
padding_type, | |||
norm_layer, | |||
activation_layer=nn.ReLU, | |||
dilation=1, | |||
spatial_transform_kwargs=None, | |||
inline=False, | |||
**conv_kwargs): | |||
super().__init__() | |||
self.conv1 = FFC_BN_ACT( | |||
dim, | |||
dim, | |||
kernel_size=3, | |||
padding=dilation, | |||
dilation=dilation, | |||
norm_layer=norm_layer, | |||
activation_layer=activation_layer, | |||
padding_type=padding_type, | |||
**conv_kwargs) | |||
self.conv2 = FFC_BN_ACT( | |||
dim, | |||
dim, | |||
kernel_size=3, | |||
padding=dilation, | |||
dilation=dilation, | |||
norm_layer=norm_layer, | |||
activation_layer=activation_layer, | |||
padding_type=padding_type, | |||
**conv_kwargs) | |||
if spatial_transform_kwargs is not None: | |||
self.conv1 = LearnableSpatialTransformWrapper( | |||
self.conv1, **spatial_transform_kwargs) | |||
self.conv2 = LearnableSpatialTransformWrapper( | |||
self.conv2, **spatial_transform_kwargs) | |||
self.inline = inline | |||
def forward(self, x): | |||
if self.inline: | |||
x_l, x_g = x[:, :-self.conv1.ffc. | |||
global_in_num], x[:, -self.conv1.ffc.global_in_num:] | |||
else: | |||
x_l, x_g = x if type(x) is tuple else (x, 0) | |||
id_l, id_g = x_l, x_g | |||
x_l, x_g = self.conv1((x_l, x_g)) | |||
x_l, x_g = self.conv2((x_l, x_g)) | |||
x_l, x_g = id_l + x_l, id_g + x_g | |||
out = x_l, x_g | |||
if self.inline: | |||
out = torch.cat(out, dim=1) | |||
return out | |||
class ConcatTupleLayer(nn.Module): | |||
def forward(self, x): | |||
assert isinstance(x, tuple) | |||
x_l, x_g = x | |||
assert torch.is_tensor(x_l) or torch.is_tensor(x_g) | |||
if not torch.is_tensor(x_g): | |||
return x_l | |||
return torch.cat(x, dim=1) | |||
class FFCResNetGenerator(nn.Module): | |||
def __init__(self, | |||
input_nc=4, | |||
output_nc=3, | |||
ngf=64, | |||
n_downsampling=3, | |||
n_blocks=18, | |||
norm_layer=nn.BatchNorm2d, | |||
padding_type='reflect', | |||
activation_layer=nn.ReLU, | |||
up_norm_layer=nn.BatchNorm2d, | |||
up_activation=nn.ReLU(True), | |||
init_conv_kwargs={ | |||
'ratio_gin': 0, | |||
'ratio_gout': 0, | |||
'enable_lfu': False | |||
}, | |||
downsample_conv_kwargs={ | |||
'ratio_gin': 0, | |||
'ratio_gout': 0, | |||
'enable_lfu': False | |||
}, | |||
resnet_conv_kwargs={ | |||
'ratio_gin': 0.75, | |||
'ratio_gout': 0.75, | |||
'enable_lfu': False | |||
}, | |||
spatial_transform_layers=None, | |||
spatial_transform_kwargs={}, | |||
add_out_act='sigmoid', | |||
max_features=1024, | |||
out_ffc=False, | |||
out_ffc_kwargs={}): | |||
assert (n_blocks >= 0) | |||
super().__init__() | |||
model = [ | |||
nn.ReflectionPad2d(3), | |||
FFC_BN_ACT( | |||
input_nc, | |||
ngf, | |||
kernel_size=7, | |||
padding=0, | |||
norm_layer=norm_layer, | |||
activation_layer=activation_layer, | |||
**init_conv_kwargs) | |||
] | |||
# downsample | |||
for i in range(n_downsampling): | |||
mult = 2**i | |||
if i == n_downsampling - 1: | |||
cur_conv_kwargs = dict(downsample_conv_kwargs) | |||
cur_conv_kwargs['ratio_gout'] = resnet_conv_kwargs.get( | |||
'ratio_gin', 0) | |||
else: | |||
cur_conv_kwargs = downsample_conv_kwargs | |||
model += [ | |||
FFC_BN_ACT( | |||
min(max_features, ngf * mult), | |||
min(max_features, ngf * mult * 2), | |||
kernel_size=3, | |||
stride=2, | |||
padding=1, | |||
norm_layer=norm_layer, | |||
activation_layer=activation_layer, | |||
**cur_conv_kwargs) | |||
] | |||
mult = 2**n_downsampling | |||
feats_num_bottleneck = min(max_features, ngf * mult) | |||
# resnet blocks | |||
for i in range(n_blocks): | |||
cur_resblock = FFCResnetBlock( | |||
feats_num_bottleneck, | |||
padding_type=padding_type, | |||
activation_layer=activation_layer, | |||
norm_layer=norm_layer, | |||
**resnet_conv_kwargs) | |||
if spatial_transform_layers is not None and i in spatial_transform_layers: | |||
cur_resblock = LearnableSpatialTransformWrapper( | |||
cur_resblock, **spatial_transform_kwargs) | |||
model += [cur_resblock] | |||
model += [ConcatTupleLayer()] | |||
# upsample | |||
for i in range(n_downsampling): | |||
mult = 2**(n_downsampling - i) | |||
model += [ | |||
nn.ConvTranspose2d( | |||
min(max_features, ngf * mult), | |||
min(max_features, int(ngf * mult / 2)), | |||
kernel_size=3, | |||
stride=2, | |||
padding=1, | |||
output_padding=1), | |||
up_norm_layer(min(max_features, int(ngf * mult / 2))), | |||
up_activation | |||
] | |||
if out_ffc: | |||
model += [ | |||
FFCResnetBlock( | |||
ngf, | |||
padding_type=padding_type, | |||
activation_layer=activation_layer, | |||
norm_layer=norm_layer, | |||
inline=True, | |||
**out_ffc_kwargs) | |||
] | |||
model += [ | |||
nn.ReflectionPad2d(3), | |||
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0) | |||
] | |||
if add_out_act: | |||
model.append( | |||
get_activation('tanh' if add_out_act is True else add_out_act)) | |||
self.model = nn.Sequential(*model) | |||
def forward(self, input): | |||
return self.model(input) |
@@ -0,0 +1,324 @@ | |||
""" | |||
Part of the implementation is borrowed and modified from LaMa, publicly available at | |||
https://github.com/saic-mdal/lama | |||
""" | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from torchvision import models | |||
from modelscope.utils.logger import get_logger | |||
try: | |||
from torchvision.models.utils import load_state_dict_from_url | |||
except ImportError: | |||
from torch.utils.model_zoo import load_url as load_state_dict_from_url | |||
# Inception weights ported to Pytorch from | |||
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz | |||
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/' \ | |||
'fid_weights/pt_inception-2015-12-05-6726825d.pth' | |||
LOGGER = get_logger() | |||
class InceptionV3(nn.Module): | |||
"""Pretrained InceptionV3 network returning feature maps""" | |||
# Index of default block of inception to return, | |||
# corresponds to output of final average pooling | |||
DEFAULT_BLOCK_INDEX = 3 | |||
# Maps feature dimensionality to their output blocks indices | |||
BLOCK_INDEX_BY_DIM = { | |||
64: 0, # First max pooling features | |||
192: 1, # Second max pooling featurs | |||
768: 2, # Pre-aux classifier features | |||
2048: 3 # Final average pooling features | |||
} | |||
def __init__(self, | |||
output_blocks=[DEFAULT_BLOCK_INDEX], | |||
resize_input=True, | |||
normalize_input=True, | |||
requires_grad=False, | |||
use_fid_inception=True): | |||
"""Build pretrained InceptionV3 | |||
Parameters | |||
---------- | |||
output_blocks : list of int | |||
Indices of blocks to return features of. Possible values are: | |||
- 0: corresponds to output of first max pooling | |||
- 1: corresponds to output of second max pooling | |||
- 2: corresponds to output which is fed to aux classifier | |||
- 3: corresponds to output of final average pooling | |||
resize_input : bool | |||
If true, bilinearly resizes input to width and height 299 before | |||
feeding input to model. As the network without fully connected | |||
layers is fully convolutional, it should be able to handle inputs | |||
of arbitrary size, so resizing might not be strictly needed | |||
normalize_input : bool | |||
If true, scales the input from range (0, 1) to the range the | |||
pretrained Inception network expects, namely (-1, 1) | |||
requires_grad : bool | |||
If true, parameters of the model require gradients. Possibly useful | |||
for finetuning the network | |||
use_fid_inception : bool | |||
If true, uses the pretrained Inception model used in Tensorflow's | |||
FID implementation. If false, uses the pretrained Inception model | |||
available in torchvision. The FID Inception model has different | |||
weights and a slightly different structure from torchvision's | |||
Inception model. If you want to compute FID scores, you are | |||
strongly advised to set this parameter to true to get comparable | |||
results. | |||
""" | |||
super(InceptionV3, self).__init__() | |||
self.resize_input = resize_input | |||
self.normalize_input = normalize_input | |||
self.output_blocks = sorted(output_blocks) | |||
self.last_needed_block = max(output_blocks) | |||
assert self.last_needed_block <= 3, \ | |||
'Last possible output block index is 3' | |||
self.blocks = nn.ModuleList() | |||
if use_fid_inception: | |||
inception = fid_inception_v3() | |||
else: | |||
inception = models.inception_v3(pretrained=True) | |||
# Block 0: input to maxpool1 | |||
block0 = [ | |||
inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, | |||
inception.Conv2d_2b_3x3, | |||
nn.MaxPool2d(kernel_size=3, stride=2) | |||
] | |||
self.blocks.append(nn.Sequential(*block0)) | |||
# Block 1: maxpool1 to maxpool2 | |||
if self.last_needed_block >= 1: | |||
block1 = [ | |||
inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, | |||
nn.MaxPool2d(kernel_size=3, stride=2) | |||
] | |||
self.blocks.append(nn.Sequential(*block1)) | |||
# Block 2: maxpool2 to aux classifier | |||
if self.last_needed_block >= 2: | |||
block2 = [ | |||
inception.Mixed_5b, | |||
inception.Mixed_5c, | |||
inception.Mixed_5d, | |||
inception.Mixed_6a, | |||
inception.Mixed_6b, | |||
inception.Mixed_6c, | |||
inception.Mixed_6d, | |||
inception.Mixed_6e, | |||
] | |||
self.blocks.append(nn.Sequential(*block2)) | |||
# Block 3: aux classifier to final avgpool | |||
if self.last_needed_block >= 3: | |||
block3 = [ | |||
inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c, | |||
nn.AdaptiveAvgPool2d(output_size=(1, 1)) | |||
] | |||
self.blocks.append(nn.Sequential(*block3)) | |||
for param in self.parameters(): | |||
param.requires_grad = requires_grad | |||
def forward(self, inp): | |||
"""Get Inception feature maps | |||
Parameters | |||
---------- | |||
inp : torch.autograd.Variable | |||
Input tensor of shape Bx3xHxW. Values are expected to be in | |||
range (0, 1) | |||
Returns | |||
------- | |||
List of torch.autograd.Variable, corresponding to the selected output | |||
block, sorted ascending by index | |||
""" | |||
outp = [] | |||
x = inp | |||
if self.resize_input: | |||
x = F.interpolate( | |||
x, size=(299, 299), mode='bilinear', align_corners=False) | |||
if self.normalize_input: | |||
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) | |||
for idx, block in enumerate(self.blocks): | |||
x = block(x) | |||
if idx in self.output_blocks: | |||
outp.append(x) | |||
if idx == self.last_needed_block: | |||
break | |||
return outp | |||
def fid_inception_v3(): | |||
"""Build pretrained Inception model for FID computation | |||
The Inception model for FID computation uses a different set of weights | |||
and has a slightly different structure than torchvision's Inception. | |||
This method first constructs torchvision's Inception and then patches the | |||
necessary parts that are different in the FID Inception model. | |||
""" | |||
LOGGER.info('fid_inception_v3 called') | |||
inception = models.inception_v3( | |||
num_classes=1008, aux_logits=False, pretrained=False) | |||
LOGGER.info('models.inception_v3 done') | |||
inception.Mixed_5b = FIDInceptionA(192, pool_features=32) | |||
inception.Mixed_5c = FIDInceptionA(256, pool_features=64) | |||
inception.Mixed_5d = FIDInceptionA(288, pool_features=64) | |||
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) | |||
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) | |||
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) | |||
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) | |||
inception.Mixed_7b = FIDInceptionE_1(1280) | |||
inception.Mixed_7c = FIDInceptionE_2(2048) | |||
LOGGER.info('fid_inception_v3 patching done') | |||
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) | |||
LOGGER.info('fid_inception_v3 weights downloaded') | |||
inception.load_state_dict(state_dict) | |||
LOGGER.info('fid_inception_v3 weights loaded into model') | |||
return inception | |||
class FIDInceptionA(models.inception.InceptionA): | |||
"""InceptionA block patched for FID computation""" | |||
def __init__(self, in_channels, pool_features): | |||
super(FIDInceptionA, self).__init__(in_channels, pool_features) | |||
def forward(self, x): | |||
branch1x1 = self.branch1x1(x) | |||
branch5x5 = self.branch5x5_1(x) | |||
branch5x5 = self.branch5x5_2(branch5x5) | |||
branch3x3dbl = self.branch3x3dbl_1(x) | |||
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) | |||
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) | |||
# Patch: Tensorflow's average pool does not use the padded zero's in | |||
# its average calculation | |||
branch_pool = F.avg_pool2d( | |||
x, kernel_size=3, stride=1, padding=1, count_include_pad=False) | |||
branch_pool = self.branch_pool(branch_pool) | |||
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] | |||
return torch.cat(outputs, 1) | |||
class FIDInceptionC(models.inception.InceptionC): | |||
"""InceptionC block patched for FID computation""" | |||
def __init__(self, in_channels, channels_7x7): | |||
super(FIDInceptionC, self).__init__(in_channels, channels_7x7) | |||
def forward(self, x): | |||
branch1x1 = self.branch1x1(x) | |||
branch7x7 = self.branch7x7_1(x) | |||
branch7x7 = self.branch7x7_2(branch7x7) | |||
branch7x7 = self.branch7x7_3(branch7x7) | |||
branch7x7dbl = self.branch7x7dbl_1(x) | |||
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) | |||
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) | |||
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) | |||
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) | |||
# Patch: Tensorflow's average pool does not use the padded zero's in | |||
# its average calculation | |||
branch_pool = F.avg_pool2d( | |||
x, kernel_size=3, stride=1, padding=1, count_include_pad=False) | |||
branch_pool = self.branch_pool(branch_pool) | |||
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] | |||
return torch.cat(outputs, 1) | |||
class FIDInceptionE_1(models.inception.InceptionE): | |||
"""First InceptionE block patched for FID computation""" | |||
def __init__(self, in_channels): | |||
super(FIDInceptionE_1, self).__init__(in_channels) | |||
def forward(self, x): | |||
branch1x1 = self.branch1x1(x) | |||
branch3x3 = self.branch3x3_1(x) | |||
branch3x3 = [ | |||
self.branch3x3_2a(branch3x3), | |||
self.branch3x3_2b(branch3x3), | |||
] | |||
branch3x3 = torch.cat(branch3x3, 1) | |||
branch3x3dbl = self.branch3x3dbl_1(x) | |||
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) | |||
branch3x3dbl = [ | |||
self.branch3x3dbl_3a(branch3x3dbl), | |||
self.branch3x3dbl_3b(branch3x3dbl), | |||
] | |||
branch3x3dbl = torch.cat(branch3x3dbl, 1) | |||
# Patch: Tensorflow's average pool does not use the padded zero's in | |||
# its average calculation | |||
branch_pool = F.avg_pool2d( | |||
x, kernel_size=3, stride=1, padding=1, count_include_pad=False) | |||
branch_pool = self.branch_pool(branch_pool) | |||
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] | |||
return torch.cat(outputs, 1) | |||
class FIDInceptionE_2(models.inception.InceptionE): | |||
"""Second InceptionE block patched for FID computation""" | |||
def __init__(self, in_channels): | |||
super(FIDInceptionE_2, self).__init__(in_channels) | |||
def forward(self, x): | |||
branch1x1 = self.branch1x1(x) | |||
branch3x3 = self.branch3x3_1(x) | |||
branch3x3 = [ | |||
self.branch3x3_2a(branch3x3), | |||
self.branch3x3_2b(branch3x3), | |||
] | |||
branch3x3 = torch.cat(branch3x3, 1) | |||
branch3x3dbl = self.branch3x3dbl_1(x) | |||
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) | |||
branch3x3dbl = [ | |||
self.branch3x3dbl_3a(branch3x3dbl), | |||
self.branch3x3dbl_3b(branch3x3dbl), | |||
] | |||
branch3x3dbl = torch.cat(branch3x3dbl, 1) | |||
# Patch: The FID Inception model uses max pooling instead of average | |||
# pooling. This is likely an error in this specific Inception | |||
# implementation, as other Inception models use average pooling here | |||
# (which matches the description in the paper). | |||
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) | |||
branch_pool = self.branch_pool(branch_pool) | |||
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] | |||
return torch.cat(outputs, 1) |
@@ -0,0 +1,47 @@ | |||
""" | |||
Part of the implementation is borrowed and modified from LaMa, publicly available at | |||
https://github.com/saic-mdal/lama | |||
""" | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
import torchvision | |||
from .ade20k import ModelBuilder | |||
IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None] | |||
IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None] | |||
class ResNetPL(nn.Module): | |||
def __init__(self, | |||
weight=1, | |||
weights_path=None, | |||
arch_encoder='resnet50dilated', | |||
segmentation=True): | |||
super().__init__() | |||
self.impl = ModelBuilder.get_encoder( | |||
weights_path=weights_path, | |||
arch_encoder=arch_encoder, | |||
arch_decoder='ppm_deepsup', | |||
fc_dim=2048, | |||
segmentation=segmentation) | |||
self.impl.eval() | |||
for w in self.impl.parameters(): | |||
w.requires_grad_(False) | |||
self.weight = weight | |||
def forward(self, pred, target): | |||
pred = (pred - IMAGENET_MEAN.to(pred)) / IMAGENET_STD.to(pred) | |||
target = (target - IMAGENET_MEAN.to(target)) / IMAGENET_STD.to(target) | |||
pred_feats = self.impl(pred, return_feature_maps=True) | |||
target_feats = self.impl(target, return_feature_maps=True) | |||
result = torch.stack([ | |||
F.mse_loss(cur_pred, cur_target) | |||
for cur_pred, cur_target in zip(pred_feats, target_feats) | |||
]).sum() * self.weight | |||
return result |
@@ -0,0 +1,75 @@ | |||
""" | |||
The implementation is adopted from | |||
https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py | |||
""" | |||
import collections | |||
import functools | |||
import logging | |||
from collections import defaultdict | |||
from functools import partial | |||
import numpy as np | |||
import torch.nn as nn | |||
# Defines the PatchGAN discriminator with the specified arguments. | |||
class NLayerDiscriminator(nn.Module): | |||
def __init__( | |||
self, | |||
input_nc=3, | |||
ndf=64, | |||
n_layers=4, | |||
norm_layer=nn.BatchNorm2d, | |||
): | |||
super().__init__() | |||
self.n_layers = n_layers | |||
kw = 4 | |||
padw = int(np.ceil((kw - 1.0) / 2)) | |||
sequence = [[ | |||
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), | |||
nn.LeakyReLU(0.2, True) | |||
]] | |||
nf = ndf | |||
for n in range(1, n_layers): | |||
nf_prev = nf | |||
nf = min(nf * 2, 512) | |||
cur_model = [] | |||
cur_model += [ | |||
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), | |||
norm_layer(nf), | |||
nn.LeakyReLU(0.2, True) | |||
] | |||
sequence.append(cur_model) | |||
nf_prev = nf | |||
nf = min(nf * 2, 512) | |||
cur_model = [] | |||
cur_model += [ | |||
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), | |||
norm_layer(nf), | |||
nn.LeakyReLU(0.2, True) | |||
] | |||
sequence.append(cur_model) | |||
sequence += [[ | |||
nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw) | |||
]] | |||
for n in range(len(sequence)): | |||
setattr(self, 'model' + str(n), nn.Sequential(*sequence[n])) | |||
def get_all_activations(self, x): | |||
res = [x] | |||
for n in range(self.n_layers + 2): | |||
model = getattr(self, 'model' + str(n)) | |||
res.append(model(res[-1])) | |||
return res[1:] | |||
def forward(self, x): | |||
act = self.get_all_activations(x) | |||
return act[-1], act[:-1] |
@@ -0,0 +1,393 @@ | |||
''' | |||
Part of the implementation is borrowed and modified from LaMa, publicly available at | |||
https://github.com/saic-mdal/lama | |||
''' | |||
import cv2 | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
from kornia.filters import gaussian_blur2d | |||
from kornia.geometry.transform import resize | |||
from kornia.morphology import erosion | |||
from torch.nn import functional as F | |||
from torch.optim import SGD, Adam | |||
from tqdm import tqdm | |||
from .modules.ffc import FFCResnetBlock | |||
def move_to_device(obj, device): | |||
if isinstance(obj, nn.Module): | |||
return obj.to(device) | |||
if torch.is_tensor(obj): | |||
return obj.to(device) | |||
if isinstance(obj, (tuple, list)): | |||
return [move_to_device(el, device) for el in obj] | |||
if isinstance(obj, dict): | |||
return {name: move_to_device(val, device) for name, val in obj.items()} | |||
raise ValueError(f'Unexpected type {type(obj)}') | |||
def ceil_modulo(x, mod): | |||
if x % mod == 0: | |||
return x | |||
return (x // mod + 1) * mod | |||
def pad_tensor_to_modulo(img, mod): | |||
batch_size, channels, height, width = img.shape | |||
out_height = ceil_modulo(height, mod) | |||
out_width = ceil_modulo(width, mod) | |||
return F.pad( | |||
img, | |||
pad=(0, out_width - width, 0, out_height - height), | |||
mode='reflect') | |||
def _pyrdown(im: torch.Tensor, downsize: tuple = None): | |||
"""downscale the image""" | |||
if downsize is None: | |||
downsize = (im.shape[2] // 2, im.shape[3] // 2) | |||
assert im.shape[ | |||
1] == 3, 'Expected shape for the input to be (n,3,height,width)' | |||
im = gaussian_blur2d(im, kernel_size=(5, 5), sigma=(1.0, 1.0)) | |||
im = F.interpolate(im, size=downsize, mode='bilinear', align_corners=False) | |||
return im | |||
def _pyrdown_mask(mask: torch.Tensor, | |||
downsize: tuple = None, | |||
eps: float = 1e-8, | |||
blur_mask: bool = True, | |||
round_up: bool = True): | |||
"""downscale the mask tensor | |||
Parameters | |||
---------- | |||
mask : torch.Tensor | |||
mask of size (B, 1, H, W) | |||
downsize : tuple, optional | |||
size to downscale to. If None, image is downscaled to half, by default None | |||
eps : float, optional | |||
threshold value for binarizing the mask, by default 1e-8 | |||
blur_mask : bool, optional | |||
if True, apply gaussian filter before downscaling, by default True | |||
round_up : bool, optional | |||
if True, values above eps are marked 1, else, values below 1-eps are marked 0, by default True | |||
Returns | |||
------- | |||
torch.Tensor | |||
downscaled mask | |||
""" | |||
if downsize is None: | |||
downsize = (mask.shape[2] // 2, mask.shape[3] // 2) | |||
assert mask.shape[ | |||
1] == 1, 'Expected shape for the input to be (n,1,height,width)' | |||
if blur_mask is True: | |||
mask = gaussian_blur2d(mask, kernel_size=(5, 5), sigma=(1.0, 1.0)) | |||
mask = F.interpolate( | |||
mask, size=downsize, mode='bilinear', align_corners=False) | |||
else: | |||
mask = F.interpolate( | |||
mask, size=downsize, mode='bilinear', align_corners=False) | |||
if round_up: | |||
mask[mask >= eps] = 1 | |||
mask[mask < eps] = 0 | |||
else: | |||
mask[mask >= 1.0 - eps] = 1 | |||
mask[mask < 1.0 - eps] = 0 | |||
return mask | |||
def _erode_mask(mask: torch.Tensor, | |||
ekernel: torch.Tensor = None, | |||
eps: float = 1e-8): | |||
"""erode the mask, and set gray pixels to 0""" | |||
if ekernel is not None: | |||
mask = erosion(mask, ekernel) | |||
mask[mask >= 1.0 - eps] = 1 | |||
mask[mask < 1.0 - eps] = 0 | |||
return mask | |||
def _l1_loss(pred: torch.Tensor, | |||
pred_downscaled: torch.Tensor, | |||
ref: torch.Tensor, | |||
mask: torch.Tensor, | |||
mask_downscaled: torch.Tensor, | |||
image: torch.Tensor, | |||
on_pred: bool = True): | |||
"""l1 loss on src pixels, and downscaled predictions if on_pred=True""" | |||
loss = torch.mean(torch.abs(pred[mask < 1e-8] - image[mask < 1e-8])) | |||
if on_pred: | |||
loss += torch.mean( | |||
torch.abs(pred_downscaled[mask_downscaled >= 1e-8] | |||
- ref[mask_downscaled >= 1e-8])) | |||
return loss | |||
def _infer(image: torch.Tensor, | |||
mask: torch.Tensor, | |||
forward_front: nn.Module, | |||
forward_rears: nn.Module, | |||
ref_lower_res: torch.Tensor, | |||
orig_shape: tuple, | |||
devices: list, | |||
scale_ind: int, | |||
n_iters: int = 15, | |||
lr: float = 0.002): | |||
"""Performs inference with refinement at a given scale. | |||
Parameters | |||
---------- | |||
image : torch.Tensor | |||
input image to be inpainted, of size (1,3,H,W) | |||
mask : torch.Tensor | |||
input inpainting mask, of size (1,1,H,W) | |||
forward_front : nn.Module | |||
the front part of the inpainting network | |||
forward_rears : nn.Module | |||
the rear part of the inpainting network | |||
ref_lower_res : torch.Tensor | |||
the inpainting at previous scale, used as reference image | |||
orig_shape : tuple | |||
shape of the original input image before padding | |||
devices : list | |||
list of available devices | |||
scale_ind : int | |||
the scale index | |||
n_iters : int, optional | |||
number of iterations of refinement, by default 15 | |||
lr : float, optional | |||
learning rate, by default 0.002 | |||
Returns | |||
------- | |||
torch.Tensor | |||
inpainted image | |||
""" | |||
masked_image = image * (1 - mask) | |||
masked_image = torch.cat([masked_image, mask], dim=1) | |||
mask = mask.repeat(1, 3, 1, 1) | |||
if ref_lower_res is not None: | |||
ref_lower_res = ref_lower_res.detach() | |||
with torch.no_grad(): | |||
z1, z2 = forward_front(masked_image) | |||
# Inference | |||
mask = mask.to(devices[-1]) | |||
ekernel = torch.from_numpy( | |||
cv2.getStructuringElement(cv2.MORPH_ELLIPSE, | |||
(15, 15)).astype(bool)).float() | |||
ekernel = ekernel.to(devices[-1]) | |||
image = image.to(devices[-1]) | |||
z1, z2 = z1.detach().to(devices[0]), z2.detach().to(devices[0]) | |||
z1.requires_grad, z2.requires_grad = True, True | |||
optimizer = Adam([z1, z2], lr=lr) | |||
pbar = tqdm(range(n_iters), leave=False) | |||
for idi in pbar: | |||
optimizer.zero_grad() | |||
input_feat = (z1, z2) | |||
for idd, forward_rear in enumerate(forward_rears): | |||
output_feat = forward_rear(input_feat) | |||
if idd < len(devices) - 1: | |||
midz1, midz2 = output_feat | |||
midz1, midz2 = midz1.to(devices[idd + 1]), midz2.to( | |||
devices[idd + 1]) | |||
input_feat = (midz1, midz2) | |||
else: | |||
pred = output_feat | |||
if ref_lower_res is None: | |||
break | |||
losses = {} | |||
# scaled loss with downsampler | |||
pred_downscaled = _pyrdown(pred[:, :, :orig_shape[0], :orig_shape[1]]) | |||
mask_downscaled = _pyrdown_mask( | |||
mask[:, :1, :orig_shape[0], :orig_shape[1]], | |||
blur_mask=False, | |||
round_up=False) | |||
mask_downscaled = _erode_mask(mask_downscaled, ekernel=ekernel) | |||
mask_downscaled = mask_downscaled.repeat(1, 3, 1, 1) | |||
losses['ms_l1'] = _l1_loss( | |||
pred, | |||
pred_downscaled, | |||
ref_lower_res, | |||
mask, | |||
mask_downscaled, | |||
image, | |||
on_pred=True) | |||
loss = sum(losses.values()) | |||
pbar.set_description( | |||
'Refining scale {} using scale {} ...current loss: {:.4f}'.format( | |||
scale_ind + 1, scale_ind, loss.item())) | |||
if idi < n_iters - 1: | |||
loss.backward() | |||
optimizer.step() | |||
del pred_downscaled | |||
del loss | |||
del pred | |||
# "pred" is the prediction after Plug-n-Play module | |||
inpainted = mask * pred + (1 - mask) * image | |||
inpainted = inpainted.detach().cpu() | |||
return inpainted | |||
def _get_image_mask_pyramid(batch: dict, min_side: int, max_scales: int, | |||
px_budget: int): | |||
"""Build the image mask pyramid | |||
Parameters | |||
---------- | |||
batch : dict | |||
batch containing image, mask, etc | |||
min_side : int | |||
minimum side length to limit the number of scales of the pyramid | |||
max_scales : int | |||
maximum number of scales allowed | |||
px_budget : int | |||
the product H*W cannot exceed this budget, because of resource constraints | |||
Returns | |||
------- | |||
tuple | |||
image-mask pyramid in the form of list of images and list of masks | |||
""" | |||
assert batch['image'].shape[ | |||
0] == 1, 'refiner works on only batches of size 1!' | |||
h, w = batch['unpad_to_size'] | |||
h, w = h[0].item(), w[0].item() | |||
image = batch['image'][..., :h, :w] | |||
mask = batch['mask'][..., :h, :w] | |||
if h * w > px_budget: | |||
# resize | |||
ratio = np.sqrt(px_budget / float(h * w)) | |||
h_orig, w_orig = h, w | |||
h, w = int(h * ratio), int(w * ratio) | |||
print( | |||
f'Original image too large for refinement! Resizing {(h_orig,w_orig)} to {(h,w)}...' | |||
) | |||
image = resize( | |||
image, (h, w), interpolation='bilinear', align_corners=False) | |||
mask = resize( | |||
mask, (h, w), interpolation='bilinear', align_corners=False) | |||
mask[mask > 1e-8] = 1 | |||
breadth = min(h, w) | |||
n_scales = min(1 + int(round(max(0, np.log2(breadth / min_side)))), | |||
max_scales) | |||
ls_images = [] | |||
ls_masks = [] | |||
ls_images.append(image) | |||
ls_masks.append(mask) | |||
for _ in range(n_scales - 1): | |||
image_p = _pyrdown(ls_images[-1]) | |||
mask_p = _pyrdown_mask(ls_masks[-1]) | |||
ls_images.append(image_p) | |||
ls_masks.append(mask_p) | |||
# reverse the lists because we want the lowest resolution image as index 0 | |||
return ls_images[::-1], ls_masks[::-1] | |||
def refine_predict(batch: dict, inpainter: nn.Module, gpu_ids: str, | |||
modulo: int, n_iters: int, lr: float, min_side: int, | |||
max_scales: int, px_budget: int): | |||
"""Refines the inpainting of the network | |||
Parameters | |||
---------- | |||
batch : dict | |||
image-mask batch, currently we assume the batchsize to be 1 | |||
inpainter : nn.Module | |||
the inpainting neural network | |||
gpu_ids : str | |||
the GPU ids of the machine to use. If only single GPU, use: "0," | |||
modulo : int | |||
pad the image to ensure dimension % modulo == 0 | |||
n_iters : int | |||
number of iterations of refinement for each scale | |||
lr : float | |||
learning rate | |||
min_side : int | |||
all sides of image on all scales should be >= min_side / sqrt(2) | |||
max_scales : int | |||
max number of downscaling scales for the image-mask pyramid | |||
px_budget : int | |||
pixels budget. Any image will be resized to satisfy height*width <= px_budget | |||
Returns | |||
------- | |||
torch.Tensor | |||
inpainted image of size (1,3,H,W) | |||
""" | |||
inpainter = inpainter.model | |||
assert not inpainter.training | |||
assert not inpainter.add_noise_kwargs | |||
assert inpainter.concat_mask | |||
gpu_ids = [ | |||
f'cuda:{gpuid}' for gpuid in gpu_ids.replace(' ', '').split(',') | |||
if gpuid.isdigit() | |||
] | |||
n_resnet_blocks = 0 | |||
first_resblock_ind = 0 | |||
found_first_resblock = False | |||
for idl in range(len(inpainter.generator.model)): | |||
if isinstance(inpainter.generator.model[idl], FFCResnetBlock): | |||
n_resnet_blocks += 1 | |||
found_first_resblock = True | |||
elif not found_first_resblock: | |||
first_resblock_ind += 1 | |||
resblocks_per_gpu = n_resnet_blocks // len(gpu_ids) | |||
devices = [torch.device(gpu_id) for gpu_id in gpu_ids] | |||
# split the model into front, and rear parts | |||
forward_front = inpainter.generator.model[0:first_resblock_ind] | |||
forward_front.to(devices[0]) | |||
forward_rears = [] | |||
for idd in range(len(gpu_ids)): | |||
if idd < len(gpu_ids) - 1: | |||
forward_rears.append( | |||
inpainter.generator.model[first_resblock_ind | |||
+ resblocks_per_gpu | |||
* (idd):first_resblock_ind | |||
+ resblocks_per_gpu * (idd + 1)]) | |||
else: | |||
forward_rears.append( | |||
inpainter.generator.model[first_resblock_ind | |||
+ resblocks_per_gpu * (idd):]) | |||
forward_rears[idd].to(devices[idd]) | |||
ls_images, ls_masks = _get_image_mask_pyramid(batch, min_side, max_scales, | |||
px_budget) | |||
image_inpainted = None | |||
for ids, (image, mask) in enumerate(zip(ls_images, ls_masks)): | |||
orig_shape = image.shape[2:] | |||
image = pad_tensor_to_modulo(image, modulo) | |||
mask = pad_tensor_to_modulo(mask, modulo) | |||
mask[mask >= 1e-8] = 1.0 | |||
mask[mask < 1e-8] = 0.0 | |||
image, mask = move_to_device(image, devices[0]), move_to_device( | |||
mask, devices[0]) | |||
if image_inpainted is not None: | |||
image_inpainted = move_to_device(image_inpainted, devices[-1]) | |||
image_inpainted = _infer(image, mask, forward_front, forward_rears, | |||
image_inpainted, orig_shape, devices, ids, | |||
n_iters, lr) | |||
image_inpainted = image_inpainted[:, :, :orig_shape[0], :orig_shape[1]] | |||
# detach everything to save resources | |||
image = image.detach().cpu() | |||
mask = mask.detach().cpu() | |||
return image_inpainted |
@@ -11,6 +11,7 @@ if TYPE_CHECKING: | |||
from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset | |||
from .movie_scene_segmentation import MovieSceneSegmentationDataset | |||
from .video_summarization_dataset import VideoSummarizationDataset | |||
from .image_inpainting import ImageInpaintingDataset | |||
from .passage_ranking_dataset import PassageRankingDataset | |||
else: | |||
@@ -24,6 +25,7 @@ else: | |||
['ImageInstanceSegmentationCocoDataset'], | |||
'video_summarization_dataset': ['VideoSummarizationDataset'], | |||
'movie_scene_segmentation': ['MovieSceneSegmentationDataset'], | |||
'image_inpainting': ['ImageInpaintingDataset'], | |||
} | |||
import sys | |||
@@ -0,0 +1,2 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from .image_inpainting_dataset import ImageInpaintingDataset |
@@ -0,0 +1,100 @@ | |||
""" | |||
The implementation is borrowed from LaMa, | |||
publicly available at https://github.com/saic-mdal/lama | |||
""" | |||
import imgaug.augmenters as iaa | |||
from albumentations import DualIAATransform, to_tuple | |||
class IAAAffine2(DualIAATransform): | |||
"""Place a regular grid of points on the input and randomly move the neighbourhood of these point around | |||
via affine transformations. | |||
Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} | |||
Args: | |||
p (float): probability of applying the transform. Default: 0.5. | |||
Targets: | |||
image, mask | |||
""" | |||
def __init__( | |||
self, | |||
scale=(0.7, 1.3), | |||
translate_percent=None, | |||
translate_px=None, | |||
rotate=0.0, | |||
shear=(-0.1, 0.1), | |||
order=1, | |||
cval=0, | |||
mode='reflect', | |||
always_apply=False, | |||
p=0.5, | |||
): | |||
super(IAAAffine2, self).__init__(always_apply, p) | |||
self.scale = dict(x=scale, y=scale) | |||
self.translate_percent = to_tuple(translate_percent, 0) | |||
self.translate_px = to_tuple(translate_px, 0) | |||
self.rotate = to_tuple(rotate) | |||
self.shear = dict(x=shear, y=shear) | |||
self.order = order | |||
self.cval = cval | |||
self.mode = mode | |||
@property | |||
def processor(self): | |||
return iaa.Affine( | |||
self.scale, | |||
self.translate_percent, | |||
self.translate_px, | |||
self.rotate, | |||
self.shear, | |||
self.order, | |||
self.cval, | |||
self.mode, | |||
) | |||
def get_transform_init_args_names(self): | |||
return ('scale', 'translate_percent', 'translate_px', 'rotate', | |||
'shear', 'order', 'cval', 'mode') | |||
class IAAPerspective2(DualIAATransform): | |||
"""Perform a random four point perspective transform of the input. | |||
Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} | |||
Args: | |||
scale ((float, float): standard deviation of the normal distributions. These are used to sample | |||
the random distances of the subimage's corners from the full image's corners. Default: (0.05, 0.1). | |||
p (float): probability of applying the transform. Default: 0.5. | |||
Targets: | |||
image, mask | |||
""" | |||
def __init__(self, | |||
scale=(0.05, 0.1), | |||
keep_size=True, | |||
always_apply=False, | |||
p=0.5, | |||
order=1, | |||
cval=0, | |||
mode='replicate'): | |||
super(IAAPerspective2, self).__init__(always_apply, p) | |||
self.scale = to_tuple(scale, 1.0) | |||
self.keep_size = keep_size | |||
self.cval = cval | |||
self.mode = mode | |||
@property | |||
def processor(self): | |||
return iaa.PerspectiveTransform( | |||
self.scale, | |||
keep_size=self.keep_size, | |||
mode=self.mode, | |||
cval=self.cval) | |||
def get_transform_init_args_names(self): | |||
return ('scale', 'keep_size') |
@@ -0,0 +1,337 @@ | |||
""" | |||
Part of the implementation is borrowed and modified from LaMa, | |||
publicly available at https://github.com/saic-mdal/lama | |||
""" | |||
import glob | |||
import os | |||
import os.path as osp | |||
from enum import Enum | |||
import albumentations as A | |||
import cv2 | |||
import json | |||
import numpy as np | |||
import torch | |||
from modelscope.metainfo import Models | |||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS | |||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \ | |||
TorchTaskDataset | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
from .aug import IAAAffine2, IAAPerspective2 | |||
LOGGER = get_logger() | |||
class LinearRamp: | |||
def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0): | |||
self.start_value = start_value | |||
self.end_value = end_value | |||
self.start_iter = start_iter | |||
self.end_iter = end_iter | |||
def __call__(self, i): | |||
if i < self.start_iter: | |||
return self.start_value | |||
if i >= self.end_iter: | |||
return self.end_value | |||
part = (i - self.start_iter) / (self.end_iter - self.start_iter) | |||
return self.start_value * (1 - part) + self.end_value * part | |||
class DrawMethod(Enum): | |||
LINE = 'line' | |||
CIRCLE = 'circle' | |||
SQUARE = 'square' | |||
def make_random_superres_mask(shape, | |||
min_step=2, | |||
max_step=4, | |||
min_width=1, | |||
max_width=3): | |||
height, width = shape | |||
mask = np.zeros((height, width), np.float32) | |||
step_x = np.random.randint(min_step, max_step + 1) | |||
width_x = np.random.randint(min_width, min(step_x, max_width + 1)) | |||
offset_x = np.random.randint(0, step_x) | |||
step_y = np.random.randint(min_step, max_step + 1) | |||
width_y = np.random.randint(min_width, min(step_y, max_width + 1)) | |||
offset_y = np.random.randint(0, step_y) | |||
for dy in range(width_y): | |||
mask[offset_y + dy::step_y] = 1 | |||
for dx in range(width_x): | |||
mask[:, offset_x + dx::step_x] = 1 | |||
return mask[None, ...] | |||
class RandomSuperresMaskGenerator: | |||
def __init__(self, **kwargs): | |||
self.kwargs = kwargs | |||
def __call__(self, img, iter_i=None): | |||
return make_random_superres_mask(img.shape[1:], **self.kwargs) | |||
def make_random_rectangle_mask(shape, | |||
margin=10, | |||
bbox_min_size=30, | |||
bbox_max_size=100, | |||
min_times=0, | |||
max_times=3): | |||
height, width = shape | |||
mask = np.zeros((height, width), np.float32) | |||
bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2) | |||
times = np.random.randint(min_times, max_times + 1) | |||
for i in range(times): | |||
box_width = np.random.randint(bbox_min_size, bbox_max_size) | |||
box_height = np.random.randint(bbox_min_size, bbox_max_size) | |||
start_x = np.random.randint(margin, width - margin - box_width + 1) | |||
start_y = np.random.randint(margin, height - margin - box_height + 1) | |||
mask[start_y:start_y + box_height, start_x:start_x + box_width] = 1 | |||
return mask[None, ...] | |||
class RandomRectangleMaskGenerator: | |||
def __init__(self, | |||
margin=10, | |||
bbox_min_size=30, | |||
bbox_max_size=100, | |||
min_times=0, | |||
max_times=3, | |||
ramp_kwargs=None): | |||
self.margin = margin | |||
self.bbox_min_size = bbox_min_size | |||
self.bbox_max_size = bbox_max_size | |||
self.min_times = min_times | |||
self.max_times = max_times | |||
self.ramp = LinearRamp( | |||
**ramp_kwargs) if ramp_kwargs is not None else None | |||
def __call__(self, img, iter_i=None, raw_image=None): | |||
coef = self.ramp(iter_i) if (self.ramp is not None) and ( | |||
iter_i is not None) else 1 | |||
cur_bbox_max_size = int(self.bbox_min_size + 1 | |||
+ (self.bbox_max_size - self.bbox_min_size) | |||
* coef) | |||
cur_max_times = int(self.min_times | |||
+ (self.max_times - self.min_times) * coef) | |||
return make_random_rectangle_mask( | |||
img.shape[1:], | |||
margin=self.margin, | |||
bbox_min_size=self.bbox_min_size, | |||
bbox_max_size=cur_bbox_max_size, | |||
min_times=self.min_times, | |||
max_times=cur_max_times) | |||
def make_random_irregular_mask(shape, | |||
max_angle=4, | |||
max_len=60, | |||
max_width=20, | |||
min_times=0, | |||
max_times=10, | |||
draw_method=DrawMethod.LINE): | |||
draw_method = DrawMethod(draw_method) | |||
height, width = shape | |||
mask = np.zeros((height, width), np.float32) | |||
times = np.random.randint(min_times, max_times + 1) | |||
for i in range(times): | |||
start_x = np.random.randint(width) | |||
start_y = np.random.randint(height) | |||
for j in range(1 + np.random.randint(5)): | |||
angle = 0.01 + np.random.randint(max_angle) | |||
if i % 2 == 0: | |||
angle = 2 * 3.1415926 - angle | |||
length = 10 + np.random.randint(max_len) | |||
brush_w = 5 + np.random.randint(max_width) | |||
end_x = np.clip( | |||
(start_x + length * np.sin(angle)).astype(np.int32), 0, width) | |||
end_y = np.clip( | |||
(start_y + length * np.cos(angle)).astype(np.int32), 0, height) | |||
if draw_method == DrawMethod.LINE: | |||
cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, | |||
brush_w) | |||
elif draw_method == DrawMethod.CIRCLE: | |||
cv2.circle( | |||
mask, (start_x, start_y), | |||
radius=brush_w, | |||
color=1., | |||
thickness=-1) | |||
elif draw_method == DrawMethod.SQUARE: | |||
radius = brush_w // 2 | |||
mask[start_y - radius:start_y + radius, | |||
start_x - radius:start_x + radius] = 1 | |||
start_x, start_y = end_x, end_y | |||
return mask[None, ...] | |||
class RandomIrregularMaskGenerator: | |||
def __init__(self, | |||
max_angle=4, | |||
max_len=60, | |||
max_width=20, | |||
min_times=0, | |||
max_times=10, | |||
ramp_kwargs=None, | |||
draw_method=DrawMethod.LINE): | |||
self.max_angle = max_angle | |||
self.max_len = max_len | |||
self.max_width = max_width | |||
self.min_times = min_times | |||
self.max_times = max_times | |||
self.draw_method = draw_method | |||
self.ramp = LinearRamp( | |||
**ramp_kwargs) if ramp_kwargs is not None else None | |||
def __call__(self, img, iter_i=None, raw_image=None): | |||
coef = self.ramp(iter_i) if (self.ramp is not None) and ( | |||
iter_i is not None) else 1 | |||
cur_max_len = int(max(1, self.max_len * coef)) | |||
cur_max_width = int(max(1, self.max_width * coef)) | |||
cur_max_times = int(self.min_times + 1 | |||
+ (self.max_times - self.min_times) * coef) | |||
return make_random_irregular_mask( | |||
img.shape[1:], | |||
max_angle=self.max_angle, | |||
max_len=cur_max_len, | |||
max_width=cur_max_width, | |||
min_times=self.min_times, | |||
max_times=cur_max_times, | |||
draw_method=self.draw_method) | |||
class MixedMaskGenerator: | |||
def __init__(self, | |||
irregular_proba=1 / 3, | |||
irregular_kwargs=None, | |||
box_proba=1 / 3, | |||
box_kwargs=None, | |||
segm_proba=1 / 3, | |||
segm_kwargs=None, | |||
squares_proba=0, | |||
squares_kwargs=None, | |||
superres_proba=0, | |||
superres_kwargs=None, | |||
outpainting_proba=0, | |||
outpainting_kwargs=None, | |||
invert_proba=0): | |||
self.probas = [] | |||
self.gens = [] | |||
if irregular_proba > 0: | |||
self.probas.append(irregular_proba) | |||
if irregular_kwargs is None: | |||
irregular_kwargs = {} | |||
else: | |||
irregular_kwargs = dict(irregular_kwargs) | |||
irregular_kwargs['draw_method'] = DrawMethod.LINE | |||
self.gens.append(RandomIrregularMaskGenerator(**irregular_kwargs)) | |||
if box_proba > 0: | |||
self.probas.append(box_proba) | |||
if box_kwargs is None: | |||
box_kwargs = {} | |||
self.gens.append(RandomRectangleMaskGenerator(**box_kwargs)) | |||
if squares_proba > 0: | |||
self.probas.append(squares_proba) | |||
if squares_kwargs is None: | |||
squares_kwargs = {} | |||
else: | |||
squares_kwargs = dict(squares_kwargs) | |||
squares_kwargs['draw_method'] = DrawMethod.SQUARE | |||
self.gens.append(RandomIrregularMaskGenerator(**squares_kwargs)) | |||
if superres_proba > 0: | |||
self.probas.append(superres_proba) | |||
if superres_kwargs is None: | |||
superres_kwargs = {} | |||
self.gens.append(RandomSuperresMaskGenerator(**superres_kwargs)) | |||
self.probas = np.array(self.probas, dtype='float32') | |||
self.probas /= self.probas.sum() | |||
self.invert_proba = invert_proba | |||
def __call__(self, img, iter_i=None, raw_image=None): | |||
kind = np.random.choice(len(self.probas), p=self.probas) | |||
gen = self.gens[kind] | |||
result = gen(img, iter_i=iter_i, raw_image=raw_image) | |||
if self.invert_proba > 0 and random.random() < self.invert_proba: | |||
result = 1 - result | |||
return result | |||
def get_transforms(test_mode, out_size): | |||
if not test_mode: | |||
transform = A.Compose([ | |||
IAAPerspective2(scale=(0.0, 0.06)), | |||
IAAAffine2(scale=(0.7, 1.3), rotate=(-40, 40), shear=(-0.1, 0.1)), | |||
A.PadIfNeeded(min_height=out_size, min_width=out_size), | |||
A.OpticalDistortion(), | |||
A.RandomCrop(height=out_size, width=out_size), | |||
A.HorizontalFlip(), | |||
A.CLAHE(), | |||
A.RandomBrightnessContrast( | |||
brightness_limit=0.2, contrast_limit=0.2), | |||
A.HueSaturationValue( | |||
hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), | |||
A.ToFloat() | |||
]) | |||
else: | |||
transform = A.Compose([ | |||
A.PadIfNeeded(min_height=out_size, min_width=out_size), | |||
A.CenterCrop(height=out_size, width=out_size), | |||
A.ToFloat() | |||
]) | |||
return transform | |||
@TASK_DATASETS.register_module( | |||
Tasks.image_inpainting, module_name=Models.image_inpainting) | |||
class ImageInpaintingDataset(TorchTaskDataset): | |||
def __init__(self, **kwargs): | |||
split_config = kwargs['split_config'] | |||
LOGGER.info(kwargs) | |||
mode = kwargs.get('test_mode', False) | |||
self.data_root = next(iter(split_config.values())) | |||
if not osp.exists(self.data_root): | |||
self.data_root = osp.dirname(self.data_root) | |||
assert osp.exists(self.data_root) | |||
mask_gen_kwargs = kwargs.get('mask_gen_kwargs', {}) | |||
out_size = kwargs.get('out_size', 256) | |||
self.mask_generator = MixedMaskGenerator(**mask_gen_kwargs) | |||
self.transform = get_transforms(mode, out_size) | |||
self.in_files = sorted( | |||
list( | |||
glob.glob( | |||
osp.join(self.data_root, '**', '*.jpg'), recursive=True)) | |||
+ list( | |||
glob.glob( | |||
osp.join(self.data_root, '**', '*.png'), recursive=True))) | |||
self.iter_i = 0 | |||
def __len__(self): | |||
return len(self.in_files) | |||
def __getitem__(self, index): | |||
path = self.in_files[index] | |||
img = cv2.imread(path) | |||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |||
img = self.transform(image=img)['image'] | |||
img = np.transpose(img, (2, 0, 1)) | |||
# TODO: maybe generate mask before augmentations? slower, but better for segmentation-based masks | |||
mask = self.mask_generator(img, iter_i=self.iter_i) | |||
self.iter_i += 1 | |||
return dict(image=img, mask=mask) |
@@ -177,6 +177,7 @@ TASK_OUTPUTS = { | |||
Tasks.image_denoising: [OutputKeys.OUTPUT_IMG], | |||
Tasks.image_portrait_enhancement: [OutputKeys.OUTPUT_IMG], | |||
Tasks.crowd_counting: [OutputKeys.SCORES, OutputKeys.OUTPUT_IMG], | |||
Tasks.image_inpainting: [OutputKeys.OUTPUT_IMG], | |||
# image generation task result for a single image | |||
# {"output_img": np.array with shape (h, w, 3)} | |||
@@ -181,6 +181,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
'damo/cv_resnet50-bert_video-scene-segmentation_movienet'), | |||
Tasks.shop_segmentation: (Pipelines.shop_segmentation, | |||
'damo/cv_vitb16_segmentation_shop-seg'), | |||
Tasks.image_inpainting: (Pipelines.image_inpainting, | |||
'damo/cv_fft_inpainting_lama'), | |||
Tasks.video_inpainting: (Pipelines.video_inpainting, | |||
'damo/cv_video-inpainting'), | |||
Tasks.hand_static: (Pipelines.hand_static, | |||
@@ -35,6 +35,7 @@ if TYPE_CHECKING: | |||
from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | |||
from .image_to_image_generate_pipeline import Image2ImageGenerationPipeline | |||
from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline | |||
from .image_inpainting_pipeline import ImageInpaintingPipeline | |||
from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline | |||
from .realtime_object_detection_pipeline import RealtimeObjectDetectionPipeline | |||
from .live_category_pipeline import LiveCategoryPipeline | |||
@@ -99,6 +100,7 @@ else: | |||
'live_category_pipeline': ['LiveCategoryPipeline'], | |||
'image_to_image_generation_pipeline': | |||
['Image2ImageGenerationPipeline'], | |||
'image_inpainting_pipeline': ['ImageInpaintingPipeline'], | |||
'ocr_detection_pipeline': ['OCRDetectionPipeline'], | |||
'ocr_recognition_pipeline': ['OCRRecognitionPipeline'], | |||
'skin_retouching_pipeline': ['SkinRetouchingPipeline'], | |||
@@ -0,0 +1,146 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Any, Dict | |||
import cv2 | |||
import numpy as np | |||
import PIL | |||
import torch | |||
import torch.nn as nn | |||
from torch.utils.data._utils.collate import default_collate | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models.cv.image_inpainting import FFTInpainting | |||
from modelscope.models.cv.image_inpainting.refinement import refine_predict | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors.image import LoadImage | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.image_inpainting, module_name=Pipelines.image_inpainting) | |||
class ImageInpaintingPipeline(Pipeline): | |||
def __init__(self, | |||
model: str, | |||
pad_out_to_modulo=8, | |||
refine=False, | |||
**kwargs): | |||
""" | |||
model: model id on modelscope hub. | |||
""" | |||
assert isinstance(model, str), 'model must be a single str' | |||
super().__init__(model=model, auto_collate=False, **kwargs) | |||
self.refine = refine | |||
logger.info(f'loading model from dir {model}') | |||
self.infer_model = FFTInpainting(model, predict_only=True) | |||
if not self.refine: | |||
self.infer_model.to(self.device) | |||
self.infer_model.eval() | |||
logger.info(f'loading model done, refinement is set to {self.refine}') | |||
self.pad_out_to_modulo = pad_out_to_modulo | |||
def move_to_device(self, obj, device): | |||
if isinstance(obj, nn.Module): | |||
return obj.to(device) | |||
if torch.is_tensor(obj): | |||
return obj.to(device) | |||
if isinstance(obj, (tuple, list)): | |||
return [self.move_to_device(el, device) for el in obj] | |||
if isinstance(obj, dict): | |||
return { | |||
name: self.move_to_device(val, device) | |||
for name, val in obj.items() | |||
} | |||
raise ValueError(f'Unexpected type {type(obj)}') | |||
def transforms(self, img): | |||
if img.ndim == 3: | |||
img = np.transpose(img, (2, 0, 1)) | |||
out_img = img.astype('float32') / 255 | |||
return out_img | |||
def ceil_modulo(self, x, mod): | |||
if x % mod == 0: | |||
return x | |||
return (x // mod + 1) * mod | |||
def pad_img_to_modulo(self, img, mod): | |||
channels, height, width = img.shape | |||
out_height = self.ceil_modulo(height, mod) | |||
out_width = self.ceil_modulo(width, mod) | |||
return np.pad( | |||
img, ((0, 0), (0, out_height - height), (0, out_width - width)), | |||
mode='symmetric') | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
if isinstance(input, str): | |||
image_name, mask_name = input.split('+') | |||
img = LoadImage.convert_to_ndarray(image_name) | |||
img = self.transforms(img) | |||
mask = np.array(LoadImage(mode='L')(mask_name)['img']) | |||
mask = self.transforms(mask) | |||
elif isinstance(input, PIL.Image.Image): | |||
img = input.crop((0, 0, int(input.width / 2), input.height)) | |||
img = self.transforms(np.array(img)) | |||
mask = input.crop((int(input.width / 2), 0, input.width, | |||
input.height)).convert('L') | |||
mask = self.transforms(np.array(mask)) | |||
else: | |||
raise TypeError('input should be either str or PIL.Image') | |||
result = dict(image=img, mask=mask[None, ...]) | |||
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1: | |||
result['unpad_to_size'] = result['image'].shape[1:] | |||
result['image'] = self.pad_img_to_modulo(result['image'], | |||
self.pad_out_to_modulo) | |||
result['mask'] = self.pad_img_to_modulo(result['mask'], | |||
self.pad_out_to_modulo) | |||
# Since Pipeline use default torch.no_grad() for performing forward func. | |||
# We conduct inference here in case of doing training for refinement. | |||
result = self.perform_inference(result) | |||
return result | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
return {OutputKeys.OUTPUT_IMG: input} | |||
def perform_inference(self, data): | |||
batch = default_collate([data]) | |||
if self.refine: | |||
assert 'unpad_to_size' in batch, 'Unpadded size is required for the refinement' | |||
assert 'cuda' in str(self.device), 'GPU is required for refinement' | |||
gpu_ids = str(self.device).split(':')[-1] | |||
cur_res = refine_predict( | |||
batch, | |||
self.infer_model, | |||
gpu_ids=gpu_ids, | |||
modulo=self.pad_out_to_modulo, | |||
n_iters=15, | |||
lr=0.002, | |||
min_side=512, | |||
max_scales=3, | |||
px_budget=900000) | |||
cur_res = cur_res[0].permute(1, 2, 0).detach().cpu().numpy() | |||
else: | |||
with torch.no_grad(): | |||
batch = self.move_to_device(batch, self.device) | |||
batch['mask'] = (batch['mask'] > 0) * 1 | |||
batch = self.infer_model(batch) | |||
cur_res = batch['inpainted'][0].permute( | |||
1, 2, 0).detach().cpu().numpy() | |||
unpad_to_size = batch.get('unpad_to_size', None) | |||
if unpad_to_size is not None: | |||
orig_height, orig_width = unpad_to_size | |||
cur_res = cur_res[:orig_height, :orig_width] | |||
cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8') | |||
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR) | |||
return cur_res | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
return inputs |
@@ -9,7 +9,7 @@ if TYPE_CHECKING: | |||
from .builder import build_trainer | |||
from .cv import (ImageInstanceSegmentationTrainer, | |||
ImagePortraitEnhancementTrainer, | |||
MovieSceneSegmentationTrainer) | |||
MovieSceneSegmentationTrainer, ImageInpaintingTrainer) | |||
from .multi_modal import CLIPTrainer | |||
from .nlp import SequenceClassificationTrainer, PassageRankingTrainer | |||
from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer | |||
@@ -22,7 +22,8 @@ else: | |||
'builder': ['build_trainer'], | |||
'cv': [ | |||
'ImageInstanceSegmentationTrainer', | |||
'ImagePortraitEnhancementTrainer', 'MovieSceneSegmentationTrainer' | |||
'ImagePortraitEnhancementTrainer', 'MovieSceneSegmentationTrainer', | |||
'ImageInpaintingTrainer' | |||
], | |||
'multi_modal': ['CLIPTrainer'], | |||
'nlp': ['SequenceClassificationTrainer', 'PassageRankingTrainer'], | |||
@@ -8,6 +8,7 @@ if TYPE_CHECKING: | |||
ImageInstanceSegmentationTrainer | |||
from .image_portrait_enhancement_trainer import ImagePortraitEnhancementTrainer | |||
from .movie_scene_segmentation_trainer import MovieSceneSegmentationTrainer | |||
from .image_inpainting_trainer import ImageInpaintingTrainer | |||
else: | |||
_import_structure = { | |||
@@ -15,7 +16,8 @@ else: | |||
['ImageInstanceSegmentationTrainer'], | |||
'image_portrait_enhancement_trainer': | |||
['ImagePortraitEnhancementTrainer'], | |||
'movie_scene_segmentation_trainer': ['MovieSceneSegmentationTrainer'] | |||
'movie_scene_segmentation_trainer': ['MovieSceneSegmentationTrainer'], | |||
'image_inpainting_trainer': ['ImageInpaintingTrainer'] | |||
} | |||
import sys | |||
@@ -0,0 +1,111 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import time | |||
from collections.abc import Mapping | |||
from torch import distributed as dist | |||
from modelscope.metainfo import Trainers | |||
from modelscope.trainers.builder import TRAINERS | |||
from modelscope.trainers.trainer import EpochBasedTrainer | |||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields, | |||
ConfigKeys, Hubs, ModeKeys, ModelFile, | |||
Tasks, TrainerStages) | |||
from modelscope.utils.data_utils import to_device | |||
from modelscope.utils.file_utils import func_receive_dict_inputs | |||
@TRAINERS.register_module(module_name=Trainers.image_inpainting) | |||
class ImageInpaintingTrainer(EpochBasedTrainer): | |||
def __init__(self, *args, **kwargs): | |||
super().__init__(*args, **kwargs) | |||
def train(self, *args, **kwargs): | |||
super().train(*args, **kwargs) | |||
def evaluate(self, *args, **kwargs): | |||
metric_values = super().evaluate(*args, **kwargs) | |||
return metric_values | |||
def prediction_step(self, model, inputs): | |||
pass | |||
def train_loop(self, data_loader): | |||
""" Training loop used by `EpochBasedTrainer.train()` | |||
""" | |||
self.invoke_hook(TrainerStages.before_run) | |||
self._epoch = 0 | |||
self.model.train() | |||
for _ in range(self._epoch, self._max_epochs): | |||
self.invoke_hook(TrainerStages.before_train_epoch) | |||
for i, data_batch in enumerate(data_loader): | |||
data_batch = to_device(data_batch, self.device) | |||
self.data_batch = data_batch | |||
self._inner_iter = i | |||
for idx in range(2): | |||
self.invoke_hook(TrainerStages.before_train_iter) | |||
self.train_step(self.model, data_batch, idx) | |||
self.invoke_hook(TrainerStages.after_train_iter) | |||
del self.data_batch | |||
self._iter += 1 | |||
self._mode = ModeKeys.TRAIN | |||
if i + 1 >= self.iters_per_epoch: | |||
break | |||
self.invoke_hook(TrainerStages.after_train_epoch) | |||
self._epoch += 1 | |||
self.invoke_hook(TrainerStages.after_run) | |||
def train_step(self, model, inputs, idx): | |||
""" Perform a training step on a batch of inputs. | |||
Subclass and override to inject custom behavior. | |||
Args: | |||
model (`TorchModel`): The model to train. | |||
inputs (`Dict[str, Union[torch.Tensor, Any]]`): | |||
The inputs and targets of the model. | |||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the | |||
argument `labels`. Check your model's documentation for all accepted arguments. | |||
Return: | |||
`torch.Tensor`: The tensor with training loss on this batch. | |||
""" | |||
# EvaluationHook will do evaluate and change mode to val, return to train mode | |||
# TODO: find more pretty way to change mode | |||
model.train() | |||
self._mode = ModeKeys.TRAIN | |||
# call model forward but not __call__ to skip postprocess | |||
if isinstance(inputs, | |||
Mapping) and not func_receive_dict_inputs(model.forward): | |||
train_outputs = model.model._do_step(**inputs, optimizer_idx=idx) | |||
else: | |||
train_outputs = model.model._do_step(inputs, optimizer_idx=idx) | |||
if not isinstance(train_outputs, dict): | |||
raise TypeError('"model.forward()" must return a dict') | |||
# add model output info to log | |||
if 'log_vars' not in train_outputs: | |||
default_keys_pattern = ['loss'] | |||
match_keys = set([]) | |||
for key_p in default_keys_pattern: | |||
match_keys.update( | |||
[key for key in train_outputs.keys() if key_p in key]) | |||
log_vars = {} | |||
for key in match_keys: | |||
value = train_outputs.get(key, None) | |||
if value is not None: | |||
if dist.is_available() and dist.is_initialized(): | |||
value = value.data.clone() | |||
dist.all_reduce(value.div_(dist.get_world_size())) | |||
log_vars.update({key: value.item()}) | |||
self.log_buffer.update(log_vars) | |||
else: | |||
self.log_buffer.update(train_outputs['log_vars']) | |||
self.train_outputs = train_outputs |
@@ -47,6 +47,8 @@ class CVTasks(object): | |||
face_emotion = 'face-emotion' | |||
product_segmentation = 'product-segmentation' | |||
crowd_counting = 'crowd-counting' | |||
# image editing | |||
skin_retouching = 'skin-retouching' | |||
image_super_resolution = 'image-super-resolution' | |||
@@ -54,6 +56,7 @@ class CVTasks(object): | |||
image_color_enhancement = 'image-color-enhancement' | |||
image_denoising = 'image-denoising' | |||
image_portrait_enhancement = 'image-portrait-enhancement' | |||
image_inpainting = 'image-inpainting' | |||
# image generation | |||
image_to_image_translation = 'image-to-image-translation' | |||
@@ -72,7 +75,6 @@ class CVTasks(object): | |||
video_category = 'video-category' | |||
video_embedding = 'video-embedding' | |||
virtual_try_on = 'virtual-try-on' | |||
crowd_counting = 'crowd-counting' | |||
movie_scene_segmentation = 'movie-scene-segmentation' | |||
# video editing | |||
@@ -7,6 +7,8 @@ ffmpeg-python>=0.2.0 | |||
ftfy | |||
imageio>=2.9.0 | |||
imageio-ffmpeg>=0.4.2 | |||
imgaug>=0.4.0 | |||
kornia>=0.5.0 | |||
lmdb | |||
lpips | |||
ml_collections | |||
@@ -0,0 +1,77 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
import cv2 | |||
import torch | |||
from PIL import Image | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.test_utils import test_level | |||
logger = get_logger() | |||
class ImageInpaintingTest(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.input_location = 'data/test/images/image_inpainting/image_inpainting.png' | |||
self.input_mask_location = 'data/test/images/image_inpainting/image_inpainting_mask.png' | |||
self.model_id = 'damo/cv_fft_inpainting_lama' | |||
def save_result(self, result): | |||
vis_img = result[OutputKeys.OUTPUT_IMG] | |||
cv2.imwrite('result.png', vis_img) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_inpainting(self): | |||
inpainting = pipeline(Tasks.image_inpainting, model=self.model_id) | |||
result = inpainting(self.input_location + '+' | |||
+ self.input_mask_location) | |||
if result: | |||
self.save_result(result) | |||
else: | |||
raise ValueError('process error') | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') | |||
def test_inpainting_with_refinement(self): | |||
# if input image is HR, set refine=True is more better | |||
inpainting = pipeline( | |||
Tasks.image_inpainting, model=self.model_id, refine=True) | |||
result = inpainting(self.input_location + '+' | |||
+ self.input_mask_location) | |||
if result: | |||
self.save_result(result) | |||
else: | |||
raise ValueError('process error') | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_inpainting_with_image(self): | |||
inpainting = pipeline(Tasks.image_inpainting, model=self.model_id) | |||
img = Image.open(self.input_location).convert('RGB') | |||
mask = Image.open(self.input_mask_location).convert('RGB') | |||
img_new = Image.new('RGB', (img.width + mask.width, img.height)) | |||
img_new.paste(img, (0, 0)) | |||
img_new.paste(mask, (img.width, 0)) | |||
result = inpainting(img_new) | |||
if result: | |||
self.save_result(result) | |||
else: | |||
raise ValueError('process error') | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_inpainting_with_default_task(self): | |||
inpainting = pipeline(Tasks.image_inpainting) | |||
result = inpainting(self.input_location + '+' | |||
+ self.input_mask_location) | |||
if result: | |||
self.save_result(result) | |||
else: | |||
raise ValueError('process error') | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -10,6 +10,7 @@ isolated: # test cases that may require excessive anmount of GPU memory, which | |||
- test_easycv_trainer.py | |||
- test_segformer.py | |||
- test_segmentation_pipeline.py | |||
- test_image_inpainting.py | |||
envs: | |||
default: # default env, case not in other env will in default, pytorch. | |||
@@ -0,0 +1,84 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import shutil | |||
import tempfile | |||
import unittest | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.metainfo import Trainers | |||
from modelscope.models.cv.image_inpainting import FFTInpainting | |||
from modelscope.msdatasets import MsDataset | |||
from modelscope.trainers import build_trainer | |||
from modelscope.utils.config import Config, ConfigDict | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.test_utils import test_level | |||
logger = get_logger() | |||
class ImageInpaintingTrainerTest(unittest.TestCase): | |||
def setUp(self): | |||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||
self.tmp_dir = tempfile.TemporaryDirectory().name | |||
if not os.path.exists(self.tmp_dir): | |||
os.makedirs(self.tmp_dir) | |||
self.model_id = 'damo/cv_fft_inpainting_lama' | |||
self.cache_path = snapshot_download(self.model_id) | |||
cfg = Config.from_file( | |||
os.path.join(self.cache_path, ModelFile.CONFIGURATION)) | |||
train_data_cfg = ConfigDict( | |||
name='PlacesToydataset', | |||
split='train', | |||
mask_gen_kwargs=cfg.dataset.mask_gen_kwargs, | |||
out_size=cfg.dataset.train_out_size, | |||
test_mode=False) | |||
test_data_cfg = ConfigDict( | |||
name='PlacesToydataset', | |||
split='test', | |||
mask_gen_kwargs=cfg.dataset.mask_gen_kwargs, | |||
out_size=cfg.dataset.val_out_size, | |||
test_mode=True) | |||
self.train_dataset = MsDataset.load( | |||
dataset_name=train_data_cfg.name, | |||
split=train_data_cfg.split, | |||
mask_gen_kwargs=train_data_cfg.mask_gen_kwargs, | |||
out_size=train_data_cfg.out_size, | |||
test_mode=train_data_cfg.test_mode) | |||
assert next( | |||
iter(self.train_dataset.config_kwargs['split_config'].values())) | |||
self.test_dataset = MsDataset.load( | |||
dataset_name=test_data_cfg.name, | |||
split=test_data_cfg.split, | |||
mask_gen_kwargs=test_data_cfg.mask_gen_kwargs, | |||
out_size=test_data_cfg.out_size, | |||
test_mode=test_data_cfg.test_mode) | |||
assert next( | |||
iter(self.test_dataset.config_kwargs['split_config'].values())) | |||
def tearDown(self): | |||
shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||
super().tearDown() | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_trainer(self): | |||
kwargs = dict( | |||
model=self.model_id, | |||
train_dataset=self.train_dataset, | |||
eval_dataset=self.test_dataset) | |||
trainer = build_trainer( | |||
name=Trainers.image_inpainting, default_args=kwargs) | |||
trainer.train() | |||
results_files = os.listdir(trainer.work_dir) | |||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
if __name__ == '__main__': | |||
unittest.main() |