You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_image_instance_segmentation_trainer.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import tempfile
  5. import unittest
  6. import zipfile
  7. from functools import partial
  8. from modelscope.hub.snapshot_download import snapshot_download
  9. from modelscope.models.cv.image_instance_segmentation import \
  10. CascadeMaskRCNNSwinModel
  11. from modelscope.models.cv.image_instance_segmentation.datasets import \
  12. ImageInstanceSegmentationCocoDataset
  13. from modelscope.trainers import build_trainer
  14. from modelscope.utils.config import Config
  15. from modelscope.utils.constant import ModelFile
  16. from modelscope.utils.test_utils import test_level
  17. class TestImageInstanceSegmentationTrainer(unittest.TestCase):
  18. model_id = 'damo/cv_swin-b_image-instance-segmentation_coco'
  19. def setUp(self):
  20. print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
  21. cache_path = snapshot_download(self.model_id)
  22. config_path = os.path.join(cache_path, ModelFile.CONFIGURATION)
  23. cfg = Config.from_file(config_path)
  24. data_root = cfg.dataset.data_root
  25. classes = tuple(cfg.dataset.classes)
  26. max_epochs = cfg.train.max_epochs
  27. samples_per_gpu = cfg.train.dataloader.batch_size_per_gpu
  28. if data_root is None:
  29. # use default toy data
  30. dataset_path = os.path.join(cache_path, 'toydata.zip')
  31. with zipfile.ZipFile(dataset_path, 'r') as zipf:
  32. zipf.extractall(cache_path)
  33. data_root = cache_path + '/toydata/'
  34. classes = ('Cat', 'Dog')
  35. self.train_dataset = ImageInstanceSegmentationCocoDataset(
  36. data_root + 'annotations/instances_train.json',
  37. classes=classes,
  38. data_root=data_root,
  39. img_prefix=data_root + 'images/train/',
  40. seg_prefix=None,
  41. test_mode=False)
  42. self.eval_dataset = ImageInstanceSegmentationCocoDataset(
  43. data_root + 'annotations/instances_val.json',
  44. classes=classes,
  45. data_root=data_root,
  46. img_prefix=data_root + 'images/val/',
  47. seg_prefix=None,
  48. test_mode=True)
  49. from mmcv.parallel import collate
  50. self.collate_fn = partial(collate, samples_per_gpu=samples_per_gpu)
  51. self.max_epochs = max_epochs
  52. self.tmp_dir = tempfile.TemporaryDirectory().name
  53. if not os.path.exists(self.tmp_dir):
  54. os.makedirs(self.tmp_dir)
  55. def tearDown(self):
  56. shutil.rmtree(self.tmp_dir)
  57. super().tearDown()
  58. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  59. def test_trainer(self):
  60. kwargs = dict(
  61. model=self.model_id,
  62. data_collator=self.collate_fn,
  63. train_dataset=self.train_dataset,
  64. eval_dataset=self.eval_dataset,
  65. work_dir=self.tmp_dir)
  66. trainer = build_trainer(
  67. name='image-instance-segmentation', default_args=kwargs)
  68. trainer.train()
  69. results_files = os.listdir(self.tmp_dir)
  70. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  71. for i in range(self.max_epochs):
  72. self.assertIn(f'epoch_{i+1}.pth', results_files)
  73. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  74. def test_trainer_with_model_and_args(self):
  75. tmp_dir = tempfile.TemporaryDirectory().name
  76. if not os.path.exists(tmp_dir):
  77. os.makedirs(tmp_dir)
  78. cache_path = snapshot_download(self.model_id)
  79. model = CascadeMaskRCNNSwinModel.from_pretrained(cache_path)
  80. kwargs = dict(
  81. cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),
  82. model=model,
  83. data_collator=self.collate_fn,
  84. train_dataset=self.train_dataset,
  85. eval_dataset=self.eval_dataset,
  86. work_dir=self.tmp_dir)
  87. trainer = build_trainer(
  88. name='image-instance-segmentation', default_args=kwargs)
  89. trainer.train()
  90. results_files = os.listdir(self.tmp_dir)
  91. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  92. for i in range(self.max_epochs):
  93. self.assertIn(f'epoch_{i+1}.pth', results_files)
  94. if __name__ == '__main__':
  95. unittest.main()