You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trainer.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import glob
  3. import os
  4. import shutil
  5. import tempfile
  6. import unittest
  7. import cv2
  8. import json
  9. import numpy as np
  10. import torch
  11. from torch import nn
  12. from torch.optim import SGD
  13. from torch.optim.lr_scheduler import StepLR
  14. from torch.utils.data import IterableDataset
  15. from modelscope.metainfo import Metrics, Trainers
  16. from modelscope.metrics.builder import MetricKeys
  17. from modelscope.models.base import Model
  18. from modelscope.trainers import build_trainer
  19. from modelscope.trainers.base import DummyTrainer
  20. from modelscope.trainers.builder import TRAINERS
  21. from modelscope.trainers.trainer import EpochBasedTrainer
  22. from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile, Tasks
  23. from modelscope.utils.test_utils import create_dummy_test_dataset, test_level
  24. class DummyIterableDataset(IterableDataset):
  25. def __iter__(self):
  26. feat = np.random.random(size=(5, )).astype(np.float32)
  27. labels = np.random.randint(0, 4, (1, ))
  28. iterations = [{'feat': feat, 'labels': labels}] * 500
  29. return iter(iterations)
  30. dummy_dataset_small = create_dummy_test_dataset(
  31. np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20)
  32. dummy_dataset_big = create_dummy_test_dataset(
  33. np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 40)
  34. class DummyModel(nn.Module, Model):
  35. def __init__(self):
  36. super().__init__()
  37. self.linear = nn.Linear(5, 4)
  38. self.bn = nn.BatchNorm1d(4)
  39. def forward(self, feat, labels):
  40. x = self.linear(feat)
  41. x = self.bn(x)
  42. loss = torch.sum(x)
  43. return dict(logits=x, loss=loss)
  44. @TRAINERS.register_module(module_name='test_vis')
  45. class VisTrainer(EpochBasedTrainer):
  46. def visualization(self, results, dataset, **kwargs):
  47. num_image = 5
  48. f = 'data/test/images/bird.JPEG'
  49. filenames = [f for _ in range(num_image)]
  50. imgs = [cv2.imread(f) for f in filenames]
  51. filenames = [f + str(i) for i in range(num_image)]
  52. vis_results = {'images': imgs, 'filenames': filenames}
  53. # visualization results will be displayed in group named eva_vis
  54. self.visualization_buffer.output['eval_vis'] = vis_results
  55. class TrainerTest(unittest.TestCase):
  56. def setUp(self):
  57. print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
  58. self.tmp_dir = tempfile.TemporaryDirectory().name
  59. if not os.path.exists(self.tmp_dir):
  60. os.makedirs(self.tmp_dir)
  61. def tearDown(self):
  62. super().tearDown()
  63. shutil.rmtree(self.tmp_dir)
  64. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  65. def test_train_0(self):
  66. json_cfg = {
  67. 'task': Tasks.image_classification,
  68. 'train': {
  69. 'work_dir':
  70. self.tmp_dir,
  71. 'dataloader': {
  72. 'batch_size_per_gpu': 2,
  73. 'workers_per_gpu': 1
  74. },
  75. 'optimizer': {
  76. 'type': 'SGD',
  77. 'lr': 0.01,
  78. 'options': {
  79. 'grad_clip': {
  80. 'max_norm': 2.0
  81. }
  82. }
  83. },
  84. 'lr_scheduler': {
  85. 'type': 'StepLR',
  86. 'step_size': 2,
  87. 'options': {
  88. 'warmup': {
  89. 'type': 'LinearWarmup',
  90. 'warmup_iters': 2
  91. }
  92. }
  93. },
  94. 'hooks': [{
  95. 'type': 'CheckpointHook',
  96. 'interval': 1
  97. }, {
  98. 'type': 'TextLoggerHook',
  99. 'interval': 1
  100. }, {
  101. 'type': 'IterTimerHook'
  102. }, {
  103. 'type': 'EvaluationHook',
  104. 'interval': 1
  105. }, {
  106. 'type': 'TensorboardHook',
  107. 'interval': 1
  108. }]
  109. },
  110. 'evaluation': {
  111. 'dataloader': {
  112. 'batch_size_per_gpu': 2,
  113. 'workers_per_gpu': 1,
  114. 'shuffle': False
  115. },
  116. 'metrics': [Metrics.seq_cls_metric],
  117. }
  118. }
  119. config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION)
  120. with open(config_path, 'w') as f:
  121. json.dump(json_cfg, f)
  122. trainer_name = Trainers.default
  123. kwargs = dict(
  124. cfg_file=config_path,
  125. model=DummyModel(),
  126. data_collator=None,
  127. train_dataset=dummy_dataset_small,
  128. eval_dataset=dummy_dataset_small,
  129. max_epochs=3,
  130. device='cpu')
  131. trainer = build_trainer(trainer_name, kwargs)
  132. trainer.train()
  133. results_files = os.listdir(self.tmp_dir)
  134. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  135. self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
  136. self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
  137. self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files)
  138. self.assertIn('tensorboard_output', results_files)
  139. self.assertTrue(len(glob.glob(f'{self.tmp_dir}/*/*events*')) > 0)
  140. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  141. def test_train_visualization(self):
  142. json_cfg = {
  143. 'task': Tasks.image_classification,
  144. 'train': {
  145. 'work_dir':
  146. self.tmp_dir,
  147. 'dataloader': {
  148. 'batch_size_per_gpu': 2,
  149. 'workers_per_gpu': 1
  150. },
  151. 'optimizer': {
  152. 'type': 'SGD',
  153. 'lr': 0.01,
  154. 'options': {
  155. 'grad_clip': {
  156. 'max_norm': 2.0
  157. }
  158. }
  159. },
  160. 'lr_scheduler': {
  161. 'type': 'StepLR',
  162. 'step_size': 2,
  163. 'options': {
  164. 'warmup': {
  165. 'type': 'LinearWarmup',
  166. 'warmup_iters': 2
  167. }
  168. }
  169. },
  170. 'hooks': [{
  171. 'type': 'CheckpointHook',
  172. 'interval': 1
  173. }, {
  174. 'type': 'TextLoggerHook',
  175. 'interval': 1
  176. }, {
  177. 'type': 'IterTimerHook'
  178. }, {
  179. 'type': 'EvaluationHook',
  180. 'interval': 1
  181. }, {
  182. 'type': 'TensorboardHook',
  183. 'interval': 1
  184. }]
  185. },
  186. 'evaluation': {
  187. 'dataloader': {
  188. 'batch_size_per_gpu': 2,
  189. 'workers_per_gpu': 1,
  190. 'shuffle': False
  191. },
  192. 'metrics': [Metrics.seq_cls_metric],
  193. 'visualization': {},
  194. }
  195. }
  196. config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION)
  197. with open(config_path, 'w') as f:
  198. json.dump(json_cfg, f)
  199. trainer_name = 'test_vis'
  200. kwargs = dict(
  201. cfg_file=config_path,
  202. model=DummyModel(),
  203. data_collator=None,
  204. train_dataset=dummy_dataset_small,
  205. eval_dataset=dummy_dataset_small,
  206. max_epochs=3,
  207. device='cpu')
  208. trainer = build_trainer(trainer_name, kwargs)
  209. trainer.train()
  210. results_files = os.listdir(self.tmp_dir)
  211. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  212. self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
  213. self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
  214. self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files)
  215. self.assertTrue(len(glob.glob(f'{self.tmp_dir}/*/*events*')) > 0)
  216. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  217. def test_train_1(self):
  218. json_cfg = {
  219. 'task': Tasks.image_classification,
  220. 'train': {
  221. 'work_dir':
  222. self.tmp_dir,
  223. 'dataloader': {
  224. 'batch_size_per_gpu': 2,
  225. 'workers_per_gpu': 1
  226. },
  227. 'hooks': [{
  228. 'type': 'CheckpointHook',
  229. 'interval': 1
  230. }, {
  231. 'type': 'TextLoggerHook',
  232. 'interval': 1
  233. }, {
  234. 'type': 'IterTimerHook'
  235. }, {
  236. 'type': 'EvaluationHook',
  237. 'interval': 1
  238. }]
  239. },
  240. 'evaluation': {
  241. 'dataloader': {
  242. 'batch_size_per_gpu': 2,
  243. 'workers_per_gpu': 1,
  244. 'shuffle': False
  245. },
  246. 'metrics': [Metrics.seq_cls_metric]
  247. }
  248. }
  249. config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION)
  250. with open(config_path, 'w') as f:
  251. json.dump(json_cfg, f)
  252. model = DummyModel()
  253. optimmizer = SGD(model.parameters(), lr=0.01)
  254. lr_scheduler = StepLR(optimmizer, 2)
  255. trainer_name = Trainers.default
  256. kwargs = dict(
  257. cfg_file=config_path,
  258. model=model,
  259. data_collator=None,
  260. train_dataset=dummy_dataset_small,
  261. eval_dataset=dummy_dataset_small,
  262. optimizers=(optimmizer, lr_scheduler),
  263. max_epochs=3,
  264. device='cpu')
  265. trainer = build_trainer(trainer_name, kwargs)
  266. trainer.train()
  267. results_files = os.listdir(self.tmp_dir)
  268. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  269. self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
  270. self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
  271. self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files)
  272. self.assertTrue(len(glob.glob(f'{self.tmp_dir}/*/*events*')) > 0)
  273. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  274. def test_train_with_default_config(self):
  275. json_cfg = {
  276. 'task': Tasks.image_classification,
  277. 'train': {
  278. 'work_dir': self.tmp_dir,
  279. 'dataloader': {
  280. 'batch_size_per_gpu': 2,
  281. 'workers_per_gpu': 1
  282. },
  283. 'hooks': [{
  284. 'type': 'EvaluationHook',
  285. 'interval': 1
  286. }]
  287. },
  288. 'evaluation': {
  289. 'dataloader': {
  290. 'batch_size_per_gpu': 2,
  291. 'workers_per_gpu': 1,
  292. 'shuffle': False
  293. },
  294. 'metrics': [Metrics.seq_cls_metric]
  295. }
  296. }
  297. config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION)
  298. with open(config_path, 'w') as f:
  299. json.dump(json_cfg, f)
  300. model = DummyModel()
  301. optimmizer = SGD(model.parameters(), lr=0.01)
  302. lr_scheduler = StepLR(optimmizer, 2)
  303. trainer_name = Trainers.default
  304. kwargs = dict(
  305. cfg_file=config_path,
  306. model=model,
  307. data_collator=None,
  308. train_dataset=dummy_dataset_big,
  309. eval_dataset=dummy_dataset_small,
  310. optimizers=(optimmizer, lr_scheduler),
  311. max_epochs=3,
  312. device='cpu')
  313. trainer = build_trainer(trainer_name, kwargs)
  314. trainer.train()
  315. results_files = os.listdir(self.tmp_dir)
  316. json_file = os.path.join(self.tmp_dir, f'{trainer.timestamp}.log.json')
  317. with open(json_file, 'r', encoding='utf-8') as f:
  318. lines = [i.strip() for i in f.readlines()]
  319. self.assertDictContainsSubset(
  320. {
  321. LogKeys.MODE: ModeKeys.TRAIN,
  322. LogKeys.EPOCH: 1,
  323. LogKeys.ITER: 10,
  324. LogKeys.LR: 0.01
  325. }, json.loads(lines[0]))
  326. self.assertDictContainsSubset(
  327. {
  328. LogKeys.MODE: ModeKeys.TRAIN,
  329. LogKeys.EPOCH: 1,
  330. LogKeys.ITER: 20,
  331. LogKeys.LR: 0.01
  332. }, json.loads(lines[1]))
  333. self.assertDictContainsSubset(
  334. {
  335. LogKeys.MODE: ModeKeys.EVAL,
  336. LogKeys.EPOCH: 1,
  337. LogKeys.ITER: 10
  338. }, json.loads(lines[2]))
  339. self.assertDictContainsSubset(
  340. {
  341. LogKeys.MODE: ModeKeys.TRAIN,
  342. LogKeys.EPOCH: 2,
  343. LogKeys.ITER: 10,
  344. LogKeys.LR: 0.01
  345. }, json.loads(lines[3]))
  346. self.assertDictContainsSubset(
  347. {
  348. LogKeys.MODE: ModeKeys.TRAIN,
  349. LogKeys.EPOCH: 2,
  350. LogKeys.ITER: 20,
  351. LogKeys.LR: 0.01
  352. }, json.loads(lines[4]))
  353. self.assertDictContainsSubset(
  354. {
  355. LogKeys.MODE: ModeKeys.EVAL,
  356. LogKeys.EPOCH: 2,
  357. LogKeys.ITER: 10
  358. }, json.loads(lines[5]))
  359. self.assertDictContainsSubset(
  360. {
  361. LogKeys.MODE: ModeKeys.TRAIN,
  362. LogKeys.EPOCH: 3,
  363. LogKeys.ITER: 10,
  364. LogKeys.LR: 0.001
  365. }, json.loads(lines[6]))
  366. self.assertDictContainsSubset(
  367. {
  368. LogKeys.MODE: ModeKeys.TRAIN,
  369. LogKeys.EPOCH: 3,
  370. LogKeys.ITER: 20,
  371. LogKeys.LR: 0.001
  372. }, json.loads(lines[7]))
  373. self.assertDictContainsSubset(
  374. {
  375. LogKeys.MODE: ModeKeys.EVAL,
  376. LogKeys.EPOCH: 3,
  377. LogKeys.ITER: 10
  378. }, json.loads(lines[8]))
  379. self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
  380. self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
  381. self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files)
  382. for i in [0, 1, 3, 4, 6, 7]:
  383. self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i])
  384. self.assertIn(LogKeys.ITER_TIME, lines[i])
  385. for i in [2, 5, 8]:
  386. self.assertIn(MetricKeys.ACCURACY, lines[i])
  387. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  388. def test_train_with_iters_per_epoch(self):
  389. json_cfg = {
  390. 'task': Tasks.image_classification,
  391. 'train': {
  392. 'work_dir': self.tmp_dir,
  393. 'dataloader': {
  394. 'batch_size_per_gpu': 2,
  395. 'workers_per_gpu': 1
  396. },
  397. 'hooks': [{
  398. 'type': 'EvaluationHook',
  399. 'interval': 1
  400. }]
  401. },
  402. 'evaluation': {
  403. 'dataloader': {
  404. 'batch_size_per_gpu': 2,
  405. 'workers_per_gpu': 1,
  406. 'shuffle': False
  407. },
  408. 'metrics': [Metrics.seq_cls_metric]
  409. }
  410. }
  411. config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION)
  412. with open(config_path, 'w') as f:
  413. json.dump(json_cfg, f)
  414. model = DummyModel()
  415. optimmizer = SGD(model.parameters(), lr=0.01)
  416. lr_scheduler = StepLR(optimmizer, 2)
  417. trainer_name = Trainers.default
  418. kwargs = dict(
  419. cfg_file=config_path,
  420. model=model,
  421. data_collator=None,
  422. optimizers=(optimmizer, lr_scheduler),
  423. train_dataset=DummyIterableDataset(),
  424. eval_dataset=DummyIterableDataset(),
  425. train_iters_per_epoch=20,
  426. val_iters_per_epoch=10,
  427. max_epochs=3,
  428. device='cpu')
  429. trainer = build_trainer(trainer_name, kwargs)
  430. trainer.train()
  431. results_files = os.listdir(self.tmp_dir)
  432. json_file = os.path.join(self.tmp_dir, f'{trainer.timestamp}.log.json')
  433. with open(json_file, 'r', encoding='utf-8') as f:
  434. lines = [i.strip() for i in f.readlines()]
  435. self.assertDictContainsSubset(
  436. {
  437. LogKeys.MODE: ModeKeys.TRAIN,
  438. LogKeys.EPOCH: 1,
  439. LogKeys.ITER: 10,
  440. LogKeys.LR: 0.01
  441. }, json.loads(lines[0]))
  442. self.assertDictContainsSubset(
  443. {
  444. LogKeys.MODE: ModeKeys.TRAIN,
  445. LogKeys.EPOCH: 1,
  446. LogKeys.ITER: 20,
  447. LogKeys.LR: 0.01
  448. }, json.loads(lines[1]))
  449. self.assertDictContainsSubset(
  450. {
  451. LogKeys.MODE: ModeKeys.EVAL,
  452. LogKeys.EPOCH: 1,
  453. LogKeys.ITER: 10
  454. }, json.loads(lines[2]))
  455. self.assertDictContainsSubset(
  456. {
  457. LogKeys.MODE: ModeKeys.TRAIN,
  458. LogKeys.EPOCH: 2,
  459. LogKeys.ITER: 10,
  460. LogKeys.LR: 0.01
  461. }, json.loads(lines[3]))
  462. self.assertDictContainsSubset(
  463. {
  464. LogKeys.MODE: ModeKeys.TRAIN,
  465. LogKeys.EPOCH: 2,
  466. LogKeys.ITER: 20,
  467. LogKeys.LR: 0.01
  468. }, json.loads(lines[4]))
  469. self.assertDictContainsSubset(
  470. {
  471. LogKeys.MODE: ModeKeys.EVAL,
  472. LogKeys.EPOCH: 2,
  473. LogKeys.ITER: 10
  474. }, json.loads(lines[5]))
  475. self.assertDictContainsSubset(
  476. {
  477. LogKeys.MODE: ModeKeys.TRAIN,
  478. LogKeys.EPOCH: 3,
  479. LogKeys.ITER: 10,
  480. LogKeys.LR: 0.001
  481. }, json.loads(lines[6]))
  482. self.assertDictContainsSubset(
  483. {
  484. LogKeys.MODE: ModeKeys.TRAIN,
  485. LogKeys.EPOCH: 3,
  486. LogKeys.ITER: 20,
  487. LogKeys.LR: 0.001
  488. }, json.loads(lines[7]))
  489. self.assertDictContainsSubset(
  490. {
  491. LogKeys.MODE: ModeKeys.EVAL,
  492. LogKeys.EPOCH: 3,
  493. LogKeys.ITER: 10
  494. }, json.loads(lines[8]))
  495. self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
  496. self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
  497. self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files)
  498. for i in [0, 1, 3, 4, 6, 7]:
  499. self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i])
  500. self.assertIn(LogKeys.ITER_TIME, lines[i])
  501. for i in [2, 5, 8]:
  502. self.assertIn(MetricKeys.ACCURACY, lines[i])
  503. class DummyTrainerTest(unittest.TestCase):
  504. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  505. def test_dummy(self):
  506. default_args = dict(cfg_file='configs/examples/train.json')
  507. trainer = build_trainer('dummy', default_args)
  508. trainer.train()
  509. trainer.evaluate()
  510. if __name__ == '__main__':
  511. unittest.main()