Browse Source

[to #42322933]图像去噪using msdataset to load dataset

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10338265
master
huizheng.hz yingda.chen 3 years ago
parent
commit
922f4c589b
13 changed files with 284 additions and 260 deletions
  1. +134
    -6
      modelscope/metrics/image_denoise_metric.py
  2. +5
    -0
      modelscope/models/cv/image_denoise/nafnet/NAFNet_arch.py
  3. +5
    -0
      modelscope/models/cv/image_denoise/nafnet/arch_util.py
  4. +1
    -0
      modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py
  5. +0
    -152
      modelscope/msdatasets/image_denoise_data/data_utils.py
  6. +0
    -78
      modelscope/msdatasets/image_denoise_data/image_denoise_dataset.py
  7. +2
    -2
      modelscope/msdatasets/task_datasets/sidd_image_denoising/__init__.py
  8. +46
    -0
      modelscope/msdatasets/task_datasets/sidd_image_denoising/data_utils.py
  9. +62
    -0
      modelscope/msdatasets/task_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py
  10. +0
    -0
      modelscope/msdatasets/task_datasets/sidd_image_denoising/transforms.py
  11. +1
    -1
      modelscope/pipelines/cv/image_denoise_pipeline.py
  12. +10
    -15
      tests/pipelines/test_image_denoise.py
  13. +18
    -6
      tests/trainers/test_image_denoise_trainer.py

+ 134
- 6
modelscope/metrics/image_denoise_metric.py View File

@@ -1,7 +1,9 @@
# The code is modified based on BasicSR metrics:
# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/metrics/psnr_ssim.py
from typing import Dict from typing import Dict


import cv2
import numpy as np import numpy as np
from skimage.metrics import peak_signal_noise_ratio, structural_similarity


from modelscope.metainfo import Metrics from modelscope.metainfo import Metrics
from modelscope.utils.registry import default_group from modelscope.utils.registry import default_group
@@ -34,12 +36,138 @@ class ImageDenoiseMetric(Metric):
def evaluate(self): def evaluate(self):
psnr_list, ssim_list = [], [] psnr_list, ssim_list = [], []
for (pred, label) in zip(self.preds, self.labels): for (pred, label) in zip(self.preds, self.labels):
psnr_list.append(
peak_signal_noise_ratio(label[0], pred[0], data_range=255))
ssim_list.append(
structural_similarity(
label[0], pred[0], multichannel=True, data_range=255))
psnr_list.append(calculate_psnr(label[0], pred[0], crop_border=0))
ssim_list.append(calculate_ssim(label[0], pred[0], crop_border=0))
return { return {
MetricKeys.PSNR: np.mean(psnr_list), MetricKeys.PSNR: np.mean(psnr_list),
MetricKeys.SSIM: np.mean(ssim_list) MetricKeys.SSIM: np.mean(ssim_list)
} }


def reorder_image(img, input_order='HWC'):
"""Reorder images to 'HWC' order.
If the input_order is (h, w), return (h, w, 1);
If the input_order is (c, h, w), return (h, w, c);
If the input_order is (h, w, c), return as it is.
Args:
img (ndarray): Input image.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
If the input image shape is (h, w), input_order will not have
effects. Default: 'HWC'.
Returns:
ndarray: reordered image.
"""

if input_order not in ['HWC', 'CHW']:
raise ValueError(
f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'"
)
if len(img.shape) == 2:
img = img[..., None]
if input_order == 'CHW':
img = img.transpose(1, 2, 0)
return img


def calculate_psnr(img, img2, crop_border, input_order='HWC', **kwargs):
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
Args:
img (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'.
Returns:
float: PSNR result.
"""

assert img.shape == img2.shape, (
f'Image shapes are different: {img.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(
f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"'
)
img = reorder_image(img, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order)

if crop_border != 0:
img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]

img = img.astype(np.float64)
img2 = img2.astype(np.float64)

mse = np.mean((img - img2)**2)
if mse == 0:
return float('inf')
return 10. * np.log10(255. * 255. / mse)


def calculate_ssim(img, img2, crop_border, input_order='HWC', **kwargs):
"""Calculate SSIM (structural similarity).
``Paper: Image quality assessment: From error visibility to structural similarity``
The results are the same as that of the official released MATLAB code in
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
For three-channel images, SSIM is calculated for each channel and then
averaged.
Args:
img (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
Default: 'HWC'.
Returns:
float: SSIM result.
"""

assert img.shape == img2.shape, (
f'Image shapes are different: {img.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(
f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"'
)
img = reorder_image(img, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order)

if crop_border != 0:
img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]

img = img.astype(np.float64)
img2 = img2.astype(np.float64)

ssims = []
for i in range(img.shape[2]):
ssims.append(_ssim(img[..., i], img2[..., i]))
return np.array(ssims).mean()


def _ssim(img, img2):
"""Calculate SSIM (structural similarity) for one channel images.
It is called by func:`calculate_ssim`.
Args:
img (ndarray): Images with range [0, 255] with order 'HWC'.
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
Returns:
float: SSIM result.
"""

c1 = (0.01 * 255)**2
c2 = (0.03 * 255)**2
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())

mu1 = cv2.filter2D(img, -1, window)[5:-5,
5:-5] # valid mode for window size 11
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

tmp1 = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2)
tmp2 = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)
ssim_map = tmp1 / tmp2
return ssim_map.mean()

+ 5
- 0
modelscope/models/cv/image_denoise/nafnet/NAFNet_arch.py View File

@@ -1,3 +1,8 @@
# ------------------------------------------------------------------------
# Modified from https://github.com/megvii-research/NAFNet/blob/main/basicsr/models/archs/NAFNet_arch.py
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------

import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn


+ 5
- 0
modelscope/models/cv/image_denoise/nafnet/arch_util.py View File

@@ -1,3 +1,8 @@
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------

import torch import torch
import torch.nn as nn import torch.nn as nn




+ 1
- 0
modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py View File

@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os import os
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, Union from typing import Any, Dict, Union


+ 0
- 152
modelscope/msdatasets/image_denoise_data/data_utils.py View File

@@ -1,152 +0,0 @@
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
import os
from os import path as osp

import cv2
import numpy as np
import torch

from .transforms import mod_crop


def img2tensor(imgs, bgr2rgb=True, float32=True):
"""Numpy array to tensor.
Args:
imgs (list[ndarray] | ndarray): Input images.
bgr2rgb (bool): Whether to change bgr to rgb.
float32 (bool): Whether to change to float32.
Returns:
list[tensor] | tensor: Tensor images. If returned results only have
one element, just return tensor.
"""

def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img

if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
else:
return _totensor(imgs, bgr2rgb, float32)


def scandir(dir_path, keyword=None, recursive=False, full_path=False):
"""Scan a directory to find the interested files.
Args:
dir_path (str): Path of the directory.
keyword (str | tuple(str), optional): File keyword that we are
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
full_path (bool, optional): If set to True, include the dir_path.
Default: False.
Returns:
A generator for all the interested files with relative pathes.
"""

if (keyword is not None) and not isinstance(keyword, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')

root = dir_path

def _scandir(dir_path, keyword, recursive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
if full_path:
return_path = entry.path
else:
return_path = osp.relpath(entry.path, root)

if keyword is None:
yield return_path
elif keyword in return_path:
yield return_path
else:
if recursive:
yield from _scandir(
entry.path, keyword=keyword, recursive=recursive)
else:
continue

return _scandir(dir_path, keyword=keyword, recursive=recursive)


def padding(img_lq, img_gt, gt_size):
h, w, _ = img_lq.shape

h_pad = max(0, gt_size - h)
w_pad = max(0, gt_size - w)

if h_pad == 0 and w_pad == 0:
return img_lq, img_gt

img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
return img_lq, img_gt


def read_img_seq(path, require_mod_crop=False, scale=1):
"""Read a sequence of images from a given folder path.
Args:
path (list[str] | str): List of image paths or image folder path.
require_mod_crop (bool): Require mod crop for each image.
Default: False.
scale (int): Scale factor for mod_crop. Default: 1.
Returns:
Tensor: size (t, c, h, w), RGB, [0, 1].
"""
if isinstance(path, list):
img_paths = path
else:
img_paths = sorted(list(scandir(path, full_path=True)))
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
if require_mod_crop:
imgs = [mod_crop(img, scale) for img in imgs]
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
imgs = torch.stack(imgs, dim=0)
return imgs


def paired_paths_from_folder(folders, keys, filename_tmpl):
"""Generate paired paths from folders.
Args:
folders (list[str]): A list of folder path. The order of list should
be [input_folder, gt_folder].
keys (list[str]): A list of keys identifying folders. The order should
be in consistent with folders, e.g., ['lq', 'gt'].
filename_tmpl (str): Template for each filename. Note that the
template excludes the file extension. Usually the filename_tmpl is
for files in the input folder.
Returns:
list[str]: Returned path list.
"""
assert len(folders) == 2, (
'The len of folders should be 2 with [input_folder, gt_folder]. '
f'But got {len(folders)}')
assert len(keys) == 2, (
'The len of keys should be 2 with [input_key, gt_key]. '
f'But got {len(keys)}')
input_folder, gt_folder = folders
input_key, gt_key = keys

input_paths = list(scandir(input_folder, keyword='NOISY', recursive=True))
gt_paths = list(scandir(gt_folder, keyword='GT', recursive=True))
assert len(input_paths) == len(gt_paths), (
f'{input_key} and {gt_key} datasets have different number of images: '
f'{len(input_paths)}, {len(gt_paths)}.')
paths = []
for idx in range(len(gt_paths)):
gt_path = os.path.join(gt_folder, gt_paths[idx])
input_path = os.path.join(input_folder, gt_path.replace('GT', 'NOISY'))

paths.append(
dict([(f'{input_key}_path', input_path),
(f'{gt_key}_path', gt_path)]))
return paths

+ 0
- 78
modelscope/msdatasets/image_denoise_data/image_denoise_dataset.py View File

@@ -1,78 +0,0 @@
import os
from typing import Callable, List, Optional, Tuple, Union

import cv2
import numpy as np
from torch.utils import data

from .data_utils import img2tensor, padding, paired_paths_from_folder
from .transforms import augment, paired_random_crop


def default_loader(path):
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0


class PairedImageDataset(data.Dataset):
"""Paired image dataset for image restoration.
"""

def __init__(self, opt, root, is_train):
super(PairedImageDataset, self).__init__()
self.opt = opt
self.is_train = is_train
self.gt_folder, self.lq_folder = os.path.join(
root, opt.dataroot_gt), os.path.join(root, opt.dataroot_lq)

if opt.filename_tmpl is not None:
self.filename_tmpl = opt.filename_tmpl
else:
self.filename_tmpl = '{}'
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder],
['lq', 'gt'], self.filename_tmpl)

def __getitem__(self, index):
scale = self.opt.scale

# Load gt and lq images. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32.
gt_path = self.paths[index]['gt_path']
img_gt = default_loader(gt_path)
lq_path = self.paths[index]['lq_path']
img_lq = default_loader(lq_path)

# augmentation for training
# if self.is_train:
gt_size = self.opt.gt_size
# padding
img_gt, img_lq = padding(img_gt, img_lq, gt_size)

# random crop
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale)

# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.opt.use_flip,
self.opt.use_rot)

# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq],
bgr2rgb=True,
float32=True)

return {
'input': img_lq,
'target': img_gt,
'input_path': lq_path,
'target_path': gt_path
}

def __len__(self):
return len(self.paths)

def to_torch_dataset(
self,
columns: Union[str, List[str]] = None,
preprocessors: Union[Callable, List[Callable]] = None,
**format_kwargs,
):
return self

modelscope/msdatasets/image_denoise_data/__init__.py → modelscope/msdatasets/task_datasets/sidd_image_denoising/__init__.py View File

@@ -4,11 +4,11 @@ from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule from modelscope.utils.import_utils import LazyImportModule


if TYPE_CHECKING: if TYPE_CHECKING:
from .image_denoise_dataset import PairedImageDataset
from .sidd_image_denoising_dataset import SiddImageDenoisingDataset


else: else:
_import_structure = { _import_structure = {
'image_denoise_dataset': ['PairedImageDataset'],
'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'],
} }


import sys import sys

+ 46
- 0
modelscope/msdatasets/task_datasets/sidd_image_denoising/data_utils.py View File

@@ -0,0 +1,46 @@
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------

import cv2
import torch


def img2tensor(imgs, bgr2rgb=True, float32=True):
"""Numpy array to tensor.
Args:
imgs (list[ndarray] | ndarray): Input images.
bgr2rgb (bool): Whether to change bgr to rgb.
float32 (bool): Whether to change to float32.
Returns:
list[tensor] | tensor: Tensor images. If returned results only have
one element, just return tensor.
"""

def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img

if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
else:
return _totensor(imgs, bgr2rgb, float32)


def padding(img_lq, img_gt, gt_size):
h, w, _ = img_lq.shape

h_pad = max(0, gt_size - h)
w_pad = max(0, gt_size - w)

if h_pad == 0 and w_pad == 0:
return img_lq, img_gt

img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
return img_lq, img_gt

+ 62
- 0
modelscope/msdatasets/task_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py View File

@@ -0,0 +1,62 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import cv2
import numpy as np

from modelscope.metainfo import Models
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
TorchTaskDataset
from modelscope.utils.constant import Tasks
from .data_utils import img2tensor, padding
from .transforms import augment, paired_random_crop


def default_loader(path):
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0


@TASK_DATASETS.register_module(
Tasks.image_denoising, module_name=Models.nafnet)
class SiddImageDenoisingDataset(TorchTaskDataset):
"""Paired image dataset for image restoration.
"""

def __init__(self, dataset, opt, is_train):
self.dataset = dataset
self.opt = opt
self.is_train = is_train

def __len__(self):
return len(self.dataset)

def __getitem__(self, index):

# Load gt and lq images. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32.
item_dict = self.dataset[index]
gt_path = item_dict['Clean Image:FILE']
img_gt = default_loader(gt_path)
lq_path = item_dict['Noisy Image:FILE']
img_lq = default_loader(lq_path)

# augmentation for training
if self.is_train:
gt_size = self.opt.gt_size
# padding
img_gt, img_lq = padding(img_gt, img_lq, gt_size)

# random crop
img_gt, img_lq = paired_random_crop(
img_gt, img_lq, gt_size, scale=1)

# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.opt.use_flip,
self.opt.use_rot)

# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq],
bgr2rgb=True,
float32=True)

return {'input': img_lq, 'target': img_gt}

modelscope/msdatasets/image_denoise_data/transforms.py → modelscope/msdatasets/task_datasets/sidd_image_denoising/transforms.py View File


+ 1
- 1
modelscope/pipelines/cv/image_denoise_pipeline.py View File

@@ -105,4 +105,4 @@ class ImageDenoisePipeline(Pipeline):
def postprocess(self, input: Dict[str, Any]) -> Dict[str, Any]: def postprocess(self, input: Dict[str, Any]) -> Dict[str, Any]:
output_img = (input['output_tensor'].squeeze(0) * 255).cpu().permute( output_img = (input['output_tensor'].squeeze(0) * 255).cpu().permute(
1, 2, 0).numpy().astype('uint8') 1, 2, 0).numpy().astype('uint8')
return {OutputKeys.OUTPUT_IMG: output_img}
return {OutputKeys.OUTPUT_IMG: output_img[:, :, ::-1]}

+ 10
- 15
tests/pipelines/test_image_denoise.py View File

@@ -2,8 +2,6 @@


import unittest import unittest


from PIL import Image

from modelscope.hub.snapshot_download import snapshot_download from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model from modelscope.models import Model
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
@@ -20,16 +18,16 @@ class ImageDenoiseTest(unittest.TestCase, DemoCompatibilityCheck):
self.task = Tasks.image_denoising self.task = Tasks.image_denoising
self.model_id = 'damo/cv_nafnet_image-denoise_sidd' self.model_id = 'damo/cv_nafnet_image-denoise_sidd'


demo_image_path = 'data/test/images/noisy-demo-1.png'
demo_image_path = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/noisy-demo-0.png'


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self): def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id) cache_path = snapshot_download(self.model_id)
pipeline = ImageDenoisePipeline(cache_path) pipeline = ImageDenoisePipeline(cache_path)
pipeline.group_key = self.task
denoise_img = pipeline( denoise_img = pipeline(
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG]
denoise_img = Image.fromarray(denoise_img)
w, h = denoise_img.size
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] # BGR
h, w = denoise_img.shape[:2]
print('pipeline: the shape of output_img is {}x{}'.format(h, w)) print('pipeline: the shape of output_img is {}x{}'.format(h, w))


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@@ -37,9 +35,8 @@ class ImageDenoiseTest(unittest.TestCase, DemoCompatibilityCheck):
model = Model.from_pretrained(self.model_id) model = Model.from_pretrained(self.model_id)
pipeline_ins = pipeline(task=Tasks.image_denoising, model=model) pipeline_ins = pipeline(task=Tasks.image_denoising, model=model)
denoise_img = pipeline_ins( denoise_img = pipeline_ins(
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG]
denoise_img = Image.fromarray(denoise_img)
w, h = denoise_img.size
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] # BGR
h, w = denoise_img.shape[:2]
print('pipeline: the shape of output_img is {}x{}'.format(h, w)) print('pipeline: the shape of output_img is {}x{}'.format(h, w))


@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@@ -47,18 +44,16 @@ class ImageDenoiseTest(unittest.TestCase, DemoCompatibilityCheck):
pipeline_ins = pipeline( pipeline_ins = pipeline(
task=Tasks.image_denoising, model=self.model_id) task=Tasks.image_denoising, model=self.model_id)
denoise_img = pipeline_ins( denoise_img = pipeline_ins(
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG]
denoise_img = Image.fromarray(denoise_img)
w, h = denoise_img.size
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] # BGR
h, w = denoise_img.shape[:2]
print('pipeline: the shape of output_img is {}x{}'.format(h, w)) print('pipeline: the shape of output_img is {}x{}'.format(h, w))


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self): def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.image_denoising) pipeline_ins = pipeline(task=Tasks.image_denoising)
denoise_img = pipeline_ins( denoise_img = pipeline_ins(
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG]
denoise_img = Image.fromarray(denoise_img)
w, h = denoise_img.size
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] # BGR
h, w = denoise_img.shape[:2]
print('pipeline: the shape of output_img is {}x{}'.format(h, w)) print('pipeline: the shape of output_img is {}x{}'.format(h, w))


@unittest.skip('demo compatibility test is only enabled on a needed-basis') @unittest.skip('demo compatibility test is only enabled on a needed-basis')


+ 18
- 6
tests/trainers/test_image_denoise_trainer.py View File

@@ -6,10 +6,12 @@ import unittest


from modelscope.hub.snapshot_download import snapshot_download from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.cv.image_denoise import NAFNetForImageDenoise from modelscope.models.cv.image_denoise import NAFNetForImageDenoise
from modelscope.msdatasets.image_denoise_data import PairedImageDataset
from modelscope.msdatasets import MsDataset
from modelscope.msdatasets.task_datasets.sidd_image_denoising import \
SiddImageDenoisingDataset
from modelscope.trainers import build_trainer from modelscope.trainers import build_trainer
from modelscope.utils.config import Config from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile
from modelscope.utils.constant import DownloadMode, ModelFile
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import test_level from modelscope.utils.test_utils import test_level


@@ -28,10 +30,20 @@ class ImageDenoiseTrainerTest(unittest.TestCase):
self.cache_path = snapshot_download(self.model_id) self.cache_path = snapshot_download(self.model_id)
self.config = Config.from_file( self.config = Config.from_file(
os.path.join(self.cache_path, ModelFile.CONFIGURATION)) os.path.join(self.cache_path, ModelFile.CONFIGURATION))
self.dataset_train = PairedImageDataset(
self.config.dataset, self.cache_path, is_train=True)
self.dataset_val = PairedImageDataset(
self.config.dataset, self.cache_path, is_train=False)
dataset_train = MsDataset.load(
'SIDD',
namespace='huizheng',
split='validation',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds
dataset_val = MsDataset.load(
'SIDD',
namespace='huizheng',
split='test',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds
self.dataset_train = SiddImageDenoisingDataset(
dataset_train, self.config.dataset, is_train=True)
self.dataset_val = SiddImageDenoisingDataset(
dataset_val, self.config.dataset, is_train=False)


def tearDown(self): def tearDown(self):
shutil.rmtree(self.tmp_dir, ignore_errors=True) shutil.rmtree(self.tmp_dir, ignore_errors=True)


Loading…
Cancel
Save