You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_easycv_trainer.py 8.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import glob
  3. import os
  4. import shutil
  5. import tempfile
  6. import unittest
  7. import json
  8. import torch
  9. from modelscope.metainfo import Models, Pipelines, Trainers
  10. from modelscope.msdatasets import MsDataset
  11. from modelscope.trainers import build_trainer
  12. from modelscope.utils.config import Config
  13. from modelscope.utils.constant import LogKeys, ModeKeys, Tasks
  14. from modelscope.utils.logger import get_logger
  15. from modelscope.utils.test_utils import DistributedTestCase, test_level
  16. from modelscope.utils.torch_utils import is_master
  17. def train_func(work_dir, dist=False, log_interval=3, imgs_per_gpu=4):
  18. import easycv
  19. config_path = os.path.join(
  20. os.path.dirname(easycv.__file__),
  21. 'configs/detection/yolox/yolox_s_8xb16_300e_coco.py')
  22. cfg = Config.from_file(config_path)
  23. cfg.log_config.update(
  24. dict(hooks=[
  25. dict(type='TextLoggerHook'),
  26. dict(type='TensorboardLoggerHook')
  27. ])) # not support TensorboardLoggerHookV2
  28. ms_cfg_file = os.path.join(work_dir, 'ms_yolox_s_8xb16_300e_coco.json')
  29. from easycv.utils.ms_utils import to_ms_config
  30. if is_master():
  31. to_ms_config(
  32. cfg,
  33. dump=True,
  34. task=Tasks.image_object_detection,
  35. ms_model_name=Models.yolox,
  36. pipeline_name=Pipelines.easycv_detection,
  37. save_path=ms_cfg_file)
  38. trainer_name = Trainers.easycv
  39. train_dataset = MsDataset.load(
  40. dataset_name='small_coco_for_test', namespace='EasyCV', split='train')
  41. eval_dataset = MsDataset.load(
  42. dataset_name='small_coco_for_test',
  43. namespace='EasyCV',
  44. split='validation')
  45. cfg_options = {
  46. 'train.max_epochs':
  47. 2,
  48. 'train.dataloader.batch_size_per_gpu':
  49. imgs_per_gpu,
  50. 'evaluation.dataloader.batch_size_per_gpu':
  51. 2,
  52. 'train.hooks': [
  53. {
  54. 'type': 'CheckpointHook',
  55. 'interval': 1
  56. },
  57. {
  58. 'type': 'EvaluationHook',
  59. 'interval': 1
  60. },
  61. {
  62. 'type': 'TextLoggerHook',
  63. 'ignore_rounding_keys': None,
  64. 'interval': log_interval
  65. },
  66. ]
  67. }
  68. kwargs = dict(
  69. cfg_file=ms_cfg_file,
  70. train_dataset=train_dataset,
  71. eval_dataset=eval_dataset,
  72. work_dir=work_dir,
  73. cfg_options=cfg_options,
  74. launcher='pytorch' if dist else None)
  75. trainer = build_trainer(trainer_name, kwargs)
  76. trainer.train()
  77. @unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest')
  78. class EasyCVTrainerTestSingleGpu(unittest.TestCase):
  79. def setUp(self):
  80. self.logger = get_logger()
  81. self.logger.info(('Testing %s.%s' %
  82. (type(self).__name__, self._testMethodName)))
  83. self.tmp_dir = tempfile.TemporaryDirectory().name
  84. if not os.path.exists(self.tmp_dir):
  85. os.makedirs(self.tmp_dir)
  86. def tearDown(self):
  87. super().tearDown()
  88. shutil.rmtree(self.tmp_dir, ignore_errors=True)
  89. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  90. def test_single_gpu(self):
  91. train_func(self.tmp_dir)
  92. results_files = os.listdir(self.tmp_dir)
  93. json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json'))
  94. self.assertEqual(len(json_files), 1)
  95. with open(json_files[0], 'r', encoding='utf-8') as f:
  96. lines = [i.strip() for i in f.readlines()]
  97. self.assertDictContainsSubset(
  98. {
  99. LogKeys.MODE: ModeKeys.TRAIN,
  100. LogKeys.EPOCH: 1,
  101. LogKeys.ITER: 3,
  102. LogKeys.LR: 0.00013
  103. }, json.loads(lines[0]))
  104. self.assertDictContainsSubset(
  105. {
  106. LogKeys.MODE: ModeKeys.EVAL,
  107. LogKeys.EPOCH: 1,
  108. LogKeys.ITER: 10
  109. }, json.loads(lines[1]))
  110. self.assertDictContainsSubset(
  111. {
  112. LogKeys.MODE: ModeKeys.TRAIN,
  113. LogKeys.EPOCH: 2,
  114. LogKeys.ITER: 3,
  115. LogKeys.LR: 0.00157
  116. }, json.loads(lines[2]))
  117. self.assertDictContainsSubset(
  118. {
  119. LogKeys.MODE: ModeKeys.EVAL,
  120. LogKeys.EPOCH: 2,
  121. LogKeys.ITER: 10
  122. }, json.loads(lines[3]))
  123. self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
  124. self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
  125. for i in [0, 2]:
  126. self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i])
  127. self.assertIn(LogKeys.ITER_TIME, lines[i])
  128. self.assertIn(LogKeys.MEMORY, lines[i])
  129. self.assertIn('total_loss', lines[i])
  130. for i in [1, 3]:
  131. self.assertIn(
  132. 'CocoDetectionEvaluator_DetectionBoxes_Precision/mAP',
  133. lines[i])
  134. self.assertIn('DetectionBoxes_Precision/mAP', lines[i])
  135. self.assertIn('DetectionBoxes_Precision/mAP@.50IOU', lines[i])
  136. self.assertIn('DetectionBoxes_Precision/mAP@.75IOU', lines[i])
  137. self.assertIn('DetectionBoxes_Precision/mAP (small)', lines[i])
  138. @unittest.skipIf(not torch.cuda.is_available()
  139. or torch.cuda.device_count() <= 1, 'distributed unittest')
  140. class EasyCVTrainerTestMultiGpus(DistributedTestCase):
  141. def setUp(self):
  142. self.logger = get_logger()
  143. self.logger.info(('Testing %s.%s' %
  144. (type(self).__name__, self._testMethodName)))
  145. self.tmp_dir = tempfile.TemporaryDirectory().name
  146. if not os.path.exists(self.tmp_dir):
  147. os.makedirs(self.tmp_dir)
  148. def tearDown(self):
  149. super().tearDown()
  150. shutil.rmtree(self.tmp_dir, ignore_errors=True)
  151. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  152. def test_multi_gpus(self):
  153. self.start(
  154. train_func,
  155. num_gpus=2,
  156. work_dir=self.tmp_dir,
  157. dist=True,
  158. log_interval=2,
  159. imgs_per_gpu=5)
  160. results_files = os.listdir(self.tmp_dir)
  161. json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json'))
  162. self.assertEqual(len(json_files), 1)
  163. with open(json_files[0], 'r', encoding='utf-8') as f:
  164. lines = [i.strip() for i in f.readlines()]
  165. self.assertDictContainsSubset(
  166. {
  167. LogKeys.MODE: ModeKeys.TRAIN,
  168. LogKeys.EPOCH: 1,
  169. LogKeys.ITER: 2,
  170. LogKeys.LR: 0.0002
  171. }, json.loads(lines[0]))
  172. self.assertDictContainsSubset(
  173. {
  174. LogKeys.MODE: ModeKeys.EVAL,
  175. LogKeys.EPOCH: 1,
  176. LogKeys.ITER: 5
  177. }, json.loads(lines[1]))
  178. self.assertDictContainsSubset(
  179. {
  180. LogKeys.MODE: ModeKeys.TRAIN,
  181. LogKeys.EPOCH: 2,
  182. LogKeys.ITER: 2,
  183. LogKeys.LR: 0.0018
  184. }, json.loads(lines[2]))
  185. self.assertDictContainsSubset(
  186. {
  187. LogKeys.MODE: ModeKeys.EVAL,
  188. LogKeys.EPOCH: 2,
  189. LogKeys.ITER: 5
  190. }, json.loads(lines[3]))
  191. self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
  192. self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
  193. for i in [0, 2]:
  194. self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i])
  195. self.assertIn(LogKeys.ITER_TIME, lines[i])
  196. self.assertIn(LogKeys.MEMORY, lines[i])
  197. self.assertIn('total_loss', lines[i])
  198. for i in [1, 3]:
  199. self.assertIn(
  200. 'CocoDetectionEvaluator_DetectionBoxes_Precision/mAP',
  201. lines[i])
  202. self.assertIn('DetectionBoxes_Precision/mAP', lines[i])
  203. self.assertIn('DetectionBoxes_Precision/mAP@.50IOU', lines[i])
  204. self.assertIn('DetectionBoxes_Precision/mAP@.75IOU', lines[i])
  205. self.assertIn('DetectionBoxes_Precision/mAP (small)', lines[i])
  206. if __name__ == '__main__':
  207. unittest.main()