You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_model.py 2.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import time
  2. import numpy
  3. import pytorch_lightning as pl
  4. from pytorch_lightning.utilities.types import EPOCH_OUTPUT
  5. from torch import nn
  6. import torch
  7. from network.res_net import resnet56, accuracy
  8. class TrainModule(pl.LightningModule):
  9. def __init__(self, config=None):
  10. super().__init__()
  11. self.time_sum = None
  12. self.config = config
  13. self.net = resnet56()
  14. self.loss = nn.CrossEntropyLoss()
  15. # 返回值必须包含loss, loss可以作为dict中的key, 或者直接返回loss
  16. def training_step(self, batch, batch_idx):
  17. _, input, label = batch
  18. label = label.flatten()
  19. pred = self.net(input)
  20. loss = self.loss(pred, label)
  21. self.log("Training loss", loss)
  22. acc = accuracy(pred, label)[0]
  23. self.log("Training acc", acc)
  24. return loss
  25. def validation_step(self, batch, batch_idx):
  26. _, input, label = batch
  27. label = label.flatten()
  28. pred = self.net(input)
  29. loss = self.loss(pred, label)
  30. self.log("Validation loss", loss)
  31. acc = accuracy(pred, label)[0]
  32. self.log("Validation acc", acc)
  33. return loss
  34. def test_step(self, batch, batch_idx):
  35. _, input, label = batch
  36. label = label.flatten()
  37. if self.time_sum is None:
  38. time_start = time.time()
  39. pred = self.net(input)
  40. time_end = time.time()
  41. self.time_sum = time_end - time_start
  42. print(f'\n推理时间为: {self.time_sum:f}')
  43. else:
  44. pred = self.net(input)
  45. loss = self.loss(pred, label)
  46. self.log("Test loss", loss)
  47. acc = accuracy(pred, label)[0]
  48. self.log("Test acc", acc)
  49. return input, label, pred
  50. def configure_optimizers(self):
  51. optimizer = torch.optim.SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
  52. return optimizer
  53. def load_pretrain_parameters(self):
  54. """
  55. 载入预训练参数
  56. """
  57. pass

基于pytorch lightning的机器学习模板, 用于对机器学习算法进行训练, 验证, 测试等, 目前实现了神经网路, 深度学习, k折交叉, 自动保存训练信息等.

Contributors (1)