huizheng.hz yingda.chen 3 years ago
parent
commit
8120891eb7
26 changed files with 1056 additions and 0 deletions
  1. +3
    -0
      data/test/images/noisy-demo-0.png
  2. +3
    -0
      data/test/images/noisy-demo-1.png
  3. +6
    -0
      modelscope/metainfo.py
  4. +1
    -0
      modelscope/metrics/__init__.py
  5. +1
    -0
      modelscope/metrics/builder.py
  6. +45
    -0
      modelscope/metrics/image_denoise_metric.py
  7. +1
    -0
      modelscope/models/__init__.py
  8. +1
    -0
      modelscope/models/cv/__init__.py
  9. +0
    -0
      modelscope/models/cv/image_denoise/__init__.py
  10. +233
    -0
      modelscope/models/cv/image_denoise/nafnet/NAFNet_arch.py
  11. +0
    -0
      modelscope/models/cv/image_denoise/nafnet/__init__.py
  12. +42
    -0
      modelscope/models/cv/image_denoise/nafnet/arch_util.py
  13. +119
    -0
      modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py
  14. +0
    -0
      modelscope/msdatasets/image_denoise_data/__init__.py
  15. +152
    -0
      modelscope/msdatasets/image_denoise_data/data_utils.py
  16. +78
    -0
      modelscope/msdatasets/image_denoise_data/image_denoise_dataset.py
  17. +96
    -0
      modelscope/msdatasets/image_denoise_data/transforms.py
  18. +1
    -0
      modelscope/outputs.py
  19. +2
    -0
      modelscope/pipelines/builder.py
  20. +1
    -0
      modelscope/pipelines/cv/__init__.py
  21. +111
    -0
      modelscope/pipelines/cv/image_denoise_pipeline.py
  22. +1
    -0
      modelscope/preprocessors/__init__.py
  23. +25
    -0
      modelscope/preprocessors/image.py
  24. +1
    -0
      modelscope/utils/constant.py
  25. +59
    -0
      tests/pipelines/test_image_denoise.py
  26. +74
    -0
      tests/trainers/test_image_denoise_trainer.py

+ 3
- 0
data/test/images/noisy-demo-0.png View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:403034182fa320130dae0d75b92e85e0850771378e674d65455c403a4958e29c
size 170716

+ 3
- 0
data/test/images/noisy-demo-1.png View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ebd5dacad9b75ef80f87eb785d7818421dadb63257da0e91e123766c5913f855
size 149971

+ 6
- 0
modelscope/metainfo.py View File

@@ -10,6 +10,7 @@ class Models(object):
Model name should only contain model info but not task info.
"""
# vision models
nafnet = 'nafnet'
csrnet = 'csrnet'
cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin'

@@ -59,6 +60,7 @@ class Pipelines(object):
"""
# vision tasks
image_matting = 'unet-image-matting'
image_denoise = 'nafnet-image-denoise'
person_image_cartoon = 'unet-person-image-cartoon'
ocr_detection = 'resnet18-ocr-detection'
action_recognition = 'TAdaConv_action-recognition'
@@ -132,6 +134,7 @@ class Preprocessors(object):

# cv preprocessor
load_image = 'load-image'
image_denoie_preprocessor = 'image-denoise-preprocessor'
image_color_enhance_preprocessor = 'image-color-enhance-preprocessor'
image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor'

@@ -167,6 +170,9 @@ class Metrics(object):
# accuracy
accuracy = 'accuracy'

# metrics for image denoise task
image_denoise_metric = 'image-denoise-metric'

# metric for image instance segmentation task
image_ins_seg_coco_metric = 'image-ins-seg-coco-metric'
# metrics for sequence classification task


+ 1
- 0
modelscope/metrics/__init__.py View File

@@ -1,6 +1,7 @@
from .base import Metric
from .builder import METRICS, build_metric, task_default_metrics
from .image_color_enhance_metric import ImageColorEnhanceMetric
from .image_denoise_metric import ImageDenoiseMetric
from .image_instance_segmentation_metric import \
ImageInstanceSegmentationCOCOMetric
from .sequence_classification_metric import SequenceClassificationMetric


+ 1
- 0
modelscope/metrics/builder.py View File

@@ -22,6 +22,7 @@ task_default_metrics = {
Tasks.sentence_similarity: [Metrics.seq_cls_metric],
Tasks.sentiment_classification: [Metrics.seq_cls_metric],
Tasks.text_generation: [Metrics.text_gen_metric],
Tasks.image_denoise: [Metrics.image_denoise_metric],
Tasks.image_color_enhance: [Metrics.image_color_enhance_metric]
}



+ 45
- 0
modelscope/metrics/image_denoise_metric.py View File

@@ -0,0 +1,45 @@
from typing import Dict

import numpy as np
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

from modelscope.metainfo import Metrics
from modelscope.utils.registry import default_group
from modelscope.utils.tensor_utils import (torch_nested_detach,
torch_nested_numpify)
from .base import Metric
from .builder import METRICS, MetricKeys


@METRICS.register_module(
group_key=default_group, module_name=Metrics.image_denoise_metric)
class ImageDenoiseMetric(Metric):
"""The metric computation class for image denoise classes.
"""
pred_name = 'pred'
label_name = 'target'

def __init__(self):
self.preds = []
self.labels = []

def add(self, outputs: Dict, inputs: Dict):
ground_truths = outputs[ImageDenoiseMetric.label_name]
eval_results = outputs[ImageDenoiseMetric.pred_name]
self.preds.append(
torch_nested_numpify(torch_nested_detach(eval_results)))
self.labels.append(
torch_nested_numpify(torch_nested_detach(ground_truths)))

def evaluate(self):
psnr_list, ssim_list = [], []
for (pred, label) in zip(self.preds, self.labels):
psnr_list.append(
peak_signal_noise_ratio(label[0], pred[0], data_range=255))
ssim_list.append(
structural_similarity(
label[0], pred[0], multichannel=True, data_range=255))
return {
MetricKeys.PSNR: np.mean(psnr_list),
MetricKeys.SSIM: np.mean(ssim_list)
}

+ 1
- 0
modelscope/models/__init__.py View File

@@ -22,6 +22,7 @@ except ModuleNotFoundError as e:

try:
from .multi_modal import OfaForImageCaptioning
from .cv import NAFNetForImageDenoise
from .nlp import (BertForMaskedLM, BertForSequenceClassification,
SbertForNLI, SbertForSentenceSimilarity,
SbertForSentimentClassification,


+ 1
- 0
modelscope/models/cv/__init__.py View File

@@ -1,2 +1,3 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .image_color_enhance.image_color_enhance import ImageColorEnhance
from .image_denoise.nafnet_for_image_denoise import * # noqa F403

+ 0
- 0
modelscope/models/cv/image_denoise/__init__.py View File


+ 233
- 0
modelscope/models/cv/image_denoise/nafnet/NAFNet_arch.py View File

@@ -0,0 +1,233 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .arch_util import LayerNorm2d


class SimpleGate(nn.Module):

def forward(self, x):
x1, x2 = x.chunk(2, dim=1)
return x1 * x2


class NAFBlock(nn.Module):

def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
super().__init__()
dw_channel = c * DW_Expand
self.conv1 = nn.Conv2d(
in_channels=c,
out_channels=dw_channel,
kernel_size=1,
padding=0,
stride=1,
groups=1,
bias=True)
self.conv2 = nn.Conv2d(
in_channels=dw_channel,
out_channels=dw_channel,
kernel_size=3,
padding=1,
stride=1,
groups=dw_channel,
bias=True)
self.conv3 = nn.Conv2d(
in_channels=dw_channel // 2,
out_channels=c,
kernel_size=1,
padding=0,
stride=1,
groups=1,
bias=True)

# Simplified Channel Attention
self.sca = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(
in_channels=dw_channel // 2,
out_channels=dw_channel // 2,
kernel_size=1,
padding=0,
stride=1,
groups=1,
bias=True),
)

# SimpleGate
self.sg = SimpleGate()

ffn_channel = FFN_Expand * c
self.conv4 = nn.Conv2d(
in_channels=c,
out_channels=ffn_channel,
kernel_size=1,
padding=0,
stride=1,
groups=1,
bias=True)
self.conv5 = nn.Conv2d(
in_channels=ffn_channel // 2,
out_channels=c,
kernel_size=1,
padding=0,
stride=1,
groups=1,
bias=True)

self.norm1 = LayerNorm2d(c)
self.norm2 = LayerNorm2d(c)

self.dropout1 = nn.Dropout(
drop_out_rate) if drop_out_rate > 0. else nn.Identity()
self.dropout2 = nn.Dropout(
drop_out_rate) if drop_out_rate > 0. else nn.Identity()

self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
self.gamma = nn.Parameter(
torch.zeros((1, c, 1, 1)), requires_grad=True)

def forward(self, inp):
x = inp

x = self.norm1(x)

x = self.conv1(x)
x = self.conv2(x)
x = self.sg(x)
x = x * self.sca(x)
x = self.conv3(x)

x = self.dropout1(x)

y = inp + x * self.beta

x = self.conv4(self.norm2(y))
x = self.sg(x)
x = self.conv5(x)

x = self.dropout2(x)

return y + x * self.gamma


class NAFNet(nn.Module):

def __init__(self,
img_channel=3,
width=16,
middle_blk_num=1,
enc_blk_nums=[],
dec_blk_nums=[]):
super().__init__()

self.intro = nn.Conv2d(
in_channels=img_channel,
out_channels=width,
kernel_size=3,
padding=1,
stride=1,
groups=1,
bias=True)
self.ending = nn.Conv2d(
in_channels=width,
out_channels=img_channel,
kernel_size=3,
padding=1,
stride=1,
groups=1,
bias=True)

self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()
self.middle_blks = nn.ModuleList()
self.ups = nn.ModuleList()
self.downs = nn.ModuleList()

chan = width
for num in enc_blk_nums:
self.encoders.append(
nn.Sequential(*[NAFBlock(chan) for _ in range(num)]))
self.downs.append(nn.Conv2d(chan, 2 * chan, 2, 2))
chan = chan * 2

self.middle_blks = \
nn.Sequential(
*[NAFBlock(chan) for _ in range(middle_blk_num)]
)

for num in dec_blk_nums:
self.ups.append(
nn.Sequential(
nn.Conv2d(chan, chan * 2, 1, bias=False),
nn.PixelShuffle(2)))
chan = chan // 2
self.decoders.append(
nn.Sequential(*[NAFBlock(chan) for _ in range(num)]))

self.padder_size = 2**len(self.encoders)

def forward(self, inp):
B, C, H, W = inp.shape
inp = self.check_image_size(inp)

x = self.intro(inp)

encs = []

for encoder, down in zip(self.encoders, self.downs):
x = encoder(x)
encs.append(x)
x = down(x)

x = self.middle_blks(x)

for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
x = up(x)
x = x + enc_skip
x = decoder(x)

x = self.ending(x)
x = x + inp

return x[:, :, :H, :W]

def check_image_size(self, x):
_, _, h, w = x.size()
mod_pad_h = (self.padder_size
- h % self.padder_size) % self.padder_size
mod_pad_w = (self.padder_size
- w % self.padder_size) % self.padder_size
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
return x


class PSNRLoss(nn.Module):

def __init__(self, loss_weight=1.0, reduction='mean', toY=False):
super(PSNRLoss, self).__init__()
assert reduction == 'mean'
self.loss_weight = loss_weight
self.scale = 10 / np.log(10)
self.toY = toY
self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1)
self.first = True

def forward(self, pred, target):
assert len(pred.size()) == 4
if self.toY:
if self.first:
self.coef = self.coef.to(pred.device)
self.first = False

pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.

pred, target = pred / 255., target / 255.
pass
assert len(pred.size()) == 4

return self.loss_weight * self.scale * torch.log((
(pred - target)**2).mean(dim=(1, 2, 3)) + 1e-8).mean()

+ 0
- 0
modelscope/models/cv/image_denoise/nafnet/__init__.py View File


+ 42
- 0
modelscope/models/cv/image_denoise/nafnet/arch_util.py View File

@@ -0,0 +1,42 @@
import torch
import torch.nn as nn


class LayerNormFunction(torch.autograd.Function):

@staticmethod
def forward(ctx, x, weight, bias, eps):
ctx.eps = eps
N, C, H, W = x.size()
mu = x.mean(1, keepdim=True)
var = (x - mu).pow(2).mean(1, keepdim=True)
y = (x - mu) / (var + eps).sqrt()
ctx.save_for_backward(y, var, weight)
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
return y

@staticmethod
def backward(ctx, grad_output):
eps = ctx.eps

N, C, H, W = grad_output.size()
y, var, weight = ctx.saved_variables
g = grad_output * weight.view(1, C, 1, 1)
mean_g = g.mean(dim=1, keepdim=True)

mean_gy = (g * y).mean(dim=1, keepdim=True)
gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(
dim=0), grad_output.sum(dim=3).sum(dim=2).sum(dim=0), None


class LayerNorm2d(nn.Module):

def __init__(self, channels, eps=1e-6):
super(LayerNorm2d, self).__init__()
self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
self.eps = eps

def forward(self, x):
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)

+ 119
- 0
modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py View File

@@ -0,0 +1,119 @@
import os
from copy import deepcopy
from typing import Any, Dict, Union

import numpy as np
import torch.cuda
from torch.nn.parallel import DataParallel, DistributedDataParallel

from modelscope.metainfo import Models
from modelscope.models.base import Tensor
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from .nafnet.NAFNet_arch import NAFNet, PSNRLoss

logger = get_logger()
__all__ = ['NAFNetForImageDenoise']


@MODELS.register_module(Tasks.image_denoise, module_name=Models.nafnet)
class NAFNetForImageDenoise(TorchModel):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the image denoise model from the `model_dir` path.

Args:
model_dir (str): the model path.

"""
super().__init__(model_dir, *args, **kwargs)
self.model_dir = model_dir
self.config = Config.from_file(
os.path.join(self.model_dir, ModelFile.CONFIGURATION))
model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
self.model = NAFNet(**self.config.model.network_g)
self.loss = PSNRLoss()

if torch.cuda.is_available():
self._device = torch.device('cuda')
else:
self._device = torch.device('cpu')

self.model = self.model.to(self._device)
self.model = self._load_pretrained(self.model, model_path)

if self.training:
self.model.train()
else:
self.model.eval()

def _load_pretrained(self,
net,
load_path,
strict=True,
param_key='params'):
if isinstance(net, (DataParallel, DistributedDataParallel)):
net = net.module
load_net = torch.load(
load_path, map_location=lambda storage, loc: storage)
if param_key is not None:
if param_key not in load_net and 'params' in load_net:
param_key = 'params'
logger.info(
f'Loading: {param_key} does not exist, use params.')
if param_key in load_net:
load_net = load_net[param_key]
logger.info(
f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].'
)
# remove unnecessary 'module.'
for k, v in deepcopy(load_net).items():
if k.startswith('module.'):
load_net[k[7:]] = v
load_net.pop(k)
net.load_state_dict(load_net, strict=strict)
logger.info('load model done.')
return net

def _train_forward(self, input: Tensor,
target: Tensor) -> Dict[str, Tensor]:
preds = self.model(input)
return {'loss': self.loss(preds, target)}

def _inference_forward(self, input: Tensor) -> Dict[str, Tensor]:
return {'outputs': self.model(input).clamp(0, 1)}

def _evaluate_postprocess(self, input: Tensor,
target: Tensor) -> Dict[str, list]:
preds = self.model(input)
preds = list(torch.split(preds, 1, 0))
targets = list(torch.split(target, 1, 0))

preds = [(pred.data * 255.).squeeze(0).permute(
1, 2, 0).cpu().numpy().astype(np.uint8) for pred in preds]
targets = [(target.data * 255.).squeeze(0).permute(
1, 2, 0).cpu().numpy().astype(np.uint8) for target in targets]

return {'pred': preds, 'target': targets}

def forward(self, inputs: Dict[str,
Tensor]) -> Dict[str, Union[list, Tensor]]:
"""return the result by the model

Args:
inputs (Tensor): the preprocessed data

Returns:
Dict[str, Tensor]: results
"""
for key, value in inputs.items():
inputs[key] = inputs[key].to(self._device)
if self.training:
return self._train_forward(**inputs)
elif 'target' in inputs:
return self._evaluate_postprocess(**inputs)
else:
return self._inference_forward(**inputs)

+ 0
- 0
modelscope/msdatasets/image_denoise_data/__init__.py View File


+ 152
- 0
modelscope/msdatasets/image_denoise_data/data_utils.py View File

@@ -0,0 +1,152 @@
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
import os
from os import path as osp

import cv2
import numpy as np
import torch

from .transforms import mod_crop


def img2tensor(imgs, bgr2rgb=True, float32=True):
"""Numpy array to tensor.
Args:
imgs (list[ndarray] | ndarray): Input images.
bgr2rgb (bool): Whether to change bgr to rgb.
float32 (bool): Whether to change to float32.
Returns:
list[tensor] | tensor: Tensor images. If returned results only have
one element, just return tensor.
"""

def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img

if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
else:
return _totensor(imgs, bgr2rgb, float32)


def scandir(dir_path, keyword=None, recursive=False, full_path=False):
"""Scan a directory to find the interested files.
Args:
dir_path (str): Path of the directory.
keyword (str | tuple(str), optional): File keyword that we are
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
full_path (bool, optional): If set to True, include the dir_path.
Default: False.
Returns:
A generator for all the interested files with relative pathes.
"""

if (keyword is not None) and not isinstance(keyword, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')

root = dir_path

def _scandir(dir_path, keyword, recursive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
if full_path:
return_path = entry.path
else:
return_path = osp.relpath(entry.path, root)

if keyword is None:
yield return_path
elif keyword in return_path:
yield return_path
else:
if recursive:
yield from _scandir(
entry.path, keyword=keyword, recursive=recursive)
else:
continue

return _scandir(dir_path, keyword=keyword, recursive=recursive)


def padding(img_lq, img_gt, gt_size):
h, w, _ = img_lq.shape

h_pad = max(0, gt_size - h)
w_pad = max(0, gt_size - w)

if h_pad == 0 and w_pad == 0:
return img_lq, img_gt

img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
return img_lq, img_gt


def read_img_seq(path, require_mod_crop=False, scale=1):
"""Read a sequence of images from a given folder path.
Args:
path (list[str] | str): List of image paths or image folder path.
require_mod_crop (bool): Require mod crop for each image.
Default: False.
scale (int): Scale factor for mod_crop. Default: 1.
Returns:
Tensor: size (t, c, h, w), RGB, [0, 1].
"""
if isinstance(path, list):
img_paths = path
else:
img_paths = sorted(list(scandir(path, full_path=True)))
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
if require_mod_crop:
imgs = [mod_crop(img, scale) for img in imgs]
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
imgs = torch.stack(imgs, dim=0)
return imgs


def paired_paths_from_folder(folders, keys, filename_tmpl):
"""Generate paired paths from folders.
Args:
folders (list[str]): A list of folder path. The order of list should
be [input_folder, gt_folder].
keys (list[str]): A list of keys identifying folders. The order should
be in consistent with folders, e.g., ['lq', 'gt'].
filename_tmpl (str): Template for each filename. Note that the
template excludes the file extension. Usually the filename_tmpl is
for files in the input folder.
Returns:
list[str]: Returned path list.
"""
assert len(folders) == 2, (
'The len of folders should be 2 with [input_folder, gt_folder]. '
f'But got {len(folders)}')
assert len(keys) == 2, (
'The len of keys should be 2 with [input_key, gt_key]. '
f'But got {len(keys)}')
input_folder, gt_folder = folders
input_key, gt_key = keys

input_paths = list(scandir(input_folder, keyword='NOISY', recursive=True))
gt_paths = list(scandir(gt_folder, keyword='GT', recursive=True))
assert len(input_paths) == len(gt_paths), (
f'{input_key} and {gt_key} datasets have different number of images: '
f'{len(input_paths)}, {len(gt_paths)}.')
paths = []
for idx in range(len(gt_paths)):
gt_path = os.path.join(gt_folder, gt_paths[idx])
input_path = os.path.join(input_folder, gt_path.replace('GT', 'NOISY'))

paths.append(
dict([(f'{input_key}_path', input_path),
(f'{gt_key}_path', gt_path)]))
return paths

+ 78
- 0
modelscope/msdatasets/image_denoise_data/image_denoise_dataset.py View File

@@ -0,0 +1,78 @@
import os
from typing import Callable, List, Optional, Tuple, Union

import cv2
import numpy as np
from torch.utils import data

from .data_utils import img2tensor, padding, paired_paths_from_folder
from .transforms import augment, paired_random_crop


def default_loader(path):
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0


class PairedImageDataset(data.Dataset):
"""Paired image dataset for image restoration.
"""

def __init__(self, opt, root, is_train):
super(PairedImageDataset, self).__init__()
self.opt = opt
self.is_train = is_train
self.gt_folder, self.lq_folder = os.path.join(
root, opt.dataroot_gt), os.path.join(root, opt.dataroot_lq)

if opt.filename_tmpl is not None:
self.filename_tmpl = opt.filename_tmpl
else:
self.filename_tmpl = '{}'
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder],
['lq', 'gt'], self.filename_tmpl)

def __getitem__(self, index):
scale = self.opt.scale

# Load gt and lq images. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32.
gt_path = self.paths[index]['gt_path']
img_gt = default_loader(gt_path)
lq_path = self.paths[index]['lq_path']
img_lq = default_loader(lq_path)

# augmentation for training
# if self.is_train:
gt_size = self.opt.gt_size
# padding
img_gt, img_lq = padding(img_gt, img_lq, gt_size)

# random crop
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale)

# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.opt.use_flip,
self.opt.use_rot)

# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq],
bgr2rgb=True,
float32=True)

return {
'input': img_lq,
'target': img_gt,
'input_path': lq_path,
'target_path': gt_path
}

def __len__(self):
return len(self.paths)

def to_torch_dataset(
self,
columns: Union[str, List[str]] = None,
preprocessors: Union[Callable, List[Callable]] = None,
**format_kwargs,
):
return self

+ 96
- 0
modelscope/msdatasets/image_denoise_data/transforms.py View File

@@ -0,0 +1,96 @@
# Modified from https://github.com/megvii-research/NAFNet/blob/main/basicsr/data/transforms.py

import random


def mod_crop(img, scale):
"""Mod crop images, used during testing.
Args:
img (ndarray): Input image.
scale (int): Scale factor.
Returns:
ndarray: Result image.
"""
img = img.copy()
if img.ndim in (2, 3):
h, w = img.shape[0], img.shape[1]
h_remainder, w_remainder = h % scale, w % scale
img = img[:h - h_remainder, :w - w_remainder, ...]
else:
raise ValueError(f'Wrong img ndim: {img.ndim}.')
return img


def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale):
"""Paired random crop.

It crops lists of lq and gt images with corresponding locations.

Args:
img_gts (list[ndarray] | ndarray): GT images.
img_lqs (list[ndarray] | ndarray): LQ images.
gt_patch_size (int): GT patch size.
scale (int): Scale factor.

Returns:
list[ndarray] | ndarray: GT images and LQ images.
"""

if not isinstance(img_gts, list):
img_gts = [img_gts]
if not isinstance(img_lqs, list):
img_lqs = [img_lqs]

h_lq, w_lq, _ = img_lqs[0].shape
h_gt, w_gt, _ = img_gts[0].shape
lq_patch_size = gt_patch_size // scale

# randomly choose top and left coordinates for lq patch
top = random.randint(0, h_lq - lq_patch_size)
left = random.randint(0, w_lq - lq_patch_size)

# crop lq patch
img_lqs = [
v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
for v in img_lqs
]

# crop corresponding gt patch
top_gt, left_gt = int(top * scale), int(left * scale)
img_gts = [
v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...]
for v in img_gts
]
if len(img_gts) == 1:
img_gts = img_gts[0]
if len(img_lqs) == 1:
img_lqs = img_lqs[0]
return img_gts, img_lqs


def augment(imgs, hflip=True, rotation=True, vflip=False):
"""Augment: horizontal flips | rotate

All the images in the list use the same augmentation.
"""
hflip = hflip and random.random() < 0.5
if vflip or rotation:
vflip = random.random() < 0.5
rot90 = rotation and random.random() < 0.5

def _augment(img):
if hflip: # horizontal
img = img[:, ::-1, :].copy()
if vflip: # vertical
img = img[::-1, :, :].copy()
if rot90:
img = img.transpose(1, 0, 2)
return img

if not isinstance(imgs, list):
imgs = [imgs]
imgs = [_augment(img) for img in imgs]
if len(imgs) == 1:
imgs = imgs[0]

return imgs

+ 1
- 0
modelscope/outputs.py View File

@@ -74,6 +74,7 @@ TASK_OUTPUTS = {
Tasks.image_editing: [OutputKeys.OUTPUT_IMG],
Tasks.image_matting: [OutputKeys.OUTPUT_IMG],
Tasks.image_generation: [OutputKeys.OUTPUT_IMG],
Tasks.image_denoise: [OutputKeys.OUTPUT_IMG],
Tasks.image_colorization: [OutputKeys.OUTPUT_IMG],
Tasks.face_image_generation: [OutputKeys.OUTPUT_IMG],
Tasks.image_super_resolution: [OutputKeys.OUTPUT_IMG],


+ 2
- 0
modelscope/pipelines/builder.py View File

@@ -35,6 +35,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
), # TODO: revise back after passing the pr
Tasks.image_matting: (Pipelines.image_matting,
'damo/cv_unet_image-matting'),
Tasks.image_denoise: (Pipelines.image_denoise,
'damo/cv_nafnet_image-denoise_sidd'),
Tasks.text_classification: (Pipelines.sentiment_analysis,
'damo/bert-base-sst2'),
Tasks.text_generation: (Pipelines.text_generation,


+ 1
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -6,6 +6,7 @@ try:
from .action_recognition_pipeline import ActionRecognitionPipeline
from .animal_recog_pipeline import AnimalRecogPipeline
from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline
from .image_denoise_pipeline import ImageDenoisePipeline
from .image_color_enhance_pipeline import ImageColorEnhancePipeline
from .virtual_tryon_pipeline import VirtualTryonPipeline
from .image_colorization_pipeline import ImageColorizationPipeline


+ 111
- 0
modelscope/pipelines/cv/image_denoise_pipeline.py View File

@@ -0,0 +1,111 @@
from typing import Any, Dict, Optional, Union

import cv2
import numpy as np
import torch
from PIL import Image
from torchvision import transforms

from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.models.cv import NAFNetForImageDenoise
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input
from modelscope.preprocessors import ImageDenoisePreprocessor, LoadImage
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from ..base import Pipeline
from ..builder import PIPELINES

logger = get_logger()

__all__ = ['ImageDenoisePipeline']


@PIPELINES.register_module(
Tasks.image_denoise, module_name=Pipelines.image_denoise)
class ImageDenoisePipeline(Pipeline):

def __init__(self,
model: Union[NAFNetForImageDenoise, str],
preprocessor: Optional[ImageDenoisePreprocessor] = None,
**kwargs):
"""
use `model` and `preprocessor` to create a cv image denoise pipeline for prediction
Args:
model: model id on modelscope hub.
"""
model = model if isinstance(
model, NAFNetForImageDenoise) else Model.from_pretrained(model)
model.eval()
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.config = model.config

if torch.cuda.is_available():
self._device = torch.device('cuda')
else:
self._device = torch.device('cpu')
self.model = model
logger.info('load image denoise model done')

def preprocess(self, input: Input) -> Dict[str, Any]:
img = LoadImage.convert_to_img(input)
test_transforms = transforms.Compose([transforms.ToTensor()])
img = test_transforms(img)
result = {'img': img.unsqueeze(0).to(self._device)}
return result

def crop_process(self, input):
output = torch.zeros_like(input) # [1, C, H, W]
# determine crop_h and crop_w
ih, iw = input.shape[-2:]
crop_rows, crop_cols = max(ih // 512, 1), max(iw // 512, 1)
overlap = 16

step_h, step_w = ih // crop_rows, iw // crop_cols
for y in range(crop_rows):
for x in range(crop_cols):
crop_y = step_h * y
crop_x = step_w * x

crop_h = step_h if y < crop_rows - 1 else ih - crop_y
crop_w = step_w if x < crop_cols - 1 else iw - crop_x

crop_frames = input[:, :,
max(0, crop_y - overlap
):min(crop_y + crop_h + overlap, ih),
max(0, crop_x - overlap
):min(crop_x + crop_w
+ overlap, iw)].contiguous()
h_start = overlap if max(0, crop_y - overlap) > 0 else 0
w_start = overlap if max(0, crop_x - overlap) > 0 else 0
h_end = h_start + crop_h if min(crop_y + crop_h
+ overlap, ih) < ih else ih
w_end = w_start + crop_w if min(crop_x + crop_w
+ overlap, iw) < iw else iw

output[:, :, crop_y:crop_y + crop_h,
crop_x:crop_x + crop_w] = self.model._inference_forward(
crop_frames)['outputs'][:, :, h_start:h_end,
w_start:w_end]
return output

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:

def set_phase(model, is_train):
if is_train:
model.train()
else:
model.eval()

is_train = False
set_phase(self.model, is_train)
with torch.no_grad():
output = self.crop_process(input['img']) # output Tensor

return {'output_tensor': output}

def postprocess(self, input: Dict[str, Any]) -> Dict[str, Any]:
output_img = (input['output_tensor'].squeeze(0) * 255).cpu().permute(
1, 2, 0).numpy().astype('uint8')
return {OutputKeys.OUTPUT_IMG: output_img}

+ 1
- 0
modelscope/preprocessors/__init__.py View File

@@ -21,6 +21,7 @@ try:
from .space.dialog_state_tracking_preprocessor import * # noqa F403
from .image import ImageColorEnhanceFinetunePreprocessor
from .image import ImageInstanceSegmentationPreprocessor
from .image import ImageDenoisePreprocessor
except ModuleNotFoundError as e:
if str(e) == "No module named 'tensorflow'":
print(TENSORFLOW_IMPORT_ERROR.format('tts'))


+ 25
- 0
modelscope/preprocessors/image.py View File

@@ -138,6 +138,31 @@ class ImageColorEnhanceFinetunePreprocessor(Preprocessor):
return data


@PREPROCESSORS.register_module(
Fields.cv, module_name=Preprocessors.image_denoie_preprocessor)
class ImageDenoisePreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""

Args:
model_dir (str): model path
"""
super().__init__(*args, **kwargs)
self.model_dir: str = model_dir

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""process the raw input data

Args:
data Dict[str, Any]

Returns:
Dict[str, Any]: the preprocessed data
"""
return data


@PREPROCESSORS.register_module(
Fields.cv,
module_name=Preprocessors.image_instance_segmentation_preprocessor)


+ 1
- 0
modelscope/utils/constant.py View File

@@ -24,6 +24,7 @@ class CVTasks(object):
image_editing = 'image-editing'
image_generation = 'image-generation'
image_matting = 'image-matting'
image_denoise = 'image-denoise'
ocr_detection = 'ocr-detection'
action_recognition = 'action-recognition'
video_embedding = 'video-embedding'


+ 59
- 0
tests/pipelines/test_image_denoise.py View File

@@ -0,0 +1,59 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import unittest

from PIL import Image

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines import ImageDenoisePipeline, pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class ImageDenoiseTest(unittest.TestCase):
model_id = 'damo/cv_nafnet_image-denoise_sidd'
demo_image_path = 'data/test/images/noisy-demo-1.png'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)
pipeline = ImageDenoisePipeline(cache_path)
denoise_img = pipeline(
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG]
denoise_img = Image.fromarray(denoise_img)
w, h = denoise_img.size
print('pipeline: the shape of output_img is {}x{}'.format(h, w))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
pipeline_ins = pipeline(task=Tasks.image_denoise, model=model)
denoise_img = pipeline_ins(
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG]
denoise_img = Image.fromarray(denoise_img)
w, h = denoise_img.size
print('pipeline: the shape of output_img is {}x{}'.format(h, w))

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_name(self):
pipeline_ins = pipeline(task=Tasks.image_denoise, model=self.model_id)
denoise_img = pipeline_ins(
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG]
denoise_img = Image.fromarray(denoise_img)
w, h = denoise_img.size
print('pipeline: the shape of output_img is {}x{}'.format(h, w))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.image_denoise)
denoise_img = pipeline_ins(
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG]
denoise_img = Image.fromarray(denoise_img)
w, h = denoise_img.size
print('pipeline: the shape of output_img is {}x{}'.format(h, w))


if __name__ == '__main__':
unittest.main()

+ 74
- 0
tests/trainers/test_image_denoise_trainer.py View File

@@ -0,0 +1,74 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import NAFNetForImageDenoise
from modelscope.msdatasets.image_denoise_data.image_denoise_dataset import \
PairedImageDataset
from modelscope.trainers import build_trainer
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import test_level

logger = get_logger()


class ImageDenoiseTrainerTest(unittest.TestCase):

def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)

self.model_id = 'damo/cv_nafnet_image-denoise_sidd'
self.cache_path = snapshot_download(self.model_id)
self.config = Config.from_file(
os.path.join(self.cache_path, ModelFile.CONFIGURATION))
self.dataset_train = PairedImageDataset(
self.config.dataset, self.cache_path, is_train=True)
self.dataset_val = PairedImageDataset(
self.config.dataset, self.cache_path, is_train=False)

def tearDown(self):
shutil.rmtree(self.tmp_dir, ignore_errors=True)
super().tearDown()

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer(self):
kwargs = dict(
model=self.model_id,
train_dataset=self.dataset_train,
eval_dataset=self.dataset_val,
work_dir=self.tmp_dir)
trainer = build_trainer(default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(2):
self.assertIn(f'epoch_{i+1}.pth', results_files)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_trainer_with_model_and_args(self):
model = NAFNetForImageDenoise.from_pretrained(self.cache_path)
kwargs = dict(
cfg_file=os.path.join(self.cache_path, ModelFile.CONFIGURATION),
model=model,
train_dataset=self.dataset_train,
eval_dataset=self.dataset_val,
max_epochs=2,
work_dir=self.tmp_dir)
trainer = build_trainer(default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(2):
self.assertIn(f'epoch_{i+1}.pth', results_files)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save