You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_model.py 1.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import pytorch_lightning as pl
  2. from torch import nn
  3. import torch
  4. from torchmetrics.classification.accuracy import Accuracy
  5. from network.MLP import MLP
  6. class TrainModule(pl.LightningModule):
  7. def __init__(self, config=None):
  8. super().__init__()
  9. self.time_sum = None
  10. self.config = config
  11. self.net = MLP(config['dim_in'], config['dim'], config['res_coef'], config['dropout_p'], config['n_layers'])
  12. # TODO 修改网络初始化方式为kaiming分布或者xavier分布
  13. self.loss = nn.BCELoss()
  14. self.accuracy = Accuracy()
  15. def training_step(self, batch, batch_idx):
  16. x, y = batch
  17. x = self.net(x)
  18. loss = self.loss(x, y.type(torch.float32))
  19. acc = self.accuracy(x, y)
  20. self.log("Training loss", loss)
  21. self.log("Training acc", acc)
  22. return loss
  23. def validation_step(self, batch, batch_idx):
  24. x, y = batch
  25. x = self.net(x)
  26. loss = self.loss(x, y.type(torch.float32))
  27. acc = self.accuracy(x, y)
  28. self.log("Validation loss", loss)
  29. self.log("Validation acc", acc)
  30. return loss, acc
  31. def test_step(self, batch, batch_idx):
  32. return 0
  33. def configure_optimizers(self):
  34. optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
  35. return optimizer
  36. def load_pretrain_parameters(self):
  37. """
  38. 载入预训练参数
  39. """
  40. pass

基于pytorch lightning的机器学习模板, 用于对机器学习算法进行训练, 验证, 测试等, 目前实现了神经网路, 深度学习, k折交叉, 自动保存训练信息等.

Contributors (1)