diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 963fc14f..23e64ffc 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -49,6 +49,7 @@ class Pipelines(object): action_recognition = 'TAdaConv_action-recognition' animal_recognation = 'resnet101-animal_recog' cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' + image_super_resolution = 'rrdb-image-super-resolution' face_image_generation = 'gan-face-image-generation' style_transfer = 'AAMS-style-transfer' diff --git a/modelscope/models/cv/super_resolution/arch_util.py b/modelscope/models/cv/super_resolution/arch_util.py new file mode 100644 index 00000000..4b87c877 --- /dev/null +++ b/modelscope/models/cv/super_resolution/arch_util.py @@ -0,0 +1,226 @@ +import collections.abc +import math +import warnings +from itertools import repeat + +import torch +import torchvision +from torch import nn as nn +from torch.nn import functional as F +from torch.nn import init as init +from torch.nn.modules.batchnorm import _BatchNorm + + +@torch.no_grad() +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + + Args: + basic_block (nn.module): nn.module class for basic block. + num_basic_block (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +class ResidualBlockNoBN(nn.Module): + """Residual block without BN. + + It has a style of: + ---Conv-ReLU-Conv-+- + |________________| + + Args: + num_feat (int): Channel number of intermediate features. + Default: 64. + res_scale (float): Residual scale. Default: 1. + pytorch_init (bool): If set to True, use pytorch default init, + otherwise, use default_init_weights. Default: False. + """ + + def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): + super(ResidualBlockNoBN, self).__init__() + self.res_scale = res_scale + self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.relu = nn.ReLU(inplace=True) + + if not pytorch_init: + default_init_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = self.conv2(self.relu(self.conv1(x))) + return identity + out * self.res_scale + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError( + f'scale {scale} is not supported. Supported scales: 2^n and 3.' + ) + super(Upsample, self).__init__(*m) + + +def flow_warp(x, + flow, + interp_mode='bilinear', + padding_mode='zeros', + align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + Returns: + Tensor: Warped image or feature map. + """ + assert x.size()[-2:] == flow.size()[1:3] + _, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid( + torch.arange(0, h).type_as(x), + torch.arange(0, w).type_as(x)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + output = F.grid_sample( + x, + vgrid_scaled, + mode=interp_mode, + padding_mode=padding_mode, + align_corners=align_corners) + + # TODO, what if align_corners=False + return output + + +def resize_flow(flow, + size_type, + sizes, + interp_mode='bilinear', + align_corners=False): + """Resize a flow according to ratio or shape. + + Args: + flow (Tensor): Precomputed flow. shape [N, 2, H, W]. + size_type (str): 'ratio' or 'shape'. + sizes (list[int | float]): the ratio for resizing or the final output + shape. + 1) The order of ratio should be [ratio_h, ratio_w]. For + downsampling, the ratio should be smaller than 1.0 (i.e., ratio + < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., + ratio > 1.0). + 2) The order of output_size should be [out_h, out_w]. + interp_mode (str): The mode of interpolation for resizing. + Default: 'bilinear'. + align_corners (bool): Whether align corners. Default: False. + + Returns: + Tensor: Resized flow. + """ + _, _, flow_h, flow_w = flow.size() + if size_type == 'ratio': + output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) + elif size_type == 'shape': + output_h, output_w = sizes[0], sizes[1] + else: + raise ValueError( + f'Size type should be ratio or shape, but got type {size_type}.') + + input_flow = flow.clone() + ratio_h = output_h / flow_h + ratio_w = output_w / flow_w + input_flow[:, 0, :, :] *= ratio_w + input_flow[:, 1, :, :] *= ratio_h + resized_flow = F.interpolate( + input=input_flow, + size=(output_h, output_w), + mode=interp_mode, + align_corners=align_corners) + return resized_flow + + +# TODO: may write a cpp file +def pixel_unshuffle(x, scale): + """ Pixel unshuffle. + + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + + assert hh % scale == 0 and hw % scale == 0 + + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) diff --git a/modelscope/models/cv/super_resolution/rrdbnet_arch.py b/modelscope/models/cv/super_resolution/rrdbnet_arch.py new file mode 100644 index 00000000..44947de1 --- /dev/null +++ b/modelscope/models/cv/super_resolution/rrdbnet_arch.py @@ -0,0 +1,129 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .arch_util import default_init_weights, make_layer, pixel_unshuffle + + +class ResidualDenseBlock(nn.Module): + """Residual Dense Block. + + Used in RRDB block in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat=64, num_grow_ch=32): + super(ResidualDenseBlock, self).__init__() + self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) + self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, + 1) + self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, + 1) + self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + default_init_weights( + [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + # Emperically, we use 0.2 to scale the residual for better performance + return x5 * 0.2 + x + + +class RRDB(nn.Module): + """Residual in Residual Dense Block. + + Used in RRDB-Net in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat, num_grow_ch=32): + super(RRDB, self).__init__() + self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) + + def forward(self, x): + out = self.rdb1(x) + out = self.rdb2(out) + out = self.rdb3(out) + # Emperically, we use 0.2 to scale the residual for better performance + return out * 0.2 + x + + +class RRDBNet(nn.Module): + """Networks consisting of Residual in Residual Dense Block, which is used + in ESRGAN. + + ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. + + We extend ESRGAN for scale x2 and scale x1. + Note: This is one option for scale 1, scale 2 in RRDBNet. + We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size + and enlarge the channel size before feeding inputs into the main ESRGAN architecture. + + Args: + num_in_ch (int): Channel number of inputs. + num_out_ch (int): Channel number of outputs. + num_feat (int): Channel number of intermediate features. + Default: 64 + num_block (int): Block number in the trunk network. Defaults: 23 + num_grow_ch (int): Channels for each growth. Default: 32. + """ + + def __init__(self, + num_in_ch, + num_out_ch, + scale=4, + num_feat=64, + num_block=23, + num_grow_ch=32): + super(RRDBNet, self).__init__() + self.scale = scale + if scale == 2: + num_in_ch = num_in_ch * 4 + elif scale == 1: + num_in_ch = num_in_ch * 16 + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = make_layer( + RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) + self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + # upsample + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + if self.scale == 2: + feat = pixel_unshuffle(x, scale=2) + elif self.scale == 1: + feat = pixel_unshuffle(x, scale=4) + else: + feat = x + feat = self.conv_first(feat) + body_feat = self.conv_body(self.body(feat)) + feat = feat + body_feat + # upsample + feat = self.lrelu( + self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) + feat = self.lrelu( + self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.conv_hr(feat))) + return out diff --git a/modelscope/outputs.py b/modelscope/outputs.py index a169f40b..9116b002 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -70,6 +70,7 @@ TASK_OUTPUTS = { Tasks.image_editing: [OutputKeys.OUTPUT_IMG], Tasks.image_matting: [OutputKeys.OUTPUT_IMG], Tasks.image_generation: [OutputKeys.OUTPUT_IMG], + Tasks.image_restoration: [OutputKeys.OUTPUT_IMG], # action recognition result for single video # { diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index bc1f7c49..b6b6dfa7 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -6,6 +6,7 @@ try: from .action_recognition_pipeline import ActionRecognitionPipeline from .animal_recog_pipeline import AnimalRecogPipeline from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline + from .image_super_resolution_pipeline import ImageSuperResolutionPipeline from .face_image_generation_pipeline import FaceImageGenerationPipeline except ModuleNotFoundError as e: if str(e) == "No module named 'torch'": diff --git a/modelscope/pipelines/cv/image_super_resolution_pipeline.py b/modelscope/pipelines/cv/image_super_resolution_pipeline.py new file mode 100644 index 00000000..354b18d3 --- /dev/null +++ b/modelscope/pipelines/cv/image_super_resolution_pipeline.py @@ -0,0 +1,77 @@ +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.super_resolution import rrdbnet_arch +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input +from modelscope.preprocessors import load_image +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from ..base import Pipeline +from ..builder import PIPELINES + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_restoration, module_name=Pipelines.image_super_resolution) +class ImageSuperResolutionPipeline(Pipeline): + + def __init__(self, model: str): + """ + use `model` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + self.num_feat = 64 + self.num_block = 23 + self.scale = 4 + self.sr_model = rrdbnet_arch.RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=self.num_feat, + num_block=self.num_block, + num_grow_ch=32, + scale=self.scale) + + model_path = f'{self.model}/{ModelFile.TORCH_MODEL_FILE}' + self.sr_model.load_state_dict(torch.load(model_path), strict=True) + + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + if isinstance(input, str): + img = np.array(load_image(input)) + elif isinstance(input, PIL.Image.Image): + img = np.array(input.convert('RGB')) + elif isinstance(input, np.ndarray): + if len(input.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + img = input[:, :, ::-1] # in rgb order + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + + img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) / 255. + result = {'img': img} + + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + self.sr_model.eval() + with torch.no_grad(): + out = self.sr_model(input['img']) + + out = out.squeeze(0).permute(1, 2, 0).flip(2) + out_img = np.clip(out.float().cpu().numpy(), 0, 1) * 255 + + return {OutputKeys.OUTPUT_IMG: out_img.astype(np.uint8)} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index e5935c3e..cffe607c 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -27,6 +27,7 @@ class CVTasks(object): ocr_detection = 'ocr-detection' action_recognition = 'action-recognition' video_embedding = 'video-embedding' + image_restoration = 'image-restoration' style_transfer = 'style-transfer' diff --git a/tests/pipelines/test_image_super_resolution.py b/tests/pipelines/test_image_super_resolution.py new file mode 100644 index 00000000..d3012342 --- /dev/null +++ b/tests/pipelines/test_image_super_resolution.py @@ -0,0 +1,37 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import os.path as osp +import unittest + +import cv2 + +from modelscope.msdatasets import MsDataset +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class ImageSuperResolutionTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_rrdb_image-super-resolution' + self.img = 'data/test/images/dogs.jpg' + + def pipeline_inference(self, pipeline: Pipeline, img: str): + result = pipeline(img) + if result is not None: + cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) + print(f'Output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_modelhub(self): + super_resolution = pipeline( + Tasks.image_restoration, model=self.model_id) + + self.pipeline_inference(super_resolution, self.img) + + +if __name__ == '__main__': + unittest.main()