| @@ -1,45 +1,59 @@ | |||
| import os | |||
| import numpy | |||
| import torch | |||
| from torch import Tensor | |||
| from torch.utils.data import Dataset, DataLoader | |||
| import pytorch_lightning as pl | |||
| class DataModule(pl.LightningDataModule): | |||
| def __init__(self, batch_size, num_workers, config=None): | |||
| def __init__(self, batch_size, num_workers, k_fold, kth_fold, dataset_path, config=None): | |||
| super().__init__() | |||
| # TODO 使用k折交叉验证 | |||
| # divide_dataset(config['dataset_path'], [0.8, 0, 0.2]) | |||
| if config['flag']: | |||
| self.x = torch.randn(config['dataset_len'], 2) | |||
| noise = torch.randn(config['dataset_len'], ) | |||
| self.y = 1.0 * self.x[:, 0] + 2.0 * self.x[:, 1] + noise | |||
| else: | |||
| x_1 = torch.randn(config['dataset_len']) | |||
| x_2 = torch.randn(config['dataset_len']) | |||
| x_useful = torch.cos(1.5 * x_1) * (x_2 ** 2) | |||
| x_1_rest_small = torch.randn(config['dataset_len'], 15) + 0.01 * x_1.unsqueeze(1) | |||
| x_1_rest_large = torch.randn(config['dataset_len'], 15) + 0.1 * x_1.unsqueeze(1) | |||
| x_2_rest_small = torch.randn(config['dataset_len'], 15) + 0.01 * x_2.unsqueeze(1) | |||
| x_2_rest_large = torch.randn(config['dataset_len'], 15) + 0.1 * x_2.unsqueeze(1) | |||
| self.x = torch.cat([x_1[:, None], x_2[:, None], x_1_rest_small, x_1_rest_large, x_2_rest_small, x_2_rest_large], | |||
| dim=1) | |||
| self.y = (10 * x_useful) + 5 * torch.randn(config['dataset_len']) | |||
| self.y_train, self.y_test = self.y[:50000], self.y[50000:] | |||
| self.x_train, self.x_test = self.x[:50000, :], self.x[50000:, :] | |||
| self.batch_size = batch_size | |||
| self.num_workers = num_workers | |||
| self.config = config | |||
| self.train_dataset = None | |||
| self.val_dataset = None | |||
| self.test_dataset = None | |||
| self.k_fold = k_fold | |||
| self.kth_fold = kth_fold | |||
| self.dataset_path = dataset_path | |||
| def setup(self, stage=None) -> None: | |||
| # 得到全部数据的list | |||
| # dataset_list = get_dataset_list(dataset_path) | |||
| x, y = self.get_fit_dataset_list() | |||
| if stage == 'fit' or stage is None: | |||
| self.train_dataset = CustomDataset(self.x_train, self.y_train, self.config) | |||
| self.val_dataset = CustomDataset(self.x_test, self.y_test, self.config) | |||
| x_train, y_train, x_val, y_val = self.get_dataset_lists(x, y) | |||
| self.train_dataset = CustomDataset(x_train, y_train, self.config) | |||
| self.val_dataset = CustomDataset(x_val, y_val, self.config) | |||
| if stage == 'test' or stage is None: | |||
| self.test_dataset = CustomDataset(self.x, self.y, self.config) | |||
| self.test_dataset = CustomDataset(x, y, self.config) | |||
| def get_fit_dataset_list(self): | |||
| if not os.path.exists(self.dataset_path + '/dataset_list.txt'): | |||
| x = torch.randn(self.config['dataset_len'], self.config['dim_in']) | |||
| noise = torch.randn(self.config['dataset_len']) | |||
| y = torch.cos(1.5 * x[:, 0]) * (x[:, 1] ** 2.0) + noise | |||
| with open(self.dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f: | |||
| for line in range(self.config['dataset_len']): | |||
| f.write(' '.join([str(temp) for temp in x[line].tolist()]) + ' ' + str(y[line].item()) + '\n') | |||
| print('已生成新的数据list') | |||
| else: | |||
| dataset_list = open(self.dataset_path + '/dataset_list.txt').readlines() | |||
| dataset_list = [[float(temp) for temp in item.strip('\n').split(' ')] for item in dataset_list] | |||
| x = torch.from_numpy(numpy.array(dataset_list)[:, 0:self.config['dim_in']]).float() | |||
| y = torch.from_numpy(numpy.array(dataset_list)[:, self.config['dim_in']]).float() | |||
| return x, y | |||
| def get_dataset_lists(self, x: Tensor, y): | |||
| # 得到一个fold的数据量和不够组成一个fold的剩余数据的数据量 | |||
| num_1fold, remainder = divmod(self.config['dataset_len'], self.k_fold) | |||
| # 分割全部数据, 得到训练集, 验证集, 测试集 | |||
| x_val = x[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder)] | |||
| y_val = y[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder)] | |||
| temp = torch.ones(x.shape[0]) | |||
| temp[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder)] = 0 | |||
| x_train = x[temp == 1] | |||
| y_train = y[temp == 1] | |||
| return x_train, y_train, x_val, y_val | |||
| def train_dataloader(self): | |||
| return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, | |||
| @@ -5,20 +5,23 @@ import pytorch_lightning as pl | |||
| from train_model import TrainModule | |||
| from multiprocessing import cpu_count | |||
| from utils import get_ckpt_path | |||
| def main(stage, | |||
| max_epochs, | |||
| batch_size, | |||
| precision, | |||
| seed, | |||
| dataset_path=None, | |||
| dataset_path, | |||
| gpus=None, | |||
| tpu_cores=None, | |||
| load_checkpoint_path=None, | |||
| save_name=None, | |||
| version_nth=None, | |||
| path_final_save=None, | |||
| every_n_epochs=1, | |||
| save_top_k=1,): | |||
| save_top_k=1, | |||
| k_fold=5, | |||
| kth_fold_start=0): | |||
| """ | |||
| 框架的入口函数. 包含设置超参数, 划分数据集, 选择训练或测试等流程 | |||
| 该函数的参数为训练过程中需要经常改动的参数 | |||
| @@ -32,74 +35,72 @@ def main(stage, | |||
| :param dataset_path: 数据集地址, 其目录下包含数据集, 标签, 全部数据的命名list | |||
| :param gpus: | |||
| :param tpu_cores: | |||
| :param load_checkpoint_path: | |||
| :param save_name: | |||
| :param version_nth: 该folds的第一个版本的版本号 | |||
| :param path_final_save: | |||
| :param every_n_epochs: | |||
| :param save_top_k: | |||
| :param kth_fold_start: 从第几个fold开始, 若使用重载训练, 则kth_fold_start为重载第几个fold, 第一个值为0 | |||
| :param k_fold: | |||
| """ | |||
| # 经常改动的 参数 作为main的输入参数 | |||
| # 不常改动的 非通用参数 存放在config | |||
| # 不经常改动的 通用参数 直接进行声明 | |||
| # 不常改动的 通用参数 直接进行声明 | |||
| # 通用参数指的是所有网络中共有的参数, 如time_sum等 | |||
| if True: | |||
| config = {'dataset_path': dataset_path, | |||
| 'dim_in': 2, | |||
| 'dim': 10, | |||
| 'res_coef': 0.5, | |||
| 'dropout_p': 0.1, | |||
| 'n_layers': 2, | |||
| 'dataset_len': 100000, | |||
| 'flag': True} | |||
| else: | |||
| config = {'dataset_path': dataset_path, | |||
| 'dim_in': 62, | |||
| 'dim': 32, | |||
| 'res_coef': 0.5, | |||
| 'dropout_p': 0.1, | |||
| 'n_layers': 20, | |||
| 'dataset_len': 100000, | |||
| 'flag': False} | |||
| # 处理输入数据 | |||
| precision = 32 if (gpus is None and tpu_cores is None) else precision | |||
| # 获得通用参数 | |||
| # TODO 获得最优的batch size | |||
| num_workers = cpu_count() | |||
| precision = 32 if (gpus is None and tpu_cores is None) else precision | |||
| dm = DataModule(batch_size=batch_size, num_workers=num_workers, config=config) | |||
| logger = pl_loggers.TensorBoardLogger('logs/') | |||
| if stage == 'fit': | |||
| # SaveCheckpoint的创建需要在TrainModule之前, 以保证网络参数初始化的确定性 | |||
| save_checkpoint = SaveCheckpoint(seed=seed, max_epochs=max_epochs, | |||
| save_name=save_name, path_final_save=path_final_save, | |||
| every_n_epochs=every_n_epochs, verbose=True, | |||
| monitor='Validation loss', save_top_k=save_top_k, | |||
| mode='min') | |||
| training_module = TrainModule(config=config) | |||
| if load_checkpoint_path is None: | |||
| print('进行初始训练') | |||
| trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores, | |||
| logger=logger, precision=precision, callbacks=[save_checkpoint]) | |||
| training_module.load_pretrain_parameters() | |||
| else: | |||
| print('进行重载训练') | |||
| trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores, | |||
| resume_from_checkpoint='./logs/default' + load_checkpoint_path, | |||
| logger=logger, precision=precision, callbacks=[save_checkpoint]) | |||
| print('训练过程中请注意gpu利用率等情况') | |||
| trainer.fit(training_module, datamodule=dm) | |||
| if stage == 'test': | |||
| if load_checkpoint_path is None: | |||
| print('未载入权重信息,不能测试') | |||
| else: | |||
| print('进行测试') | |||
| training_module = TrainModule.load_from_checkpoint( | |||
| checkpoint_path='./logs/default' + load_checkpoint_path, | |||
| **{'config': config}) | |||
| trainer = pl.Trainer(gpus=gpus, tpu_cores=tpu_cores, logger=logger, precision=precision) | |||
| trainer.test(training_module, datamodule=dm) | |||
| # 在cmd中使用tensorboard --logdir logs命令可以查看结果,在Jupyter格式下需要加%前缀 | |||
| # 获得非通用参数 | |||
| config = {'dim_in': 2, | |||
| 'dim': 10, | |||
| 'res_coef': 0.5, | |||
| 'dropout_p': 0.1, | |||
| 'n_layers': 2, | |||
| 'dataset_len': 100000} | |||
| # for kth_fold in range(kth_fold_start, k_fold): | |||
| for kth_fold in range(kth_fold_start, kth_fold_start+1): | |||
| load_checkpoint_path = get_ckpt_path(f'version_{version_nth+kth_fold}') | |||
| logger = pl_loggers.TensorBoardLogger('logs/') | |||
| dm = DataModule(batch_size=batch_size, num_workers=num_workers, k_fold=k_fold, kth_fold=kth_fold, | |||
| dataset_path=dataset_path, config=config) | |||
| if stage == 'fit': | |||
| # SaveCheckpoint的创建需要在TrainModule之前, 以保证网络参数初始化的确定性 | |||
| save_checkpoint = SaveCheckpoint(seed=seed, max_epochs=max_epochs, | |||
| path_final_save=path_final_save, | |||
| every_n_epochs=every_n_epochs, verbose=True, | |||
| monitor='Validation loss', save_top_k=save_top_k, | |||
| mode='min') | |||
| training_module = TrainModule(config=config) | |||
| if kth_fold != kth_fold_start or load_checkpoint_path is None: | |||
| print('进行初始训练') | |||
| trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores, | |||
| logger=logger, precision=precision, callbacks=[save_checkpoint]) | |||
| training_module.load_pretrain_parameters() | |||
| else: | |||
| print('进行重载训练') | |||
| trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores, | |||
| resume_from_checkpoint='./logs/default' + load_checkpoint_path, | |||
| logger=logger, precision=precision, callbacks=[save_checkpoint]) | |||
| print('训练过程中请注意gpu利用率等情况') | |||
| trainer.fit(training_module, datamodule=dm) | |||
| if stage == 'test': | |||
| if load_checkpoint_path is None: | |||
| print('未载入权重信息,不能测试') | |||
| else: | |||
| print('进行测试') | |||
| training_module = TrainModule.load_from_checkpoint( | |||
| checkpoint_path='./logs/default' + load_checkpoint_path, | |||
| **{'config': config}) | |||
| trainer = pl.Trainer(gpus=gpus, tpu_cores=tpu_cores, logger=logger, precision=precision) | |||
| trainer.test(training_module, datamodule=dm) | |||
| # 在cmd中使用tensorboard --logdir logs命令可以查看结果,在Jupyter格式下需要加%前缀 | |||
| if __name__ == "__main__": | |||
| main('fit', max_epochs=5, batch_size=32, precision=16, seed=1234, | |||
| main('fit', max_epochs=2, batch_size=32, precision=16, seed=1234, dataset_path='./dataset', k_fold=5 | |||
| # gpus=1, | |||
| # load_checkpoint_path='/version_4/checkpoints/epoch=4-step=7814.ckpt', | |||
| # version_nth=8, # 该folds的第一个版本的版本号 | |||
| # kth_fold_start=0 # 如果需要重载训练, 则指定重载的版本和其位于k_fold的fold数 | |||
| ) | |||
| @@ -1,5 +1,5 @@ | |||
| import torch.nn as nn | |||
| from torch.nn import Mish | |||
| from network_module.activation import jdlu, JDLU | |||
| class MLPLayer(nn.Module): | |||
| @@ -7,13 +7,15 @@ class MLPLayer(nn.Module): | |||
| super().__init__() | |||
| self.linear = nn.Linear(dim_in, dim_out) | |||
| self.res_coef = res_coef | |||
| self.activation = Mish() | |||
| self.activation = nn.ReLU() | |||
| self.activation1 = JDLU(dim_out) | |||
| self.dropout = nn.Dropout(dropout_p) | |||
| self.ln = nn.LayerNorm(dim_out) | |||
| def forward(self, x): | |||
| y = self.linear(x) | |||
| y = self.activation(y) | |||
| y = self.activation1(y) | |||
| # y = jdlu(y) | |||
| y = self.dropout(y) | |||
| if self.res_coef == 0: | |||
| return self.ln(y) | |||
| @@ -1,10 +1,8 @@ | |||
| import os | |||
| import numpy.random | |||
| from pytorch_lightning.callbacks import ModelCheckpoint | |||
| import pytorch_lightning as pl | |||
| import shutil | |||
| import random | |||
| from pytorch_lightning.utilities import rank_zero_info | |||
| from utils import zip_dir | |||
| @@ -14,7 +12,6 @@ class SaveCheckpoint(ModelCheckpoint): | |||
| max_epochs, | |||
| seed=None, | |||
| every_n_epochs=None, | |||
| save_name=None, | |||
| path_final_save=None, | |||
| monitor=None, | |||
| save_top_k=None, | |||
| @@ -27,7 +24,6 @@ class SaveCheckpoint(ModelCheckpoint): | |||
| :param max_epochs: | |||
| :param seed: | |||
| :param every_n_epochs: | |||
| :param save_name: | |||
| :param path_final_save: | |||
| :param monitor: | |||
| :param save_top_k: | |||
| @@ -39,7 +35,6 @@ class SaveCheckpoint(ModelCheckpoint): | |||
| numpy.random.seed(seed) | |||
| self.seeds = numpy.random.randint(0, 2000, max_epochs) | |||
| pl.seed_everything(seed) | |||
| self.save_name = save_name | |||
| self.path_final_save = path_final_save | |||
| self.monitor = monitor | |||
| self.save_top_k = save_top_k | |||
| @@ -71,11 +66,11 @@ class SaveCheckpoint(ModelCheckpoint): | |||
| if self.check_monitor_top_k(trainer, current): | |||
| self._update_best_and_save(current, trainer, monitor_candidates) | |||
| if self.save_name is not None and self.path_final_save is not None: | |||
| zip_dir('./logs/default/' + self.save_name, './' + self.save_name + '.zip') | |||
| if os.path.exists(self.path_final_save + '/' + self.save_name + '.zip'): | |||
| os.remove(self.path_final_save + '/' + self.save_name + '.zip') | |||
| shutil.move('./' + self.save_name + '.zip', self.path_final_save) | |||
| if self.path_final_save is not None: | |||
| zip_dir('./logs', './logs.zip') | |||
| if os.path.exists(self.path_final_save + '/logs.zip'): | |||
| os.remove(self.path_final_save + '/logs.zip') | |||
| shutil.move('./logs.zip', self.path_final_save) | |||
| elif self.verbose: | |||
| epoch = monitor_candidates.get("epoch") | |||
| step = monitor_candidates.get("step") | |||
| @@ -2,120 +2,25 @@ | |||
| import glob | |||
| import os | |||
| import random | |||
| import string | |||
| import zipfile | |||
| import cv2 | |||
| import numpy | |||
| import torch | |||
| def divide_dataset(dataset_path, rate_datasets): | |||
| """ | |||
| 切分数据集, 划分为训练集,验证集,测试集生成list文件并保存为: | |||
| train_dataset_list、validate_dataset_list、test_dataset_list. | |||
| 每个比例必须大于0且保证至少每个数据集中具有一个样本, 验证集可以为0. | |||
| :param dataset_path: 数据集的地址 | |||
| :param rate_datasets: 不同数据集[训练集,验证集,测试集]的比例 | |||
| """ | |||
| # 当不存在总的all_dataset_list文件时, 生成all_dataset_list | |||
| if not os.path.exists(dataset_path + '/all_dataset_list.txt'): | |||
| def get_dataset_list(dataset_path): | |||
| if not os.path.exists(dataset_path + '/dataset_list.txt'): | |||
| all_list = glob.glob(dataset_path + '/labels' + '/*.png') | |||
| with open(dataset_path + '/all_dataset_list.txt', 'w', encoding='utf-8') as f: | |||
| random.shuffle(all_list) | |||
| with open(dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f: | |||
| for line in all_list: | |||
| f.write(os.path.basename(line.replace('\\', '/')) + '\n') | |||
| path_train_dataset_list = dataset_path + '/train_dataset_list.txt' | |||
| path_validate_dataset_list = dataset_path + '/validate_dataset_list.txt' | |||
| path_test_dataset_list = dataset_path + '/test_dataset_list.txt' | |||
| # 如果验证集的比例为0,则将测试集设置为验证集并取消测试集; | |||
| if rate_datasets[1] == 0: | |||
| # 如果无切分后的list文件, 则生成新的list文件 | |||
| if not (os.path.exists(path_train_dataset_list) and | |||
| os.path.exists(path_validate_dataset_list) and | |||
| os.path.exists(path_test_dataset_list)): | |||
| all_list = open(dataset_path + '/all_dataset_list.txt').readlines() | |||
| random.shuffle(all_list) | |||
| train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])] | |||
| test_dataset_list = all_list[int(len(all_list) * rate_datasets[0]):] | |||
| with open(path_train_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in train_dataset_list: | |||
| f.write(line) | |||
| with open(path_validate_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in test_dataset_list: | |||
| f.write(line) | |||
| with open(path_test_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in test_dataset_list: | |||
| f.write(line) | |||
| print('已生成新的数据list') | |||
| else: | |||
| # 判断比例是否正确,如果不正确,则重新生成数据集 | |||
| all_list = open(dataset_path + '/all_dataset_list.txt').readlines() | |||
| with open(path_train_dataset_list) as f: | |||
| train_dataset_list_exist = f.readlines() | |||
| with open(path_validate_dataset_list) as f: | |||
| test_dataset_list_exist = f.readlines() | |||
| random.shuffle(all_list) | |||
| train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])] | |||
| test_dataset_list = all_list[int(len(all_list) * rate_datasets[0]):] | |||
| if not (len(train_dataset_list_exist) == len(train_dataset_list) and | |||
| len(test_dataset_list_exist) == len(test_dataset_list)): | |||
| with open(path_train_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in train_dataset_list: | |||
| f.write(line) | |||
| with open(path_validate_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in test_dataset_list: | |||
| f.write(line) | |||
| with open(path_test_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in test_dataset_list: | |||
| f.write(line) | |||
| print('已生成新的数据list') | |||
| # 如果验证集比例不为零,则同时存在验证集和测试集 | |||
| return all_list | |||
| else: | |||
| # 如果无切分后的list文件, 则生成新的list文件 | |||
| if not (os.path.exists(dataset_path + '/train_dataset_list.txt') and | |||
| os.path.exists(dataset_path + '/validate_dataset_list.txt') and | |||
| os.path.exists(dataset_path + '/test_dataset_list.txt')): | |||
| all_list = open(dataset_path + '/all_dataset_list.txt').readlines() | |||
| random.shuffle(all_list) | |||
| train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])] | |||
| validate_dataset_list = all_list[int(len(all_list) * rate_datasets[0]): | |||
| int(len(all_list) * (rate_datasets[0] + rate_datasets[1]))] | |||
| test_dataset_list = all_list[int(len(all_list) * (rate_datasets[0] + rate_datasets[1])):] | |||
| with open(path_train_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in train_dataset_list: | |||
| f.write(line) | |||
| with open(path_validate_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in validate_dataset_list: | |||
| f.write(line) | |||
| with open(path_test_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in test_dataset_list: | |||
| f.write(line) | |||
| print('已生成新的数据list') | |||
| else: | |||
| # 判断比例是否正确,如果不正确,则重新生成数据集 | |||
| all_list = open(dataset_path + '/all_dataset_list.txt').readlines() | |||
| with open(path_train_dataset_list) as f: | |||
| train_dataset_list_exist = f.readlines() | |||
| with open(path_validate_dataset_list) as f: | |||
| validate_dataset_list_exist = f.readlines() | |||
| with open(path_test_dataset_list) as f: | |||
| test_dataset_list_exist = f.readlines() | |||
| random.shuffle(all_list) | |||
| train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])] | |||
| validate_dataset_list = all_list[int(len(all_list) * rate_datasets[0]): | |||
| int(len(all_list) * (rate_datasets[0] + rate_datasets[1]))] | |||
| test_dataset_list = all_list[int(len(all_list) * (rate_datasets[0] + rate_datasets[1])):] | |||
| if not (len(train_dataset_list_exist) == len(train_dataset_list) and | |||
| len(validate_dataset_list_exist) == len(validate_dataset_list) and | |||
| len(test_dataset_list_exist) == len(test_dataset_list)): | |||
| with open(path_train_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in train_dataset_list: | |||
| f.write(line) | |||
| with open(path_validate_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in validate_dataset_list: | |||
| f.write(line) | |||
| with open(path_test_dataset_list, 'w', encoding='utf-8') as f: | |||
| for line in test_dataset_list: | |||
| f.write(line) | |||
| print('已生成新的数据list') | |||
| all_list = open(dataset_path + '/all_dataset_list.txt').readlines() | |||
| return all_list | |||
| def zip_dir(dir_path, zip_path): | |||
| @@ -142,6 +47,7 @@ def ncolors(num_colors): | |||
| :param num_colors: 颜色数 | |||
| :return: | |||
| """ | |||
| def get_n_hls_colors(num): | |||
| import random | |||
| hls_colors = [] | |||
| @@ -187,3 +93,35 @@ def visual_label(dataset_path, n_classes): | |||
| trans_factory(torch.from_numpy(label_image).float() / n_classes).save( | |||
| dataset_path + '/visual_label/' + name, | |||
| quality=95) | |||
| def get_ckpt_path(version_name: string): | |||
| if version_name is None: | |||
| return None | |||
| else: | |||
| checkpoints_path = './logs/default/' + version_name + '/checkpoints' | |||
| ckpt_path = glob.glob(checkpoints_path + '/*.ckpt') | |||
| return ckpt_path[0].replace('\\', '/') | |||
| def rwxl(): | |||
| # 写 | |||
| # dataset_xl = xl.Workbook(write_only=True) | |||
| # dataset_sh = dataset_xl.create_sheet('dataset', 0) | |||
| # for row in range(self.x.shape[0]): | |||
| # for col in range(self.x.shape[1]): | |||
| # dataset_sh.cell(row + 1, col + 1).value = float(self.x[row, col]) | |||
| # dataset_sh.cell(row + 1, self.x.shape[1] + 1).value = float(self.y[row]) | |||
| # dataset_xl.save(dataset_path + '/dataset.xlsx') | |||
| # dataset_xl.close() | |||
| # 读 | |||
| # dataset_xl = xl.load_workbook(dataset_path + '/dataset_list.xlsx', read_only=True) | |||
| # dataset_sh = dataset_xl.get_sheet_by_name('dataset_list') | |||
| # temp = [[dataset_sh[row + 1][col].value for col in range(config['dim_in'] + 1)] for row in | |||
| # range(config['dataset_len'])] | |||
| # dataset_xl.close() | |||
| pass | |||
| if __name__ == "__main__": | |||
| get_ckpt_path('version_0') | |||