| @@ -0,0 +1,21 @@ | |||
| """ | |||
| 测试.py文件各个函数的运行时间 | |||
| """ | |||
| import pstats | |||
| # 在cmd中使用python -m cProfile -o result.cprofile main.py | |||
| # 创建 Stats 对象 | |||
| p = pstats.Stats('./result.cprofile') | |||
| # 按照运行时间和函数名进行排序 | |||
| # 按照函数名排序,只打印前n行函数(其中n为print_stats(n)的输入参数)的信息, | |||
| p.strip_dirs().sort_stats("cumulative", "name").print_stats(15) | |||
| # 参数还可为小数, 表示前n(其中n为一个小于1的百分数, 是print_stats(n)的输入参数)的函数信息 | |||
| # p.strip_dirs().sort_stats("cumulative", "name").print_stats(0.5) | |||
| # 查看调用main()的函数 | |||
| # p.print_callers(0.5, "main") | |||
| # 查看main()函数中调用的函数 | |||
| # p.print_callees("main") | |||
| # pip安装snakeviz后,在cmd里运行如下命令: | |||
| # snakeviz result.out | |||
| @@ -0,0 +1,65 @@ | |||
| import torch | |||
| from torch.utils.data import Dataset, DataLoader | |||
| import pytorch_lightning as pl | |||
| class DataModule(pl.LightningDataModule): | |||
| def __init__(self, batch_size, num_workers, config=None): | |||
| super().__init__() | |||
| # TODO 使用k折交叉验证 | |||
| # divide_dataset(config['dataset_path'], [0.8, 0, 0.2]) | |||
| if config['flag']: | |||
| x = torch.randn(100000, 2) | |||
| noise = torch.randn(100000, ) | |||
| y = ((1.0 * x[:, 0] + 2.0 * x[:, 1] + noise) > 0).type(torch.int64) | |||
| else: | |||
| x_1 = torch.randn(100000) | |||
| x_2 = torch.randn(100000) | |||
| x_useful = torch.cos(1.5 * x_1) * (x_2 ** 2) | |||
| x_1_rest_small = torch.randn(100000, 15) + 0.01 * x_1.unsqueeze(1) | |||
| x_1_rest_large = torch.randn(100000, 15) + 0.1 * x_1.unsqueeze(1) | |||
| x_2_rest_small = torch.randn(100000, 15) + 0.01 * x_2.unsqueeze(1) | |||
| x_2_rest_large = torch.randn(100000, 15) + 0.1 * x_2.unsqueeze(1) | |||
| x = torch.cat([x_1[:, None], x_2[:, None], x_1_rest_small, x_1_rest_large, x_2_rest_small, x_2_rest_large], | |||
| dim=1) | |||
| y = ((10 * x_useful) + 5 * torch.randn(100000) > 0.0).type(torch.int64) | |||
| self.y_train, self.y_test = y[:50000], y[50000:] | |||
| self.x_train, self.x_test = x[:50000, :], x[50000:, :] | |||
| self.batch_size = batch_size | |||
| self.num_workers = num_workers | |||
| self.config = config | |||
| self.train_dataset = None | |||
| self.val_dataset = None | |||
| self.test_dataset = None | |||
| def setup(self, stage=None) -> None: | |||
| if stage == 'fit' or stage is None: | |||
| self.train_dataset = CustomDataset(self.x_train, self.y_train, self.config) | |||
| self.val_dataset = CustomDataset(self.x_test, self.y_test, self.config) | |||
| if stage == 'test' or stage is None: | |||
| self.test_dataset = CustomDataset(self.x_test, self.y_test, self.config) | |||
| def train_dataloader(self): | |||
| return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) | |||
| def val_dataloader(self): | |||
| return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers) | |||
| def test_dataloader(self): | |||
| return DataLoader(self.test_dataset, batch_size=1, shuffle=False, num_workers=self.num_workers) | |||
| class CustomDataset(Dataset): | |||
| def __init__(self, x, y, config): | |||
| super().__init__() | |||
| self.x = x | |||
| self.y = y | |||
| self.config = config | |||
| def __getitem__(self, idx): | |||
| return self.x[idx, :], self.y[idx] | |||
| def __len__(self): | |||
| return self.x.shape[0] | |||
| @@ -0,0 +1,38 @@ | |||
| """ | |||
| 评估指定文件夹下的预测结果, 评价结果均不计算背景类 | |||
| """ | |||
| import numpy as np | |||
| from PIL import Image | |||
| from os.path import join | |||
| from network_module.iou import IOU | |||
| from network_module.pix_acc import calculate_acc | |||
| def evalute(n_classes, dataset_path, verbose=False): | |||
| iou = IOU(n_classes) | |||
| test_list = open(join(dataset_path, 'test_dataset_list.txt').replace('\\', '/')).readlines() | |||
| for ind in range(len(test_list)): | |||
| pred = np.array(Image.open(join(dataset_path, 'prediction', test_list[ind].strip('\n')).replace('\\', '/'))) | |||
| label = np.array(Image.open(join(dataset_path, 'labels', test_list[ind].strip('\n')).replace('\\', '/'))) | |||
| if len(label.flatten()) != len(pred.flatten()): | |||
| print('跳过{:s}: pred len {:d} != label len {:d},'.format( | |||
| test_list[ind].strip('\n'), len(label.flatten()), len(pred.flatten()))) | |||
| continue | |||
| iou.add_data(pred, label) | |||
| # 必须置于iou_loss.forward前,因为forward会清除hist | |||
| overall_acc, acc = calculate_acc(iou.hist) | |||
| mIoU, IoUs = iou.get_miou() | |||
| if verbose: | |||
| for ind_class in range(n_classes): | |||
| print('===>' + str(ind_class) + ':\t' + str(IoUs[ind_class].float())) | |||
| print('===> mIoU: ' + str(mIoU)) | |||
| print('===> overall accuracy:', overall_acc) | |||
| print('===> accuracy of each class:', acc) | |||
| if __name__ == "__main__": | |||
| evalute(9, './dataset/MFNet(RGB-T)-mini', verbose=True) | |||
| @@ -0,0 +1,19 @@ | |||
| #!/bin/bash | |||
| module load anaconda/2020.11 | |||
| module load cuda/10.2 | |||
| module load cudnn/8.1.1.33_CUDA10.2 | |||
| #conda create --name py37 python=3.7 | |||
| source activate py37 | |||
| cd run | |||
| cd machineLearningScaffold | |||
| #pip install -r requirements.txt | |||
| python main.py | |||
| #sbatch --gpus=1 ./bscc-run.sh | |||
| #当前作业ID: 67417 | |||
| #查询作业: parajobs 取消作业: scancel ID | |||
| @@ -0,0 +1,92 @@ | |||
| { | |||
| "cells": [ | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "id": "4c4d04db-6fe0-4511-abb3-c746ba869863", | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "!apt-get update\n", | |||
| "!apt-get install p7zip-full -y" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "id": "7c027a62-8116-4c55-9a90-95c4d90dd087", | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "!7z x dataset.zip -o/hy-tmp\n", | |||
| "!7z x machineLearningScaffold.zip -o/hy-tmp\n", | |||
| "!7z x vit_pre-train_checkpoint.zip -o/hy-tmp\n", | |||
| "!mv -f /hy-tmp/pre-train_checkpoint/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz /hy-tmp/pre-train_checkpoint" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 8, | |||
| "id": "4c6dc6a1-0208-4efb-beaf-b1264a94ba71", | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "\u001b[33mWARNING: You are using pip version 21.0.1; however, version 21.2.4 is available.\n", | |||
| "You should consider upgrading via the '/usr/bin/python3.8 -m pip install --upgrade pip' command.\u001b[0m\n" | |||
| ] | |||
| } | |||
| ], | |||
| "source": [ | |||
| "!pip install -r /hy-tmp/requirements.txt > /dev/null" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "id": "2c074f52-d27e-4605-8b4f-45b11a2f1eb1", | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "import main\n", | |||
| "%load_ext tensorboard" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "id": "a5a8d64e-d90e-467f-a3b4-db8b31874044", | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "main.main('fit', gpus=1, dataset_path='./dataset/MFNet(RGB-T)', num_workers=2, max_epochs=60, batch_size=16, precision=16, seed=1234, # tpu_cores=8,\n", | |||
| " checkpoint_every_n_val=1, save_name='version_0', #checkpoint_path='/version_15/checkpoints/epoch=59-step=4739.ckpt',\n", | |||
| " path_final_save='./drive/MyDrive', save_top_k=1\n", | |||
| " )" | |||
| ] | |||
| } | |||
| ], | |||
| "metadata": { | |||
| "kernelspec": { | |||
| "display_name": "Python 3", | |||
| "language": "python", | |||
| "name": "python3" | |||
| }, | |||
| "language_info": { | |||
| "codemirror_mode": { | |||
| "name": "ipython", | |||
| "version": 3 | |||
| }, | |||
| "file_extension": ".py", | |||
| "mimetype": "text/x-python", | |||
| "name": "python", | |||
| "nbconvert_exporter": "python", | |||
| "pygments_lexer": "ipython3", | |||
| "version": "3.8.10" | |||
| } | |||
| }, | |||
| "nbformat": 4, | |||
| "nbformat_minor": 5 | |||
| } | |||
| @@ -0,0 +1,99 @@ | |||
| from save_checkpoint import SaveCheckpoint | |||
| from data_module import DataModule | |||
| from pytorch_lightning import loggers as pl_loggers | |||
| import pytorch_lightning as pl | |||
| from train_model import TrainModule | |||
| def main(stage, | |||
| num_workers, | |||
| max_epochs, | |||
| batch_size, | |||
| precision, | |||
| seed, | |||
| dataset_path=None, | |||
| gpus=None, | |||
| tpu_cores=None, | |||
| load_checkpoint_path=None, | |||
| save_name=None, | |||
| path_final_save=None, | |||
| every_n_epochs=1, | |||
| save_top_k=1,): | |||
| """ | |||
| 框架的入口函数. 包含设置超参数, 划分数据集, 选择训练或测试等流程 | |||
| 该函数的参数为训练过程中需要经常改动的参数 | |||
| :param stage: 表示处于训练阶段还是测试阶段, fit表示训练, test表示测试 | |||
| :param num_workers: | |||
| :param max_epochs: | |||
| :param batch_size: | |||
| :param precision: 训练精度, 正常精度为32, 半精度为16, 也可以是64. 精度代表每个参数的类型所占的位数 | |||
| :param seed: | |||
| :param dataset_path: 数据集地址, 其目录下包含数据集, 标签, 全部数据的命名list | |||
| :param gpus: | |||
| :param tpu_cores: | |||
| :param load_checkpoint_path: | |||
| :param save_name: | |||
| :param path_final_save: | |||
| :param every_n_epochs: | |||
| :param save_top_k: | |||
| """ | |||
| # config存放确定模型后不常改动的非通用的参数, 通用参数且不经常带动的直接进行声明 | |||
| if False: | |||
| config = {'dataset_path': dataset_path, | |||
| 'dim_in': 2, | |||
| 'dim': 10, | |||
| 'res_coef': 0.5, | |||
| 'dropout_p': 0.1, | |||
| 'n_layers': 2, | |||
| 'flag': True} | |||
| else: | |||
| config = {'dataset_path': dataset_path, | |||
| 'dim_in': 62, | |||
| 'dim': 32, | |||
| 'res_coef': 0.5, | |||
| 'dropout_p': 0.1, | |||
| 'n_layers': 20, | |||
| 'flag': False} | |||
| # TODO 获得最优的batch size | |||
| # TODO 自动获取CPU核心数并设置num workers | |||
| precision = 32 if (gpus is None and tpu_cores is None) else precision | |||
| dm = DataModule(batch_size=batch_size, num_workers=num_workers, config=config) | |||
| logger = pl_loggers.TensorBoardLogger('logs/') | |||
| if stage == 'fit': | |||
| training_module = TrainModule(config=config) | |||
| save_checkpoint = SaveCheckpoint(seed=seed, max_epochs=max_epochs, | |||
| save_name=save_name, path_final_save=path_final_save, | |||
| every_n_epochs=every_n_epochs, verbose=True, | |||
| monitor='Validation acc', save_top_k=save_top_k, | |||
| mode='max') | |||
| if load_checkpoint_path is None: | |||
| print('进行初始训练') | |||
| trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores, | |||
| logger=logger, precision=precision, callbacks=[save_checkpoint]) | |||
| training_module.load_pretrain_parameters() | |||
| else: | |||
| print('进行重载训练') | |||
| trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores, | |||
| resume_from_checkpoint='./logs/default' + load_checkpoint_path, | |||
| logger=logger, precision=precision, callbacks=[save_checkpoint]) | |||
| trainer.fit(training_module, datamodule=dm) | |||
| if stage == 'test': | |||
| if load_checkpoint_path is None: | |||
| print('未载入权重信息,不能测试') | |||
| else: | |||
| print('进行测试') | |||
| training_module = TrainModule.load_from_checkpoint( | |||
| checkpoint_path='./logs/default' + load_checkpoint_path, | |||
| **{'config': config}) | |||
| trainer = pl.Trainer(gpus=gpus, tpu_cores=tpu_cores, logger=logger, precision=precision) | |||
| trainer.test(training_module, datamodule=dm) | |||
| # 在cmd中使用tensorboard --logdir logs命令可以查看结果,在Jupyter格式下需要加%前缀 | |||
| if __name__ == "__main__": | |||
| main('fit', num_workers=8, max_epochs=5, batch_size=32, precision=16, seed=1234, | |||
| # gpus=1, | |||
| # load_checkpoint_path='/version_5/checkpoints/epoch=149-step=7949.ckpt', | |||
| ) | |||
| @@ -0,0 +1,40 @@ | |||
| import torch.nn as nn | |||
| from torch.nn import Mish | |||
| class MLPLayer(nn.Module): | |||
| def __init__(self, dim_in, dim_out, res_coef=0.0, dropout_p=0.1): | |||
| super().__init__() | |||
| self.linear = nn.Linear(dim_in, dim_out) | |||
| self.res_coef = res_coef | |||
| self.activation = Mish() | |||
| self.dropout = nn.Dropout(dropout_p) | |||
| self.ln = nn.LayerNorm(dim_out) | |||
| def forward(self, x): | |||
| y = self.linear(x) | |||
| y = self.activation(y) | |||
| y = self.dropout(y) | |||
| if self.res_coef == 0: | |||
| return self.ln(y) | |||
| else: | |||
| return self.ln(self.res_coef * x + y) | |||
| class MLP(nn.Module): | |||
| def __init__(self, dim_in, dim, res_coef=0.5, dropout_p=0.1, n_layers=10): | |||
| super().__init__() | |||
| self.mlp = nn.ModuleList() | |||
| self.first_linear = MLPLayer(dim_in, dim) | |||
| self.n_layers = n_layers | |||
| for i in range(n_layers): | |||
| self.mlp.append(MLPLayer(dim, dim, res_coef, dropout_p)) | |||
| self.final = nn.Linear(dim, 1) | |||
| self.sigmoid = nn.Sigmoid() | |||
| def forward(self, x): | |||
| x = self.first_linear(x) | |||
| for layer in self.mlp: | |||
| x = layer(x) | |||
| x = self.sigmoid(self.final(x)) | |||
| return x.squeeze() | |||
| @@ -0,0 +1 @@ | |||
| # 存放一些通用的与网络相关的模块,不随网络的生命周期结束而消失的模块,比如优化器,损失,评估等 | |||
| @@ -0,0 +1,32 @@ | |||
| import torch | |||
| def one_hot_encoder(input_tensor, n_classes): | |||
| """ | |||
| 将输入tensor转化为one-hot形式 | |||
| :param input_tensor: | |||
| :param n_classes: | |||
| :return: | |||
| """ | |||
| tensor_list = [] | |||
| for i in range(n_classes): | |||
| temp_prob = input_tensor == i # * torch.ones_like(input_tensor) | |||
| tensor_list.append(temp_prob.unsqueeze(1)) | |||
| output_tensor = torch.cat(tensor_list, dim=1) | |||
| return output_tensor.long() | |||
| def torch_nanmean(x): | |||
| """ | |||
| 输出忽略nan的tensor均值 | |||
| :param x: | |||
| :return: | |||
| """ | |||
| num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum() | |||
| value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum() | |||
| # num为0表示均为nan, 此时由于分母不能为0, 则设num为1 | |||
| if num == 0: | |||
| num = 1 | |||
| return value / num | |||
| @@ -0,0 +1,52 @@ | |||
| """ | |||
| 和计算iou相关的函数和类, 包括计算iou loss | |||
| """ | |||
| import torch | |||
| from torch import nn | |||
| from network_module.compute_utils import torch_nanmean | |||
| def fast_hist(pred, label, n_classes): | |||
| # np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n) | |||
| return torch.bincount(n_classes * label + pred, minlength=n_classes ** 2).reshape(n_classes, n_classes) | |||
| def per_class_iu(hist): | |||
| # 计算所有验证集图片的逐类别mIoU值 | |||
| # 分别为每个类别计算mIoU,hist的形状(n, n) | |||
| # 矩阵的对角线上的值组成的一维数组/矩阵的所有元素之和,返回值形状(n,) | |||
| # hist.sum(0)=按列相加 hist.sum(1)按行相加, 行表示标签, 列表示预测 | |||
| return (torch.diag(hist)) / (torch.sum(hist, 1) + torch.sum(hist, 0) - torch.diag(hist)) | |||
| def get_ious(pred, label, n_classes): | |||
| hist = fast_hist(pred.flatten(), label.flatten(), n_classes) | |||
| IoUs = per_class_iu(hist) | |||
| mIoU = torch_nanmean(IoUs[1:n_classes]) | |||
| return mIoU, IoUs | |||
| class IOU_loss(nn.Module): | |||
| def __init__(self, n_classes): | |||
| super(IOU_loss, self).__init__() | |||
| self.n_classes = n_classes | |||
| def forward(self, pred, label): | |||
| mIoU, _ = get_ious(pred, label, self.n_classes) | |||
| return 1 - mIoU | |||
| class IOU: | |||
| def __init__(self, n_classes): | |||
| self.n_classes = n_classes | |||
| self.hist = None | |||
| def add_data(self, preds, label): | |||
| self.hist += torch.zeros((self.n_classes, self.n_classes)).type_as( | |||
| preds) if self.hist is None else self.hist + fast_hist(preds.int(), label, self.n_classes) | |||
| def get_miou(self): | |||
| IoUs = per_class_iu(self.hist) | |||
| self.hist = None | |||
| mIoU = torch_nanmean(IoUs[1:self.n_classes]) | |||
| return mIoU, IoUs | |||
| @@ -0,0 +1,25 @@ | |||
| import torch | |||
| def calculate_acc(hist): | |||
| """ | |||
| 计算准确率, 而不是iou | |||
| :param hist: | |||
| :return: | |||
| """ | |||
| n_class = hist.size()[0] | |||
| conf = torch.zeros((n_class, n_class)) | |||
| for cid in range(n_class): | |||
| if torch.sum(hist[:, cid]) > 0: | |||
| conf[:, cid] = hist[:, cid] / torch.sum(hist[:, cid]) | |||
| # 可以看作对于除了背景外的像素点的判断accuracy, 但是比较偏向于判断为某些类的正确率. | |||
| # nan表示均判断为背景, 如果存在除背景外的类别, 则正确率为0; 如果不存在, 则表示nan(无结果,若不去除背景,则正确率为1) | |||
| overall_acc = torch.sum(torch.diag(hist[1:, 1:])) / torch.sum(hist[1:, :]) | |||
| # acc为某类预测结果是正确的概率 | |||
| # nan表示无像素判断为该类, 若存在该类, 则表示正确率为0; 若不存在, 则表示nan(无法判断为该类结果的正确率) | |||
| acc = torch.diag(conf) | |||
| return overall_acc, acc | |||
| @@ -0,0 +1,90 @@ | |||
| import os | |||
| from pytorch_lightning.callbacks import ModelCheckpoint | |||
| import pytorch_lightning | |||
| import pytorch_lightning as pl | |||
| import shutil | |||
| import random | |||
| from pytorch_lightning.utilities import rank_zero_info | |||
| from utils import zip_dir | |||
| class SaveCheckpoint(ModelCheckpoint): | |||
| def __init__(self, | |||
| max_epochs, | |||
| seed=None, | |||
| every_n_epochs=None, | |||
| save_name=None, | |||
| path_final_save=None, | |||
| monitor=None, | |||
| save_top_k=None, | |||
| verbose=False, | |||
| mode='min', | |||
| no_save_before_epoch=0): | |||
| """ | |||
| 通过回调实现checkpoint的保存逻辑, 同时具有回调函数中定义on_validation_end等功能. | |||
| :param max_epochs: | |||
| :param seed: | |||
| :param every_n_epochs: | |||
| :param save_name: | |||
| :param path_final_save: | |||
| :param monitor: | |||
| :param save_top_k: | |||
| :param verbose: | |||
| :param mode: | |||
| :param no_save_before_epoch: | |||
| """ | |||
| super().__init__(every_n_epochs=every_n_epochs, verbose=verbose, mode=mode) | |||
| random.seed(seed) | |||
| self.seeds = [] | |||
| for i in range(max_epochs): | |||
| self.seeds.append(random.randint(0, 2000)) | |||
| self.seeds.append(0) | |||
| pytorch_lightning.seed_everything(seed) | |||
| self.save_name = save_name | |||
| self.path_final_save = path_final_save | |||
| self.monitor = monitor | |||
| self.save_top_k = save_top_k | |||
| self.flag_sanity_check = 0 | |||
| self.no_save_before_epoch = no_save_before_epoch | |||
| def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: | |||
| """ | |||
| 修改随机数逻辑,网络的随机种子给定,取样本的随机种子由给定的随机种子生成,保证即使重载训练每个epoch具有不同的抽样序列. | |||
| 同时保存checkpoint. | |||
| :param trainer: | |||
| :param pl_module: | |||
| :return: | |||
| """ | |||
| if self.flag_sanity_check == 0: | |||
| pytorch_lightning.seed_everything(self.seeds[trainer.current_epoch]) | |||
| self.flag_sanity_check = 1 | |||
| else: | |||
| pytorch_lightning.seed_everything(self.seeds[trainer.current_epoch + 1]) | |||
| super().on_validation_end(trainer, pl_module) | |||
| def _save_top_k_checkpoint(self, trainer: 'pl.Trainer', monitor_candidates) -> None: | |||
| epoch = monitor_candidates.get("epoch") | |||
| if self.monitor is None or self.save_top_k == 0 or epoch < self.no_save_before_epoch: | |||
| return | |||
| current = monitor_candidates.get(self.monitor) | |||
| if self.check_monitor_top_k(trainer, current): | |||
| self._update_best_and_save(current, trainer, monitor_candidates) | |||
| if self.save_name is not None and self.path_final_save is not None: | |||
| zip_dir('./logs/default/' + self.save_name, './' + self.save_name + '.zip') | |||
| if os.path.exists(self.path_final_save + '/' + self.save_name + '.zip'): | |||
| os.remove(self.path_final_save + '/' + self.save_name + '.zip') | |||
| shutil.move('./' + self.save_name + '.zip', self.path_final_save) | |||
| elif self.verbose: | |||
| epoch = monitor_candidates.get("epoch") | |||
| step = monitor_candidates.get("step") | |||
| best_model_values = 'now best model:' | |||
| for cou_best_model in self.best_k_models: | |||
| best_model_values = ' '.join( | |||
| (best_model_values, str(round(float(self.best_k_models[cou_best_model]), 4)))) | |||
| rank_zero_info( | |||
| f"\nEpoch {epoch:d}, global step {step:d}: {self.monitor} ({float(current):f}) was not in " | |||
| f"top {self.save_top_k:d}({best_model_values:s})") | |||
| @@ -0,0 +1,47 @@ | |||
| import pytorch_lightning as pl | |||
| from torch import nn | |||
| import torch | |||
| from torchmetrics.classification.accuracy import Accuracy | |||
| from network.MLP import MLP | |||
| class TrainModule(pl.LightningModule): | |||
| def __init__(self, config=None): | |||
| super().__init__() | |||
| self.time_sum = None | |||
| self.config = config | |||
| self.net = MLP(config['dim_in'], config['dim'], config['res_coef'], config['dropout_p'], config['n_layers']) | |||
| # TODO 修改网络初始化方式为kaiming分布或者xavier分布 | |||
| self.loss = nn.BCELoss() | |||
| self.accuracy = Accuracy() | |||
| def training_step(self, batch, batch_idx): | |||
| x, y = batch | |||
| x = self.net(x) | |||
| loss = self.loss(x, y.type(torch.float32)) | |||
| acc = self.accuracy(x, y) | |||
| self.log("Training loss", loss) | |||
| self.log("Training acc", acc) | |||
| return loss | |||
| def validation_step(self, batch, batch_idx): | |||
| x, y = batch | |||
| x = self.net(x) | |||
| loss = self.loss(x, y.type(torch.float32)) | |||
| acc = self.accuracy(x, y) | |||
| self.log("Validation loss", loss) | |||
| self.log("Validation acc", acc) | |||
| return loss, acc | |||
| def test_step(self, batch, batch_idx): | |||
| return 0 | |||
| def configure_optimizers(self): | |||
| optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) | |||
| return optimizer | |||
| def load_pretrain_parameters(self): | |||
| """ | |||
| 载入预训练参数 | |||
| """ | |||
| pass | |||
| @@ -0,0 +1,189 @@ | |||
| # 包含一些与网络无关的工具 | |||
| import glob | |||
| import os | |||
| import random | |||
| import zipfile | |||
| import cv2 | |||
| import torch | |||
| def divide_dataset(dataset_path, rate_datasets): | |||
| """ | |||
| 切分数据集, 划分为训练集,验证集,测试集生成list文件并保存为: | |||
| train_dataset_list、validate_dataset_list、test_dataset_list. | |||
| 每个比例必须大于0且保证至少每个数据集中具有一个样本, 验证集可以为0. | |||
| :param dataset_path: 数据集的地址 | |||
| :param rate_datasets: 不同数据集[训练集,验证集,测试集]的比例 | |||
| """ | |||
| # 当不存在总的all_dataset_list文件时, 生成all_dataset_list | |||
| if not os.path.exists(dataset_path + '/all_dataset_list.txt'): | |||
| all_list = glob.glob(dataset_path + '/labels' + '/*.png') | |||
| with open(dataset_path + '/all_dataset_list.txt', 'w', encoding='utf-8') as f: | |||
| for line in all_list: | |||
| f.write(os.path.basename(line.replace('\\', '/')) + '\n') | |||
| path_train_dataset_list = dataset_path + '/train_dataset_list.txt' | |||
| path_validate_dataset_list = dataset_path + '/validate_dataset_list.txt' | |||
| path_test_dataset_list = dataset_path + '/test_dataset_list.txt' | |||
| # 如果验证集的比例为0,则将测试集设置为验证集并取消测试集; | |||
| if rate_datasets[1] == 0: | |||
| # 如果无切分后的list文件, 则生成新的list文件 | |||
| if not (os.path.exists(path_train_dataset_list) and | |||
| os.path.exists(path_validate_dataset_list) and | |||
| os.path.exists(path_test_dataset_list)): | |||
| all_list = open(dataset_path + '/all_dataset_list.txt').readlines() | |||
| random.shuffle(all_list) | |||
| train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])] | |||
| test_dataset_list = all_list[int(len(all_list) * rate_datasets[0]):] | |||
| with open(path_train_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in train_dataset_list: | |||
| f.write(line) | |||
| with open(path_validate_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in test_dataset_list: | |||
| f.write(line) | |||
| with open(path_test_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in test_dataset_list: | |||
| f.write(line) | |||
| print('已生成新的数据list') | |||
| else: | |||
| # 判断比例是否正确,如果不正确,则重新生成数据集 | |||
| all_list = open(dataset_path + '/all_dataset_list.txt').readlines() | |||
| with open(path_train_dataset_list) as f: | |||
| train_dataset_list_exist = f.readlines() | |||
| with open(path_validate_dataset_list) as f: | |||
| test_dataset_list_exist = f.readlines() | |||
| random.shuffle(all_list) | |||
| train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])] | |||
| test_dataset_list = all_list[int(len(all_list) * rate_datasets[0]):] | |||
| if not (len(train_dataset_list_exist) == len(train_dataset_list) and | |||
| len(test_dataset_list_exist) == len(test_dataset_list)): | |||
| with open(path_train_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in train_dataset_list: | |||
| f.write(line) | |||
| with open(path_validate_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in test_dataset_list: | |||
| f.write(line) | |||
| with open(path_test_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in test_dataset_list: | |||
| f.write(line) | |||
| print('已生成新的数据list') | |||
| # 如果验证集比例不为零,则同时存在验证集和测试集 | |||
| else: | |||
| # 如果无切分后的list文件, 则生成新的list文件 | |||
| if not (os.path.exists(dataset_path + '/train_dataset_list.txt') and | |||
| os.path.exists(dataset_path + '/validate_dataset_list.txt') and | |||
| os.path.exists(dataset_path + '/test_dataset_list.txt')): | |||
| all_list = open(dataset_path + '/all_dataset_list.txt').readlines() | |||
| random.shuffle(all_list) | |||
| train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])] | |||
| validate_dataset_list = all_list[int(len(all_list) * rate_datasets[0]): | |||
| int(len(all_list) * (rate_datasets[0] + rate_datasets[1]))] | |||
| test_dataset_list = all_list[int(len(all_list) * (rate_datasets[0] + rate_datasets[1])):] | |||
| with open(path_train_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in train_dataset_list: | |||
| f.write(line) | |||
| with open(path_validate_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in validate_dataset_list: | |||
| f.write(line) | |||
| with open(path_test_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in test_dataset_list: | |||
| f.write(line) | |||
| print('已生成新的数据list') | |||
| else: | |||
| # 判断比例是否正确,如果不正确,则重新生成数据集 | |||
| all_list = open(dataset_path + '/all_dataset_list.txt').readlines() | |||
| with open(path_train_dataset_list) as f: | |||
| train_dataset_list_exist = f.readlines() | |||
| with open(path_validate_dataset_list) as f: | |||
| validate_dataset_list_exist = f.readlines() | |||
| with open(path_test_dataset_list) as f: | |||
| test_dataset_list_exist = f.readlines() | |||
| random.shuffle(all_list) | |||
| train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])] | |||
| validate_dataset_list = all_list[int(len(all_list) * rate_datasets[0]): | |||
| int(len(all_list) * (rate_datasets[0] + rate_datasets[1]))] | |||
| test_dataset_list = all_list[int(len(all_list) * (rate_datasets[0] + rate_datasets[1])):] | |||
| if not (len(train_dataset_list_exist) == len(train_dataset_list) and | |||
| len(validate_dataset_list_exist) == len(validate_dataset_list) and | |||
| len(test_dataset_list_exist) == len(test_dataset_list)): | |||
| with open(path_train_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in train_dataset_list: | |||
| f.write(line) | |||
| with open(path_validate_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in validate_dataset_list: | |||
| f.write(line) | |||
| with open(path_test_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in test_dataset_list: | |||
| f.write(line) | |||
| print('已生成新的数据list') | |||
| def zip_dir(dir_path, zip_path): | |||
| """ | |||
| 压缩文件 | |||
| :param dir_path: 目标文件夹路径 | |||
| :param zip_path: 压缩后的文件夹路径 | |||
| """ | |||
| ziper = zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) | |||
| for root, dirnames, filenames in os.walk(dir_path): | |||
| file_path = root.replace(dir_path, '') # 去掉根路径,只对目标文件夹下的文件及文件夹进行压缩 | |||
| # 循环出一个个文件名 | |||
| for filename in filenames: | |||
| ziper.write(os.path.join(root, filename), os.path.join(file_path, filename)) | |||
| ziper.close() | |||
| def ncolors(num_colors): | |||
| """ | |||
| 生成区别度较大的几种颜色 | |||
| copy: https://blog.csdn.net/choumin/article/details/90320297 | |||
| :param num_colors: 颜色数 | |||
| :return: | |||
| """ | |||
| def get_n_hls_colors(num): | |||
| import random | |||
| hls_colors = [] | |||
| i = 0 | |||
| step = 360.0 / num | |||
| while i < 360: | |||
| h = i | |||
| s = 90 + random.random() * 10 | |||
| li = 50 + random.random() * 10 | |||
| _hlsc = [h / 360.0, li / 100.0, s / 100.0] | |||
| hls_colors.append(_hlsc) | |||
| i += step | |||
| return hls_colors | |||
| import colorsys | |||
| rgb_colors = [] | |||
| if num_colors < 1: | |||
| return rgb_colors | |||
| for hlsc in get_n_hls_colors(num_colors): | |||
| _r, _g, _b = colorsys.hls_to_rgb(hlsc[0], hlsc[1], hlsc[2]) | |||
| r, g, b = [int(x * 255.0) for x in (_r, _g, _b)] | |||
| rgb_colors.append([r, g, b]) | |||
| return rgb_colors | |||
| def visual_label(dataset_path, n_classes): | |||
| """ | |||
| 将标签可视化 | |||
| :param dataset_path: 地址 | |||
| :param n_classes: 类别数 | |||
| """ | |||
| label_path = os.path.join(dataset_path, 'test', 'labels').replace('\\', '/') | |||
| label_image_list = glob.glob(label_path + '/*.png') | |||
| label_image_list.sort() | |||
| from torchvision import transforms | |||
| trans_factory = transforms.ToPILImage() | |||
| if not os.path.exists(dataset_path + '/visual_label'): | |||
| os.mkdir(dataset_path + '/visual_label') | |||
| for index in range(len(label_image_list)): | |||
| label_image = cv2.imread(label_image_list[index], -1) | |||
| name = os.path.basename(label_image_list[index]) | |||
| trans_factory(torch.from_numpy(label_image).float() / n_classes).save( | |||
| dataset_path + '/visual_label/' + name, | |||
| quality=95) | |||