diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 06e614e6..6d0ec729 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -23,7 +23,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), Tasks.image_captioning: ('ofa', None), Tasks.image_generation: - ('cv_unet_person-image-cartoon', 'damo/cv_unet_image-matting_damo'), + ('person-image-cartoon', + 'damo/cv_unet_person-image-cartoon_compound-models'), } diff --git a/modelscope/pipelines/cv/image_cartoon_pipeline.py b/modelscope/pipelines/cv/image_cartoon_pipeline.py index 6a6c10e0..d253eaf5 100644 --- a/modelscope/pipelines/cv/image_cartoon_pipeline.py +++ b/modelscope/pipelines/cv/image_cartoon_pipeline.py @@ -25,20 +25,19 @@ logger = get_logger() @PIPELINES.register_module( - Tasks.image_generation, module_name='cv_unet_person-image-cartoon') + Tasks.image_generation, module_name='person-image-cartoon') class ImageCartoonPipeline(Pipeline): def __init__(self, model: str): super().__init__(model=model) - - self.facer = FaceAna(model) + self.facer = FaceAna(self.model) self.sess_anime_head = self.load_sess( - os.path.join(model, 'cartoon_anime_h.pb'), 'model_anime_head') + os.path.join(self.model, 'cartoon_anime_h.pb'), 'model_anime_head') self.sess_anime_bg = self.load_sess( - os.path.join(model, 'cartoon_anime_bg.pb'), 'model_anime_bg') + os.path.join(self.model, 'cartoon_anime_bg.pb'), 'model_anime_bg') self.box_width = 288 - global_mask = cv2.imread(os.path.join(model, 'alpha.jpg')) + global_mask = cv2.imread(os.path.join(self.model, 'alpha.jpg')) global_mask = cv2.resize( global_mask, (self.box_width, self.box_width), interpolation=cv2.INTER_AREA) diff --git a/tests/pipelines/test_person_image_cartoon.py b/tests/pipelines/test_person_image_cartoon.py index 817593f1..6f352e42 100644 --- a/tests/pipelines/test_person_image_cartoon.py +++ b/tests/pipelines/test_person_image_cartoon.py @@ -1,26 +1,31 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +import os.path as osp import unittest import cv2 from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline from modelscope.utils.constant import Tasks -def all_file(file_dir): - L = [] - for root, dirs, files in os.walk(file_dir): - for file in files: - extend = os.path.splitext(file)[1] - if extend == '.png' or extend == '.jpg' or extend == '.jpeg' or extend == '.JPG' or extend == '.HEIC': - L.append(os.path.join(root, file)) - return L +class ImageCartoonTest(unittest.TestCase): + def setUp(self) -> None: + self.model_id = 'damo/cv_unet_person-image-cartoon_compound-models' + self.test_image = \ + 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com' \ + '/data/test/maas/image_carton/test.png' -class ImageCartoonTest(unittest.TestCase): + def pipeline_inference(self, pipeline: Pipeline, input_location: str): + result = pipeline(input_location) + if result is not None: + cv2.imwrite('result.png', result['output_png']) + print(f'Output written to {osp.abspath("result.png")}') - def test_run(self): + @unittest.skip('deprecated, download model from model hub instead') + def test_run_by_direct_model_download(self): model_dir = './assets' if not os.path.exists(model_dir): os.system( @@ -29,9 +34,15 @@ class ImageCartoonTest(unittest.TestCase): os.system('unzip assets.zip') img_cartoon = pipeline(Tasks.image_generation, model=model_dir) - result = img_cartoon(os.path.join(model_dir, 'test.png')) - if result is not None: - cv2.imwrite('result.png', result['output_png']) + self.pipeline_inference(img_cartoon, self.test_image) + + def test_run_modelhub(self): + img_cartoon = pipeline(Tasks.image_generation, model=self.model_id) + self.pipeline_inference(img_cartoon, self.test_image) + + def test_run_modelhub_default_model(self): + img_cartoon = pipeline(Tasks.image_generation) + self.pipeline_inference(img_cartoon, self.test_image) if __name__ == '__main__': diff --git a/tests/utils/test_hub_operation.py b/tests/utils/test_hub_operation.py new file mode 100644 index 00000000..f432a60c --- /dev/null +++ b/tests/utils/test_hub_operation.py @@ -0,0 +1,50 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +from maas_hub.maas_api import MaasApi +from maas_hub.repository import Repository + +USER_NAME = 'maasadmin' +PASSWORD = '12345678' + + +class HubOperationTest(unittest.TestCase): + + def setUp(self): + self.api = MaasApi() + # note this is temporary before official account management is ready + self.api.login(USER_NAME, PASSWORD) + + @unittest.skip('to be used for local test only') + def test_model_repo_creation(self): + # change to proper model names before use + model_name = 'cv_unet_person-image-cartoon_compound-models' + model_chinese_name = '达摩卡通化模型' + model_org = 'damo' + try: + self.api.create_model( + owner=model_org, + name=model_name, + chinese_name=model_chinese_name, + visibility=5, # 1-private, 5-public + license='apache-2.0') + # TODO: support proper name duplication checking + except KeyError as ke: + if ke.args[0] == 'name': + print(f'model {self.model_name} already exists, ignore') + else: + raise + + # Note that this can be done via git operation once model repo + # has been created. Git-Op is the RECOMMENDED model upload approach + @unittest.skip('to be used for local test only') + def test_model_upload(self): + local_path = '/path/to/local/model/directory' + assert osp.exists(local_path), 'Local model directory not exist.' + repo = Repository(local_dir=local_path) + repo.push_to_hub(commit_message='Upload model files') + + +if __name__ == '__main__': + unittest.main()