Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9440706 * init * merge mastermaster
@@ -49,6 +49,7 @@ class Pipelines(object): | |||||
action_recognition = 'TAdaConv_action-recognition' | action_recognition = 'TAdaConv_action-recognition' | ||||
animal_recognation = 'resnet101-animal_recog' | animal_recognation = 'resnet101-animal_recog' | ||||
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | ||||
image_super_resolution = 'rrdb-image-super-resolution' | |||||
face_image_generation = 'gan-face-image-generation' | face_image_generation = 'gan-face-image-generation' | ||||
style_transfer = 'AAMS-style-transfer' | style_transfer = 'AAMS-style-transfer' | ||||
@@ -0,0 +1,226 @@ | |||||
import collections.abc | |||||
import math | |||||
import warnings | |||||
from itertools import repeat | |||||
import torch | |||||
import torchvision | |||||
from torch import nn as nn | |||||
from torch.nn import functional as F | |||||
from torch.nn import init as init | |||||
from torch.nn.modules.batchnorm import _BatchNorm | |||||
@torch.no_grad() | |||||
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): | |||||
"""Initialize network weights. | |||||
Args: | |||||
module_list (list[nn.Module] | nn.Module): Modules to be initialized. | |||||
scale (float): Scale initialized weights, especially for residual | |||||
blocks. Default: 1. | |||||
bias_fill (float): The value to fill bias. Default: 0 | |||||
kwargs (dict): Other arguments for initialization function. | |||||
""" | |||||
if not isinstance(module_list, list): | |||||
module_list = [module_list] | |||||
for module in module_list: | |||||
for m in module.modules(): | |||||
if isinstance(m, nn.Conv2d): | |||||
init.kaiming_normal_(m.weight, **kwargs) | |||||
m.weight.data *= scale | |||||
if m.bias is not None: | |||||
m.bias.data.fill_(bias_fill) | |||||
elif isinstance(m, nn.Linear): | |||||
init.kaiming_normal_(m.weight, **kwargs) | |||||
m.weight.data *= scale | |||||
if m.bias is not None: | |||||
m.bias.data.fill_(bias_fill) | |||||
elif isinstance(m, _BatchNorm): | |||||
init.constant_(m.weight, 1) | |||||
if m.bias is not None: | |||||
m.bias.data.fill_(bias_fill) | |||||
def make_layer(basic_block, num_basic_block, **kwarg): | |||||
"""Make layers by stacking the same blocks. | |||||
Args: | |||||
basic_block (nn.module): nn.module class for basic block. | |||||
num_basic_block (int): number of blocks. | |||||
Returns: | |||||
nn.Sequential: Stacked blocks in nn.Sequential. | |||||
""" | |||||
layers = [] | |||||
for _ in range(num_basic_block): | |||||
layers.append(basic_block(**kwarg)) | |||||
return nn.Sequential(*layers) | |||||
class ResidualBlockNoBN(nn.Module): | |||||
"""Residual block without BN. | |||||
It has a style of: | |||||
---Conv-ReLU-Conv-+- | |||||
|________________| | |||||
Args: | |||||
num_feat (int): Channel number of intermediate features. | |||||
Default: 64. | |||||
res_scale (float): Residual scale. Default: 1. | |||||
pytorch_init (bool): If set to True, use pytorch default init, | |||||
otherwise, use default_init_weights. Default: False. | |||||
""" | |||||
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): | |||||
super(ResidualBlockNoBN, self).__init__() | |||||
self.res_scale = res_scale | |||||
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) | |||||
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) | |||||
self.relu = nn.ReLU(inplace=True) | |||||
if not pytorch_init: | |||||
default_init_weights([self.conv1, self.conv2], 0.1) | |||||
def forward(self, x): | |||||
identity = x | |||||
out = self.conv2(self.relu(self.conv1(x))) | |||||
return identity + out * self.res_scale | |||||
class Upsample(nn.Sequential): | |||||
"""Upsample module. | |||||
Args: | |||||
scale (int): Scale factor. Supported scales: 2^n and 3. | |||||
num_feat (int): Channel number of intermediate features. | |||||
""" | |||||
def __init__(self, scale, num_feat): | |||||
m = [] | |||||
if (scale & (scale - 1)) == 0: # scale = 2^n | |||||
for _ in range(int(math.log(scale, 2))): | |||||
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) | |||||
m.append(nn.PixelShuffle(2)) | |||||
elif scale == 3: | |||||
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) | |||||
m.append(nn.PixelShuffle(3)) | |||||
else: | |||||
raise ValueError( | |||||
f'scale {scale} is not supported. Supported scales: 2^n and 3.' | |||||
) | |||||
super(Upsample, self).__init__(*m) | |||||
def flow_warp(x, | |||||
flow, | |||||
interp_mode='bilinear', | |||||
padding_mode='zeros', | |||||
align_corners=True): | |||||
"""Warp an image or feature map with optical flow. | |||||
Args: | |||||
x (Tensor): Tensor with size (n, c, h, w). | |||||
flow (Tensor): Tensor with size (n, h, w, 2), normal value. | |||||
interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. | |||||
padding_mode (str): 'zeros' or 'border' or 'reflection'. | |||||
Default: 'zeros'. | |||||
align_corners (bool): Before pytorch 1.3, the default value is | |||||
align_corners=True. After pytorch 1.3, the default value is | |||||
align_corners=False. Here, we use the True as default. | |||||
Returns: | |||||
Tensor: Warped image or feature map. | |||||
""" | |||||
assert x.size()[-2:] == flow.size()[1:3] | |||||
_, _, h, w = x.size() | |||||
# create mesh grid | |||||
grid_y, grid_x = torch.meshgrid( | |||||
torch.arange(0, h).type_as(x), | |||||
torch.arange(0, w).type_as(x)) | |||||
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 | |||||
grid.requires_grad = False | |||||
vgrid = grid + flow | |||||
# scale grid to [-1,1] | |||||
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 | |||||
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 | |||||
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) | |||||
output = F.grid_sample( | |||||
x, | |||||
vgrid_scaled, | |||||
mode=interp_mode, | |||||
padding_mode=padding_mode, | |||||
align_corners=align_corners) | |||||
# TODO, what if align_corners=False | |||||
return output | |||||
def resize_flow(flow, | |||||
size_type, | |||||
sizes, | |||||
interp_mode='bilinear', | |||||
align_corners=False): | |||||
"""Resize a flow according to ratio or shape. | |||||
Args: | |||||
flow (Tensor): Precomputed flow. shape [N, 2, H, W]. | |||||
size_type (str): 'ratio' or 'shape'. | |||||
sizes (list[int | float]): the ratio for resizing or the final output | |||||
shape. | |||||
1) The order of ratio should be [ratio_h, ratio_w]. For | |||||
downsampling, the ratio should be smaller than 1.0 (i.e., ratio | |||||
< 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., | |||||
ratio > 1.0). | |||||
2) The order of output_size should be [out_h, out_w]. | |||||
interp_mode (str): The mode of interpolation for resizing. | |||||
Default: 'bilinear'. | |||||
align_corners (bool): Whether align corners. Default: False. | |||||
Returns: | |||||
Tensor: Resized flow. | |||||
""" | |||||
_, _, flow_h, flow_w = flow.size() | |||||
if size_type == 'ratio': | |||||
output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) | |||||
elif size_type == 'shape': | |||||
output_h, output_w = sizes[0], sizes[1] | |||||
else: | |||||
raise ValueError( | |||||
f'Size type should be ratio or shape, but got type {size_type}.') | |||||
input_flow = flow.clone() | |||||
ratio_h = output_h / flow_h | |||||
ratio_w = output_w / flow_w | |||||
input_flow[:, 0, :, :] *= ratio_w | |||||
input_flow[:, 1, :, :] *= ratio_h | |||||
resized_flow = F.interpolate( | |||||
input=input_flow, | |||||
size=(output_h, output_w), | |||||
mode=interp_mode, | |||||
align_corners=align_corners) | |||||
return resized_flow | |||||
# TODO: may write a cpp file | |||||
def pixel_unshuffle(x, scale): | |||||
""" Pixel unshuffle. | |||||
Args: | |||||
x (Tensor): Input feature with shape (b, c, hh, hw). | |||||
scale (int): Downsample ratio. | |||||
Returns: | |||||
Tensor: the pixel unshuffled feature. | |||||
""" | |||||
b, c, hh, hw = x.size() | |||||
out_channel = c * (scale**2) | |||||
assert hh % scale == 0 and hw % scale == 0 | |||||
h = hh // scale | |||||
w = hw // scale | |||||
x_view = x.view(b, c, h, scale, w, scale) | |||||
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) |
@@ -0,0 +1,129 @@ | |||||
import torch | |||||
from torch import nn as nn | |||||
from torch.nn import functional as F | |||||
from .arch_util import default_init_weights, make_layer, pixel_unshuffle | |||||
class ResidualDenseBlock(nn.Module): | |||||
"""Residual Dense Block. | |||||
Used in RRDB block in ESRGAN. | |||||
Args: | |||||
num_feat (int): Channel number of intermediate features. | |||||
num_grow_ch (int): Channels for each growth. | |||||
""" | |||||
def __init__(self, num_feat=64, num_grow_ch=32): | |||||
super(ResidualDenseBlock, self).__init__() | |||||
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) | |||||
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) | |||||
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, | |||||
1) | |||||
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, | |||||
1) | |||||
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) | |||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |||||
# initialization | |||||
default_init_weights( | |||||
[self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) | |||||
def forward(self, x): | |||||
x1 = self.lrelu(self.conv1(x)) | |||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) | |||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) | |||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) | |||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) | |||||
# Emperically, we use 0.2 to scale the residual for better performance | |||||
return x5 * 0.2 + x | |||||
class RRDB(nn.Module): | |||||
"""Residual in Residual Dense Block. | |||||
Used in RRDB-Net in ESRGAN. | |||||
Args: | |||||
num_feat (int): Channel number of intermediate features. | |||||
num_grow_ch (int): Channels for each growth. | |||||
""" | |||||
def __init__(self, num_feat, num_grow_ch=32): | |||||
super(RRDB, self).__init__() | |||||
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) | |||||
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) | |||||
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) | |||||
def forward(self, x): | |||||
out = self.rdb1(x) | |||||
out = self.rdb2(out) | |||||
out = self.rdb3(out) | |||||
# Emperically, we use 0.2 to scale the residual for better performance | |||||
return out * 0.2 + x | |||||
class RRDBNet(nn.Module): | |||||
"""Networks consisting of Residual in Residual Dense Block, which is used | |||||
in ESRGAN. | |||||
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. | |||||
We extend ESRGAN for scale x2 and scale x1. | |||||
Note: This is one option for scale 1, scale 2 in RRDBNet. | |||||
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size | |||||
and enlarge the channel size before feeding inputs into the main ESRGAN architecture. | |||||
Args: | |||||
num_in_ch (int): Channel number of inputs. | |||||
num_out_ch (int): Channel number of outputs. | |||||
num_feat (int): Channel number of intermediate features. | |||||
Default: 64 | |||||
num_block (int): Block number in the trunk network. Defaults: 23 | |||||
num_grow_ch (int): Channels for each growth. Default: 32. | |||||
""" | |||||
def __init__(self, | |||||
num_in_ch, | |||||
num_out_ch, | |||||
scale=4, | |||||
num_feat=64, | |||||
num_block=23, | |||||
num_grow_ch=32): | |||||
super(RRDBNet, self).__init__() | |||||
self.scale = scale | |||||
if scale == 2: | |||||
num_in_ch = num_in_ch * 4 | |||||
elif scale == 1: | |||||
num_in_ch = num_in_ch * 16 | |||||
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) | |||||
self.body = make_layer( | |||||
RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) | |||||
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |||||
# upsample | |||||
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |||||
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |||||
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |||||
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) | |||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |||||
def forward(self, x): | |||||
if self.scale == 2: | |||||
feat = pixel_unshuffle(x, scale=2) | |||||
elif self.scale == 1: | |||||
feat = pixel_unshuffle(x, scale=4) | |||||
else: | |||||
feat = x | |||||
feat = self.conv_first(feat) | |||||
body_feat = self.conv_body(self.body(feat)) | |||||
feat = feat + body_feat | |||||
# upsample | |||||
feat = self.lrelu( | |||||
self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) | |||||
feat = self.lrelu( | |||||
self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) | |||||
out = self.conv_last(self.lrelu(self.conv_hr(feat))) | |||||
return out |
@@ -70,6 +70,7 @@ TASK_OUTPUTS = { | |||||
Tasks.image_editing: [OutputKeys.OUTPUT_IMG], | Tasks.image_editing: [OutputKeys.OUTPUT_IMG], | ||||
Tasks.image_matting: [OutputKeys.OUTPUT_IMG], | Tasks.image_matting: [OutputKeys.OUTPUT_IMG], | ||||
Tasks.image_generation: [OutputKeys.OUTPUT_IMG], | Tasks.image_generation: [OutputKeys.OUTPUT_IMG], | ||||
Tasks.image_restoration: [OutputKeys.OUTPUT_IMG], | |||||
# action recognition result for single video | # action recognition result for single video | ||||
# { | # { | ||||
@@ -6,6 +6,7 @@ try: | |||||
from .action_recognition_pipeline import ActionRecognitionPipeline | from .action_recognition_pipeline import ActionRecognitionPipeline | ||||
from .animal_recog_pipeline import AnimalRecogPipeline | from .animal_recog_pipeline import AnimalRecogPipeline | ||||
from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline | from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline | ||||
from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | |||||
from .face_image_generation_pipeline import FaceImageGenerationPipeline | from .face_image_generation_pipeline import FaceImageGenerationPipeline | ||||
except ModuleNotFoundError as e: | except ModuleNotFoundError as e: | ||||
if str(e) == "No module named 'torch'": | if str(e) == "No module named 'torch'": | ||||
@@ -0,0 +1,77 @@ | |||||
from typing import Any, Dict | |||||
import cv2 | |||||
import numpy as np | |||||
import PIL | |||||
import torch | |||||
from modelscope.metainfo import Pipelines | |||||
from modelscope.models.cv.super_resolution import rrdbnet_arch | |||||
from modelscope.outputs import OutputKeys | |||||
from modelscope.pipelines.base import Input | |||||
from modelscope.preprocessors import load_image | |||||
from modelscope.utils.constant import ModelFile, Tasks | |||||
from modelscope.utils.logger import get_logger | |||||
from ..base import Pipeline | |||||
from ..builder import PIPELINES | |||||
logger = get_logger() | |||||
@PIPELINES.register_module( | |||||
Tasks.image_restoration, module_name=Pipelines.image_super_resolution) | |||||
class ImageSuperResolutionPipeline(Pipeline): | |||||
def __init__(self, model: str): | |||||
""" | |||||
use `model` to create a kws pipeline for prediction | |||||
Args: | |||||
model: model id on modelscope hub. | |||||
""" | |||||
super().__init__(model=model) | |||||
self.num_feat = 64 | |||||
self.num_block = 23 | |||||
self.scale = 4 | |||||
self.sr_model = rrdbnet_arch.RRDBNet( | |||||
num_in_ch=3, | |||||
num_out_ch=3, | |||||
num_feat=self.num_feat, | |||||
num_block=self.num_block, | |||||
num_grow_ch=32, | |||||
scale=self.scale) | |||||
model_path = f'{self.model}/{ModelFile.TORCH_MODEL_FILE}' | |||||
self.sr_model.load_state_dict(torch.load(model_path), strict=True) | |||||
logger.info('load model done') | |||||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
if isinstance(input, str): | |||||
img = np.array(load_image(input)) | |||||
elif isinstance(input, PIL.Image.Image): | |||||
img = np.array(input.convert('RGB')) | |||||
elif isinstance(input, np.ndarray): | |||||
if len(input.shape) == 2: | |||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||||
img = input[:, :, ::-1] # in rgb order | |||||
else: | |||||
raise TypeError(f'input should be either str, PIL.Image,' | |||||
f' np.array, but got {type(input)}') | |||||
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) / 255. | |||||
result = {'img': img} | |||||
return result | |||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
self.sr_model.eval() | |||||
with torch.no_grad(): | |||||
out = self.sr_model(input['img']) | |||||
out = out.squeeze(0).permute(1, 2, 0).flip(2) | |||||
out_img = np.clip(out.float().cpu().numpy(), 0, 1) * 255 | |||||
return {OutputKeys.OUTPUT_IMG: out_img.astype(np.uint8)} | |||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
return inputs |
@@ -27,6 +27,7 @@ class CVTasks(object): | |||||
ocr_detection = 'ocr-detection' | ocr_detection = 'ocr-detection' | ||||
action_recognition = 'action-recognition' | action_recognition = 'action-recognition' | ||||
video_embedding = 'video-embedding' | video_embedding = 'video-embedding' | ||||
image_restoration = 'image-restoration' | |||||
style_transfer = 'style-transfer' | style_transfer = 'style-transfer' | ||||
@@ -0,0 +1,37 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os | |||||
import os.path as osp | |||||
import unittest | |||||
import cv2 | |||||
from modelscope.msdatasets import MsDataset | |||||
from modelscope.outputs import OutputKeys | |||||
from modelscope.pipelines import pipeline | |||||
from modelscope.pipelines.base import Pipeline | |||||
from modelscope.utils.constant import Tasks | |||||
from modelscope.utils.test_utils import test_level | |||||
class ImageSuperResolutionTest(unittest.TestCase): | |||||
def setUp(self) -> None: | |||||
self.model_id = 'damo/cv_rrdb_image-super-resolution' | |||||
self.img = 'data/test/images/dogs.jpg' | |||||
def pipeline_inference(self, pipeline: Pipeline, img: str): | |||||
result = pipeline(img) | |||||
if result is not None: | |||||
cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) | |||||
print(f'Output written to {osp.abspath("result.png")}') | |||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
def test_run_modelhub(self): | |||||
super_resolution = pipeline( | |||||
Tasks.image_restoration, model=self.model_id) | |||||
self.pipeline_inference(super_resolution, self.img) | |||||
if __name__ == '__main__': | |||||
unittest.main() |