You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import torch
  2. from save_checkpoint import SaveCheckpoint
  3. from data_module import DataModule
  4. from pytorch_lightning import loggers as pl_loggers
  5. import pytorch_lightning as pl
  6. from train_model import TrainModule
  7. from multiprocessing import cpu_count
  8. from utils import get_ckpt_path
  9. def main(stage,
  10. max_epochs,
  11. batch_size,
  12. precision,
  13. seed,
  14. dataset_path,
  15. k_fold,
  16. kth_fold_start,
  17. gpus=None,
  18. tpu_cores=None,
  19. version_nth=None,
  20. path_final_save=None,
  21. every_n_epochs=1,
  22. save_top_k=1,
  23. version_info='无'
  24. ):
  25. """
  26. 框架的入口函数. 包含设置超参数, 划分数据集, 选择训练或测试等流程
  27. 该函数的参数为训练过程中需要经常改动的参数
  28. 经常改动的 参数 作为main的输入参数
  29. 不常改动的 非通用参数 存放在config
  30. 不常改动的 通用参数 直接进行声明
  31. * 通用参数指的是所有网络中共有的参数, 如time_sum等
  32. :param stage: 表示处于训练阶段还是测试阶段, fit表示训练, test表示测试
  33. :param max_epochs:
  34. :param batch_size:
  35. :param precision: 训练精度, 正常精度为32, 半精度为16, 也可以是64. 精度代表每个参数的类型所占的位数
  36. :param seed:
  37. :param dataset_path: 数据集地址, 其目录下包含数据集文件夹, 标签文件夹, 全部数据的命名list
  38. :param gpus:
  39. :param tpu_cores:
  40. :param version_nth: 不论是重载训练还是测试, 固定为该folds的第一个版本的版本号
  41. :param path_final_save:
  42. :param every_n_epochs: 每n个epoch设置一个检查点
  43. :param save_top_k:
  44. :param kth_fold_start: 从第几个fold开始, 若使用重载训练, 则kth_fold_start为重载第几个fold, 第一个值为0.
  45. 非重载训练的情况下, 可以通过调整该值控制训练的次数;
  46. :param k_fold:
  47. :param version_info: 版本信息, 主要记录该版本的网络数据集等
  48. """
  49. # 处理输入数据
  50. precision = 32 if (gpus is None and tpu_cores is None) else precision
  51. # 自动处理:param gpus
  52. gpus = 1 if torch.cuda.is_available() and gpus is None and tpu_cores is None else None
  53. # 定义不常改动的通用参数
  54. # TODO 获得最优的batch size
  55. num_workers = min([cpu_count(), 8])
  56. # 获得非通用参数
  57. config = {'dim_in': 32, }
  58. for kth_fold in range(kth_fold_start, k_fold):
  59. load_checkpoint_path = get_ckpt_path(version_nth, kth_fold)
  60. logger = pl_loggers.TensorBoardLogger('logs/')
  61. dm = DataModule(batch_size=batch_size, num_workers=num_workers, k_fold=k_fold, kth_fold=kth_fold,
  62. dataset_path=dataset_path, config=config)
  63. if stage == 'fit':
  64. # SaveCheckpoint的创建需要在TrainModule之前, 以保证网络参数初始化的确定性
  65. save_checkpoint = SaveCheckpoint(seed=seed, max_epochs=max_epochs,
  66. path_final_save=path_final_save,
  67. every_n_epochs=every_n_epochs, verbose=True,
  68. monitor='Validation acc', save_top_k=save_top_k,
  69. mode='max', version_info=version_info)
  70. training_module = TrainModule(config=config)
  71. if kth_fold != kth_fold_start or load_checkpoint_path is None:
  72. print('进行初始训练')
  73. trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores,
  74. logger=logger, precision=precision, callbacks=[save_checkpoint])
  75. training_module.load_pretrain_parameters()
  76. else:
  77. print('进行重载训练')
  78. trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores,
  79. resume_from_checkpoint='./logs/default' + load_checkpoint_path,
  80. logger=logger, precision=precision, callbacks=[save_checkpoint])
  81. print('训练过程中请注意gpu利用率等情况')
  82. trainer.fit(training_module, datamodule=dm)
  83. if stage == 'test':
  84. if load_checkpoint_path is None:
  85. print('未载入权重信息,不能测试')
  86. else:
  87. print('进行测试')
  88. training_module = TrainModule.load_from_checkpoint(
  89. checkpoint_path='./logs/default' + load_checkpoint_path,
  90. **{'config': config})
  91. trainer = pl.Trainer(gpus=gpus, tpu_cores=tpu_cores, logger=logger, precision=precision)
  92. trainer.test(training_module, datamodule=dm)
  93. # 在cmd中使用tensorboard --logdir logs命令可以查看结果,在Jupyter格式下需要加%前缀
  94. if __name__ == "__main__":
  95. main('fit', max_epochs=1, batch_size=128, precision=16, seed=1234, dataset_path='./dataset', k_fold=10,
  96. kth_fold_start=9, version_info='ResNet-RuLe-CIFAR100',
  97. # version_nth=8,
  98. )

基于pytorch lightning的机器学习模板, 用于对机器学习算法进行训练, 验证, 测试等, 目前实现了神经网路, 深度学习, k折交叉, 自动保存训练信息等.

Contributors (1)