Browse Source

[to #42322933] Merge request from 雪洛:cv/aams

* add test code
 * fix bug
 * support gray image
 * update unitest
 * bugfixed
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9419792
master
jianqiang.rjq huangjun.hj 3 years ago
parent
commit
5876fdc25c
8 changed files with 199 additions and 2 deletions
  1. +3
    -0
      data/test/images/style_transfer_content.jpg
  2. +3
    -0
      data/test/images/style_transfer_style.jpg
  3. +1
    -0
      modelscope/metainfo.py
  4. +3
    -1
      modelscope/pipelines/builder.py
  5. +2
    -1
      modelscope/pipelines/cv/__init__.py
  6. +131
    -0
      modelscope/pipelines/cv/style_transfer_pipeline.py
  7. +1
    -0
      modelscope/utils/constant.py
  8. +55
    -0
      tests/pipelines/test_style_transfer.py

+ 3
- 0
data/test/images/style_transfer_content.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f33a6ad9fcd7367cec2e81b8b0e4234d4f5f7d1be284d48085a25bb6d03782d7
size 72130

+ 3
- 0
data/test/images/style_transfer_style.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1af09b2c18a6674b7d88849cb87564dd77e1ce04d1517bb085449b614cc0c8d8
size 376101

+ 1
- 0
modelscope/metainfo.py View File

@@ -49,6 +49,7 @@ class Pipelines(object):
action_recognition = 'TAdaConv_action-recognition'
animal_recognation = 'resnet101-animal_recog'
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding'
style_transfer = 'AAMS-style-transfer'

# nlp tasks
sentence_similarity = 'sentence-similarity'


+ 3
- 1
modelscope/pipelines/builder.py View File

@@ -64,7 +64,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_r2p1d_video_embedding'),
Tasks.text_to_image_synthesis:
(Pipelines.text_to_image_synthesis,
'damo/cv_imagen_text-to-image-synthesis_tiny')
'damo/cv_imagen_text-to-image-synthesis_tiny'),
Tasks.style_transfer: (Pipelines.style_transfer,
'damo/cv_aams_style-transfer_damo')
}




+ 2
- 1
modelscope/pipelines/cv/__init__.py View File

@@ -15,11 +15,12 @@ except ModuleNotFoundError as e:
try:
from .image_cartoon_pipeline import ImageCartoonPipeline
from .image_matting_pipeline import ImageMattingPipeline
from .style_transfer_pipeline import StyleTransferPipeline
from .ocr_detection_pipeline import OCRDetectionPipeline
except ModuleNotFoundError as e:
if str(e) == "No module named 'tensorflow'":
print(
TENSORFLOW_IMPORT_ERROR.format(
'image-cartoon image-matting ocr-detection'))
'image-cartoon image-matting ocr-detection style-transfer'))
else:
raise ModuleNotFoundError(e)

+ 131
- 0
modelscope/pipelines/cv/style_transfer_pipeline.py View File

@@ -0,0 +1,131 @@
import os.path as osp
from typing import Any, Dict

import cv2
import numpy as np
import PIL

from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import load_image
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.style_transfer, module_name=Pipelines.style_transfer)
class StyleTransferPipeline(Pipeline):

def __init__(self, model: str):
"""
use `model` and `preprocessor` to create a kws pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model)
import tensorflow as tf
if tf.__version__ >= '2.0':
tf = tf.compat.v1
model_path = osp.join(self.model, ModelFile.TF_GRAPH_FILE)

config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
self._session = tf.Session(config=config)
self.max_length = 800
with self._session.as_default():
logger.info(f'loading model from {model_path}')
with tf.gfile.FastGFile(model_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')

self.content = tf.get_default_graph().get_tensor_by_name(
'content:0')
self.style = tf.get_default_graph().get_tensor_by_name(
'style:0')
self.output = tf.get_default_graph().get_tensor_by_name(
'stylized_output:0')
self.attention = tf.get_default_graph().get_tensor_by_name(
'attention_map:0')
self.inter_weight = tf.get_default_graph().get_tensor_by_name(
'inter_weight:0')
self.centroids = tf.get_default_graph().get_tensor_by_name(
'centroids:0')
logger.info('load model done')

def _sanitize_parameters(self, **pipeline_parameters):
return pipeline_parameters, {}, {}

def preprocess(self, content: Input, style: Input) -> Dict[str, Any]:
if isinstance(content, str):
content = np.array(load_image(content))
elif isinstance(content, PIL.Image.Image):
content = np.array(content.convert('RGB'))
elif isinstance(content, np.ndarray):
if len(content.shape) == 2:
content = cv2.cvtColor(content, cv2.COLOR_GRAY2BGR)
content = content[:, :, ::-1] # in rgb order
else:
raise TypeError(
f'modelscope error: content should be either str, PIL.Image,'
f' np.array, but got {type(content)}')
if len(content.shape) == 2:
content = cv2.cvtColor(content, cv2.COLOR_GRAY2BGR)
content_img = content.astype(np.float)

if isinstance(style, str):
style_img = np.array(load_image(style))
elif isinstance(style, PIL.Image.Image):
style_img = np.array(style.convert('RGB'))
elif isinstance(style, np.ndarray):
if len(style.shape) == 2:
style_img = cv2.cvtColor(style, cv2.COLOR_GRAY2BGR)
style_img = style_img[:, :, ::-1] # in rgb order
else:
raise TypeError(
f'modelscope error: style should be either str, PIL.Image,'
f' np.array, but got {type(style)}')

if len(style_img.shape) == 2:
style_img = cv2.cvtColor(style_img, cv2.COLOR_GRAY2BGR)
style_img = style_img.astype(np.float)

result = {'content': content_img, 'style': style_img}
return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
content_feed, style_feed = input['content'], input['style']
h = np.shape(content_feed)[0]
w = np.shape(content_feed)[1]
if h > self.max_length or w > self.max_length:
if h > w:
content_feed = cv2.resize(
content_feed,
(int(self.max_length * w / h), self.max_length))
else:
content_feed = cv2.resize(
content_feed,
(self.max_length, int(self.max_length * h / w)))

with self._session.as_default():
feed_dict = {
self.content: content_feed,
self.style: style_feed,
self.inter_weight: 1.0
}
output_img = self._session.run(self.output, feed_dict=feed_dict)

# print('out_img shape:{}'.format(output_img.shape))
output_img = cv2.cvtColor(output_img[0], cv2.COLOR_RGB2BGR)
output_img = np.clip(output_img, 0, 255).astype(np.uint8)

output_img = cv2.resize(output_img, (w, h))

return {OutputKeys.OUTPUT_IMG: output_img}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 1
- 0
modelscope/utils/constant.py View File

@@ -27,6 +27,7 @@ class CVTasks(object):
ocr_detection = 'ocr-detection'
action_recognition = 'action-recognition'
video_embedding = 'video-embedding'
style_transfer = 'style-transfer'


class NLPTasks(object):


+ 55
- 0
tests/pipelines/test_style_transfer.py View File

@@ -0,0 +1,55 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import tempfile
import unittest

import cv2

from modelscope.fileio import File
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.test_utils import test_level


class StyleTransferTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/cv_aams_style-transfer_damo'

@unittest.skip('deprecated, download model from model hub instead')
def test_run_by_direct_model_download(self):
snapshot_path = snapshot_download(self.model_id)
print('snapshot_path: {}'.format(snapshot_path))
style_transfer = pipeline(Tasks.style_transfer, model=snapshot_path)

result = style_transfer(
'data/test/images/style_transfer_content.jpg',
style='data/test/images/style_transfer_style.jpg')
cv2.imwrite('result_styletransfer1.png', result[OutputKeys.OUTPUT_IMG])

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_modelhub(self):
style_transfer = pipeline(Tasks.style_transfer, model=self.model_id)

result = style_transfer(
'data/test/images/style_transfer_content.jpg',
style='data/test/images/style_transfer_style.jpg')
cv2.imwrite('result_styletransfer2.png', result[OutputKeys.OUTPUT_IMG])
print('style_transfer.test_run_modelhub done')

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
style_transfer = pipeline(Tasks.style_transfer)

result = style_transfer(
'data/test/images/style_transfer_content.jpg',
style='data/test/images/style_transfer_style.jpg')
cv2.imwrite('result_styletransfer3.png', result[OutputKeys.OUTPUT_IMG])
print('style_transfer.test_run_modelhub_default_model done')


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save