Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10208603 * update easycv pipelinesmaster
@@ -10,6 +10,7 @@ from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.pipelines.util import is_official_hub_path | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||
from modelscope.utils.device import create_device | |||
class EasyCVPipeline(object): | |||
@@ -53,16 +54,19 @@ class EasyCVPipeline(object): | |||
), f'Not find "{ModelFile.CONFIGURATION}" in model directory!' | |||
self.cfg = Config.from_file(self.config_file) | |||
self.predict_op = self._build_predict_op() | |||
if 'device' in kwargs: | |||
kwargs['device'] = create_device(kwargs['device']) | |||
self.predict_op = self._build_predict_op(**kwargs) | |||
def _build_predict_op(self): | |||
def _build_predict_op(self, **kwargs): | |||
"""Build EasyCV predictor.""" | |||
from easycv.predictors.builder import build_predictor | |||
easycv_config = self._to_easycv_config() | |||
pipeline_op = build_predictor(self.cfg.pipeline.predictor_config, { | |||
'model_path': self.model_path, | |||
'config_file': easycv_config | |||
'config_file': easycv_config, | |||
**kwargs | |||
}) | |||
return pipeline_op | |||
@@ -91,5 +95,4 @@ class EasyCVPipeline(object): | |||
return easycv_config | |||
def __call__(self, inputs) -> Any: | |||
# TODO: support image url | |||
return self.predict_op(inputs) |
@@ -4,7 +4,6 @@ from typing import Any | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import LoadImage | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from .base import EasyCVPipeline | |||
@@ -34,8 +33,11 @@ class Face2DKeypointsPipeline(EasyCVPipeline): | |||
return self.predict_op.show_result(img, points, scale, save_path) | |||
def __call__(self, inputs) -> Any: | |||
output = self.predict_op(inputs)[0][0] | |||
points = output['point'] | |||
poses = output['pose'] | |||
outputs = self.predict_op(inputs) | |||
return {OutputKeys.KEYPOINTS: points, OutputKeys.POSES: poses} | |||
results = [{ | |||
OutputKeys.KEYPOINTS: output['point'], | |||
OutputKeys.POSES: output['pose'] | |||
} for output in outputs] | |||
return results |
@@ -28,7 +28,7 @@ class Hand2DKeypointsPipeline(EasyCVPipeline): | |||
*args, | |||
**kwargs) | |||
def _build_predict_op(self): | |||
def _build_predict_op(self, **kwargs): | |||
"""Build EasyCV predictor.""" | |||
from easycv.predictors.builder import build_predictor | |||
detection_predictor_type = self.cfg['DETECTION']['type'] | |||
@@ -46,6 +46,7 @@ class Hand2DKeypointsPipeline(EasyCVPipeline): | |||
easycv_config = self._to_easycv_config() | |||
pipeline_op = build_predictor(self.cfg.pipeline.predictor_config, { | |||
'model_path': self.model_path, | |||
'config_file': easycv_config | |||
'config_file': easycv_config, | |||
**kwargs | |||
}) | |||
return pipeline_op |
@@ -14,7 +14,7 @@ mmcls>=0.21.0 | |||
mmdet>=2.25.0 | |||
networkx>=2.5 | |||
onnxruntime>=1.10 | |||
pai-easycv>=0.6.0 | |||
pai-easycv>=0.6.3.4 | |||
pandas | |||
psutil | |||
regex | |||
@@ -1,10 +1,11 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
from distutils.version import LooseVersion | |||
import easycv | |||
import numpy as np | |||
from PIL import Image | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.test_utils import test_level | |||
@@ -24,38 +25,60 @@ class EasyCVSegmentationPipelineTest(unittest.TestCase): | |||
results = outputs[0] | |||
self.assertListEqual( | |||
list(img.shape)[:2], list(results['seg_pred'][0].shape)) | |||
self.assertListEqual(results['seg_pred'][0][1, 4:10].tolist(), | |||
list(img.shape)[:2], list(results['seg_pred'].shape)) | |||
self.assertListEqual(results['seg_pred'][1, 4:10].tolist(), | |||
[161 for i in range(6)]) | |||
self.assertListEqual(results['seg_pred'][0][-1, -10:].tolist(), | |||
self.assertListEqual(results['seg_pred'][-1, -10:].tolist(), | |||
[133 for i in range(10)]) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def _internal_test_batch(self, model_id, num_samples=2, batch_size=2): | |||
# TODO: support in the future | |||
img = np.asarray(Image.open(self.img_path)) | |||
num_samples = num_samples | |||
batch_size = batch_size | |||
semantic_seg = pipeline( | |||
task=Tasks.image_segmentation, | |||
model=model_id, | |||
batch_size=batch_size) | |||
outputs = semantic_seg([self.img_path] * num_samples) | |||
self.assertEqual(semantic_seg.predict_op.batch_size, batch_size) | |||
self.assertEqual(len(outputs), num_samples) | |||
for output in outputs: | |||
self.assertListEqual( | |||
list(img.shape)[:2], list(output['seg_pred'].shape)) | |||
self.assertListEqual(output['seg_pred'][1, 4:10].tolist(), | |||
[161 for i in range(6)]) | |||
self.assertListEqual(output['seg_pred'][-1, -10:].tolist(), | |||
[133 for i in range(10)]) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_segformer_b0(self): | |||
model_id = 'damo/cv_segformer-b0_image_semantic-segmentation_coco-stuff164k' | |||
self._internal_test__(model_id) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_segformer_b1(self): | |||
model_id = 'damo/cv_segformer-b1_image_semantic-segmentation_coco-stuff164k' | |||
self._internal_test__(model_id) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_segformer_b2(self): | |||
model_id = 'damo/cv_segformer-b2_image_semantic-segmentation_coco-stuff164k' | |||
self._internal_test__(model_id) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_segformer_b3(self): | |||
model_id = 'damo/cv_segformer-b3_image_semantic-segmentation_coco-stuff164k' | |||
self._internal_test__(model_id) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_segformer_b4(self): | |||
model_id = 'damo/cv_segformer-b4_image_semantic-segmentation_coco-stuff164k' | |||
self._internal_test__(model_id) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_segformer_b5(self): | |||
model_id = 'damo/cv_segformer-b5_image_semantic-segmentation_coco-stuff164k' | |||
self._internal_test__(model_id) | |||
@@ -18,7 +18,7 @@ class EasyCVFace2DKeypointsPipelineTest(unittest.TestCase): | |||
face_2d_keypoints_align = pipeline( | |||
task=Tasks.face_2d_keypoints, model=model_id) | |||
output = face_2d_keypoints_align(img_path) | |||
output = face_2d_keypoints_align(img_path)[0] | |||
output_keypoints = output[OutputKeys.KEYPOINTS] | |||
output_pose = output[OutputKeys.POSES] | |||
@@ -9,6 +9,7 @@ isolated: # test cases that may require excessive anmount of GPU memory, which | |||
- test_image_super_resolution.py | |||
- test_easycv_trainer.py | |||
- test_segformer.py | |||
- test_segmentation_pipeline.py | |||
envs: | |||
default: # default env, case not in other env will in default, pytorch. | |||