|
- # ------------------------------------------------------------------------
- # Copyright (c) Alibaba, Inc. and its affiliates.
- # ------------------------------------------------------------------------
- # modified from https://github.com/megvii-research/NAFNet/blob/main/basicsr/metrics/psnr_ssim.py
- # ------------------------------------------------------------------------
- from typing import Dict
-
- import cv2
- import numpy as np
- import torch
-
- from modelscope.metainfo import Metrics
- from modelscope.utils.registry import default_group
- from .base import Metric
- from .builder import METRICS, MetricKeys
-
-
- @METRICS.register_module(
- group_key=default_group, module_name=Metrics.image_denoise_metric)
- class ImageDenoiseMetric(Metric):
- """The metric computation class for image denoise classes.
- """
- pred_name = 'pred'
- label_name = 'target'
-
- def __init__(self):
- super(ImageDenoiseMetric, self).__init__()
- self.preds = []
- self.labels = []
-
- def add(self, outputs: Dict, inputs: Dict):
- ground_truths = outputs[ImageDenoiseMetric.label_name]
- eval_results = outputs[ImageDenoiseMetric.pred_name]
- self.preds.append(eval_results)
- self.labels.append(ground_truths)
-
- def evaluate(self):
- psnr_list, ssim_list = [], []
- for (pred, label) in zip(self.preds, self.labels):
- psnr_list.append(calculate_psnr(label[0], pred[0], crop_border=0))
- ssim_list.append(calculate_ssim(label[0], pred[0], crop_border=0))
- return {
- MetricKeys.PSNR: np.mean(psnr_list),
- MetricKeys.SSIM: np.mean(ssim_list)
- }
-
-
- def reorder_image(img, input_order='HWC'):
- """Reorder images to 'HWC' order.
- If the input_order is (h, w), return (h, w, 1);
- If the input_order is (c, h, w), return (h, w, c);
- If the input_order is (h, w, c), return as it is.
- Args:
- img (ndarray): Input image.
- input_order (str): Whether the input order is 'HWC' or 'CHW'.
- If the input image shape is (h, w), input_order will not have
- effects. Default: 'HWC'.
- Returns:
- ndarray: reordered image.
- """
-
- if input_order not in ['HWC', 'CHW']:
- raise ValueError(
- f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'"
- )
- if len(img.shape) == 2:
- img = img[..., None]
- if input_order == 'CHW':
- img = img.transpose(1, 2, 0)
- return img
-
-
- def calculate_psnr(img1, img2, crop_border, input_order='HWC'):
- """Calculate PSNR (Peak Signal-to-Noise Ratio).
- Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
- Args:
- img1 (ndarray/tensor): Images with range [0, 255]/[0, 1].
- img2 (ndarray/tensor): Images with range [0, 255]/[0, 1].
- crop_border (int): Cropped pixels in each edge of an image. These
- pixels are not involved in the PSNR calculation.
- input_order (str): Whether the input order is 'HWC' or 'CHW'.
- Default: 'HWC'.
- test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
- Returns:
- float: psnr result.
- """
-
- assert img1.shape == img2.shape, (
- f'Image shapes are different: {img1.shape}, {img2.shape}.')
- if input_order not in ['HWC', 'CHW']:
- raise ValueError(
- f'Wrong input_order {input_order}. Supported input_orders are '
- '"HWC" and "CHW"')
- if type(img1) == torch.Tensor:
- if len(img1.shape) == 4:
- img1 = img1.squeeze(0)
- img1 = img1.detach().cpu().numpy().transpose(1, 2, 0)
- if type(img2) == torch.Tensor:
- if len(img2.shape) == 4:
- img2 = img2.squeeze(0)
- img2 = img2.detach().cpu().numpy().transpose(1, 2, 0)
-
- img1 = reorder_image(img1, input_order=input_order)
- img2 = reorder_image(img2, input_order=input_order)
- img1 = img1.astype(np.float64)
- img2 = img2.astype(np.float64)
-
- if crop_border != 0:
- img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
- img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
-
- def _psnr(img1, img2):
-
- mse = np.mean((img1 - img2)**2)
- if mse == 0:
- return float('inf')
- max_value = 1. if img1.max() <= 1 else 255.
- return 20. * np.log10(max_value / np.sqrt(mse))
-
- return _psnr(img1, img2)
-
-
- def calculate_ssim(img1, img2, crop_border, input_order='HWC', ssim3d=True):
- """Calculate SSIM (structural similarity).
- Ref:
- Image quality assessment: From error visibility to structural similarity
- The results are the same as that of the official released MATLAB code in
- https://ece.uwaterloo.ca/~z70wang/research/ssim/.
- For three-channel images, SSIM is calculated for each channel and then
- averaged.
- Args:
- img1 (ndarray): Images with range [0, 255].
- img2 (ndarray): Images with range [0, 255].
- crop_border (int): Cropped pixels in each edge of an image. These
- pixels are not involved in the SSIM calculation.
- input_order (str): Whether the input order is 'HWC' or 'CHW'.
- Default: 'HWC'.
- test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
- Returns:
- float: ssim result.
- """
-
- assert img1.shape == img2.shape, (
- f'Image shapes are different: {img1.shape}, {img2.shape}.')
- if input_order not in ['HWC', 'CHW']:
- raise ValueError(
- f'Wrong input_order {input_order}. Supported input_orders are '
- '"HWC" and "CHW"')
-
- if type(img1) == torch.Tensor:
- if len(img1.shape) == 4:
- img1 = img1.squeeze(0)
- img1 = img1.detach().cpu().numpy().transpose(1, 2, 0)
- if type(img2) == torch.Tensor:
- if len(img2.shape) == 4:
- img2 = img2.squeeze(0)
- img2 = img2.detach().cpu().numpy().transpose(1, 2, 0)
-
- img1 = reorder_image(img1, input_order=input_order)
- img2 = reorder_image(img2, input_order=input_order)
-
- img1 = img1.astype(np.float64)
- img2 = img2.astype(np.float64)
-
- if crop_border != 0:
- img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
- img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
-
- def _cal_ssim(img1, img2):
- ssims = []
-
- max_value = 1 if img1.max() <= 1 else 255
- with torch.no_grad():
- final_ssim = _ssim_3d(img1, img2, max_value) if ssim3d else _ssim(
- img1, img2, max_value)
- ssims.append(final_ssim)
-
- return np.array(ssims).mean()
-
- return _cal_ssim(img1, img2)
-
-
- def _ssim(img, img2, max_value):
- """Calculate SSIM (structural similarity) for one channel images.
- It is called by func:`calculate_ssim`.
- Args:
- img (ndarray): Images with range [0, 255] with order 'HWC'.
- img2 (ndarray): Images with range [0, 255] with order 'HWC'.
- Returns:
- float: SSIM result.
- """
-
- c1 = (0.01 * max_value)**2
- c2 = (0.03 * max_value)**2
-
- img = img.astype(np.float64)
- img2 = img2.astype(np.float64)
- kernel = cv2.getGaussianKernel(11, 1.5)
- window = np.outer(kernel, kernel.transpose())
-
- mu1 = cv2.filter2D(img, -1, window)[5:-5,
- 5:-5] # valid mode for window size 11
- mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
- mu1_sq = mu1**2
- mu2_sq = mu2**2
- mu1_mu2 = mu1 * mu2
- sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq
- sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
- sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
-
- tmp1 = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2)
- tmp2 = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)
- ssim_map = tmp1 / tmp2
- return ssim_map.mean()
-
-
- def _3d_gaussian_calculator(img, conv3d):
- out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
- return out
-
-
- def _generate_3d_gaussian_kernel():
- kernel = cv2.getGaussianKernel(11, 1.5)
- window = np.outer(kernel, kernel.transpose())
- kernel_3 = cv2.getGaussianKernel(11, 1.5)
- kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0))
- conv3d = torch.nn.Conv3d(
- 1,
- 1, (11, 11, 11),
- stride=1,
- padding=(5, 5, 5),
- bias=False,
- padding_mode='replicate')
- conv3d.weight.requires_grad = False
- conv3d.weight[0, 0, :, :, :] = kernel
- return conv3d
-
-
- def _ssim_3d(img1, img2, max_value):
- assert len(img1.shape) == 3 and len(img2.shape) == 3
- """Calculate SSIM (structural similarity) for one channel images.
- It is called by func:`calculate_ssim`.
- Args:
- img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
- img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
- Returns:
- float: ssim result.
- """
- C1 = (0.01 * max_value)**2
- C2 = (0.03 * max_value)**2
- img1 = img1.astype(np.float64)
- img2 = img2.astype(np.float64)
-
- kernel = _generate_3d_gaussian_kernel().cuda()
-
- img1 = torch.tensor(img1).float().cuda()
- img2 = torch.tensor(img2).float().cuda()
-
- mu1 = _3d_gaussian_calculator(img1, kernel)
- mu2 = _3d_gaussian_calculator(img2, kernel)
-
- mu1_sq = mu1**2
- mu2_sq = mu2**2
- mu1_mu2 = mu1 * mu2
- sigma1_sq = _3d_gaussian_calculator(img1**2, kernel) - mu1_sq
- sigma2_sq = _3d_gaussian_calculator(img2**2, kernel) - mu2_sq
- sigma12 = _3d_gaussian_calculator(img1 * img2, kernel) - mu1_mu2
-
- tmp1 = (2 * mu1_mu2 + C1) * (2 * sigma12 + C2)
- tmp2 = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
- ssim_map = tmp1 / tmp2
- return float(ssim_map.mean())
|