Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8993758master
@@ -23,7 +23,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), | |||
Tasks.image_captioning: ('ofa', None), | |||
Tasks.image_generation: | |||
('cv_unet_person-image-cartoon', 'damo/cv_unet_image-matting_damo'), | |||
('person-image-cartoon', | |||
'damo/cv_unet_person-image-cartoon_compound-models'), | |||
} | |||
@@ -25,20 +25,19 @@ logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.image_generation, module_name='cv_unet_person-image-cartoon') | |||
Tasks.image_generation, module_name='person-image-cartoon') | |||
class ImageCartoonPipeline(Pipeline): | |||
def __init__(self, model: str): | |||
super().__init__(model=model) | |||
self.facer = FaceAna(model) | |||
self.facer = FaceAna(self.model) | |||
self.sess_anime_head = self.load_sess( | |||
os.path.join(model, 'cartoon_anime_h.pb'), 'model_anime_head') | |||
os.path.join(self.model, 'cartoon_anime_h.pb'), 'model_anime_head') | |||
self.sess_anime_bg = self.load_sess( | |||
os.path.join(model, 'cartoon_anime_bg.pb'), 'model_anime_bg') | |||
os.path.join(self.model, 'cartoon_anime_bg.pb'), 'model_anime_bg') | |||
self.box_width = 288 | |||
global_mask = cv2.imread(os.path.join(model, 'alpha.jpg')) | |||
global_mask = cv2.imread(os.path.join(self.model, 'alpha.jpg')) | |||
global_mask = cv2.resize( | |||
global_mask, (self.box_width, self.box_width), | |||
interpolation=cv2.INTER_AREA) | |||
@@ -1,26 +1,31 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import os.path as osp | |||
import unittest | |||
import cv2 | |||
from modelscope.pipelines import pipeline | |||
from modelscope.pipelines.base import Pipeline | |||
from modelscope.utils.constant import Tasks | |||
def all_file(file_dir): | |||
L = [] | |||
for root, dirs, files in os.walk(file_dir): | |||
for file in files: | |||
extend = os.path.splitext(file)[1] | |||
if extend == '.png' or extend == '.jpg' or extend == '.jpeg' or extend == '.JPG' or extend == '.HEIC': | |||
L.append(os.path.join(root, file)) | |||
return L | |||
class ImageCartoonTest(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.model_id = 'damo/cv_unet_person-image-cartoon_compound-models' | |||
self.test_image = \ | |||
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com' \ | |||
'/data/test/maas/image_carton/test.png' | |||
class ImageCartoonTest(unittest.TestCase): | |||
def pipeline_inference(self, pipeline: Pipeline, input_location: str): | |||
result = pipeline(input_location) | |||
if result is not None: | |||
cv2.imwrite('result.png', result['output_png']) | |||
print(f'Output written to {osp.abspath("result.png")}') | |||
def test_run(self): | |||
@unittest.skip('deprecated, download model from model hub instead') | |||
def test_run_by_direct_model_download(self): | |||
model_dir = './assets' | |||
if not os.path.exists(model_dir): | |||
os.system( | |||
@@ -29,9 +34,15 @@ class ImageCartoonTest(unittest.TestCase): | |||
os.system('unzip assets.zip') | |||
img_cartoon = pipeline(Tasks.image_generation, model=model_dir) | |||
result = img_cartoon(os.path.join(model_dir, 'test.png')) | |||
if result is not None: | |||
cv2.imwrite('result.png', result['output_png']) | |||
self.pipeline_inference(img_cartoon, self.test_image) | |||
def test_run_modelhub(self): | |||
img_cartoon = pipeline(Tasks.image_generation, model=self.model_id) | |||
self.pipeline_inference(img_cartoon, self.test_image) | |||
def test_run_modelhub_default_model(self): | |||
img_cartoon = pipeline(Tasks.image_generation) | |||
self.pipeline_inference(img_cartoon, self.test_image) | |||
if __name__ == '__main__': | |||
@@ -0,0 +1,50 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os.path as osp | |||
import unittest | |||
from maas_hub.maas_api import MaasApi | |||
from maas_hub.repository import Repository | |||
USER_NAME = 'maasadmin' | |||
PASSWORD = '12345678' | |||
class HubOperationTest(unittest.TestCase): | |||
def setUp(self): | |||
self.api = MaasApi() | |||
# note this is temporary before official account management is ready | |||
self.api.login(USER_NAME, PASSWORD) | |||
@unittest.skip('to be used for local test only') | |||
def test_model_repo_creation(self): | |||
# change to proper model names before use | |||
model_name = 'cv_unet_person-image-cartoon_compound-models' | |||
model_chinese_name = '达摩卡通化模型' | |||
model_org = 'damo' | |||
try: | |||
self.api.create_model( | |||
owner=model_org, | |||
name=model_name, | |||
chinese_name=model_chinese_name, | |||
visibility=5, # 1-private, 5-public | |||
license='apache-2.0') | |||
# TODO: support proper name duplication checking | |||
except KeyError as ke: | |||
if ke.args[0] == 'name': | |||
print(f'model {self.model_name} already exists, ignore') | |||
else: | |||
raise | |||
# Note that this can be done via git operation once model repo | |||
# has been created. Git-Op is the RECOMMENDED model upload approach | |||
@unittest.skip('to be used for local test only') | |||
def test_model_upload(self): | |||
local_path = '/path/to/local/model/directory' | |||
assert osp.exists(local_path), 'Local model directory not exist.' | |||
repo = Repository(local_dir=local_path) | |||
repo.push_to_hub(commit_message='Upload model files') | |||
if __name__ == '__main__': | |||
unittest.main() |