Browse Source

[to #42322933] refine cartoon model and add model op utitlity

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8993758
master
yingda.chen 3 years ago
parent
commit
3c1ec035fd
4 changed files with 81 additions and 20 deletions
  1. +2
    -1
      modelscope/pipelines/builder.py
  2. +5
    -6
      modelscope/pipelines/cv/image_cartoon_pipeline.py
  3. +24
    -13
      tests/pipelines/test_person_image_cartoon.py
  4. +50
    -0
      tests/utils/test_hub_operation.py

+ 2
- 1
modelscope/pipelines/builder.py View File

@@ -23,7 +23,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'),
Tasks.image_captioning: ('ofa', None), Tasks.image_captioning: ('ofa', None),
Tasks.image_generation: Tasks.image_generation:
('cv_unet_person-image-cartoon', 'damo/cv_unet_image-matting_damo'),
('person-image-cartoon',
'damo/cv_unet_person-image-cartoon_compound-models'),
} }






+ 5
- 6
modelscope/pipelines/cv/image_cartoon_pipeline.py View File

@@ -25,20 +25,19 @@ logger = get_logger()




@PIPELINES.register_module( @PIPELINES.register_module(
Tasks.image_generation, module_name='cv_unet_person-image-cartoon')
Tasks.image_generation, module_name='person-image-cartoon')
class ImageCartoonPipeline(Pipeline): class ImageCartoonPipeline(Pipeline):


def __init__(self, model: str): def __init__(self, model: str):
super().__init__(model=model) super().__init__(model=model)

self.facer = FaceAna(model)
self.facer = FaceAna(self.model)
self.sess_anime_head = self.load_sess( self.sess_anime_head = self.load_sess(
os.path.join(model, 'cartoon_anime_h.pb'), 'model_anime_head')
os.path.join(self.model, 'cartoon_anime_h.pb'), 'model_anime_head')
self.sess_anime_bg = self.load_sess( self.sess_anime_bg = self.load_sess(
os.path.join(model, 'cartoon_anime_bg.pb'), 'model_anime_bg')
os.path.join(self.model, 'cartoon_anime_bg.pb'), 'model_anime_bg')


self.box_width = 288 self.box_width = 288
global_mask = cv2.imread(os.path.join(model, 'alpha.jpg'))
global_mask = cv2.imread(os.path.join(self.model, 'alpha.jpg'))
global_mask = cv2.resize( global_mask = cv2.resize(
global_mask, (self.box_width, self.box_width), global_mask, (self.box_width, self.box_width),
interpolation=cv2.INTER_AREA) interpolation=cv2.INTER_AREA)


+ 24
- 13
tests/pipelines/test_person_image_cartoon.py View File

@@ -1,26 +1,31 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import os import os
import os.path as osp
import unittest import unittest


import cv2 import cv2


from modelscope.pipelines import pipeline from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks




def all_file(file_dir):
L = []
for root, dirs, files in os.walk(file_dir):
for file in files:
extend = os.path.splitext(file)[1]
if extend == '.png' or extend == '.jpg' or extend == '.jpeg' or extend == '.JPG' or extend == '.HEIC':
L.append(os.path.join(root, file))
return L
class ImageCartoonTest(unittest.TestCase):


def setUp(self) -> None:
self.model_id = 'damo/cv_unet_person-image-cartoon_compound-models'
self.test_image = \
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com' \
'/data/test/maas/image_carton/test.png'


class ImageCartoonTest(unittest.TestCase):
def pipeline_inference(self, pipeline: Pipeline, input_location: str):
result = pipeline(input_location)
if result is not None:
cv2.imwrite('result.png', result['output_png'])
print(f'Output written to {osp.abspath("result.png")}')


def test_run(self):
@unittest.skip('deprecated, download model from model hub instead')
def test_run_by_direct_model_download(self):
model_dir = './assets' model_dir = './assets'
if not os.path.exists(model_dir): if not os.path.exists(model_dir):
os.system( os.system(
@@ -29,9 +34,15 @@ class ImageCartoonTest(unittest.TestCase):
os.system('unzip assets.zip') os.system('unzip assets.zip')


img_cartoon = pipeline(Tasks.image_generation, model=model_dir) img_cartoon = pipeline(Tasks.image_generation, model=model_dir)
result = img_cartoon(os.path.join(model_dir, 'test.png'))
if result is not None:
cv2.imwrite('result.png', result['output_png'])
self.pipeline_inference(img_cartoon, self.test_image)

def test_run_modelhub(self):
img_cartoon = pipeline(Tasks.image_generation, model=self.model_id)
self.pipeline_inference(img_cartoon, self.test_image)

def test_run_modelhub_default_model(self):
img_cartoon = pipeline(Tasks.image_generation)
self.pipeline_inference(img_cartoon, self.test_image)




if __name__ == '__main__': if __name__ == '__main__':


+ 50
- 0
tests/utils/test_hub_operation.py View File

@@ -0,0 +1,50 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import unittest

from maas_hub.maas_api import MaasApi
from maas_hub.repository import Repository

USER_NAME = 'maasadmin'
PASSWORD = '12345678'


class HubOperationTest(unittest.TestCase):

def setUp(self):
self.api = MaasApi()
# note this is temporary before official account management is ready
self.api.login(USER_NAME, PASSWORD)

@unittest.skip('to be used for local test only')
def test_model_repo_creation(self):
# change to proper model names before use
model_name = 'cv_unet_person-image-cartoon_compound-models'
model_chinese_name = '达摩卡通化模型'
model_org = 'damo'
try:
self.api.create_model(
owner=model_org,
name=model_name,
chinese_name=model_chinese_name,
visibility=5, # 1-private, 5-public
license='apache-2.0')
# TODO: support proper name duplication checking
except KeyError as ke:
if ke.args[0] == 'name':
print(f'model {self.model_name} already exists, ignore')
else:
raise

# Note that this can be done via git operation once model repo
# has been created. Git-Op is the RECOMMENDED model upload approach
@unittest.skip('to be used for local test only')
def test_model_upload(self):
local_path = '/path/to/local/model/directory'
assert osp.exists(local_path), 'Local model directory not exist.'
repo = Repository(local_dir=local_path)
repo.push_to_hub(commit_message='Upload model files')


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save