You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

save_checkpoint.py 3.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import os
  2. import numpy.random
  3. from pytorch_lightning.callbacks import ModelCheckpoint
  4. import pytorch_lightning as pl
  5. import shutil
  6. from pytorch_lightning.utilities import rank_zero_info
  7. from utils import zip_dir
  8. class SaveCheckpoint(ModelCheckpoint):
  9. def __init__(self,
  10. max_epochs,
  11. seed=None,
  12. every_n_epochs=None,
  13. path_final_save=None,
  14. monitor=None,
  15. save_top_k=None,
  16. verbose=False,
  17. mode='min',
  18. no_save_before_epoch=0):
  19. """
  20. 通过回调实现checkpoint的保存逻辑, 同时具有回调函数中定义on_validation_end等功能.
  21. :param max_epochs:
  22. :param seed:
  23. :param every_n_epochs:
  24. :param path_final_save:
  25. :param monitor:
  26. :param save_top_k:
  27. :param verbose:
  28. :param mode:
  29. :param no_save_before_epoch:
  30. """
  31. super().__init__(every_n_epochs=every_n_epochs, verbose=verbose, mode=mode)
  32. numpy.random.seed(seed)
  33. self.seeds = numpy.random.randint(0, 2000, max_epochs)
  34. pl.seed_everything(seed)
  35. self.path_final_save = path_final_save
  36. self.monitor = monitor
  37. self.save_top_k = save_top_k
  38. self.flag_sanity_check = 0
  39. self.no_save_before_epoch = no_save_before_epoch
  40. def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
  41. """
  42. 修改随机数逻辑,网络的随机种子给定,取样本的随机种子由给定的随机种子生成,保证即使重载训练每个epoch具有不同的抽样序列.
  43. 同时保存checkpoint.
  44. :param trainer:
  45. :param pl_module:
  46. :return:
  47. """
  48. # 第一个epoch使用原始输入seed作为种子, 后续的epoch使用seeds中的第epoch-1个作为种子
  49. if self.flag_sanity_check == 0:
  50. self.flag_sanity_check = 1
  51. else:
  52. pl.seed_everything(self.seeds[trainer.current_epoch])
  53. super().on_validation_end(trainer, pl_module)
  54. def _save_top_k_checkpoint(self, trainer: 'pl.Trainer', monitor_candidates) -> None:
  55. epoch = monitor_candidates.get("epoch")
  56. if self.monitor is None or self.save_top_k == 0 or epoch < self.no_save_before_epoch:
  57. return
  58. current = monitor_candidates.get(self.monitor)
  59. if self.check_monitor_top_k(trainer, current):
  60. self._update_best_and_save(current, trainer, monitor_candidates)
  61. if self.path_final_save is not None:
  62. zip_dir('./logs', './logs.zip')
  63. if os.path.exists(self.path_final_save + '/logs.zip'):
  64. os.remove(self.path_final_save + '/logs.zip')
  65. shutil.move('./logs.zip', self.path_final_save)
  66. elif self.verbose:
  67. epoch = monitor_candidates.get("epoch")
  68. step = monitor_candidates.get("step")
  69. best_model_values = 'now best model:'
  70. for cou_best_model in self.best_k_models:
  71. best_model_values = ' '.join(
  72. (best_model_values, str(round(float(self.best_k_models[cou_best_model]), 4))))
  73. rank_zero_info(
  74. f"\nEpoch {epoch:d}, global step {step:d}: {self.monitor} ({float(current):f}) was not in "
  75. f"top {self.save_top_k:d}({best_model_values:s})")

基于pytorch lightning的机器学习模板, 用于对机器学习算法进行训练, 验证, 测试等, 目前实现了神经网路, 深度学习, k折交叉, 自动保存训练信息等.

Contributors (1)