You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_model.py 1.9 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import time
  2. import numpy
  3. import pytorch_lightning as pl
  4. from pytorch_lightning.utilities.types import EPOCH_OUTPUT
  5. from torch import nn
  6. import torch
  7. from network.res_net import resnet56, accuracy
  8. class TrainModule(pl.LightningModule):
  9. def __init__(self, config=None):
  10. super().__init__()
  11. self.time_sum = None
  12. self.config = config
  13. self.net = resnet56()
  14. self.loss = nn.CrossEntropyLoss()
  15. def training_step(self, batch, batch_idx):
  16. _, input, label = batch
  17. label = label.flatten()
  18. pred = self.net(input)
  19. loss = self.loss(pred, label)
  20. self.log("Training loss", loss)
  21. acc = accuracy(pred, label)[0]
  22. self.log("Training acc", acc)
  23. return loss
  24. def validation_step(self, batch, batch_idx):
  25. _, input, label = batch
  26. label = label.flatten()
  27. pred = self.net(input)
  28. loss = self.loss(pred, label)
  29. self.log("Validation loss", loss)
  30. acc = accuracy(pred, label)[0]
  31. self.log("Validation acc", acc)
  32. return loss
  33. def test_step(self, batch, batch_idx):
  34. _, input, label = batch
  35. label = label.flatten()
  36. if self.time_sum is None:
  37. time_start = time.time()
  38. pred = self.net(input)
  39. time_end = time.time()
  40. self.time_sum = time_end - time_start
  41. print(f'\n推理时间为: {self.time_sum:f}')
  42. else:
  43. pred = self.net(input)
  44. loss = self.loss(pred, label)
  45. self.log("Test loss", loss)
  46. acc = accuracy(pred, label)[0]
  47. self.log("Test acc", acc)
  48. return input, label, pred
  49. def configure_optimizers(self):
  50. optimizer = torch.optim.SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
  51. return optimizer
  52. def load_pretrain_parameters(self):
  53. """
  54. 载入预训练参数
  55. """
  56. pass

基于pytorch lightning的机器学习模板, 用于对机器学习算法进行训练, 验证, 测试等, 目前实现了神经网路, 深度学习, k折交叉, 自动保存训练信息等.

Contributors (1)