Browse Source

[to #42322933]图像去噪using msdataset to load dataset

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10338265
master
huizheng.hz yingda.chen 3 years ago
parent
commit
922f4c589b
13 changed files with 284 additions and 260 deletions
  1. +134
    -6
      modelscope/metrics/image_denoise_metric.py
  2. +5
    -0
      modelscope/models/cv/image_denoise/nafnet/NAFNet_arch.py
  3. +5
    -0
      modelscope/models/cv/image_denoise/nafnet/arch_util.py
  4. +1
    -0
      modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py
  5. +0
    -152
      modelscope/msdatasets/image_denoise_data/data_utils.py
  6. +0
    -78
      modelscope/msdatasets/image_denoise_data/image_denoise_dataset.py
  7. +2
    -2
      modelscope/msdatasets/task_datasets/sidd_image_denoising/__init__.py
  8. +46
    -0
      modelscope/msdatasets/task_datasets/sidd_image_denoising/data_utils.py
  9. +62
    -0
      modelscope/msdatasets/task_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py
  10. +0
    -0
      modelscope/msdatasets/task_datasets/sidd_image_denoising/transforms.py
  11. +1
    -1
      modelscope/pipelines/cv/image_denoise_pipeline.py
  12. +10
    -15
      tests/pipelines/test_image_denoise.py
  13. +18
    -6
      tests/trainers/test_image_denoise_trainer.py

+ 134
- 6
modelscope/metrics/image_denoise_metric.py View File

@@ -1,7 +1,9 @@
# The code is modified based on BasicSR metrics:
# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/metrics/psnr_ssim.py
from typing import Dict

import cv2
import numpy as np
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

from modelscope.metainfo import Metrics
from modelscope.utils.registry import default_group
@@ -34,12 +36,138 @@ class ImageDenoiseMetric(Metric):
def evaluate(self):
psnr_list, ssim_list = [], []
for (pred, label) in zip(self.preds, self.labels):
psnr_list.append(
peak_signal_noise_ratio(label[0], pred[0], data_range=255))
ssim_list.append(
structural_similarity(
label[0], pred[0], multichannel=True, data_range=255))
psnr_list.append(calculate_psnr(label[0], pred[0], crop_border=0))
ssim_list.append(calculate_ssim(label[0], pred[0], crop_border=0))
return {
MetricKeys.PSNR: np.mean(psnr_list),
MetricKeys.SSIM: np.mean(ssim_list)
}


def reorder_image(img, input_order='HWC'):
"""Reorder images to 'HWC' order.
If the input_order is (h, w), return (h, w, 1);
If the input_order is (c, h, w), return (h, w, c);
If the input_order is (h, w, c), return as it is.
Args:
img (ndarray): Input image.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
If the input image shape is (h, w), input_order will not have
effects. Default: 'HWC'.
Returns:
ndarray: reordered image.
"""

if input_order not in ['HWC', 'CHW']:
raise ValueError(
f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'"
)
if len(img.shape) == 2:
img = img[..., None]
if input_order == 'CHW':
img = img.transpose(1, 2, 0)
return img


def calculate_psnr(img, img2, crop_border, input_order='HWC', **kwargs):
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
Args:
img (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'.
Returns:
float: PSNR result.
"""

assert img.shape == img2.shape, (
f'Image shapes are different: {img.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(
f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"'
)
img = reorder_image(img, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order)

if crop_border != 0:
img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]

img = img.astype(np.float64)
img2 = img2.astype(np.float64)

mse = np.mean((img - img2)**2)
if mse == 0:
return float('inf')
return 10. * np.log10(255. * 255. / mse)


def calculate_ssim(img, img2, crop_border, input_order='HWC', **kwargs):
"""Calculate SSIM (structural similarity).
``Paper: Image quality assessment: From error visibility to structural similarity``
The results are the same as that of the official released MATLAB code in
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
For three-channel images, SSIM is calculated for each channel and then
averaged.
Args:
img (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
Default: 'HWC'.
Returns:
float: SSIM result.
"""

assert img.shape == img2.shape, (
f'Image shapes are different: {img.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(
f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"'
)
img = reorder_image(img, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order)

if crop_border != 0:
img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]

img = img.astype(np.float64)
img2 = img2.astype(np.float64)

ssims = []
for i in range(img.shape[2]):
ssims.append(_ssim(img[..., i], img2[..., i]))
return np.array(ssims).mean()


def _ssim(img, img2):
"""Calculate SSIM (structural similarity) for one channel images.
It is called by func:`calculate_ssim`.
Args:
img (ndarray): Images with range [0, 255] with order 'HWC'.
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
Returns:
float: SSIM result.
"""

c1 = (0.01 * 255)**2
c2 = (0.03 * 255)**2
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())

mu1 = cv2.filter2D(img, -1, window)[5:-5,
5:-5] # valid mode for window size 11
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

tmp1 = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2)
tmp2 = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)
ssim_map = tmp1 / tmp2
return ssim_map.mean()

+ 5
- 0
modelscope/models/cv/image_denoise/nafnet/NAFNet_arch.py View File

@@ -1,3 +1,8 @@
# ------------------------------------------------------------------------
# Modified from https://github.com/megvii-research/NAFNet/blob/main/basicsr/models/archs/NAFNet_arch.py
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------

import numpy as np
import torch
import torch.nn as nn


+ 5
- 0
modelscope/models/cv/image_denoise/nafnet/arch_util.py View File

@@ -1,3 +1,8 @@
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------

import torch
import torch.nn as nn



+ 1
- 0
modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py View File

@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from copy import deepcopy
from typing import Any, Dict, Union


+ 0
- 152
modelscope/msdatasets/image_denoise_data/data_utils.py View File

@@ -1,152 +0,0 @@
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
import os
from os import path as osp

import cv2
import numpy as np
import torch

from .transforms import mod_crop


def img2tensor(imgs, bgr2rgb=True, float32=True):
"""Numpy array to tensor.
Args:
imgs (list[ndarray] | ndarray): Input images.
bgr2rgb (bool): Whether to change bgr to rgb.
float32 (bool): Whether to change to float32.
Returns:
list[tensor] | tensor: Tensor images. If returned results only have
one element, just return tensor.
"""

def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img

if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
else:
return _totensor(imgs, bgr2rgb, float32)


def scandir(dir_path, keyword=None, recursive=False, full_path=False):
"""Scan a directory to find the interested files.
Args:
dir_path (str): Path of the directory.
keyword (str | tuple(str), optional): File keyword that we are
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
full_path (bool, optional): If set to True, include the dir_path.
Default: False.
Returns:
A generator for all the interested files with relative pathes.
"""

if (keyword is not None) and not isinstance(keyword, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')

root = dir_path

def _scandir(dir_path, keyword, recursive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
if full_path:
return_path = entry.path
else:
return_path = osp.relpath(entry.path, root)

if keyword is None:
yield return_path
elif keyword in return_path:
yield return_path
else:
if recursive:
yield from _scandir(
entry.path, keyword=keyword, recursive=recursive)
else:
continue

return _scandir(dir_path, keyword=keyword, recursive=recursive)


def padding(img_lq, img_gt, gt_size):
h, w, _ = img_lq.shape

h_pad = max(0, gt_size - h)
w_pad = max(0, gt_size - w)

if h_pad == 0 and w_pad == 0:
return img_lq, img_gt

img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
return img_lq, img_gt


def read_img_seq(path, require_mod_crop=False, scale=1):
"""Read a sequence of images from a given folder path.
Args:
path (list[str] | str): List of image paths or image folder path.
require_mod_crop (bool): Require mod crop for each image.
Default: False.
scale (int): Scale factor for mod_crop. Default: 1.
Returns:
Tensor: size (t, c, h, w), RGB, [0, 1].
"""
if isinstance(path, list):
img_paths = path
else:
img_paths = sorted(list(scandir(path, full_path=True)))
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
if require_mod_crop:
imgs = [mod_crop(img, scale) for img in imgs]
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
imgs = torch.stack(imgs, dim=0)
return imgs


def paired_paths_from_folder(folders, keys, filename_tmpl):
"""Generate paired paths from folders.
Args:
folders (list[str]): A list of folder path. The order of list should
be [input_folder, gt_folder].
keys (list[str]): A list of keys identifying folders. The order should
be in consistent with folders, e.g., ['lq', 'gt'].
filename_tmpl (str): Template for each filename. Note that the
template excludes the file extension. Usually the filename_tmpl is
for files in the input folder.
Returns:
list[str]: Returned path list.
"""
assert len(folders) == 2, (
'The len of folders should be 2 with [input_folder, gt_folder]. '
f'But got {len(folders)}')
assert len(keys) == 2, (
'The len of keys should be 2 with [input_key, gt_key]. '
f'But got {len(keys)}')
input_folder, gt_folder = folders
input_key, gt_key = keys

input_paths = list(scandir(input_folder, keyword='NOISY', recursive=True))
gt_paths = list(scandir(gt_folder, keyword='GT', recursive=True))
assert len(input_paths) == len(gt_paths), (
f'{input_key} and {gt_key} datasets have different number of images: '
f'{len(input_paths)}, {len(gt_paths)}.')
paths = []
for idx in range(len(gt_paths)):
gt_path = os.path.join(gt_folder, gt_paths[idx])
input_path = os.path.join(input_folder, gt_path.replace('GT', 'NOISY'))

paths.append(
dict([(f'{input_key}_path', input_path),
(f'{gt_key}_path', gt_path)]))
return paths

+ 0
- 78
modelscope/msdatasets/image_denoise_data/image_denoise_dataset.py View File

@@ -1,78 +0,0 @@
import os
from typing import Callable, List, Optional, Tuple, Union

import cv2
import numpy as np
from torch.utils import data

from .data_utils import img2tensor, padding, paired_paths_from_folder
from .transforms import augment, paired_random_crop


def default_loader(path):
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0


class PairedImageDataset(data.Dataset):
"""Paired image dataset for image restoration.
"""

def __init__(self, opt, root, is_train):
super(PairedImageDataset, self).__init__()
self.opt = opt
self.is_train = is_train
self.gt_folder, self.lq_folder = os.path.join(
root, opt.dataroot_gt), os.path.join(root, opt.dataroot_lq)

if opt.filename_tmpl is not None:
self.filename_tmpl = opt.filename_tmpl
else:
self.filename_tmpl = '{}'
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder],
['lq', 'gt'], self.filename_tmpl)

def __getitem__(self, index):
scale = self.opt.scale

# Load gt and lq images. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32.
gt_path = self.paths[index]['gt_path']
img_gt = default_loader(gt_path)
lq_path = self.paths[index]['lq_path']
img_lq = default_loader(lq_path)

# augmentation for training
# if self.is_train:
gt_size = self.opt.gt_size
# padding
img_gt, img_lq = padding(img_gt, img_lq, gt_size)

# random crop
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale)

# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.opt.use_flip,
self.opt.use_rot)

# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq],
bgr2rgb=True,
float32=True)

return {
'input': img_lq,
'target': img_gt,
'input_path': lq_path,
'target_path': gt_path
}

def __len__(self):
return len(self.paths)

def to_torch_dataset(
self,
columns: Union[str, List[str]] = None,
preprocessors: Union[Callable, List[Callable]] = None,
**format_kwargs,
):
return self

modelscope/msdatasets/image_denoise_data/__init__.py → modelscope/msdatasets/task_datasets/sidd_image_denoising/__init__.py View File

@@ -4,11 +4,11 @@ from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .image_denoise_dataset import PairedImageDataset
from .sidd_image_denoising_dataset import SiddImageDenoisingDataset

else:
_import_structure = {
'image_denoise_dataset': ['PairedImageDataset'],
'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'],
}

import sys

+ 46
- 0
modelscope/msdatasets/task_datasets/sidd_image_denoising/data_utils.py View File

@@ -0,0 +1,46 @@
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------

import cv2
import torch


def img2tensor(imgs, bgr2rgb=True, float32=True):
"""Numpy array to tensor.
Args:
imgs (list[ndarray] | ndarray): Input images.
bgr2rgb (bool): Whether to change bgr to rgb.
float32 (bool): Whether to change to float32.
Returns:
list[tensor] | tensor: Tensor images. If returned results only have
one element, just return tensor.
"""

def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img

if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
else:
return _totensor(imgs, bgr2rgb, float32)


def padding(img_lq, img_gt, gt_size):
h, w, _ = img_lq.shape

h_pad = max(0, gt_size - h)
w_pad = max(0, gt_size - w)

if h_pad == 0 and w_pad == 0:
return img_lq, img_gt

img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
return img_lq, img_gt

+ 62
- 0
modelscope/msdatasets/task_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py View File

@@ -0,0 +1,62 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import cv2
import numpy as np

from modelscope.metainfo import Models
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
TorchTaskDataset
from modelscope.utils.constant import Tasks
from .data_utils import img2tensor, padding
from .transforms import augment, paired_random_crop


def default_loader(path):
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0


@TASK_DATASETS.register_module(
Tasks.image_denoising, module_name=Models.nafnet)
class SiddImageDenoisingDataset(TorchTaskDataset):
"""Paired image dataset for image restoration.
"""

def __init__(self, dataset, opt, is_train):
self.dataset = dataset
self.opt = opt
self.is_train = is_train

def __len__(self):
return len(self.dataset)

def __getitem__(self, index):

# Load gt and lq images. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32.
item_dict = self.dataset[index]
gt_path = item_dict['Clean Image:FILE']
img_gt = default_loader(gt_path)
lq_path = item_dict['Noisy Image:FILE']
img_lq = default_loader(lq_path)

# augmentation for training
if self.is_train:
gt_size = self.opt.gt_size
# padding
img_gt, img_lq = padding(img_gt, img_lq, gt_size)

# random crop
img_gt, img_lq = paired_random_crop(
img_gt, img_lq, gt_size, scale=1)

# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.opt.use_flip,
self.opt.use_rot)

# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq],
bgr2rgb=True,
float32=True)

return {'input': img_lq, 'target': img_gt}

modelscope/msdatasets/image_denoise_data/transforms.py → modelscope/msdatasets/task_datasets/sidd_image_denoising/transforms.py View File


+ 1
- 1
modelscope/pipelines/cv/image_denoise_pipeline.py View File

@@ -105,4 +105,4 @@ class ImageDenoisePipeline(Pipeline):
def postprocess(self, input: Dict[str, Any]) -> Dict[str, Any]:
output_img = (input['output_tensor'].squeeze(0) * 255).cpu().permute(
1, 2, 0).numpy().astype('uint8')
return {OutputKeys.OUTPUT_IMG: output_img}
return {OutputKeys.OUTPUT_IMG: output_img[:, :, ::-1]}

+ 10
- 15
tests/pipelines/test_image_denoise.py View File

@@ -2,8 +2,6 @@

import unittest

from PIL import Image

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.outputs import OutputKeys
@@ -20,16 +18,16 @@ class ImageDenoiseTest(unittest.TestCase, DemoCompatibilityCheck):
self.task = Tasks.image_denoising
self.model_id = 'damo/cv_nafnet_image-denoise_sidd'

demo_image_path = 'data/test/images/noisy-demo-1.png'
demo_image_path = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/noisy-demo-0.png'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)
pipeline = ImageDenoisePipeline(cache_path)
pipeline.group_key = self.task
denoise_img = pipeline(
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG]
denoise_img = Image.fromarray(denoise_img)
w, h = denoise_img.size
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] # BGR
h, w = denoise_img.shape[:2]
print('pipeline: the shape of output_img is {}x{}'.format(h, w))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@@ -37,9 +35,8 @@ class ImageDenoiseTest(unittest.TestCase, DemoCompatibilityCheck):
model = Model.from_pretrained(self.model_id)
pipeline_ins = pipeline(task=Tasks.image_denoising, model=model)
denoise_img = pipeline_ins(
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG]
denoise_img = Image.fromarray(denoise_img)
w, h = denoise_img.size
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] # BGR
h, w = denoise_img.shape[:2]
print('pipeline: the shape of output_img is {}x{}'.format(h, w))

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@@ -47,18 +44,16 @@ class ImageDenoiseTest(unittest.TestCase, DemoCompatibilityCheck):
pipeline_ins = pipeline(
task=Tasks.image_denoising, model=self.model_id)
denoise_img = pipeline_ins(
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG]
denoise_img = Image.fromarray(denoise_img)
w, h = denoise_img.size
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] # BGR
h, w = denoise_img.shape[:2]
print('pipeline: the shape of output_img is {}x{}'.format(h, w))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.image_denoising)
denoise_img = pipeline_ins(
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG]
denoise_img = Image.fromarray(denoise_img)
w, h = denoise_img.size
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] # BGR
h, w = denoise_img.shape[:2]
print('pipeline: the shape of output_img is {}x{}'.format(h, w))

@unittest.skip('demo compatibility test is only enabled on a needed-basis')


+ 18
- 6
tests/trainers/test_image_denoise_trainer.py View File

@@ -6,10 +6,12 @@ import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.cv.image_denoise import NAFNetForImageDenoise
from modelscope.msdatasets.image_denoise_data import PairedImageDataset
from modelscope.msdatasets import MsDataset
from modelscope.msdatasets.task_datasets.sidd_image_denoising import \
SiddImageDenoisingDataset
from modelscope.trainers import build_trainer
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile
from modelscope.utils.constant import DownloadMode, ModelFile
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import test_level

@@ -28,10 +30,20 @@ class ImageDenoiseTrainerTest(unittest.TestCase):
self.cache_path = snapshot_download(self.model_id)
self.config = Config.from_file(
os.path.join(self.cache_path, ModelFile.CONFIGURATION))
self.dataset_train = PairedImageDataset(
self.config.dataset, self.cache_path, is_train=True)
self.dataset_val = PairedImageDataset(
self.config.dataset, self.cache_path, is_train=False)
dataset_train = MsDataset.load(
'SIDD',
namespace='huizheng',
split='validation',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds
dataset_val = MsDataset.load(
'SIDD',
namespace='huizheng',
split='test',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds
self.dataset_train = SiddImageDenoisingDataset(
dataset_train, self.config.dataset, is_train=True)
self.dataset_val = SiddImageDenoisingDataset(
dataset_val, self.config.dataset, is_train=False)

def tearDown(self):
shutil.rmtree(self.tmp_dir, ignore_errors=True)


Loading…
Cancel
Save