import time import numpy import pytorch_lightning as pl from pytorch_lightning.utilities.types import EPOCH_OUTPUT from torch import nn import torch from network.MLP_JDLU import MLP_JDLU from network.MLP_ReLU import MLP_ReLU class TrainModule(pl.LightningModule): def __init__(self, config=None): super().__init__() self.time_sum = None self.config = config if 1: self.net = MLP_ReLU(config['dim_in'], config['dim'], config['res_coef'], config['dropout_p'], config['n_layers']) else: self.net = MLP_JDLU(config['dim_in'], config['dim'], config['res_coef'], config['dropout_p'], config['n_layers']) self.loss = nn.MSELoss() def training_step(self, batch, batch_idx): x, y = batch x = self.net(x) loss = self.loss(x, y.type(torch.float32)) self.log("Training loss", loss) return loss def validation_step(self, batch, batch_idx): x, y = batch x = self.net(x) loss = self.loss(x, y.type(torch.float32)) self.log("Validation loss", loss) return loss def test_step(self, batch, batch_idx): input, label = batch if self.time_sum is None: time_start = time.time() pred = self.net(input) time_end = time.time() self.time_sum = time_end - time_start print(f'\n推理时间为: {self.time_sum:f}') else: pred = self.net(input) loss = self.loss(pred.reshape(1), label.type(torch.float32)) self.log("Test loss", loss) return input, label, pred def test_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: records = numpy.empty((self.config['dataset_len'], 4)) # count for cou in range(len(outputs)): records[cou, 0] = outputs[cou][0][0, 0] records[cou, 1] = outputs[cou][0][0, 1] records[cou, 2] = outputs[cou][1][0] records[cou, 3] = outputs[cou][2] import plotly.graph_objects as go trace0 = go.Mesh3d(x=records[:, 0], y=records[:, 1], z=records[:, 2], opacity=0.5, name='label' ) trace1 = go.Mesh3d(x=records[:, 0], y=records[:, 1], z=records[:, 3], opacity=0.5, name='pred' ) fig = go.Figure(data=[trace0, trace1]) fig.update_layout( scene=dict( # xaxis=dict(nticks=4, range=[-100, 100], ), # yaxis=dict(nticks=4, range=[-50, 100], ), # zaxis=dict(nticks=4, range=[-100, 100], ), aspectratio=dict(x=1, y=1, z=0.5), ), ) fig.show() def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer def load_pretrain_parameters(self): """ 载入预训练参数 """ pass