You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_hand_2d_keypoints.py 1.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from modelscope.outputs import OutputKeys
  4. from modelscope.pipelines import pipeline
  5. from modelscope.utils.constant import Tasks
  6. from modelscope.utils.test_utils import test_level
  7. class Hand2DKeypointsPipelineTest(unittest.TestCase):
  8. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  9. def test_hand_2d_keypoints(self):
  10. img_path = 'data/test/images/hand_keypoints.jpg'
  11. model_id = 'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody'
  12. hand_keypoint = pipeline(task=Tasks.hand_2d_keypoints, model=model_id)
  13. results = hand_keypoint(img_path)
  14. self.assertIn(OutputKeys.KEYPOINTS, results.keys())
  15. self.assertIn(OutputKeys.BOXES, results.keys())
  16. self.assertEqual(results[OutputKeys.KEYPOINTS].shape[1], 21)
  17. self.assertEqual(results[OutputKeys.KEYPOINTS].shape[2], 3)
  18. self.assertEqual(results[OutputKeys.BOXES].shape[1], 4)
  19. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  20. def test_hand_2d_keypoints_with_default_model(self):
  21. img_path = 'data/test/images/hand_keypoints.jpg'
  22. hand_keypoint = pipeline(task=Tasks.hand_2d_keypoints)
  23. results = hand_keypoint(img_path)
  24. self.assertIn(OutputKeys.KEYPOINTS, results.keys())
  25. self.assertIn(OutputKeys.BOXES, results.keys())
  26. self.assertEqual(results[OutputKeys.KEYPOINTS].shape[1], 21)
  27. self.assertEqual(results[OutputKeys.KEYPOINTS].shape[2], 3)
  28. self.assertEqual(results[OutputKeys.BOXES].shape[1], 4)
  29. if __name__ == '__main__':
  30. unittest.main()