You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_segmentation_pipeline.py 3.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from distutils.version import LooseVersion
  4. import cv2
  5. import easycv
  6. import numpy as np
  7. from PIL import Image
  8. from modelscope.outputs import OutputKeys
  9. from modelscope.pipelines import pipeline
  10. from modelscope.utils.constant import Tasks
  11. from modelscope.utils.cv.image_utils import semantic_seg_masks_to_image
  12. from modelscope.utils.demo_utils import DemoCompatibilityCheck
  13. from modelscope.utils.test_utils import test_level
  14. class EasyCVSegmentationPipelineTest(unittest.TestCase,
  15. DemoCompatibilityCheck):
  16. img_path = 'data/test/images/image_segmentation.jpg'
  17. def setUp(self) -> None:
  18. self.task = Tasks.image_segmentation
  19. self.model_id = 'damo/cv_segformer-b0_image_semantic-segmentation_coco-stuff164k'
  20. def _internal_test_(self, model_id):
  21. semantic_seg = pipeline(task=Tasks.image_segmentation, model=model_id)
  22. outputs = semantic_seg(self.img_path)
  23. draw_img = semantic_seg_masks_to_image(outputs[OutputKeys.MASKS])
  24. cv2.imwrite('result.jpg', draw_img)
  25. print('test ' + model_id + ' DONE')
  26. def _internal_test_batch_(self, model_id, num_samples=2, batch_size=2):
  27. # TODO: support in the future
  28. img = np.asarray(Image.open(self.img_path))
  29. num_samples = num_samples
  30. batch_size = batch_size
  31. semantic_seg = pipeline(
  32. task=Tasks.image_segmentation,
  33. model=model_id,
  34. batch_size=batch_size)
  35. outputs = semantic_seg([self.img_path] * num_samples)
  36. self.assertEqual(semantic_seg.predict_op.batch_size, batch_size)
  37. self.assertEqual(len(outputs), num_samples)
  38. for output in outputs:
  39. self.assertListEqual(
  40. list(img.shape)[:2], list(output['seg_pred'].shape))
  41. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  42. def test_segformer_b0(self):
  43. model_id = 'damo/cv_segformer-b0_image_semantic-segmentation_coco-stuff164k'
  44. self._internal_test_(model_id)
  45. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  46. def test_segformer_b1(self):
  47. model_id = 'damo/cv_segformer-b1_image_semantic-segmentation_coco-stuff164k'
  48. self._internal_test_(model_id)
  49. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  50. def test_segformer_b2(self):
  51. model_id = 'damo/cv_segformer-b2_image_semantic-segmentation_coco-stuff164k'
  52. self._internal_test_(model_id)
  53. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  54. def test_segformer_b3(self):
  55. model_id = 'damo/cv_segformer-b3_image_semantic-segmentation_coco-stuff164k'
  56. self._internal_test_(model_id)
  57. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  58. def test_segformer_b4(self):
  59. model_id = 'damo/cv_segformer-b4_image_semantic-segmentation_coco-stuff164k'
  60. self._internal_test_(model_id)
  61. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  62. def test_segformer_b5(self):
  63. model_id = 'damo/cv_segformer-b5_image_semantic-segmentation_coco-stuff164k'
  64. self._internal_test_(model_id)
  65. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  66. def test_demo_compatibility(self):
  67. self.compatibility_check()
  68. if __name__ == '__main__':
  69. unittest.main()