You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from save_checkpoint import SaveCheckpoint
  2. from data_module import DataModule
  3. from pytorch_lightning import loggers as pl_loggers
  4. import pytorch_lightning as pl
  5. from train_model import TrainModule
  6. def main(stage,
  7. num_workers,
  8. max_epochs,
  9. batch_size,
  10. precision,
  11. seed,
  12. dataset_path=None,
  13. gpus=None,
  14. tpu_cores=None,
  15. load_checkpoint_path=None,
  16. save_name=None,
  17. path_final_save=None,
  18. every_n_epochs=1,
  19. save_top_k=1,):
  20. """
  21. 框架的入口函数. 包含设置超参数, 划分数据集, 选择训练或测试等流程
  22. 该函数的参数为训练过程中需要经常改动的参数
  23. :param stage: 表示处于训练阶段还是测试阶段, fit表示训练, test表示测试
  24. :param num_workers:
  25. :param max_epochs:
  26. :param batch_size:
  27. :param precision: 训练精度, 正常精度为32, 半精度为16, 也可以是64. 精度代表每个参数的类型所占的位数
  28. :param seed:
  29. :param dataset_path: 数据集地址, 其目录下包含数据集, 标签, 全部数据的命名list
  30. :param gpus:
  31. :param tpu_cores:
  32. :param load_checkpoint_path:
  33. :param save_name:
  34. :param path_final_save:
  35. :param every_n_epochs:
  36. :param save_top_k:
  37. """
  38. # config存放确定模型后不常改动的非通用的参数, 通用参数且不经常带动的直接进行声明
  39. if False:
  40. config = {'dataset_path': dataset_path,
  41. 'dim_in': 2,
  42. 'dim': 10,
  43. 'res_coef': 0.5,
  44. 'dropout_p': 0.1,
  45. 'n_layers': 2,
  46. 'flag': True}
  47. else:
  48. config = {'dataset_path': dataset_path,
  49. 'dim_in': 62,
  50. 'dim': 32,
  51. 'res_coef': 0.5,
  52. 'dropout_p': 0.1,
  53. 'n_layers': 20,
  54. 'flag': False}
  55. # TODO 获得最优的batch size
  56. # TODO 自动获取CPU核心数并设置num workers
  57. precision = 32 if (gpus is None and tpu_cores is None) else precision
  58. dm = DataModule(batch_size=batch_size, num_workers=num_workers, config=config)
  59. logger = pl_loggers.TensorBoardLogger('logs/')
  60. if stage == 'fit':
  61. training_module = TrainModule(config=config)
  62. save_checkpoint = SaveCheckpoint(seed=seed, max_epochs=max_epochs,
  63. save_name=save_name, path_final_save=path_final_save,
  64. every_n_epochs=every_n_epochs, verbose=True,
  65. monitor='Validation acc', save_top_k=save_top_k,
  66. mode='max')
  67. if load_checkpoint_path is None:
  68. print('进行初始训练')
  69. trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores,
  70. logger=logger, precision=precision, callbacks=[save_checkpoint])
  71. training_module.load_pretrain_parameters()
  72. else:
  73. print('进行重载训练')
  74. trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores,
  75. resume_from_checkpoint='./logs/default' + load_checkpoint_path,
  76. logger=logger, precision=precision, callbacks=[save_checkpoint])
  77. trainer.fit(training_module, datamodule=dm)
  78. if stage == 'test':
  79. if load_checkpoint_path is None:
  80. print('未载入权重信息,不能测试')
  81. else:
  82. print('进行测试')
  83. training_module = TrainModule.load_from_checkpoint(
  84. checkpoint_path='./logs/default' + load_checkpoint_path,
  85. **{'config': config})
  86. trainer = pl.Trainer(gpus=gpus, tpu_cores=tpu_cores, logger=logger, precision=precision)
  87. trainer.test(training_module, datamodule=dm)
  88. # 在cmd中使用tensorboard --logdir logs命令可以查看结果,在Jupyter格式下需要加%前缀
  89. if __name__ == "__main__":
  90. main('fit', num_workers=8, max_epochs=5, batch_size=32, precision=16, seed=1234,
  91. # gpus=1,
  92. # load_checkpoint_path='/version_5/checkpoints/epoch=149-step=7949.ckpt',
  93. )

基于pytorch lightning的机器学习模板, 用于对机器学习算法进行训练, 验证, 测试等, 目前实现了神经网路, 深度学习, k折交叉, 自动保存训练信息等.

Contributors (1)