From 52ed4770bb4480c0cc5fdf4f4f9b053383d4881e Mon Sep 17 00:00:00 2001 From: shenyan <23357320@qq.com> Date: Thu, 14 Oct 2021 15:29:39 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E7=A5=9E=E7=BB=8F=E7=BD=91?= =?UTF-8?q?=E7=BB=9C=E8=AE=AD=E7=BB=83=E6=A1=86=E6=9E=B6;=20=E5=AE=8C?= =?UTF-8?q?=E6=88=90=E5=88=86=E7=B1=BB=E4=BB=BB=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- custom_profile.py | 21 ++++ data_module.py | 65 +++++++++++ evalute.py | 38 +++++++ free_servers/bscc-run.sh | 19 ++++ free_servers/hy-scaffold.ipynb | 92 ++++++++++++++++ main.py | 99 +++++++++++++++++ network/MLP.py | 40 +++++++ network_module/__init__.py | 1 + network_module/compute_utils.py | 32 ++++++ network_module/iou.py | 52 +++++++++ network_module/pix_acc.py | 25 +++++ requirements.txt | Bin 0 -> 1990 bytes save_checkpoint.py | 90 +++++++++++++++ train_model.py | 47 ++++++++ utils.py | 189 ++++++++++++++++++++++++++++++++ 15 files changed, 810 insertions(+) create mode 100644 custom_profile.py create mode 100644 data_module.py create mode 100644 evalute.py create mode 100644 free_servers/bscc-run.sh create mode 100644 free_servers/hy-scaffold.ipynb create mode 100644 main.py create mode 100644 network/MLP.py create mode 100644 network_module/__init__.py create mode 100644 network_module/compute_utils.py create mode 100644 network_module/iou.py create mode 100644 network_module/pix_acc.py create mode 100644 requirements.txt create mode 100644 save_checkpoint.py create mode 100644 train_model.py create mode 100644 utils.py diff --git a/custom_profile.py b/custom_profile.py new file mode 100644 index 0000000..c8e26dd --- /dev/null +++ b/custom_profile.py @@ -0,0 +1,21 @@ +""" +测试.py文件各个函数的运行时间 +""" +import pstats + +# 在cmd中使用python -m cProfile -o result.cprofile main.py + +# 创建 Stats 对象 +p = pstats.Stats('./result.cprofile') +# 按照运行时间和函数名进行排序 +# 按照函数名排序,只打印前n行函数(其中n为print_stats(n)的输入参数)的信息, +p.strip_dirs().sort_stats("cumulative", "name").print_stats(15) +# 参数还可为小数, 表示前n(其中n为一个小于1的百分数, 是print_stats(n)的输入参数)的函数信息 +# p.strip_dirs().sort_stats("cumulative", "name").print_stats(0.5) +# 查看调用main()的函数 +# p.print_callers(0.5, "main") +# 查看main()函数中调用的函数 +# p.print_callees("main") + +# pip安装snakeviz后,在cmd里运行如下命令: +# snakeviz result.out diff --git a/data_module.py b/data_module.py new file mode 100644 index 0000000..fe3b41a --- /dev/null +++ b/data_module.py @@ -0,0 +1,65 @@ +import torch +from torch.utils.data import Dataset, DataLoader +import pytorch_lightning as pl + + +class DataModule(pl.LightningDataModule): + def __init__(self, batch_size, num_workers, config=None): + super().__init__() + # TODO 使用k折交叉验证 + # divide_dataset(config['dataset_path'], [0.8, 0, 0.2]) + if config['flag']: + x = torch.randn(100000, 2) + noise = torch.randn(100000, ) + y = ((1.0 * x[:, 0] + 2.0 * x[:, 1] + noise) > 0).type(torch.int64) + else: + x_1 = torch.randn(100000) + x_2 = torch.randn(100000) + x_useful = torch.cos(1.5 * x_1) * (x_2 ** 2) + x_1_rest_small = torch.randn(100000, 15) + 0.01 * x_1.unsqueeze(1) + x_1_rest_large = torch.randn(100000, 15) + 0.1 * x_1.unsqueeze(1) + x_2_rest_small = torch.randn(100000, 15) + 0.01 * x_2.unsqueeze(1) + x_2_rest_large = torch.randn(100000, 15) + 0.1 * x_2.unsqueeze(1) + x = torch.cat([x_1[:, None], x_2[:, None], x_1_rest_small, x_1_rest_large, x_2_rest_small, x_2_rest_large], + dim=1) + y = ((10 * x_useful) + 5 * torch.randn(100000) > 0.0).type(torch.int64) + + self.y_train, self.y_test = y[:50000], y[50000:] + self.x_train, self.x_test = x[:50000, :], x[50000:, :] + + self.batch_size = batch_size + self.num_workers = num_workers + self.config = config + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None + + def setup(self, stage=None) -> None: + if stage == 'fit' or stage is None: + self.train_dataset = CustomDataset(self.x_train, self.y_train, self.config) + self.val_dataset = CustomDataset(self.x_test, self.y_test, self.config) + if stage == 'test' or stage is None: + self.test_dataset = CustomDataset(self.x_test, self.y_test, self.config) + + def train_dataloader(self): + return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) + + def val_dataloader(self): + return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers) + + def test_dataloader(self): + return DataLoader(self.test_dataset, batch_size=1, shuffle=False, num_workers=self.num_workers) + + +class CustomDataset(Dataset): + def __init__(self, x, y, config): + super().__init__() + self.x = x + self.y = y + self.config = config + + def __getitem__(self, idx): + return self.x[idx, :], self.y[idx] + + def __len__(self): + return self.x.shape[0] diff --git a/evalute.py b/evalute.py new file mode 100644 index 0000000..be91362 --- /dev/null +++ b/evalute.py @@ -0,0 +1,38 @@ +""" +评估指定文件夹下的预测结果, 评价结果均不计算背景类 +""" +import numpy as np +from PIL import Image +from os.path import join +from network_module.iou import IOU +from network_module.pix_acc import calculate_acc + + +def evalute(n_classes, dataset_path, verbose=False): + iou = IOU(n_classes) + + test_list = open(join(dataset_path, 'test_dataset_list.txt').replace('\\', '/')).readlines() + for ind in range(len(test_list)): + pred = np.array(Image.open(join(dataset_path, 'prediction', test_list[ind].strip('\n')).replace('\\', '/'))) + label = np.array(Image.open(join(dataset_path, 'labels', test_list[ind].strip('\n')).replace('\\', '/'))) + if len(label.flatten()) != len(pred.flatten()): + print('跳过{:s}: pred len {:d} != label len {:d},'.format( + test_list[ind].strip('\n'), len(label.flatten()), len(pred.flatten()))) + continue + + iou.add_data(pred, label) + + # 必须置于iou_loss.forward前,因为forward会清除hist + overall_acc, acc = calculate_acc(iou.hist) + mIoU, IoUs = iou.get_miou() + + if verbose: + for ind_class in range(n_classes): + print('===>' + str(ind_class) + ':\t' + str(IoUs[ind_class].float())) + print('===> mIoU: ' + str(mIoU)) + print('===> overall accuracy:', overall_acc) + print('===> accuracy of each class:', acc) + + +if __name__ == "__main__": + evalute(9, './dataset/MFNet(RGB-T)-mini', verbose=True) diff --git a/free_servers/bscc-run.sh b/free_servers/bscc-run.sh new file mode 100644 index 0000000..84b07c6 --- /dev/null +++ b/free_servers/bscc-run.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +module load anaconda/2020.11 +module load cuda/10.2 +module load cudnn/8.1.1.33_CUDA10.2 + +#conda create --name py37 python=3.7 +source activate py37 + +cd run +cd machineLearningScaffold + +#pip install -r requirements.txt + +python main.py + +#sbatch --gpus=1 ./bscc-run.sh +#当前作业ID: 67417 +#查询作业: parajobs 取消作业: scancel ID \ No newline at end of file diff --git a/free_servers/hy-scaffold.ipynb b/free_servers/hy-scaffold.ipynb new file mode 100644 index 0000000..067788c --- /dev/null +++ b/free_servers/hy-scaffold.ipynb @@ -0,0 +1,92 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "4c4d04db-6fe0-4511-abb3-c746ba869863", + "metadata": {}, + "outputs": [], + "source": [ + "!apt-get update\n", + "!apt-get install p7zip-full -y" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c027a62-8116-4c55-9a90-95c4d90dd087", + "metadata": {}, + "outputs": [], + "source": [ + "!7z x dataset.zip -o/hy-tmp\n", + "!7z x machineLearningScaffold.zip -o/hy-tmp\n", + "!7z x vit_pre-train_checkpoint.zip -o/hy-tmp\n", + "!mv -f /hy-tmp/pre-train_checkpoint/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz /hy-tmp/pre-train_checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "4c6dc6a1-0208-4efb-beaf-b1264a94ba71", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mWARNING: You are using pip version 21.0.1; however, version 21.2.4 is available.\n", + "You should consider upgrading via the '/usr/bin/python3.8 -m pip install --upgrade pip' command.\u001b[0m\n" + ] + } + ], + "source": [ + "!pip install -r /hy-tmp/requirements.txt > /dev/null" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2c074f52-d27e-4605-8b4f-45b11a2f1eb1", + "metadata": {}, + "outputs": [], + "source": [ + "import main\n", + "%load_ext tensorboard" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a5a8d64e-d90e-467f-a3b4-db8b31874044", + "metadata": {}, + "outputs": [], + "source": [ + "main.main('fit', gpus=1, dataset_path='./dataset/MFNet(RGB-T)', num_workers=2, max_epochs=60, batch_size=16, precision=16, seed=1234, # tpu_cores=8,\n", + " checkpoint_every_n_val=1, save_name='version_0', #checkpoint_path='/version_15/checkpoints/epoch=59-step=4739.ckpt',\n", + " path_final_save='./drive/MyDrive', save_top_k=1\n", + " )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/main.py b/main.py new file mode 100644 index 0000000..22b24e8 --- /dev/null +++ b/main.py @@ -0,0 +1,99 @@ +from save_checkpoint import SaveCheckpoint +from data_module import DataModule +from pytorch_lightning import loggers as pl_loggers +import pytorch_lightning as pl +from train_model import TrainModule + + +def main(stage, + num_workers, + max_epochs, + batch_size, + precision, + seed, + dataset_path=None, + gpus=None, + tpu_cores=None, + load_checkpoint_path=None, + save_name=None, + path_final_save=None, + every_n_epochs=1, + save_top_k=1,): + """ + 框架的入口函数. 包含设置超参数, 划分数据集, 选择训练或测试等流程 + 该函数的参数为训练过程中需要经常改动的参数 + + :param stage: 表示处于训练阶段还是测试阶段, fit表示训练, test表示测试 + :param num_workers: + :param max_epochs: + :param batch_size: + :param precision: 训练精度, 正常精度为32, 半精度为16, 也可以是64. 精度代表每个参数的类型所占的位数 + :param seed: + + :param dataset_path: 数据集地址, 其目录下包含数据集, 标签, 全部数据的命名list + :param gpus: + :param tpu_cores: + :param load_checkpoint_path: + :param save_name: + :param path_final_save: + :param every_n_epochs: + :param save_top_k: + """ + # config存放确定模型后不常改动的非通用的参数, 通用参数且不经常带动的直接进行声明 + if False: + config = {'dataset_path': dataset_path, + 'dim_in': 2, + 'dim': 10, + 'res_coef': 0.5, + 'dropout_p': 0.1, + 'n_layers': 2, + 'flag': True} + else: + config = {'dataset_path': dataset_path, + 'dim_in': 62, + 'dim': 32, + 'res_coef': 0.5, + 'dropout_p': 0.1, + 'n_layers': 20, + 'flag': False} + # TODO 获得最优的batch size + # TODO 自动获取CPU核心数并设置num workers + precision = 32 if (gpus is None and tpu_cores is None) else precision + dm = DataModule(batch_size=batch_size, num_workers=num_workers, config=config) + logger = pl_loggers.TensorBoardLogger('logs/') + if stage == 'fit': + training_module = TrainModule(config=config) + save_checkpoint = SaveCheckpoint(seed=seed, max_epochs=max_epochs, + save_name=save_name, path_final_save=path_final_save, + every_n_epochs=every_n_epochs, verbose=True, + monitor='Validation acc', save_top_k=save_top_k, + mode='max') + if load_checkpoint_path is None: + print('进行初始训练') + trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores, + logger=logger, precision=precision, callbacks=[save_checkpoint]) + training_module.load_pretrain_parameters() + else: + print('进行重载训练') + trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores, + resume_from_checkpoint='./logs/default' + load_checkpoint_path, + logger=logger, precision=precision, callbacks=[save_checkpoint]) + trainer.fit(training_module, datamodule=dm) + if stage == 'test': + if load_checkpoint_path is None: + print('未载入权重信息,不能测试') + else: + print('进行测试') + training_module = TrainModule.load_from_checkpoint( + checkpoint_path='./logs/default' + load_checkpoint_path, + **{'config': config}) + trainer = pl.Trainer(gpus=gpus, tpu_cores=tpu_cores, logger=logger, precision=precision) + trainer.test(training_module, datamodule=dm) + # 在cmd中使用tensorboard --logdir logs命令可以查看结果,在Jupyter格式下需要加%前缀 + + +if __name__ == "__main__": + main('fit', num_workers=8, max_epochs=5, batch_size=32, precision=16, seed=1234, + # gpus=1, + # load_checkpoint_path='/version_5/checkpoints/epoch=149-step=7949.ckpt', + ) diff --git a/network/MLP.py b/network/MLP.py new file mode 100644 index 0000000..7b1058d --- /dev/null +++ b/network/MLP.py @@ -0,0 +1,40 @@ +import torch.nn as nn +from torch.nn import Mish + + +class MLPLayer(nn.Module): + def __init__(self, dim_in, dim_out, res_coef=0.0, dropout_p=0.1): + super().__init__() + self.linear = nn.Linear(dim_in, dim_out) + self.res_coef = res_coef + self.activation = Mish() + self.dropout = nn.Dropout(dropout_p) + self.ln = nn.LayerNorm(dim_out) + + def forward(self, x): + y = self.linear(x) + y = self.activation(y) + y = self.dropout(y) + if self.res_coef == 0: + return self.ln(y) + else: + return self.ln(self.res_coef * x + y) + + +class MLP(nn.Module): + def __init__(self, dim_in, dim, res_coef=0.5, dropout_p=0.1, n_layers=10): + super().__init__() + self.mlp = nn.ModuleList() + self.first_linear = MLPLayer(dim_in, dim) + self.n_layers = n_layers + for i in range(n_layers): + self.mlp.append(MLPLayer(dim, dim, res_coef, dropout_p)) + self.final = nn.Linear(dim, 1) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + x = self.first_linear(x) + for layer in self.mlp: + x = layer(x) + x = self.sigmoid(self.final(x)) + return x.squeeze() diff --git a/network_module/__init__.py b/network_module/__init__.py new file mode 100644 index 0000000..f3426ad --- /dev/null +++ b/network_module/__init__.py @@ -0,0 +1 @@ +# 存放一些通用的与网络相关的模块,不随网络的生命周期结束而消失的模块,比如优化器,损失,评估等 diff --git a/network_module/compute_utils.py b/network_module/compute_utils.py new file mode 100644 index 0000000..d231b9a --- /dev/null +++ b/network_module/compute_utils.py @@ -0,0 +1,32 @@ +import torch + + +def one_hot_encoder(input_tensor, n_classes): + """ + 将输入tensor转化为one-hot形式 + + :param input_tensor: + :param n_classes: + :return: + """ + tensor_list = [] + for i in range(n_classes): + temp_prob = input_tensor == i # * torch.ones_like(input_tensor) + tensor_list.append(temp_prob.unsqueeze(1)) + output_tensor = torch.cat(tensor_list, dim=1) + return output_tensor.long() + + +def torch_nanmean(x): + """ + 输出忽略nan的tensor均值 + + :param x: + :return: + """ + num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum() + value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum() + # num为0表示均为nan, 此时由于分母不能为0, 则设num为1 + if num == 0: + num = 1 + return value / num diff --git a/network_module/iou.py b/network_module/iou.py new file mode 100644 index 0000000..41ef768 --- /dev/null +++ b/network_module/iou.py @@ -0,0 +1,52 @@ +""" +和计算iou相关的函数和类, 包括计算iou loss +""" +import torch +from torch import nn +from network_module.compute_utils import torch_nanmean + + +def fast_hist(pred, label, n_classes): + # np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n) + return torch.bincount(n_classes * label + pred, minlength=n_classes ** 2).reshape(n_classes, n_classes) + + +def per_class_iu(hist): + # 计算所有验证集图片的逐类别mIoU值 + # 分别为每个类别计算mIoU,hist的形状(n, n) + # 矩阵的对角线上的值组成的一维数组/矩阵的所有元素之和,返回值形状(n,) + # hist.sum(0)=按列相加 hist.sum(1)按行相加, 行表示标签, 列表示预测 + return (torch.diag(hist)) / (torch.sum(hist, 1) + torch.sum(hist, 0) - torch.diag(hist)) + + +def get_ious(pred, label, n_classes): + hist = fast_hist(pred.flatten(), label.flatten(), n_classes) + IoUs = per_class_iu(hist) + mIoU = torch_nanmean(IoUs[1:n_classes]) + return mIoU, IoUs + + +class IOU_loss(nn.Module): + def __init__(self, n_classes): + super(IOU_loss, self).__init__() + self.n_classes = n_classes + + def forward(self, pred, label): + mIoU, _ = get_ious(pred, label, self.n_classes) + return 1 - mIoU + + +class IOU: + def __init__(self, n_classes): + self.n_classes = n_classes + self.hist = None + + def add_data(self, preds, label): + self.hist += torch.zeros((self.n_classes, self.n_classes)).type_as( + preds) if self.hist is None else self.hist + fast_hist(preds.int(), label, self.n_classes) + + def get_miou(self): + IoUs = per_class_iu(self.hist) + self.hist = None + mIoU = torch_nanmean(IoUs[1:self.n_classes]) + return mIoU, IoUs diff --git a/network_module/pix_acc.py b/network_module/pix_acc.py new file mode 100644 index 0000000..4264fe7 --- /dev/null +++ b/network_module/pix_acc.py @@ -0,0 +1,25 @@ +import torch + + +def calculate_acc(hist): + """ + 计算准确率, 而不是iou + + :param hist: + :return: + """ + n_class = hist.size()[0] + conf = torch.zeros((n_class, n_class)) + for cid in range(n_class): + if torch.sum(hist[:, cid]) > 0: + conf[:, cid] = hist[:, cid] / torch.sum(hist[:, cid]) + + # 可以看作对于除了背景外的像素点的判断accuracy, 但是比较偏向于判断为某些类的正确率. + # nan表示均判断为背景, 如果存在除背景外的类别, 则正确率为0; 如果不存在, 则表示nan(无结果,若不去除背景,则正确率为1) + overall_acc = torch.sum(torch.diag(hist[1:, 1:])) / torch.sum(hist[1:, :]) + + # acc为某类预测结果是正确的概率 + # nan表示无像素判断为该类, 若存在该类, 则表示正确率为0; 若不存在, 则表示nan(无法判断为该类结果的正确率) + acc = torch.diag(conf) + + return overall_acc, acc diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..0c5a37323abcdd2f95d611b9274b689540b52db0 GIT binary patch literal 1990 zcmZvdTWi}u5QXQt&|gv%VmsG|J{0;=+JK>j(nm*@FGQ9c%TAO0__pVp(RxK{gk+f3-{(AWz~^w|RgYu~VL+mbj=!>vd-ip8Cn6ZYY8aWQ_{tDc!1v@UKgy)i}2;>&su{Q$EjiN;b5jVppqMNKFlkauaC^hpF*sGMX!@__yBKm7fdqUiFYM` z>FlNc?Uh;Ppc*;5srLM`Pj+YD91;Inje3#8qp~H}eV=cfGSL@g@;2u5Y{5$Q(MMz* zlLJ?-L?`8I_lE4)!x90{aW*l9t?MMe zT;&|aSA;JG(;8n3BB@NuDD?t4<4u}0;)3^s_HM<7M}dq_oJ#1!zohL^Z#VuHFRr{g ziHBZ}4lAUIZ$)@HJZ+-)VU8O{PIzXQ@5;onack_nm#*ojkanuutA24e)dP%B#_WQU I{LCHm4^b*1;s5{u literal 0 HcmV?d00001 diff --git a/save_checkpoint.py b/save_checkpoint.py new file mode 100644 index 0000000..ba28633 --- /dev/null +++ b/save_checkpoint.py @@ -0,0 +1,90 @@ +import os +from pytorch_lightning.callbacks import ModelCheckpoint +import pytorch_lightning +import pytorch_lightning as pl +import shutil +import random +from pytorch_lightning.utilities import rank_zero_info +from utils import zip_dir + + +class SaveCheckpoint(ModelCheckpoint): + def __init__(self, + max_epochs, + seed=None, + every_n_epochs=None, + save_name=None, + path_final_save=None, + monitor=None, + save_top_k=None, + verbose=False, + mode='min', + no_save_before_epoch=0): + """ + 通过回调实现checkpoint的保存逻辑, 同时具有回调函数中定义on_validation_end等功能. + + :param max_epochs: + :param seed: + :param every_n_epochs: + :param save_name: + :param path_final_save: + :param monitor: + :param save_top_k: + :param verbose: + :param mode: + :param no_save_before_epoch: + """ + super().__init__(every_n_epochs=every_n_epochs, verbose=verbose, mode=mode) + random.seed(seed) + self.seeds = [] + for i in range(max_epochs): + self.seeds.append(random.randint(0, 2000)) + self.seeds.append(0) + pytorch_lightning.seed_everything(seed) + self.save_name = save_name + self.path_final_save = path_final_save + self.monitor = monitor + self.save_top_k = save_top_k + self.flag_sanity_check = 0 + self.no_save_before_epoch = no_save_before_epoch + + def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: + """ + 修改随机数逻辑,网络的随机种子给定,取样本的随机种子由给定的随机种子生成,保证即使重载训练每个epoch具有不同的抽样序列. + 同时保存checkpoint. + + :param trainer: + :param pl_module: + :return: + """ + if self.flag_sanity_check == 0: + pytorch_lightning.seed_everything(self.seeds[trainer.current_epoch]) + self.flag_sanity_check = 1 + else: + pytorch_lightning.seed_everything(self.seeds[trainer.current_epoch + 1]) + super().on_validation_end(trainer, pl_module) + + def _save_top_k_checkpoint(self, trainer: 'pl.Trainer', monitor_candidates) -> None: + epoch = monitor_candidates.get("epoch") + if self.monitor is None or self.save_top_k == 0 or epoch < self.no_save_before_epoch: + return + + current = monitor_candidates.get(self.monitor) + + if self.check_monitor_top_k(trainer, current): + self._update_best_and_save(current, trainer, monitor_candidates) + if self.save_name is not None and self.path_final_save is not None: + zip_dir('./logs/default/' + self.save_name, './' + self.save_name + '.zip') + if os.path.exists(self.path_final_save + '/' + self.save_name + '.zip'): + os.remove(self.path_final_save + '/' + self.save_name + '.zip') + shutil.move('./' + self.save_name + '.zip', self.path_final_save) + elif self.verbose: + epoch = monitor_candidates.get("epoch") + step = monitor_candidates.get("step") + best_model_values = 'now best model:' + for cou_best_model in self.best_k_models: + best_model_values = ' '.join( + (best_model_values, str(round(float(self.best_k_models[cou_best_model]), 4)))) + rank_zero_info( + f"\nEpoch {epoch:d}, global step {step:d}: {self.monitor} ({float(current):f}) was not in " + f"top {self.save_top_k:d}({best_model_values:s})") diff --git a/train_model.py b/train_model.py new file mode 100644 index 0000000..90f91a4 --- /dev/null +++ b/train_model.py @@ -0,0 +1,47 @@ +import pytorch_lightning as pl +from torch import nn +import torch +from torchmetrics.classification.accuracy import Accuracy +from network.MLP import MLP + + +class TrainModule(pl.LightningModule): + def __init__(self, config=None): + super().__init__() + self.time_sum = None + self.config = config + self.net = MLP(config['dim_in'], config['dim'], config['res_coef'], config['dropout_p'], config['n_layers']) + # TODO 修改网络初始化方式为kaiming分布或者xavier分布 + self.loss = nn.BCELoss() + self.accuracy = Accuracy() + + def training_step(self, batch, batch_idx): + x, y = batch + x = self.net(x) + loss = self.loss(x, y.type(torch.float32)) + acc = self.accuracy(x, y) + self.log("Training loss", loss) + self.log("Training acc", acc) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + x = self.net(x) + loss = self.loss(x, y.type(torch.float32)) + acc = self.accuracy(x, y) + self.log("Validation loss", loss) + self.log("Validation acc", acc) + return loss, acc + + def test_step(self, batch, batch_idx): + return 0 + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + return optimizer + + def load_pretrain_parameters(self): + """ + 载入预训练参数 + """ + pass diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..24af011 --- /dev/null +++ b/utils.py @@ -0,0 +1,189 @@ +# 包含一些与网络无关的工具 +import glob +import os +import random +import zipfile +import cv2 +import torch + + +def divide_dataset(dataset_path, rate_datasets): + """ + 切分数据集, 划分为训练集,验证集,测试集生成list文件并保存为: + train_dataset_list、validate_dataset_list、test_dataset_list. + 每个比例必须大于0且保证至少每个数据集中具有一个样本, 验证集可以为0. + + :param dataset_path: 数据集的地址 + :param rate_datasets: 不同数据集[训练集,验证集,测试集]的比例 + """ + # 当不存在总的all_dataset_list文件时, 生成all_dataset_list + if not os.path.exists(dataset_path + '/all_dataset_list.txt'): + all_list = glob.glob(dataset_path + '/labels' + '/*.png') + with open(dataset_path + '/all_dataset_list.txt', 'w', encoding='utf-8') as f: + for line in all_list: + f.write(os.path.basename(line.replace('\\', '/')) + '\n') + path_train_dataset_list = dataset_path + '/train_dataset_list.txt' + path_validate_dataset_list = dataset_path + '/validate_dataset_list.txt' + path_test_dataset_list = dataset_path + '/test_dataset_list.txt' + # 如果验证集的比例为0,则将测试集设置为验证集并取消测试集; + if rate_datasets[1] == 0: + # 如果无切分后的list文件, 则生成新的list文件 + if not (os.path.exists(path_train_dataset_list) and + os.path.exists(path_validate_dataset_list) and + os.path.exists(path_test_dataset_list)): + all_list = open(dataset_path + '/all_dataset_list.txt').readlines() + random.shuffle(all_list) + train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])] + test_dataset_list = all_list[int(len(all_list) * rate_datasets[0]):] + with open(path_train_dataset_list, 'w', encoding='utf-8') as f: + for line in train_dataset_list: + f.write(line) + with open(path_validate_dataset_list, 'w', encoding='utf-8') as f: + for line in test_dataset_list: + f.write(line) + with open(path_test_dataset_list, 'w', encoding='utf-8') as f: + for line in test_dataset_list: + f.write(line) + print('已生成新的数据list') + else: + # 判断比例是否正确,如果不正确,则重新生成数据集 + all_list = open(dataset_path + '/all_dataset_list.txt').readlines() + with open(path_train_dataset_list) as f: + train_dataset_list_exist = f.readlines() + with open(path_validate_dataset_list) as f: + test_dataset_list_exist = f.readlines() + random.shuffle(all_list) + train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])] + test_dataset_list = all_list[int(len(all_list) * rate_datasets[0]):] + if not (len(train_dataset_list_exist) == len(train_dataset_list) and + len(test_dataset_list_exist) == len(test_dataset_list)): + with open(path_train_dataset_list, 'w', encoding='utf-8') as f: + for line in train_dataset_list: + f.write(line) + with open(path_validate_dataset_list, 'w', encoding='utf-8') as f: + for line in test_dataset_list: + f.write(line) + with open(path_test_dataset_list, 'w', encoding='utf-8') as f: + for line in test_dataset_list: + f.write(line) + print('已生成新的数据list') + # 如果验证集比例不为零,则同时存在验证集和测试集 + else: + # 如果无切分后的list文件, 则生成新的list文件 + if not (os.path.exists(dataset_path + '/train_dataset_list.txt') and + os.path.exists(dataset_path + '/validate_dataset_list.txt') and + os.path.exists(dataset_path + '/test_dataset_list.txt')): + all_list = open(dataset_path + '/all_dataset_list.txt').readlines() + random.shuffle(all_list) + train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])] + validate_dataset_list = all_list[int(len(all_list) * rate_datasets[0]): + int(len(all_list) * (rate_datasets[0] + rate_datasets[1]))] + test_dataset_list = all_list[int(len(all_list) * (rate_datasets[0] + rate_datasets[1])):] + with open(path_train_dataset_list, 'w', encoding='utf-8') as f: + for line in train_dataset_list: + f.write(line) + with open(path_validate_dataset_list, 'w', encoding='utf-8') as f: + for line in validate_dataset_list: + f.write(line) + with open(path_test_dataset_list, 'w', encoding='utf-8') as f: + for line in test_dataset_list: + f.write(line) + print('已生成新的数据list') + else: + # 判断比例是否正确,如果不正确,则重新生成数据集 + all_list = open(dataset_path + '/all_dataset_list.txt').readlines() + with open(path_train_dataset_list) as f: + train_dataset_list_exist = f.readlines() + with open(path_validate_dataset_list) as f: + validate_dataset_list_exist = f.readlines() + with open(path_test_dataset_list) as f: + test_dataset_list_exist = f.readlines() + random.shuffle(all_list) + train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])] + validate_dataset_list = all_list[int(len(all_list) * rate_datasets[0]): + int(len(all_list) * (rate_datasets[0] + rate_datasets[1]))] + test_dataset_list = all_list[int(len(all_list) * (rate_datasets[0] + rate_datasets[1])):] + if not (len(train_dataset_list_exist) == len(train_dataset_list) and + len(validate_dataset_list_exist) == len(validate_dataset_list) and + len(test_dataset_list_exist) == len(test_dataset_list)): + with open(path_train_dataset_list, 'w', encoding='utf-8') as f: + for line in train_dataset_list: + f.write(line) + with open(path_validate_dataset_list, 'w', encoding='utf-8') as f: + for line in validate_dataset_list: + f.write(line) + with open(path_test_dataset_list, 'w', encoding='utf-8') as f: + for line in test_dataset_list: + f.write(line) + print('已生成新的数据list') + + +def zip_dir(dir_path, zip_path): + """ + 压缩文件 + + :param dir_path: 目标文件夹路径 + :param zip_path: 压缩后的文件夹路径 + """ + ziper = zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) + for root, dirnames, filenames in os.walk(dir_path): + file_path = root.replace(dir_path, '') # 去掉根路径,只对目标文件夹下的文件及文件夹进行压缩 + # 循环出一个个文件名 + for filename in filenames: + ziper.write(os.path.join(root, filename), os.path.join(file_path, filename)) + ziper.close() + + +def ncolors(num_colors): + """ + 生成区别度较大的几种颜色 + copy: https://blog.csdn.net/choumin/article/details/90320297 + + :param num_colors: 颜色数 + :return: + """ + def get_n_hls_colors(num): + import random + hls_colors = [] + i = 0 + step = 360.0 / num + while i < 360: + h = i + s = 90 + random.random() * 10 + li = 50 + random.random() * 10 + _hlsc = [h / 360.0, li / 100.0, s / 100.0] + hls_colors.append(_hlsc) + i += step + return hls_colors + + import colorsys + rgb_colors = [] + if num_colors < 1: + return rgb_colors + for hlsc in get_n_hls_colors(num_colors): + _r, _g, _b = colorsys.hls_to_rgb(hlsc[0], hlsc[1], hlsc[2]) + r, g, b = [int(x * 255.0) for x in (_r, _g, _b)] + rgb_colors.append([r, g, b]) + return rgb_colors + + +def visual_label(dataset_path, n_classes): + """ + 将标签可视化 + + :param dataset_path: 地址 + :param n_classes: 类别数 + """ + label_path = os.path.join(dataset_path, 'test', 'labels').replace('\\', '/') + label_image_list = glob.glob(label_path + '/*.png') + label_image_list.sort() + from torchvision import transforms + trans_factory = transforms.ToPILImage() + if not os.path.exists(dataset_path + '/visual_label'): + os.mkdir(dataset_path + '/visual_label') + for index in range(len(label_image_list)): + label_image = cv2.imread(label_image_list[index], -1) + name = os.path.basename(label_image_list[index]) + trans_factory(torch.from_numpy(label_image).float() / n_classes).save( + dataset_path + '/visual_label/' + name, + quality=95)