You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_model.py 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import time
  2. import numpy
  3. import pytorch_lightning as pl
  4. from pytorch_lightning.utilities.types import EPOCH_OUTPUT
  5. from torch import nn
  6. import torch
  7. from network.MLP_JDLU import MLP_JDLU
  8. from network.MLP_ReLU import MLP_ReLU
  9. class TrainModule(pl.LightningModule):
  10. def __init__(self, config=None):
  11. super().__init__()
  12. self.time_sum = None
  13. self.config = config
  14. if 1:
  15. self.net = MLP_ReLU(config['dim_in'], config['dim'], config['res_coef'], config['dropout_p'],
  16. config['n_layers'])
  17. else:
  18. self.net = MLP_JDLU(config['dim_in'], config['dim'], config['res_coef'], config['dropout_p'],
  19. config['n_layers'])
  20. self.loss = nn.MSELoss()
  21. def training_step(self, batch, batch_idx):
  22. x, y = batch
  23. x = self.net(x)
  24. loss = self.loss(x, y.type(torch.float32))
  25. self.log("Training loss", loss)
  26. return loss
  27. def validation_step(self, batch, batch_idx):
  28. x, y = batch
  29. x = self.net(x)
  30. loss = self.loss(x, y.type(torch.float32))
  31. self.log("Validation loss", loss)
  32. return loss
  33. def test_step(self, batch, batch_idx):
  34. input, label = batch
  35. if self.time_sum is None:
  36. time_start = time.time()
  37. pred = self.net(input)
  38. time_end = time.time()
  39. self.time_sum = time_end - time_start
  40. print(f'\n推理时间为: {self.time_sum:f}')
  41. else:
  42. pred = self.net(input)
  43. loss = self.loss(pred.reshape(1), label.type(torch.float32))
  44. self.log("Test loss", loss)
  45. return input, label, pred
  46. def test_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
  47. records = numpy.empty((self.config['dataset_len'], 4))
  48. # count
  49. for cou in range(len(outputs)):
  50. records[cou, 0] = outputs[cou][0][0, 0]
  51. records[cou, 1] = outputs[cou][0][0, 1]
  52. records[cou, 2] = outputs[cou][1][0]
  53. records[cou, 3] = outputs[cou][2]
  54. import plotly.graph_objects as go
  55. trace0 = go.Mesh3d(x=records[:, 0],
  56. y=records[:, 1],
  57. z=records[:, 2],
  58. opacity=0.5,
  59. name='label'
  60. )
  61. trace1 = go.Mesh3d(x=records[:, 0],
  62. y=records[:, 1],
  63. z=records[:, 3],
  64. opacity=0.5,
  65. name='pred'
  66. )
  67. fig = go.Figure(data=[trace0, trace1])
  68. fig.update_layout(
  69. scene=dict(
  70. # xaxis=dict(nticks=4, range=[-100, 100], ),
  71. # yaxis=dict(nticks=4, range=[-50, 100], ),
  72. # zaxis=dict(nticks=4, range=[-100, 100], ),
  73. aspectratio=dict(x=1, y=1, z=0.5),
  74. ),
  75. )
  76. fig.show()
  77. def configure_optimizers(self):
  78. optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
  79. return optimizer
  80. def load_pretrain_parameters(self):
  81. """
  82. 载入预训练参数
  83. """
  84. pass

基于pytorch lightning的机器学习模板, 用于对机器学习算法进行训练, 验证, 测试等, 目前实现了神经网路, 深度学习, k折交叉, 自动保存训练信息等.

Contributors (1)