You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_image_instance_segmentation.py 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import unittest
  4. from modelscope.hub.snapshot_download import snapshot_download
  5. from modelscope.models import Model
  6. from modelscope.models.cv.image_instance_segmentation import \
  7. CascadeMaskRCNNSwinModel
  8. from modelscope.outputs import OutputKeys
  9. from modelscope.pipelines import pipeline
  10. from modelscope.pipelines.cv import ImageInstanceSegmentationPipeline
  11. from modelscope.preprocessors import build_preprocessor
  12. from modelscope.utils.config import Config
  13. from modelscope.utils.constant import Fields, ModelFile, Tasks
  14. from modelscope.utils.test_utils import test_level
  15. class ImageInstanceSegmentationTest(unittest.TestCase):
  16. model_id = 'damo/cv_swin-b_image-instance-segmentation_coco'
  17. image = 'data/test/images/image_instance_segmentation.jpg'
  18. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  19. def test_run_with_model_from_modelhub(self):
  20. model = Model.from_pretrained(self.model_id)
  21. config_path = os.path.join(model.model_dir, ModelFile.CONFIGURATION)
  22. cfg = Config.from_file(config_path)
  23. preprocessor = build_preprocessor(cfg.preprocessor, Fields.cv)
  24. pipeline_ins = pipeline(
  25. task=Tasks.image_segmentation,
  26. model=model,
  27. preprocessor=preprocessor)
  28. print(pipeline_ins(input=self.image)[OutputKeys.LABELS])
  29. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  30. def test_run_with_model_name(self):
  31. pipeline_ins = pipeline(
  32. task=Tasks.image_segmentation, model=self.model_id)
  33. print(pipeline_ins(input=self.image)[OutputKeys.LABELS])
  34. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  35. def test_run_with_default_model(self):
  36. pipeline_ins = pipeline(task=Tasks.image_segmentation)
  37. print(pipeline_ins(input=self.image)[OutputKeys.LABELS])
  38. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  39. def test_run_by_direct_model_download(self):
  40. cache_path = snapshot_download(self.model_id)
  41. config_path = os.path.join(cache_path, ModelFile.CONFIGURATION)
  42. cfg = Config.from_file(config_path)
  43. preprocessor = build_preprocessor(cfg.preprocessor, Fields.cv)
  44. model = CascadeMaskRCNNSwinModel(cache_path)
  45. pipeline1 = ImageInstanceSegmentationPipeline(
  46. model, preprocessor=preprocessor)
  47. pipeline2 = pipeline(
  48. Tasks.image_segmentation, model=model, preprocessor=preprocessor)
  49. print(f'pipeline1:{pipeline1(input=self.image)[OutputKeys.LABELS]}')
  50. print(f'pipeline2: {pipeline2(input=self.image)[OutputKeys.LABELS]}')
  51. if __name__ == '__main__':
  52. unittest.main()