Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9491966master
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:403034182fa320130dae0d75b92e85e0850771378e674d65455c403a4958e29c | |||
size 170716 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:ebd5dacad9b75ef80f87eb785d7818421dadb63257da0e91e123766c5913f855 | |||
size 149971 |
@@ -10,6 +10,7 @@ class Models(object): | |||
Model name should only contain model info but not task info. | |||
""" | |||
# vision models | |||
nafnet = 'nafnet' | |||
csrnet = 'csrnet' | |||
cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin' | |||
@@ -59,6 +60,7 @@ class Pipelines(object): | |||
""" | |||
# vision tasks | |||
image_matting = 'unet-image-matting' | |||
image_denoise = 'nafnet-image-denoise' | |||
person_image_cartoon = 'unet-person-image-cartoon' | |||
ocr_detection = 'resnet18-ocr-detection' | |||
action_recognition = 'TAdaConv_action-recognition' | |||
@@ -132,6 +134,7 @@ class Preprocessors(object): | |||
# cv preprocessor | |||
load_image = 'load-image' | |||
image_denoie_preprocessor = 'image-denoise-preprocessor' | |||
image_color_enhance_preprocessor = 'image-color-enhance-preprocessor' | |||
image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor' | |||
@@ -167,6 +170,9 @@ class Metrics(object): | |||
# accuracy | |||
accuracy = 'accuracy' | |||
# metrics for image denoise task | |||
image_denoise_metric = 'image-denoise-metric' | |||
# metric for image instance segmentation task | |||
image_ins_seg_coco_metric = 'image-ins-seg-coco-metric' | |||
# metrics for sequence classification task | |||
@@ -1,6 +1,7 @@ | |||
from .base import Metric | |||
from .builder import METRICS, build_metric, task_default_metrics | |||
from .image_color_enhance_metric import ImageColorEnhanceMetric | |||
from .image_denoise_metric import ImageDenoiseMetric | |||
from .image_instance_segmentation_metric import \ | |||
ImageInstanceSegmentationCOCOMetric | |||
from .sequence_classification_metric import SequenceClassificationMetric | |||
@@ -22,6 +22,7 @@ task_default_metrics = { | |||
Tasks.sentence_similarity: [Metrics.seq_cls_metric], | |||
Tasks.sentiment_classification: [Metrics.seq_cls_metric], | |||
Tasks.text_generation: [Metrics.text_gen_metric], | |||
Tasks.image_denoise: [Metrics.image_denoise_metric], | |||
Tasks.image_color_enhance: [Metrics.image_color_enhance_metric] | |||
} | |||
@@ -0,0 +1,45 @@ | |||
from typing import Dict | |||
import numpy as np | |||
from skimage.metrics import peak_signal_noise_ratio, structural_similarity | |||
from modelscope.metainfo import Metrics | |||
from modelscope.utils.registry import default_group | |||
from modelscope.utils.tensor_utils import (torch_nested_detach, | |||
torch_nested_numpify) | |||
from .base import Metric | |||
from .builder import METRICS, MetricKeys | |||
@METRICS.register_module( | |||
group_key=default_group, module_name=Metrics.image_denoise_metric) | |||
class ImageDenoiseMetric(Metric): | |||
"""The metric computation class for image denoise classes. | |||
""" | |||
pred_name = 'pred' | |||
label_name = 'target' | |||
def __init__(self): | |||
self.preds = [] | |||
self.labels = [] | |||
def add(self, outputs: Dict, inputs: Dict): | |||
ground_truths = outputs[ImageDenoiseMetric.label_name] | |||
eval_results = outputs[ImageDenoiseMetric.pred_name] | |||
self.preds.append( | |||
torch_nested_numpify(torch_nested_detach(eval_results))) | |||
self.labels.append( | |||
torch_nested_numpify(torch_nested_detach(ground_truths))) | |||
def evaluate(self): | |||
psnr_list, ssim_list = [], [] | |||
for (pred, label) in zip(self.preds, self.labels): | |||
psnr_list.append( | |||
peak_signal_noise_ratio(label[0], pred[0], data_range=255)) | |||
ssim_list.append( | |||
structural_similarity( | |||
label[0], pred[0], multichannel=True, data_range=255)) | |||
return { | |||
MetricKeys.PSNR: np.mean(psnr_list), | |||
MetricKeys.SSIM: np.mean(ssim_list) | |||
} |
@@ -22,6 +22,7 @@ except ModuleNotFoundError as e: | |||
try: | |||
from .multi_modal import OfaForImageCaptioning | |||
from .cv import NAFNetForImageDenoise | |||
from .nlp import (BertForMaskedLM, BertForSequenceClassification, | |||
SbertForNLI, SbertForSentenceSimilarity, | |||
SbertForSentimentClassification, | |||
@@ -1,2 +1,3 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from .image_color_enhance.image_color_enhance import ImageColorEnhance | |||
from .image_denoise.nafnet_for_image_denoise import * # noqa F403 |
@@ -0,0 +1,233 @@ | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from .arch_util import LayerNorm2d | |||
class SimpleGate(nn.Module): | |||
def forward(self, x): | |||
x1, x2 = x.chunk(2, dim=1) | |||
return x1 * x2 | |||
class NAFBlock(nn.Module): | |||
def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): | |||
super().__init__() | |||
dw_channel = c * DW_Expand | |||
self.conv1 = nn.Conv2d( | |||
in_channels=c, | |||
out_channels=dw_channel, | |||
kernel_size=1, | |||
padding=0, | |||
stride=1, | |||
groups=1, | |||
bias=True) | |||
self.conv2 = nn.Conv2d( | |||
in_channels=dw_channel, | |||
out_channels=dw_channel, | |||
kernel_size=3, | |||
padding=1, | |||
stride=1, | |||
groups=dw_channel, | |||
bias=True) | |||
self.conv3 = nn.Conv2d( | |||
in_channels=dw_channel // 2, | |||
out_channels=c, | |||
kernel_size=1, | |||
padding=0, | |||
stride=1, | |||
groups=1, | |||
bias=True) | |||
# Simplified Channel Attention | |||
self.sca = nn.Sequential( | |||
nn.AdaptiveAvgPool2d(1), | |||
nn.Conv2d( | |||
in_channels=dw_channel // 2, | |||
out_channels=dw_channel // 2, | |||
kernel_size=1, | |||
padding=0, | |||
stride=1, | |||
groups=1, | |||
bias=True), | |||
) | |||
# SimpleGate | |||
self.sg = SimpleGate() | |||
ffn_channel = FFN_Expand * c | |||
self.conv4 = nn.Conv2d( | |||
in_channels=c, | |||
out_channels=ffn_channel, | |||
kernel_size=1, | |||
padding=0, | |||
stride=1, | |||
groups=1, | |||
bias=True) | |||
self.conv5 = nn.Conv2d( | |||
in_channels=ffn_channel // 2, | |||
out_channels=c, | |||
kernel_size=1, | |||
padding=0, | |||
stride=1, | |||
groups=1, | |||
bias=True) | |||
self.norm1 = LayerNorm2d(c) | |||
self.norm2 = LayerNorm2d(c) | |||
self.dropout1 = nn.Dropout( | |||
drop_out_rate) if drop_out_rate > 0. else nn.Identity() | |||
self.dropout2 = nn.Dropout( | |||
drop_out_rate) if drop_out_rate > 0. else nn.Identity() | |||
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) | |||
self.gamma = nn.Parameter( | |||
torch.zeros((1, c, 1, 1)), requires_grad=True) | |||
def forward(self, inp): | |||
x = inp | |||
x = self.norm1(x) | |||
x = self.conv1(x) | |||
x = self.conv2(x) | |||
x = self.sg(x) | |||
x = x * self.sca(x) | |||
x = self.conv3(x) | |||
x = self.dropout1(x) | |||
y = inp + x * self.beta | |||
x = self.conv4(self.norm2(y)) | |||
x = self.sg(x) | |||
x = self.conv5(x) | |||
x = self.dropout2(x) | |||
return y + x * self.gamma | |||
class NAFNet(nn.Module): | |||
def __init__(self, | |||
img_channel=3, | |||
width=16, | |||
middle_blk_num=1, | |||
enc_blk_nums=[], | |||
dec_blk_nums=[]): | |||
super().__init__() | |||
self.intro = nn.Conv2d( | |||
in_channels=img_channel, | |||
out_channels=width, | |||
kernel_size=3, | |||
padding=1, | |||
stride=1, | |||
groups=1, | |||
bias=True) | |||
self.ending = nn.Conv2d( | |||
in_channels=width, | |||
out_channels=img_channel, | |||
kernel_size=3, | |||
padding=1, | |||
stride=1, | |||
groups=1, | |||
bias=True) | |||
self.encoders = nn.ModuleList() | |||
self.decoders = nn.ModuleList() | |||
self.middle_blks = nn.ModuleList() | |||
self.ups = nn.ModuleList() | |||
self.downs = nn.ModuleList() | |||
chan = width | |||
for num in enc_blk_nums: | |||
self.encoders.append( | |||
nn.Sequential(*[NAFBlock(chan) for _ in range(num)])) | |||
self.downs.append(nn.Conv2d(chan, 2 * chan, 2, 2)) | |||
chan = chan * 2 | |||
self.middle_blks = \ | |||
nn.Sequential( | |||
*[NAFBlock(chan) for _ in range(middle_blk_num)] | |||
) | |||
for num in dec_blk_nums: | |||
self.ups.append( | |||
nn.Sequential( | |||
nn.Conv2d(chan, chan * 2, 1, bias=False), | |||
nn.PixelShuffle(2))) | |||
chan = chan // 2 | |||
self.decoders.append( | |||
nn.Sequential(*[NAFBlock(chan) for _ in range(num)])) | |||
self.padder_size = 2**len(self.encoders) | |||
def forward(self, inp): | |||
B, C, H, W = inp.shape | |||
inp = self.check_image_size(inp) | |||
x = self.intro(inp) | |||
encs = [] | |||
for encoder, down in zip(self.encoders, self.downs): | |||
x = encoder(x) | |||
encs.append(x) | |||
x = down(x) | |||
x = self.middle_blks(x) | |||
for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): | |||
x = up(x) | |||
x = x + enc_skip | |||
x = decoder(x) | |||
x = self.ending(x) | |||
x = x + inp | |||
return x[:, :, :H, :W] | |||
def check_image_size(self, x): | |||
_, _, h, w = x.size() | |||
mod_pad_h = (self.padder_size | |||
- h % self.padder_size) % self.padder_size | |||
mod_pad_w = (self.padder_size | |||
- w % self.padder_size) % self.padder_size | |||
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) | |||
return x | |||
class PSNRLoss(nn.Module): | |||
def __init__(self, loss_weight=1.0, reduction='mean', toY=False): | |||
super(PSNRLoss, self).__init__() | |||
assert reduction == 'mean' | |||
self.loss_weight = loss_weight | |||
self.scale = 10 / np.log(10) | |||
self.toY = toY | |||
self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1) | |||
self.first = True | |||
def forward(self, pred, target): | |||
assert len(pred.size()) == 4 | |||
if self.toY: | |||
if self.first: | |||
self.coef = self.coef.to(pred.device) | |||
self.first = False | |||
pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. | |||
target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. | |||
pred, target = pred / 255., target / 255. | |||
pass | |||
assert len(pred.size()) == 4 | |||
return self.loss_weight * self.scale * torch.log(( | |||
(pred - target)**2).mean(dim=(1, 2, 3)) + 1e-8).mean() |
@@ -0,0 +1,42 @@ | |||
import torch | |||
import torch.nn as nn | |||
class LayerNormFunction(torch.autograd.Function): | |||
@staticmethod | |||
def forward(ctx, x, weight, bias, eps): | |||
ctx.eps = eps | |||
N, C, H, W = x.size() | |||
mu = x.mean(1, keepdim=True) | |||
var = (x - mu).pow(2).mean(1, keepdim=True) | |||
y = (x - mu) / (var + eps).sqrt() | |||
ctx.save_for_backward(y, var, weight) | |||
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) | |||
return y | |||
@staticmethod | |||
def backward(ctx, grad_output): | |||
eps = ctx.eps | |||
N, C, H, W = grad_output.size() | |||
y, var, weight = ctx.saved_variables | |||
g = grad_output * weight.view(1, C, 1, 1) | |||
mean_g = g.mean(dim=1, keepdim=True) | |||
mean_gy = (g * y).mean(dim=1, keepdim=True) | |||
gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) | |||
return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum( | |||
dim=0), grad_output.sum(dim=3).sum(dim=2).sum(dim=0), None | |||
class LayerNorm2d(nn.Module): | |||
def __init__(self, channels, eps=1e-6): | |||
super(LayerNorm2d, self).__init__() | |||
self.register_parameter('weight', nn.Parameter(torch.ones(channels))) | |||
self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) | |||
self.eps = eps | |||
def forward(self, x): | |||
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) |
@@ -0,0 +1,119 @@ | |||
import os | |||
from copy import deepcopy | |||
from typing import Any, Dict, Union | |||
import numpy as np | |||
import torch.cuda | |||
from torch.nn.parallel import DataParallel, DistributedDataParallel | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base import Tensor | |||
from modelscope.models.base.base_torch_model import TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
from .nafnet.NAFNet_arch import NAFNet, PSNRLoss | |||
logger = get_logger() | |||
__all__ = ['NAFNetForImageDenoise'] | |||
@MODELS.register_module(Tasks.image_denoise, module_name=Models.nafnet) | |||
class NAFNetForImageDenoise(TorchModel): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""initialize the image denoise model from the `model_dir` path. | |||
Args: | |||
model_dir (str): the model path. | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
self.model_dir = model_dir | |||
self.config = Config.from_file( | |||
os.path.join(self.model_dir, ModelFile.CONFIGURATION)) | |||
model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||
self.model = NAFNet(**self.config.model.network_g) | |||
self.loss = PSNRLoss() | |||
if torch.cuda.is_available(): | |||
self._device = torch.device('cuda') | |||
else: | |||
self._device = torch.device('cpu') | |||
self.model = self.model.to(self._device) | |||
self.model = self._load_pretrained(self.model, model_path) | |||
if self.training: | |||
self.model.train() | |||
else: | |||
self.model.eval() | |||
def _load_pretrained(self, | |||
net, | |||
load_path, | |||
strict=True, | |||
param_key='params'): | |||
if isinstance(net, (DataParallel, DistributedDataParallel)): | |||
net = net.module | |||
load_net = torch.load( | |||
load_path, map_location=lambda storage, loc: storage) | |||
if param_key is not None: | |||
if param_key not in load_net and 'params' in load_net: | |||
param_key = 'params' | |||
logger.info( | |||
f'Loading: {param_key} does not exist, use params.') | |||
if param_key in load_net: | |||
load_net = load_net[param_key] | |||
logger.info( | |||
f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].' | |||
) | |||
# remove unnecessary 'module.' | |||
for k, v in deepcopy(load_net).items(): | |||
if k.startswith('module.'): | |||
load_net[k[7:]] = v | |||
load_net.pop(k) | |||
net.load_state_dict(load_net, strict=strict) | |||
logger.info('load model done.') | |||
return net | |||
def _train_forward(self, input: Tensor, | |||
target: Tensor) -> Dict[str, Tensor]: | |||
preds = self.model(input) | |||
return {'loss': self.loss(preds, target)} | |||
def _inference_forward(self, input: Tensor) -> Dict[str, Tensor]: | |||
return {'outputs': self.model(input).clamp(0, 1)} | |||
def _evaluate_postprocess(self, input: Tensor, | |||
target: Tensor) -> Dict[str, list]: | |||
preds = self.model(input) | |||
preds = list(torch.split(preds, 1, 0)) | |||
targets = list(torch.split(target, 1, 0)) | |||
preds = [(pred.data * 255.).squeeze(0).permute( | |||
1, 2, 0).cpu().numpy().astype(np.uint8) for pred in preds] | |||
targets = [(target.data * 255.).squeeze(0).permute( | |||
1, 2, 0).cpu().numpy().astype(np.uint8) for target in targets] | |||
return {'pred': preds, 'target': targets} | |||
def forward(self, inputs: Dict[str, | |||
Tensor]) -> Dict[str, Union[list, Tensor]]: | |||
"""return the result by the model | |||
Args: | |||
inputs (Tensor): the preprocessed data | |||
Returns: | |||
Dict[str, Tensor]: results | |||
""" | |||
for key, value in inputs.items(): | |||
inputs[key] = inputs[key].to(self._device) | |||
if self.training: | |||
return self._train_forward(**inputs) | |||
elif 'target' in inputs: | |||
return self._evaluate_postprocess(**inputs) | |||
else: | |||
return self._inference_forward(**inputs) |
@@ -0,0 +1,152 @@ | |||
# ------------------------------------------------------------------------ | |||
# Modified from BasicSR (https://github.com/xinntao/BasicSR) | |||
# Copyright 2018-2020 BasicSR Authors | |||
# ------------------------------------------------------------------------ | |||
import os | |||
from os import path as osp | |||
import cv2 | |||
import numpy as np | |||
import torch | |||
from .transforms import mod_crop | |||
def img2tensor(imgs, bgr2rgb=True, float32=True): | |||
"""Numpy array to tensor. | |||
Args: | |||
imgs (list[ndarray] | ndarray): Input images. | |||
bgr2rgb (bool): Whether to change bgr to rgb. | |||
float32 (bool): Whether to change to float32. | |||
Returns: | |||
list[tensor] | tensor: Tensor images. If returned results only have | |||
one element, just return tensor. | |||
""" | |||
def _totensor(img, bgr2rgb, float32): | |||
if img.shape[2] == 3 and bgr2rgb: | |||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |||
img = torch.from_numpy(img.transpose(2, 0, 1)) | |||
if float32: | |||
img = img.float() | |||
return img | |||
if isinstance(imgs, list): | |||
return [_totensor(img, bgr2rgb, float32) for img in imgs] | |||
else: | |||
return _totensor(imgs, bgr2rgb, float32) | |||
def scandir(dir_path, keyword=None, recursive=False, full_path=False): | |||
"""Scan a directory to find the interested files. | |||
Args: | |||
dir_path (str): Path of the directory. | |||
keyword (str | tuple(str), optional): File keyword that we are | |||
interested in. Default: None. | |||
recursive (bool, optional): If set to True, recursively scan the | |||
directory. Default: False. | |||
full_path (bool, optional): If set to True, include the dir_path. | |||
Default: False. | |||
Returns: | |||
A generator for all the interested files with relative pathes. | |||
""" | |||
if (keyword is not None) and not isinstance(keyword, (str, tuple)): | |||
raise TypeError('"suffix" must be a string or tuple of strings') | |||
root = dir_path | |||
def _scandir(dir_path, keyword, recursive): | |||
for entry in os.scandir(dir_path): | |||
if not entry.name.startswith('.') and entry.is_file(): | |||
if full_path: | |||
return_path = entry.path | |||
else: | |||
return_path = osp.relpath(entry.path, root) | |||
if keyword is None: | |||
yield return_path | |||
elif keyword in return_path: | |||
yield return_path | |||
else: | |||
if recursive: | |||
yield from _scandir( | |||
entry.path, keyword=keyword, recursive=recursive) | |||
else: | |||
continue | |||
return _scandir(dir_path, keyword=keyword, recursive=recursive) | |||
def padding(img_lq, img_gt, gt_size): | |||
h, w, _ = img_lq.shape | |||
h_pad = max(0, gt_size - h) | |||
w_pad = max(0, gt_size - w) | |||
if h_pad == 0 and w_pad == 0: | |||
return img_lq, img_gt | |||
img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) | |||
img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) | |||
return img_lq, img_gt | |||
def read_img_seq(path, require_mod_crop=False, scale=1): | |||
"""Read a sequence of images from a given folder path. | |||
Args: | |||
path (list[str] | str): List of image paths or image folder path. | |||
require_mod_crop (bool): Require mod crop for each image. | |||
Default: False. | |||
scale (int): Scale factor for mod_crop. Default: 1. | |||
Returns: | |||
Tensor: size (t, c, h, w), RGB, [0, 1]. | |||
""" | |||
if isinstance(path, list): | |||
img_paths = path | |||
else: | |||
img_paths = sorted(list(scandir(path, full_path=True))) | |||
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths] | |||
if require_mod_crop: | |||
imgs = [mod_crop(img, scale) for img in imgs] | |||
imgs = img2tensor(imgs, bgr2rgb=True, float32=True) | |||
imgs = torch.stack(imgs, dim=0) | |||
return imgs | |||
def paired_paths_from_folder(folders, keys, filename_tmpl): | |||
"""Generate paired paths from folders. | |||
Args: | |||
folders (list[str]): A list of folder path. The order of list should | |||
be [input_folder, gt_folder]. | |||
keys (list[str]): A list of keys identifying folders. The order should | |||
be in consistent with folders, e.g., ['lq', 'gt']. | |||
filename_tmpl (str): Template for each filename. Note that the | |||
template excludes the file extension. Usually the filename_tmpl is | |||
for files in the input folder. | |||
Returns: | |||
list[str]: Returned path list. | |||
""" | |||
assert len(folders) == 2, ( | |||
'The len of folders should be 2 with [input_folder, gt_folder]. ' | |||
f'But got {len(folders)}') | |||
assert len(keys) == 2, ( | |||
'The len of keys should be 2 with [input_key, gt_key]. ' | |||
f'But got {len(keys)}') | |||
input_folder, gt_folder = folders | |||
input_key, gt_key = keys | |||
input_paths = list(scandir(input_folder, keyword='NOISY', recursive=True)) | |||
gt_paths = list(scandir(gt_folder, keyword='GT', recursive=True)) | |||
assert len(input_paths) == len(gt_paths), ( | |||
f'{input_key} and {gt_key} datasets have different number of images: ' | |||
f'{len(input_paths)}, {len(gt_paths)}.') | |||
paths = [] | |||
for idx in range(len(gt_paths)): | |||
gt_path = os.path.join(gt_folder, gt_paths[idx]) | |||
input_path = os.path.join(input_folder, gt_path.replace('GT', 'NOISY')) | |||
paths.append( | |||
dict([(f'{input_key}_path', input_path), | |||
(f'{gt_key}_path', gt_path)])) | |||
return paths |
@@ -0,0 +1,78 @@ | |||
import os | |||
from typing import Callable, List, Optional, Tuple, Union | |||
import cv2 | |||
import numpy as np | |||
from torch.utils import data | |||
from .data_utils import img2tensor, padding, paired_paths_from_folder | |||
from .transforms import augment, paired_random_crop | |||
def default_loader(path): | |||
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0 | |||
class PairedImageDataset(data.Dataset): | |||
"""Paired image dataset for image restoration. | |||
""" | |||
def __init__(self, opt, root, is_train): | |||
super(PairedImageDataset, self).__init__() | |||
self.opt = opt | |||
self.is_train = is_train | |||
self.gt_folder, self.lq_folder = os.path.join( | |||
root, opt.dataroot_gt), os.path.join(root, opt.dataroot_lq) | |||
if opt.filename_tmpl is not None: | |||
self.filename_tmpl = opt.filename_tmpl | |||
else: | |||
self.filename_tmpl = '{}' | |||
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], | |||
['lq', 'gt'], self.filename_tmpl) | |||
def __getitem__(self, index): | |||
scale = self.opt.scale | |||
# Load gt and lq images. Dimension order: HWC; channel order: BGR; | |||
# image range: [0, 1], float32. | |||
gt_path = self.paths[index]['gt_path'] | |||
img_gt = default_loader(gt_path) | |||
lq_path = self.paths[index]['lq_path'] | |||
img_lq = default_loader(lq_path) | |||
# augmentation for training | |||
# if self.is_train: | |||
gt_size = self.opt.gt_size | |||
# padding | |||
img_gt, img_lq = padding(img_gt, img_lq, gt_size) | |||
# random crop | |||
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale) | |||
# flip, rotation | |||
img_gt, img_lq = augment([img_gt, img_lq], self.opt.use_flip, | |||
self.opt.use_rot) | |||
# BGR to RGB, HWC to CHW, numpy to tensor | |||
img_gt, img_lq = img2tensor([img_gt, img_lq], | |||
bgr2rgb=True, | |||
float32=True) | |||
return { | |||
'input': img_lq, | |||
'target': img_gt, | |||
'input_path': lq_path, | |||
'target_path': gt_path | |||
} | |||
def __len__(self): | |||
return len(self.paths) | |||
def to_torch_dataset( | |||
self, | |||
columns: Union[str, List[str]] = None, | |||
preprocessors: Union[Callable, List[Callable]] = None, | |||
**format_kwargs, | |||
): | |||
return self |
@@ -0,0 +1,96 @@ | |||
# Modified from https://github.com/megvii-research/NAFNet/blob/main/basicsr/data/transforms.py | |||
import random | |||
def mod_crop(img, scale): | |||
"""Mod crop images, used during testing. | |||
Args: | |||
img (ndarray): Input image. | |||
scale (int): Scale factor. | |||
Returns: | |||
ndarray: Result image. | |||
""" | |||
img = img.copy() | |||
if img.ndim in (2, 3): | |||
h, w = img.shape[0], img.shape[1] | |||
h_remainder, w_remainder = h % scale, w % scale | |||
img = img[:h - h_remainder, :w - w_remainder, ...] | |||
else: | |||
raise ValueError(f'Wrong img ndim: {img.ndim}.') | |||
return img | |||
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale): | |||
"""Paired random crop. | |||
It crops lists of lq and gt images with corresponding locations. | |||
Args: | |||
img_gts (list[ndarray] | ndarray): GT images. | |||
img_lqs (list[ndarray] | ndarray): LQ images. | |||
gt_patch_size (int): GT patch size. | |||
scale (int): Scale factor. | |||
Returns: | |||
list[ndarray] | ndarray: GT images and LQ images. | |||
""" | |||
if not isinstance(img_gts, list): | |||
img_gts = [img_gts] | |||
if not isinstance(img_lqs, list): | |||
img_lqs = [img_lqs] | |||
h_lq, w_lq, _ = img_lqs[0].shape | |||
h_gt, w_gt, _ = img_gts[0].shape | |||
lq_patch_size = gt_patch_size // scale | |||
# randomly choose top and left coordinates for lq patch | |||
top = random.randint(0, h_lq - lq_patch_size) | |||
left = random.randint(0, w_lq - lq_patch_size) | |||
# crop lq patch | |||
img_lqs = [ | |||
v[top:top + lq_patch_size, left:left + lq_patch_size, ...] | |||
for v in img_lqs | |||
] | |||
# crop corresponding gt patch | |||
top_gt, left_gt = int(top * scale), int(left * scale) | |||
img_gts = [ | |||
v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] | |||
for v in img_gts | |||
] | |||
if len(img_gts) == 1: | |||
img_gts = img_gts[0] | |||
if len(img_lqs) == 1: | |||
img_lqs = img_lqs[0] | |||
return img_gts, img_lqs | |||
def augment(imgs, hflip=True, rotation=True, vflip=False): | |||
"""Augment: horizontal flips | rotate | |||
All the images in the list use the same augmentation. | |||
""" | |||
hflip = hflip and random.random() < 0.5 | |||
if vflip or rotation: | |||
vflip = random.random() < 0.5 | |||
rot90 = rotation and random.random() < 0.5 | |||
def _augment(img): | |||
if hflip: # horizontal | |||
img = img[:, ::-1, :].copy() | |||
if vflip: # vertical | |||
img = img[::-1, :, :].copy() | |||
if rot90: | |||
img = img.transpose(1, 0, 2) | |||
return img | |||
if not isinstance(imgs, list): | |||
imgs = [imgs] | |||
imgs = [_augment(img) for img in imgs] | |||
if len(imgs) == 1: | |||
imgs = imgs[0] | |||
return imgs |
@@ -74,6 +74,7 @@ TASK_OUTPUTS = { | |||
Tasks.image_editing: [OutputKeys.OUTPUT_IMG], | |||
Tasks.image_matting: [OutputKeys.OUTPUT_IMG], | |||
Tasks.image_generation: [OutputKeys.OUTPUT_IMG], | |||
Tasks.image_denoise: [OutputKeys.OUTPUT_IMG], | |||
Tasks.image_colorization: [OutputKeys.OUTPUT_IMG], | |||
Tasks.face_image_generation: [OutputKeys.OUTPUT_IMG], | |||
Tasks.image_super_resolution: [OutputKeys.OUTPUT_IMG], | |||
@@ -35,6 +35,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
), # TODO: revise back after passing the pr | |||
Tasks.image_matting: (Pipelines.image_matting, | |||
'damo/cv_unet_image-matting'), | |||
Tasks.image_denoise: (Pipelines.image_denoise, | |||
'damo/cv_nafnet_image-denoise_sidd'), | |||
Tasks.text_classification: (Pipelines.sentiment_analysis, | |||
'damo/bert-base-sst2'), | |||
Tasks.text_generation: (Pipelines.text_generation, | |||
@@ -6,6 +6,7 @@ try: | |||
from .action_recognition_pipeline import ActionRecognitionPipeline | |||
from .animal_recog_pipeline import AnimalRecogPipeline | |||
from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline | |||
from .image_denoise_pipeline import ImageDenoisePipeline | |||
from .image_color_enhance_pipeline import ImageColorEnhancePipeline | |||
from .virtual_tryon_pipeline import VirtualTryonPipeline | |||
from .image_colorization_pipeline import ImageColorizationPipeline | |||
@@ -0,0 +1,111 @@ | |||
from typing import Any, Dict, Optional, Union | |||
import cv2 | |||
import numpy as np | |||
import torch | |||
from PIL import Image | |||
from torchvision import transforms | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models import Model | |||
from modelscope.models.cv import NAFNetForImageDenoise | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input | |||
from modelscope.preprocessors import ImageDenoisePreprocessor, LoadImage | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
from ..base import Pipeline | |||
from ..builder import PIPELINES | |||
logger = get_logger() | |||
__all__ = ['ImageDenoisePipeline'] | |||
@PIPELINES.register_module( | |||
Tasks.image_denoise, module_name=Pipelines.image_denoise) | |||
class ImageDenoisePipeline(Pipeline): | |||
def __init__(self, | |||
model: Union[NAFNetForImageDenoise, str], | |||
preprocessor: Optional[ImageDenoisePreprocessor] = None, | |||
**kwargs): | |||
""" | |||
use `model` and `preprocessor` to create a cv image denoise pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
model = model if isinstance( | |||
model, NAFNetForImageDenoise) else Model.from_pretrained(model) | |||
model.eval() | |||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
self.config = model.config | |||
if torch.cuda.is_available(): | |||
self._device = torch.device('cuda') | |||
else: | |||
self._device = torch.device('cpu') | |||
self.model = model | |||
logger.info('load image denoise model done') | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
img = LoadImage.convert_to_img(input) | |||
test_transforms = transforms.Compose([transforms.ToTensor()]) | |||
img = test_transforms(img) | |||
result = {'img': img.unsqueeze(0).to(self._device)} | |||
return result | |||
def crop_process(self, input): | |||
output = torch.zeros_like(input) # [1, C, H, W] | |||
# determine crop_h and crop_w | |||
ih, iw = input.shape[-2:] | |||
crop_rows, crop_cols = max(ih // 512, 1), max(iw // 512, 1) | |||
overlap = 16 | |||
step_h, step_w = ih // crop_rows, iw // crop_cols | |||
for y in range(crop_rows): | |||
for x in range(crop_cols): | |||
crop_y = step_h * y | |||
crop_x = step_w * x | |||
crop_h = step_h if y < crop_rows - 1 else ih - crop_y | |||
crop_w = step_w if x < crop_cols - 1 else iw - crop_x | |||
crop_frames = input[:, :, | |||
max(0, crop_y - overlap | |||
):min(crop_y + crop_h + overlap, ih), | |||
max(0, crop_x - overlap | |||
):min(crop_x + crop_w | |||
+ overlap, iw)].contiguous() | |||
h_start = overlap if max(0, crop_y - overlap) > 0 else 0 | |||
w_start = overlap if max(0, crop_x - overlap) > 0 else 0 | |||
h_end = h_start + crop_h if min(crop_y + crop_h | |||
+ overlap, ih) < ih else ih | |||
w_end = w_start + crop_w if min(crop_x + crop_w | |||
+ overlap, iw) < iw else iw | |||
output[:, :, crop_y:crop_y + crop_h, | |||
crop_x:crop_x + crop_w] = self.model._inference_forward( | |||
crop_frames)['outputs'][:, :, h_start:h_end, | |||
w_start:w_end] | |||
return output | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
def set_phase(model, is_train): | |||
if is_train: | |||
model.train() | |||
else: | |||
model.eval() | |||
is_train = False | |||
set_phase(self.model, is_train) | |||
with torch.no_grad(): | |||
output = self.crop_process(input['img']) # output Tensor | |||
return {'output_tensor': output} | |||
def postprocess(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
output_img = (input['output_tensor'].squeeze(0) * 255).cpu().permute( | |||
1, 2, 0).numpy().astype('uint8') | |||
return {OutputKeys.OUTPUT_IMG: output_img} |
@@ -21,6 +21,7 @@ try: | |||
from .space.dialog_state_tracking_preprocessor import * # noqa F403 | |||
from .image import ImageColorEnhanceFinetunePreprocessor | |||
from .image import ImageInstanceSegmentationPreprocessor | |||
from .image import ImageDenoisePreprocessor | |||
except ModuleNotFoundError as e: | |||
if str(e) == "No module named 'tensorflow'": | |||
print(TENSORFLOW_IMPORT_ERROR.format('tts')) | |||
@@ -138,6 +138,31 @@ class ImageColorEnhanceFinetunePreprocessor(Preprocessor): | |||
return data | |||
@PREPROCESSORS.register_module( | |||
Fields.cv, module_name=Preprocessors.image_denoie_preprocessor) | |||
class ImageDenoisePreprocessor(Preprocessor): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
""" | |||
Args: | |||
model_dir (str): model path | |||
""" | |||
super().__init__(*args, **kwargs) | |||
self.model_dir: str = model_dir | |||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
"""process the raw input data | |||
Args: | |||
data Dict[str, Any] | |||
Returns: | |||
Dict[str, Any]: the preprocessed data | |||
""" | |||
return data | |||
@PREPROCESSORS.register_module( | |||
Fields.cv, | |||
module_name=Preprocessors.image_instance_segmentation_preprocessor) | |||
@@ -24,6 +24,7 @@ class CVTasks(object): | |||
image_editing = 'image-editing' | |||
image_generation = 'image-generation' | |||
image_matting = 'image-matting' | |||
image_denoise = 'image-denoise' | |||
ocr_detection = 'ocr-detection' | |||
action_recognition = 'action-recognition' | |||
video_embedding = 'video-embedding' | |||
@@ -0,0 +1,59 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
from PIL import Image | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.models import Model | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import ImageDenoisePipeline, pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.test_utils import test_level | |||
class ImageDenoiseTest(unittest.TestCase): | |||
model_id = 'damo/cv_nafnet_image-denoise_sidd' | |||
demo_image_path = 'data/test/images/noisy-demo-1.png' | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_by_direct_model_download(self): | |||
cache_path = snapshot_download(self.model_id) | |||
pipeline = ImageDenoisePipeline(cache_path) | |||
denoise_img = pipeline( | |||
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] | |||
denoise_img = Image.fromarray(denoise_img) | |||
w, h = denoise_img.size | |||
print('pipeline: the shape of output_img is {}x{}'.format(h, w)) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_model_from_modelhub(self): | |||
model = Model.from_pretrained(self.model_id) | |||
pipeline_ins = pipeline(task=Tasks.image_denoise, model=model) | |||
denoise_img = pipeline_ins( | |||
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] | |||
denoise_img = Image.fromarray(denoise_img) | |||
w, h = denoise_img.size | |||
print('pipeline: the shape of output_img is {}x{}'.format(h, w)) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_with_model_name(self): | |||
pipeline_ins = pipeline(task=Tasks.image_denoise, model=self.model_id) | |||
denoise_img = pipeline_ins( | |||
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] | |||
denoise_img = Image.fromarray(denoise_img) | |||
w, h = denoise_img.size | |||
print('pipeline: the shape of output_img is {}x{}'.format(h, w)) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_default_model(self): | |||
pipeline_ins = pipeline(task=Tasks.image_denoise) | |||
denoise_img = pipeline_ins( | |||
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] | |||
denoise_img = Image.fromarray(denoise_img) | |||
w, h = denoise_img.size | |||
print('pipeline: the shape of output_img is {}x{}'.format(h, w)) | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -0,0 +1,74 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import shutil | |||
import tempfile | |||
import unittest | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.models import NAFNetForImageDenoise | |||
from modelscope.msdatasets.image_denoise_data.image_denoise_dataset import \ | |||
PairedImageDataset | |||
from modelscope.trainers import build_trainer | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.test_utils import test_level | |||
logger = get_logger() | |||
class ImageDenoiseTrainerTest(unittest.TestCase): | |||
def setUp(self): | |||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||
self.tmp_dir = tempfile.TemporaryDirectory().name | |||
if not os.path.exists(self.tmp_dir): | |||
os.makedirs(self.tmp_dir) | |||
self.model_id = 'damo/cv_nafnet_image-denoise_sidd' | |||
self.cache_path = snapshot_download(self.model_id) | |||
self.config = Config.from_file( | |||
os.path.join(self.cache_path, ModelFile.CONFIGURATION)) | |||
self.dataset_train = PairedImageDataset( | |||
self.config.dataset, self.cache_path, is_train=True) | |||
self.dataset_val = PairedImageDataset( | |||
self.config.dataset, self.cache_path, is_train=False) | |||
def tearDown(self): | |||
shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||
super().tearDown() | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_trainer(self): | |||
kwargs = dict( | |||
model=self.model_id, | |||
train_dataset=self.dataset_train, | |||
eval_dataset=self.dataset_val, | |||
work_dir=self.tmp_dir) | |||
trainer = build_trainer(default_args=kwargs) | |||
trainer.train() | |||
results_files = os.listdir(self.tmp_dir) | |||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
for i in range(2): | |||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_trainer_with_model_and_args(self): | |||
model = NAFNetForImageDenoise.from_pretrained(self.cache_path) | |||
kwargs = dict( | |||
cfg_file=os.path.join(self.cache_path, ModelFile.CONFIGURATION), | |||
model=model, | |||
train_dataset=self.dataset_train, | |||
eval_dataset=self.dataset_val, | |||
max_epochs=2, | |||
work_dir=self.tmp_dir) | |||
trainer = build_trainer(default_args=kwargs) | |||
trainer.train() | |||
results_files = os.listdir(self.tmp_dir) | |||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
for i in range(2): | |||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
if __name__ == '__main__': | |||
unittest.main() |