You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trainer.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import tempfile
  5. import unittest
  6. import json
  7. import numpy as np
  8. import torch
  9. from torch import nn
  10. from torch.optim import SGD
  11. from torch.optim.lr_scheduler import StepLR
  12. from torch.utils.data import IterableDataset
  13. from modelscope.metainfo import Metrics, Trainers
  14. from modelscope.metrics.builder import MetricKeys
  15. from modelscope.models.base import Model
  16. from modelscope.trainers import build_trainer
  17. from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile
  18. from modelscope.utils.test_utils import create_dummy_test_dataset, test_level
  19. class DummyIterableDataset(IterableDataset):
  20. def __iter__(self):
  21. feat = np.random.random(size=(5, )).astype(np.float32)
  22. labels = np.random.randint(0, 4, (1, ))
  23. iterations = [{'feat': feat, 'labels': labels}] * 500
  24. return iter(iterations)
  25. dummy_dataset_small = create_dummy_test_dataset(
  26. np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20)
  27. dummy_dataset_big = create_dummy_test_dataset(
  28. np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 40)
  29. class DummyModel(nn.Module, Model):
  30. def __init__(self):
  31. super().__init__()
  32. self.linear = nn.Linear(5, 4)
  33. self.bn = nn.BatchNorm1d(4)
  34. def forward(self, feat, labels):
  35. x = self.linear(feat)
  36. x = self.bn(x)
  37. loss = torch.sum(x)
  38. return dict(logits=x, loss=loss)
  39. class TrainerTest(unittest.TestCase):
  40. def setUp(self):
  41. print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
  42. self.tmp_dir = tempfile.TemporaryDirectory().name
  43. if not os.path.exists(self.tmp_dir):
  44. os.makedirs(self.tmp_dir)
  45. def tearDown(self):
  46. super().tearDown()
  47. shutil.rmtree(self.tmp_dir)
  48. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  49. def test_train_0(self):
  50. json_cfg = {
  51. 'train': {
  52. 'work_dir':
  53. self.tmp_dir,
  54. 'dataloader': {
  55. 'batch_size_per_gpu': 2,
  56. 'workers_per_gpu': 1
  57. },
  58. 'optimizer': {
  59. 'type': 'SGD',
  60. 'lr': 0.01,
  61. 'options': {
  62. 'grad_clip': {
  63. 'max_norm': 2.0
  64. }
  65. }
  66. },
  67. 'lr_scheduler': {
  68. 'type': 'StepLR',
  69. 'step_size': 2,
  70. 'options': {
  71. 'warmup': {
  72. 'type': 'LinearWarmup',
  73. 'warmup_iters': 2
  74. }
  75. }
  76. },
  77. 'hooks': [{
  78. 'type': 'CheckpointHook',
  79. 'interval': 1
  80. }, {
  81. 'type': 'TextLoggerHook',
  82. 'interval': 1
  83. }, {
  84. 'type': 'IterTimerHook'
  85. }, {
  86. 'type': 'EvaluationHook',
  87. 'interval': 1
  88. }]
  89. },
  90. 'evaluation': {
  91. 'dataloader': {
  92. 'batch_size_per_gpu': 2,
  93. 'workers_per_gpu': 1,
  94. 'shuffle': False
  95. },
  96. 'metrics': [Metrics.seq_cls_metric]
  97. }
  98. }
  99. config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION)
  100. with open(config_path, 'w') as f:
  101. json.dump(json_cfg, f)
  102. trainer_name = Trainers.default
  103. kwargs = dict(
  104. cfg_file=config_path,
  105. model=DummyModel(),
  106. data_collator=None,
  107. train_dataset=dummy_dataset_small,
  108. eval_dataset=dummy_dataset_small,
  109. max_epochs=3,
  110. device='cpu')
  111. trainer = build_trainer(trainer_name, kwargs)
  112. trainer.train()
  113. results_files = os.listdir(self.tmp_dir)
  114. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  115. self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
  116. self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
  117. self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files)
  118. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  119. def test_train_1(self):
  120. json_cfg = {
  121. 'train': {
  122. 'work_dir':
  123. self.tmp_dir,
  124. 'dataloader': {
  125. 'batch_size_per_gpu': 2,
  126. 'workers_per_gpu': 1
  127. },
  128. 'hooks': [{
  129. 'type': 'CheckpointHook',
  130. 'interval': 1
  131. }, {
  132. 'type': 'TextLoggerHook',
  133. 'interval': 1
  134. }, {
  135. 'type': 'IterTimerHook'
  136. }, {
  137. 'type': 'EvaluationHook',
  138. 'interval': 1
  139. }]
  140. },
  141. 'evaluation': {
  142. 'dataloader': {
  143. 'batch_size_per_gpu': 2,
  144. 'workers_per_gpu': 1,
  145. 'shuffle': False
  146. },
  147. 'metrics': [Metrics.seq_cls_metric]
  148. }
  149. }
  150. config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION)
  151. with open(config_path, 'w') as f:
  152. json.dump(json_cfg, f)
  153. model = DummyModel()
  154. optimmizer = SGD(model.parameters(), lr=0.01)
  155. lr_scheduler = StepLR(optimmizer, 2)
  156. trainer_name = Trainers.default
  157. kwargs = dict(
  158. cfg_file=config_path,
  159. model=model,
  160. data_collator=None,
  161. train_dataset=dummy_dataset_small,
  162. eval_dataset=dummy_dataset_small,
  163. optimizers=(optimmizer, lr_scheduler),
  164. max_epochs=3,
  165. device='cpu')
  166. trainer = build_trainer(trainer_name, kwargs)
  167. trainer.train()
  168. results_files = os.listdir(self.tmp_dir)
  169. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  170. self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
  171. self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
  172. self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files)
  173. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  174. def test_train_with_default_config(self):
  175. json_cfg = {
  176. 'train': {
  177. 'work_dir': self.tmp_dir,
  178. 'dataloader': {
  179. 'batch_size_per_gpu': 2,
  180. 'workers_per_gpu': 1
  181. },
  182. 'hooks': [{
  183. 'type': 'EvaluationHook',
  184. 'interval': 1
  185. }]
  186. },
  187. 'evaluation': {
  188. 'dataloader': {
  189. 'batch_size_per_gpu': 2,
  190. 'workers_per_gpu': 1,
  191. 'shuffle': False
  192. },
  193. 'metrics': [Metrics.seq_cls_metric]
  194. }
  195. }
  196. config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION)
  197. with open(config_path, 'w') as f:
  198. json.dump(json_cfg, f)
  199. model = DummyModel()
  200. optimmizer = SGD(model.parameters(), lr=0.01)
  201. lr_scheduler = StepLR(optimmizer, 2)
  202. trainer_name = Trainers.default
  203. kwargs = dict(
  204. cfg_file=config_path,
  205. model=model,
  206. data_collator=None,
  207. train_dataset=dummy_dataset_big,
  208. eval_dataset=dummy_dataset_small,
  209. optimizers=(optimmizer, lr_scheduler),
  210. max_epochs=3,
  211. device='cpu')
  212. trainer = build_trainer(trainer_name, kwargs)
  213. trainer.train()
  214. results_files = os.listdir(self.tmp_dir)
  215. json_file = os.path.join(self.tmp_dir, f'{trainer.timestamp}.log.json')
  216. with open(json_file, 'r') as f:
  217. lines = [i.strip() for i in f.readlines()]
  218. self.assertDictContainsSubset(
  219. {
  220. LogKeys.MODE: ModeKeys.TRAIN,
  221. LogKeys.EPOCH: 1,
  222. LogKeys.ITER: 10,
  223. LogKeys.LR: 0.01
  224. }, json.loads(lines[0]))
  225. self.assertDictContainsSubset(
  226. {
  227. LogKeys.MODE: ModeKeys.TRAIN,
  228. LogKeys.EPOCH: 1,
  229. LogKeys.ITER: 20,
  230. LogKeys.LR: 0.01
  231. }, json.loads(lines[1]))
  232. self.assertDictContainsSubset(
  233. {
  234. LogKeys.MODE: ModeKeys.EVAL,
  235. LogKeys.EPOCH: 1,
  236. LogKeys.ITER: 20
  237. }, json.loads(lines[2]))
  238. self.assertDictContainsSubset(
  239. {
  240. LogKeys.MODE: ModeKeys.TRAIN,
  241. LogKeys.EPOCH: 2,
  242. LogKeys.ITER: 10,
  243. LogKeys.LR: 0.01
  244. }, json.loads(lines[3]))
  245. self.assertDictContainsSubset(
  246. {
  247. LogKeys.MODE: ModeKeys.TRAIN,
  248. LogKeys.EPOCH: 2,
  249. LogKeys.ITER: 20,
  250. LogKeys.LR: 0.01
  251. }, json.loads(lines[4]))
  252. self.assertDictContainsSubset(
  253. {
  254. LogKeys.MODE: ModeKeys.EVAL,
  255. LogKeys.EPOCH: 2,
  256. LogKeys.ITER: 20
  257. }, json.loads(lines[5]))
  258. self.assertDictContainsSubset(
  259. {
  260. LogKeys.MODE: ModeKeys.TRAIN,
  261. LogKeys.EPOCH: 3,
  262. LogKeys.ITER: 10,
  263. LogKeys.LR: 0.001
  264. }, json.loads(lines[6]))
  265. self.assertDictContainsSubset(
  266. {
  267. LogKeys.MODE: ModeKeys.TRAIN,
  268. LogKeys.EPOCH: 3,
  269. LogKeys.ITER: 20,
  270. LogKeys.LR: 0.001
  271. }, json.loads(lines[7]))
  272. self.assertDictContainsSubset(
  273. {
  274. LogKeys.MODE: ModeKeys.EVAL,
  275. LogKeys.EPOCH: 3,
  276. LogKeys.ITER: 20
  277. }, json.loads(lines[8]))
  278. self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
  279. self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
  280. self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files)
  281. for i in [0, 1, 3, 4, 6, 7]:
  282. self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i])
  283. self.assertIn(LogKeys.ITER_TIME, lines[i])
  284. for i in [2, 5, 8]:
  285. self.assertIn(MetricKeys.ACCURACY, lines[i])
  286. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  287. def test_train_with_iters_per_epoch(self):
  288. json_cfg = {
  289. 'train': {
  290. 'work_dir': self.tmp_dir,
  291. 'dataloader': {
  292. 'batch_size_per_gpu': 2,
  293. 'workers_per_gpu': 1
  294. },
  295. 'hooks': [{
  296. 'type': 'EvaluationHook',
  297. 'interval': 1
  298. }]
  299. },
  300. 'evaluation': {
  301. 'dataloader': {
  302. 'batch_size_per_gpu': 2,
  303. 'workers_per_gpu': 1,
  304. 'shuffle': False
  305. },
  306. 'metrics': [Metrics.seq_cls_metric]
  307. }
  308. }
  309. config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION)
  310. with open(config_path, 'w') as f:
  311. json.dump(json_cfg, f)
  312. model = DummyModel()
  313. optimmizer = SGD(model.parameters(), lr=0.01)
  314. lr_scheduler = StepLR(optimmizer, 2)
  315. trainer_name = Trainers.default
  316. kwargs = dict(
  317. cfg_file=config_path,
  318. model=model,
  319. data_collator=None,
  320. optimizers=(optimmizer, lr_scheduler),
  321. train_dataset=DummyIterableDataset(),
  322. eval_dataset=DummyIterableDataset(),
  323. train_iters_per_epoch=20,
  324. val_iters_per_epoch=10,
  325. max_epochs=3,
  326. device='cpu')
  327. trainer = build_trainer(trainer_name, kwargs)
  328. trainer.train()
  329. results_files = os.listdir(self.tmp_dir)
  330. json_file = os.path.join(self.tmp_dir, f'{trainer.timestamp}.log.json')
  331. with open(json_file, 'r') as f:
  332. lines = [i.strip() for i in f.readlines()]
  333. self.assertDictContainsSubset(
  334. {
  335. LogKeys.MODE: ModeKeys.TRAIN,
  336. LogKeys.EPOCH: 1,
  337. LogKeys.ITER: 10,
  338. LogKeys.LR: 0.01
  339. }, json.loads(lines[0]))
  340. self.assertDictContainsSubset(
  341. {
  342. LogKeys.MODE: ModeKeys.TRAIN,
  343. LogKeys.EPOCH: 1,
  344. LogKeys.ITER: 20,
  345. LogKeys.LR: 0.01
  346. }, json.loads(lines[1]))
  347. self.assertDictContainsSubset(
  348. {
  349. LogKeys.MODE: ModeKeys.EVAL,
  350. LogKeys.EPOCH: 1,
  351. LogKeys.ITER: 10
  352. }, json.loads(lines[2]))
  353. self.assertDictContainsSubset(
  354. {
  355. LogKeys.MODE: ModeKeys.TRAIN,
  356. LogKeys.EPOCH: 2,
  357. LogKeys.ITER: 10,
  358. LogKeys.LR: 0.01
  359. }, json.loads(lines[3]))
  360. self.assertDictContainsSubset(
  361. {
  362. LogKeys.MODE: ModeKeys.TRAIN,
  363. LogKeys.EPOCH: 2,
  364. LogKeys.ITER: 20,
  365. LogKeys.LR: 0.01
  366. }, json.loads(lines[4]))
  367. self.assertDictContainsSubset(
  368. {
  369. LogKeys.MODE: ModeKeys.EVAL,
  370. LogKeys.EPOCH: 2,
  371. LogKeys.ITER: 10
  372. }, json.loads(lines[5]))
  373. self.assertDictContainsSubset(
  374. {
  375. LogKeys.MODE: ModeKeys.TRAIN,
  376. LogKeys.EPOCH: 3,
  377. LogKeys.ITER: 10,
  378. LogKeys.LR: 0.001
  379. }, json.loads(lines[6]))
  380. self.assertDictContainsSubset(
  381. {
  382. LogKeys.MODE: ModeKeys.TRAIN,
  383. LogKeys.EPOCH: 3,
  384. LogKeys.ITER: 20,
  385. LogKeys.LR: 0.001
  386. }, json.loads(lines[7]))
  387. self.assertDictContainsSubset(
  388. {
  389. LogKeys.MODE: ModeKeys.EVAL,
  390. LogKeys.EPOCH: 3,
  391. LogKeys.ITER: 10
  392. }, json.loads(lines[8]))
  393. self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
  394. self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
  395. self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files)
  396. for i in [0, 1, 3, 4, 6, 7]:
  397. self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i])
  398. self.assertIn(LogKeys.ITER_TIME, lines[i])
  399. for i in [2, 5, 8]:
  400. self.assertIn(MetricKeys.ACCURACY, lines[i])
  401. class DummyTrainerTest(unittest.TestCase):
  402. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  403. def test_dummy(self):
  404. default_args = dict(cfg_file='configs/examples/train.json')
  405. trainer = build_trainer('dummy', default_args)
  406. trainer.train()
  407. trainer.evaluate()
  408. if __name__ == '__main__':
  409. unittest.main()