| @@ -0,0 +1,21 @@ | |||||
| """ | |||||
| 测试.py文件各个函数的运行时间 | |||||
| """ | |||||
| import pstats | |||||
| # 在cmd中使用python -m cProfile -o result.cprofile main.py | |||||
| # 创建 Stats 对象 | |||||
| p = pstats.Stats('./result.cprofile') | |||||
| # 按照运行时间和函数名进行排序 | |||||
| # 按照函数名排序,只打印前n行函数(其中n为print_stats(n)的输入参数)的信息, | |||||
| p.strip_dirs().sort_stats("cumulative", "name").print_stats(15) | |||||
| # 参数还可为小数, 表示前n(其中n为一个小于1的百分数, 是print_stats(n)的输入参数)的函数信息 | |||||
| # p.strip_dirs().sort_stats("cumulative", "name").print_stats(0.5) | |||||
| # 查看调用main()的函数 | |||||
| # p.print_callers(0.5, "main") | |||||
| # 查看main()函数中调用的函数 | |||||
| # p.print_callees("main") | |||||
| # pip安装snakeviz后,在cmd里运行如下命令: | |||||
| # snakeviz result.out | |||||
| @@ -0,0 +1,65 @@ | |||||
| import torch | |||||
| from torch.utils.data import Dataset, DataLoader | |||||
| import pytorch_lightning as pl | |||||
| class DataModule(pl.LightningDataModule): | |||||
| def __init__(self, batch_size, num_workers, config=None): | |||||
| super().__init__() | |||||
| # TODO 使用k折交叉验证 | |||||
| # divide_dataset(config['dataset_path'], [0.8, 0, 0.2]) | |||||
| if config['flag']: | |||||
| x = torch.randn(100000, 2) | |||||
| noise = torch.randn(100000, ) | |||||
| y = ((1.0 * x[:, 0] + 2.0 * x[:, 1] + noise) > 0).type(torch.int64) | |||||
| else: | |||||
| x_1 = torch.randn(100000) | |||||
| x_2 = torch.randn(100000) | |||||
| x_useful = torch.cos(1.5 * x_1) * (x_2 ** 2) | |||||
| x_1_rest_small = torch.randn(100000, 15) + 0.01 * x_1.unsqueeze(1) | |||||
| x_1_rest_large = torch.randn(100000, 15) + 0.1 * x_1.unsqueeze(1) | |||||
| x_2_rest_small = torch.randn(100000, 15) + 0.01 * x_2.unsqueeze(1) | |||||
| x_2_rest_large = torch.randn(100000, 15) + 0.1 * x_2.unsqueeze(1) | |||||
| x = torch.cat([x_1[:, None], x_2[:, None], x_1_rest_small, x_1_rest_large, x_2_rest_small, x_2_rest_large], | |||||
| dim=1) | |||||
| y = ((10 * x_useful) + 5 * torch.randn(100000) > 0.0).type(torch.int64) | |||||
| self.y_train, self.y_test = y[:50000], y[50000:] | |||||
| self.x_train, self.x_test = x[:50000, :], x[50000:, :] | |||||
| self.batch_size = batch_size | |||||
| self.num_workers = num_workers | |||||
| self.config = config | |||||
| self.train_dataset = None | |||||
| self.val_dataset = None | |||||
| self.test_dataset = None | |||||
| def setup(self, stage=None) -> None: | |||||
| if stage == 'fit' or stage is None: | |||||
| self.train_dataset = CustomDataset(self.x_train, self.y_train, self.config) | |||||
| self.val_dataset = CustomDataset(self.x_test, self.y_test, self.config) | |||||
| if stage == 'test' or stage is None: | |||||
| self.test_dataset = CustomDataset(self.x_test, self.y_test, self.config) | |||||
| def train_dataloader(self): | |||||
| return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) | |||||
| def val_dataloader(self): | |||||
| return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers) | |||||
| def test_dataloader(self): | |||||
| return DataLoader(self.test_dataset, batch_size=1, shuffle=False, num_workers=self.num_workers) | |||||
| class CustomDataset(Dataset): | |||||
| def __init__(self, x, y, config): | |||||
| super().__init__() | |||||
| self.x = x | |||||
| self.y = y | |||||
| self.config = config | |||||
| def __getitem__(self, idx): | |||||
| return self.x[idx, :], self.y[idx] | |||||
| def __len__(self): | |||||
| return self.x.shape[0] | |||||
| @@ -0,0 +1,38 @@ | |||||
| """ | |||||
| 评估指定文件夹下的预测结果, 评价结果均不计算背景类 | |||||
| """ | |||||
| import numpy as np | |||||
| from PIL import Image | |||||
| from os.path import join | |||||
| from network_module.iou import IOU | |||||
| from network_module.pix_acc import calculate_acc | |||||
| def evalute(n_classes, dataset_path, verbose=False): | |||||
| iou = IOU(n_classes) | |||||
| test_list = open(join(dataset_path, 'test_dataset_list.txt').replace('\\', '/')).readlines() | |||||
| for ind in range(len(test_list)): | |||||
| pred = np.array(Image.open(join(dataset_path, 'prediction', test_list[ind].strip('\n')).replace('\\', '/'))) | |||||
| label = np.array(Image.open(join(dataset_path, 'labels', test_list[ind].strip('\n')).replace('\\', '/'))) | |||||
| if len(label.flatten()) != len(pred.flatten()): | |||||
| print('跳过{:s}: pred len {:d} != label len {:d},'.format( | |||||
| test_list[ind].strip('\n'), len(label.flatten()), len(pred.flatten()))) | |||||
| continue | |||||
| iou.add_data(pred, label) | |||||
| # 必须置于iou_loss.forward前,因为forward会清除hist | |||||
| overall_acc, acc = calculate_acc(iou.hist) | |||||
| mIoU, IoUs = iou.get_miou() | |||||
| if verbose: | |||||
| for ind_class in range(n_classes): | |||||
| print('===>' + str(ind_class) + ':\t' + str(IoUs[ind_class].float())) | |||||
| print('===> mIoU: ' + str(mIoU)) | |||||
| print('===> overall accuracy:', overall_acc) | |||||
| print('===> accuracy of each class:', acc) | |||||
| if __name__ == "__main__": | |||||
| evalute(9, './dataset/MFNet(RGB-T)-mini', verbose=True) | |||||
| @@ -0,0 +1,19 @@ | |||||
| #!/bin/bash | |||||
| module load anaconda/2020.11 | |||||
| module load cuda/10.2 | |||||
| module load cudnn/8.1.1.33_CUDA10.2 | |||||
| #conda create --name py37 python=3.7 | |||||
| source activate py37 | |||||
| cd run | |||||
| cd machineLearningScaffold | |||||
| #pip install -r requirements.txt | |||||
| python main.py | |||||
| #sbatch --gpus=1 ./bscc-run.sh | |||||
| #当前作业ID: 67417 | |||||
| #查询作业: parajobs 取消作业: scancel ID | |||||
| @@ -0,0 +1,92 @@ | |||||
| { | |||||
| "cells": [ | |||||
| { | |||||
| "cell_type": "code", | |||||
| "execution_count": null, | |||||
| "id": "4c4d04db-6fe0-4511-abb3-c746ba869863", | |||||
| "metadata": {}, | |||||
| "outputs": [], | |||||
| "source": [ | |||||
| "!apt-get update\n", | |||||
| "!apt-get install p7zip-full -y" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "cell_type": "code", | |||||
| "execution_count": null, | |||||
| "id": "7c027a62-8116-4c55-9a90-95c4d90dd087", | |||||
| "metadata": {}, | |||||
| "outputs": [], | |||||
| "source": [ | |||||
| "!7z x dataset.zip -o/hy-tmp\n", | |||||
| "!7z x machineLearningScaffold.zip -o/hy-tmp\n", | |||||
| "!7z x vit_pre-train_checkpoint.zip -o/hy-tmp\n", | |||||
| "!mv -f /hy-tmp/pre-train_checkpoint/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz /hy-tmp/pre-train_checkpoint" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "cell_type": "code", | |||||
| "execution_count": 8, | |||||
| "id": "4c6dc6a1-0208-4efb-beaf-b1264a94ba71", | |||||
| "metadata": {}, | |||||
| "outputs": [ | |||||
| { | |||||
| "name": "stdout", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| "\u001b[33mWARNING: You are using pip version 21.0.1; however, version 21.2.4 is available.\n", | |||||
| "You should consider upgrading via the '/usr/bin/python3.8 -m pip install --upgrade pip' command.\u001b[0m\n" | |||||
| ] | |||||
| } | |||||
| ], | |||||
| "source": [ | |||||
| "!pip install -r /hy-tmp/requirements.txt > /dev/null" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "cell_type": "code", | |||||
| "execution_count": null, | |||||
| "id": "2c074f52-d27e-4605-8b4f-45b11a2f1eb1", | |||||
| "metadata": {}, | |||||
| "outputs": [], | |||||
| "source": [ | |||||
| "import main\n", | |||||
| "%load_ext tensorboard" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "cell_type": "code", | |||||
| "execution_count": null, | |||||
| "id": "a5a8d64e-d90e-467f-a3b4-db8b31874044", | |||||
| "metadata": {}, | |||||
| "outputs": [], | |||||
| "source": [ | |||||
| "main.main('fit', gpus=1, dataset_path='./dataset/MFNet(RGB-T)', num_workers=2, max_epochs=60, batch_size=16, precision=16, seed=1234, # tpu_cores=8,\n", | |||||
| " checkpoint_every_n_val=1, save_name='version_0', #checkpoint_path='/version_15/checkpoints/epoch=59-step=4739.ckpt',\n", | |||||
| " path_final_save='./drive/MyDrive', save_top_k=1\n", | |||||
| " )" | |||||
| ] | |||||
| } | |||||
| ], | |||||
| "metadata": { | |||||
| "kernelspec": { | |||||
| "display_name": "Python 3", | |||||
| "language": "python", | |||||
| "name": "python3" | |||||
| }, | |||||
| "language_info": { | |||||
| "codemirror_mode": { | |||||
| "name": "ipython", | |||||
| "version": 3 | |||||
| }, | |||||
| "file_extension": ".py", | |||||
| "mimetype": "text/x-python", | |||||
| "name": "python", | |||||
| "nbconvert_exporter": "python", | |||||
| "pygments_lexer": "ipython3", | |||||
| "version": "3.8.10" | |||||
| } | |||||
| }, | |||||
| "nbformat": 4, | |||||
| "nbformat_minor": 5 | |||||
| } | |||||
| @@ -0,0 +1,99 @@ | |||||
| from save_checkpoint import SaveCheckpoint | |||||
| from data_module import DataModule | |||||
| from pytorch_lightning import loggers as pl_loggers | |||||
| import pytorch_lightning as pl | |||||
| from train_model import TrainModule | |||||
| def main(stage, | |||||
| num_workers, | |||||
| max_epochs, | |||||
| batch_size, | |||||
| precision, | |||||
| seed, | |||||
| dataset_path=None, | |||||
| gpus=None, | |||||
| tpu_cores=None, | |||||
| load_checkpoint_path=None, | |||||
| save_name=None, | |||||
| path_final_save=None, | |||||
| every_n_epochs=1, | |||||
| save_top_k=1,): | |||||
| """ | |||||
| 框架的入口函数. 包含设置超参数, 划分数据集, 选择训练或测试等流程 | |||||
| 该函数的参数为训练过程中需要经常改动的参数 | |||||
| :param stage: 表示处于训练阶段还是测试阶段, fit表示训练, test表示测试 | |||||
| :param num_workers: | |||||
| :param max_epochs: | |||||
| :param batch_size: | |||||
| :param precision: 训练精度, 正常精度为32, 半精度为16, 也可以是64. 精度代表每个参数的类型所占的位数 | |||||
| :param seed: | |||||
| :param dataset_path: 数据集地址, 其目录下包含数据集, 标签, 全部数据的命名list | |||||
| :param gpus: | |||||
| :param tpu_cores: | |||||
| :param load_checkpoint_path: | |||||
| :param save_name: | |||||
| :param path_final_save: | |||||
| :param every_n_epochs: | |||||
| :param save_top_k: | |||||
| """ | |||||
| # config存放确定模型后不常改动的非通用的参数, 通用参数且不经常带动的直接进行声明 | |||||
| if False: | |||||
| config = {'dataset_path': dataset_path, | |||||
| 'dim_in': 2, | |||||
| 'dim': 10, | |||||
| 'res_coef': 0.5, | |||||
| 'dropout_p': 0.1, | |||||
| 'n_layers': 2, | |||||
| 'flag': True} | |||||
| else: | |||||
| config = {'dataset_path': dataset_path, | |||||
| 'dim_in': 62, | |||||
| 'dim': 32, | |||||
| 'res_coef': 0.5, | |||||
| 'dropout_p': 0.1, | |||||
| 'n_layers': 20, | |||||
| 'flag': False} | |||||
| # TODO 获得最优的batch size | |||||
| # TODO 自动获取CPU核心数并设置num workers | |||||
| precision = 32 if (gpus is None and tpu_cores is None) else precision | |||||
| dm = DataModule(batch_size=batch_size, num_workers=num_workers, config=config) | |||||
| logger = pl_loggers.TensorBoardLogger('logs/') | |||||
| if stage == 'fit': | |||||
| training_module = TrainModule(config=config) | |||||
| save_checkpoint = SaveCheckpoint(seed=seed, max_epochs=max_epochs, | |||||
| save_name=save_name, path_final_save=path_final_save, | |||||
| every_n_epochs=every_n_epochs, verbose=True, | |||||
| monitor='Validation acc', save_top_k=save_top_k, | |||||
| mode='max') | |||||
| if load_checkpoint_path is None: | |||||
| print('进行初始训练') | |||||
| trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores, | |||||
| logger=logger, precision=precision, callbacks=[save_checkpoint]) | |||||
| training_module.load_pretrain_parameters() | |||||
| else: | |||||
| print('进行重载训练') | |||||
| trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores, | |||||
| resume_from_checkpoint='./logs/default' + load_checkpoint_path, | |||||
| logger=logger, precision=precision, callbacks=[save_checkpoint]) | |||||
| trainer.fit(training_module, datamodule=dm) | |||||
| if stage == 'test': | |||||
| if load_checkpoint_path is None: | |||||
| print('未载入权重信息,不能测试') | |||||
| else: | |||||
| print('进行测试') | |||||
| training_module = TrainModule.load_from_checkpoint( | |||||
| checkpoint_path='./logs/default' + load_checkpoint_path, | |||||
| **{'config': config}) | |||||
| trainer = pl.Trainer(gpus=gpus, tpu_cores=tpu_cores, logger=logger, precision=precision) | |||||
| trainer.test(training_module, datamodule=dm) | |||||
| # 在cmd中使用tensorboard --logdir logs命令可以查看结果,在Jupyter格式下需要加%前缀 | |||||
| if __name__ == "__main__": | |||||
| main('fit', num_workers=8, max_epochs=5, batch_size=32, precision=16, seed=1234, | |||||
| # gpus=1, | |||||
| # load_checkpoint_path='/version_5/checkpoints/epoch=149-step=7949.ckpt', | |||||
| ) | |||||
| @@ -0,0 +1,40 @@ | |||||
| import torch.nn as nn | |||||
| from torch.nn import Mish | |||||
| class MLPLayer(nn.Module): | |||||
| def __init__(self, dim_in, dim_out, res_coef=0.0, dropout_p=0.1): | |||||
| super().__init__() | |||||
| self.linear = nn.Linear(dim_in, dim_out) | |||||
| self.res_coef = res_coef | |||||
| self.activation = Mish() | |||||
| self.dropout = nn.Dropout(dropout_p) | |||||
| self.ln = nn.LayerNorm(dim_out) | |||||
| def forward(self, x): | |||||
| y = self.linear(x) | |||||
| y = self.activation(y) | |||||
| y = self.dropout(y) | |||||
| if self.res_coef == 0: | |||||
| return self.ln(y) | |||||
| else: | |||||
| return self.ln(self.res_coef * x + y) | |||||
| class MLP(nn.Module): | |||||
| def __init__(self, dim_in, dim, res_coef=0.5, dropout_p=0.1, n_layers=10): | |||||
| super().__init__() | |||||
| self.mlp = nn.ModuleList() | |||||
| self.first_linear = MLPLayer(dim_in, dim) | |||||
| self.n_layers = n_layers | |||||
| for i in range(n_layers): | |||||
| self.mlp.append(MLPLayer(dim, dim, res_coef, dropout_p)) | |||||
| self.final = nn.Linear(dim, 1) | |||||
| self.sigmoid = nn.Sigmoid() | |||||
| def forward(self, x): | |||||
| x = self.first_linear(x) | |||||
| for layer in self.mlp: | |||||
| x = layer(x) | |||||
| x = self.sigmoid(self.final(x)) | |||||
| return x.squeeze() | |||||
| @@ -0,0 +1 @@ | |||||
| # 存放一些通用的与网络相关的模块,不随网络的生命周期结束而消失的模块,比如优化器,损失,评估等 | |||||
| @@ -0,0 +1,32 @@ | |||||
| import torch | |||||
| def one_hot_encoder(input_tensor, n_classes): | |||||
| """ | |||||
| 将输入tensor转化为one-hot形式 | |||||
| :param input_tensor: | |||||
| :param n_classes: | |||||
| :return: | |||||
| """ | |||||
| tensor_list = [] | |||||
| for i in range(n_classes): | |||||
| temp_prob = input_tensor == i # * torch.ones_like(input_tensor) | |||||
| tensor_list.append(temp_prob.unsqueeze(1)) | |||||
| output_tensor = torch.cat(tensor_list, dim=1) | |||||
| return output_tensor.long() | |||||
| def torch_nanmean(x): | |||||
| """ | |||||
| 输出忽略nan的tensor均值 | |||||
| :param x: | |||||
| :return: | |||||
| """ | |||||
| num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum() | |||||
| value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum() | |||||
| # num为0表示均为nan, 此时由于分母不能为0, 则设num为1 | |||||
| if num == 0: | |||||
| num = 1 | |||||
| return value / num | |||||
| @@ -0,0 +1,52 @@ | |||||
| """ | |||||
| 和计算iou相关的函数和类, 包括计算iou loss | |||||
| """ | |||||
| import torch | |||||
| from torch import nn | |||||
| from network_module.compute_utils import torch_nanmean | |||||
| def fast_hist(pred, label, n_classes): | |||||
| # np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n) | |||||
| return torch.bincount(n_classes * label + pred, minlength=n_classes ** 2).reshape(n_classes, n_classes) | |||||
| def per_class_iu(hist): | |||||
| # 计算所有验证集图片的逐类别mIoU值 | |||||
| # 分别为每个类别计算mIoU,hist的形状(n, n) | |||||
| # 矩阵的对角线上的值组成的一维数组/矩阵的所有元素之和,返回值形状(n,) | |||||
| # hist.sum(0)=按列相加 hist.sum(1)按行相加, 行表示标签, 列表示预测 | |||||
| return (torch.diag(hist)) / (torch.sum(hist, 1) + torch.sum(hist, 0) - torch.diag(hist)) | |||||
| def get_ious(pred, label, n_classes): | |||||
| hist = fast_hist(pred.flatten(), label.flatten(), n_classes) | |||||
| IoUs = per_class_iu(hist) | |||||
| mIoU = torch_nanmean(IoUs[1:n_classes]) | |||||
| return mIoU, IoUs | |||||
| class IOU_loss(nn.Module): | |||||
| def __init__(self, n_classes): | |||||
| super(IOU_loss, self).__init__() | |||||
| self.n_classes = n_classes | |||||
| def forward(self, pred, label): | |||||
| mIoU, _ = get_ious(pred, label, self.n_classes) | |||||
| return 1 - mIoU | |||||
| class IOU: | |||||
| def __init__(self, n_classes): | |||||
| self.n_classes = n_classes | |||||
| self.hist = None | |||||
| def add_data(self, preds, label): | |||||
| self.hist += torch.zeros((self.n_classes, self.n_classes)).type_as( | |||||
| preds) if self.hist is None else self.hist + fast_hist(preds.int(), label, self.n_classes) | |||||
| def get_miou(self): | |||||
| IoUs = per_class_iu(self.hist) | |||||
| self.hist = None | |||||
| mIoU = torch_nanmean(IoUs[1:self.n_classes]) | |||||
| return mIoU, IoUs | |||||
| @@ -0,0 +1,25 @@ | |||||
| import torch | |||||
| def calculate_acc(hist): | |||||
| """ | |||||
| 计算准确率, 而不是iou | |||||
| :param hist: | |||||
| :return: | |||||
| """ | |||||
| n_class = hist.size()[0] | |||||
| conf = torch.zeros((n_class, n_class)) | |||||
| for cid in range(n_class): | |||||
| if torch.sum(hist[:, cid]) > 0: | |||||
| conf[:, cid] = hist[:, cid] / torch.sum(hist[:, cid]) | |||||
| # 可以看作对于除了背景外的像素点的判断accuracy, 但是比较偏向于判断为某些类的正确率. | |||||
| # nan表示均判断为背景, 如果存在除背景外的类别, 则正确率为0; 如果不存在, 则表示nan(无结果,若不去除背景,则正确率为1) | |||||
| overall_acc = torch.sum(torch.diag(hist[1:, 1:])) / torch.sum(hist[1:, :]) | |||||
| # acc为某类预测结果是正确的概率 | |||||
| # nan表示无像素判断为该类, 若存在该类, 则表示正确率为0; 若不存在, 则表示nan(无法判断为该类结果的正确率) | |||||
| acc = torch.diag(conf) | |||||
| return overall_acc, acc | |||||
| @@ -0,0 +1,90 @@ | |||||
| import os | |||||
| from pytorch_lightning.callbacks import ModelCheckpoint | |||||
| import pytorch_lightning | |||||
| import pytorch_lightning as pl | |||||
| import shutil | |||||
| import random | |||||
| from pytorch_lightning.utilities import rank_zero_info | |||||
| from utils import zip_dir | |||||
| class SaveCheckpoint(ModelCheckpoint): | |||||
| def __init__(self, | |||||
| max_epochs, | |||||
| seed=None, | |||||
| every_n_epochs=None, | |||||
| save_name=None, | |||||
| path_final_save=None, | |||||
| monitor=None, | |||||
| save_top_k=None, | |||||
| verbose=False, | |||||
| mode='min', | |||||
| no_save_before_epoch=0): | |||||
| """ | |||||
| 通过回调实现checkpoint的保存逻辑, 同时具有回调函数中定义on_validation_end等功能. | |||||
| :param max_epochs: | |||||
| :param seed: | |||||
| :param every_n_epochs: | |||||
| :param save_name: | |||||
| :param path_final_save: | |||||
| :param monitor: | |||||
| :param save_top_k: | |||||
| :param verbose: | |||||
| :param mode: | |||||
| :param no_save_before_epoch: | |||||
| """ | |||||
| super().__init__(every_n_epochs=every_n_epochs, verbose=verbose, mode=mode) | |||||
| random.seed(seed) | |||||
| self.seeds = [] | |||||
| for i in range(max_epochs): | |||||
| self.seeds.append(random.randint(0, 2000)) | |||||
| self.seeds.append(0) | |||||
| pytorch_lightning.seed_everything(seed) | |||||
| self.save_name = save_name | |||||
| self.path_final_save = path_final_save | |||||
| self.monitor = monitor | |||||
| self.save_top_k = save_top_k | |||||
| self.flag_sanity_check = 0 | |||||
| self.no_save_before_epoch = no_save_before_epoch | |||||
| def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: | |||||
| """ | |||||
| 修改随机数逻辑,网络的随机种子给定,取样本的随机种子由给定的随机种子生成,保证即使重载训练每个epoch具有不同的抽样序列. | |||||
| 同时保存checkpoint. | |||||
| :param trainer: | |||||
| :param pl_module: | |||||
| :return: | |||||
| """ | |||||
| if self.flag_sanity_check == 0: | |||||
| pytorch_lightning.seed_everything(self.seeds[trainer.current_epoch]) | |||||
| self.flag_sanity_check = 1 | |||||
| else: | |||||
| pytorch_lightning.seed_everything(self.seeds[trainer.current_epoch + 1]) | |||||
| super().on_validation_end(trainer, pl_module) | |||||
| def _save_top_k_checkpoint(self, trainer: 'pl.Trainer', monitor_candidates) -> None: | |||||
| epoch = monitor_candidates.get("epoch") | |||||
| if self.monitor is None or self.save_top_k == 0 or epoch < self.no_save_before_epoch: | |||||
| return | |||||
| current = monitor_candidates.get(self.monitor) | |||||
| if self.check_monitor_top_k(trainer, current): | |||||
| self._update_best_and_save(current, trainer, monitor_candidates) | |||||
| if self.save_name is not None and self.path_final_save is not None: | |||||
| zip_dir('./logs/default/' + self.save_name, './' + self.save_name + '.zip') | |||||
| if os.path.exists(self.path_final_save + '/' + self.save_name + '.zip'): | |||||
| os.remove(self.path_final_save + '/' + self.save_name + '.zip') | |||||
| shutil.move('./' + self.save_name + '.zip', self.path_final_save) | |||||
| elif self.verbose: | |||||
| epoch = monitor_candidates.get("epoch") | |||||
| step = monitor_candidates.get("step") | |||||
| best_model_values = 'now best model:' | |||||
| for cou_best_model in self.best_k_models: | |||||
| best_model_values = ' '.join( | |||||
| (best_model_values, str(round(float(self.best_k_models[cou_best_model]), 4)))) | |||||
| rank_zero_info( | |||||
| f"\nEpoch {epoch:d}, global step {step:d}: {self.monitor} ({float(current):f}) was not in " | |||||
| f"top {self.save_top_k:d}({best_model_values:s})") | |||||
| @@ -0,0 +1,47 @@ | |||||
| import pytorch_lightning as pl | |||||
| from torch import nn | |||||
| import torch | |||||
| from torchmetrics.classification.accuracy import Accuracy | |||||
| from network.MLP import MLP | |||||
| class TrainModule(pl.LightningModule): | |||||
| def __init__(self, config=None): | |||||
| super().__init__() | |||||
| self.time_sum = None | |||||
| self.config = config | |||||
| self.net = MLP(config['dim_in'], config['dim'], config['res_coef'], config['dropout_p'], config['n_layers']) | |||||
| # TODO 修改网络初始化方式为kaiming分布或者xavier分布 | |||||
| self.loss = nn.BCELoss() | |||||
| self.accuracy = Accuracy() | |||||
| def training_step(self, batch, batch_idx): | |||||
| x, y = batch | |||||
| x = self.net(x) | |||||
| loss = self.loss(x, y.type(torch.float32)) | |||||
| acc = self.accuracy(x, y) | |||||
| self.log("Training loss", loss) | |||||
| self.log("Training acc", acc) | |||||
| return loss | |||||
| def validation_step(self, batch, batch_idx): | |||||
| x, y = batch | |||||
| x = self.net(x) | |||||
| loss = self.loss(x, y.type(torch.float32)) | |||||
| acc = self.accuracy(x, y) | |||||
| self.log("Validation loss", loss) | |||||
| self.log("Validation acc", acc) | |||||
| return loss, acc | |||||
| def test_step(self, batch, batch_idx): | |||||
| return 0 | |||||
| def configure_optimizers(self): | |||||
| optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) | |||||
| return optimizer | |||||
| def load_pretrain_parameters(self): | |||||
| """ | |||||
| 载入预训练参数 | |||||
| """ | |||||
| pass | |||||
| @@ -0,0 +1,189 @@ | |||||
| # 包含一些与网络无关的工具 | |||||
| import glob | |||||
| import os | |||||
| import random | |||||
| import zipfile | |||||
| import cv2 | |||||
| import torch | |||||
| def divide_dataset(dataset_path, rate_datasets): | |||||
| """ | |||||
| 切分数据集, 划分为训练集,验证集,测试集生成list文件并保存为: | |||||
| train_dataset_list、validate_dataset_list、test_dataset_list. | |||||
| 每个比例必须大于0且保证至少每个数据集中具有一个样本, 验证集可以为0. | |||||
| :param dataset_path: 数据集的地址 | |||||
| :param rate_datasets: 不同数据集[训练集,验证集,测试集]的比例 | |||||
| """ | |||||
| # 当不存在总的all_dataset_list文件时, 生成all_dataset_list | |||||
| if not os.path.exists(dataset_path + '/all_dataset_list.txt'): | |||||
| all_list = glob.glob(dataset_path + '/labels' + '/*.png') | |||||
| with open(dataset_path + '/all_dataset_list.txt', 'w', encoding='utf-8') as f: | |||||
| for line in all_list: | |||||
| f.write(os.path.basename(line.replace('\\', '/')) + '\n') | |||||
| path_train_dataset_list = dataset_path + '/train_dataset_list.txt' | |||||
| path_validate_dataset_list = dataset_path + '/validate_dataset_list.txt' | |||||
| path_test_dataset_list = dataset_path + '/test_dataset_list.txt' | |||||
| # 如果验证集的比例为0,则将测试集设置为验证集并取消测试集; | |||||
| if rate_datasets[1] == 0: | |||||
| # 如果无切分后的list文件, 则生成新的list文件 | |||||
| if not (os.path.exists(path_train_dataset_list) and | |||||
| os.path.exists(path_validate_dataset_list) and | |||||
| os.path.exists(path_test_dataset_list)): | |||||
| all_list = open(dataset_path + '/all_dataset_list.txt').readlines() | |||||
| random.shuffle(all_list) | |||||
| train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])] | |||||
| test_dataset_list = all_list[int(len(all_list) * rate_datasets[0]):] | |||||
| with open(path_train_dataset_list, 'w', encoding='utf-8') as f: | |||||
| for line in train_dataset_list: | |||||
| f.write(line) | |||||
| with open(path_validate_dataset_list, 'w', encoding='utf-8') as f: | |||||
| for line in test_dataset_list: | |||||
| f.write(line) | |||||
| with open(path_test_dataset_list, 'w', encoding='utf-8') as f: | |||||
| for line in test_dataset_list: | |||||
| f.write(line) | |||||
| print('已生成新的数据list') | |||||
| else: | |||||
| # 判断比例是否正确,如果不正确,则重新生成数据集 | |||||
| all_list = open(dataset_path + '/all_dataset_list.txt').readlines() | |||||
| with open(path_train_dataset_list) as f: | |||||
| train_dataset_list_exist = f.readlines() | |||||
| with open(path_validate_dataset_list) as f: | |||||
| test_dataset_list_exist = f.readlines() | |||||
| random.shuffle(all_list) | |||||
| train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])] | |||||
| test_dataset_list = all_list[int(len(all_list) * rate_datasets[0]):] | |||||
| if not (len(train_dataset_list_exist) == len(train_dataset_list) and | |||||
| len(test_dataset_list_exist) == len(test_dataset_list)): | |||||
| with open(path_train_dataset_list, 'w', encoding='utf-8') as f: | |||||
| for line in train_dataset_list: | |||||
| f.write(line) | |||||
| with open(path_validate_dataset_list, 'w', encoding='utf-8') as f: | |||||
| for line in test_dataset_list: | |||||
| f.write(line) | |||||
| with open(path_test_dataset_list, 'w', encoding='utf-8') as f: | |||||
| for line in test_dataset_list: | |||||
| f.write(line) | |||||
| print('已生成新的数据list') | |||||
| # 如果验证集比例不为零,则同时存在验证集和测试集 | |||||
| else: | |||||
| # 如果无切分后的list文件, 则生成新的list文件 | |||||
| if not (os.path.exists(dataset_path + '/train_dataset_list.txt') and | |||||
| os.path.exists(dataset_path + '/validate_dataset_list.txt') and | |||||
| os.path.exists(dataset_path + '/test_dataset_list.txt')): | |||||
| all_list = open(dataset_path + '/all_dataset_list.txt').readlines() | |||||
| random.shuffle(all_list) | |||||
| train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])] | |||||
| validate_dataset_list = all_list[int(len(all_list) * rate_datasets[0]): | |||||
| int(len(all_list) * (rate_datasets[0] + rate_datasets[1]))] | |||||
| test_dataset_list = all_list[int(len(all_list) * (rate_datasets[0] + rate_datasets[1])):] | |||||
| with open(path_train_dataset_list, 'w', encoding='utf-8') as f: | |||||
| for line in train_dataset_list: | |||||
| f.write(line) | |||||
| with open(path_validate_dataset_list, 'w', encoding='utf-8') as f: | |||||
| for line in validate_dataset_list: | |||||
| f.write(line) | |||||
| with open(path_test_dataset_list, 'w', encoding='utf-8') as f: | |||||
| for line in test_dataset_list: | |||||
| f.write(line) | |||||
| print('已生成新的数据list') | |||||
| else: | |||||
| # 判断比例是否正确,如果不正确,则重新生成数据集 | |||||
| all_list = open(dataset_path + '/all_dataset_list.txt').readlines() | |||||
| with open(path_train_dataset_list) as f: | |||||
| train_dataset_list_exist = f.readlines() | |||||
| with open(path_validate_dataset_list) as f: | |||||
| validate_dataset_list_exist = f.readlines() | |||||
| with open(path_test_dataset_list) as f: | |||||
| test_dataset_list_exist = f.readlines() | |||||
| random.shuffle(all_list) | |||||
| train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])] | |||||
| validate_dataset_list = all_list[int(len(all_list) * rate_datasets[0]): | |||||
| int(len(all_list) * (rate_datasets[0] + rate_datasets[1]))] | |||||
| test_dataset_list = all_list[int(len(all_list) * (rate_datasets[0] + rate_datasets[1])):] | |||||
| if not (len(train_dataset_list_exist) == len(train_dataset_list) and | |||||
| len(validate_dataset_list_exist) == len(validate_dataset_list) and | |||||
| len(test_dataset_list_exist) == len(test_dataset_list)): | |||||
| with open(path_train_dataset_list, 'w', encoding='utf-8') as f: | |||||
| for line in train_dataset_list: | |||||
| f.write(line) | |||||
| with open(path_validate_dataset_list, 'w', encoding='utf-8') as f: | |||||
| for line in validate_dataset_list: | |||||
| f.write(line) | |||||
| with open(path_test_dataset_list, 'w', encoding='utf-8') as f: | |||||
| for line in test_dataset_list: | |||||
| f.write(line) | |||||
| print('已生成新的数据list') | |||||
| def zip_dir(dir_path, zip_path): | |||||
| """ | |||||
| 压缩文件 | |||||
| :param dir_path: 目标文件夹路径 | |||||
| :param zip_path: 压缩后的文件夹路径 | |||||
| """ | |||||
| ziper = zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) | |||||
| for root, dirnames, filenames in os.walk(dir_path): | |||||
| file_path = root.replace(dir_path, '') # 去掉根路径,只对目标文件夹下的文件及文件夹进行压缩 | |||||
| # 循环出一个个文件名 | |||||
| for filename in filenames: | |||||
| ziper.write(os.path.join(root, filename), os.path.join(file_path, filename)) | |||||
| ziper.close() | |||||
| def ncolors(num_colors): | |||||
| """ | |||||
| 生成区别度较大的几种颜色 | |||||
| copy: https://blog.csdn.net/choumin/article/details/90320297 | |||||
| :param num_colors: 颜色数 | |||||
| :return: | |||||
| """ | |||||
| def get_n_hls_colors(num): | |||||
| import random | |||||
| hls_colors = [] | |||||
| i = 0 | |||||
| step = 360.0 / num | |||||
| while i < 360: | |||||
| h = i | |||||
| s = 90 + random.random() * 10 | |||||
| li = 50 + random.random() * 10 | |||||
| _hlsc = [h / 360.0, li / 100.0, s / 100.0] | |||||
| hls_colors.append(_hlsc) | |||||
| i += step | |||||
| return hls_colors | |||||
| import colorsys | |||||
| rgb_colors = [] | |||||
| if num_colors < 1: | |||||
| return rgb_colors | |||||
| for hlsc in get_n_hls_colors(num_colors): | |||||
| _r, _g, _b = colorsys.hls_to_rgb(hlsc[0], hlsc[1], hlsc[2]) | |||||
| r, g, b = [int(x * 255.0) for x in (_r, _g, _b)] | |||||
| rgb_colors.append([r, g, b]) | |||||
| return rgb_colors | |||||
| def visual_label(dataset_path, n_classes): | |||||
| """ | |||||
| 将标签可视化 | |||||
| :param dataset_path: 地址 | |||||
| :param n_classes: 类别数 | |||||
| """ | |||||
| label_path = os.path.join(dataset_path, 'test', 'labels').replace('\\', '/') | |||||
| label_image_list = glob.glob(label_path + '/*.png') | |||||
| label_image_list.sort() | |||||
| from torchvision import transforms | |||||
| trans_factory = transforms.ToPILImage() | |||||
| if not os.path.exists(dataset_path + '/visual_label'): | |||||
| os.mkdir(dataset_path + '/visual_label') | |||||
| for index in range(len(label_image_list)): | |||||
| label_image = cv2.imread(label_image_list[index], -1) | |||||
| name = os.path.basename(label_image_list[index]) | |||||
| trans_factory(torch.from_numpy(label_image).float() / n_classes).save( | |||||
| dataset_path + '/visual_label/' + name, | |||||
| quality=95) | |||||