Browse Source

[to #42322933] feat: add hand keypoints pipeline

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9961906

    * feat: add hand keypoints pipeline
master
liangting.zl yingda.chen 3 years ago
parent
commit
4484dcaa04
8 changed files with 121 additions and 0 deletions
  1. +3
    -0
      data/test/images/hand_keypoints.jpg
  2. +1
    -0
      modelscope/metainfo.py
  3. +15
    -0
      modelscope/outputs.py
  4. +3
    -0
      modelscope/pipelines/builder.py
  5. +2
    -0
      modelscope/pipelines/cv/__init__.py
  6. +51
    -0
      modelscope/pipelines/cv/hand_2d_keypoints_pipeline.py
  7. +1
    -0
      modelscope/utils/constant.py
  8. +45
    -0
      tests/pipelines/test_hand_2d_keypoints.py

+ 3
- 0
data/test/images/hand_keypoints.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c05d58edee7398de37b8e479410676d6b97cfde69cc003e8356a348067e71988
size 7750

+ 1
- 0
modelscope/metainfo.py View File

@@ -112,6 +112,7 @@ class Pipelines(object):
hicossl_video_embedding = 'hicossl-s3dg-video_embedding'
body_2d_keypoints = 'hrnetv2w32_body-2d-keypoints_image'
body_3d_keypoints = 'canonical_body-3d-keypoints_video'
hand_2d_keypoints = 'hrnetv2w18_hand-2d-keypoints_image'
human_detection = 'resnet18-human-detection'
object_detection = 'vit-object-detection'
easycv_detection = 'easycv-detection'


+ 15
- 0
modelscope/outputs.py View File

@@ -219,6 +219,21 @@ TASK_OUTPUTS = {
# }
Tasks.body_3d_keypoints: [OutputKeys.POSES],

# 2D hand keypoints result for single sample
# {
# "keypoints": [
# [[x, y, score] * 21],
# [[x, y, score] * 21],
# [[x, y, score] * 21],
# ],
# "boxes": [
# [x1, y1, x2, y2],
# [x1, y1, x2, y2],
# [x1, y1, x2, y2],
# ]
# }
Tasks.hand_2d_keypoints: [OutputKeys.KEYPOINTS, OutputKeys.BOXES],

# video single object tracking result for single video
# {
# "boxes": [


+ 3
- 0
modelscope/pipelines/builder.py View File

@@ -99,6 +99,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_hrnetv2w32_body-2d-keypoints_image'),
Tasks.body_3d_keypoints: (Pipelines.body_3d_keypoints,
'damo/cv_canonical_body-3d-keypoints_video'),
Tasks.hand_2d_keypoints:
(Pipelines.hand_2d_keypoints,
'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody'),
Tasks.face_detection: (Pipelines.face_detection,
'damo/cv_resnet_facedetection_scrfd10gkps'),
Tasks.face_recognition: (Pipelines.face_recognition,


+ 2
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -9,6 +9,7 @@ if TYPE_CHECKING:
from .animal_recognition_pipeline import AnimalRecognitionPipeline
from .body_2d_keypoints_pipeline import Body2DKeypointsPipeline
from .body_3d_keypoints_pipeline import Body3DKeypointsPipeline
from .hand_2d_keypoints_pipeline import Hand2DKeypointsPipeline
from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline
from .hicossl_video_embedding_pipeline import HICOSSLVideoEmbeddingPipeline
from .crowd_counting_pipeline import CrowdCountingPipeline
@@ -57,6 +58,7 @@ else:
'animal_recognition_pipeline': ['AnimalRecognitionPipeline'],
'body_2d_keypoints_pipeline': ['Body2DKeypointsPipeline'],
'body_3d_keypoints_pipeline': ['Body3DKeypointsPipeline'],
'hand_2d_keypoints_pipeline': ['Hand2DKeypointsPipeline'],
'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'],
'hicossl_video_embedding_pipeline': ['HICOSSLVideoEmbeddingPipeline'],
'crowd_counting_pipeline': ['CrowdCountingPipeline'],


+ 51
- 0
modelscope/pipelines/cv/hand_2d_keypoints_pipeline.py View File

@@ -0,0 +1,51 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path

from modelscope.metainfo import Pipelines
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import ModelFile, Tasks
from .easycv_pipelines.base import EasyCVPipeline


@PIPELINES.register_module(
Tasks.hand_2d_keypoints, module_name=Pipelines.hand_2d_keypoints)
class Hand2DKeypointsPipeline(EasyCVPipeline):
"""Pipeline for hand pose keypoint task."""

def __init__(self,
model: str,
model_file_pattern=ModelFile.TORCH_MODEL_FILE,
*args,
**kwargs):
"""
model (str): model id on modelscope hub or local model path.
model_file_pattern (str): model file pattern.
"""
self.model_dir = model
super(Hand2DKeypointsPipeline, self).__init__(
model=model,
model_file_pattern=model_file_pattern,
*args,
**kwargs)

def _build_predict_op(self):
"""Build EasyCV predictor."""
from easycv.predictors.builder import build_predictor
detection_predictor_type = self.cfg['DETECTION']['type']
detection_model_path = os.path.join(
self.model_dir, self.cfg['DETECTION']['model_path'])
detection_cfg_file = os.path.join(self.model_dir,
self.cfg['DETECTION']['config_file'])
detection_score_threshold = self.cfg['DETECTION']['score_threshold']
self.cfg.pipeline.predictor_config[
'detection_predictor_config'] = dict(
type=detection_predictor_type,
model_path=detection_model_path,
config_file=detection_cfg_file,
score_threshold=detection_score_threshold)
easycv_config = self._to_easycv_config()
pipeline_op = build_predictor(self.cfg.pipeline.predictor_config, {
'model_path': self.model_path,
'config_file': easycv_config
})
return pipeline_op

+ 1
- 0
modelscope/utils/constant.py View File

@@ -27,6 +27,7 @@ class CVTasks(object):
face_image_generation = 'face-image-generation'
body_2d_keypoints = 'body-2d-keypoints'
body_3d_keypoints = 'body-3d-keypoints'
hand_2d_keypoints = 'hand-2d-keypoints'
general_recognition = 'general-recognition'

image_classification = 'image-classification'


+ 45
- 0
tests/pipelines/test_hand_2d_keypoints.py View File

@@ -0,0 +1,45 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class Hand2DKeypointsPipelineTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_hand_2d_keypoints(self):
img_path = 'data/test/images/hand_keypoints.jpg'
model_id = 'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody'

hand_keypoint = pipeline(task=Tasks.hand_2d_keypoints, model=model_id)
outputs = hand_keypoint(img_path)
self.assertEqual(len(outputs), 1)

results = outputs[0]
self.assertIn(OutputKeys.KEYPOINTS, results.keys())
self.assertIn(OutputKeys.BOXES, results.keys())
self.assertEqual(results[OutputKeys.KEYPOINTS].shape[1], 21)
self.assertEqual(results[OutputKeys.KEYPOINTS].shape[2], 3)
self.assertEqual(results[OutputKeys.BOXES].shape[1], 4)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_hand_2d_keypoints_with_default_model(self):
img_path = 'data/test/images/hand_keypoints.jpg'

hand_keypoint = pipeline(task=Tasks.hand_2d_keypoints)
outputs = hand_keypoint(img_path)
self.assertEqual(len(outputs), 1)

results = outputs[0]
self.assertIn(OutputKeys.KEYPOINTS, results.keys())
self.assertIn(OutputKeys.BOXES, results.keys())
self.assertEqual(results[OutputKeys.KEYPOINTS].shape[1], 21)
self.assertEqual(results[OutputKeys.KEYPOINTS].shape[2], 3)
self.assertEqual(results[OutputKeys.BOXES].shape[1], 4)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save