You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

save_checkpoint.py 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import os
  2. import numpy.random
  3. from pytorch_lightning.callbacks import ModelCheckpoint
  4. import pytorch_lightning as pl
  5. import shutil
  6. from pytorch_lightning.utilities import rank_zero_info
  7. from utils import zip_dir
  8. import re
  9. class SaveCheckpoint(ModelCheckpoint):
  10. def __init__(self,
  11. max_epochs,
  12. seed=None,
  13. every_n_epochs=None,
  14. path_final_save=None,
  15. monitor=None,
  16. save_top_k=None,
  17. verbose=False,
  18. mode='min',
  19. no_save_before_epoch=0,
  20. version_info='无'):
  21. """
  22. 通过回调实现checkpoint的保存逻辑, 同时具有回调函数中定义on_validation_end等功能.
  23. :param max_epochs:
  24. :param seed:
  25. :param every_n_epochs:
  26. :param path_final_save:
  27. :param monitor:
  28. :param save_top_k:
  29. :param verbose:
  30. :param mode:
  31. :param no_save_before_epoch:
  32. :param version_info:
  33. """
  34. super().__init__(every_n_epochs=every_n_epochs, verbose=verbose, mode=mode)
  35. self.mode = mode
  36. numpy.random.seed(seed)
  37. self.seeds = numpy.random.randint(0, 2000, max_epochs)
  38. pl.seed_everything(seed)
  39. self.path_final_save = path_final_save
  40. self.monitor = monitor
  41. self.save_top_k = save_top_k
  42. self.flag_sanity_check = 0
  43. self.no_save_before_epoch = no_save_before_epoch
  44. self.version_info = version_info
  45. def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
  46. """
  47. 修改随机数逻辑,网络的随机种子给定,取样本的随机种子由给定的随机种子生成,保证即使重载训练每个epoch具有不同的抽样序列.
  48. 同时保存checkpoint.
  49. :param trainer:
  50. :param pl_module:
  51. :return:
  52. """
  53. # 第一个epoch使用原始输入seed作为种子, 后续的epoch使用seeds中的第epoch-1个作为种子
  54. if self.flag_sanity_check == 0:
  55. self.flag_sanity_check = 1
  56. else:
  57. pl.seed_everything(self.seeds[trainer.current_epoch])
  58. super().on_validation_end(trainer, pl_module)
  59. def _save_top_k_checkpoint(self, trainer: 'pl.Trainer', monitor_candidates) -> None:
  60. epoch = monitor_candidates.get("epoch")
  61. if self.monitor is None or self.save_top_k == 0 or epoch < self.no_save_before_epoch:
  62. return
  63. current = monitor_candidates.get(self.monitor)
  64. if self.check_monitor_top_k(trainer, current):
  65. self._update_best_and_save(current, trainer, monitor_candidates)
  66. if self.mode == 'max':
  67. best_model_value = max([float(item) for item in list(self.best_k_models.values())])
  68. else:
  69. best_model_value = min([float(item) for item in list(self.best_k_models.values())])
  70. version_name = 'version_unkown'
  71. for item in re.split(r'[/|\\]', self.dirpath):
  72. if 'version_' in item:
  73. version_name = item
  74. break
  75. # 保存版本信息(准确率等)到txt中
  76. if not os.path.exists('./logs/default/version_info.txt'):
  77. with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f:
  78. f.write(version_name + ' ' + str(best_model_value) + ' ' + self.version_info + '\n')
  79. else:
  80. with open('./logs/default/version_info.txt', 'r', encoding='utf-8') as f:
  81. info_list = f.readlines()
  82. info_list = [item.strip('\n').split(' ') for item in info_list]
  83. # 对list进行转置, 转置前行为版本号和其数据, 列为不同的版本
  84. info_list = list(map(list, zip(*info_list)))
  85. if version_name in info_list[0]:
  86. for cou in range(len(info_list[0])):
  87. if version_name == info_list[0][cou]:
  88. info_list[1][cou] = str(best_model_value)
  89. info_list[2][cou] = self.version_info
  90. else:
  91. info_list[0].append(version_name)
  92. info_list[1].append(str(best_model_value))
  93. info_list[2].append(self.version_info)
  94. # 对list进行转置
  95. info_list = list(map(list, zip(*info_list)))
  96. with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f:
  97. for line in info_list:
  98. line = " ".join(line)
  99. f.write(line + '\n')
  100. # 每次更新ckpt文件后, 将其存放到另一个位置
  101. if self.path_final_save is not None:
  102. zip_dir('./logs/default/' + version_name, './' + version_name + '.zip')
  103. if os.path.exists(self.path_final_save + '/' + version_name + '.zip'):
  104. os.remove(self.path_final_save + '/' + version_name + '.zip')
  105. shutil.move('./' + version_name + '.zip', self.path_final_save)
  106. elif self.verbose:
  107. epoch = monitor_candidates.get("epoch")
  108. step = monitor_candidates.get("step")
  109. best_model_values = 'now best model:'
  110. for cou_best_model in self.best_k_models:
  111. best_model_values = ' '.join(
  112. (best_model_values, str(round(float(self.best_k_models[cou_best_model]), 4))))
  113. rank_zero_info(
  114. f"\nEpoch {epoch:d}, global step {step:d}: {self.monitor} ({float(current):f}) was not in "
  115. f"top {self.save_top_k:d}({best_model_values:s})")

基于pytorch lightning的机器学习模板, 用于对机器学习算法进行训练, 验证, 测试等, 目前实现了神经网路, 深度学习, k折交叉, 自动保存训练信息等.

Contributors (1)