| @@ -37,9 +37,10 @@ def main(stage, | |||||
| :param tpu_cores: | :param tpu_cores: | ||||
| :param version_nth: 该folds的第一个版本的版本号 | :param version_nth: 该folds的第一个版本的版本号 | ||||
| :param path_final_save: | :param path_final_save: | ||||
| :param every_n_epochs: | |||||
| :param every_n_epochs: 每n个epoch设置一个检查点 | |||||
| :param save_top_k: | :param save_top_k: | ||||
| :param kth_fold_start: 从第几个fold开始, 若使用重载训练, 则kth_fold_start为重载第几个fold, 第一个值为0 | |||||
| :param kth_fold_start: 从第几个fold开始, 若使用重载训练, 则kth_fold_start为重载第几个fold, 第一个值为0. | |||||
| 非重载训练的情况下, 可以通过调整该值控制训练的次数 | |||||
| :param k_fold: | :param k_fold: | ||||
| """ | """ | ||||
| # 经常改动的 参数 作为main的输入参数 | # 经常改动的 参数 作为main的输入参数 | ||||
| @@ -59,9 +60,8 @@ def main(stage, | |||||
| 'dropout_p': 0.1, | 'dropout_p': 0.1, | ||||
| 'n_layers': 2, | 'n_layers': 2, | ||||
| 'dataset_len': 100000} | 'dataset_len': 100000} | ||||
| # for kth_fold in range(kth_fold_start, k_fold): | |||||
| for kth_fold in range(kth_fold_start, kth_fold_start+1): | |||||
| load_checkpoint_path = get_ckpt_path(f'version_{version_nth+kth_fold}') | |||||
| for kth_fold in range(kth_fold_start, k_fold): | |||||
| load_checkpoint_path = get_ckpt_path(version_nth, kth_fold) | |||||
| logger = pl_loggers.TensorBoardLogger('logs/') | logger = pl_loggers.TensorBoardLogger('logs/') | ||||
| dm = DataModule(batch_size=batch_size, num_workers=num_workers, k_fold=k_fold, kth_fold=kth_fold, | dm = DataModule(batch_size=batch_size, num_workers=num_workers, k_fold=k_fold, kth_fold=kth_fold, | ||||
| dataset_path=dataset_path, config=config) | dataset_path=dataset_path, config=config) | ||||
| @@ -99,8 +99,8 @@ def main(stage, | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| main('fit', max_epochs=2, batch_size=32, precision=16, seed=1234, dataset_path='./dataset', k_fold=5 | |||||
| main('fit', max_epochs=2, batch_size=32, precision=16, seed=1234, dataset_path='./dataset', k_fold=5, | |||||
| # gpus=1, | # gpus=1, | ||||
| # version_nth=8, # 该folds的第一个版本的版本号 | # version_nth=8, # 该folds的第一个版本的版本号 | ||||
| # kth_fold_start=0 # 如果需要重载训练, 则指定重载的版本和其位于k_fold的fold数 | |||||
| kth_fold_start=4, | |||||
| ) | ) | ||||
| @@ -0,0 +1,51 @@ | |||||
| import math | |||||
| import torch.nn as nn | |||||
| from network_module.activation import jdlu, JDLU | |||||
| class MLPLayer(nn.Module): | |||||
| def __init__(self, dim_in, dim_out, res_coef=0.0, dropout_p=0.1): | |||||
| super().__init__() | |||||
| self.linear = nn.Linear(dim_in, dim_out) | |||||
| self.res_coef = res_coef | |||||
| self.activation = JDLU(dim_out) | |||||
| self.dropout = nn.Dropout(dropout_p) | |||||
| self.ln = nn.LayerNorm(dim_out) | |||||
| def forward(self, x): | |||||
| y = self.linear(x) | |||||
| y = self.activation(y) | |||||
| y = self.dropout(y) | |||||
| if self.res_coef == 0: | |||||
| return y | |||||
| else: | |||||
| return self.res_coef * x + y | |||||
| class MLP_JDLU(nn.Module): | |||||
| def __init__(self, dim_in, dim, res_coef=0.5, dropout_p=0.1, n_layers=10): | |||||
| super().__init__() | |||||
| self.mlp = nn.ModuleList() | |||||
| self.first_linear = MLPLayer(dim_in, dim) | |||||
| self.n_layers = n_layers | |||||
| for i in range(n_layers): | |||||
| self.mlp.append(MLPLayer(dim, dim, res_coef, dropout_p)) | |||||
| self.final = nn.Linear(dim, 1) | |||||
| self.apply(self.weight_init) | |||||
| def forward(self, x): | |||||
| x = self.first_linear(x) | |||||
| for layer in self.mlp: | |||||
| x = layer(x) | |||||
| x = self.final(x) | |||||
| return x.squeeze() | |||||
| @staticmethod | |||||
| def weight_init(m): | |||||
| if isinstance(m, nn.Linear): | |||||
| nn.init.xavier_normal_(m.weight) | |||||
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight) | |||||
| bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 | |||||
| nn.init.uniform_(m.bias, -bound, bound) | |||||
| @@ -5,7 +5,9 @@ import pytorch_lightning as pl | |||||
| from pytorch_lightning.utilities.types import EPOCH_OUTPUT | from pytorch_lightning.utilities.types import EPOCH_OUTPUT | ||||
| from torch import nn | from torch import nn | ||||
| import torch | import torch | ||||
| from network.MLP import MLP | |||||
| from network.MLP_JDLU import MLP_JDLU | |||||
| from network.MLP_ReLU import MLP_ReLU | |||||
| class TrainModule(pl.LightningModule): | class TrainModule(pl.LightningModule): | ||||
| @@ -13,8 +15,12 @@ class TrainModule(pl.LightningModule): | |||||
| super().__init__() | super().__init__() | ||||
| self.time_sum = None | self.time_sum = None | ||||
| self.config = config | self.config = config | ||||
| self.net = MLP(config['dim_in'], config['dim'], config['res_coef'], config['dropout_p'], config['n_layers']) | |||||
| # TODO 修改网络初始化方式为kaiming分布或者xavier分布 | |||||
| if 1: | |||||
| self.net = MLP_ReLU(config['dim_in'], config['dim'], config['res_coef'], config['dropout_p'], | |||||
| config['n_layers']) | |||||
| else: | |||||
| self.net = MLP_JDLU(config['dim_in'], config['dim'], config['res_coef'], config['dropout_p'], | |||||
| config['n_layers']) | |||||
| self.loss = nn.MSELoss() | self.loss = nn.MSELoss() | ||||
| def training_step(self, batch, batch_idx): | def training_step(self, batch, batch_idx): | ||||
| @@ -95,10 +95,11 @@ def visual_label(dataset_path, n_classes): | |||||
| quality=95) | quality=95) | ||||
| def get_ckpt_path(version_name: string): | |||||
| if version_name is None: | |||||
| def get_ckpt_path(version_nth: int, kth_fold: int): | |||||
| if version_nth is None: | |||||
| return None | return None | ||||
| else: | else: | ||||
| version_name = f'version_{version_nth + kth_fold}' | |||||
| checkpoints_path = './logs/default/' + version_name + '/checkpoints' | checkpoints_path = './logs/default/' + version_name + '/checkpoints' | ||||
| ckpt_path = glob.glob(checkpoints_path + '/*.ckpt') | ckpt_path = glob.glob(checkpoints_path + '/*.ckpt') | ||||
| return ckpt_path[0].replace('\\', '/') | return ckpt_path[0].replace('\\', '/') | ||||