You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trainer.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import tempfile
  5. import unittest
  6. import json
  7. import numpy as np
  8. import torch
  9. from torch import nn
  10. from torch.optim import SGD
  11. from torch.optim.lr_scheduler import StepLR
  12. from torch.utils.data import IterableDataset
  13. from modelscope.metainfo import Metrics, Trainers
  14. from modelscope.metrics.builder import MetricKeys
  15. from modelscope.models.base import Model
  16. from modelscope.trainers import build_trainer
  17. from modelscope.trainers.base import DummyTrainer
  18. from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile, Tasks
  19. from modelscope.utils.test_utils import create_dummy_test_dataset, test_level
  20. class DummyIterableDataset(IterableDataset):
  21. def __iter__(self):
  22. feat = np.random.random(size=(5, )).astype(np.float32)
  23. labels = np.random.randint(0, 4, (1, ))
  24. iterations = [{'feat': feat, 'labels': labels}] * 500
  25. return iter(iterations)
  26. dummy_dataset_small = create_dummy_test_dataset(
  27. np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20)
  28. dummy_dataset_big = create_dummy_test_dataset(
  29. np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 40)
  30. class DummyModel(nn.Module, Model):
  31. def __init__(self):
  32. super().__init__()
  33. self.linear = nn.Linear(5, 4)
  34. self.bn = nn.BatchNorm1d(4)
  35. def forward(self, feat, labels):
  36. x = self.linear(feat)
  37. x = self.bn(x)
  38. loss = torch.sum(x)
  39. return dict(logits=x, loss=loss)
  40. class TrainerTest(unittest.TestCase):
  41. def setUp(self):
  42. print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
  43. self.tmp_dir = tempfile.TemporaryDirectory().name
  44. if not os.path.exists(self.tmp_dir):
  45. os.makedirs(self.tmp_dir)
  46. def tearDown(self):
  47. super().tearDown()
  48. shutil.rmtree(self.tmp_dir)
  49. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  50. def test_train_0(self):
  51. json_cfg = {
  52. 'task': Tasks.image_classification,
  53. 'train': {
  54. 'work_dir':
  55. self.tmp_dir,
  56. 'dataloader': {
  57. 'batch_size_per_gpu': 2,
  58. 'workers_per_gpu': 1
  59. },
  60. 'optimizer': {
  61. 'type': 'SGD',
  62. 'lr': 0.01,
  63. 'options': {
  64. 'grad_clip': {
  65. 'max_norm': 2.0
  66. }
  67. }
  68. },
  69. 'lr_scheduler': {
  70. 'type': 'StepLR',
  71. 'step_size': 2,
  72. 'options': {
  73. 'warmup': {
  74. 'type': 'LinearWarmup',
  75. 'warmup_iters': 2
  76. }
  77. }
  78. },
  79. 'hooks': [{
  80. 'type': 'CheckpointHook',
  81. 'interval': 1
  82. }, {
  83. 'type': 'TextLoggerHook',
  84. 'interval': 1
  85. }, {
  86. 'type': 'IterTimerHook'
  87. }, {
  88. 'type': 'EvaluationHook',
  89. 'interval': 1
  90. }]
  91. },
  92. 'evaluation': {
  93. 'dataloader': {
  94. 'batch_size_per_gpu': 2,
  95. 'workers_per_gpu': 1,
  96. 'shuffle': False
  97. },
  98. 'metrics': [Metrics.seq_cls_metric]
  99. }
  100. }
  101. config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION)
  102. with open(config_path, 'w') as f:
  103. json.dump(json_cfg, f)
  104. trainer_name = Trainers.default
  105. kwargs = dict(
  106. cfg_file=config_path,
  107. model=DummyModel(),
  108. data_collator=None,
  109. train_dataset=dummy_dataset_small,
  110. eval_dataset=dummy_dataset_small,
  111. max_epochs=3,
  112. device='cpu')
  113. trainer = build_trainer(trainer_name, kwargs)
  114. trainer.train()
  115. results_files = os.listdir(self.tmp_dir)
  116. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  117. self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
  118. self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
  119. self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files)
  120. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  121. def test_train_1(self):
  122. json_cfg = {
  123. 'task': Tasks.image_classification,
  124. 'train': {
  125. 'work_dir':
  126. self.tmp_dir,
  127. 'dataloader': {
  128. 'batch_size_per_gpu': 2,
  129. 'workers_per_gpu': 1
  130. },
  131. 'hooks': [{
  132. 'type': 'CheckpointHook',
  133. 'interval': 1
  134. }, {
  135. 'type': 'TextLoggerHook',
  136. 'interval': 1
  137. }, {
  138. 'type': 'IterTimerHook'
  139. }, {
  140. 'type': 'EvaluationHook',
  141. 'interval': 1
  142. }]
  143. },
  144. 'evaluation': {
  145. 'dataloader': {
  146. 'batch_size_per_gpu': 2,
  147. 'workers_per_gpu': 1,
  148. 'shuffle': False
  149. },
  150. 'metrics': [Metrics.seq_cls_metric]
  151. }
  152. }
  153. config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION)
  154. with open(config_path, 'w') as f:
  155. json.dump(json_cfg, f)
  156. model = DummyModel()
  157. optimmizer = SGD(model.parameters(), lr=0.01)
  158. lr_scheduler = StepLR(optimmizer, 2)
  159. trainer_name = Trainers.default
  160. kwargs = dict(
  161. cfg_file=config_path,
  162. model=model,
  163. data_collator=None,
  164. train_dataset=dummy_dataset_small,
  165. eval_dataset=dummy_dataset_small,
  166. optimizers=(optimmizer, lr_scheduler),
  167. max_epochs=3,
  168. device='cpu')
  169. trainer = build_trainer(trainer_name, kwargs)
  170. trainer.train()
  171. results_files = os.listdir(self.tmp_dir)
  172. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  173. self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
  174. self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
  175. self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files)
  176. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  177. def test_train_with_default_config(self):
  178. json_cfg = {
  179. 'task': Tasks.image_classification,
  180. 'train': {
  181. 'work_dir': self.tmp_dir,
  182. 'dataloader': {
  183. 'batch_size_per_gpu': 2,
  184. 'workers_per_gpu': 1
  185. },
  186. 'hooks': [{
  187. 'type': 'EvaluationHook',
  188. 'interval': 1
  189. }]
  190. },
  191. 'evaluation': {
  192. 'dataloader': {
  193. 'batch_size_per_gpu': 2,
  194. 'workers_per_gpu': 1,
  195. 'shuffle': False
  196. },
  197. 'metrics': [Metrics.seq_cls_metric]
  198. }
  199. }
  200. config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION)
  201. with open(config_path, 'w') as f:
  202. json.dump(json_cfg, f)
  203. model = DummyModel()
  204. optimmizer = SGD(model.parameters(), lr=0.01)
  205. lr_scheduler = StepLR(optimmizer, 2)
  206. trainer_name = Trainers.default
  207. kwargs = dict(
  208. cfg_file=config_path,
  209. model=model,
  210. data_collator=None,
  211. train_dataset=dummy_dataset_big,
  212. eval_dataset=dummy_dataset_small,
  213. optimizers=(optimmizer, lr_scheduler),
  214. max_epochs=3,
  215. device='cpu')
  216. trainer = build_trainer(trainer_name, kwargs)
  217. trainer.train()
  218. results_files = os.listdir(self.tmp_dir)
  219. json_file = os.path.join(self.tmp_dir, f'{trainer.timestamp}.log.json')
  220. with open(json_file, 'r', encoding='utf-8') as f:
  221. lines = [i.strip() for i in f.readlines()]
  222. self.assertDictContainsSubset(
  223. {
  224. LogKeys.MODE: ModeKeys.TRAIN,
  225. LogKeys.EPOCH: 1,
  226. LogKeys.ITER: 10,
  227. LogKeys.LR: 0.01
  228. }, json.loads(lines[0]))
  229. self.assertDictContainsSubset(
  230. {
  231. LogKeys.MODE: ModeKeys.TRAIN,
  232. LogKeys.EPOCH: 1,
  233. LogKeys.ITER: 20,
  234. LogKeys.LR: 0.01
  235. }, json.loads(lines[1]))
  236. self.assertDictContainsSubset(
  237. {
  238. LogKeys.MODE: ModeKeys.EVAL,
  239. LogKeys.EPOCH: 1,
  240. LogKeys.ITER: 10
  241. }, json.loads(lines[2]))
  242. self.assertDictContainsSubset(
  243. {
  244. LogKeys.MODE: ModeKeys.TRAIN,
  245. LogKeys.EPOCH: 2,
  246. LogKeys.ITER: 10,
  247. LogKeys.LR: 0.01
  248. }, json.loads(lines[3]))
  249. self.assertDictContainsSubset(
  250. {
  251. LogKeys.MODE: ModeKeys.TRAIN,
  252. LogKeys.EPOCH: 2,
  253. LogKeys.ITER: 20,
  254. LogKeys.LR: 0.01
  255. }, json.loads(lines[4]))
  256. self.assertDictContainsSubset(
  257. {
  258. LogKeys.MODE: ModeKeys.EVAL,
  259. LogKeys.EPOCH: 2,
  260. LogKeys.ITER: 10
  261. }, json.loads(lines[5]))
  262. self.assertDictContainsSubset(
  263. {
  264. LogKeys.MODE: ModeKeys.TRAIN,
  265. LogKeys.EPOCH: 3,
  266. LogKeys.ITER: 10,
  267. LogKeys.LR: 0.001
  268. }, json.loads(lines[6]))
  269. self.assertDictContainsSubset(
  270. {
  271. LogKeys.MODE: ModeKeys.TRAIN,
  272. LogKeys.EPOCH: 3,
  273. LogKeys.ITER: 20,
  274. LogKeys.LR: 0.001
  275. }, json.loads(lines[7]))
  276. self.assertDictContainsSubset(
  277. {
  278. LogKeys.MODE: ModeKeys.EVAL,
  279. LogKeys.EPOCH: 3,
  280. LogKeys.ITER: 10
  281. }, json.loads(lines[8]))
  282. self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
  283. self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
  284. self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files)
  285. for i in [0, 1, 3, 4, 6, 7]:
  286. self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i])
  287. self.assertIn(LogKeys.ITER_TIME, lines[i])
  288. for i in [2, 5, 8]:
  289. self.assertIn(MetricKeys.ACCURACY, lines[i])
  290. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  291. def test_train_with_iters_per_epoch(self):
  292. json_cfg = {
  293. 'task': Tasks.image_classification,
  294. 'train': {
  295. 'work_dir': self.tmp_dir,
  296. 'dataloader': {
  297. 'batch_size_per_gpu': 2,
  298. 'workers_per_gpu': 1
  299. },
  300. 'hooks': [{
  301. 'type': 'EvaluationHook',
  302. 'interval': 1
  303. }]
  304. },
  305. 'evaluation': {
  306. 'dataloader': {
  307. 'batch_size_per_gpu': 2,
  308. 'workers_per_gpu': 1,
  309. 'shuffle': False
  310. },
  311. 'metrics': [Metrics.seq_cls_metric]
  312. }
  313. }
  314. config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION)
  315. with open(config_path, 'w') as f:
  316. json.dump(json_cfg, f)
  317. model = DummyModel()
  318. optimmizer = SGD(model.parameters(), lr=0.01)
  319. lr_scheduler = StepLR(optimmizer, 2)
  320. trainer_name = Trainers.default
  321. kwargs = dict(
  322. cfg_file=config_path,
  323. model=model,
  324. data_collator=None,
  325. optimizers=(optimmizer, lr_scheduler),
  326. train_dataset=DummyIterableDataset(),
  327. eval_dataset=DummyIterableDataset(),
  328. train_iters_per_epoch=20,
  329. val_iters_per_epoch=10,
  330. max_epochs=3,
  331. device='cpu')
  332. trainer = build_trainer(trainer_name, kwargs)
  333. trainer.train()
  334. results_files = os.listdir(self.tmp_dir)
  335. json_file = os.path.join(self.tmp_dir, f'{trainer.timestamp}.log.json')
  336. with open(json_file, 'r', encoding='utf-8') as f:
  337. lines = [i.strip() for i in f.readlines()]
  338. self.assertDictContainsSubset(
  339. {
  340. LogKeys.MODE: ModeKeys.TRAIN,
  341. LogKeys.EPOCH: 1,
  342. LogKeys.ITER: 10,
  343. LogKeys.LR: 0.01
  344. }, json.loads(lines[0]))
  345. self.assertDictContainsSubset(
  346. {
  347. LogKeys.MODE: ModeKeys.TRAIN,
  348. LogKeys.EPOCH: 1,
  349. LogKeys.ITER: 20,
  350. LogKeys.LR: 0.01
  351. }, json.loads(lines[1]))
  352. self.assertDictContainsSubset(
  353. {
  354. LogKeys.MODE: ModeKeys.EVAL,
  355. LogKeys.EPOCH: 1,
  356. LogKeys.ITER: 10
  357. }, json.loads(lines[2]))
  358. self.assertDictContainsSubset(
  359. {
  360. LogKeys.MODE: ModeKeys.TRAIN,
  361. LogKeys.EPOCH: 2,
  362. LogKeys.ITER: 10,
  363. LogKeys.LR: 0.01
  364. }, json.loads(lines[3]))
  365. self.assertDictContainsSubset(
  366. {
  367. LogKeys.MODE: ModeKeys.TRAIN,
  368. LogKeys.EPOCH: 2,
  369. LogKeys.ITER: 20,
  370. LogKeys.LR: 0.01
  371. }, json.loads(lines[4]))
  372. self.assertDictContainsSubset(
  373. {
  374. LogKeys.MODE: ModeKeys.EVAL,
  375. LogKeys.EPOCH: 2,
  376. LogKeys.ITER: 10
  377. }, json.loads(lines[5]))
  378. self.assertDictContainsSubset(
  379. {
  380. LogKeys.MODE: ModeKeys.TRAIN,
  381. LogKeys.EPOCH: 3,
  382. LogKeys.ITER: 10,
  383. LogKeys.LR: 0.001
  384. }, json.loads(lines[6]))
  385. self.assertDictContainsSubset(
  386. {
  387. LogKeys.MODE: ModeKeys.TRAIN,
  388. LogKeys.EPOCH: 3,
  389. LogKeys.ITER: 20,
  390. LogKeys.LR: 0.001
  391. }, json.loads(lines[7]))
  392. self.assertDictContainsSubset(
  393. {
  394. LogKeys.MODE: ModeKeys.EVAL,
  395. LogKeys.EPOCH: 3,
  396. LogKeys.ITER: 10
  397. }, json.loads(lines[8]))
  398. self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
  399. self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
  400. self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files)
  401. for i in [0, 1, 3, 4, 6, 7]:
  402. self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i])
  403. self.assertIn(LogKeys.ITER_TIME, lines[i])
  404. for i in [2, 5, 8]:
  405. self.assertIn(MetricKeys.ACCURACY, lines[i])
  406. class DummyTrainerTest(unittest.TestCase):
  407. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  408. def test_dummy(self):
  409. default_args = dict(cfg_file='configs/examples/train.json')
  410. trainer = build_trainer('dummy', default_args)
  411. trainer.train()
  412. trainer.evaluate()
  413. if __name__ == '__main__':
  414. unittest.main()