Browse Source

修改了evaluate中的拼写错误;使用自定义激活函数jdlu; 添加由版本名获得对应版本的ckpt的函数;改变存在保存路径的情况下的保存模式; 开发k折交叉验证; 增加了excel表格读写函数

master
shenyan 4 years ago
parent
commit
d7bfcf0830
6 changed files with 156 additions and 206 deletions
  1. +42
    -28
      data_module.py
  2. +0
    -0
      evaluate.py
  3. +62
    -61
      main.py
  4. +5
    -3
      network/MLP.py
  5. +5
    -10
      save_checkpoint.py
  6. +42
    -104
      utils.py

+ 42
- 28
data_module.py View File

@@ -1,45 +1,59 @@
import os
import numpy
import torch
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl


class DataModule(pl.LightningDataModule):
def __init__(self, batch_size, num_workers, config=None):
def __init__(self, batch_size, num_workers, k_fold, kth_fold, dataset_path, config=None):
super().__init__()
# TODO 使用k折交叉验证
# divide_dataset(config['dataset_path'], [0.8, 0, 0.2])
if config['flag']:
self.x = torch.randn(config['dataset_len'], 2)
noise = torch.randn(config['dataset_len'], )
self.y = 1.0 * self.x[:, 0] + 2.0 * self.x[:, 1] + noise
else:
x_1 = torch.randn(config['dataset_len'])
x_2 = torch.randn(config['dataset_len'])
x_useful = torch.cos(1.5 * x_1) * (x_2 ** 2)
x_1_rest_small = torch.randn(config['dataset_len'], 15) + 0.01 * x_1.unsqueeze(1)
x_1_rest_large = torch.randn(config['dataset_len'], 15) + 0.1 * x_1.unsqueeze(1)
x_2_rest_small = torch.randn(config['dataset_len'], 15) + 0.01 * x_2.unsqueeze(1)
x_2_rest_large = torch.randn(config['dataset_len'], 15) + 0.1 * x_2.unsqueeze(1)
self.x = torch.cat([x_1[:, None], x_2[:, None], x_1_rest_small, x_1_rest_large, x_2_rest_small, x_2_rest_large],
dim=1)
self.y = (10 * x_useful) + 5 * torch.randn(config['dataset_len'])

self.y_train, self.y_test = self.y[:50000], self.y[50000:]
self.x_train, self.x_test = self.x[:50000, :], self.x[50000:, :]

self.batch_size = batch_size
self.num_workers = num_workers
self.config = config
self.train_dataset = None
self.val_dataset = None
self.test_dataset = None
self.k_fold = k_fold
self.kth_fold = kth_fold
self.dataset_path = dataset_path

def setup(self, stage=None) -> None:
# 得到全部数据的list
# dataset_list = get_dataset_list(dataset_path)
x, y = self.get_fit_dataset_list()
if stage == 'fit' or stage is None:
self.train_dataset = CustomDataset(self.x_train, self.y_train, self.config)
self.val_dataset = CustomDataset(self.x_test, self.y_test, self.config)
x_train, y_train, x_val, y_val = self.get_dataset_lists(x, y)
self.train_dataset = CustomDataset(x_train, y_train, self.config)
self.val_dataset = CustomDataset(x_val, y_val, self.config)
if stage == 'test' or stage is None:
self.test_dataset = CustomDataset(self.x, self.y, self.config)
self.test_dataset = CustomDataset(x, y, self.config)

def get_fit_dataset_list(self):
if not os.path.exists(self.dataset_path + '/dataset_list.txt'):
x = torch.randn(self.config['dataset_len'], self.config['dim_in'])
noise = torch.randn(self.config['dataset_len'])
y = torch.cos(1.5 * x[:, 0]) * (x[:, 1] ** 2.0) + noise
with open(self.dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f:
for line in range(self.config['dataset_len']):
f.write(' '.join([str(temp) for temp in x[line].tolist()]) + ' ' + str(y[line].item()) + '\n')
print('已生成新的数据list')
else:
dataset_list = open(self.dataset_path + '/dataset_list.txt').readlines()
dataset_list = [[float(temp) for temp in item.strip('\n').split(' ')] for item in dataset_list]
x = torch.from_numpy(numpy.array(dataset_list)[:, 0:self.config['dim_in']]).float()
y = torch.from_numpy(numpy.array(dataset_list)[:, self.config['dim_in']]).float()
return x, y

def get_dataset_lists(self, x: Tensor, y):
# 得到一个fold的数据量和不够组成一个fold的剩余数据的数据量
num_1fold, remainder = divmod(self.config['dataset_len'], self.k_fold)
# 分割全部数据, 得到训练集, 验证集, 测试集
x_val = x[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder)]
y_val = y[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder)]
temp = torch.ones(x.shape[0])
temp[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder)] = 0
x_train = x[temp == 1]
y_train = y[temp == 1]
return x_train, y_train, x_val, y_val

def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers,


evalute.py → evaluate.py View File


+ 62
- 61
main.py View File

@@ -5,20 +5,23 @@ import pytorch_lightning as pl
from train_model import TrainModule
from multiprocessing import cpu_count

from utils import get_ckpt_path


def main(stage,
max_epochs,
batch_size,
precision,
seed,
dataset_path=None,
dataset_path,
gpus=None,
tpu_cores=None,
load_checkpoint_path=None,
save_name=None,
version_nth=None,
path_final_save=None,
every_n_epochs=1,
save_top_k=1,):
save_top_k=1,
k_fold=5,
kth_fold_start=0):
"""
框架的入口函数. 包含设置超参数, 划分数据集, 选择训练或测试等流程
该函数的参数为训练过程中需要经常改动的参数
@@ -32,74 +35,72 @@ def main(stage,
:param dataset_path: 数据集地址, 其目录下包含数据集, 标签, 全部数据的命名list
:param gpus:
:param tpu_cores:
:param load_checkpoint_path:
:param save_name:
:param version_nth: 该folds的第一个版本的版本号
:param path_final_save:
:param every_n_epochs:
:param save_top_k:
:param kth_fold_start: 从第几个fold开始, 若使用重载训练, 则kth_fold_start为重载第几个fold, 第一个值为0
:param k_fold:
"""
# 经常改动的 参数 作为main的输入参数
# 不常改动的 非通用参数 存放在config
# 不常改动的 通用参数 直接进行声明
# 不常改动的 通用参数 直接进行声明
# 通用参数指的是所有网络中共有的参数, 如time_sum等
if True:
config = {'dataset_path': dataset_path,
'dim_in': 2,
'dim': 10,
'res_coef': 0.5,
'dropout_p': 0.1,
'n_layers': 2,
'dataset_len': 100000,
'flag': True}
else:
config = {'dataset_path': dataset_path,
'dim_in': 62,
'dim': 32,
'res_coef': 0.5,
'dropout_p': 0.1,
'n_layers': 20,
'dataset_len': 100000,
'flag': False}

# 处理输入数据
precision = 32 if (gpus is None and tpu_cores is None) else precision
# 获得通用参数
# TODO 获得最优的batch size
num_workers = cpu_count()
precision = 32 if (gpus is None and tpu_cores is None) else precision
dm = DataModule(batch_size=batch_size, num_workers=num_workers, config=config)
logger = pl_loggers.TensorBoardLogger('logs/')
if stage == 'fit':
# SaveCheckpoint的创建需要在TrainModule之前, 以保证网络参数初始化的确定性
save_checkpoint = SaveCheckpoint(seed=seed, max_epochs=max_epochs,
save_name=save_name, path_final_save=path_final_save,
every_n_epochs=every_n_epochs, verbose=True,
monitor='Validation loss', save_top_k=save_top_k,
mode='min')
training_module = TrainModule(config=config)
if load_checkpoint_path is None:
print('进行初始训练')
trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores,
logger=logger, precision=precision, callbacks=[save_checkpoint])
training_module.load_pretrain_parameters()
else:
print('进行重载训练')
trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores,
resume_from_checkpoint='./logs/default' + load_checkpoint_path,
logger=logger, precision=precision, callbacks=[save_checkpoint])
print('训练过程中请注意gpu利用率等情况')
trainer.fit(training_module, datamodule=dm)
if stage == 'test':
if load_checkpoint_path is None:
print('未载入权重信息,不能测试')
else:
print('进行测试')
training_module = TrainModule.load_from_checkpoint(
checkpoint_path='./logs/default' + load_checkpoint_path,
**{'config': config})
trainer = pl.Trainer(gpus=gpus, tpu_cores=tpu_cores, logger=logger, precision=precision)
trainer.test(training_module, datamodule=dm)
# 在cmd中使用tensorboard --logdir logs命令可以查看结果,在Jupyter格式下需要加%前缀
# 获得非通用参数
config = {'dim_in': 2,
'dim': 10,
'res_coef': 0.5,
'dropout_p': 0.1,
'n_layers': 2,
'dataset_len': 100000}
# for kth_fold in range(kth_fold_start, k_fold):
for kth_fold in range(kth_fold_start, kth_fold_start+1):
load_checkpoint_path = get_ckpt_path(f'version_{version_nth+kth_fold}')
logger = pl_loggers.TensorBoardLogger('logs/')
dm = DataModule(batch_size=batch_size, num_workers=num_workers, k_fold=k_fold, kth_fold=kth_fold,
dataset_path=dataset_path, config=config)
if stage == 'fit':
# SaveCheckpoint的创建需要在TrainModule之前, 以保证网络参数初始化的确定性
save_checkpoint = SaveCheckpoint(seed=seed, max_epochs=max_epochs,
path_final_save=path_final_save,
every_n_epochs=every_n_epochs, verbose=True,
monitor='Validation loss', save_top_k=save_top_k,
mode='min')
training_module = TrainModule(config=config)
if kth_fold != kth_fold_start or load_checkpoint_path is None:
print('进行初始训练')
trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores,
logger=logger, precision=precision, callbacks=[save_checkpoint])
training_module.load_pretrain_parameters()
else:
print('进行重载训练')
trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores,
resume_from_checkpoint='./logs/default' + load_checkpoint_path,
logger=logger, precision=precision, callbacks=[save_checkpoint])
print('训练过程中请注意gpu利用率等情况')
trainer.fit(training_module, datamodule=dm)
if stage == 'test':
if load_checkpoint_path is None:
print('未载入权重信息,不能测试')
else:
print('进行测试')
training_module = TrainModule.load_from_checkpoint(
checkpoint_path='./logs/default' + load_checkpoint_path,
**{'config': config})
trainer = pl.Trainer(gpus=gpus, tpu_cores=tpu_cores, logger=logger, precision=precision)
trainer.test(training_module, datamodule=dm)
# 在cmd中使用tensorboard --logdir logs命令可以查看结果,在Jupyter格式下需要加%前缀


if __name__ == "__main__":
main('fit', max_epochs=5, batch_size=32, precision=16, seed=1234,
main('fit', max_epochs=2, batch_size=32, precision=16, seed=1234, dataset_path='./dataset', k_fold=5
# gpus=1,
# load_checkpoint_path='/version_4/checkpoints/epoch=4-step=7814.ckpt',
# version_nth=8, # 该folds的第一个版本的版本号
# kth_fold_start=0 # 如果需要重载训练, 则指定重载的版本和其位于k_fold的fold数
)

+ 5
- 3
network/MLP.py View File

@@ -1,5 +1,5 @@
import torch.nn as nn
from torch.nn import Mish
from network_module.activation import jdlu, JDLU


class MLPLayer(nn.Module):
@@ -7,13 +7,15 @@ class MLPLayer(nn.Module):
super().__init__()
self.linear = nn.Linear(dim_in, dim_out)
self.res_coef = res_coef
self.activation = Mish()
self.activation = nn.ReLU()
self.activation1 = JDLU(dim_out)
self.dropout = nn.Dropout(dropout_p)
self.ln = nn.LayerNorm(dim_out)

def forward(self, x):
y = self.linear(x)
y = self.activation(y)
y = self.activation1(y)
# y = jdlu(y)
y = self.dropout(y)
if self.res_coef == 0:
return self.ln(y)


+ 5
- 10
save_checkpoint.py View File

@@ -1,10 +1,8 @@
import os

import numpy.random
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl
import shutil
import random
from pytorch_lightning.utilities import rank_zero_info
from utils import zip_dir

@@ -14,7 +12,6 @@ class SaveCheckpoint(ModelCheckpoint):
max_epochs,
seed=None,
every_n_epochs=None,
save_name=None,
path_final_save=None,
monitor=None,
save_top_k=None,
@@ -27,7 +24,6 @@ class SaveCheckpoint(ModelCheckpoint):
:param max_epochs:
:param seed:
:param every_n_epochs:
:param save_name:
:param path_final_save:
:param monitor:
:param save_top_k:
@@ -39,7 +35,6 @@ class SaveCheckpoint(ModelCheckpoint):
numpy.random.seed(seed)
self.seeds = numpy.random.randint(0, 2000, max_epochs)
pl.seed_everything(seed)
self.save_name = save_name
self.path_final_save = path_final_save
self.monitor = monitor
self.save_top_k = save_top_k
@@ -71,11 +66,11 @@ class SaveCheckpoint(ModelCheckpoint):

if self.check_monitor_top_k(trainer, current):
self._update_best_and_save(current, trainer, monitor_candidates)
if self.save_name is not None and self.path_final_save is not None:
zip_dir('./logs/default/' + self.save_name, './' + self.save_name + '.zip')
if os.path.exists(self.path_final_save + '/' + self.save_name + '.zip'):
os.remove(self.path_final_save + '/' + self.save_name + '.zip')
shutil.move('./' + self.save_name + '.zip', self.path_final_save)
if self.path_final_save is not None:
zip_dir('./logs', './logs.zip')
if os.path.exists(self.path_final_save + '/logs.zip'):
os.remove(self.path_final_save + '/logs.zip')
shutil.move('./logs.zip', self.path_final_save)
elif self.verbose:
epoch = monitor_candidates.get("epoch")
step = monitor_candidates.get("step")


+ 42
- 104
utils.py View File

@@ -2,120 +2,25 @@
import glob
import os
import random
import string
import zipfile
import cv2
import numpy
import torch


def divide_dataset(dataset_path, rate_datasets):
"""
切分数据集, 划分为训练集,验证集,测试集生成list文件并保存为:
train_dataset_list、validate_dataset_list、test_dataset_list.
每个比例必须大于0且保证至少每个数据集中具有一个样本, 验证集可以为0.

:param dataset_path: 数据集的地址
:param rate_datasets: 不同数据集[训练集,验证集,测试集]的比例
"""
# 当不存在总的all_dataset_list文件时, 生成all_dataset_list
if not os.path.exists(dataset_path + '/all_dataset_list.txt'):
def get_dataset_list(dataset_path):
if not os.path.exists(dataset_path + '/dataset_list.txt'):
all_list = glob.glob(dataset_path + '/labels' + '/*.png')
with open(dataset_path + '/all_dataset_list.txt', 'w', encoding='utf-8') as f:
random.shuffle(all_list)
with open(dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f:
for line in all_list:
f.write(os.path.basename(line.replace('\\', '/')) + '\n')
path_train_dataset_list = dataset_path + '/train_dataset_list.txt'
path_validate_dataset_list = dataset_path + '/validate_dataset_list.txt'
path_test_dataset_list = dataset_path + '/test_dataset_list.txt'
# 如果验证集的比例为0,则将测试集设置为验证集并取消测试集;
if rate_datasets[1] == 0:
# 如果无切分后的list文件, 则生成新的list文件
if not (os.path.exists(path_train_dataset_list) and
os.path.exists(path_validate_dataset_list) and
os.path.exists(path_test_dataset_list)):
all_list = open(dataset_path + '/all_dataset_list.txt').readlines()
random.shuffle(all_list)
train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])]
test_dataset_list = all_list[int(len(all_list) * rate_datasets[0]):]
with open(path_train_dataset_list, 'w', encoding='utf-8') as f:
for line in train_dataset_list:
f.write(line)
with open(path_validate_dataset_list, 'w', encoding='utf-8') as f:
for line in test_dataset_list:
f.write(line)
with open(path_test_dataset_list, 'w', encoding='utf-8') as f:
for line in test_dataset_list:
f.write(line)
print('已生成新的数据list')
else:
# 判断比例是否正确,如果不正确,则重新生成数据集
all_list = open(dataset_path + '/all_dataset_list.txt').readlines()
with open(path_train_dataset_list) as f:
train_dataset_list_exist = f.readlines()
with open(path_validate_dataset_list) as f:
test_dataset_list_exist = f.readlines()
random.shuffle(all_list)
train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])]
test_dataset_list = all_list[int(len(all_list) * rate_datasets[0]):]
if not (len(train_dataset_list_exist) == len(train_dataset_list) and
len(test_dataset_list_exist) == len(test_dataset_list)):
with open(path_train_dataset_list, 'w', encoding='utf-8') as f:
for line in train_dataset_list:
f.write(line)
with open(path_validate_dataset_list, 'w', encoding='utf-8') as f:
for line in test_dataset_list:
f.write(line)
with open(path_test_dataset_list, 'w', encoding='utf-8') as f:
for line in test_dataset_list:
f.write(line)
print('已生成新的数据list')
# 如果验证集比例不为零,则同时存在验证集和测试集
return all_list
else:
# 如果无切分后的list文件, 则生成新的list文件
if not (os.path.exists(dataset_path + '/train_dataset_list.txt') and
os.path.exists(dataset_path + '/validate_dataset_list.txt') and
os.path.exists(dataset_path + '/test_dataset_list.txt')):
all_list = open(dataset_path + '/all_dataset_list.txt').readlines()
random.shuffle(all_list)
train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])]
validate_dataset_list = all_list[int(len(all_list) * rate_datasets[0]):
int(len(all_list) * (rate_datasets[0] + rate_datasets[1]))]
test_dataset_list = all_list[int(len(all_list) * (rate_datasets[0] + rate_datasets[1])):]
with open(path_train_dataset_list, 'w', encoding='utf-8') as f:
for line in train_dataset_list:
f.write(line)
with open(path_validate_dataset_list, 'w', encoding='utf-8') as f:
for line in validate_dataset_list:
f.write(line)
with open(path_test_dataset_list, 'w', encoding='utf-8') as f:
for line in test_dataset_list:
f.write(line)
print('已生成新的数据list')
else:
# 判断比例是否正确,如果不正确,则重新生成数据集
all_list = open(dataset_path + '/all_dataset_list.txt').readlines()
with open(path_train_dataset_list) as f:
train_dataset_list_exist = f.readlines()
with open(path_validate_dataset_list) as f:
validate_dataset_list_exist = f.readlines()
with open(path_test_dataset_list) as f:
test_dataset_list_exist = f.readlines()
random.shuffle(all_list)
train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])]
validate_dataset_list = all_list[int(len(all_list) * rate_datasets[0]):
int(len(all_list) * (rate_datasets[0] + rate_datasets[1]))]
test_dataset_list = all_list[int(len(all_list) * (rate_datasets[0] + rate_datasets[1])):]
if not (len(train_dataset_list_exist) == len(train_dataset_list) and
len(validate_dataset_list_exist) == len(validate_dataset_list) and
len(test_dataset_list_exist) == len(test_dataset_list)):
with open(path_train_dataset_list, 'w', encoding='utf-8') as f:
for line in train_dataset_list:
f.write(line)
with open(path_validate_dataset_list, 'w', encoding='utf-8') as f:
for line in validate_dataset_list:
f.write(line)
with open(path_test_dataset_list, 'w', encoding='utf-8') as f:
for line in test_dataset_list:
f.write(line)
print('已生成新的数据list')
all_list = open(dataset_path + '/all_dataset_list.txt').readlines()
return all_list


def zip_dir(dir_path, zip_path):
@@ -142,6 +47,7 @@ def ncolors(num_colors):
:param num_colors: 颜色数
:return:
"""

def get_n_hls_colors(num):
import random
hls_colors = []
@@ -187,3 +93,35 @@ def visual_label(dataset_path, n_classes):
trans_factory(torch.from_numpy(label_image).float() / n_classes).save(
dataset_path + '/visual_label/' + name,
quality=95)


def get_ckpt_path(version_name: string):
if version_name is None:
return None
else:
checkpoints_path = './logs/default/' + version_name + '/checkpoints'
ckpt_path = glob.glob(checkpoints_path + '/*.ckpt')
return ckpt_path[0].replace('\\', '/')


def rwxl():
# 写
# dataset_xl = xl.Workbook(write_only=True)
# dataset_sh = dataset_xl.create_sheet('dataset', 0)
# for row in range(self.x.shape[0]):
# for col in range(self.x.shape[1]):
# dataset_sh.cell(row + 1, col + 1).value = float(self.x[row, col])
# dataset_sh.cell(row + 1, self.x.shape[1] + 1).value = float(self.y[row])
# dataset_xl.save(dataset_path + '/dataset.xlsx')
# dataset_xl.close()
# 读
# dataset_xl = xl.load_workbook(dataset_path + '/dataset_list.xlsx', read_only=True)
# dataset_sh = dataset_xl.get_sheet_by_name('dataset_list')
# temp = [[dataset_sh[row + 1][col].value for col in range(config['dim_in'] + 1)] for row in
# range(config['dataset_len'])]
# dataset_xl.close()
pass


if __name__ == "__main__":
get_ckpt_path('version_0')

Loading…
Cancel
Save