Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9501913 * [to #43259593] refacor image preprocessmaster
@@ -1,6 +1,6 @@ | |||
import os.path as osp | |||
from typing import Any, Dict | |||
import cv2 | |||
import json | |||
import numpy as np | |||
import torch | |||
@@ -11,7 +11,6 @@ from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.torch_utils import create_device | |||
def audio_norm(x): | |||
@@ -8,7 +8,6 @@ from typing import Any, Dict, Generator, List, Mapping, Union | |||
import numpy as np | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.models.base import Model | |||
from modelscope.msdatasets import MsDataset | |||
from modelscope.outputs import TASK_OUTPUTS | |||
@@ -13,7 +13,7 @@ from modelscope.models.cv.animal_recognition import resnet | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import load_image | |||
from modelscope.preprocessors import LoadImage, load_image | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
@@ -79,19 +79,7 @@ class AnimalRecogPipeline(Pipeline): | |||
logger.info('load model done') | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
if isinstance(input, str): | |||
img = load_image(input) | |||
elif isinstance(input, PIL.Image.Image): | |||
img = input.convert('RGB') | |||
elif isinstance(input, np.ndarray): | |||
if len(input.shape) == 2: | |||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||
img = input[:, :, ::-1] | |||
img = Image.fromarray(img.astype('uint8')).convert('RGB') | |||
else: | |||
raise TypeError(f'input should be either str, PIL.Image,' | |||
f' np.array, but got {type(input)}') | |||
img = LoadImage.convert_to_img(input) | |||
normalize = transforms.Normalize( | |||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |||
test_transforms = transforms.Compose([ | |||
@@ -3,7 +3,6 @@ from typing import Any, Dict | |||
import cv2 | |||
import numpy as np | |||
import PIL | |||
import tensorflow as tf | |||
from modelscope.metainfo import Pipelines | |||
@@ -14,7 +13,7 @@ from modelscope.models.cv.cartoon.utils import get_f5p, padTo16x, resize_size | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import load_image | |||
from modelscope.preprocessors import LoadImage | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
@@ -65,17 +64,7 @@ class ImageCartoonPipeline(Pipeline): | |||
return sess | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
if isinstance(input, str): | |||
img = np.array(load_image(input)) | |||
elif isinstance(input, PIL.Image.Image): | |||
img = np.array(input.convert('RGB')) | |||
elif isinstance(input, np.ndarray): | |||
if len(input.shape) == 2: | |||
input = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | |||
img = input[:, :, ::-1] | |||
else: | |||
raise TypeError(f'input should be either str, PIL.Image,' | |||
f' np.array, but got {type(input)}') | |||
img = LoadImage.convert_to_ndarray(input) | |||
img = img.astype(np.float) | |||
result = {'img': img} | |||
return result | |||
@@ -13,7 +13,7 @@ from modelscope.models.cv.image_color_enhance.image_color_enhance import \ | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input | |||
from modelscope.preprocessors import (ImageColorEnhanceFinetunePreprocessor, | |||
load_image) | |||
LoadImage, load_image) | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
from ..base import Pipeline | |||
@@ -47,18 +47,7 @@ class ImageColorEnhancePipeline(Pipeline): | |||
self._device = torch.device('cpu') | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
if isinstance(input, str): | |||
img = load_image(input) | |||
elif isinstance(input, PIL.Image.Image): | |||
img = input.convert('RGB') | |||
elif isinstance(input, np.ndarray): | |||
if len(input.shape) == 2: | |||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |||
img = Image.fromarray(img.astype('uint8')).convert('RGB') | |||
else: | |||
raise TypeError(f'input should be either str, PIL.Image,' | |||
f' np.array, but got {type(input)}') | |||
img = LoadImage.convert_to_img(input) | |||
test_transforms = transforms.Compose([transforms.ToTensor()]) | |||
img = test_transforms(img) | |||
result = {'src': img.unsqueeze(0).to(self._device)} | |||
@@ -88,7 +88,7 @@ class ImageColorizationPipeline(Pipeline): | |||
img = input.convert('LA').convert('RGB') | |||
elif isinstance(input, np.ndarray): | |||
if len(input.shape) == 2: | |||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||
img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | |||
img = input[:, :, ::-1] # in rgb order | |||
img = PIL.Image.fromarray(img).convert('LA').convert('RGB') | |||
else: | |||
@@ -3,13 +3,12 @@ from typing import Any, Dict | |||
import cv2 | |||
import numpy as np | |||
import PIL | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import load_image | |||
from modelscope.preprocessors import LoadImage | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
@@ -47,17 +46,7 @@ class ImageMattingPipeline(Pipeline): | |||
logger.info('load model done') | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
if isinstance(input, str): | |||
img = np.array(load_image(input)) | |||
elif isinstance(input, PIL.Image.Image): | |||
img = np.array(input.convert('RGB')) | |||
elif isinstance(input, np.ndarray): | |||
if len(input.shape) == 2: | |||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||
img = input[:, :, ::-1] # in rgb order | |||
else: | |||
raise TypeError(f'input should be either str, PIL.Image,' | |||
f' np.array, but got {type(input)}') | |||
img = LoadImage.convert_to_img(input) | |||
img = img.astype(np.float) | |||
result = {'img': img} | |||
return result | |||
@@ -10,7 +10,7 @@ from modelscope.models.cv.super_resolution import rrdbnet_arch | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import load_image | |||
from modelscope.preprocessors import LoadImage, load_image | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
@@ -46,18 +46,7 @@ class ImageSuperResolutionPipeline(Pipeline): | |||
logger.info('load model done') | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
if isinstance(input, str): | |||
img = np.array(load_image(input)) | |||
elif isinstance(input, PIL.Image.Image): | |||
img = np.array(input.convert('RGB')) | |||
elif isinstance(input, np.ndarray): | |||
if len(input.shape) == 2: | |||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||
img = input[:, :, ::-1] # in rgb order | |||
else: | |||
raise TypeError(f'input should be either str, PIL.Image,' | |||
f' np.array, but got {type(input)}') | |||
img = LoadImage.convert_to_ndarray(input) | |||
img = torch.from_numpy(img).to(self.device).permute( | |||
2, 0, 1).unsqueeze(0) / 255. | |||
result = {'img': img} | |||
@@ -3,14 +3,13 @@ from typing import Any, Dict | |||
import cv2 | |||
import numpy as np | |||
import PIL | |||
import tensorflow as tf | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import load_image | |||
from modelscope.preprocessors import LoadImage | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
from .ocr_utils import model_resnet_mutex_v4_linewithchar, ops, utils | |||
@@ -112,17 +111,8 @@ class OCRDetectionPipeline(Pipeline): | |||
model_loader.restore(sess, model_path) | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
if isinstance(input, str): | |||
img = np.array(load_image(input)) | |||
elif isinstance(input, PIL.Image.Image): | |||
img = np.array(input.convert('RGB')) | |||
elif isinstance(input, np.ndarray): | |||
if len(input.shape) == 2: | |||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||
img = input[:, :, ::-1] # in rgb order | |||
else: | |||
raise TypeError(f'input should be either str, PIL.Image,' | |||
f' np.array, but got {type(input)}') | |||
img = LoadImage.convert_to_ndarray(input) | |||
h, w, c = img.shape | |||
img_pad = np.zeros((max(h, w), max(h, w), 3), dtype=np.float32) | |||
img_pad[:h, :w, :] = img | |||
@@ -3,13 +3,12 @@ from typing import Any, Dict | |||
import cv2 | |||
import numpy as np | |||
import PIL | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import load_image | |||
from modelscope.preprocessors import LoadImage | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
@@ -61,35 +60,12 @@ class StyleTransferPipeline(Pipeline): | |||
return pipeline_parameters, {}, {} | |||
def preprocess(self, content: Input, style: Input) -> Dict[str, Any]: | |||
if isinstance(content, str): | |||
content = np.array(load_image(content)) | |||
elif isinstance(content, PIL.Image.Image): | |||
content = np.array(content.convert('RGB')) | |||
elif isinstance(content, np.ndarray): | |||
if len(content.shape) == 2: | |||
content = cv2.cvtColor(content, cv2.COLOR_GRAY2BGR) | |||
content = content[:, :, ::-1] # in rgb order | |||
else: | |||
raise TypeError( | |||
f'modelscope error: content should be either str, PIL.Image,' | |||
f' np.array, but got {type(content)}') | |||
content = LoadImage.convert_to_ndarray(content) | |||
if len(content.shape) == 2: | |||
content = cv2.cvtColor(content, cv2.COLOR_GRAY2BGR) | |||
content_img = content.astype(np.float) | |||
if isinstance(style, str): | |||
style_img = np.array(load_image(style)) | |||
elif isinstance(style, PIL.Image.Image): | |||
style_img = np.array(style.convert('RGB')) | |||
elif isinstance(style, np.ndarray): | |||
if len(style.shape) == 2: | |||
style_img = cv2.cvtColor(style, cv2.COLOR_GRAY2BGR) | |||
style_img = style_img[:, :, ::-1] # in rgb order | |||
else: | |||
raise TypeError( | |||
f'modelscope error: style should be either str, PIL.Image,' | |||
f' np.array, but got {type(style)}') | |||
style_img = LoadImage.convert_to_ndarray(style) | |||
if len(style_img.shape) == 2: | |||
style_img = cv2.cvtColor(style_img, cv2.COLOR_GRAY2BGR) | |||
style_img = style_img.astype(np.float) | |||
@@ -1,21 +1,18 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os.path as osp | |||
from abc import ABC, abstractmethod | |||
from typing import Any, Dict, Generator, List, Union | |||
from typing import Any, Dict | |||
import cv2 | |||
import numpy as np | |||
import PIL | |||
import torch | |||
from PIL import Image | |||
from torchvision import transforms | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models.cv.virual_tryon.sdafnet import SDAFNet_Tryon | |||
from modelscope.outputs import TASK_OUTPUTS, OutputKeys | |||
from modelscope.pipelines.util import is_model, is_official_hub_path | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.preprocessors import load_image | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from ..base import Pipeline | |||
@@ -67,6 +64,7 @@ class VirtualTryonPipeline(Pipeline): | |||
load_pretrained(self.model, src_params) | |||
self.model = self.model.eval() | |||
self.size = 192 | |||
from torchvision import transforms | |||
self.test_transforms = transforms.Compose([ | |||
transforms.Resize(self.size, interpolation=2), | |||
transforms.ToTensor(), | |||
@@ -2,7 +2,10 @@ | |||
import io | |||
from typing import Any, Dict, Union | |||
import torch | |||
import cv2 | |||
import numpy as np | |||
import PIL | |||
from numpy import ndarray | |||
from PIL import Image, ImageOps | |||
from modelscope.fileio import File | |||
@@ -60,6 +63,37 @@ class LoadImage: | |||
repr_str = f'{self.__class__.__name__}(' f'mode={self.mode})' | |||
return repr_str | |||
@staticmethod | |||
def convert_to_ndarray(input) -> ndarray: | |||
if isinstance(input, str): | |||
img = np.array(load_image(input)) | |||
elif isinstance(input, PIL.Image.Image): | |||
img = np.array(input.convert('RGB')) | |||
elif isinstance(input, np.ndarray): | |||
if len(input.shape) == 2: | |||
input = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | |||
img = input[:, :, ::-1] | |||
else: | |||
raise TypeError(f'input should be either str, PIL.Image,' | |||
f' np.array, but got {type(input)}') | |||
return img | |||
@staticmethod | |||
def convert_to_img(input) -> ndarray: | |||
if isinstance(input, str): | |||
img = load_image(input) | |||
elif isinstance(input, PIL.Image.Image): | |||
img = input.convert('RGB') | |||
elif isinstance(input, np.ndarray): | |||
if len(input.shape) == 2: | |||
img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | |||
img = input[:, :, ::-1] | |||
img = Image.fromarray(img.astype('uint8')).convert('RGB') | |||
else: | |||
raise TypeError(f'input should be either str, PIL.Image,' | |||
f' np.array, but got {type(input)}') | |||
return img | |||
def load_image(image_path_or_url: str) -> Image.Image: | |||
""" simple interface to load an image from file or url | |||