Browse Source

[to #42322933]fix psnr/ssim metrics for NAFNet (image denoise)

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10403246
master
huizheng.hz yingda.chen 3 years ago
parent
commit
c5c14ad60a
3 changed files with 149 additions and 55 deletions
  1. +146
    -47
      modelscope/metrics/image_denoise_metric.py
  2. +2
    -8
      modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py
  3. +1
    -0
      modelscope/msdatasets/task_datasets/__init__.py

+ 146
- 47
modelscope/metrics/image_denoise_metric.py View File

@@ -1,14 +1,16 @@
# The code is modified based on BasicSR metrics:
# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/metrics/psnr_ssim.py
# ------------------------------------------------------------------------
# Copyright (c) Alibaba, Inc. and its affiliates.
# ------------------------------------------------------------------------
# modified from https://github.com/megvii-research/NAFNet/blob/main/basicsr/metrics/psnr_ssim.py
# ------------------------------------------------------------------------
from typing import Dict

import cv2
import numpy as np
import torch

from modelscope.metainfo import Metrics
from modelscope.utils.registry import default_group
from modelscope.utils.tensor_utils import (torch_nested_detach,
torch_nested_numpify)
from .base import Metric
from .builder import METRICS, MetricKeys

@@ -22,16 +24,15 @@ class ImageDenoiseMetric(Metric):
label_name = 'target'

def __init__(self):
super(ImageDenoiseMetric, self).__init__()
self.preds = []
self.labels = []

def add(self, outputs: Dict, inputs: Dict):
ground_truths = outputs[ImageDenoiseMetric.label_name]
eval_results = outputs[ImageDenoiseMetric.pred_name]
self.preds.append(
torch_nested_numpify(torch_nested_detach(eval_results)))
self.labels.append(
torch_nested_numpify(torch_nested_detach(ground_truths)))
self.preds.append(eval_results)
self.labels.append(ground_truths)

def evaluate(self):
psnr_list, ssim_list = [], []
@@ -69,80 +70,117 @@ def reorder_image(img, input_order='HWC'):
return img


def calculate_psnr(img, img2, crop_border, input_order='HWC', **kwargs):
def calculate_psnr(img1, img2, crop_border, input_order='HWC'):
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
Args:
img (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'.
img1 (ndarray/tensor): Images with range [0, 255]/[0, 1].
img2 (ndarray/tensor): Images with range [0, 255]/[0, 1].
crop_border (int): Cropped pixels in each edge of an image. These
pixels are not involved in the PSNR calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
Default: 'HWC'.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: PSNR result.
float: psnr result.
"""

assert img.shape == img2.shape, (
f'Image shapes are different: {img.shape}, {img2.shape}.')
assert img1.shape == img2.shape, (
f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(
f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"'
)
img = reorder_image(img, input_order=input_order)
f'Wrong input_order {input_order}. Supported input_orders are '
'"HWC" and "CHW"')
if type(img1) == torch.Tensor:
if len(img1.shape) == 4:
img1 = img1.squeeze(0)
img1 = img1.detach().cpu().numpy().transpose(1, 2, 0)
if type(img2) == torch.Tensor:
if len(img2.shape) == 4:
img2 = img2.squeeze(0)
img2 = img2.detach().cpu().numpy().transpose(1, 2, 0)

img1 = reorder_image(img1, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order)
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)

if crop_border != 0:
img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]

img = img.astype(np.float64)
img2 = img2.astype(np.float64)
def _psnr(img1, img2):

mse = np.mean((img1 - img2)**2)
if mse == 0:
return float('inf')
max_value = 1. if img1.max() <= 1 else 255.
return 20. * np.log10(max_value / np.sqrt(mse))

mse = np.mean((img - img2)**2)
if mse == 0:
return float('inf')
return 10. * np.log10(255. * 255. / mse)
return _psnr(img1, img2)


def calculate_ssim(img, img2, crop_border, input_order='HWC', **kwargs):
def calculate_ssim(img1, img2, crop_border, input_order='HWC', ssim3d=True):
"""Calculate SSIM (structural similarity).
``Paper: Image quality assessment: From error visibility to structural similarity``
Ref:
Image quality assessment: From error visibility to structural similarity
The results are the same as that of the official released MATLAB code in
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
For three-channel images, SSIM is calculated for each channel and then
averaged.
Args:
img (ndarray): Images with range [0, 255].
img1 (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
crop_border (int): Cropped pixels in each edge of an image. These
pixels are not involved in the SSIM calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
Default: 'HWC'.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: SSIM result.
float: ssim result.
"""

assert img.shape == img2.shape, (
f'Image shapes are different: {img.shape}, {img2.shape}.')
assert img1.shape == img2.shape, (
f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(
f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"'
)
img = reorder_image(img, input_order=input_order)
f'Wrong input_order {input_order}. Supported input_orders are '
'"HWC" and "CHW"')

if type(img1) == torch.Tensor:
if len(img1.shape) == 4:
img1 = img1.squeeze(0)
img1 = img1.detach().cpu().numpy().transpose(1, 2, 0)
if type(img2) == torch.Tensor:
if len(img2.shape) == 4:
img2 = img2.squeeze(0)
img2 = img2.detach().cpu().numpy().transpose(1, 2, 0)

img1 = reorder_image(img1, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order)

img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)

if crop_border != 0:
img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]

img = img.astype(np.float64)
img2 = img2.astype(np.float64)
def _cal_ssim(img1, img2):
ssims = []

max_value = 1 if img1.max() <= 1 else 255
with torch.no_grad():
final_ssim = _ssim_3d(img1, img2, max_value) if ssim3d else _ssim(
img1, img2, max_value)
ssims.append(final_ssim)

ssims = []
for i in range(img.shape[2]):
ssims.append(_ssim(img[..., i], img2[..., i]))
return np.array(ssims).mean()
return np.array(ssims).mean()

return _cal_ssim(img1, img2)

def _ssim(img, img2):

def _ssim(img, img2, max_value):
"""Calculate SSIM (structural similarity) for one channel images.
It is called by func:`calculate_ssim`.
Args:
@@ -152,8 +190,11 @@ def _ssim(img, img2):
float: SSIM result.
"""

c1 = (0.01 * 255)**2
c2 = (0.03 * 255)**2
c1 = (0.01 * max_value)**2
c2 = (0.03 * max_value)**2

img = img.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())

@@ -171,3 +212,61 @@ def _ssim(img, img2):
tmp2 = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)
ssim_map = tmp1 / tmp2
return ssim_map.mean()


def _3d_gaussian_calculator(img, conv3d):
out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
return out


def _generate_3d_gaussian_kernel():
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
kernel_3 = cv2.getGaussianKernel(11, 1.5)
kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0))
conv3d = torch.nn.Conv3d(
1,
1, (11, 11, 11),
stride=1,
padding=(5, 5, 5),
bias=False,
padding_mode='replicate')
conv3d.weight.requires_grad = False
conv3d.weight[0, 0, :, :, :] = kernel
return conv3d


def _ssim_3d(img1, img2, max_value):
assert len(img1.shape) == 3 and len(img2.shape) == 3
"""Calculate SSIM (structural similarity) for one channel images.
It is called by func:`calculate_ssim`.
Args:
img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
Returns:
float: ssim result.
"""
C1 = (0.01 * max_value)**2
C2 = (0.03 * max_value)**2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)

kernel = _generate_3d_gaussian_kernel().cuda()

img1 = torch.tensor(img1).float().cuda()
img2 = torch.tensor(img2).float().cuda()

mu1 = _3d_gaussian_calculator(img1, kernel)
mu2 = _3d_gaussian_calculator(img2, kernel)

mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = _3d_gaussian_calculator(img1**2, kernel) - mu1_sq
sigma2_sq = _3d_gaussian_calculator(img2**2, kernel) - mu2_sq
sigma12 = _3d_gaussian_calculator(img1 * img2, kernel) - mu1_mu2

tmp1 = (2 * mu1_mu2 + C1) * (2 * sigma12 + C2)
tmp2 = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
ssim_map = tmp1 / tmp2
return float(ssim_map.mean())

+ 2
- 8
modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py View File

@@ -3,7 +3,6 @@ import os
from copy import deepcopy
from typing import Any, Dict, Union

import numpy as np
import torch.cuda
from torch.nn.parallel import DataParallel, DistributedDataParallel

@@ -78,13 +77,8 @@ class NAFNetForImageDenoise(TorchModel):
def _evaluate_postprocess(self, input: Tensor,
target: Tensor) -> Dict[str, list]:
preds = self.model(input)
preds = list(torch.split(preds, 1, 0))
targets = list(torch.split(target, 1, 0))

preds = [(pred.data * 255.).squeeze(0).permute(
1, 2, 0).cpu().numpy().astype(np.uint8) for pred in preds]
targets = [(target.data * 255.).squeeze(0).permute(
1, 2, 0).cpu().numpy().astype(np.uint8) for target in targets]
preds = list(torch.split(preds.clamp(0, 1), 1, 0))
targets = list(torch.split(target.clamp(0, 1), 1, 0))

return {'pred': preds, 'target': targets}



+ 1
- 0
modelscope/msdatasets/task_datasets/__init__.py View File

@@ -26,6 +26,7 @@ else:
'video_summarization_dataset': ['VideoSummarizationDataset'],
'movie_scene_segmentation': ['MovieSceneSegmentationDataset'],
'image_inpainting': ['ImageInpaintingDataset'],
'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'],
}
import sys



Loading…
Cancel
Save