Browse Source

完成神经网络训练框架; 完成分类任务

master
shenyan 4 years ago
commit
52ed4770bb
15 changed files with 810 additions and 0 deletions
  1. +21
    -0
      custom_profile.py
  2. +65
    -0
      data_module.py
  3. +38
    -0
      evalute.py
  4. +19
    -0
      free_servers/bscc-run.sh
  5. +92
    -0
      free_servers/hy-scaffold.ipynb
  6. +99
    -0
      main.py
  7. +40
    -0
      network/MLP.py
  8. +1
    -0
      network_module/__init__.py
  9. +32
    -0
      network_module/compute_utils.py
  10. +52
    -0
      network_module/iou.py
  11. +25
    -0
      network_module/pix_acc.py
  12. BIN
      requirements.txt
  13. +90
    -0
      save_checkpoint.py
  14. +47
    -0
      train_model.py
  15. +189
    -0
      utils.py

+ 21
- 0
custom_profile.py View File

@@ -0,0 +1,21 @@
"""
测试.py文件各个函数的运行时间
"""
import pstats

# 在cmd中使用python -m cProfile -o result.cprofile main.py

# 创建 Stats 对象
p = pstats.Stats('./result.cprofile')
# 按照运行时间和函数名进行排序
# 按照函数名排序,只打印前n行函数(其中n为print_stats(n)的输入参数)的信息,
p.strip_dirs().sort_stats("cumulative", "name").print_stats(15)
# 参数还可为小数, 表示前n(其中n为一个小于1的百分数, 是print_stats(n)的输入参数)的函数信息
# p.strip_dirs().sort_stats("cumulative", "name").print_stats(0.5)
# 查看调用main()的函数
# p.print_callers(0.5, "main")
# 查看main()函数中调用的函数
# p.print_callees("main")

# pip安装snakeviz后,在cmd里运行如下命令:
# snakeviz result.out

+ 65
- 0
data_module.py View File

@@ -0,0 +1,65 @@
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl


class DataModule(pl.LightningDataModule):
def __init__(self, batch_size, num_workers, config=None):
super().__init__()
# TODO 使用k折交叉验证
# divide_dataset(config['dataset_path'], [0.8, 0, 0.2])
if config['flag']:
x = torch.randn(100000, 2)
noise = torch.randn(100000, )
y = ((1.0 * x[:, 0] + 2.0 * x[:, 1] + noise) > 0).type(torch.int64)
else:
x_1 = torch.randn(100000)
x_2 = torch.randn(100000)
x_useful = torch.cos(1.5 * x_1) * (x_2 ** 2)
x_1_rest_small = torch.randn(100000, 15) + 0.01 * x_1.unsqueeze(1)
x_1_rest_large = torch.randn(100000, 15) + 0.1 * x_1.unsqueeze(1)
x_2_rest_small = torch.randn(100000, 15) + 0.01 * x_2.unsqueeze(1)
x_2_rest_large = torch.randn(100000, 15) + 0.1 * x_2.unsqueeze(1)
x = torch.cat([x_1[:, None], x_2[:, None], x_1_rest_small, x_1_rest_large, x_2_rest_small, x_2_rest_large],
dim=1)
y = ((10 * x_useful) + 5 * torch.randn(100000) > 0.0).type(torch.int64)

self.y_train, self.y_test = y[:50000], y[50000:]
self.x_train, self.x_test = x[:50000, :], x[50000:, :]

self.batch_size = batch_size
self.num_workers = num_workers
self.config = config
self.train_dataset = None
self.val_dataset = None
self.test_dataset = None

def setup(self, stage=None) -> None:
if stage == 'fit' or stage is None:
self.train_dataset = CustomDataset(self.x_train, self.y_train, self.config)
self.val_dataset = CustomDataset(self.x_test, self.y_test, self.config)
if stage == 'test' or stage is None:
self.test_dataset = CustomDataset(self.x_test, self.y_test, self.config)

def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=1, shuffle=False, num_workers=self.num_workers)


class CustomDataset(Dataset):
def __init__(self, x, y, config):
super().__init__()
self.x = x
self.y = y
self.config = config

def __getitem__(self, idx):
return self.x[idx, :], self.y[idx]

def __len__(self):
return self.x.shape[0]

+ 38
- 0
evalute.py View File

@@ -0,0 +1,38 @@
"""
评估指定文件夹下的预测结果, 评价结果均不计算背景类
"""
import numpy as np
from PIL import Image
from os.path import join
from network_module.iou import IOU
from network_module.pix_acc import calculate_acc


def evalute(n_classes, dataset_path, verbose=False):
iou = IOU(n_classes)

test_list = open(join(dataset_path, 'test_dataset_list.txt').replace('\\', '/')).readlines()
for ind in range(len(test_list)):
pred = np.array(Image.open(join(dataset_path, 'prediction', test_list[ind].strip('\n')).replace('\\', '/')))
label = np.array(Image.open(join(dataset_path, 'labels', test_list[ind].strip('\n')).replace('\\', '/')))
if len(label.flatten()) != len(pred.flatten()):
print('跳过{:s}: pred len {:d} != label len {:d},'.format(
test_list[ind].strip('\n'), len(label.flatten()), len(pred.flatten())))
continue

iou.add_data(pred, label)

# 必须置于iou_loss.forward前,因为forward会清除hist
overall_acc, acc = calculate_acc(iou.hist)
mIoU, IoUs = iou.get_miou()

if verbose:
for ind_class in range(n_classes):
print('===>' + str(ind_class) + ':\t' + str(IoUs[ind_class].float()))
print('===> mIoU: ' + str(mIoU))
print('===> overall accuracy:', overall_acc)
print('===> accuracy of each class:', acc)


if __name__ == "__main__":
evalute(9, './dataset/MFNet(RGB-T)-mini', verbose=True)

+ 19
- 0
free_servers/bscc-run.sh View File

@@ -0,0 +1,19 @@
#!/bin/bash

module load anaconda/2020.11
module load cuda/10.2
module load cudnn/8.1.1.33_CUDA10.2

#conda create --name py37 python=3.7
source activate py37

cd run
cd machineLearningScaffold

#pip install -r requirements.txt

python main.py

#sbatch --gpus=1 ./bscc-run.sh
#当前作业ID: 67417
#查询作业: parajobs 取消作业: scancel ID

+ 92
- 0
free_servers/hy-scaffold.ipynb View File

@@ -0,0 +1,92 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "4c4d04db-6fe0-4511-abb3-c746ba869863",
"metadata": {},
"outputs": [],
"source": [
"!apt-get update\n",
"!apt-get install p7zip-full -y"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7c027a62-8116-4c55-9a90-95c4d90dd087",
"metadata": {},
"outputs": [],
"source": [
"!7z x dataset.zip -o/hy-tmp\n",
"!7z x machineLearningScaffold.zip -o/hy-tmp\n",
"!7z x vit_pre-train_checkpoint.zip -o/hy-tmp\n",
"!mv -f /hy-tmp/pre-train_checkpoint/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz /hy-tmp/pre-train_checkpoint"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "4c6dc6a1-0208-4efb-beaf-b1264a94ba71",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARNING: You are using pip version 21.0.1; however, version 21.2.4 is available.\n",
"You should consider upgrading via the '/usr/bin/python3.8 -m pip install --upgrade pip' command.\u001b[0m\n"
]
}
],
"source": [
"!pip install -r /hy-tmp/requirements.txt > /dev/null"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2c074f52-d27e-4605-8b4f-45b11a2f1eb1",
"metadata": {},
"outputs": [],
"source": [
"import main\n",
"%load_ext tensorboard"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a5a8d64e-d90e-467f-a3b4-db8b31874044",
"metadata": {},
"outputs": [],
"source": [
"main.main('fit', gpus=1, dataset_path='./dataset/MFNet(RGB-T)', num_workers=2, max_epochs=60, batch_size=16, precision=16, seed=1234, # tpu_cores=8,\n",
" checkpoint_every_n_val=1, save_name='version_0', #checkpoint_path='/version_15/checkpoints/epoch=59-step=4739.ckpt',\n",
" path_final_save='./drive/MyDrive', save_top_k=1\n",
" )"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

+ 99
- 0
main.py View File

@@ -0,0 +1,99 @@
from save_checkpoint import SaveCheckpoint
from data_module import DataModule
from pytorch_lightning import loggers as pl_loggers
import pytorch_lightning as pl
from train_model import TrainModule


def main(stage,
num_workers,
max_epochs,
batch_size,
precision,
seed,
dataset_path=None,
gpus=None,
tpu_cores=None,
load_checkpoint_path=None,
save_name=None,
path_final_save=None,
every_n_epochs=1,
save_top_k=1,):
"""
框架的入口函数. 包含设置超参数, 划分数据集, 选择训练或测试等流程
该函数的参数为训练过程中需要经常改动的参数

:param stage: 表示处于训练阶段还是测试阶段, fit表示训练, test表示测试
:param num_workers:
:param max_epochs:
:param batch_size:
:param precision: 训练精度, 正常精度为32, 半精度为16, 也可以是64. 精度代表每个参数的类型所占的位数
:param seed:

:param dataset_path: 数据集地址, 其目录下包含数据集, 标签, 全部数据的命名list
:param gpus:
:param tpu_cores:
:param load_checkpoint_path:
:param save_name:
:param path_final_save:
:param every_n_epochs:
:param save_top_k:
"""
# config存放确定模型后不常改动的非通用的参数, 通用参数且不经常带动的直接进行声明
if False:
config = {'dataset_path': dataset_path,
'dim_in': 2,
'dim': 10,
'res_coef': 0.5,
'dropout_p': 0.1,
'n_layers': 2,
'flag': True}
else:
config = {'dataset_path': dataset_path,
'dim_in': 62,
'dim': 32,
'res_coef': 0.5,
'dropout_p': 0.1,
'n_layers': 20,
'flag': False}
# TODO 获得最优的batch size
# TODO 自动获取CPU核心数并设置num workers
precision = 32 if (gpus is None and tpu_cores is None) else precision
dm = DataModule(batch_size=batch_size, num_workers=num_workers, config=config)
logger = pl_loggers.TensorBoardLogger('logs/')
if stage == 'fit':
training_module = TrainModule(config=config)
save_checkpoint = SaveCheckpoint(seed=seed, max_epochs=max_epochs,
save_name=save_name, path_final_save=path_final_save,
every_n_epochs=every_n_epochs, verbose=True,
monitor='Validation acc', save_top_k=save_top_k,
mode='max')
if load_checkpoint_path is None:
print('进行初始训练')
trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores,
logger=logger, precision=precision, callbacks=[save_checkpoint])
training_module.load_pretrain_parameters()
else:
print('进行重载训练')
trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores,
resume_from_checkpoint='./logs/default' + load_checkpoint_path,
logger=logger, precision=precision, callbacks=[save_checkpoint])
trainer.fit(training_module, datamodule=dm)
if stage == 'test':
if load_checkpoint_path is None:
print('未载入权重信息,不能测试')
else:
print('进行测试')
training_module = TrainModule.load_from_checkpoint(
checkpoint_path='./logs/default' + load_checkpoint_path,
**{'config': config})
trainer = pl.Trainer(gpus=gpus, tpu_cores=tpu_cores, logger=logger, precision=precision)
trainer.test(training_module, datamodule=dm)
# 在cmd中使用tensorboard --logdir logs命令可以查看结果,在Jupyter格式下需要加%前缀


if __name__ == "__main__":
main('fit', num_workers=8, max_epochs=5, batch_size=32, precision=16, seed=1234,
# gpus=1,
# load_checkpoint_path='/version_5/checkpoints/epoch=149-step=7949.ckpt',
)

+ 40
- 0
network/MLP.py View File

@@ -0,0 +1,40 @@
import torch.nn as nn
from torch.nn import Mish


class MLPLayer(nn.Module):
def __init__(self, dim_in, dim_out, res_coef=0.0, dropout_p=0.1):
super().__init__()
self.linear = nn.Linear(dim_in, dim_out)
self.res_coef = res_coef
self.activation = Mish()
self.dropout = nn.Dropout(dropout_p)
self.ln = nn.LayerNorm(dim_out)

def forward(self, x):
y = self.linear(x)
y = self.activation(y)
y = self.dropout(y)
if self.res_coef == 0:
return self.ln(y)
else:
return self.ln(self.res_coef * x + y)


class MLP(nn.Module):
def __init__(self, dim_in, dim, res_coef=0.5, dropout_p=0.1, n_layers=10):
super().__init__()
self.mlp = nn.ModuleList()
self.first_linear = MLPLayer(dim_in, dim)
self.n_layers = n_layers
for i in range(n_layers):
self.mlp.append(MLPLayer(dim, dim, res_coef, dropout_p))
self.final = nn.Linear(dim, 1)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
x = self.first_linear(x)
for layer in self.mlp:
x = layer(x)
x = self.sigmoid(self.final(x))
return x.squeeze()

+ 1
- 0
network_module/__init__.py View File

@@ -0,0 +1 @@
# 存放一些通用的与网络相关的模块,不随网络的生命周期结束而消失的模块,比如优化器,损失,评估等

+ 32
- 0
network_module/compute_utils.py View File

@@ -0,0 +1,32 @@
import torch


def one_hot_encoder(input_tensor, n_classes):
"""
将输入tensor转化为one-hot形式

:param input_tensor:
:param n_classes:
:return:
"""
tensor_list = []
for i in range(n_classes):
temp_prob = input_tensor == i # * torch.ones_like(input_tensor)
tensor_list.append(temp_prob.unsqueeze(1))
output_tensor = torch.cat(tensor_list, dim=1)
return output_tensor.long()


def torch_nanmean(x):
"""
输出忽略nan的tensor均值

:param x:
:return:
"""
num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum()
value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum()
# num为0表示均为nan, 此时由于分母不能为0, 则设num为1
if num == 0:
num = 1
return value / num

+ 52
- 0
network_module/iou.py View File

@@ -0,0 +1,52 @@
"""
和计算iou相关的函数和类, 包括计算iou loss
"""
import torch
from torch import nn
from network_module.compute_utils import torch_nanmean


def fast_hist(pred, label, n_classes):
# np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n)
return torch.bincount(n_classes * label + pred, minlength=n_classes ** 2).reshape(n_classes, n_classes)


def per_class_iu(hist):
# 计算所有验证集图片的逐类别mIoU值
# 分别为每个类别计算mIoU,hist的形状(n, n)
# 矩阵的对角线上的值组成的一维数组/矩阵的所有元素之和,返回值形状(n,)
# hist.sum(0)=按列相加 hist.sum(1)按行相加, 行表示标签, 列表示预测
return (torch.diag(hist)) / (torch.sum(hist, 1) + torch.sum(hist, 0) - torch.diag(hist))


def get_ious(pred, label, n_classes):
hist = fast_hist(pred.flatten(), label.flatten(), n_classes)
IoUs = per_class_iu(hist)
mIoU = torch_nanmean(IoUs[1:n_classes])
return mIoU, IoUs


class IOU_loss(nn.Module):
def __init__(self, n_classes):
super(IOU_loss, self).__init__()
self.n_classes = n_classes

def forward(self, pred, label):
mIoU, _ = get_ious(pred, label, self.n_classes)
return 1 - mIoU


class IOU:
def __init__(self, n_classes):
self.n_classes = n_classes
self.hist = None

def add_data(self, preds, label):
self.hist += torch.zeros((self.n_classes, self.n_classes)).type_as(
preds) if self.hist is None else self.hist + fast_hist(preds.int(), label, self.n_classes)

def get_miou(self):
IoUs = per_class_iu(self.hist)
self.hist = None
mIoU = torch_nanmean(IoUs[1:self.n_classes])
return mIoU, IoUs

+ 25
- 0
network_module/pix_acc.py View File

@@ -0,0 +1,25 @@
import torch


def calculate_acc(hist):
"""
计算准确率, 而不是iou

:param hist:
:return:
"""
n_class = hist.size()[0]
conf = torch.zeros((n_class, n_class))
for cid in range(n_class):
if torch.sum(hist[:, cid]) > 0:
conf[:, cid] = hist[:, cid] / torch.sum(hist[:, cid])

# 可以看作对于除了背景外的像素点的判断accuracy, 但是比较偏向于判断为某些类的正确率.
# nan表示均判断为背景, 如果存在除背景外的类别, 则正确率为0; 如果不存在, 则表示nan(无结果,若不去除背景,则正确率为1)
overall_acc = torch.sum(torch.diag(hist[1:, 1:])) / torch.sum(hist[1:, :])

# acc为某类预测结果是正确的概率
# nan表示无像素判断为该类, 若存在该类, 则表示正确率为0; 若不存在, 则表示nan(无法判断为该类结果的正确率)
acc = torch.diag(conf)

return overall_acc, acc

BIN
requirements.txt View File


+ 90
- 0
save_checkpoint.py View File

@@ -0,0 +1,90 @@
import os
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning
import pytorch_lightning as pl
import shutil
import random
from pytorch_lightning.utilities import rank_zero_info
from utils import zip_dir


class SaveCheckpoint(ModelCheckpoint):
def __init__(self,
max_epochs,
seed=None,
every_n_epochs=None,
save_name=None,
path_final_save=None,
monitor=None,
save_top_k=None,
verbose=False,
mode='min',
no_save_before_epoch=0):
"""
通过回调实现checkpoint的保存逻辑, 同时具有回调函数中定义on_validation_end等功能.

:param max_epochs:
:param seed:
:param every_n_epochs:
:param save_name:
:param path_final_save:
:param monitor:
:param save_top_k:
:param verbose:
:param mode:
:param no_save_before_epoch:
"""
super().__init__(every_n_epochs=every_n_epochs, verbose=verbose, mode=mode)
random.seed(seed)
self.seeds = []
for i in range(max_epochs):
self.seeds.append(random.randint(0, 2000))
self.seeds.append(0)
pytorch_lightning.seed_everything(seed)
self.save_name = save_name
self.path_final_save = path_final_save
self.monitor = monitor
self.save_top_k = save_top_k
self.flag_sanity_check = 0
self.no_save_before_epoch = no_save_before_epoch

def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""
修改随机数逻辑,网络的随机种子给定,取样本的随机种子由给定的随机种子生成,保证即使重载训练每个epoch具有不同的抽样序列.
同时保存checkpoint.

:param trainer:
:param pl_module:
:return:
"""
if self.flag_sanity_check == 0:
pytorch_lightning.seed_everything(self.seeds[trainer.current_epoch])
self.flag_sanity_check = 1
else:
pytorch_lightning.seed_everything(self.seeds[trainer.current_epoch + 1])
super().on_validation_end(trainer, pl_module)

def _save_top_k_checkpoint(self, trainer: 'pl.Trainer', monitor_candidates) -> None:
epoch = monitor_candidates.get("epoch")
if self.monitor is None or self.save_top_k == 0 or epoch < self.no_save_before_epoch:
return

current = monitor_candidates.get(self.monitor)

if self.check_monitor_top_k(trainer, current):
self._update_best_and_save(current, trainer, monitor_candidates)
if self.save_name is not None and self.path_final_save is not None:
zip_dir('./logs/default/' + self.save_name, './' + self.save_name + '.zip')
if os.path.exists(self.path_final_save + '/' + self.save_name + '.zip'):
os.remove(self.path_final_save + '/' + self.save_name + '.zip')
shutil.move('./' + self.save_name + '.zip', self.path_final_save)
elif self.verbose:
epoch = monitor_candidates.get("epoch")
step = monitor_candidates.get("step")
best_model_values = 'now best model:'
for cou_best_model in self.best_k_models:
best_model_values = ' '.join(
(best_model_values, str(round(float(self.best_k_models[cou_best_model]), 4))))
rank_zero_info(
f"\nEpoch {epoch:d}, global step {step:d}: {self.monitor} ({float(current):f}) was not in "
f"top {self.save_top_k:d}({best_model_values:s})")

+ 47
- 0
train_model.py View File

@@ -0,0 +1,47 @@
import pytorch_lightning as pl
from torch import nn
import torch
from torchmetrics.classification.accuracy import Accuracy
from network.MLP import MLP


class TrainModule(pl.LightningModule):
def __init__(self, config=None):
super().__init__()
self.time_sum = None
self.config = config
self.net = MLP(config['dim_in'], config['dim'], config['res_coef'], config['dropout_p'], config['n_layers'])
# TODO 修改网络初始化方式为kaiming分布或者xavier分布
self.loss = nn.BCELoss()
self.accuracy = Accuracy()

def training_step(self, batch, batch_idx):
x, y = batch
x = self.net(x)
loss = self.loss(x, y.type(torch.float32))
acc = self.accuracy(x, y)
self.log("Training loss", loss)
self.log("Training acc", acc)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
x = self.net(x)
loss = self.loss(x, y.type(torch.float32))
acc = self.accuracy(x, y)
self.log("Validation loss", loss)
self.log("Validation acc", acc)
return loss, acc

def test_step(self, batch, batch_idx):
return 0

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer

def load_pretrain_parameters(self):
"""
载入预训练参数
"""
pass

+ 189
- 0
utils.py View File

@@ -0,0 +1,189 @@
# 包含一些与网络无关的工具
import glob
import os
import random
import zipfile
import cv2
import torch


def divide_dataset(dataset_path, rate_datasets):
"""
切分数据集, 划分为训练集,验证集,测试集生成list文件并保存为:
train_dataset_list、validate_dataset_list、test_dataset_list.
每个比例必须大于0且保证至少每个数据集中具有一个样本, 验证集可以为0.

:param dataset_path: 数据集的地址
:param rate_datasets: 不同数据集[训练集,验证集,测试集]的比例
"""
# 当不存在总的all_dataset_list文件时, 生成all_dataset_list
if not os.path.exists(dataset_path + '/all_dataset_list.txt'):
all_list = glob.glob(dataset_path + '/labels' + '/*.png')
with open(dataset_path + '/all_dataset_list.txt', 'w', encoding='utf-8') as f:
for line in all_list:
f.write(os.path.basename(line.replace('\\', '/')) + '\n')
path_train_dataset_list = dataset_path + '/train_dataset_list.txt'
path_validate_dataset_list = dataset_path + '/validate_dataset_list.txt'
path_test_dataset_list = dataset_path + '/test_dataset_list.txt'
# 如果验证集的比例为0,则将测试集设置为验证集并取消测试集;
if rate_datasets[1] == 0:
# 如果无切分后的list文件, 则生成新的list文件
if not (os.path.exists(path_train_dataset_list) and
os.path.exists(path_validate_dataset_list) and
os.path.exists(path_test_dataset_list)):
all_list = open(dataset_path + '/all_dataset_list.txt').readlines()
random.shuffle(all_list)
train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])]
test_dataset_list = all_list[int(len(all_list) * rate_datasets[0]):]
with open(path_train_dataset_list, 'w', encoding='utf-8') as f:
for line in train_dataset_list:
f.write(line)
with open(path_validate_dataset_list, 'w', encoding='utf-8') as f:
for line in test_dataset_list:
f.write(line)
with open(path_test_dataset_list, 'w', encoding='utf-8') as f:
for line in test_dataset_list:
f.write(line)
print('已生成新的数据list')
else:
# 判断比例是否正确,如果不正确,则重新生成数据集
all_list = open(dataset_path + '/all_dataset_list.txt').readlines()
with open(path_train_dataset_list) as f:
train_dataset_list_exist = f.readlines()
with open(path_validate_dataset_list) as f:
test_dataset_list_exist = f.readlines()
random.shuffle(all_list)
train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])]
test_dataset_list = all_list[int(len(all_list) * rate_datasets[0]):]
if not (len(train_dataset_list_exist) == len(train_dataset_list) and
len(test_dataset_list_exist) == len(test_dataset_list)):
with open(path_train_dataset_list, 'w', encoding='utf-8') as f:
for line in train_dataset_list:
f.write(line)
with open(path_validate_dataset_list, 'w', encoding='utf-8') as f:
for line in test_dataset_list:
f.write(line)
with open(path_test_dataset_list, 'w', encoding='utf-8') as f:
for line in test_dataset_list:
f.write(line)
print('已生成新的数据list')
# 如果验证集比例不为零,则同时存在验证集和测试集
else:
# 如果无切分后的list文件, 则生成新的list文件
if not (os.path.exists(dataset_path + '/train_dataset_list.txt') and
os.path.exists(dataset_path + '/validate_dataset_list.txt') and
os.path.exists(dataset_path + '/test_dataset_list.txt')):
all_list = open(dataset_path + '/all_dataset_list.txt').readlines()
random.shuffle(all_list)
train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])]
validate_dataset_list = all_list[int(len(all_list) * rate_datasets[0]):
int(len(all_list) * (rate_datasets[0] + rate_datasets[1]))]
test_dataset_list = all_list[int(len(all_list) * (rate_datasets[0] + rate_datasets[1])):]
with open(path_train_dataset_list, 'w', encoding='utf-8') as f:
for line in train_dataset_list:
f.write(line)
with open(path_validate_dataset_list, 'w', encoding='utf-8') as f:
for line in validate_dataset_list:
f.write(line)
with open(path_test_dataset_list, 'w', encoding='utf-8') as f:
for line in test_dataset_list:
f.write(line)
print('已生成新的数据list')
else:
# 判断比例是否正确,如果不正确,则重新生成数据集
all_list = open(dataset_path + '/all_dataset_list.txt').readlines()
with open(path_train_dataset_list) as f:
train_dataset_list_exist = f.readlines()
with open(path_validate_dataset_list) as f:
validate_dataset_list_exist = f.readlines()
with open(path_test_dataset_list) as f:
test_dataset_list_exist = f.readlines()
random.shuffle(all_list)
train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])]
validate_dataset_list = all_list[int(len(all_list) * rate_datasets[0]):
int(len(all_list) * (rate_datasets[0] + rate_datasets[1]))]
test_dataset_list = all_list[int(len(all_list) * (rate_datasets[0] + rate_datasets[1])):]
if not (len(train_dataset_list_exist) == len(train_dataset_list) and
len(validate_dataset_list_exist) == len(validate_dataset_list) and
len(test_dataset_list_exist) == len(test_dataset_list)):
with open(path_train_dataset_list, 'w', encoding='utf-8') as f:
for line in train_dataset_list:
f.write(line)
with open(path_validate_dataset_list, 'w', encoding='utf-8') as f:
for line in validate_dataset_list:
f.write(line)
with open(path_test_dataset_list, 'w', encoding='utf-8') as f:
for line in test_dataset_list:
f.write(line)
print('已生成新的数据list')


def zip_dir(dir_path, zip_path):
"""
压缩文件

:param dir_path: 目标文件夹路径
:param zip_path: 压缩后的文件夹路径
"""
ziper = zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED)
for root, dirnames, filenames in os.walk(dir_path):
file_path = root.replace(dir_path, '') # 去掉根路径,只对目标文件夹下的文件及文件夹进行压缩
# 循环出一个个文件名
for filename in filenames:
ziper.write(os.path.join(root, filename), os.path.join(file_path, filename))
ziper.close()


def ncolors(num_colors):
"""
生成区别度较大的几种颜色
copy: https://blog.csdn.net/choumin/article/details/90320297

:param num_colors: 颜色数
:return:
"""
def get_n_hls_colors(num):
import random
hls_colors = []
i = 0
step = 360.0 / num
while i < 360:
h = i
s = 90 + random.random() * 10
li = 50 + random.random() * 10
_hlsc = [h / 360.0, li / 100.0, s / 100.0]
hls_colors.append(_hlsc)
i += step
return hls_colors

import colorsys
rgb_colors = []
if num_colors < 1:
return rgb_colors
for hlsc in get_n_hls_colors(num_colors):
_r, _g, _b = colorsys.hls_to_rgb(hlsc[0], hlsc[1], hlsc[2])
r, g, b = [int(x * 255.0) for x in (_r, _g, _b)]
rgb_colors.append([r, g, b])
return rgb_colors


def visual_label(dataset_path, n_classes):
"""
将标签可视化

:param dataset_path: 地址
:param n_classes: 类别数
"""
label_path = os.path.join(dataset_path, 'test', 'labels').replace('\\', '/')
label_image_list = glob.glob(label_path + '/*.png')
label_image_list.sort()
from torchvision import transforms
trans_factory = transforms.ToPILImage()
if not os.path.exists(dataset_path + '/visual_label'):
os.mkdir(dataset_path + '/visual_label')
for index in range(len(label_image_list)):
label_image = cv2.imread(label_image_list[index], -1)
name = os.path.basename(label_image_list[index])
trans_factory(torch.from_numpy(label_image).float() / n_classes).save(
dataset_path + '/visual_label/' + name,
quality=95)

Loading…
Cancel
Save