diff --git a/data_module.py b/data_module.py index 2069df2..ce8e1f7 100644 --- a/data_module.py +++ b/data_module.py @@ -22,11 +22,11 @@ class DataModule(pl.LightningDataModule): k_fold_dataset_list = self.get_k_fold_dataset_list() if stage == 'fit' or stage is None: dataset_train, dataset_val = self.get_fit_dataset_lists(k_fold_dataset_list) - self.train_dataset = CustomDataset(self.dataset_path, dataset_train, self.config, 'train') - self.val_dataset = CustomDataset(self.dataset_path, dataset_val, self.config, 'train') + self.train_dataset = CustomDataset(self.dataset_path, dataset_train, 'train', self.config,) + self.val_dataset = CustomDataset(self.dataset_path, dataset_val, 'val', self.config,) if stage == 'test' or stage is None: dataset_test = self.get_test_dataset_lists(k_fold_dataset_list) - self.test_dataset = CustomDataset(self.dataset_path, dataset_test, self.config, 'test') + self.test_dataset = CustomDataset(self.dataset_path, dataset_test, 'test', self.config,) def get_k_fold_dataset_list(self): # 得到用于K折分割的数据的list, 并生成文件夹进行保存 @@ -72,13 +72,31 @@ class DataModule(pl.LightningDataModule): class CustomDataset(Dataset): - def __init__(self, dataset_path, dataset, config, type): + def __init__(self, dataset_path, dataset, stage, config, ): super().__init__() self.dataset = dataset - self.trans = transforms.ToTensor() - self.labels = open(dataset_path + '/' + type + '/label.txt').readlines() + # 此处的均值和方差来源于ImageNet + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + if stage == 'train': + self.trans = transforms.Compose([ + transforms.RandomHorizontalFlip(), + transforms.RandomCrop(config['dim_in'], 4), + transforms.ToTensor(), + normalize, ]) + elif stage == 'val': + stage = 'train' + self.trans = transforms.Compose([ + transforms.ToTensor(), + normalize, ]) + else: + self.trans = transforms.Compose([ + transforms.ToTensor(), + normalize, ]) + self.labels = open(dataset_path + '/' + stage + '/label.txt').readlines() def __getitem__(self, idx): + # 注意: 为了满足初始化权重算法的要求, 需要输入参数的均值为0. 可以使用transforms.Normalize() image_path = self.dataset[idx] image_name = os.path.basename(image_path) image = Image.open(image_path) diff --git a/main.py b/main.py index b7e62f4..9ba1bf0 100644 --- a/main.py +++ b/main.py @@ -59,12 +59,7 @@ def main(stage, # TODO 获得最优的batch size num_workers = cpu_count() # 获得非通用参数 - config = {'dim_in': 5, - 'dim': 10, - 'res_coef': 0.5, - 'dropout_p': 0.1, - 'n_layers': 3, - 'dataset_len': 100000} + config = {'dim_in': 32, } for kth_fold in range(kth_fold_start, k_fold): load_checkpoint_path = get_ckpt_path(version_nth, kth_fold) logger = pl_loggers.TensorBoardLogger('logs/')