From 69da8f91ac5ca420408100c4ec5abd0c5987e65a Mon Sep 17 00:00:00 2001 From: "ashui.cbh" Date: Tue, 11 Oct 2022 20:49:13 +0800 Subject: [PATCH] [to #42322933]suport image inpainting Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10111615 --- .../image_inpainting/image_inpainting.png | 3 + .../image_inpainting_mask.png | 3 + modelscope/metainfo.py | 5 + modelscope/metrics/__init__.py | 2 + modelscope/metrics/builder.py | 2 + modelscope/metrics/image_inpainting_metric.py | 210 +++++++ modelscope/models/cv/__init__.py | 17 +- .../models/cv/crowd_counting/cc_model.py | 2 + .../cv/crowd_counting/hrnet_aspp_relu.py | 14 +- .../models/cv/image_inpainting/__init__.py | 22 + modelscope/models/cv/image_inpainting/base.py | 75 +++ .../models/cv/image_inpainting/default.py | 210 +++++++ .../models/cv/image_inpainting/model.py | 36 ++ .../cv/image_inpainting/modules/__init__.py | 0 .../modules/ade20k/__init__.py | 2 + .../image_inpainting/modules/ade20k/base.py | 380 +++++++++++ .../image_inpainting/modules/ade20k/resnet.py | 183 ++++++ .../image_inpainting/modules/adversarial.py | 167 +++++ .../modules/feature_matching.py | 45 ++ .../models/cv/image_inpainting/modules/ffc.py | 588 ++++++++++++++++++ .../cv/image_inpainting/modules/inception.py | 324 ++++++++++ .../cv/image_inpainting/modules/perceptual.py | 47 ++ .../cv/image_inpainting/modules/pix2pixhd.py | 75 +++ .../models/cv/image_inpainting/refinement.py | 393 ++++++++++++ .../msdatasets/task_datasets/__init__.py | 2 + .../image_inpainting/__init__.py | 2 + .../task_datasets/image_inpainting/aug.py | 100 +++ .../image_inpainting_dataset.py | 337 ++++++++++ modelscope/outputs.py | 1 + modelscope/pipelines/builder.py | 2 + modelscope/pipelines/cv/__init__.py | 2 + .../pipelines/cv/image_inpainting_pipeline.py | 146 +++++ modelscope/trainers/__init__.py | 5 +- modelscope/trainers/cv/__init__.py | 4 +- .../trainers/cv/image_inpainting_trainer.py | 111 ++++ modelscope/utils/constant.py | 4 +- requirements/cv.txt | 2 + tests/pipelines/test_image_inpainting.py | 77 +++ tests/run_config.yaml | 1 + .../trainers/test_image_inpainting_trainer.py | 84 +++ 40 files changed, 3666 insertions(+), 19 deletions(-) create mode 100644 data/test/images/image_inpainting/image_inpainting.png create mode 100644 data/test/images/image_inpainting/image_inpainting_mask.png create mode 100644 modelscope/metrics/image_inpainting_metric.py create mode 100644 modelscope/models/cv/image_inpainting/__init__.py create mode 100644 modelscope/models/cv/image_inpainting/base.py create mode 100644 modelscope/models/cv/image_inpainting/default.py create mode 100644 modelscope/models/cv/image_inpainting/model.py create mode 100644 modelscope/models/cv/image_inpainting/modules/__init__.py create mode 100644 modelscope/models/cv/image_inpainting/modules/ade20k/__init__.py create mode 100644 modelscope/models/cv/image_inpainting/modules/ade20k/base.py create mode 100644 modelscope/models/cv/image_inpainting/modules/ade20k/resnet.py create mode 100644 modelscope/models/cv/image_inpainting/modules/adversarial.py create mode 100644 modelscope/models/cv/image_inpainting/modules/feature_matching.py create mode 100644 modelscope/models/cv/image_inpainting/modules/ffc.py create mode 100644 modelscope/models/cv/image_inpainting/modules/inception.py create mode 100644 modelscope/models/cv/image_inpainting/modules/perceptual.py create mode 100644 modelscope/models/cv/image_inpainting/modules/pix2pixhd.py create mode 100644 modelscope/models/cv/image_inpainting/refinement.py create mode 100644 modelscope/msdatasets/task_datasets/image_inpainting/__init__.py create mode 100644 modelscope/msdatasets/task_datasets/image_inpainting/aug.py create mode 100644 modelscope/msdatasets/task_datasets/image_inpainting/image_inpainting_dataset.py create mode 100644 modelscope/pipelines/cv/image_inpainting_pipeline.py create mode 100644 modelscope/trainers/cv/image_inpainting_trainer.py create mode 100644 tests/pipelines/test_image_inpainting.py create mode 100644 tests/trainers/test_image_inpainting_trainer.py diff --git a/data/test/images/image_inpainting/image_inpainting.png b/data/test/images/image_inpainting/image_inpainting.png new file mode 100644 index 00000000..e141012d --- /dev/null +++ b/data/test/images/image_inpainting/image_inpainting.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:46db348eae61448f1668ce282caec21375e96c3268d53da44aa67ec32cbf4fa5 +size 2747938 diff --git a/data/test/images/image_inpainting/image_inpainting_mask.png b/data/test/images/image_inpainting/image_inpainting_mask.png new file mode 100644 index 00000000..e30f67e7 --- /dev/null +++ b/data/test/images/image_inpainting/image_inpainting_mask.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:709c1828ed2d56badf2f19a40194da9a5e5e6db2fb73ef55d047407f49bc7a15 +size 27616 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 77627abc..cae9d188 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -27,6 +27,7 @@ class Models(object): face_2d_keypoints = 'face-2d-keypoints' panoptic_segmentation = 'swinL-panoptic-segmentation' image_reid_person = 'passvitb' + image_inpainting = 'FFTInpainting' video_summarization = 'pgl-video-summarization' swinL_semantic_segmentation = 'swinL-semantic-segmentation' vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' @@ -179,6 +180,7 @@ class Pipelines(object): video_summarization = 'googlenet_pgl_video_summarization' image_semantic_segmentation = 'image-semantic-segmentation' image_reid_person = 'passvitb-image-reid-person' + image_inpainting = 'fft-inpainting' text_driven_segmentation = 'text-driven-segmentation' movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' shop_segmentation = 'shop-segmentation' @@ -264,6 +266,7 @@ class Trainers(object): image_portrait_enhancement = 'image-portrait-enhancement' video_summarization = 'video-summarization' movie_scene_segmentation = 'movie-scene-segmentation' + image_inpainting = 'image-inpainting' # nlp trainers bert_sentiment_analysis = 'bert-sentiment-analysis' @@ -363,6 +366,8 @@ class Metrics(object): video_summarization_metric = 'video-summarization-metric' # metric for movie-scene-segmentation task movie_scene_segmentation_metric = 'movie-scene-segmentation-metric' + # metric for inpainting task + image_inpainting_metric = 'image-inpainting-metric' class Optimizers(object): diff --git a/modelscope/metrics/__init__.py b/modelscope/metrics/__init__.py index d3975a2c..e6a03a22 100644 --- a/modelscope/metrics/__init__.py +++ b/modelscope/metrics/__init__.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from .token_classification_metric import TokenClassificationMetric from .video_summarization_metric import VideoSummarizationMetric from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric + from .image_inpainting_metric import ImageInpaintingMetric else: _import_structure = { @@ -34,6 +35,7 @@ else: 'token_classification_metric': ['TokenClassificationMetric'], 'video_summarization_metric': ['VideoSummarizationMetric'], 'movie_scene_segmentation_metric': ['MovieSceneSegmentationMetric'], + 'image_inpainting_metric': ['ImageInpaintingMetric'], } import sys diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py index 9e875cc4..ee4d2840 100644 --- a/modelscope/metrics/builder.py +++ b/modelscope/metrics/builder.py @@ -18,6 +18,7 @@ class MetricKeys(object): SSIM = 'ssim' AVERAGE_LOSS = 'avg_loss' FScore = 'fscore' + FID = 'fid' BLEU_1 = 'bleu-1' BLEU_4 = 'bleu-4' ROUGE_1 = 'rouge-1' @@ -39,6 +40,7 @@ task_default_metrics = { Tasks.image_captioning: [Metrics.text_gen_metric], Tasks.visual_question_answering: [Metrics.text_gen_metric], Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric], + Tasks.image_inpainting: [Metrics.image_inpainting_metric], } diff --git a/modelscope/metrics/image_inpainting_metric.py b/modelscope/metrics/image_inpainting_metric.py new file mode 100644 index 00000000..954d4ca2 --- /dev/null +++ b/modelscope/metrics/image_inpainting_metric.py @@ -0,0 +1,210 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +from typing import Dict + +import numpy as np +import torch +import torch.nn.functional as F +from scipy import linalg + +from modelscope.metainfo import Metrics +from modelscope.models.cv.image_inpainting.modules.inception import InceptionV3 +from modelscope.utils.registry import default_group +from modelscope.utils.tensor_utils import (torch_nested_detach, + torch_nested_numpify) +from .base import Metric +from .builder import METRICS, MetricKeys + + +def fid_calculate_activation_statistics(act): + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + +def calculate_frechet_distance(activations_pred, activations_target, eps=1e-6): + mu1, sigma1 = fid_calculate_activation_statistics(activations_pred) + mu2, sigma2 = fid_calculate_activation_statistics(activations_target) + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + # if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-2): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) + - 2 * tr_covmean) + + +class FIDScore(torch.nn.Module): + + def __init__(self, dims=2048, eps=1e-6): + super().__init__() + if getattr(FIDScore, '_MODEL', None) is None: + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + FIDScore._MODEL = InceptionV3([block_idx]).eval() + self.model = FIDScore._MODEL + self.eps = eps + self.reset() + + def forward(self, pred_batch, target_batch, mask=None): + activations_pred = self._get_activations(pred_batch) + activations_target = self._get_activations(target_batch) + + self.activations_pred.append(activations_pred.detach().cpu()) + self.activations_target.append(activations_target.detach().cpu()) + + def get_value(self): + activations_pred, activations_target = (self.activations_pred, + self.activations_target) + activations_pred = torch.cat(activations_pred).cpu().numpy() + activations_target = torch.cat(activations_target).cpu().numpy() + + total_distance = calculate_frechet_distance( + activations_pred, activations_target, eps=self.eps) + + self.reset() + return total_distance + + def reset(self): + self.activations_pred = [] + self.activations_target = [] + + def _get_activations(self, batch): + activations = self.model(batch)[0] + if activations.shape[2] != 1 or activations.shape[3] != 1: + assert False, \ + 'We should not have got here, because Inception always scales inputs to 299x299' + activations = activations.squeeze(-1).squeeze(-1) + return activations + + +class SSIM(torch.nn.Module): + """SSIM. Modified from: + https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py + """ + + def __init__(self, window_size=11, size_average=True): + super().__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.register_buffer('window', + self._create_window(window_size, self.channel)) + + def forward(self, img1, img2): + assert len(img1.shape) == 4 + + channel = img1.size()[1] + + if channel == self.channel and self.window.data.type( + ) == img1.data.type(): + window = self.window + else: + window = self._create_window(self.window_size, channel) + + window = window.type_as(img1) + + self.window = window + self.channel = channel + + return self._ssim(img1, img2, window, self.window_size, channel, + self.size_average) + + def _gaussian(self, window_size, sigma): + gauss = torch.Tensor([ + np.exp(-(x - (window_size // 2))**2 / float(2 * sigma**2)) + for x in range(window_size) + ]) + return gauss / gauss.sum() + + def _create_window(self, window_size, channel): + _1D_window = self._gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm( + _1D_window.t()).float().unsqueeze(0).unsqueeze(0) + return _2D_window.expand(channel, 1, window_size, + window_size).contiguous() + + def _ssim(self, + img1, + img2, + window, + window_size, + channel, + size_average=True): + mu1 = F.conv2d( + img1, window, padding=(window_size // 2), groups=channel) + mu2 = F.conv2d( + img2, window, padding=(window_size // 2), groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d( + img1 * img1, window, padding=(window_size // 2), + groups=channel) - mu1_sq + sigma2_sq = F.conv2d( + img2 * img2, window, padding=(window_size // 2), + groups=channel) - mu2_sq + sigma12 = F.conv2d( + img1 * img2, window, padding=(window_size // 2), + groups=channel) - mu1_mu2 + + C1 = 0.01**2 + C2 = 0.03**2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ + ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + + return ssim_map.mean(1).mean(1).mean(1) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + return + + +@METRICS.register_module( + group_key=default_group, module_name=Metrics.image_inpainting_metric) +class ImageInpaintingMetric(Metric): + """The metric computation class for image inpainting classes. + """ + + def __init__(self): + self.preds = [] + self.targets = [] + self.SSIM = SSIM(window_size=11, size_average=False).eval() + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.FID = FIDScore().to(device) + + def add(self, outputs: Dict, inputs: Dict): + pred = outputs['inpainted'] + target = inputs['image'] + self.preds.append(torch_nested_detach(pred)) + self.targets.append(torch_nested_detach(target)) + + def evaluate(self): + ssim_list = [] + for (pred, target) in zip(self.preds, self.targets): + ssim_list.append(self.SSIM(pred, target)) + self.FID(pred, target) + ssim_list = torch_nested_numpify(ssim_list) + fid = self.FID.get_value() + return {MetricKeys.SSIM: np.mean(ssim_list), MetricKeys.FID: fid} diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index f2798b59..ba7b03c5 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -5,13 +5,14 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, body_3d_keypoints, cartoon, cmdssl_video_embedding, crowd_counting, face_2d_keypoints, face_detection, face_generation, image_classification, image_color_enhance, - image_colorization, image_denoise, image_instance_segmentation, - image_panoptic_segmentation, image_portrait_enhancement, - image_reid_person, image_semantic_segmentation, - image_to_image_generation, image_to_image_translation, - movie_scene_segmentation, object_detection, - product_retrieval_embedding, realtime_object_detection, - salient_detection, shop_segmentation, super_resolution, - video_single_object_tracking, video_summarization, virual_tryon) + image_colorization, image_denoise, image_inpainting, + image_instance_segmentation, image_panoptic_segmentation, + image_portrait_enhancement, image_reid_person, + image_semantic_segmentation, image_to_image_generation, + image_to_image_translation, movie_scene_segmentation, + object_detection, product_retrieval_embedding, + realtime_object_detection, salient_detection, shop_segmentation, + super_resolution, video_single_object_tracking, + video_summarization, virual_tryon) # yapf: enable diff --git a/modelscope/models/cv/crowd_counting/cc_model.py b/modelscope/models/cv/crowd_counting/cc_model.py index 582b26f4..16fbc261 100644 --- a/modelscope/models/cv/crowd_counting/cc_model.py +++ b/modelscope/models/cv/crowd_counting/cc_model.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os from typing import Any, Dict, Optional, Union diff --git a/modelscope/models/cv/crowd_counting/hrnet_aspp_relu.py b/modelscope/models/cv/crowd_counting/hrnet_aspp_relu.py index 982ba939..0d1bd3ca 100644 --- a/modelscope/models/cv/crowd_counting/hrnet_aspp_relu.py +++ b/modelscope/models/cv/crowd_counting/hrnet_aspp_relu.py @@ -1,10 +1,10 @@ -# ------------------------------------------------------------------------------ -# Copyright (c) Microsoft -# Licensed under the MIT License. -# Written by Bin Xiao (Bin.Xiao@microsoft.com) -# Modified by Ke Sun (sunk@mail.ustc.edu.cn) -# https://github.com/HRNet/HRNet-Image-Classification/blob/master/lib/models/cls_hrnet.py -# ------------------------------------------------------------------------------ +""" +Copyright (c) Microsoft +Licensed under the MIT License. +Written by Bin Xiao (Bin.Xiao@microsoft.com) +Modified by Ke Sun (sunk@mail.ustc.edu.cn) +https://github.com/HRNet/HRNet-Image-Classification/blob/master/lib/models/cls_hrnet.py +""" import functools import logging diff --git a/modelscope/models/cv/image_inpainting/__init__.py b/modelscope/models/cv/image_inpainting/__init__.py new file mode 100644 index 00000000..e7c63cd4 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .model import FFTInpainting + +else: + _import_structure = { + 'model': ['FFTInpainting'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_inpainting/base.py b/modelscope/models/cv/image_inpainting/base.py new file mode 100644 index 00000000..04e73630 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/base.py @@ -0,0 +1,75 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +from typing import Dict, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.utils.logger import get_logger +from .modules.adversarial import NonSaturatingWithR1 +from .modules.ffc import FFCResNetGenerator +from .modules.perceptual import ResNetPL +from .modules.pix2pixhd import NLayerDiscriminator + +LOGGER = get_logger() + + +class BaseInpaintingTrainingModule(nn.Module): + + def __init__(self, + model_dir='', + use_ddp=True, + predict_only=False, + visualize_each_iters=100, + average_generator=False, + generator_avg_beta=0.999, + average_generator_start_step=30000, + average_generator_period=10, + store_discr_outputs_for_vis=False, + **kwargs): + super().__init__() + LOGGER.info( + f'BaseInpaintingTrainingModule init called, predict_only is {predict_only}' + ) + + self.generator = FFCResNetGenerator() + self.use_ddp = use_ddp + + if not predict_only: + self.discriminator = NLayerDiscriminator() + self.adversarial_loss = NonSaturatingWithR1( + weight=10, + gp_coef=0.001, + mask_as_fake_target=True, + allow_scale_mask=True) + + self.average_generator = average_generator + self.generator_avg_beta = generator_avg_beta + self.average_generator_start_step = average_generator_start_step + self.average_generator_period = average_generator_period + self.generator_average = None + self.last_generator_averaging_step = -1 + self.store_discr_outputs_for_vis = store_discr_outputs_for_vis + + self.loss_l1 = nn.L1Loss(reduction='none') + + self.loss_resnet_pl = ResNetPL(weight=30, weights_path=model_dir) + + self.visualize_each_iters = visualize_each_iters + LOGGER.info('BaseInpaintingTrainingModule init done') + + def forward(self, batch: Dict[str, + torch.Tensor]) -> Dict[str, torch.Tensor]: + """Pass data through generator and obtain at leas 'predicted_image' and 'inpainted' keys""" + raise NotImplementedError() + + def generator_loss(self, + batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + raise NotImplementedError() + + def discriminator_loss( + self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + raise NotImplementedError() diff --git a/modelscope/models/cv/image_inpainting/default.py b/modelscope/models/cv/image_inpainting/default.py new file mode 100644 index 00000000..5f57d63f --- /dev/null +++ b/modelscope/models/cv/image_inpainting/default.py @@ -0,0 +1,210 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +import bisect + +import torch +import torch.nn.functional as F + +from modelscope.utils.logger import get_logger +from .base import BaseInpaintingTrainingModule +from .modules.feature_matching import feature_matching_loss, masked_l1_loss + +LOGGER = get_logger() + + +def set_requires_grad(module, value): + for param in module.parameters(): + param.requires_grad = value + + +def add_prefix_to_keys(dct, prefix): + return {prefix + k: v for k, v in dct.items()} + + +class LinearRamp: + + def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0): + self.start_value = start_value + self.end_value = end_value + self.start_iter = start_iter + self.end_iter = end_iter + + def __call__(self, i): + if i < self.start_iter: + return self.start_value + if i >= self.end_iter: + return self.end_value + part = (i - self.start_iter) / (self.end_iter - self.start_iter) + return self.start_value * (1 - part) + self.end_value * part + + +class LadderRamp: + + def __init__(self, start_iters, values): + self.start_iters = start_iters + self.values = values + assert len(values) == len(start_iters) + 1, (len(values), + len(start_iters)) + + def __call__(self, i): + segment_i = bisect.bisect_right(self.start_iters, i) + return self.values[segment_i] + + +def get_ramp(kind='ladder', **kwargs): + if kind == 'linear': + return LinearRamp(**kwargs) + if kind == 'ladder': + return LadderRamp(**kwargs) + raise ValueError(f'Unexpected ramp kind: {kind}') + + +class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule): + + def __init__(self, + model_dir='', + predict_only=False, + concat_mask=True, + rescale_scheduler_kwargs=None, + image_to_discriminator='predicted_image', + add_noise_kwargs=None, + noise_fill_hole=False, + const_area_crop_kwargs=None, + distance_weighter_kwargs=None, + distance_weighted_mask_for_discr=False, + fake_fakes_proba=0, + fake_fakes_generator_kwargs=None, + **kwargs): + super().__init__(model_dir=model_dir, predict_only=predict_only) + self.concat_mask = concat_mask + self.rescale_size_getter = get_ramp( + **rescale_scheduler_kwargs + ) if rescale_scheduler_kwargs is not None else None + self.image_to_discriminator = image_to_discriminator + self.add_noise_kwargs = add_noise_kwargs + self.noise_fill_hole = noise_fill_hole + self.const_area_crop_kwargs = const_area_crop_kwargs + self.refine_mask_for_losses = None + self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr + + self.feature_matching_weight = 100 + self.losses_l1_weight_known = 10 + self.losses_l1_weight_missing = 0 + self.fake_fakes_proba = fake_fakes_proba + + def forward(self, batch): + img = batch['image'] + mask = batch['mask'] + + masked_img = img * (1 - mask) + + if self.concat_mask: + masked_img = torch.cat([masked_img, mask], dim=1) + + batch['predicted_image'] = self.generator(masked_img) + batch['inpainted'] = mask * batch['predicted_image'] + ( + 1 - mask) * batch['image'] + + batch['mask_for_losses'] = mask + + return batch + + def generator_loss(self, batch): + img = batch['image'] + predicted_img = batch[self.image_to_discriminator] + original_mask = batch['mask'] + supervised_mask = batch['mask_for_losses'] + + # L1 + l1_value = masked_l1_loss(predicted_img, img, supervised_mask, + self.losses_l1_weight_known, + self.losses_l1_weight_missing) + + total_loss = l1_value + metrics = dict(gen_l1=l1_value) + + # discriminator + # adversarial_loss calls backward by itself + mask_for_discr = supervised_mask if self.distance_weighted_mask_for_discr else original_mask + self.adversarial_loss.pre_generator_step( + real_batch=img, + fake_batch=predicted_img, + generator=self.generator, + discriminator=self.discriminator) + discr_real_pred, discr_real_features = self.discriminator(img) + discr_fake_pred, discr_fake_features = self.discriminator( + predicted_img) + adv_gen_loss, adv_metrics = self.adversarial_loss.generator_loss( + real_batch=img, + fake_batch=predicted_img, + discr_real_pred=discr_real_pred, + discr_fake_pred=discr_fake_pred, + mask=mask_for_discr) + total_loss = total_loss + adv_gen_loss + metrics['gen_adv'] = adv_gen_loss + metrics.update(add_prefix_to_keys(adv_metrics, 'adv_')) + + # feature matching + if self.feature_matching_weight > 0: + need_mask_in_fm = False + mask_for_fm = supervised_mask if need_mask_in_fm else None + fm_value = feature_matching_loss( + discr_fake_features, discr_real_features, + mask=mask_for_fm) * self.feature_matching_weight + total_loss = total_loss + fm_value + metrics['gen_fm'] = fm_value + + if self.loss_resnet_pl is not None: + resnet_pl_value = self.loss_resnet_pl(predicted_img, img) + total_loss = total_loss + resnet_pl_value + metrics['gen_resnet_pl'] = resnet_pl_value + + return total_loss, metrics + + def discriminator_loss(self, batch): + total_loss = 0 + metrics = {} + + predicted_img = batch[self.image_to_discriminator].detach() + self.adversarial_loss.pre_discriminator_step( + real_batch=batch['image'], + fake_batch=predicted_img, + generator=self.generator, + discriminator=self.discriminator) + discr_real_pred, discr_real_features = self.discriminator( + batch['image']) + discr_fake_pred, discr_fake_features = self.discriminator( + predicted_img) + adv_discr_loss, adv_metrics = self.adversarial_loss.discriminator_loss( + real_batch=batch['image'], + fake_batch=predicted_img, + discr_real_pred=discr_real_pred, + discr_fake_pred=discr_fake_pred, + mask=batch['mask']) + + total_loss = (total_loss + adv_discr_loss) * 0.1 + metrics['discr_adv'] = adv_discr_loss + metrics.update(add_prefix_to_keys(adv_metrics, 'adv_')) + + return total_loss, metrics + + def _do_step(self, batch, optimizer_idx=None): + if optimizer_idx == 0: # step for generator + set_requires_grad(self.generator, True) + set_requires_grad(self.discriminator, False) + elif optimizer_idx == 1: # step for discriminator + set_requires_grad(self.generator, False) + set_requires_grad(self.discriminator, True) + + batch = self(batch) + total_loss = 0 + if optimizer_idx is None or optimizer_idx == 0: # step for generator + total_loss, metrics = self.generator_loss(batch) + + elif optimizer_idx is None or optimizer_idx == 1: # step for discriminator + total_loss, metrics = self.discriminator_loss(batch) + + result = dict(loss=total_loss) + return result diff --git a/modelscope/models/cv/image_inpainting/model.py b/modelscope/models/cv/image_inpainting/model.py new file mode 100644 index 00000000..b12f6edd --- /dev/null +++ b/modelscope/models/cv/image_inpainting/model.py @@ -0,0 +1,36 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +LOGGER = get_logger() + + +@MODELS.register_module( + Tasks.image_inpainting, module_name=Models.image_inpainting) +class FFTInpainting(TorchModel): + + def __init__(self, model_dir: str, **kwargs): + super().__init__(model_dir, **kwargs) + + from .default import DefaultInpaintingTrainingModule + pretrained = kwargs.get('pretrained', True) + predict_only = kwargs.get('predict_only', False) + net = DefaultInpaintingTrainingModule( + model_dir=model_dir, predict_only=predict_only) + if pretrained: + path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE) + LOGGER.info(f'loading pretrained model from {path}') + state = torch.load(path, map_location='cpu') + net.load_state_dict(state, strict=False) + self.model = net + + def forward(self, inputs): + return self.model(inputs) diff --git a/modelscope/models/cv/image_inpainting/modules/__init__.py b/modelscope/models/cv/image_inpainting/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/image_inpainting/modules/ade20k/__init__.py b/modelscope/models/cv/image_inpainting/modules/ade20k/__init__.py new file mode 100644 index 00000000..89c3e293 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/ade20k/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .base import ModelBuilder diff --git a/modelscope/models/cv/image_inpainting/modules/ade20k/base.py b/modelscope/models/cv/image_inpainting/modules/ade20k/base.py new file mode 100644 index 00000000..02bd3cc4 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/ade20k/base.py @@ -0,0 +1,380 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" + +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules import BatchNorm2d + +from . import resnet + +NUM_CLASS = 150 + + +# Model Builder +class ModelBuilder: + # custom weights initialization + @staticmethod + def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.kaiming_normal_(m.weight.data) + elif classname.find('BatchNorm') != -1: + m.weight.data.fill_(1.) + m.bias.data.fill_(1e-4) + + @staticmethod + def build_encoder(arch='resnet50dilated', + fc_dim=512, + weights='', + model_dir=''): + pretrained = True if len(weights) == 0 else False + arch = arch.lower() + if arch == 'resnet50dilated': + orig_resnet = resnet.__dict__['resnet50']( + pretrained=pretrained, model_dir=model_dir) + net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) + elif arch == 'resnet50': + orig_resnet = resnet.__dict__['resnet50']( + pretrained=pretrained, model_dir=model_dir) + net_encoder = Resnet(orig_resnet) + else: + raise Exception('Architecture undefined!') + + # encoders are usually pretrained + # net_encoder.apply(ModelBuilder.weights_init) + if len(weights) > 0: + print('Loading weights for net_encoder') + net_encoder.load_state_dict( + torch.load(weights, map_location=lambda storage, loc: storage), + strict=False) + return net_encoder + + @staticmethod + def build_decoder(arch='ppm_deepsup', + fc_dim=512, + num_class=NUM_CLASS, + weights='', + use_softmax=False, + drop_last_conv=False): + arch = arch.lower() + if arch == 'ppm_deepsup': + net_decoder = PPMDeepsup( + num_class=num_class, + fc_dim=fc_dim, + use_softmax=use_softmax, + drop_last_conv=drop_last_conv) + elif arch == 'c1_deepsup': + net_decoder = C1DeepSup( + num_class=num_class, + fc_dim=fc_dim, + use_softmax=use_softmax, + drop_last_conv=drop_last_conv) + else: + raise Exception('Architecture undefined!') + + net_decoder.apply(ModelBuilder.weights_init) + if len(weights) > 0: + print('Loading weights for net_decoder') + net_decoder.load_state_dict( + torch.load(weights, map_location=lambda storage, loc: storage), + strict=False) + return net_decoder + + @staticmethod + def get_decoder(weights_path, arch_encoder, arch_decoder, fc_dim, + drop_last_conv, *arts, **kwargs): + path = os.path.join( + weights_path, 'ade20k', + f'ade20k-{arch_encoder}-{arch_decoder}/decoder_epoch_20.pth') + return ModelBuilder.build_decoder( + arch=arch_decoder, + fc_dim=fc_dim, + weights=path, + use_softmax=True, + drop_last_conv=drop_last_conv) + + @staticmethod + def get_encoder(weights_path, arch_encoder, arch_decoder, fc_dim, + segmentation, *arts, **kwargs): + if segmentation: + path = os.path.join( + weights_path, 'ade20k', + f'ade20k-{arch_encoder}-{arch_decoder}/encoder_epoch_20.pth') + else: + path = '' + return ModelBuilder.build_encoder( + arch=arch_encoder, + fc_dim=fc_dim, + weights=path, + model_dir=weights_path) + + +def conv3x3_bn_relu(in_planes, out_planes, stride=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False), + BatchNorm2d(out_planes), + nn.ReLU(inplace=True), + ) + + +# pyramid pooling, deep supervision +class PPMDeepsup(nn.Module): + + def __init__(self, + num_class=NUM_CLASS, + fc_dim=4096, + use_softmax=False, + pool_scales=(1, 2, 3, 6), + drop_last_conv=False): + super().__init__() + self.use_softmax = use_softmax + self.drop_last_conv = drop_last_conv + + self.ppm = [] + for scale in pool_scales: + self.ppm.append( + nn.Sequential( + nn.AdaptiveAvgPool2d(scale), + nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), + BatchNorm2d(512), nn.ReLU(inplace=True))) + self.ppm = nn.ModuleList(self.ppm) + self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) + + self.conv_last = nn.Sequential( + nn.Conv2d( + fc_dim + len(pool_scales) * 512, + 512, + kernel_size=3, + padding=1, + bias=False), BatchNorm2d(512), nn.ReLU(inplace=True), + nn.Dropout2d(0.1), nn.Conv2d(512, num_class, kernel_size=1)) + self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + self.dropout_deepsup = nn.Dropout2d(0.1) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + + input_size = conv5.size() + ppm_out = [conv5] + for pool_scale in self.ppm: + ppm_out.append( + nn.functional.interpolate( + pool_scale(conv5), (input_size[2], input_size[3]), + mode='bilinear', + align_corners=False)) + ppm_out = torch.cat(ppm_out, 1) + + if self.drop_last_conv: + return ppm_out + else: + x = self.conv_last(ppm_out) + + if self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + x = nn.functional.softmax(x, dim=1) + return x + + # deep sup + conv4 = conv_out[-2] + _ = self.cbr_deepsup(conv4) + _ = self.dropout_deepsup(_) + _ = self.conv_last_deepsup(_) + + x = nn.functional.log_softmax(x, dim=1) + _ = nn.functional.log_softmax(_, dim=1) + + return (x, _) + + +class Resnet(nn.Module): + + def __init__(self, orig_resnet): + super(Resnet, self).__init__() + + # take pretrained resnet, except AvgPool and FC + self.conv1 = orig_resnet.conv1 + self.bn1 = orig_resnet.bn1 + self.relu1 = orig_resnet.relu1 + self.conv2 = orig_resnet.conv2 + self.bn2 = orig_resnet.bn2 + self.relu2 = orig_resnet.relu2 + self.conv3 = orig_resnet.conv3 + self.bn3 = orig_resnet.bn3 + self.relu3 = orig_resnet.relu3 + self.maxpool = orig_resnet.maxpool + self.layer1 = orig_resnet.layer1 + self.layer2 = orig_resnet.layer2 + self.layer3 = orig_resnet.layer3 + self.layer4 = orig_resnet.layer4 + + def forward(self, x, return_feature_maps=False): + conv_out = [] + + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x) + conv_out.append(x) + x = self.layer2(x) + conv_out.append(x) + x = self.layer3(x) + conv_out.append(x) + x = self.layer4(x) + conv_out.append(x) + + if return_feature_maps: + return conv_out + return [x] + + +# Resnet Dilated +class ResnetDilated(nn.Module): + + def __init__(self, orig_resnet, dilate_scale=8): + super().__init__() + from functools import partial + + if dilate_scale == 8: + orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2)) + orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4)) + elif dilate_scale == 16: + orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2)) + + # take pretrained resnet, except AvgPool and FC + self.conv1 = orig_resnet.conv1 + self.bn1 = orig_resnet.bn1 + self.relu1 = orig_resnet.relu1 + self.conv2 = orig_resnet.conv2 + self.bn2 = orig_resnet.bn2 + self.relu2 = orig_resnet.relu2 + self.conv3 = orig_resnet.conv3 + self.bn3 = orig_resnet.bn3 + self.relu3 = orig_resnet.relu3 + self.maxpool = orig_resnet.maxpool + self.layer1 = orig_resnet.layer1 + self.layer2 = orig_resnet.layer2 + self.layer3 = orig_resnet.layer3 + self.layer4 = orig_resnet.layer4 + + def _nostride_dilate(self, m, dilate): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + # the convolution with stride + if m.stride == (2, 2): + m.stride = (1, 1) + if m.kernel_size == (3, 3): + m.dilation = (dilate // 2, dilate // 2) + m.padding = (dilate // 2, dilate // 2) + # other convoluions + else: + if m.kernel_size == (3, 3): + m.dilation = (dilate, dilate) + m.padding = (dilate, dilate) + + def forward(self, x, return_feature_maps=False): + conv_out = [] + + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x) + conv_out.append(x) + x = self.layer2(x) + conv_out.append(x) + x = self.layer3(x) + conv_out.append(x) + x = self.layer4(x) + conv_out.append(x) + + if return_feature_maps: + return conv_out + return [x] + + +# last conv, deep supervision +class C1DeepSup(nn.Module): + + def __init__(self, + num_class=150, + fc_dim=2048, + use_softmax=False, + drop_last_conv=False): + super(C1DeepSup, self).__init__() + self.use_softmax = use_softmax + self.drop_last_conv = drop_last_conv + + self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) + self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) + + # last conv + self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + + x = self.cbr(conv5) + + if self.drop_last_conv: + return x + else: + x = self.conv_last(x) + + if self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + x = nn.functional.softmax(x, dim=1) + return x + + # deep sup + conv4 = conv_out[-2] + _ = self.cbr_deepsup(conv4) + _ = self.conv_last_deepsup(_) + + x = nn.functional.log_softmax(x, dim=1) + _ = nn.functional.log_softmax(_, dim=1) + + return (x, _) + + +# last conv +class C1(nn.Module): + + def __init__(self, num_class=150, fc_dim=2048, use_softmax=False): + super(C1, self).__init__() + self.use_softmax = use_softmax + + self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) + + # last conv + self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + x = self.cbr(conv5) + x = self.conv_last(x) + + if self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + x = nn.functional.softmax(x, dim=1) + else: + x = nn.functional.log_softmax(x, dim=1) + + return x diff --git a/modelscope/models/cv/image_inpainting/modules/ade20k/resnet.py b/modelscope/models/cv/image_inpainting/modules/ade20k/resnet.py new file mode 100644 index 00000000..7da9ff07 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/ade20k/resnet.py @@ -0,0 +1,183 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +import math +import os + +import torch +import torch.nn as nn +from torch.nn import BatchNorm2d + +__all__ = ['ResNet', 'resnet50'] + + +def conv3x3(in_planes, out_planes, stride=1): + '3x3 convolution with padding' + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn2 = BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000): + self.inplanes = 128 + super(ResNet, self).__init__() + self.conv1 = conv3x3(3, 64, stride=2) + self.bn1 = BatchNorm2d(64) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = conv3x3(64, 64) + self.bn2 = BatchNorm2d(64) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = conv3x3(64, 128) + self.bn3 = BatchNorm2d(128) + self.relu3 = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +def resnet50(pretrained=False, model_dir='', **kwargs): + """Constructs a ResNet-50 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + cached_file = os.path.join(model_dir, 'resnet50-imagenet.pth') + model.load_state_dict( + torch.load(cached_file, map_location='cpu'), strict=False) + return model diff --git a/modelscope/models/cv/image_inpainting/modules/adversarial.py b/modelscope/models/cv/image_inpainting/modules/adversarial.py new file mode 100644 index 00000000..b183876b --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/adversarial.py @@ -0,0 +1,167 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +from typing import Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BaseAdversarialLoss: + + def pre_generator_step(self, real_batch: torch.Tensor, + fake_batch: torch.Tensor, generator: nn.Module, + discriminator: nn.Module): + """ + Prepare for generator step + :param real_batch: Tensor, a batch of real samples + :param fake_batch: Tensor, a batch of samples produced by generator + :param generator: + :param discriminator: + :return: None + """ + + def pre_discriminator_step(self, real_batch: torch.Tensor, + fake_batch: torch.Tensor, generator: nn.Module, + discriminator: nn.Module): + """ + Prepare for discriminator step + :param real_batch: Tensor, a batch of real samples + :param fake_batch: Tensor, a batch of samples produced by generator + :param generator: + :param discriminator: + :return: None + """ + + def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, + mask: Optional[torch.Tensor] = None) \ + -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Calculate generator loss + :param real_batch: Tensor, a batch of real samples + :param fake_batch: Tensor, a batch of samples produced by generator + :param discr_real_pred: Tensor, discriminator output for real_batch + :param discr_fake_pred: Tensor, discriminator output for fake_batch + :param mask: Tensor, actual mask, which was at input of generator when making fake_batch + :return: total generator loss along with some values that might be interesting to log + """ + raise NotImplementedError + + def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, + mask: Optional[torch.Tensor] = None) \ + -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Calculate discriminator loss and call .backward() on it + :param real_batch: Tensor, a batch of real samples + :param fake_batch: Tensor, a batch of samples produced by generator + :param discr_real_pred: Tensor, discriminator output for real_batch + :param discr_fake_pred: Tensor, discriminator output for fake_batch + :param mask: Tensor, actual mask, which was at input of generator when making fake_batch + :return: total discriminator loss along with some values that might be interesting to log + """ + raise NotImplementedError + + def interpolate_mask(self, mask, shape): + assert mask is not None + assert self.allow_scale_mask or shape == mask.shape[-2:] + if shape != mask.shape[-2:] and self.allow_scale_mask: + if self.mask_scale_mode == 'maxpool': + mask = F.adaptive_max_pool2d(mask, shape) + else: + mask = F.interpolate( + mask, size=shape, mode=self.mask_scale_mode) + return mask + + +def make_r1_gp(discr_real_pred, real_batch): + if torch.is_grad_enabled(): + grad_real = torch.autograd.grad( + outputs=discr_real_pred.sum(), + inputs=real_batch, + create_graph=True)[0] + grad_penalty = (grad_real.view(grad_real.shape[0], + -1).norm(2, dim=1)**2).mean() + else: + grad_penalty = 0 + real_batch.requires_grad = False + + return grad_penalty + + +class NonSaturatingWithR1(BaseAdversarialLoss): + + def __init__(self, + gp_coef=5, + weight=1, + mask_as_fake_target=False, + allow_scale_mask=False, + mask_scale_mode='nearest', + extra_mask_weight_for_gen=0, + use_unmasked_for_gen=True, + use_unmasked_for_discr=True): + self.gp_coef = gp_coef + self.weight = weight + # use for discr => use for gen; + # otherwise we teach only the discr to pay attention to very small difference + assert use_unmasked_for_gen or (not use_unmasked_for_discr) + # mask as target => use unmasked for discr: + # if we don't care about unmasked regions at all + # then it doesn't matter if the value of mask_as_fake_target is true or false + assert use_unmasked_for_discr or (not mask_as_fake_target) + self.use_unmasked_for_gen = use_unmasked_for_gen + self.use_unmasked_for_discr = use_unmasked_for_discr + self.mask_as_fake_target = mask_as_fake_target + self.allow_scale_mask = allow_scale_mask + self.mask_scale_mode = mask_scale_mode + self.extra_mask_weight_for_gen = extra_mask_weight_for_gen + + def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, + mask=None) \ + -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + fake_loss = F.softplus(-discr_fake_pred) + if (self.mask_as_fake_target and self.extra_mask_weight_for_gen > 0) or \ + not self.use_unmasked_for_gen: # == if masked region should be treated differently + mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:]) + if not self.use_unmasked_for_gen: + fake_loss = fake_loss * mask + else: + pixel_weights = 1 + mask * self.extra_mask_weight_for_gen + fake_loss = fake_loss * pixel_weights + + return fake_loss.mean() * self.weight, dict() + + def pre_discriminator_step(self, real_batch: torch.Tensor, + fake_batch: torch.Tensor, generator: nn.Module, + discriminator: nn.Module): + real_batch.requires_grad = True + + def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, + mask=None) \ + -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + + real_loss = F.softplus(-discr_real_pred) + grad_penalty = make_r1_gp(discr_real_pred, real_batch) * self.gp_coef + fake_loss = F.softplus(discr_fake_pred) + + if not self.use_unmasked_for_discr or self.mask_as_fake_target: + # == if masked region should be treated differently + mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:]) + # use_unmasked_for_discr=False only makes sense for fakes; + # for reals there is no difference beetween two regions + fake_loss = fake_loss * mask + if self.mask_as_fake_target: + fake_loss = fake_loss + (1 + - mask) * F.softplus(-discr_fake_pred) + + sum_discr_loss = real_loss + grad_penalty + fake_loss + metrics = dict( + discr_real_out=discr_real_pred.mean(), + discr_fake_out=discr_fake_pred.mean(), + discr_real_gp=grad_penalty) + return sum_discr_loss.mean(), metrics diff --git a/modelscope/models/cv/image_inpainting/modules/feature_matching.py b/modelscope/models/cv/image_inpainting/modules/feature_matching.py new file mode 100644 index 00000000..c2effb20 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/feature_matching.py @@ -0,0 +1,45 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +from typing import List + +import torch +import torch.nn.functional as F + + +def masked_l2_loss(pred, target, mask, weight_known, weight_missing): + per_pixel_l2 = F.mse_loss(pred, target, reduction='none') + pixel_weights = mask * weight_missing + (1 - mask) * weight_known + return (pixel_weights * per_pixel_l2).mean() + + +def masked_l1_loss(pred, target, mask, weight_known, weight_missing): + per_pixel_l1 = F.l1_loss(pred, target, reduction='none') + pixel_weights = mask * weight_missing + (1 - mask) * weight_known + return (pixel_weights * per_pixel_l1).mean() + + +def feature_matching_loss(fake_features: List[torch.Tensor], + target_features: List[torch.Tensor], + mask=None): + if mask is None: + res = torch.stack([ + F.mse_loss(fake_feat, target_feat) + for fake_feat, target_feat in zip(fake_features, target_features) + ]).mean() + else: + res = 0 + norm = 0 + for fake_feat, target_feat in zip(fake_features, target_features): + cur_mask = F.interpolate( + mask, + size=fake_feat.shape[-2:], + mode='bilinear', + align_corners=False) + error_weights = 1 - cur_mask + cur_val = ((fake_feat - target_feat).pow(2) * error_weights).mean() + res = res + cur_val + norm += 1 + res = res / norm + return res diff --git a/modelscope/models/cv/image_inpainting/modules/ffc.py b/modelscope/models/cv/image_inpainting/modules/ffc.py new file mode 100644 index 00000000..c74425e3 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/ffc.py @@ -0,0 +1,588 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from kornia.geometry.transform import rotate + + +def get_activation(kind='tanh'): + if kind == 'tanh': + return nn.Tanh() + if kind == 'sigmoid': + return nn.Sigmoid() + if kind is False: + return nn.Identity() + raise ValueError(f'Unknown activation kind {kind}') + + +class SELayer(nn.Module): + + def __init__(self, channel, reduction=16): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel, bias=False), nn.Sigmoid()) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + res = x * y.expand_as(x) + return res + + +class FourierUnit(nn.Module): + + def __init__(self, + in_channels, + out_channels, + groups=1, + spatial_scale_factor=None, + spatial_scale_mode='bilinear', + spectral_pos_encoding=False, + use_se=False, + se_kwargs=None, + ffc3d=False, + fft_norm='ortho'): + # bn_layer not used + super(FourierUnit, self).__init__() + self.groups = groups + + self.conv_layer = torch.nn.Conv2d( + in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0), + out_channels=out_channels * 2, + kernel_size=1, + stride=1, + padding=0, + groups=self.groups, + bias=False) + self.bn = torch.nn.BatchNorm2d(out_channels * 2) + self.relu = torch.nn.ReLU(inplace=True) + + # squeeze and excitation block + self.use_se = use_se + if use_se: + if se_kwargs is None: + se_kwargs = {} + self.se = SELayer(self.conv_layer.in_channels, **se_kwargs) + + self.spatial_scale_factor = spatial_scale_factor + self.spatial_scale_mode = spatial_scale_mode + self.spectral_pos_encoding = spectral_pos_encoding + self.ffc3d = ffc3d + self.fft_norm = fft_norm + + def forward(self, x): + batch = x.shape[0] + + if self.spatial_scale_factor is not None: + orig_size = x.shape[-2:] + x = F.interpolate( + x, + scale_factor=self.spatial_scale_factor, + mode=self.spatial_scale_mode, + align_corners=False) + + # (batch, c, h, w/2+1, 2) + fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1) + ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) + ffted = torch.stack((ffted.real, ffted.imag), dim=-1) + ffted = ffted.permute(0, 1, 4, 2, + 3).contiguous() # (batch, c, 2, h, w/2+1) + ffted = ffted.view(( + batch, + -1, + ) + ffted.size()[3:]) + + if self.spectral_pos_encoding: + height, width = ffted.shape[-2:] + coords_vert = torch.linspace(0, 1, + height)[None, None, :, None].expand( + batch, 1, height, width).to(ffted) + coords_hor = torch.linspace(0, 1, + width)[None, None, None, :].expand( + batch, 1, height, width).to(ffted) + ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1) + + if self.use_se: + ffted = self.se(ffted) + + ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1) + ffted = self.relu(self.bn(ffted)) + + ffted = ffted.view(( + batch, + -1, + 2, + ) + ffted.size()[2:]).permute( + 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2) + ffted = torch.complex(ffted[..., 0], ffted[..., 1]) + + ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:] + output = torch.fft.irfftn( + ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) + + if self.spatial_scale_factor is not None: + output = F.interpolate( + output, + size=orig_size, + mode=self.spatial_scale_mode, + align_corners=False) + + return output + + +class SpectralTransform(nn.Module): + + def __init__(self, + in_channels, + out_channels, + stride=1, + groups=1, + enable_lfu=True, + **fu_kwargs): + # bn_layer not used + super(SpectralTransform, self).__init__() + self.enable_lfu = enable_lfu + if stride == 2: + self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2) + else: + self.downsample = nn.Identity() + + self.stride = stride + self.conv1 = nn.Sequential( + nn.Conv2d( + in_channels, + out_channels // 2, + kernel_size=1, + groups=groups, + bias=False), nn.BatchNorm2d(out_channels // 2), + nn.ReLU(inplace=True)) + self.fu = FourierUnit(out_channels // 2, out_channels // 2, groups, + **fu_kwargs) + if self.enable_lfu: + self.lfu = FourierUnit(out_channels // 2, out_channels // 2, + groups) + self.conv2 = torch.nn.Conv2d( + out_channels // 2, + out_channels, + kernel_size=1, + groups=groups, + bias=False) + + def forward(self, x): + + x = self.downsample(x) + x = self.conv1(x) + output = self.fu(x) + + if self.enable_lfu: + n, c, h, w = x.shape + split_no = 2 + split_s = h // split_no + xs = torch.cat( + torch.split(x[:, :c // 4], split_s, dim=-2), + dim=1).contiguous() + xs = torch.cat( + torch.split(xs, split_s, dim=-1), dim=1).contiguous() + xs = self.lfu(xs) + xs = xs.repeat(1, 1, split_no, split_no).contiguous() + else: + xs = 0 + + output = self.conv2(x + output + xs) + + return output + + +class LearnableSpatialTransformWrapper(nn.Module): + + def __init__(self, + impl, + pad_coef=0.5, + angle_init_range=80, + train_angle=True): + super().__init__() + self.impl = impl + self.angle = torch.rand(1) * angle_init_range + if train_angle: + self.angle = nn.Parameter(self.angle, requires_grad=True) + self.pad_coef = pad_coef + + def forward(self, x): + if torch.is_tensor(x): + return self.inverse_transform(self.impl(self.transform(x)), x) + elif isinstance(x, tuple): + x_trans = tuple(self.transform(elem) for elem in x) + y_trans = self.impl(x_trans) + return tuple( + self.inverse_transform(elem, orig_x) + for elem, orig_x in zip(y_trans, x)) + else: + raise ValueError(f'Unexpected input type {type(x)}') + + def transform(self, x): + height, width = x.shape[2:] + pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) + x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode='reflect') + x_padded_rotated = rotate(x_padded, angle=self.angle.to(x_padded)) + return x_padded_rotated + + def inverse_transform(self, y_padded_rotated, orig_x): + height, width = orig_x.shape[2:] + pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) + + y_padded = rotate( + y_padded_rotated, angle=-self.angle.to(y_padded_rotated)) + y_height, y_width = y_padded.shape[2:] + y = y_padded[:, :, pad_h:y_height - pad_h, pad_w:y_width - pad_w] + return y + + +class FFC(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + ratio_gin, + ratio_gout, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False, + enable_lfu=True, + padding_type='reflect', + gated=False, + **spectral_kwargs): + super(FFC, self).__init__() + + assert stride == 1 or stride == 2, 'Stride should be 1 or 2.' + self.stride = stride + + in_cg = int(in_channels * ratio_gin) + in_cl = in_channels - in_cg + out_cg = int(out_channels * ratio_gout) + out_cl = out_channels - out_cg + + self.ratio_gin = ratio_gin + self.ratio_gout = ratio_gout + self.global_in_num = in_cg + + module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d + self.convl2l = module( + in_cl, + out_cl, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode=padding_type) + module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d + self.convl2g = module( + in_cl, + out_cg, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode=padding_type) + module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d + self.convg2l = module( + in_cg, + out_cl, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode=padding_type) + module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform + self.convg2g = module(in_cg, out_cg, stride, + 1 if groups == 1 else groups // 2, enable_lfu, + **spectral_kwargs) + + self.gated = gated + module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d + self.gate = module(in_channels, 2, 1) + + def forward(self, x): + x_l, x_g = x if type(x) is tuple else (x, 0) + out_xl, out_xg = 0, 0 + + if self.gated: + total_input_parts = [x_l] + if torch.is_tensor(x_g): + total_input_parts.append(x_g) + total_input = torch.cat(total_input_parts, dim=1) + + gates = torch.sigmoid(self.gate(total_input)) + g2l_gate, l2g_gate = gates.chunk(2, dim=1) + else: + g2l_gate, l2g_gate = 1, 1 + + if self.ratio_gout != 1: + out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate + if self.ratio_gout != 0: + out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g) + + return out_xl, out_xg + + +class FFC_BN_ACT(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + ratio_gin, + ratio_gout, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False, + norm_layer=nn.BatchNorm2d, + activation_layer=nn.Identity, + padding_type='reflect', + enable_lfu=True, + **kwargs): + super(FFC_BN_ACT, self).__init__() + self.ffc = FFC( + in_channels, + out_channels, + kernel_size, + ratio_gin, + ratio_gout, + stride, + padding, + dilation, + groups, + bias, + enable_lfu, + padding_type=padding_type, + **kwargs) + lnorm = nn.Identity if ratio_gout == 1 else norm_layer + gnorm = nn.Identity if ratio_gout == 0 else norm_layer + global_channels = int(out_channels * ratio_gout) + self.bn_l = lnorm(out_channels - global_channels) + self.bn_g = gnorm(global_channels) + + lact = nn.Identity if ratio_gout == 1 else activation_layer + gact = nn.Identity if ratio_gout == 0 else activation_layer + self.act_l = lact(inplace=True) + self.act_g = gact(inplace=True) + + def forward(self, x): + x_l, x_g = self.ffc(x) + x_l = self.act_l(self.bn_l(x_l)) + x_g = self.act_g(self.bn_g(x_g)) + return x_l, x_g + + +class FFCResnetBlock(nn.Module): + + def __init__(self, + dim, + padding_type, + norm_layer, + activation_layer=nn.ReLU, + dilation=1, + spatial_transform_kwargs=None, + inline=False, + **conv_kwargs): + super().__init__() + self.conv1 = FFC_BN_ACT( + dim, + dim, + kernel_size=3, + padding=dilation, + dilation=dilation, + norm_layer=norm_layer, + activation_layer=activation_layer, + padding_type=padding_type, + **conv_kwargs) + self.conv2 = FFC_BN_ACT( + dim, + dim, + kernel_size=3, + padding=dilation, + dilation=dilation, + norm_layer=norm_layer, + activation_layer=activation_layer, + padding_type=padding_type, + **conv_kwargs) + if spatial_transform_kwargs is not None: + self.conv1 = LearnableSpatialTransformWrapper( + self.conv1, **spatial_transform_kwargs) + self.conv2 = LearnableSpatialTransformWrapper( + self.conv2, **spatial_transform_kwargs) + self.inline = inline + + def forward(self, x): + if self.inline: + x_l, x_g = x[:, :-self.conv1.ffc. + global_in_num], x[:, -self.conv1.ffc.global_in_num:] + else: + x_l, x_g = x if type(x) is tuple else (x, 0) + + id_l, id_g = x_l, x_g + + x_l, x_g = self.conv1((x_l, x_g)) + x_l, x_g = self.conv2((x_l, x_g)) + + x_l, x_g = id_l + x_l, id_g + x_g + out = x_l, x_g + if self.inline: + out = torch.cat(out, dim=1) + return out + + +class ConcatTupleLayer(nn.Module): + + def forward(self, x): + assert isinstance(x, tuple) + x_l, x_g = x + assert torch.is_tensor(x_l) or torch.is_tensor(x_g) + if not torch.is_tensor(x_g): + return x_l + return torch.cat(x, dim=1) + + +class FFCResNetGenerator(nn.Module): + + def __init__(self, + input_nc=4, + output_nc=3, + ngf=64, + n_downsampling=3, + n_blocks=18, + norm_layer=nn.BatchNorm2d, + padding_type='reflect', + activation_layer=nn.ReLU, + up_norm_layer=nn.BatchNorm2d, + up_activation=nn.ReLU(True), + init_conv_kwargs={ + 'ratio_gin': 0, + 'ratio_gout': 0, + 'enable_lfu': False + }, + downsample_conv_kwargs={ + 'ratio_gin': 0, + 'ratio_gout': 0, + 'enable_lfu': False + }, + resnet_conv_kwargs={ + 'ratio_gin': 0.75, + 'ratio_gout': 0.75, + 'enable_lfu': False + }, + spatial_transform_layers=None, + spatial_transform_kwargs={}, + add_out_act='sigmoid', + max_features=1024, + out_ffc=False, + out_ffc_kwargs={}): + assert (n_blocks >= 0) + super().__init__() + + model = [ + nn.ReflectionPad2d(3), + FFC_BN_ACT( + input_nc, + ngf, + kernel_size=7, + padding=0, + norm_layer=norm_layer, + activation_layer=activation_layer, + **init_conv_kwargs) + ] + + # downsample + for i in range(n_downsampling): + mult = 2**i + if i == n_downsampling - 1: + cur_conv_kwargs = dict(downsample_conv_kwargs) + cur_conv_kwargs['ratio_gout'] = resnet_conv_kwargs.get( + 'ratio_gin', 0) + else: + cur_conv_kwargs = downsample_conv_kwargs + model += [ + FFC_BN_ACT( + min(max_features, ngf * mult), + min(max_features, ngf * mult * 2), + kernel_size=3, + stride=2, + padding=1, + norm_layer=norm_layer, + activation_layer=activation_layer, + **cur_conv_kwargs) + ] + + mult = 2**n_downsampling + feats_num_bottleneck = min(max_features, ngf * mult) + + # resnet blocks + for i in range(n_blocks): + cur_resblock = FFCResnetBlock( + feats_num_bottleneck, + padding_type=padding_type, + activation_layer=activation_layer, + norm_layer=norm_layer, + **resnet_conv_kwargs) + if spatial_transform_layers is not None and i in spatial_transform_layers: + cur_resblock = LearnableSpatialTransformWrapper( + cur_resblock, **spatial_transform_kwargs) + model += [cur_resblock] + + model += [ConcatTupleLayer()] + + # upsample + for i in range(n_downsampling): + mult = 2**(n_downsampling - i) + model += [ + nn.ConvTranspose2d( + min(max_features, ngf * mult), + min(max_features, int(ngf * mult / 2)), + kernel_size=3, + stride=2, + padding=1, + output_padding=1), + up_norm_layer(min(max_features, int(ngf * mult / 2))), + up_activation + ] + + if out_ffc: + model += [ + FFCResnetBlock( + ngf, + padding_type=padding_type, + activation_layer=activation_layer, + norm_layer=norm_layer, + inline=True, + **out_ffc_kwargs) + ] + + model += [ + nn.ReflectionPad2d(3), + nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0) + ] + if add_out_act: + model.append( + get_activation('tanh' if add_out_act is True else add_out_act)) + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) diff --git a/modelscope/models/cv/image_inpainting/modules/inception.py b/modelscope/models/cv/image_inpainting/modules/inception.py new file mode 100644 index 00000000..5070533d --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/inception.py @@ -0,0 +1,324 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import models + +from modelscope.utils.logger import get_logger + +try: + from torchvision.models.utils import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +# Inception weights ported to Pytorch from +# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz +FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/' \ + 'fid_weights/pt_inception-2015-12-05-6726825d.pth' + +LOGGER = get_logger() + + +class InceptionV3(nn.Module): + """Pretrained InceptionV3 network returning feature maps""" + + # Index of default block of inception to return, + # corresponds to output of final average pooling + DEFAULT_BLOCK_INDEX = 3 + + # Maps feature dimensionality to their output blocks indices + BLOCK_INDEX_BY_DIM = { + 64: 0, # First max pooling features + 192: 1, # Second max pooling featurs + 768: 2, # Pre-aux classifier features + 2048: 3 # Final average pooling features + } + + def __init__(self, + output_blocks=[DEFAULT_BLOCK_INDEX], + resize_input=True, + normalize_input=True, + requires_grad=False, + use_fid_inception=True): + """Build pretrained InceptionV3 + + Parameters + ---------- + output_blocks : list of int + Indices of blocks to return features of. Possible values are: + - 0: corresponds to output of first max pooling + - 1: corresponds to output of second max pooling + - 2: corresponds to output which is fed to aux classifier + - 3: corresponds to output of final average pooling + resize_input : bool + If true, bilinearly resizes input to width and height 299 before + feeding input to model. As the network without fully connected + layers is fully convolutional, it should be able to handle inputs + of arbitrary size, so resizing might not be strictly needed + normalize_input : bool + If true, scales the input from range (0, 1) to the range the + pretrained Inception network expects, namely (-1, 1) + requires_grad : bool + If true, parameters of the model require gradients. Possibly useful + for finetuning the network + use_fid_inception : bool + If true, uses the pretrained Inception model used in Tensorflow's + FID implementation. If false, uses the pretrained Inception model + available in torchvision. The FID Inception model has different + weights and a slightly different structure from torchvision's + Inception model. If you want to compute FID scores, you are + strongly advised to set this parameter to true to get comparable + results. + """ + super(InceptionV3, self).__init__() + + self.resize_input = resize_input + self.normalize_input = normalize_input + self.output_blocks = sorted(output_blocks) + self.last_needed_block = max(output_blocks) + + assert self.last_needed_block <= 3, \ + 'Last possible output block index is 3' + + self.blocks = nn.ModuleList() + + if use_fid_inception: + inception = fid_inception_v3() + else: + inception = models.inception_v3(pretrained=True) + + # Block 0: input to maxpool1 + block0 = [ + inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, + inception.Conv2d_2b_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block0)) + + # Block 1: maxpool1 to maxpool2 + if self.last_needed_block >= 1: + block1 = [ + inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block1)) + + # Block 2: maxpool2 to aux classifier + if self.last_needed_block >= 2: + block2 = [ + inception.Mixed_5b, + inception.Mixed_5c, + inception.Mixed_5d, + inception.Mixed_6a, + inception.Mixed_6b, + inception.Mixed_6c, + inception.Mixed_6d, + inception.Mixed_6e, + ] + self.blocks.append(nn.Sequential(*block2)) + + # Block 3: aux classifier to final avgpool + if self.last_needed_block >= 3: + block3 = [ + inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c, + nn.AdaptiveAvgPool2d(output_size=(1, 1)) + ] + self.blocks.append(nn.Sequential(*block3)) + + for param in self.parameters(): + param.requires_grad = requires_grad + + def forward(self, inp): + """Get Inception feature maps + + Parameters + ---------- + inp : torch.autograd.Variable + Input tensor of shape Bx3xHxW. Values are expected to be in + range (0, 1) + + Returns + ------- + List of torch.autograd.Variable, corresponding to the selected output + block, sorted ascending by index + """ + outp = [] + x = inp + + if self.resize_input: + x = F.interpolate( + x, size=(299, 299), mode='bilinear', align_corners=False) + + if self.normalize_input: + x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) + + for idx, block in enumerate(self.blocks): + x = block(x) + if idx in self.output_blocks: + outp.append(x) + + if idx == self.last_needed_block: + break + + return outp + + +def fid_inception_v3(): + """Build pretrained Inception model for FID computation + + The Inception model for FID computation uses a different set of weights + and has a slightly different structure than torchvision's Inception. + + This method first constructs torchvision's Inception and then patches the + necessary parts that are different in the FID Inception model. + """ + LOGGER.info('fid_inception_v3 called') + inception = models.inception_v3( + num_classes=1008, aux_logits=False, pretrained=False) + LOGGER.info('models.inception_v3 done') + inception.Mixed_5b = FIDInceptionA(192, pool_features=32) + inception.Mixed_5c = FIDInceptionA(256, pool_features=64) + inception.Mixed_5d = FIDInceptionA(288, pool_features=64) + inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) + inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) + inception.Mixed_7b = FIDInceptionE_1(1280) + inception.Mixed_7c = FIDInceptionE_2(2048) + + LOGGER.info('fid_inception_v3 patching done') + + state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) + LOGGER.info('fid_inception_v3 weights downloaded') + + inception.load_state_dict(state_dict) + LOGGER.info('fid_inception_v3 weights loaded into model') + + return inception + + +class FIDInceptionA(models.inception.InceptionA): + """InceptionA block patched for FID computation""" + + def __init__(self, in_channels, pool_features): + super(FIDInceptionA, self).__init__(in_channels, pool_features) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d( + x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionC(models.inception.InceptionC): + """InceptionC block patched for FID computation""" + + def __init__(self, in_channels, channels_7x7): + super(FIDInceptionC, self).__init__(in_channels, channels_7x7) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d( + x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_1(models.inception.InceptionE): + """First InceptionE block patched for FID computation""" + + def __init__(self, in_channels): + super(FIDInceptionE_1, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d( + x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_2(models.inception.InceptionE): + """Second InceptionE block patched for FID computation""" + + def __init__(self, in_channels): + super(FIDInceptionE_2, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: The FID Inception model uses max pooling instead of average + # pooling. This is likely an error in this specific Inception + # implementation, as other Inception models use average pooling here + # (which matches the description in the paper). + branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) diff --git a/modelscope/models/cv/image_inpainting/modules/perceptual.py b/modelscope/models/cv/image_inpainting/modules/perceptual.py new file mode 100644 index 00000000..80fe2b96 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/perceptual.py @@ -0,0 +1,47 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +from .ade20k import ModelBuilder + +IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None] +IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None] + + +class ResNetPL(nn.Module): + + def __init__(self, + weight=1, + weights_path=None, + arch_encoder='resnet50dilated', + segmentation=True): + super().__init__() + self.impl = ModelBuilder.get_encoder( + weights_path=weights_path, + arch_encoder=arch_encoder, + arch_decoder='ppm_deepsup', + fc_dim=2048, + segmentation=segmentation) + self.impl.eval() + for w in self.impl.parameters(): + w.requires_grad_(False) + + self.weight = weight + + def forward(self, pred, target): + pred = (pred - IMAGENET_MEAN.to(pred)) / IMAGENET_STD.to(pred) + target = (target - IMAGENET_MEAN.to(target)) / IMAGENET_STD.to(target) + + pred_feats = self.impl(pred, return_feature_maps=True) + target_feats = self.impl(target, return_feature_maps=True) + + result = torch.stack([ + F.mse_loss(cur_pred, cur_target) + for cur_pred, cur_target in zip(pred_feats, target_feats) + ]).sum() * self.weight + return result diff --git a/modelscope/models/cv/image_inpainting/modules/pix2pixhd.py b/modelscope/models/cv/image_inpainting/modules/pix2pixhd.py new file mode 100644 index 00000000..32e18f3e --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/pix2pixhd.py @@ -0,0 +1,75 @@ +""" +The implementation is adopted from +https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py +""" +import collections +import functools +import logging +from collections import defaultdict +from functools import partial + +import numpy as np +import torch.nn as nn + + +# Defines the PatchGAN discriminator with the specified arguments. +class NLayerDiscriminator(nn.Module): + + def __init__( + self, + input_nc=3, + ndf=64, + n_layers=4, + norm_layer=nn.BatchNorm2d, + ): + super().__init__() + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw - 1.0) / 2)) + sequence = [[ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True) + ]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + + cur_model = [] + cur_model += [ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ] + sequence.append(cur_model) + + nf_prev = nf + nf = min(nf * 2, 512) + + cur_model = [] + cur_model += [ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ] + sequence.append(cur_model) + + sequence += [[ + nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw) + ]] + + for n in range(len(sequence)): + setattr(self, 'model' + str(n), nn.Sequential(*sequence[n])) + + def get_all_activations(self, x): + res = [x] + for n in range(self.n_layers + 2): + model = getattr(self, 'model' + str(n)) + res.append(model(res[-1])) + return res[1:] + + def forward(self, x): + act = self.get_all_activations(x) + return act[-1], act[:-1] diff --git a/modelscope/models/cv/image_inpainting/refinement.py b/modelscope/models/cv/image_inpainting/refinement.py new file mode 100644 index 00000000..662d8a05 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/refinement.py @@ -0,0 +1,393 @@ +''' +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +''' +import cv2 +import numpy as np +import torch +import torch.nn as nn +from kornia.filters import gaussian_blur2d +from kornia.geometry.transform import resize +from kornia.morphology import erosion +from torch.nn import functional as F +from torch.optim import SGD, Adam +from tqdm import tqdm + +from .modules.ffc import FFCResnetBlock + + +def move_to_device(obj, device): + if isinstance(obj, nn.Module): + return obj.to(device) + if torch.is_tensor(obj): + return obj.to(device) + if isinstance(obj, (tuple, list)): + return [move_to_device(el, device) for el in obj] + if isinstance(obj, dict): + return {name: move_to_device(val, device) for name, val in obj.items()} + raise ValueError(f'Unexpected type {type(obj)}') + + +def ceil_modulo(x, mod): + if x % mod == 0: + return x + return (x // mod + 1) * mod + + +def pad_tensor_to_modulo(img, mod): + batch_size, channels, height, width = img.shape + out_height = ceil_modulo(height, mod) + out_width = ceil_modulo(width, mod) + return F.pad( + img, + pad=(0, out_width - width, 0, out_height - height), + mode='reflect') + + +def _pyrdown(im: torch.Tensor, downsize: tuple = None): + """downscale the image""" + if downsize is None: + downsize = (im.shape[2] // 2, im.shape[3] // 2) + assert im.shape[ + 1] == 3, 'Expected shape for the input to be (n,3,height,width)' + im = gaussian_blur2d(im, kernel_size=(5, 5), sigma=(1.0, 1.0)) + im = F.interpolate(im, size=downsize, mode='bilinear', align_corners=False) + return im + + +def _pyrdown_mask(mask: torch.Tensor, + downsize: tuple = None, + eps: float = 1e-8, + blur_mask: bool = True, + round_up: bool = True): + """downscale the mask tensor + + Parameters + ---------- + mask : torch.Tensor + mask of size (B, 1, H, W) + downsize : tuple, optional + size to downscale to. If None, image is downscaled to half, by default None + eps : float, optional + threshold value for binarizing the mask, by default 1e-8 + blur_mask : bool, optional + if True, apply gaussian filter before downscaling, by default True + round_up : bool, optional + if True, values above eps are marked 1, else, values below 1-eps are marked 0, by default True + + Returns + ------- + torch.Tensor + downscaled mask + """ + + if downsize is None: + downsize = (mask.shape[2] // 2, mask.shape[3] // 2) + assert mask.shape[ + 1] == 1, 'Expected shape for the input to be (n,1,height,width)' + if blur_mask is True: + mask = gaussian_blur2d(mask, kernel_size=(5, 5), sigma=(1.0, 1.0)) + mask = F.interpolate( + mask, size=downsize, mode='bilinear', align_corners=False) + else: + mask = F.interpolate( + mask, size=downsize, mode='bilinear', align_corners=False) + if round_up: + mask[mask >= eps] = 1 + mask[mask < eps] = 0 + else: + mask[mask >= 1.0 - eps] = 1 + mask[mask < 1.0 - eps] = 0 + return mask + + +def _erode_mask(mask: torch.Tensor, + ekernel: torch.Tensor = None, + eps: float = 1e-8): + """erode the mask, and set gray pixels to 0""" + if ekernel is not None: + mask = erosion(mask, ekernel) + mask[mask >= 1.0 - eps] = 1 + mask[mask < 1.0 - eps] = 0 + return mask + + +def _l1_loss(pred: torch.Tensor, + pred_downscaled: torch.Tensor, + ref: torch.Tensor, + mask: torch.Tensor, + mask_downscaled: torch.Tensor, + image: torch.Tensor, + on_pred: bool = True): + """l1 loss on src pixels, and downscaled predictions if on_pred=True""" + loss = torch.mean(torch.abs(pred[mask < 1e-8] - image[mask < 1e-8])) + if on_pred: + loss += torch.mean( + torch.abs(pred_downscaled[mask_downscaled >= 1e-8] + - ref[mask_downscaled >= 1e-8])) + return loss + + +def _infer(image: torch.Tensor, + mask: torch.Tensor, + forward_front: nn.Module, + forward_rears: nn.Module, + ref_lower_res: torch.Tensor, + orig_shape: tuple, + devices: list, + scale_ind: int, + n_iters: int = 15, + lr: float = 0.002): + """Performs inference with refinement at a given scale. + + Parameters + ---------- + image : torch.Tensor + input image to be inpainted, of size (1,3,H,W) + mask : torch.Tensor + input inpainting mask, of size (1,1,H,W) + forward_front : nn.Module + the front part of the inpainting network + forward_rears : nn.Module + the rear part of the inpainting network + ref_lower_res : torch.Tensor + the inpainting at previous scale, used as reference image + orig_shape : tuple + shape of the original input image before padding + devices : list + list of available devices + scale_ind : int + the scale index + n_iters : int, optional + number of iterations of refinement, by default 15 + lr : float, optional + learning rate, by default 0.002 + + Returns + ------- + torch.Tensor + inpainted image + """ + masked_image = image * (1 - mask) + masked_image = torch.cat([masked_image, mask], dim=1) + + mask = mask.repeat(1, 3, 1, 1) + if ref_lower_res is not None: + ref_lower_res = ref_lower_res.detach() + with torch.no_grad(): + z1, z2 = forward_front(masked_image) + # Inference + mask = mask.to(devices[-1]) + ekernel = torch.from_numpy( + cv2.getStructuringElement(cv2.MORPH_ELLIPSE, + (15, 15)).astype(bool)).float() + ekernel = ekernel.to(devices[-1]) + image = image.to(devices[-1]) + z1, z2 = z1.detach().to(devices[0]), z2.detach().to(devices[0]) + z1.requires_grad, z2.requires_grad = True, True + + optimizer = Adam([z1, z2], lr=lr) + + pbar = tqdm(range(n_iters), leave=False) + for idi in pbar: + optimizer.zero_grad() + input_feat = (z1, z2) + for idd, forward_rear in enumerate(forward_rears): + output_feat = forward_rear(input_feat) + if idd < len(devices) - 1: + midz1, midz2 = output_feat + midz1, midz2 = midz1.to(devices[idd + 1]), midz2.to( + devices[idd + 1]) + input_feat = (midz1, midz2) + else: + pred = output_feat + + if ref_lower_res is None: + break + losses = {} + # scaled loss with downsampler + pred_downscaled = _pyrdown(pred[:, :, :orig_shape[0], :orig_shape[1]]) + mask_downscaled = _pyrdown_mask( + mask[:, :1, :orig_shape[0], :orig_shape[1]], + blur_mask=False, + round_up=False) + mask_downscaled = _erode_mask(mask_downscaled, ekernel=ekernel) + mask_downscaled = mask_downscaled.repeat(1, 3, 1, 1) + losses['ms_l1'] = _l1_loss( + pred, + pred_downscaled, + ref_lower_res, + mask, + mask_downscaled, + image, + on_pred=True) + + loss = sum(losses.values()) + pbar.set_description( + 'Refining scale {} using scale {} ...current loss: {:.4f}'.format( + scale_ind + 1, scale_ind, loss.item())) + if idi < n_iters - 1: + loss.backward() + optimizer.step() + del pred_downscaled + del loss + del pred + # "pred" is the prediction after Plug-n-Play module + inpainted = mask * pred + (1 - mask) * image + inpainted = inpainted.detach().cpu() + return inpainted + + +def _get_image_mask_pyramid(batch: dict, min_side: int, max_scales: int, + px_budget: int): + """Build the image mask pyramid + + Parameters + ---------- + batch : dict + batch containing image, mask, etc + min_side : int + minimum side length to limit the number of scales of the pyramid + max_scales : int + maximum number of scales allowed + px_budget : int + the product H*W cannot exceed this budget, because of resource constraints + + Returns + ------- + tuple + image-mask pyramid in the form of list of images and list of masks + """ + + assert batch['image'].shape[ + 0] == 1, 'refiner works on only batches of size 1!' + + h, w = batch['unpad_to_size'] + h, w = h[0].item(), w[0].item() + + image = batch['image'][..., :h, :w] + mask = batch['mask'][..., :h, :w] + if h * w > px_budget: + # resize + ratio = np.sqrt(px_budget / float(h * w)) + h_orig, w_orig = h, w + h, w = int(h * ratio), int(w * ratio) + print( + f'Original image too large for refinement! Resizing {(h_orig,w_orig)} to {(h,w)}...' + ) + image = resize( + image, (h, w), interpolation='bilinear', align_corners=False) + mask = resize( + mask, (h, w), interpolation='bilinear', align_corners=False) + mask[mask > 1e-8] = 1 + breadth = min(h, w) + n_scales = min(1 + int(round(max(0, np.log2(breadth / min_side)))), + max_scales) + ls_images = [] + ls_masks = [] + + ls_images.append(image) + ls_masks.append(mask) + + for _ in range(n_scales - 1): + image_p = _pyrdown(ls_images[-1]) + mask_p = _pyrdown_mask(ls_masks[-1]) + ls_images.append(image_p) + ls_masks.append(mask_p) + # reverse the lists because we want the lowest resolution image as index 0 + return ls_images[::-1], ls_masks[::-1] + + +def refine_predict(batch: dict, inpainter: nn.Module, gpu_ids: str, + modulo: int, n_iters: int, lr: float, min_side: int, + max_scales: int, px_budget: int): + """Refines the inpainting of the network + + Parameters + ---------- + batch : dict + image-mask batch, currently we assume the batchsize to be 1 + inpainter : nn.Module + the inpainting neural network + gpu_ids : str + the GPU ids of the machine to use. If only single GPU, use: "0," + modulo : int + pad the image to ensure dimension % modulo == 0 + n_iters : int + number of iterations of refinement for each scale + lr : float + learning rate + min_side : int + all sides of image on all scales should be >= min_side / sqrt(2) + max_scales : int + max number of downscaling scales for the image-mask pyramid + px_budget : int + pixels budget. Any image will be resized to satisfy height*width <= px_budget + + Returns + ------- + torch.Tensor + inpainted image of size (1,3,H,W) + """ + inpainter = inpainter.model + assert not inpainter.training + assert not inpainter.add_noise_kwargs + assert inpainter.concat_mask + + gpu_ids = [ + f'cuda:{gpuid}' for gpuid in gpu_ids.replace(' ', '').split(',') + if gpuid.isdigit() + ] + n_resnet_blocks = 0 + first_resblock_ind = 0 + found_first_resblock = False + for idl in range(len(inpainter.generator.model)): + if isinstance(inpainter.generator.model[idl], FFCResnetBlock): + n_resnet_blocks += 1 + found_first_resblock = True + elif not found_first_resblock: + first_resblock_ind += 1 + resblocks_per_gpu = n_resnet_blocks // len(gpu_ids) + + devices = [torch.device(gpu_id) for gpu_id in gpu_ids] + + # split the model into front, and rear parts + forward_front = inpainter.generator.model[0:first_resblock_ind] + forward_front.to(devices[0]) + forward_rears = [] + for idd in range(len(gpu_ids)): + if idd < len(gpu_ids) - 1: + forward_rears.append( + inpainter.generator.model[first_resblock_ind + + resblocks_per_gpu + * (idd):first_resblock_ind + + resblocks_per_gpu * (idd + 1)]) + else: + forward_rears.append( + inpainter.generator.model[first_resblock_ind + + resblocks_per_gpu * (idd):]) + forward_rears[idd].to(devices[idd]) + + ls_images, ls_masks = _get_image_mask_pyramid(batch, min_side, max_scales, + px_budget) + image_inpainted = None + + for ids, (image, mask) in enumerate(zip(ls_images, ls_masks)): + orig_shape = image.shape[2:] + image = pad_tensor_to_modulo(image, modulo) + mask = pad_tensor_to_modulo(mask, modulo) + mask[mask >= 1e-8] = 1.0 + mask[mask < 1e-8] = 0.0 + image, mask = move_to_device(image, devices[0]), move_to_device( + mask, devices[0]) + if image_inpainted is not None: + image_inpainted = move_to_device(image_inpainted, devices[-1]) + image_inpainted = _infer(image, mask, forward_front, forward_rears, + image_inpainted, orig_shape, devices, ids, + n_iters, lr) + image_inpainted = image_inpainted[:, :, :orig_shape[0], :orig_shape[1]] + # detach everything to save resources + image = image.detach().cpu() + mask = mask.detach().cpu() + + return image_inpainted diff --git a/modelscope/msdatasets/task_datasets/__init__.py b/modelscope/msdatasets/task_datasets/__init__.py index e2bf5bc1..35c060f0 100644 --- a/modelscope/msdatasets/task_datasets/__init__.py +++ b/modelscope/msdatasets/task_datasets/__init__.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset from .movie_scene_segmentation import MovieSceneSegmentationDataset from .video_summarization_dataset import VideoSummarizationDataset + from .image_inpainting import ImageInpaintingDataset from .passage_ranking_dataset import PassageRankingDataset else: @@ -24,6 +25,7 @@ else: ['ImageInstanceSegmentationCocoDataset'], 'video_summarization_dataset': ['VideoSummarizationDataset'], 'movie_scene_segmentation': ['MovieSceneSegmentationDataset'], + 'image_inpainting': ['ImageInpaintingDataset'], } import sys diff --git a/modelscope/msdatasets/task_datasets/image_inpainting/__init__.py b/modelscope/msdatasets/task_datasets/image_inpainting/__init__.py new file mode 100644 index 00000000..732a1bd7 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/image_inpainting/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .image_inpainting_dataset import ImageInpaintingDataset diff --git a/modelscope/msdatasets/task_datasets/image_inpainting/aug.py b/modelscope/msdatasets/task_datasets/image_inpainting/aug.py new file mode 100644 index 00000000..445bb9b4 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/image_inpainting/aug.py @@ -0,0 +1,100 @@ +""" +The implementation is borrowed from LaMa, +publicly available at https://github.com/saic-mdal/lama +""" +import imgaug.augmenters as iaa +from albumentations import DualIAATransform, to_tuple + + +class IAAAffine2(DualIAATransform): + """Place a regular grid of points on the input and randomly move the neighbourhood of these point around + via affine transformations. + + Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} + + Args: + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask + """ + + def __init__( + self, + scale=(0.7, 1.3), + translate_percent=None, + translate_px=None, + rotate=0.0, + shear=(-0.1, 0.1), + order=1, + cval=0, + mode='reflect', + always_apply=False, + p=0.5, + ): + super(IAAAffine2, self).__init__(always_apply, p) + self.scale = dict(x=scale, y=scale) + self.translate_percent = to_tuple(translate_percent, 0) + self.translate_px = to_tuple(translate_px, 0) + self.rotate = to_tuple(rotate) + self.shear = dict(x=shear, y=shear) + self.order = order + self.cval = cval + self.mode = mode + + @property + def processor(self): + return iaa.Affine( + self.scale, + self.translate_percent, + self.translate_px, + self.rotate, + self.shear, + self.order, + self.cval, + self.mode, + ) + + def get_transform_init_args_names(self): + return ('scale', 'translate_percent', 'translate_px', 'rotate', + 'shear', 'order', 'cval', 'mode') + + +class IAAPerspective2(DualIAATransform): + """Perform a random four point perspective transform of the input. + + Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} + + Args: + scale ((float, float): standard deviation of the normal distributions. These are used to sample + the random distances of the subimage's corners from the full image's corners. Default: (0.05, 0.1). + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask + """ + + def __init__(self, + scale=(0.05, 0.1), + keep_size=True, + always_apply=False, + p=0.5, + order=1, + cval=0, + mode='replicate'): + super(IAAPerspective2, self).__init__(always_apply, p) + self.scale = to_tuple(scale, 1.0) + self.keep_size = keep_size + self.cval = cval + self.mode = mode + + @property + def processor(self): + return iaa.PerspectiveTransform( + self.scale, + keep_size=self.keep_size, + mode=self.mode, + cval=self.cval) + + def get_transform_init_args_names(self): + return ('scale', 'keep_size') diff --git a/modelscope/msdatasets/task_datasets/image_inpainting/image_inpainting_dataset.py b/modelscope/msdatasets/task_datasets/image_inpainting/image_inpainting_dataset.py new file mode 100644 index 00000000..057b8f88 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/image_inpainting/image_inpainting_dataset.py @@ -0,0 +1,337 @@ +""" +Part of the implementation is borrowed and modified from LaMa, +publicly available at https://github.com/saic-mdal/lama +""" +import glob +import os +import os.path as osp +from enum import Enum + +import albumentations as A +import cv2 +import json +import numpy as np +import torch + +from modelscope.metainfo import Models +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.msdatasets.task_datasets.torch_base_dataset import \ + TorchTaskDataset +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from .aug import IAAAffine2, IAAPerspective2 + +LOGGER = get_logger() + + +class LinearRamp: + + def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0): + self.start_value = start_value + self.end_value = end_value + self.start_iter = start_iter + self.end_iter = end_iter + + def __call__(self, i): + if i < self.start_iter: + return self.start_value + if i >= self.end_iter: + return self.end_value + part = (i - self.start_iter) / (self.end_iter - self.start_iter) + return self.start_value * (1 - part) + self.end_value * part + + +class DrawMethod(Enum): + LINE = 'line' + CIRCLE = 'circle' + SQUARE = 'square' + + +def make_random_superres_mask(shape, + min_step=2, + max_step=4, + min_width=1, + max_width=3): + height, width = shape + mask = np.zeros((height, width), np.float32) + step_x = np.random.randint(min_step, max_step + 1) + width_x = np.random.randint(min_width, min(step_x, max_width + 1)) + offset_x = np.random.randint(0, step_x) + + step_y = np.random.randint(min_step, max_step + 1) + width_y = np.random.randint(min_width, min(step_y, max_width + 1)) + offset_y = np.random.randint(0, step_y) + + for dy in range(width_y): + mask[offset_y + dy::step_y] = 1 + for dx in range(width_x): + mask[:, offset_x + dx::step_x] = 1 + return mask[None, ...] + + +class RandomSuperresMaskGenerator: + + def __init__(self, **kwargs): + self.kwargs = kwargs + + def __call__(self, img, iter_i=None): + return make_random_superres_mask(img.shape[1:], **self.kwargs) + + +def make_random_rectangle_mask(shape, + margin=10, + bbox_min_size=30, + bbox_max_size=100, + min_times=0, + max_times=3): + height, width = shape + mask = np.zeros((height, width), np.float32) + bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2) + times = np.random.randint(min_times, max_times + 1) + for i in range(times): + box_width = np.random.randint(bbox_min_size, bbox_max_size) + box_height = np.random.randint(bbox_min_size, bbox_max_size) + start_x = np.random.randint(margin, width - margin - box_width + 1) + start_y = np.random.randint(margin, height - margin - box_height + 1) + mask[start_y:start_y + box_height, start_x:start_x + box_width] = 1 + return mask[None, ...] + + +class RandomRectangleMaskGenerator: + + def __init__(self, + margin=10, + bbox_min_size=30, + bbox_max_size=100, + min_times=0, + max_times=3, + ramp_kwargs=None): + self.margin = margin + self.bbox_min_size = bbox_min_size + self.bbox_max_size = bbox_max_size + self.min_times = min_times + self.max_times = max_times + self.ramp = LinearRamp( + **ramp_kwargs) if ramp_kwargs is not None else None + + def __call__(self, img, iter_i=None, raw_image=None): + coef = self.ramp(iter_i) if (self.ramp is not None) and ( + iter_i is not None) else 1 + cur_bbox_max_size = int(self.bbox_min_size + 1 + + (self.bbox_max_size - self.bbox_min_size) + * coef) + cur_max_times = int(self.min_times + + (self.max_times - self.min_times) * coef) + return make_random_rectangle_mask( + img.shape[1:], + margin=self.margin, + bbox_min_size=self.bbox_min_size, + bbox_max_size=cur_bbox_max_size, + min_times=self.min_times, + max_times=cur_max_times) + + +def make_random_irregular_mask(shape, + max_angle=4, + max_len=60, + max_width=20, + min_times=0, + max_times=10, + draw_method=DrawMethod.LINE): + draw_method = DrawMethod(draw_method) + + height, width = shape + mask = np.zeros((height, width), np.float32) + times = np.random.randint(min_times, max_times + 1) + for i in range(times): + start_x = np.random.randint(width) + start_y = np.random.randint(height) + for j in range(1 + np.random.randint(5)): + angle = 0.01 + np.random.randint(max_angle) + if i % 2 == 0: + angle = 2 * 3.1415926 - angle + length = 10 + np.random.randint(max_len) + brush_w = 5 + np.random.randint(max_width) + end_x = np.clip( + (start_x + length * np.sin(angle)).astype(np.int32), 0, width) + end_y = np.clip( + (start_y + length * np.cos(angle)).astype(np.int32), 0, height) + if draw_method == DrawMethod.LINE: + cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, + brush_w) + elif draw_method == DrawMethod.CIRCLE: + cv2.circle( + mask, (start_x, start_y), + radius=brush_w, + color=1., + thickness=-1) + elif draw_method == DrawMethod.SQUARE: + radius = brush_w // 2 + mask[start_y - radius:start_y + radius, + start_x - radius:start_x + radius] = 1 + start_x, start_y = end_x, end_y + return mask[None, ...] + + +class RandomIrregularMaskGenerator: + + def __init__(self, + max_angle=4, + max_len=60, + max_width=20, + min_times=0, + max_times=10, + ramp_kwargs=None, + draw_method=DrawMethod.LINE): + self.max_angle = max_angle + self.max_len = max_len + self.max_width = max_width + self.min_times = min_times + self.max_times = max_times + self.draw_method = draw_method + self.ramp = LinearRamp( + **ramp_kwargs) if ramp_kwargs is not None else None + + def __call__(self, img, iter_i=None, raw_image=None): + coef = self.ramp(iter_i) if (self.ramp is not None) and ( + iter_i is not None) else 1 + cur_max_len = int(max(1, self.max_len * coef)) + cur_max_width = int(max(1, self.max_width * coef)) + cur_max_times = int(self.min_times + 1 + + (self.max_times - self.min_times) * coef) + return make_random_irregular_mask( + img.shape[1:], + max_angle=self.max_angle, + max_len=cur_max_len, + max_width=cur_max_width, + min_times=self.min_times, + max_times=cur_max_times, + draw_method=self.draw_method) + + +class MixedMaskGenerator: + + def __init__(self, + irregular_proba=1 / 3, + irregular_kwargs=None, + box_proba=1 / 3, + box_kwargs=None, + segm_proba=1 / 3, + segm_kwargs=None, + squares_proba=0, + squares_kwargs=None, + superres_proba=0, + superres_kwargs=None, + outpainting_proba=0, + outpainting_kwargs=None, + invert_proba=0): + self.probas = [] + self.gens = [] + + if irregular_proba > 0: + self.probas.append(irregular_proba) + if irregular_kwargs is None: + irregular_kwargs = {} + else: + irregular_kwargs = dict(irregular_kwargs) + irregular_kwargs['draw_method'] = DrawMethod.LINE + self.gens.append(RandomIrregularMaskGenerator(**irregular_kwargs)) + + if box_proba > 0: + self.probas.append(box_proba) + if box_kwargs is None: + box_kwargs = {} + self.gens.append(RandomRectangleMaskGenerator(**box_kwargs)) + + if squares_proba > 0: + self.probas.append(squares_proba) + if squares_kwargs is None: + squares_kwargs = {} + else: + squares_kwargs = dict(squares_kwargs) + squares_kwargs['draw_method'] = DrawMethod.SQUARE + self.gens.append(RandomIrregularMaskGenerator(**squares_kwargs)) + + if superres_proba > 0: + self.probas.append(superres_proba) + if superres_kwargs is None: + superres_kwargs = {} + self.gens.append(RandomSuperresMaskGenerator(**superres_kwargs)) + + self.probas = np.array(self.probas, dtype='float32') + self.probas /= self.probas.sum() + self.invert_proba = invert_proba + + def __call__(self, img, iter_i=None, raw_image=None): + kind = np.random.choice(len(self.probas), p=self.probas) + gen = self.gens[kind] + result = gen(img, iter_i=iter_i, raw_image=raw_image) + if self.invert_proba > 0 and random.random() < self.invert_proba: + result = 1 - result + return result + + +def get_transforms(test_mode, out_size): + if not test_mode: + transform = A.Compose([ + IAAPerspective2(scale=(0.0, 0.06)), + IAAAffine2(scale=(0.7, 1.3), rotate=(-40, 40), shear=(-0.1, 0.1)), + A.PadIfNeeded(min_height=out_size, min_width=out_size), + A.OpticalDistortion(), + A.RandomCrop(height=out_size, width=out_size), + A.HorizontalFlip(), + A.CLAHE(), + A.RandomBrightnessContrast( + brightness_limit=0.2, contrast_limit=0.2), + A.HueSaturationValue( + hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), + A.ToFloat() + ]) + else: + transform = A.Compose([ + A.PadIfNeeded(min_height=out_size, min_width=out_size), + A.CenterCrop(height=out_size, width=out_size), + A.ToFloat() + ]) + return transform + + +@TASK_DATASETS.register_module( + Tasks.image_inpainting, module_name=Models.image_inpainting) +class ImageInpaintingDataset(TorchTaskDataset): + + def __init__(self, **kwargs): + split_config = kwargs['split_config'] + LOGGER.info(kwargs) + mode = kwargs.get('test_mode', False) + + self.data_root = next(iter(split_config.values())) + if not osp.exists(self.data_root): + self.data_root = osp.dirname(self.data_root) + assert osp.exists(self.data_root) + mask_gen_kwargs = kwargs.get('mask_gen_kwargs', {}) + out_size = kwargs.get('out_size', 256) + self.mask_generator = MixedMaskGenerator(**mask_gen_kwargs) + self.transform = get_transforms(mode, out_size) + self.in_files = sorted( + list( + glob.glob( + osp.join(self.data_root, '**', '*.jpg'), recursive=True)) + + list( + glob.glob( + osp.join(self.data_root, '**', '*.png'), recursive=True))) + self.iter_i = 0 + + def __len__(self): + return len(self.in_files) + + def __getitem__(self, index): + path = self.in_files[index] + img = cv2.imread(path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = self.transform(image=img)['image'] + img = np.transpose(img, (2, 0, 1)) + # TODO: maybe generate mask before augmentations? slower, but better for segmentation-based masks + mask = self.mask_generator(img, iter_i=self.iter_i) + self.iter_i += 1 + return dict(image=img, mask=mask) diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 07a14191..dd59d6fb 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -177,6 +177,7 @@ TASK_OUTPUTS = { Tasks.image_denoising: [OutputKeys.OUTPUT_IMG], Tasks.image_portrait_enhancement: [OutputKeys.OUTPUT_IMG], Tasks.crowd_counting: [OutputKeys.SCORES, OutputKeys.OUTPUT_IMG], + Tasks.image_inpainting: [OutputKeys.OUTPUT_IMG], # image generation task result for a single image # {"output_img": np.array with shape (h, w, 3)} diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index c9a70d14..b18d4465 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -181,6 +181,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_resnet50-bert_video-scene-segmentation_movienet'), Tasks.shop_segmentation: (Pipelines.shop_segmentation, 'damo/cv_vitb16_segmentation_shop-seg'), + Tasks.image_inpainting: (Pipelines.image_inpainting, + 'damo/cv_fft_inpainting_lama'), Tasks.video_inpainting: (Pipelines.video_inpainting, 'damo/cv_video-inpainting'), Tasks.hand_static: (Pipelines.hand_static, diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 55bad09a..118eaf17 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -35,6 +35,7 @@ if TYPE_CHECKING: from .image_super_resolution_pipeline import ImageSuperResolutionPipeline from .image_to_image_generate_pipeline import Image2ImageGenerationPipeline from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline + from .image_inpainting_pipeline import ImageInpaintingPipeline from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline from .realtime_object_detection_pipeline import RealtimeObjectDetectionPipeline from .live_category_pipeline import LiveCategoryPipeline @@ -99,6 +100,7 @@ else: 'live_category_pipeline': ['LiveCategoryPipeline'], 'image_to_image_generation_pipeline': ['Image2ImageGenerationPipeline'], + 'image_inpainting_pipeline': ['ImageInpaintingPipeline'], 'ocr_detection_pipeline': ['OCRDetectionPipeline'], 'ocr_recognition_pipeline': ['OCRRecognitionPipeline'], 'skin_retouching_pipeline': ['SkinRetouchingPipeline'], diff --git a/modelscope/pipelines/cv/image_inpainting_pipeline.py b/modelscope/pipelines/cv/image_inpainting_pipeline.py new file mode 100644 index 00000000..6ae0d63e --- /dev/null +++ b/modelscope/pipelines/cv/image_inpainting_pipeline.py @@ -0,0 +1,146 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch +import torch.nn as nn +from torch.utils.data._utils.collate import default_collate + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.image_inpainting import FFTInpainting +from modelscope.models.cv.image_inpainting.refinement import refine_predict +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors.image import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_inpainting, module_name=Pipelines.image_inpainting) +class ImageInpaintingPipeline(Pipeline): + + def __init__(self, + model: str, + pad_out_to_modulo=8, + refine=False, + **kwargs): + """ + model: model id on modelscope hub. + """ + assert isinstance(model, str), 'model must be a single str' + super().__init__(model=model, auto_collate=False, **kwargs) + self.refine = refine + logger.info(f'loading model from dir {model}') + self.infer_model = FFTInpainting(model, predict_only=True) + if not self.refine: + self.infer_model.to(self.device) + self.infer_model.eval() + logger.info(f'loading model done, refinement is set to {self.refine}') + self.pad_out_to_modulo = pad_out_to_modulo + + def move_to_device(self, obj, device): + if isinstance(obj, nn.Module): + return obj.to(device) + if torch.is_tensor(obj): + return obj.to(device) + if isinstance(obj, (tuple, list)): + return [self.move_to_device(el, device) for el in obj] + if isinstance(obj, dict): + return { + name: self.move_to_device(val, device) + for name, val in obj.items() + } + raise ValueError(f'Unexpected type {type(obj)}') + + def transforms(self, img): + if img.ndim == 3: + img = np.transpose(img, (2, 0, 1)) + out_img = img.astype('float32') / 255 + return out_img + + def ceil_modulo(self, x, mod): + if x % mod == 0: + return x + return (x // mod + 1) * mod + + def pad_img_to_modulo(self, img, mod): + channels, height, width = img.shape + out_height = self.ceil_modulo(height, mod) + out_width = self.ceil_modulo(width, mod) + return np.pad( + img, ((0, 0), (0, out_height - height), (0, out_width - width)), + mode='symmetric') + + def preprocess(self, input: Input) -> Dict[str, Any]: + if isinstance(input, str): + image_name, mask_name = input.split('+') + img = LoadImage.convert_to_ndarray(image_name) + img = self.transforms(img) + mask = np.array(LoadImage(mode='L')(mask_name)['img']) + mask = self.transforms(mask) + elif isinstance(input, PIL.Image.Image): + img = input.crop((0, 0, int(input.width / 2), input.height)) + img = self.transforms(np.array(img)) + mask = input.crop((int(input.width / 2), 0, input.width, + input.height)).convert('L') + mask = self.transforms(np.array(mask)) + else: + raise TypeError('input should be either str or PIL.Image') + result = dict(image=img, mask=mask[None, ...]) + + if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1: + result['unpad_to_size'] = result['image'].shape[1:] + result['image'] = self.pad_img_to_modulo(result['image'], + self.pad_out_to_modulo) + result['mask'] = self.pad_img_to_modulo(result['mask'], + self.pad_out_to_modulo) + + # Since Pipeline use default torch.no_grad() for performing forward func. + # We conduct inference here in case of doing training for refinement. + result = self.perform_inference(result) + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + return {OutputKeys.OUTPUT_IMG: input} + + def perform_inference(self, data): + batch = default_collate([data]) + if self.refine: + assert 'unpad_to_size' in batch, 'Unpadded size is required for the refinement' + assert 'cuda' in str(self.device), 'GPU is required for refinement' + gpu_ids = str(self.device).split(':')[-1] + cur_res = refine_predict( + batch, + self.infer_model, + gpu_ids=gpu_ids, + modulo=self.pad_out_to_modulo, + n_iters=15, + lr=0.002, + min_side=512, + max_scales=3, + px_budget=900000) + cur_res = cur_res[0].permute(1, 2, 0).detach().cpu().numpy() + else: + with torch.no_grad(): + batch = self.move_to_device(batch, self.device) + batch['mask'] = (batch['mask'] > 0) * 1 + batch = self.infer_model(batch) + cur_res = batch['inpainted'][0].permute( + 1, 2, 0).detach().cpu().numpy() + unpad_to_size = batch.get('unpad_to_size', None) + if unpad_to_size is not None: + orig_height, orig_width = unpad_to_size + cur_res = cur_res[:orig_height, :orig_width] + + cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8') + cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR) + return cur_res + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/trainers/__init__.py b/modelscope/trainers/__init__.py index a632642a..86917261 100644 --- a/modelscope/trainers/__init__.py +++ b/modelscope/trainers/__init__.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from .builder import build_trainer from .cv import (ImageInstanceSegmentationTrainer, ImagePortraitEnhancementTrainer, - MovieSceneSegmentationTrainer) + MovieSceneSegmentationTrainer, ImageInpaintingTrainer) from .multi_modal import CLIPTrainer from .nlp import SequenceClassificationTrainer, PassageRankingTrainer from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer @@ -22,7 +22,8 @@ else: 'builder': ['build_trainer'], 'cv': [ 'ImageInstanceSegmentationTrainer', - 'ImagePortraitEnhancementTrainer', 'MovieSceneSegmentationTrainer' + 'ImagePortraitEnhancementTrainer', 'MovieSceneSegmentationTrainer', + 'ImageInpaintingTrainer' ], 'multi_modal': ['CLIPTrainer'], 'nlp': ['SequenceClassificationTrainer', 'PassageRankingTrainer'], diff --git a/modelscope/trainers/cv/__init__.py b/modelscope/trainers/cv/__init__.py index 4c65870e..d09fd75c 100644 --- a/modelscope/trainers/cv/__init__.py +++ b/modelscope/trainers/cv/__init__.py @@ -8,6 +8,7 @@ if TYPE_CHECKING: ImageInstanceSegmentationTrainer from .image_portrait_enhancement_trainer import ImagePortraitEnhancementTrainer from .movie_scene_segmentation_trainer import MovieSceneSegmentationTrainer + from .image_inpainting_trainer import ImageInpaintingTrainer else: _import_structure = { @@ -15,7 +16,8 @@ else: ['ImageInstanceSegmentationTrainer'], 'image_portrait_enhancement_trainer': ['ImagePortraitEnhancementTrainer'], - 'movie_scene_segmentation_trainer': ['MovieSceneSegmentationTrainer'] + 'movie_scene_segmentation_trainer': ['MovieSceneSegmentationTrainer'], + 'image_inpainting_trainer': ['ImageInpaintingTrainer'] } import sys diff --git a/modelscope/trainers/cv/image_inpainting_trainer.py b/modelscope/trainers/cv/image_inpainting_trainer.py new file mode 100644 index 00000000..74d1ed9f --- /dev/null +++ b/modelscope/trainers/cv/image_inpainting_trainer.py @@ -0,0 +1,111 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import time +from collections.abc import Mapping + +from torch import distributed as dist + +from modelscope.metainfo import Trainers +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.trainer import EpochBasedTrainer +from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields, + ConfigKeys, Hubs, ModeKeys, ModelFile, + Tasks, TrainerStages) +from modelscope.utils.data_utils import to_device +from modelscope.utils.file_utils import func_receive_dict_inputs + + +@TRAINERS.register_module(module_name=Trainers.image_inpainting) +class ImageInpaintingTrainer(EpochBasedTrainer): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def train(self, *args, **kwargs): + super().train(*args, **kwargs) + + def evaluate(self, *args, **kwargs): + metric_values = super().evaluate(*args, **kwargs) + return metric_values + + def prediction_step(self, model, inputs): + pass + + def train_loop(self, data_loader): + """ Training loop used by `EpochBasedTrainer.train()` + """ + self.invoke_hook(TrainerStages.before_run) + self._epoch = 0 + self.model.train() + for _ in range(self._epoch, self._max_epochs): + self.invoke_hook(TrainerStages.before_train_epoch) + for i, data_batch in enumerate(data_loader): + data_batch = to_device(data_batch, self.device) + self.data_batch = data_batch + self._inner_iter = i + for idx in range(2): + self.invoke_hook(TrainerStages.before_train_iter) + self.train_step(self.model, data_batch, idx) + self.invoke_hook(TrainerStages.after_train_iter) + del self.data_batch + self._iter += 1 + self._mode = ModeKeys.TRAIN + + if i + 1 >= self.iters_per_epoch: + break + + self.invoke_hook(TrainerStages.after_train_epoch) + self._epoch += 1 + + self.invoke_hook(TrainerStages.after_run) + + def train_step(self, model, inputs, idx): + """ Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + Args: + model (`TorchModel`): The model to train. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + + Return: + `torch.Tensor`: The tensor with training loss on this batch. + """ + # EvaluationHook will do evaluate and change mode to val, return to train mode + # TODO: find more pretty way to change mode + model.train() + self._mode = ModeKeys.TRAIN + # call model forward but not __call__ to skip postprocess + if isinstance(inputs, + Mapping) and not func_receive_dict_inputs(model.forward): + train_outputs = model.model._do_step(**inputs, optimizer_idx=idx) + else: + train_outputs = model.model._do_step(inputs, optimizer_idx=idx) + + if not isinstance(train_outputs, dict): + raise TypeError('"model.forward()" must return a dict') + + # add model output info to log + if 'log_vars' not in train_outputs: + default_keys_pattern = ['loss'] + match_keys = set([]) + for key_p in default_keys_pattern: + match_keys.update( + [key for key in train_outputs.keys() if key_p in key]) + + log_vars = {} + for key in match_keys: + value = train_outputs.get(key, None) + if value is not None: + if dist.is_available() and dist.is_initialized(): + value = value.data.clone() + dist.all_reduce(value.div_(dist.get_world_size())) + log_vars.update({key: value.item()}) + self.log_buffer.update(log_vars) + else: + self.log_buffer.update(train_outputs['log_vars']) + + self.train_outputs = train_outputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 2331dc85..2a5ac694 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -47,6 +47,8 @@ class CVTasks(object): face_emotion = 'face-emotion' product_segmentation = 'product-segmentation' + crowd_counting = 'crowd-counting' + # image editing skin_retouching = 'skin-retouching' image_super_resolution = 'image-super-resolution' @@ -54,6 +56,7 @@ class CVTasks(object): image_color_enhancement = 'image-color-enhancement' image_denoising = 'image-denoising' image_portrait_enhancement = 'image-portrait-enhancement' + image_inpainting = 'image-inpainting' # image generation image_to_image_translation = 'image-to-image-translation' @@ -72,7 +75,6 @@ class CVTasks(object): video_category = 'video-category' video_embedding = 'video-embedding' virtual_try_on = 'virtual-try-on' - crowd_counting = 'crowd-counting' movie_scene_segmentation = 'movie-scene-segmentation' # video editing diff --git a/requirements/cv.txt b/requirements/cv.txt index f907256d..e6ffb5ff 100644 --- a/requirements/cv.txt +++ b/requirements/cv.txt @@ -7,6 +7,8 @@ ffmpeg-python>=0.2.0 ftfy imageio>=2.9.0 imageio-ffmpeg>=0.4.2 +imgaug>=0.4.0 +kornia>=0.5.0 lmdb lpips ml_collections diff --git a/tests/pipelines/test_image_inpainting.py b/tests/pipelines/test_image_inpainting.py new file mode 100644 index 00000000..b89ce399 --- /dev/null +++ b/tests/pipelines/test_image_inpainting.py @@ -0,0 +1,77 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import cv2 +import torch +from PIL import Image + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class ImageInpaintingTest(unittest.TestCase): + + def setUp(self) -> None: + self.input_location = 'data/test/images/image_inpainting/image_inpainting.png' + self.input_mask_location = 'data/test/images/image_inpainting/image_inpainting_mask.png' + self.model_id = 'damo/cv_fft_inpainting_lama' + + def save_result(self, result): + vis_img = result[OutputKeys.OUTPUT_IMG] + cv2.imwrite('result.png', vis_img) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_inpainting(self): + inpainting = pipeline(Tasks.image_inpainting, model=self.model_id) + result = inpainting(self.input_location + '+' + + self.input_mask_location) + if result: + self.save_result(result) + else: + raise ValueError('process error') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') + def test_inpainting_with_refinement(self): + # if input image is HR, set refine=True is more better + inpainting = pipeline( + Tasks.image_inpainting, model=self.model_id, refine=True) + result = inpainting(self.input_location + '+' + + self.input_mask_location) + if result: + self.save_result(result) + else: + raise ValueError('process error') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_inpainting_with_image(self): + inpainting = pipeline(Tasks.image_inpainting, model=self.model_id) + img = Image.open(self.input_location).convert('RGB') + mask = Image.open(self.input_mask_location).convert('RGB') + img_new = Image.new('RGB', (img.width + mask.width, img.height)) + img_new.paste(img, (0, 0)) + img_new.paste(mask, (img.width, 0)) + result = inpainting(img_new) + if result: + self.save_result(result) + else: + raise ValueError('process error') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_inpainting_with_default_task(self): + inpainting = pipeline(Tasks.image_inpainting) + result = inpainting(self.input_location + '+' + + self.input_mask_location) + if result: + self.save_result(result) + else: + raise ValueError('process error') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/run_config.yaml b/tests/run_config.yaml index 4c571b7f..b4149dc9 100644 --- a/tests/run_config.yaml +++ b/tests/run_config.yaml @@ -10,6 +10,7 @@ isolated: # test cases that may require excessive anmount of GPU memory, which - test_easycv_trainer.py - test_segformer.py - test_segmentation_pipeline.py + - test_image_inpainting.py envs: default: # default env, case not in other env will in default, pytorch. diff --git a/tests/trainers/test_image_inpainting_trainer.py b/tests/trainers/test_image_inpainting_trainer.py new file mode 100644 index 00000000..807fe64f --- /dev/null +++ b/tests/trainers/test_image_inpainting_trainer.py @@ -0,0 +1,84 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.models.cv.image_inpainting import FFTInpainting +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config, ConfigDict +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class ImageInpaintingTrainerTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + self.model_id = 'damo/cv_fft_inpainting_lama' + self.cache_path = snapshot_download(self.model_id) + cfg = Config.from_file( + os.path.join(self.cache_path, ModelFile.CONFIGURATION)) + + train_data_cfg = ConfigDict( + name='PlacesToydataset', + split='train', + mask_gen_kwargs=cfg.dataset.mask_gen_kwargs, + out_size=cfg.dataset.train_out_size, + test_mode=False) + + test_data_cfg = ConfigDict( + name='PlacesToydataset', + split='test', + mask_gen_kwargs=cfg.dataset.mask_gen_kwargs, + out_size=cfg.dataset.val_out_size, + test_mode=True) + + self.train_dataset = MsDataset.load( + dataset_name=train_data_cfg.name, + split=train_data_cfg.split, + mask_gen_kwargs=train_data_cfg.mask_gen_kwargs, + out_size=train_data_cfg.out_size, + test_mode=train_data_cfg.test_mode) + assert next( + iter(self.train_dataset.config_kwargs['split_config'].values())) + + self.test_dataset = MsDataset.load( + dataset_name=test_data_cfg.name, + split=test_data_cfg.split, + mask_gen_kwargs=test_data_cfg.mask_gen_kwargs, + out_size=test_data_cfg.out_size, + test_mode=test_data_cfg.test_mode) + assert next( + iter(self.test_dataset.config_kwargs['split_config'].values())) + + def tearDown(self): + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer(self): + kwargs = dict( + model=self.model_id, + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset) + + trainer = build_trainer( + name=Trainers.image_inpainting, default_args=kwargs) + trainer.train() + results_files = os.listdir(trainer.work_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + + +if __name__ == '__main__': + unittest.main()