Browse Source

[to #42322933] add multi-style cartoon models to ut

1. 卡通化接入多风格模型(原始日漫风、3D、手绘风、素描风、艺术特效风格),添加ut接入测试
2. 修改pipeline中模型文件名称至通用名
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10153717
master
myf272609 yingda.chen 3 years ago
parent
commit
0041ab0ab8
2 changed files with 30 additions and 4 deletions
  1. +2
    -4
      modelscope/pipelines/cv/image_cartoon_pipeline.py
  2. +28
    -0
      tests/pipelines/test_person_image_cartoon.py

+ 2
- 4
modelscope/pipelines/cv/image_cartoon_pipeline.py View File

@@ -40,11 +40,9 @@ class ImageCartoonPipeline(Pipeline):
with device_placement(self.framework, self.device_name):
self.facer = FaceAna(self.model)
self.sess_anime_head = self.load_sess(
os.path.join(self.model, 'cartoon_anime_h.pb'),
'model_anime_head')
os.path.join(self.model, 'cartoon_h.pb'), 'model_anime_head')
self.sess_anime_bg = self.load_sess(
os.path.join(self.model, 'cartoon_anime_bg.pb'),
'model_anime_bg')
os.path.join(self.model, 'cartoon_bg.pb'), 'model_anime_bg')

self.box_width = 288
global_mask = cv2.imread(os.path.join(self.model, 'alpha.jpg'))


+ 28
- 0
tests/pipelines/test_person_image_cartoon.py View File

@@ -16,6 +16,10 @@ class ImageCartoonTest(unittest.TestCase, DemoCompatibilityCheck):

def setUp(self) -> None:
self.model_id = 'damo/cv_unet_person-image-cartoon_compound-models'
self.model_id_3d = 'damo/cv_unet_person-image-cartoon-3d_compound-models'
self.model_id_handdrawn = 'damo/cv_unet_person-image-cartoon-handdrawn_compound-models'
self.model_id_sketch = 'damo/cv_unet_person-image-cartoon-sketch_compound-models'
self.model_id_artstyle = 'damo/cv_unet_person-image-cartoon-artstyle_compound-models'
self.task = Tasks.image_portrait_stylization
self.test_image = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_cartoon.png'

@@ -31,6 +35,30 @@ class ImageCartoonTest(unittest.TestCase, DemoCompatibilityCheck):
Tasks.image_portrait_stylization, model=self.model_id)
self.pipeline_inference(img_cartoon, self.test_image)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub_3d(self):
img_cartoon = pipeline(
Tasks.image_portrait_stylization, model=self.model_id_3d)
self.pipeline_inference(img_cartoon, self.test_image)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub_handdrawn(self):
img_cartoon = pipeline(
Tasks.image_portrait_stylization, model=self.model_id_handdrawn)
self.pipeline_inference(img_cartoon, self.test_image)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub_sketch(self):
img_cartoon = pipeline(
Tasks.image_portrait_stylization, model=self.model_id_sketch)
self.pipeline_inference(img_cartoon, self.test_image)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub_artstyle(self):
img_cartoon = pipeline(
Tasks.image_portrait_stylization, model=self.model_id_artstyle)
self.pipeline_inference(img_cartoon, self.test_image)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
img_cartoon = pipeline(Tasks.image_portrait_stylization)


Loading…
Cancel
Save