|
|
@@ -1,14 +1,16 @@ |
|
|
|
# The code is modified based on BasicSR metrics: |
|
|
|
# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/metrics/psnr_ssim.py |
|
|
|
# ------------------------------------------------------------------------ |
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates. |
|
|
|
# ------------------------------------------------------------------------ |
|
|
|
# modified from https://github.com/megvii-research/NAFNet/blob/main/basicsr/metrics/psnr_ssim.py |
|
|
|
# ------------------------------------------------------------------------ |
|
|
|
from typing import Dict |
|
|
|
|
|
|
|
import cv2 |
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
|
|
|
|
from modelscope.metainfo import Metrics |
|
|
|
from modelscope.utils.registry import default_group |
|
|
|
from modelscope.utils.tensor_utils import (torch_nested_detach, |
|
|
|
torch_nested_numpify) |
|
|
|
from .base import Metric |
|
|
|
from .builder import METRICS, MetricKeys |
|
|
|
|
|
|
@@ -22,16 +24,15 @@ class ImageDenoiseMetric(Metric): |
|
|
|
label_name = 'target' |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
super(ImageDenoiseMetric, self).__init__() |
|
|
|
self.preds = [] |
|
|
|
self.labels = [] |
|
|
|
|
|
|
|
def add(self, outputs: Dict, inputs: Dict): |
|
|
|
ground_truths = outputs[ImageDenoiseMetric.label_name] |
|
|
|
eval_results = outputs[ImageDenoiseMetric.pred_name] |
|
|
|
self.preds.append( |
|
|
|
torch_nested_numpify(torch_nested_detach(eval_results))) |
|
|
|
self.labels.append( |
|
|
|
torch_nested_numpify(torch_nested_detach(ground_truths))) |
|
|
|
self.preds.append(eval_results) |
|
|
|
self.labels.append(ground_truths) |
|
|
|
|
|
|
|
def evaluate(self): |
|
|
|
psnr_list, ssim_list = [], [] |
|
|
@@ -69,80 +70,117 @@ def reorder_image(img, input_order='HWC'): |
|
|
|
return img |
|
|
|
|
|
|
|
|
|
|
|
def calculate_psnr(img, img2, crop_border, input_order='HWC', **kwargs): |
|
|
|
def calculate_psnr(img1, img2, crop_border, input_order='HWC'): |
|
|
|
"""Calculate PSNR (Peak Signal-to-Noise Ratio). |
|
|
|
Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio |
|
|
|
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio |
|
|
|
Args: |
|
|
|
img (ndarray): Images with range [0, 255]. |
|
|
|
img2 (ndarray): Images with range [0, 255]. |
|
|
|
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. |
|
|
|
input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'. |
|
|
|
img1 (ndarray/tensor): Images with range [0, 255]/[0, 1]. |
|
|
|
img2 (ndarray/tensor): Images with range [0, 255]/[0, 1]. |
|
|
|
crop_border (int): Cropped pixels in each edge of an image. These |
|
|
|
pixels are not involved in the PSNR calculation. |
|
|
|
input_order (str): Whether the input order is 'HWC' or 'CHW'. |
|
|
|
Default: 'HWC'. |
|
|
|
test_y_channel (bool): Test on Y channel of YCbCr. Default: False. |
|
|
|
Returns: |
|
|
|
float: PSNR result. |
|
|
|
float: psnr result. |
|
|
|
""" |
|
|
|
|
|
|
|
assert img.shape == img2.shape, ( |
|
|
|
f'Image shapes are different: {img.shape}, {img2.shape}.') |
|
|
|
assert img1.shape == img2.shape, ( |
|
|
|
f'Image shapes are differnet: {img1.shape}, {img2.shape}.') |
|
|
|
if input_order not in ['HWC', 'CHW']: |
|
|
|
raise ValueError( |
|
|
|
f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"' |
|
|
|
) |
|
|
|
img = reorder_image(img, input_order=input_order) |
|
|
|
f'Wrong input_order {input_order}. Supported input_orders are ' |
|
|
|
'"HWC" and "CHW"') |
|
|
|
if type(img1) == torch.Tensor: |
|
|
|
if len(img1.shape) == 4: |
|
|
|
img1 = img1.squeeze(0) |
|
|
|
img1 = img1.detach().cpu().numpy().transpose(1, 2, 0) |
|
|
|
if type(img2) == torch.Tensor: |
|
|
|
if len(img2.shape) == 4: |
|
|
|
img2 = img2.squeeze(0) |
|
|
|
img2 = img2.detach().cpu().numpy().transpose(1, 2, 0) |
|
|
|
|
|
|
|
img1 = reorder_image(img1, input_order=input_order) |
|
|
|
img2 = reorder_image(img2, input_order=input_order) |
|
|
|
img1 = img1.astype(np.float64) |
|
|
|
img2 = img2.astype(np.float64) |
|
|
|
|
|
|
|
if crop_border != 0: |
|
|
|
img = img[crop_border:-crop_border, crop_border:-crop_border, ...] |
|
|
|
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] |
|
|
|
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] |
|
|
|
|
|
|
|
img = img.astype(np.float64) |
|
|
|
img2 = img2.astype(np.float64) |
|
|
|
def _psnr(img1, img2): |
|
|
|
|
|
|
|
mse = np.mean((img1 - img2)**2) |
|
|
|
if mse == 0: |
|
|
|
return float('inf') |
|
|
|
max_value = 1. if img1.max() <= 1 else 255. |
|
|
|
return 20. * np.log10(max_value / np.sqrt(mse)) |
|
|
|
|
|
|
|
mse = np.mean((img - img2)**2) |
|
|
|
if mse == 0: |
|
|
|
return float('inf') |
|
|
|
return 10. * np.log10(255. * 255. / mse) |
|
|
|
return _psnr(img1, img2) |
|
|
|
|
|
|
|
|
|
|
|
def calculate_ssim(img, img2, crop_border, input_order='HWC', **kwargs): |
|
|
|
def calculate_ssim(img1, img2, crop_border, input_order='HWC', ssim3d=True): |
|
|
|
"""Calculate SSIM (structural similarity). |
|
|
|
``Paper: Image quality assessment: From error visibility to structural similarity`` |
|
|
|
Ref: |
|
|
|
Image quality assessment: From error visibility to structural similarity |
|
|
|
The results are the same as that of the official released MATLAB code in |
|
|
|
https://ece.uwaterloo.ca/~z70wang/research/ssim/. |
|
|
|
For three-channel images, SSIM is calculated for each channel and then |
|
|
|
averaged. |
|
|
|
Args: |
|
|
|
img (ndarray): Images with range [0, 255]. |
|
|
|
img1 (ndarray): Images with range [0, 255]. |
|
|
|
img2 (ndarray): Images with range [0, 255]. |
|
|
|
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. |
|
|
|
crop_border (int): Cropped pixels in each edge of an image. These |
|
|
|
pixels are not involved in the SSIM calculation. |
|
|
|
input_order (str): Whether the input order is 'HWC' or 'CHW'. |
|
|
|
Default: 'HWC'. |
|
|
|
test_y_channel (bool): Test on Y channel of YCbCr. Default: False. |
|
|
|
Returns: |
|
|
|
float: SSIM result. |
|
|
|
float: ssim result. |
|
|
|
""" |
|
|
|
|
|
|
|
assert img.shape == img2.shape, ( |
|
|
|
f'Image shapes are different: {img.shape}, {img2.shape}.') |
|
|
|
assert img1.shape == img2.shape, ( |
|
|
|
f'Image shapes are differnet: {img1.shape}, {img2.shape}.') |
|
|
|
if input_order not in ['HWC', 'CHW']: |
|
|
|
raise ValueError( |
|
|
|
f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"' |
|
|
|
) |
|
|
|
img = reorder_image(img, input_order=input_order) |
|
|
|
f'Wrong input_order {input_order}. Supported input_orders are ' |
|
|
|
'"HWC" and "CHW"') |
|
|
|
|
|
|
|
if type(img1) == torch.Tensor: |
|
|
|
if len(img1.shape) == 4: |
|
|
|
img1 = img1.squeeze(0) |
|
|
|
img1 = img1.detach().cpu().numpy().transpose(1, 2, 0) |
|
|
|
if type(img2) == torch.Tensor: |
|
|
|
if len(img2.shape) == 4: |
|
|
|
img2 = img2.squeeze(0) |
|
|
|
img2 = img2.detach().cpu().numpy().transpose(1, 2, 0) |
|
|
|
|
|
|
|
img1 = reorder_image(img1, input_order=input_order) |
|
|
|
img2 = reorder_image(img2, input_order=input_order) |
|
|
|
|
|
|
|
img1 = img1.astype(np.float64) |
|
|
|
img2 = img2.astype(np.float64) |
|
|
|
|
|
|
|
if crop_border != 0: |
|
|
|
img = img[crop_border:-crop_border, crop_border:-crop_border, ...] |
|
|
|
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] |
|
|
|
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] |
|
|
|
|
|
|
|
img = img.astype(np.float64) |
|
|
|
img2 = img2.astype(np.float64) |
|
|
|
def _cal_ssim(img1, img2): |
|
|
|
ssims = [] |
|
|
|
|
|
|
|
max_value = 1 if img1.max() <= 1 else 255 |
|
|
|
with torch.no_grad(): |
|
|
|
final_ssim = _ssim_3d(img1, img2, max_value) if ssim3d else _ssim( |
|
|
|
img1, img2, max_value) |
|
|
|
ssims.append(final_ssim) |
|
|
|
|
|
|
|
ssims = [] |
|
|
|
for i in range(img.shape[2]): |
|
|
|
ssims.append(_ssim(img[..., i], img2[..., i])) |
|
|
|
return np.array(ssims).mean() |
|
|
|
return np.array(ssims).mean() |
|
|
|
|
|
|
|
return _cal_ssim(img1, img2) |
|
|
|
|
|
|
|
def _ssim(img, img2): |
|
|
|
|
|
|
|
def _ssim(img, img2, max_value): |
|
|
|
"""Calculate SSIM (structural similarity) for one channel images. |
|
|
|
It is called by func:`calculate_ssim`. |
|
|
|
Args: |
|
|
@@ -152,8 +190,11 @@ def _ssim(img, img2): |
|
|
|
float: SSIM result. |
|
|
|
""" |
|
|
|
|
|
|
|
c1 = (0.01 * 255)**2 |
|
|
|
c2 = (0.03 * 255)**2 |
|
|
|
c1 = (0.01 * max_value)**2 |
|
|
|
c2 = (0.03 * max_value)**2 |
|
|
|
|
|
|
|
img = img.astype(np.float64) |
|
|
|
img2 = img2.astype(np.float64) |
|
|
|
kernel = cv2.getGaussianKernel(11, 1.5) |
|
|
|
window = np.outer(kernel, kernel.transpose()) |
|
|
|
|
|
|
@@ -171,3 +212,61 @@ def _ssim(img, img2): |
|
|
|
tmp2 = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2) |
|
|
|
ssim_map = tmp1 / tmp2 |
|
|
|
return ssim_map.mean() |
|
|
|
|
|
|
|
|
|
|
|
def _3d_gaussian_calculator(img, conv3d): |
|
|
|
out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
def _generate_3d_gaussian_kernel(): |
|
|
|
kernel = cv2.getGaussianKernel(11, 1.5) |
|
|
|
window = np.outer(kernel, kernel.transpose()) |
|
|
|
kernel_3 = cv2.getGaussianKernel(11, 1.5) |
|
|
|
kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0)) |
|
|
|
conv3d = torch.nn.Conv3d( |
|
|
|
1, |
|
|
|
1, (11, 11, 11), |
|
|
|
stride=1, |
|
|
|
padding=(5, 5, 5), |
|
|
|
bias=False, |
|
|
|
padding_mode='replicate') |
|
|
|
conv3d.weight.requires_grad = False |
|
|
|
conv3d.weight[0, 0, :, :, :] = kernel |
|
|
|
return conv3d |
|
|
|
|
|
|
|
|
|
|
|
def _ssim_3d(img1, img2, max_value): |
|
|
|
assert len(img1.shape) == 3 and len(img2.shape) == 3 |
|
|
|
"""Calculate SSIM (structural similarity) for one channel images. |
|
|
|
It is called by func:`calculate_ssim`. |
|
|
|
Args: |
|
|
|
img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. |
|
|
|
img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. |
|
|
|
Returns: |
|
|
|
float: ssim result. |
|
|
|
""" |
|
|
|
C1 = (0.01 * max_value)**2 |
|
|
|
C2 = (0.03 * max_value)**2 |
|
|
|
img1 = img1.astype(np.float64) |
|
|
|
img2 = img2.astype(np.float64) |
|
|
|
|
|
|
|
kernel = _generate_3d_gaussian_kernel().cuda() |
|
|
|
|
|
|
|
img1 = torch.tensor(img1).float().cuda() |
|
|
|
img2 = torch.tensor(img2).float().cuda() |
|
|
|
|
|
|
|
mu1 = _3d_gaussian_calculator(img1, kernel) |
|
|
|
mu2 = _3d_gaussian_calculator(img2, kernel) |
|
|
|
|
|
|
|
mu1_sq = mu1**2 |
|
|
|
mu2_sq = mu2**2 |
|
|
|
mu1_mu2 = mu1 * mu2 |
|
|
|
sigma1_sq = _3d_gaussian_calculator(img1**2, kernel) - mu1_sq |
|
|
|
sigma2_sq = _3d_gaussian_calculator(img2**2, kernel) - mu2_sq |
|
|
|
sigma12 = _3d_gaussian_calculator(img1 * img2, kernel) - mu1_mu2 |
|
|
|
|
|
|
|
tmp1 = (2 * mu1_mu2 + C1) * (2 * sigma12 + C2) |
|
|
|
tmp2 = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) |
|
|
|
ssim_map = tmp1 / tmp2 |
|
|
|
return float(ssim_map.mean()) |