|
- # Copyright (c) Alibaba, Inc. and its affiliates.
- from typing import Any, Dict, Optional, Union
-
- import torch
- from torchvision import transforms
-
- from modelscope.metainfo import Pipelines
- from modelscope.models import Model
- from modelscope.models.cv.image_denoise import NAFNetForImageDenoise
- from modelscope.outputs import OutputKeys
- from modelscope.pipelines.base import Input, Pipeline
- from modelscope.pipelines.builder import PIPELINES
- from modelscope.preprocessors import ImageDenoisePreprocessor, LoadImage
- from modelscope.utils.constant import Tasks
- from modelscope.utils.logger import get_logger
-
- logger = get_logger()
-
- __all__ = ['ImageDenoisePipeline']
-
-
- @PIPELINES.register_module(
- Tasks.image_denoising, module_name=Pipelines.image_denoise)
- class ImageDenoisePipeline(Pipeline):
-
- def __init__(self,
- model: Union[NAFNetForImageDenoise, str],
- preprocessor: Optional[ImageDenoisePreprocessor] = None,
- **kwargs):
- """
- use `model` and `preprocessor` to create a cv image denoise pipeline for prediction
- Args:
- model: model id on modelscope hub.
- """
- super().__init__(model=model, preprocessor=preprocessor, **kwargs)
- self.model.eval()
- self.config = self.model.config
-
- if torch.cuda.is_available():
- self._device = torch.device('cuda')
- else:
- self._device = torch.device('cpu')
- logger.info('load image denoise model done')
-
- def preprocess(self, input: Input) -> Dict[str, Any]:
- img = LoadImage.convert_to_img(input)
- test_transforms = transforms.Compose([transforms.ToTensor()])
- img = test_transforms(img)
- result = {'img': img.unsqueeze(0).to(self._device)}
- return result
-
- def crop_process(self, input):
- output = torch.zeros_like(input) # [1, C, H, W]
- # determine crop_h and crop_w
- ih, iw = input.shape[-2:]
- crop_rows, crop_cols = max(ih // 512, 1), max(iw // 512, 1)
- overlap = 16
-
- step_h, step_w = ih // crop_rows, iw // crop_cols
- for y in range(crop_rows):
- for x in range(crop_cols):
- crop_y = step_h * y
- crop_x = step_w * x
-
- crop_h = step_h if y < crop_rows - 1 else ih - crop_y
- crop_w = step_w if x < crop_cols - 1 else iw - crop_x
-
- crop_frames = input[:, :,
- max(0, crop_y - overlap
- ):min(crop_y + crop_h + overlap, ih),
- max(0, crop_x - overlap
- ):min(crop_x + crop_w
- + overlap, iw)].contiguous()
- h_start = overlap if max(0, crop_y - overlap) > 0 else 0
- w_start = overlap if max(0, crop_x - overlap) > 0 else 0
- h_end = h_start + crop_h if min(crop_y + crop_h
- + overlap, ih) < ih else ih
- w_end = w_start + crop_w if min(crop_x + crop_w
- + overlap, iw) < iw else iw
-
- output[:, :, crop_y:crop_y + crop_h,
- crop_x:crop_x + crop_w] = self.model._inference_forward(
- crop_frames)['outputs'][:, :, h_start:h_end,
- w_start:w_end]
- return output
-
- def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
-
- def set_phase(model, is_train):
- if is_train:
- model.train()
- else:
- model.eval()
-
- is_train = False
- set_phase(self.model, is_train)
- with torch.no_grad():
- output = self.crop_process(input['img']) # output Tensor
-
- return {'output_tensor': output}
-
- def postprocess(self, input: Dict[str, Any]) -> Dict[str, Any]:
- output_img = (input['output_tensor'].squeeze(0) * 255).cpu().permute(
- 1, 2, 0).numpy().astype('uint8')
- return {OutputKeys.OUTPUT_IMG: output_img[:, :, ::-1]}
|