You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_image_instance_segmentation.py 3.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import unittest
  4. from modelscope.hub.snapshot_download import snapshot_download
  5. from modelscope.models import Model
  6. from modelscope.models.cv.image_instance_segmentation import \
  7. CascadeMaskRCNNSwinModel
  8. from modelscope.outputs import OutputKeys
  9. from modelscope.pipelines import pipeline
  10. from modelscope.pipelines.cv import ImageInstanceSegmentationPipeline
  11. from modelscope.preprocessors import build_preprocessor
  12. from modelscope.utils.config import Config
  13. from modelscope.utils.constant import Fields, ModelFile, Tasks
  14. from modelscope.utils.demo_utils import DemoCompatibilityCheck
  15. from modelscope.utils.test_utils import test_level
  16. class ImageInstanceSegmentationTest(unittest.TestCase, DemoCompatibilityCheck):
  17. def setUp(self) -> None:
  18. self.task = Tasks.image_segmentation
  19. self.model_id = 'damo/cv_swin-b_image-instance-segmentation_coco'
  20. image = 'data/test/images/image_instance_segmentation.jpg'
  21. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  22. def test_run_with_model_from_modelhub(self):
  23. model = Model.from_pretrained(self.model_id)
  24. config_path = os.path.join(model.model_dir, ModelFile.CONFIGURATION)
  25. cfg = Config.from_file(config_path)
  26. preprocessor = build_preprocessor(cfg.preprocessor, Fields.cv)
  27. pipeline_ins = pipeline(
  28. task=Tasks.image_segmentation,
  29. model=model,
  30. preprocessor=preprocessor)
  31. print(pipeline_ins(input=self.image)[OutputKeys.LABELS])
  32. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  33. def test_run_with_model_name(self):
  34. pipeline_ins = pipeline(
  35. task=Tasks.image_segmentation, model=self.model_id)
  36. print(pipeline_ins(input=self.image)[OutputKeys.LABELS])
  37. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  38. def test_run_with_default_model(self):
  39. pipeline_ins = pipeline(task=Tasks.image_segmentation)
  40. print(pipeline_ins(input=self.image)[OutputKeys.LABELS])
  41. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  42. def test_run_by_direct_model_download(self):
  43. cache_path = snapshot_download(self.model_id)
  44. config_path = os.path.join(cache_path, ModelFile.CONFIGURATION)
  45. cfg = Config.from_file(config_path)
  46. preprocessor = build_preprocessor(cfg.preprocessor, Fields.cv)
  47. model = CascadeMaskRCNNSwinModel(cache_path)
  48. pipeline1 = ImageInstanceSegmentationPipeline(
  49. model, preprocessor=preprocessor)
  50. pipeline2 = pipeline(
  51. Tasks.image_segmentation, model=model, preprocessor=preprocessor)
  52. print(f'pipeline1:{pipeline1(input=self.image)[OutputKeys.LABELS]}')
  53. print(f'pipeline2: {pipeline2(input=self.image)[OutputKeys.LABELS]}')
  54. @unittest.skip('demo compatibility test is only enabled on a needed-basis')
  55. def test_demo_compatibility(self):
  56. self.compatibility_check()
  57. if __name__ == '__main__':
  58. unittest.main()