Browse Source

[to #42322933] refine cartoon model and add model op utitlity

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8993758
master
yingda.chen 3 years ago
parent
commit
3c1ec035fd
4 changed files with 81 additions and 20 deletions
  1. +2
    -1
      modelscope/pipelines/builder.py
  2. +5
    -6
      modelscope/pipelines/cv/image_cartoon_pipeline.py
  3. +24
    -13
      tests/pipelines/test_person_image_cartoon.py
  4. +50
    -0
      tests/utils/test_hub_operation.py

+ 2
- 1
modelscope/pipelines/builder.py View File

@@ -23,7 +23,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'),
Tasks.image_captioning: ('ofa', None),
Tasks.image_generation:
('cv_unet_person-image-cartoon', 'damo/cv_unet_image-matting_damo'),
('person-image-cartoon',
'damo/cv_unet_person-image-cartoon_compound-models'),
}




+ 5
- 6
modelscope/pipelines/cv/image_cartoon_pipeline.py View File

@@ -25,20 +25,19 @@ logger = get_logger()


@PIPELINES.register_module(
Tasks.image_generation, module_name='cv_unet_person-image-cartoon')
Tasks.image_generation, module_name='person-image-cartoon')
class ImageCartoonPipeline(Pipeline):

def __init__(self, model: str):
super().__init__(model=model)

self.facer = FaceAna(model)
self.facer = FaceAna(self.model)
self.sess_anime_head = self.load_sess(
os.path.join(model, 'cartoon_anime_h.pb'), 'model_anime_head')
os.path.join(self.model, 'cartoon_anime_h.pb'), 'model_anime_head')
self.sess_anime_bg = self.load_sess(
os.path.join(model, 'cartoon_anime_bg.pb'), 'model_anime_bg')
os.path.join(self.model, 'cartoon_anime_bg.pb'), 'model_anime_bg')

self.box_width = 288
global_mask = cv2.imread(os.path.join(model, 'alpha.jpg'))
global_mask = cv2.imread(os.path.join(self.model, 'alpha.jpg'))
global_mask = cv2.resize(
global_mask, (self.box_width, self.box_width),
interpolation=cv2.INTER_AREA)


+ 24
- 13
tests/pipelines/test_person_image_cartoon.py View File

@@ -1,26 +1,31 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
import unittest

import cv2

from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.utils.constant import Tasks


def all_file(file_dir):
L = []
for root, dirs, files in os.walk(file_dir):
for file in files:
extend = os.path.splitext(file)[1]
if extend == '.png' or extend == '.jpg' or extend == '.jpeg' or extend == '.JPG' or extend == '.HEIC':
L.append(os.path.join(root, file))
return L
class ImageCartoonTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/cv_unet_person-image-cartoon_compound-models'
self.test_image = \
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com' \
'/data/test/maas/image_carton/test.png'

class ImageCartoonTest(unittest.TestCase):
def pipeline_inference(self, pipeline: Pipeline, input_location: str):
result = pipeline(input_location)
if result is not None:
cv2.imwrite('result.png', result['output_png'])
print(f'Output written to {osp.abspath("result.png")}')

def test_run(self):
@unittest.skip('deprecated, download model from model hub instead')
def test_run_by_direct_model_download(self):
model_dir = './assets'
if not os.path.exists(model_dir):
os.system(
@@ -29,9 +34,15 @@ class ImageCartoonTest(unittest.TestCase):
os.system('unzip assets.zip')

img_cartoon = pipeline(Tasks.image_generation, model=model_dir)
result = img_cartoon(os.path.join(model_dir, 'test.png'))
if result is not None:
cv2.imwrite('result.png', result['output_png'])
self.pipeline_inference(img_cartoon, self.test_image)

def test_run_modelhub(self):
img_cartoon = pipeline(Tasks.image_generation, model=self.model_id)
self.pipeline_inference(img_cartoon, self.test_image)

def test_run_modelhub_default_model(self):
img_cartoon = pipeline(Tasks.image_generation)
self.pipeline_inference(img_cartoon, self.test_image)


if __name__ == '__main__':


+ 50
- 0
tests/utils/test_hub_operation.py View File

@@ -0,0 +1,50 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import unittest

from maas_hub.maas_api import MaasApi
from maas_hub.repository import Repository

USER_NAME = 'maasadmin'
PASSWORD = '12345678'


class HubOperationTest(unittest.TestCase):

def setUp(self):
self.api = MaasApi()
# note this is temporary before official account management is ready
self.api.login(USER_NAME, PASSWORD)

@unittest.skip('to be used for local test only')
def test_model_repo_creation(self):
# change to proper model names before use
model_name = 'cv_unet_person-image-cartoon_compound-models'
model_chinese_name = '达摩卡通化模型'
model_org = 'damo'
try:
self.api.create_model(
owner=model_org,
name=model_name,
chinese_name=model_chinese_name,
visibility=5, # 1-private, 5-public
license='apache-2.0')
# TODO: support proper name duplication checking
except KeyError as ke:
if ke.args[0] == 'name':
print(f'model {self.model_name} already exists, ignore')
else:
raise

# Note that this can be done via git operation once model repo
# has been created. Git-Op is the RECOMMENDED model upload approach
@unittest.skip('to be used for local test only')
def test_model_upload(self):
local_path = '/path/to/local/model/directory'
assert osp.exists(local_path), 'Local model directory not exist.'
repo = Repository(local_dir=local_path)
repo.push_to_hub(commit_message='Upload model files')


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save