You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_card_detection_scrfd_trainer.py 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import glob
  3. import os
  4. import shutil
  5. import tempfile
  6. import unittest
  7. import torch
  8. from modelscope.hub.snapshot_download import snapshot_download
  9. from modelscope.metainfo import Trainers
  10. from modelscope.msdatasets import MsDataset
  11. from modelscope.trainers import build_trainer
  12. from modelscope.utils.config import Config
  13. from modelscope.utils.constant import ModelFile
  14. from modelscope.utils.test_utils import DistributedTestCase, test_level
  15. def _setup():
  16. model_id = 'damo/cv_resnet_carddetection_scrfd34gkps'
  17. # mini dataset only for unit test, remove '_mini' for full dataset.
  18. ms_ds_syncards = MsDataset.load(
  19. 'SyntheticCards_mini', namespace='shaoxuan')
  20. data_path = ms_ds_syncards.config_kwargs['split_config']
  21. train_dir = data_path['train']
  22. val_dir = data_path['validation']
  23. train_root = train_dir + '/' + os.listdir(train_dir)[0] + '/'
  24. val_root = val_dir + '/' + os.listdir(val_dir)[0] + '/'
  25. max_epochs = 1 # run epochs in unit test
  26. cache_path = snapshot_download(model_id)
  27. tmp_dir = tempfile.TemporaryDirectory().name
  28. if not os.path.exists(tmp_dir):
  29. os.makedirs(tmp_dir)
  30. return train_root, val_root, max_epochs, cache_path, tmp_dir
  31. def train_func(**kwargs):
  32. trainer = build_trainer(
  33. name=Trainers.card_detection_scrfd, default_args=kwargs)
  34. trainer.train()
  35. class TestCardDetectionScrfdTrainerSingleGPU(unittest.TestCase):
  36. def setUp(self):
  37. print(('SingleGPU Testing %s.%s' %
  38. (type(self).__name__, self._testMethodName)))
  39. self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup(
  40. )
  41. def tearDown(self):
  42. shutil.rmtree(self.tmp_dir)
  43. super().tearDown()
  44. def _cfg_modify_fn(self, cfg):
  45. cfg.checkpoint_config.interval = 1
  46. cfg.log_config.interval = 10
  47. cfg.evaluation.interval = 1
  48. cfg.data.workers_per_gpu = 3
  49. cfg.data.samples_per_gpu = 4 # batch size
  50. return cfg
  51. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  52. def test_trainer_from_scratch(self):
  53. kwargs = dict(
  54. cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'),
  55. work_dir=self.tmp_dir,
  56. train_root=self.train_root,
  57. val_root=self.val_root,
  58. total_epochs=self.max_epochs,
  59. cfg_modify_fn=self._cfg_modify_fn)
  60. trainer = build_trainer(
  61. name=Trainers.card_detection_scrfd, default_args=kwargs)
  62. trainer.train()
  63. results_files = os.listdir(self.tmp_dir)
  64. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  65. for i in range(self.max_epochs):
  66. self.assertIn(f'epoch_{i+1}.pth', results_files)
  67. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  68. def test_trainer_finetune(self):
  69. pretrain_epoch = 640
  70. self.max_epochs += pretrain_epoch
  71. kwargs = dict(
  72. cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'),
  73. work_dir=self.tmp_dir,
  74. train_root=self.train_root,
  75. val_root=self.val_root,
  76. total_epochs=self.max_epochs,
  77. resume_from=os.path.join(self.cache_path,
  78. ModelFile.TORCH_MODEL_BIN_FILE),
  79. cfg_modify_fn=self._cfg_modify_fn)
  80. trainer = build_trainer(
  81. name=Trainers.card_detection_scrfd, default_args=kwargs)
  82. trainer.train()
  83. results_files = os.listdir(self.tmp_dir)
  84. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  85. for i in range(pretrain_epoch, self.max_epochs):
  86. self.assertIn(f'epoch_{i+1}.pth', results_files)
  87. @unittest.skipIf(not torch.cuda.is_available()
  88. or torch.cuda.device_count() <= 1, 'distributed unittest')
  89. class TestCardDetectionScrfdTrainerMultiGpus(DistributedTestCase):
  90. def setUp(self):
  91. print(('MultiGPUs Testing %s.%s' %
  92. (type(self).__name__, self._testMethodName)))
  93. self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup(
  94. )
  95. cfg_file_path = os.path.join(self.cache_path, 'mmcv_scrfd.py')
  96. cfg = Config.from_file(cfg_file_path)
  97. cfg.checkpoint_config.interval = 1
  98. cfg.log_config.interval = 10
  99. cfg.evaluation.interval = 1
  100. cfg.data.workers_per_gpu = 3
  101. cfg.data.samples_per_gpu = 4
  102. cfg.dump(cfg_file_path)
  103. def tearDown(self):
  104. shutil.rmtree(self.tmp_dir)
  105. super().tearDown()
  106. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  107. def test_multi_gpus_finetune(self):
  108. pretrain_epoch = 640
  109. self.max_epochs += pretrain_epoch
  110. kwargs = dict(
  111. cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'),
  112. work_dir=self.tmp_dir,
  113. train_root=self.train_root,
  114. val_root=self.val_root,
  115. total_epochs=self.max_epochs,
  116. resume_from=os.path.join(self.cache_path,
  117. ModelFile.TORCH_MODEL_BIN_FILE),
  118. launcher='pytorch')
  119. self.start(train_func, num_gpus=2, **kwargs)
  120. results_files = os.listdir(self.tmp_dir)
  121. json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json'))
  122. self.assertEqual(len(json_files), 1)
  123. for i in range(pretrain_epoch, self.max_epochs):
  124. self.assertIn(f'epoch_{i+1}.pth', results_files)
  125. if __name__ == '__main__':
  126. unittest.main()