Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8993758master
@@ -23,7 +23,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), | Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), | ||||
Tasks.image_captioning: ('ofa', None), | Tasks.image_captioning: ('ofa', None), | ||||
Tasks.image_generation: | Tasks.image_generation: | ||||
('cv_unet_person-image-cartoon', 'damo/cv_unet_image-matting_damo'), | |||||
('person-image-cartoon', | |||||
'damo/cv_unet_person-image-cartoon_compound-models'), | |||||
} | } | ||||
@@ -25,20 +25,19 @@ logger = get_logger() | |||||
@PIPELINES.register_module( | @PIPELINES.register_module( | ||||
Tasks.image_generation, module_name='cv_unet_person-image-cartoon') | |||||
Tasks.image_generation, module_name='person-image-cartoon') | |||||
class ImageCartoonPipeline(Pipeline): | class ImageCartoonPipeline(Pipeline): | ||||
def __init__(self, model: str): | def __init__(self, model: str): | ||||
super().__init__(model=model) | super().__init__(model=model) | ||||
self.facer = FaceAna(model) | |||||
self.facer = FaceAna(self.model) | |||||
self.sess_anime_head = self.load_sess( | self.sess_anime_head = self.load_sess( | ||||
os.path.join(model, 'cartoon_anime_h.pb'), 'model_anime_head') | |||||
os.path.join(self.model, 'cartoon_anime_h.pb'), 'model_anime_head') | |||||
self.sess_anime_bg = self.load_sess( | self.sess_anime_bg = self.load_sess( | ||||
os.path.join(model, 'cartoon_anime_bg.pb'), 'model_anime_bg') | |||||
os.path.join(self.model, 'cartoon_anime_bg.pb'), 'model_anime_bg') | |||||
self.box_width = 288 | self.box_width = 288 | ||||
global_mask = cv2.imread(os.path.join(model, 'alpha.jpg')) | |||||
global_mask = cv2.imread(os.path.join(self.model, 'alpha.jpg')) | |||||
global_mask = cv2.resize( | global_mask = cv2.resize( | ||||
global_mask, (self.box_width, self.box_width), | global_mask, (self.box_width, self.box_width), | ||||
interpolation=cv2.INTER_AREA) | interpolation=cv2.INTER_AREA) | ||||
@@ -1,26 +1,31 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
import os | import os | ||||
import os.path as osp | |||||
import unittest | import unittest | ||||
import cv2 | import cv2 | ||||
from modelscope.pipelines import pipeline | from modelscope.pipelines import pipeline | ||||
from modelscope.pipelines.base import Pipeline | |||||
from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
def all_file(file_dir): | |||||
L = [] | |||||
for root, dirs, files in os.walk(file_dir): | |||||
for file in files: | |||||
extend = os.path.splitext(file)[1] | |||||
if extend == '.png' or extend == '.jpg' or extend == '.jpeg' or extend == '.JPG' or extend == '.HEIC': | |||||
L.append(os.path.join(root, file)) | |||||
return L | |||||
class ImageCartoonTest(unittest.TestCase): | |||||
def setUp(self) -> None: | |||||
self.model_id = 'damo/cv_unet_person-image-cartoon_compound-models' | |||||
self.test_image = \ | |||||
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com' \ | |||||
'/data/test/maas/image_carton/test.png' | |||||
class ImageCartoonTest(unittest.TestCase): | |||||
def pipeline_inference(self, pipeline: Pipeline, input_location: str): | |||||
result = pipeline(input_location) | |||||
if result is not None: | |||||
cv2.imwrite('result.png', result['output_png']) | |||||
print(f'Output written to {osp.abspath("result.png")}') | |||||
def test_run(self): | |||||
@unittest.skip('deprecated, download model from model hub instead') | |||||
def test_run_by_direct_model_download(self): | |||||
model_dir = './assets' | model_dir = './assets' | ||||
if not os.path.exists(model_dir): | if not os.path.exists(model_dir): | ||||
os.system( | os.system( | ||||
@@ -29,9 +34,15 @@ class ImageCartoonTest(unittest.TestCase): | |||||
os.system('unzip assets.zip') | os.system('unzip assets.zip') | ||||
img_cartoon = pipeline(Tasks.image_generation, model=model_dir) | img_cartoon = pipeline(Tasks.image_generation, model=model_dir) | ||||
result = img_cartoon(os.path.join(model_dir, 'test.png')) | |||||
if result is not None: | |||||
cv2.imwrite('result.png', result['output_png']) | |||||
self.pipeline_inference(img_cartoon, self.test_image) | |||||
def test_run_modelhub(self): | |||||
img_cartoon = pipeline(Tasks.image_generation, model=self.model_id) | |||||
self.pipeline_inference(img_cartoon, self.test_image) | |||||
def test_run_modelhub_default_model(self): | |||||
img_cartoon = pipeline(Tasks.image_generation) | |||||
self.pipeline_inference(img_cartoon, self.test_image) | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
@@ -0,0 +1,50 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os.path as osp | |||||
import unittest | |||||
from maas_hub.maas_api import MaasApi | |||||
from maas_hub.repository import Repository | |||||
USER_NAME = 'maasadmin' | |||||
PASSWORD = '12345678' | |||||
class HubOperationTest(unittest.TestCase): | |||||
def setUp(self): | |||||
self.api = MaasApi() | |||||
# note this is temporary before official account management is ready | |||||
self.api.login(USER_NAME, PASSWORD) | |||||
@unittest.skip('to be used for local test only') | |||||
def test_model_repo_creation(self): | |||||
# change to proper model names before use | |||||
model_name = 'cv_unet_person-image-cartoon_compound-models' | |||||
model_chinese_name = '达摩卡通化模型' | |||||
model_org = 'damo' | |||||
try: | |||||
self.api.create_model( | |||||
owner=model_org, | |||||
name=model_name, | |||||
chinese_name=model_chinese_name, | |||||
visibility=5, # 1-private, 5-public | |||||
license='apache-2.0') | |||||
# TODO: support proper name duplication checking | |||||
except KeyError as ke: | |||||
if ke.args[0] == 'name': | |||||
print(f'model {self.model_name} already exists, ignore') | |||||
else: | |||||
raise | |||||
# Note that this can be done via git operation once model repo | |||||
# has been created. Git-Op is the RECOMMENDED model upload approach | |||||
@unittest.skip('to be used for local test only') | |||||
def test_model_upload(self): | |||||
local_path = '/path/to/local/model/directory' | |||||
assert osp.exists(local_path), 'Local model directory not exist.' | |||||
repo = Repository(local_dir=local_path) | |||||
repo.push_to_hub(commit_message='Upload model files') | |||||
if __name__ == '__main__': | |||||
unittest.main() |