| @@ -37,9 +37,10 @@ def main(stage, | |||
| :param tpu_cores: | |||
| :param version_nth: 该folds的第一个版本的版本号 | |||
| :param path_final_save: | |||
| :param every_n_epochs: | |||
| :param every_n_epochs: 每n个epoch设置一个检查点 | |||
| :param save_top_k: | |||
| :param kth_fold_start: 从第几个fold开始, 若使用重载训练, 则kth_fold_start为重载第几个fold, 第一个值为0 | |||
| :param kth_fold_start: 从第几个fold开始, 若使用重载训练, 则kth_fold_start为重载第几个fold, 第一个值为0. | |||
| 非重载训练的情况下, 可以通过调整该值控制训练的次数 | |||
| :param k_fold: | |||
| """ | |||
| # 经常改动的 参数 作为main的输入参数 | |||
| @@ -59,9 +60,8 @@ def main(stage, | |||
| 'dropout_p': 0.1, | |||
| 'n_layers': 2, | |||
| 'dataset_len': 100000} | |||
| # for kth_fold in range(kth_fold_start, k_fold): | |||
| for kth_fold in range(kth_fold_start, kth_fold_start+1): | |||
| load_checkpoint_path = get_ckpt_path(f'version_{version_nth+kth_fold}') | |||
| for kth_fold in range(kth_fold_start, k_fold): | |||
| load_checkpoint_path = get_ckpt_path(version_nth, kth_fold) | |||
| logger = pl_loggers.TensorBoardLogger('logs/') | |||
| dm = DataModule(batch_size=batch_size, num_workers=num_workers, k_fold=k_fold, kth_fold=kth_fold, | |||
| dataset_path=dataset_path, config=config) | |||
| @@ -99,8 +99,8 @@ def main(stage, | |||
| if __name__ == "__main__": | |||
| main('fit', max_epochs=2, batch_size=32, precision=16, seed=1234, dataset_path='./dataset', k_fold=5 | |||
| main('fit', max_epochs=2, batch_size=32, precision=16, seed=1234, dataset_path='./dataset', k_fold=5, | |||
| # gpus=1, | |||
| # version_nth=8, # 该folds的第一个版本的版本号 | |||
| # kth_fold_start=0 # 如果需要重载训练, 则指定重载的版本和其位于k_fold的fold数 | |||
| kth_fold_start=4, | |||
| ) | |||
| @@ -0,0 +1,51 @@ | |||
| import math | |||
| import torch.nn as nn | |||
| from network_module.activation import jdlu, JDLU | |||
| class MLPLayer(nn.Module): | |||
| def __init__(self, dim_in, dim_out, res_coef=0.0, dropout_p=0.1): | |||
| super().__init__() | |||
| self.linear = nn.Linear(dim_in, dim_out) | |||
| self.res_coef = res_coef | |||
| self.activation = JDLU(dim_out) | |||
| self.dropout = nn.Dropout(dropout_p) | |||
| self.ln = nn.LayerNorm(dim_out) | |||
| def forward(self, x): | |||
| y = self.linear(x) | |||
| y = self.activation(y) | |||
| y = self.dropout(y) | |||
| if self.res_coef == 0: | |||
| return y | |||
| else: | |||
| return self.res_coef * x + y | |||
| class MLP_JDLU(nn.Module): | |||
| def __init__(self, dim_in, dim, res_coef=0.5, dropout_p=0.1, n_layers=10): | |||
| super().__init__() | |||
| self.mlp = nn.ModuleList() | |||
| self.first_linear = MLPLayer(dim_in, dim) | |||
| self.n_layers = n_layers | |||
| for i in range(n_layers): | |||
| self.mlp.append(MLPLayer(dim, dim, res_coef, dropout_p)) | |||
| self.final = nn.Linear(dim, 1) | |||
| self.apply(self.weight_init) | |||
| def forward(self, x): | |||
| x = self.first_linear(x) | |||
| for layer in self.mlp: | |||
| x = layer(x) | |||
| x = self.final(x) | |||
| return x.squeeze() | |||
| @staticmethod | |||
| def weight_init(m): | |||
| if isinstance(m, nn.Linear): | |||
| nn.init.xavier_normal_(m.weight) | |||
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight) | |||
| bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 | |||
| nn.init.uniform_(m.bias, -bound, bound) | |||
| @@ -5,7 +5,9 @@ import pytorch_lightning as pl | |||
| from pytorch_lightning.utilities.types import EPOCH_OUTPUT | |||
| from torch import nn | |||
| import torch | |||
| from network.MLP import MLP | |||
| from network.MLP_JDLU import MLP_JDLU | |||
| from network.MLP_ReLU import MLP_ReLU | |||
| class TrainModule(pl.LightningModule): | |||
| @@ -13,8 +15,12 @@ class TrainModule(pl.LightningModule): | |||
| super().__init__() | |||
| self.time_sum = None | |||
| self.config = config | |||
| self.net = MLP(config['dim_in'], config['dim'], config['res_coef'], config['dropout_p'], config['n_layers']) | |||
| # TODO 修改网络初始化方式为kaiming分布或者xavier分布 | |||
| if 1: | |||
| self.net = MLP_ReLU(config['dim_in'], config['dim'], config['res_coef'], config['dropout_p'], | |||
| config['n_layers']) | |||
| else: | |||
| self.net = MLP_JDLU(config['dim_in'], config['dim'], config['res_coef'], config['dropout_p'], | |||
| config['n_layers']) | |||
| self.loss = nn.MSELoss() | |||
| def training_step(self, batch, batch_idx): | |||
| @@ -95,10 +95,11 @@ def visual_label(dataset_path, n_classes): | |||
| quality=95) | |||
| def get_ckpt_path(version_name: string): | |||
| if version_name is None: | |||
| def get_ckpt_path(version_nth: int, kth_fold: int): | |||
| if version_nth is None: | |||
| return None | |||
| else: | |||
| version_name = f'version_{version_nth + kth_fold}' | |||
| checkpoints_path = './logs/default/' + version_name + '/checkpoints' | |||
| ckpt_path = glob.glob(checkpoints_path + '/*.ckpt') | |||
| return ckpt_path[0].replace('\\', '/') | |||