From b2ac2ad2c1ebe3bc4d546689b81496510b410ecf Mon Sep 17 00:00:00 2001 From: shenyan <23357320@qq.com> Date: Thu, 21 Oct 2021 19:55:25 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E8=AE=AD=E7=BB=83=E9=98=B6?= =?UTF-8?q?=E6=AE=B5=E7=9A=84=E6=95=B0=E6=8D=AE=E5=A2=9E=E5=BC=BA;=20?= =?UTF-8?q?=E5=B0=86=E8=BE=93=E5=85=A5=E6=95=B0=E6=8D=AE=E6=A0=87=E5=87=86?= =?UTF-8?q?=E5=8C=96;=20=E5=88=A0=E9=99=A4main=E4=B8=AD=E7=9A=84config?= =?UTF-8?q?=E4=B8=8D=E9=9C=80=E8=A6=81=E7=9A=84=E9=87=8F=E5=B9=B6=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E8=BE=93=E5=85=A5=E7=BB=B4=E5=BA=A6;?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_module.py | 30 ++++++++++++++++++++++++------ main.py | 7 +------ 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/data_module.py b/data_module.py index 2069df2..ce8e1f7 100644 --- a/data_module.py +++ b/data_module.py @@ -22,11 +22,11 @@ class DataModule(pl.LightningDataModule): k_fold_dataset_list = self.get_k_fold_dataset_list() if stage == 'fit' or stage is None: dataset_train, dataset_val = self.get_fit_dataset_lists(k_fold_dataset_list) - self.train_dataset = CustomDataset(self.dataset_path, dataset_train, self.config, 'train') - self.val_dataset = CustomDataset(self.dataset_path, dataset_val, self.config, 'train') + self.train_dataset = CustomDataset(self.dataset_path, dataset_train, 'train', self.config,) + self.val_dataset = CustomDataset(self.dataset_path, dataset_val, 'val', self.config,) if stage == 'test' or stage is None: dataset_test = self.get_test_dataset_lists(k_fold_dataset_list) - self.test_dataset = CustomDataset(self.dataset_path, dataset_test, self.config, 'test') + self.test_dataset = CustomDataset(self.dataset_path, dataset_test, 'test', self.config,) def get_k_fold_dataset_list(self): # 得到用于K折分割的数据的list, 并生成文件夹进行保存 @@ -72,13 +72,31 @@ class DataModule(pl.LightningDataModule): class CustomDataset(Dataset): - def __init__(self, dataset_path, dataset, config, type): + def __init__(self, dataset_path, dataset, stage, config, ): super().__init__() self.dataset = dataset - self.trans = transforms.ToTensor() - self.labels = open(dataset_path + '/' + type + '/label.txt').readlines() + # 此处的均值和方差来源于ImageNet + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + if stage == 'train': + self.trans = transforms.Compose([ + transforms.RandomHorizontalFlip(), + transforms.RandomCrop(config['dim_in'], 4), + transforms.ToTensor(), + normalize, ]) + elif stage == 'val': + stage = 'train' + self.trans = transforms.Compose([ + transforms.ToTensor(), + normalize, ]) + else: + self.trans = transforms.Compose([ + transforms.ToTensor(), + normalize, ]) + self.labels = open(dataset_path + '/' + stage + '/label.txt').readlines() def __getitem__(self, idx): + # 注意: 为了满足初始化权重算法的要求, 需要输入参数的均值为0. 可以使用transforms.Normalize() image_path = self.dataset[idx] image_name = os.path.basename(image_path) image = Image.open(image_path) diff --git a/main.py b/main.py index b7e62f4..9ba1bf0 100644 --- a/main.py +++ b/main.py @@ -59,12 +59,7 @@ def main(stage, # TODO 获得最优的batch size num_workers = cpu_count() # 获得非通用参数 - config = {'dim_in': 5, - 'dim': 10, - 'res_coef': 0.5, - 'dropout_p': 0.1, - 'n_layers': 3, - 'dataset_len': 100000} + config = {'dim_in': 32, } for kth_fold in range(kth_fold_start, k_fold): load_checkpoint_path = get_ckpt_path(version_nth, kth_fold) logger = pl_loggers.TensorBoardLogger('logs/')