* add test code * fix bug * support gray image * update unitest * bugfixed Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9419792master
@@ -0,0 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | |||||
oid sha256:f33a6ad9fcd7367cec2e81b8b0e4234d4f5f7d1be284d48085a25bb6d03782d7 | |||||
size 72130 |
@@ -0,0 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | |||||
oid sha256:1af09b2c18a6674b7d88849cb87564dd77e1ce04d1517bb085449b614cc0c8d8 | |||||
size 376101 |
@@ -49,6 +49,7 @@ class Pipelines(object): | |||||
action_recognition = 'TAdaConv_action-recognition' | action_recognition = 'TAdaConv_action-recognition' | ||||
animal_recognation = 'resnet101-animal_recog' | animal_recognation = 'resnet101-animal_recog' | ||||
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | ||||
style_transfer = 'AAMS-style-transfer' | |||||
# nlp tasks | # nlp tasks | ||||
sentence_similarity = 'sentence-similarity' | sentence_similarity = 'sentence-similarity' | ||||
@@ -64,7 +64,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
'damo/cv_r2p1d_video_embedding'), | 'damo/cv_r2p1d_video_embedding'), | ||||
Tasks.text_to_image_synthesis: | Tasks.text_to_image_synthesis: | ||||
(Pipelines.text_to_image_synthesis, | (Pipelines.text_to_image_synthesis, | ||||
'damo/cv_imagen_text-to-image-synthesis_tiny') | |||||
'damo/cv_imagen_text-to-image-synthesis_tiny'), | |||||
Tasks.style_transfer: (Pipelines.style_transfer, | |||||
'damo/cv_aams_style-transfer_damo') | |||||
} | } | ||||
@@ -15,11 +15,12 @@ except ModuleNotFoundError as e: | |||||
try: | try: | ||||
from .image_cartoon_pipeline import ImageCartoonPipeline | from .image_cartoon_pipeline import ImageCartoonPipeline | ||||
from .image_matting_pipeline import ImageMattingPipeline | from .image_matting_pipeline import ImageMattingPipeline | ||||
from .style_transfer_pipeline import StyleTransferPipeline | |||||
from .ocr_detection_pipeline import OCRDetectionPipeline | from .ocr_detection_pipeline import OCRDetectionPipeline | ||||
except ModuleNotFoundError as e: | except ModuleNotFoundError as e: | ||||
if str(e) == "No module named 'tensorflow'": | if str(e) == "No module named 'tensorflow'": | ||||
print( | print( | ||||
TENSORFLOW_IMPORT_ERROR.format( | TENSORFLOW_IMPORT_ERROR.format( | ||||
'image-cartoon image-matting ocr-detection')) | |||||
'image-cartoon image-matting ocr-detection style-transfer')) | |||||
else: | else: | ||||
raise ModuleNotFoundError(e) | raise ModuleNotFoundError(e) |
@@ -0,0 +1,131 @@ | |||||
import os.path as osp | |||||
from typing import Any, Dict | |||||
import cv2 | |||||
import numpy as np | |||||
import PIL | |||||
from modelscope.metainfo import Pipelines | |||||
from modelscope.outputs import OutputKeys | |||||
from modelscope.pipelines.base import Input, Pipeline | |||||
from modelscope.pipelines.builder import PIPELINES | |||||
from modelscope.preprocessors import load_image | |||||
from modelscope.utils.constant import ModelFile, Tasks | |||||
from modelscope.utils.logger import get_logger | |||||
logger = get_logger() | |||||
@PIPELINES.register_module( | |||||
Tasks.style_transfer, module_name=Pipelines.style_transfer) | |||||
class StyleTransferPipeline(Pipeline): | |||||
def __init__(self, model: str): | |||||
""" | |||||
use `model` and `preprocessor` to create a kws pipeline for prediction | |||||
Args: | |||||
model: model id on modelscope hub. | |||||
""" | |||||
super().__init__(model=model) | |||||
import tensorflow as tf | |||||
if tf.__version__ >= '2.0': | |||||
tf = tf.compat.v1 | |||||
model_path = osp.join(self.model, ModelFile.TF_GRAPH_FILE) | |||||
config = tf.ConfigProto(allow_soft_placement=True) | |||||
config.gpu_options.allow_growth = True | |||||
self._session = tf.Session(config=config) | |||||
self.max_length = 800 | |||||
with self._session.as_default(): | |||||
logger.info(f'loading model from {model_path}') | |||||
with tf.gfile.FastGFile(model_path, 'rb') as f: | |||||
graph_def = tf.GraphDef() | |||||
graph_def.ParseFromString(f.read()) | |||||
tf.import_graph_def(graph_def, name='') | |||||
self.content = tf.get_default_graph().get_tensor_by_name( | |||||
'content:0') | |||||
self.style = tf.get_default_graph().get_tensor_by_name( | |||||
'style:0') | |||||
self.output = tf.get_default_graph().get_tensor_by_name( | |||||
'stylized_output:0') | |||||
self.attention = tf.get_default_graph().get_tensor_by_name( | |||||
'attention_map:0') | |||||
self.inter_weight = tf.get_default_graph().get_tensor_by_name( | |||||
'inter_weight:0') | |||||
self.centroids = tf.get_default_graph().get_tensor_by_name( | |||||
'centroids:0') | |||||
logger.info('load model done') | |||||
def _sanitize_parameters(self, **pipeline_parameters): | |||||
return pipeline_parameters, {}, {} | |||||
def preprocess(self, content: Input, style: Input) -> Dict[str, Any]: | |||||
if isinstance(content, str): | |||||
content = np.array(load_image(content)) | |||||
elif isinstance(content, PIL.Image.Image): | |||||
content = np.array(content.convert('RGB')) | |||||
elif isinstance(content, np.ndarray): | |||||
if len(content.shape) == 2: | |||||
content = cv2.cvtColor(content, cv2.COLOR_GRAY2BGR) | |||||
content = content[:, :, ::-1] # in rgb order | |||||
else: | |||||
raise TypeError( | |||||
f'modelscope error: content should be either str, PIL.Image,' | |||||
f' np.array, but got {type(content)}') | |||||
if len(content.shape) == 2: | |||||
content = cv2.cvtColor(content, cv2.COLOR_GRAY2BGR) | |||||
content_img = content.astype(np.float) | |||||
if isinstance(style, str): | |||||
style_img = np.array(load_image(style)) | |||||
elif isinstance(style, PIL.Image.Image): | |||||
style_img = np.array(style.convert('RGB')) | |||||
elif isinstance(style, np.ndarray): | |||||
if len(style.shape) == 2: | |||||
style_img = cv2.cvtColor(style, cv2.COLOR_GRAY2BGR) | |||||
style_img = style_img[:, :, ::-1] # in rgb order | |||||
else: | |||||
raise TypeError( | |||||
f'modelscope error: style should be either str, PIL.Image,' | |||||
f' np.array, but got {type(style)}') | |||||
if len(style_img.shape) == 2: | |||||
style_img = cv2.cvtColor(style_img, cv2.COLOR_GRAY2BGR) | |||||
style_img = style_img.astype(np.float) | |||||
result = {'content': content_img, 'style': style_img} | |||||
return result | |||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
content_feed, style_feed = input['content'], input['style'] | |||||
h = np.shape(content_feed)[0] | |||||
w = np.shape(content_feed)[1] | |||||
if h > self.max_length or w > self.max_length: | |||||
if h > w: | |||||
content_feed = cv2.resize( | |||||
content_feed, | |||||
(int(self.max_length * w / h), self.max_length)) | |||||
else: | |||||
content_feed = cv2.resize( | |||||
content_feed, | |||||
(self.max_length, int(self.max_length * h / w))) | |||||
with self._session.as_default(): | |||||
feed_dict = { | |||||
self.content: content_feed, | |||||
self.style: style_feed, | |||||
self.inter_weight: 1.0 | |||||
} | |||||
output_img = self._session.run(self.output, feed_dict=feed_dict) | |||||
# print('out_img shape:{}'.format(output_img.shape)) | |||||
output_img = cv2.cvtColor(output_img[0], cv2.COLOR_RGB2BGR) | |||||
output_img = np.clip(output_img, 0, 255).astype(np.uint8) | |||||
output_img = cv2.resize(output_img, (w, h)) | |||||
return {OutputKeys.OUTPUT_IMG: output_img} | |||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
return inputs |
@@ -27,6 +27,7 @@ class CVTasks(object): | |||||
ocr_detection = 'ocr-detection' | ocr_detection = 'ocr-detection' | ||||
action_recognition = 'action-recognition' | action_recognition = 'action-recognition' | ||||
video_embedding = 'video-embedding' | video_embedding = 'video-embedding' | ||||
style_transfer = 'style-transfer' | |||||
class NLPTasks(object): | class NLPTasks(object): | ||||
@@ -0,0 +1,55 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os.path as osp | |||||
import tempfile | |||||
import unittest | |||||
import cv2 | |||||
from modelscope.fileio import File | |||||
from modelscope.hub.snapshot_download import snapshot_download | |||||
from modelscope.outputs import OutputKeys | |||||
from modelscope.pipelines import pipeline | |||||
from modelscope.pipelines.base import Pipeline | |||||
from modelscope.utils.constant import ModelFile, Tasks | |||||
from modelscope.utils.test_utils import test_level | |||||
class StyleTransferTest(unittest.TestCase): | |||||
def setUp(self) -> None: | |||||
self.model_id = 'damo/cv_aams_style-transfer_damo' | |||||
@unittest.skip('deprecated, download model from model hub instead') | |||||
def test_run_by_direct_model_download(self): | |||||
snapshot_path = snapshot_download(self.model_id) | |||||
print('snapshot_path: {}'.format(snapshot_path)) | |||||
style_transfer = pipeline(Tasks.style_transfer, model=snapshot_path) | |||||
result = style_transfer( | |||||
'data/test/images/style_transfer_content.jpg', | |||||
style='data/test/images/style_transfer_style.jpg') | |||||
cv2.imwrite('result_styletransfer1.png', result[OutputKeys.OUTPUT_IMG]) | |||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
def test_run_modelhub(self): | |||||
style_transfer = pipeline(Tasks.style_transfer, model=self.model_id) | |||||
result = style_transfer( | |||||
'data/test/images/style_transfer_content.jpg', | |||||
style='data/test/images/style_transfer_style.jpg') | |||||
cv2.imwrite('result_styletransfer2.png', result[OutputKeys.OUTPUT_IMG]) | |||||
print('style_transfer.test_run_modelhub done') | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run_modelhub_default_model(self): | |||||
style_transfer = pipeline(Tasks.style_transfer) | |||||
result = style_transfer( | |||||
'data/test/images/style_transfer_content.jpg', | |||||
style='data/test/images/style_transfer_style.jpg') | |||||
cv2.imwrite('result_styletransfer3.png', result[OutputKeys.OUTPUT_IMG]) | |||||
print('style_transfer.test_run_modelhub_default_model done') | |||||
if __name__ == '__main__': | |||||
unittest.main() |