Browse Source

修改根据版本号和当前fold数得到ckpt地址的函数的BUG; 加入基于JDLU激活函数的网络; 改变网络的初始化方

master
shenyan 4 years ago
parent
commit
ea7bd5d762
5 changed files with 70 additions and 12 deletions
  1. +7
    -7
      main.py
  2. +51
    -0
      network/MLP_JDLU.py
  3. +0
    -0
      network/MLP_ReLU.py
  4. +9
    -3
      train_model.py
  5. +3
    -2
      utils.py

+ 7
- 7
main.py View File

@@ -37,9 +37,10 @@ def main(stage,
:param tpu_cores:
:param version_nth: 该folds的第一个版本的版本号
:param path_final_save:
:param every_n_epochs:
:param every_n_epochs: 每n个epoch设置一个检查点
:param save_top_k:
:param kth_fold_start: 从第几个fold开始, 若使用重载训练, 则kth_fold_start为重载第几个fold, 第一个值为0
:param kth_fold_start: 从第几个fold开始, 若使用重载训练, 则kth_fold_start为重载第几个fold, 第一个值为0.
非重载训练的情况下, 可以通过调整该值控制训练的次数
:param k_fold:
"""
# 经常改动的 参数 作为main的输入参数
@@ -59,9 +60,8 @@ def main(stage,
'dropout_p': 0.1,
'n_layers': 2,
'dataset_len': 100000}
# for kth_fold in range(kth_fold_start, k_fold):
for kth_fold in range(kth_fold_start, kth_fold_start+1):
load_checkpoint_path = get_ckpt_path(f'version_{version_nth+kth_fold}')
for kth_fold in range(kth_fold_start, k_fold):
load_checkpoint_path = get_ckpt_path(version_nth, kth_fold)
logger = pl_loggers.TensorBoardLogger('logs/')
dm = DataModule(batch_size=batch_size, num_workers=num_workers, k_fold=k_fold, kth_fold=kth_fold,
dataset_path=dataset_path, config=config)
@@ -99,8 +99,8 @@ def main(stage,


if __name__ == "__main__":
main('fit', max_epochs=2, batch_size=32, precision=16, seed=1234, dataset_path='./dataset', k_fold=5
main('fit', max_epochs=2, batch_size=32, precision=16, seed=1234, dataset_path='./dataset', k_fold=5,
# gpus=1,
# version_nth=8, # 该folds的第一个版本的版本号
# kth_fold_start=0 # 如果需要重载训练, 则指定重载的版本和其位于k_fold的fold数
kth_fold_start=4,
)

+ 51
- 0
network/MLP_JDLU.py View File

@@ -0,0 +1,51 @@
import math

import torch.nn as nn
from network_module.activation import jdlu, JDLU


class MLPLayer(nn.Module):
def __init__(self, dim_in, dim_out, res_coef=0.0, dropout_p=0.1):
super().__init__()
self.linear = nn.Linear(dim_in, dim_out)
self.res_coef = res_coef
self.activation = JDLU(dim_out)
self.dropout = nn.Dropout(dropout_p)
self.ln = nn.LayerNorm(dim_out)

def forward(self, x):
y = self.linear(x)
y = self.activation(y)
y = self.dropout(y)
if self.res_coef == 0:
return y
else:
return self.res_coef * x + y


class MLP_JDLU(nn.Module):
def __init__(self, dim_in, dim, res_coef=0.5, dropout_p=0.1, n_layers=10):
super().__init__()
self.mlp = nn.ModuleList()
self.first_linear = MLPLayer(dim_in, dim)
self.n_layers = n_layers
for i in range(n_layers):
self.mlp.append(MLPLayer(dim, dim, res_coef, dropout_p))
self.final = nn.Linear(dim, 1)
self.apply(self.weight_init)

def forward(self, x):
x = self.first_linear(x)
for layer in self.mlp:
x = layer(x)
x = self.final(x)
return x.squeeze()

@staticmethod
def weight_init(m):
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)

fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(m.bias, -bound, bound)

network/MLP.py → network/MLP_ReLU.py View File


+ 9
- 3
train_model.py View File

@@ -5,7 +5,9 @@ import pytorch_lightning as pl
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
from torch import nn
import torch
from network.MLP import MLP

from network.MLP_JDLU import MLP_JDLU
from network.MLP_ReLU import MLP_ReLU


class TrainModule(pl.LightningModule):
@@ -13,8 +15,12 @@ class TrainModule(pl.LightningModule):
super().__init__()
self.time_sum = None
self.config = config
self.net = MLP(config['dim_in'], config['dim'], config['res_coef'], config['dropout_p'], config['n_layers'])
# TODO 修改网络初始化方式为kaiming分布或者xavier分布
if 1:
self.net = MLP_ReLU(config['dim_in'], config['dim'], config['res_coef'], config['dropout_p'],
config['n_layers'])
else:
self.net = MLP_JDLU(config['dim_in'], config['dim'], config['res_coef'], config['dropout_p'],
config['n_layers'])
self.loss = nn.MSELoss()

def training_step(self, batch, batch_idx):


+ 3
- 2
utils.py View File

@@ -95,10 +95,11 @@ def visual_label(dataset_path, n_classes):
quality=95)


def get_ckpt_path(version_name: string):
if version_name is None:
def get_ckpt_path(version_nth: int, kth_fold: int):
if version_nth is None:
return None
else:
version_name = f'version_{version_nth + kth_fold}'
checkpoints_path = './logs/default/' + version_name + '/checkpoints'
ckpt_path = glob.glob(checkpoints_path + '/*.ckpt')
return ckpt_path[0].replace('\\', '/')


Loading…
Cancel
Save