You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_image_instance_segmentation.py 2.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import unittest
  4. from modelscope.hub.snapshot_download import snapshot_download
  5. from modelscope.models import Model
  6. from modelscope.models.cv.image_instance_segmentation.model import \
  7. CascadeMaskRCNNSwinModel
  8. from modelscope.outputs import OutputKeys
  9. from modelscope.pipelines import ImageInstanceSegmentationPipeline, pipeline
  10. from modelscope.preprocessors import build_preprocessor
  11. from modelscope.utils.config import Config
  12. from modelscope.utils.constant import Fields, ModelFile, Tasks
  13. from modelscope.utils.test_utils import test_level
  14. class ImageInstanceSegmentationTest(unittest.TestCase):
  15. model_id = 'damo/cv_swin-b_image-instance-segmentation_coco'
  16. image = 'data/test/images/image_instance_segmentation.jpg'
  17. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  18. def test_run_with_model_from_modelhub(self):
  19. model = Model.from_pretrained(self.model_id)
  20. config_path = os.path.join(model.model_dir, ModelFile.CONFIGURATION)
  21. cfg = Config.from_file(config_path)
  22. preprocessor = build_preprocessor(cfg.preprocessor, Fields.cv)
  23. pipeline_ins = pipeline(
  24. task=Tasks.image_segmentation,
  25. model=model,
  26. preprocessor=preprocessor)
  27. print(pipeline_ins(input=self.image)[OutputKeys.LABELS])
  28. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  29. def test_run_with_model_name(self):
  30. pipeline_ins = pipeline(
  31. task=Tasks.image_segmentation, model=self.model_id)
  32. print(pipeline_ins(input=self.image)[OutputKeys.LABELS])
  33. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  34. def test_run_with_default_model(self):
  35. pipeline_ins = pipeline(task=Tasks.image_segmentation)
  36. print(pipeline_ins(input=self.image)[OutputKeys.LABELS])
  37. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  38. def test_run_by_direct_model_download(self):
  39. cache_path = snapshot_download(self.model_id)
  40. config_path = os.path.join(cache_path, ModelFile.CONFIGURATION)
  41. cfg = Config.from_file(config_path)
  42. preprocessor = build_preprocessor(cfg.preprocessor, Fields.cv)
  43. model = CascadeMaskRCNNSwinModel(cache_path)
  44. pipeline1 = ImageInstanceSegmentationPipeline(
  45. model, preprocessor=preprocessor)
  46. pipeline2 = pipeline(
  47. Tasks.image_segmentation, model=model, preprocessor=preprocessor)
  48. print(f'pipeline1:{pipeline1(input=self.image)[OutputKeys.LABELS]}')
  49. print(f'pipeline2: {pipeline2(input=self.image)[OutputKeys.LABELS]}')
  50. if __name__ == '__main__':
  51. unittest.main()