From ea7bd5d762f2adcf4e21c676882dfbb32cab7f3a Mon Sep 17 00:00:00 2001 From: shenyan <23357320@qq.com> Date: Thu, 14 Oct 2021 15:36:01 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=A0=B9=E6=8D=AE=E7=89=88?= =?UTF-8?q?=E6=9C=AC=E5=8F=B7=E5=92=8C=E5=BD=93=E5=89=8Dfold=E6=95=B0?= =?UTF-8?q?=E5=BE=97=E5=88=B0ckpt=E5=9C=B0=E5=9D=80=E7=9A=84=E5=87=BD?= =?UTF-8?q?=E6=95=B0=E7=9A=84BUG;=20=E5=8A=A0=E5=85=A5=E5=9F=BA=E4=BA=8EJD?= =?UTF-8?q?LU=E6=BF=80=E6=B4=BB=E5=87=BD=E6=95=B0=E7=9A=84=E7=BD=91?= =?UTF-8?q?=E7=BB=9C;=20=E6=94=B9=E5=8F=98=E7=BD=91=E7=BB=9C=E7=9A=84?= =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96=E6=96=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 14 ++++----- network/MLP_JDLU.py | 51 +++++++++++++++++++++++++++++++++ network/{MLP.py => MLP_ReLU.py} | 0 train_model.py | 12 ++++++-- utils.py | 5 ++-- 5 files changed, 70 insertions(+), 12 deletions(-) create mode 100644 network/MLP_JDLU.py rename network/{MLP.py => MLP_ReLU.py} (100%) diff --git a/main.py b/main.py index 474fc0e..de437ca 100644 --- a/main.py +++ b/main.py @@ -37,9 +37,10 @@ def main(stage, :param tpu_cores: :param version_nth: 该folds的第一个版本的版本号 :param path_final_save: - :param every_n_epochs: + :param every_n_epochs: 每n个epoch设置一个检查点 :param save_top_k: - :param kth_fold_start: 从第几个fold开始, 若使用重载训练, 则kth_fold_start为重载第几个fold, 第一个值为0 + :param kth_fold_start: 从第几个fold开始, 若使用重载训练, 则kth_fold_start为重载第几个fold, 第一个值为0. + 非重载训练的情况下, 可以通过调整该值控制训练的次数 :param k_fold: """ # 经常改动的 参数 作为main的输入参数 @@ -59,9 +60,8 @@ def main(stage, 'dropout_p': 0.1, 'n_layers': 2, 'dataset_len': 100000} - # for kth_fold in range(kth_fold_start, k_fold): - for kth_fold in range(kth_fold_start, kth_fold_start+1): - load_checkpoint_path = get_ckpt_path(f'version_{version_nth+kth_fold}') + for kth_fold in range(kth_fold_start, k_fold): + load_checkpoint_path = get_ckpt_path(version_nth, kth_fold) logger = pl_loggers.TensorBoardLogger('logs/') dm = DataModule(batch_size=batch_size, num_workers=num_workers, k_fold=k_fold, kth_fold=kth_fold, dataset_path=dataset_path, config=config) @@ -99,8 +99,8 @@ def main(stage, if __name__ == "__main__": - main('fit', max_epochs=2, batch_size=32, precision=16, seed=1234, dataset_path='./dataset', k_fold=5 + main('fit', max_epochs=2, batch_size=32, precision=16, seed=1234, dataset_path='./dataset', k_fold=5, # gpus=1, # version_nth=8, # 该folds的第一个版本的版本号 - # kth_fold_start=0 # 如果需要重载训练, 则指定重载的版本和其位于k_fold的fold数 + kth_fold_start=4, ) diff --git a/network/MLP_JDLU.py b/network/MLP_JDLU.py new file mode 100644 index 0000000..fe4baa7 --- /dev/null +++ b/network/MLP_JDLU.py @@ -0,0 +1,51 @@ +import math + +import torch.nn as nn +from network_module.activation import jdlu, JDLU + + +class MLPLayer(nn.Module): + def __init__(self, dim_in, dim_out, res_coef=0.0, dropout_p=0.1): + super().__init__() + self.linear = nn.Linear(dim_in, dim_out) + self.res_coef = res_coef + self.activation = JDLU(dim_out) + self.dropout = nn.Dropout(dropout_p) + self.ln = nn.LayerNorm(dim_out) + + def forward(self, x): + y = self.linear(x) + y = self.activation(y) + y = self.dropout(y) + if self.res_coef == 0: + return y + else: + return self.res_coef * x + y + + +class MLP_JDLU(nn.Module): + def __init__(self, dim_in, dim, res_coef=0.5, dropout_p=0.1, n_layers=10): + super().__init__() + self.mlp = nn.ModuleList() + self.first_linear = MLPLayer(dim_in, dim) + self.n_layers = n_layers + for i in range(n_layers): + self.mlp.append(MLPLayer(dim, dim, res_coef, dropout_p)) + self.final = nn.Linear(dim, 1) + self.apply(self.weight_init) + + def forward(self, x): + x = self.first_linear(x) + for layer in self.mlp: + x = layer(x) + x = self.final(x) + return x.squeeze() + + @staticmethod + def weight_init(m): + if isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(m.bias, -bound, bound) diff --git a/network/MLP.py b/network/MLP_ReLU.py similarity index 100% rename from network/MLP.py rename to network/MLP_ReLU.py diff --git a/train_model.py b/train_model.py index aba7aef..40e2116 100644 --- a/train_model.py +++ b/train_model.py @@ -5,7 +5,9 @@ import pytorch_lightning as pl from pytorch_lightning.utilities.types import EPOCH_OUTPUT from torch import nn import torch -from network.MLP import MLP + +from network.MLP_JDLU import MLP_JDLU +from network.MLP_ReLU import MLP_ReLU class TrainModule(pl.LightningModule): @@ -13,8 +15,12 @@ class TrainModule(pl.LightningModule): super().__init__() self.time_sum = None self.config = config - self.net = MLP(config['dim_in'], config['dim'], config['res_coef'], config['dropout_p'], config['n_layers']) - # TODO 修改网络初始化方式为kaiming分布或者xavier分布 + if 1: + self.net = MLP_ReLU(config['dim_in'], config['dim'], config['res_coef'], config['dropout_p'], + config['n_layers']) + else: + self.net = MLP_JDLU(config['dim_in'], config['dim'], config['res_coef'], config['dropout_p'], + config['n_layers']) self.loss = nn.MSELoss() def training_step(self, batch, batch_idx): diff --git a/utils.py b/utils.py index 74d6b7a..8e43915 100644 --- a/utils.py +++ b/utils.py @@ -95,10 +95,11 @@ def visual_label(dataset_path, n_classes): quality=95) -def get_ckpt_path(version_name: string): - if version_name is None: +def get_ckpt_path(version_nth: int, kth_fold: int): + if version_nth is None: return None else: + version_name = f'version_{version_nth + kth_fold}' checkpoints_path = './logs/default/' + version_name + '/checkpoints' ckpt_path = glob.glob(checkpoints_path + '/*.ckpt') return ckpt_path[0].replace('\\', '/')