Browse Source

[to #43259593] refacor image preprocess

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9501913

    * [to #43259593] refacor image preprocess
master
yingda.chen 3 years ago
parent
commit
590bc52f97
13 changed files with 56 additions and 116 deletions
  1. +1
    -1
      modelscope/models/multi_modal/clip/clip_model.py
  2. +0
    -1
      modelscope/pipelines/audio/ans_pipeline.py
  3. +0
    -1
      modelscope/pipelines/base.py
  4. +2
    -14
      modelscope/pipelines/cv/animal_recog_pipeline.py
  5. +2
    -13
      modelscope/pipelines/cv/image_cartoon_pipeline.py
  6. +2
    -13
      modelscope/pipelines/cv/image_color_enhance_pipeline.py
  7. +1
    -1
      modelscope/pipelines/cv/image_colorization_pipeline.py
  8. +2
    -13
      modelscope/pipelines/cv/image_matting_pipeline.py
  9. +2
    -13
      modelscope/pipelines/cv/image_super_resolution_pipeline.py
  10. +3
    -13
      modelscope/pipelines/cv/ocr_detection_pipeline.py
  11. +3
    -27
      modelscope/pipelines/cv/style_transfer_pipeline.py
  12. +3
    -5
      modelscope/pipelines/cv/virtual_tryon_pipeline.py
  13. +35
    -1
      modelscope/preprocessors/image.py

+ 1
- 1
modelscope/models/multi_modal/clip/clip_model.py View File

@@ -1,6 +1,6 @@
import os.path as osp
from typing import Any, Dict

import cv2
import json
import numpy as np
import torch


+ 0
- 1
modelscope/pipelines/audio/ans_pipeline.py View File

@@ -11,7 +11,6 @@ from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks
from modelscope.utils.torch_utils import create_device


def audio_norm(x):


+ 0
- 1
modelscope/pipelines/base.py View File

@@ -8,7 +8,6 @@ from typing import Any, Dict, Generator, List, Mapping, Union

import numpy as np

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.base import Model
from modelscope.msdatasets import MsDataset
from modelscope.outputs import TASK_OUTPUTS


+ 2
- 14
modelscope/pipelines/cv/animal_recog_pipeline.py View File

@@ -13,7 +13,7 @@ from modelscope.models.cv.animal_recognition import resnet
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import load_image
from modelscope.preprocessors import LoadImage, load_image
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

@@ -79,19 +79,7 @@ class AnimalRecogPipeline(Pipeline):
logger.info('load model done')

def preprocess(self, input: Input) -> Dict[str, Any]:
if isinstance(input, str):
img = load_image(input)
elif isinstance(input, PIL.Image.Image):
img = input.convert('RGB')
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = input[:, :, ::-1]
img = Image.fromarray(img.astype('uint8')).convert('RGB')
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')

img = LoadImage.convert_to_img(input)
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
test_transforms = transforms.Compose([


+ 2
- 13
modelscope/pipelines/cv/image_cartoon_pipeline.py View File

@@ -3,7 +3,6 @@ from typing import Any, Dict

import cv2
import numpy as np
import PIL
import tensorflow as tf

from modelscope.metainfo import Pipelines
@@ -14,7 +13,7 @@ from modelscope.models.cv.cartoon.utils import get_f5p, padTo16x, resize_size
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import load_image
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

@@ -65,17 +64,7 @@ class ImageCartoonPipeline(Pipeline):
return sess

def preprocess(self, input: Input) -> Dict[str, Any]:
if isinstance(input, str):
img = np.array(load_image(input))
elif isinstance(input, PIL.Image.Image):
img = np.array(input.convert('RGB'))
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
input = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR)
img = input[:, :, ::-1]
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')
img = LoadImage.convert_to_ndarray(input)
img = img.astype(np.float)
result = {'img': img}
return result


+ 2
- 13
modelscope/pipelines/cv/image_color_enhance_pipeline.py View File

@@ -13,7 +13,7 @@ from modelscope.models.cv.image_color_enhance.image_color_enhance import \
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input
from modelscope.preprocessors import (ImageColorEnhanceFinetunePreprocessor,
load_image)
LoadImage, load_image)
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from ..base import Pipeline
@@ -47,18 +47,7 @@ class ImageColorEnhancePipeline(Pipeline):
self._device = torch.device('cpu')

def preprocess(self, input: Input) -> Dict[str, Any]:
if isinstance(input, str):
img = load_image(input)
elif isinstance(input, PIL.Image.Image):
img = input.convert('RGB')
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
img = Image.fromarray(img.astype('uint8')).convert('RGB')
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')

img = LoadImage.convert_to_img(input)
test_transforms = transforms.Compose([transforms.ToTensor()])
img = test_transforms(img)
result = {'src': img.unsqueeze(0).to(self._device)}


+ 1
- 1
modelscope/pipelines/cv/image_colorization_pipeline.py View File

@@ -88,7 +88,7 @@ class ImageColorizationPipeline(Pipeline):
img = input.convert('LA').convert('RGB')
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR)
img = input[:, :, ::-1] # in rgb order
img = PIL.Image.fromarray(img).convert('LA').convert('RGB')
else:


+ 2
- 13
modelscope/pipelines/cv/image_matting_pipeline.py View File

@@ -3,13 +3,12 @@ from typing import Any, Dict

import cv2
import numpy as np
import PIL

from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import load_image
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

@@ -47,17 +46,7 @@ class ImageMattingPipeline(Pipeline):
logger.info('load model done')

def preprocess(self, input: Input) -> Dict[str, Any]:
if isinstance(input, str):
img = np.array(load_image(input))
elif isinstance(input, PIL.Image.Image):
img = np.array(input.convert('RGB'))
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = input[:, :, ::-1] # in rgb order
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')
img = LoadImage.convert_to_img(input)
img = img.astype(np.float)
result = {'img': img}
return result


+ 2
- 13
modelscope/pipelines/cv/image_super_resolution_pipeline.py View File

@@ -10,7 +10,7 @@ from modelscope.models.cv.super_resolution import rrdbnet_arch
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import load_image
from modelscope.preprocessors import LoadImage, load_image
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

@@ -46,18 +46,7 @@ class ImageSuperResolutionPipeline(Pipeline):
logger.info('load model done')

def preprocess(self, input: Input) -> Dict[str, Any]:
if isinstance(input, str):
img = np.array(load_image(input))
elif isinstance(input, PIL.Image.Image):
img = np.array(input.convert('RGB'))
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = input[:, :, ::-1] # in rgb order
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')

img = LoadImage.convert_to_ndarray(input)
img = torch.from_numpy(img).to(self.device).permute(
2, 0, 1).unsqueeze(0) / 255.
result = {'img': img}


+ 3
- 13
modelscope/pipelines/cv/ocr_detection_pipeline.py View File

@@ -3,14 +3,13 @@ from typing import Any, Dict

import cv2
import numpy as np
import PIL
import tensorflow as tf

from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import load_image
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from .ocr_utils import model_resnet_mutex_v4_linewithchar, ops, utils
@@ -112,17 +111,8 @@ class OCRDetectionPipeline(Pipeline):
model_loader.restore(sess, model_path)

def preprocess(self, input: Input) -> Dict[str, Any]:
if isinstance(input, str):
img = np.array(load_image(input))
elif isinstance(input, PIL.Image.Image):
img = np.array(input.convert('RGB'))
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = input[:, :, ::-1] # in rgb order
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')
img = LoadImage.convert_to_ndarray(input)

h, w, c = img.shape
img_pad = np.zeros((max(h, w), max(h, w), 3), dtype=np.float32)
img_pad[:h, :w, :] = img


+ 3
- 27
modelscope/pipelines/cv/style_transfer_pipeline.py View File

@@ -3,13 +3,12 @@ from typing import Any, Dict

import cv2
import numpy as np
import PIL

from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import load_image
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

@@ -61,35 +60,12 @@ class StyleTransferPipeline(Pipeline):
return pipeline_parameters, {}, {}

def preprocess(self, content: Input, style: Input) -> Dict[str, Any]:
if isinstance(content, str):
content = np.array(load_image(content))
elif isinstance(content, PIL.Image.Image):
content = np.array(content.convert('RGB'))
elif isinstance(content, np.ndarray):
if len(content.shape) == 2:
content = cv2.cvtColor(content, cv2.COLOR_GRAY2BGR)
content = content[:, :, ::-1] # in rgb order
else:
raise TypeError(
f'modelscope error: content should be either str, PIL.Image,'
f' np.array, but got {type(content)}')
content = LoadImage.convert_to_ndarray(content)
if len(content.shape) == 2:
content = cv2.cvtColor(content, cv2.COLOR_GRAY2BGR)
content_img = content.astype(np.float)

if isinstance(style, str):
style_img = np.array(load_image(style))
elif isinstance(style, PIL.Image.Image):
style_img = np.array(style.convert('RGB'))
elif isinstance(style, np.ndarray):
if len(style.shape) == 2:
style_img = cv2.cvtColor(style, cv2.COLOR_GRAY2BGR)
style_img = style_img[:, :, ::-1] # in rgb order
else:
raise TypeError(
f'modelscope error: style should be either str, PIL.Image,'
f' np.array, but got {type(style)}')

style_img = LoadImage.convert_to_ndarray(style)
if len(style_img.shape) == 2:
style_img = cv2.cvtColor(style_img, cv2.COLOR_GRAY2BGR)
style_img = style_img.astype(np.float)


+ 3
- 5
modelscope/pipelines/cv/virtual_tryon_pipeline.py View File

@@ -1,21 +1,18 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os.path as osp
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Union
from typing import Any, Dict

import cv2
import numpy as np
import PIL
import torch
from PIL import Image
from torchvision import transforms

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Pipelines
from modelscope.models.cv.virual_tryon.sdafnet import SDAFNet_Tryon
from modelscope.outputs import TASK_OUTPUTS, OutputKeys
from modelscope.pipelines.util import is_model, is_official_hub_path
from modelscope.outputs import OutputKeys
from modelscope.preprocessors import load_image
from modelscope.utils.constant import ModelFile, Tasks
from ..base import Pipeline
@@ -67,6 +64,7 @@ class VirtualTryonPipeline(Pipeline):
load_pretrained(self.model, src_params)
self.model = self.model.eval()
self.size = 192
from torchvision import transforms
self.test_transforms = transforms.Compose([
transforms.Resize(self.size, interpolation=2),
transforms.ToTensor(),


+ 35
- 1
modelscope/preprocessors/image.py View File

@@ -2,7 +2,10 @@
import io
from typing import Any, Dict, Union

import torch
import cv2
import numpy as np
import PIL
from numpy import ndarray
from PIL import Image, ImageOps

from modelscope.fileio import File
@@ -60,6 +63,37 @@ class LoadImage:
repr_str = f'{self.__class__.__name__}(' f'mode={self.mode})'
return repr_str

@staticmethod
def convert_to_ndarray(input) -> ndarray:
if isinstance(input, str):
img = np.array(load_image(input))
elif isinstance(input, PIL.Image.Image):
img = np.array(input.convert('RGB'))
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
input = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR)
img = input[:, :, ::-1]
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')
return img

@staticmethod
def convert_to_img(input) -> ndarray:
if isinstance(input, str):
img = load_image(input)
elif isinstance(input, PIL.Image.Image):
img = input.convert('RGB')
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR)
img = input[:, :, ::-1]
img = Image.fromarray(img.astype('uint8')).convert('RGB')
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')
return img


def load_image(image_path_or_url: str) -> Image.Image:
""" simple interface to load an image from file or url


Loading…
Cancel
Save