|
|
|
@@ -22,11 +22,11 @@ class DataModule(pl.LightningDataModule): |
|
|
|
k_fold_dataset_list = self.get_k_fold_dataset_list() |
|
|
|
if stage == 'fit' or stage is None: |
|
|
|
dataset_train, dataset_val = self.get_fit_dataset_lists(k_fold_dataset_list) |
|
|
|
self.train_dataset = CustomDataset(self.dataset_path, dataset_train, self.config, 'train') |
|
|
|
self.val_dataset = CustomDataset(self.dataset_path, dataset_val, self.config, 'train') |
|
|
|
self.train_dataset = CustomDataset(self.dataset_path, dataset_train, 'train', self.config,) |
|
|
|
self.val_dataset = CustomDataset(self.dataset_path, dataset_val, 'val', self.config,) |
|
|
|
if stage == 'test' or stage is None: |
|
|
|
dataset_test = self.get_test_dataset_lists(k_fold_dataset_list) |
|
|
|
self.test_dataset = CustomDataset(self.dataset_path, dataset_test, self.config, 'test') |
|
|
|
self.test_dataset = CustomDataset(self.dataset_path, dataset_test, 'test', self.config,) |
|
|
|
|
|
|
|
def get_k_fold_dataset_list(self): |
|
|
|
# 得到用于K折分割的数据的list, 并生成文件夹进行保存 |
|
|
|
@@ -72,13 +72,31 @@ class DataModule(pl.LightningDataModule): |
|
|
|
|
|
|
|
|
|
|
|
class CustomDataset(Dataset): |
|
|
|
def __init__(self, dataset_path, dataset, config, type): |
|
|
|
def __init__(self, dataset_path, dataset, stage, config, ): |
|
|
|
super().__init__() |
|
|
|
self.dataset = dataset |
|
|
|
self.trans = transforms.ToTensor() |
|
|
|
self.labels = open(dataset_path + '/' + type + '/label.txt').readlines() |
|
|
|
# 此处的均值和方差来源于ImageNet |
|
|
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
|
std=[0.229, 0.224, 0.225]) |
|
|
|
if stage == 'train': |
|
|
|
self.trans = transforms.Compose([ |
|
|
|
transforms.RandomHorizontalFlip(), |
|
|
|
transforms.RandomCrop(config['dim_in'], 4), |
|
|
|
transforms.ToTensor(), |
|
|
|
normalize, ]) |
|
|
|
elif stage == 'val': |
|
|
|
stage = 'train' |
|
|
|
self.trans = transforms.Compose([ |
|
|
|
transforms.ToTensor(), |
|
|
|
normalize, ]) |
|
|
|
else: |
|
|
|
self.trans = transforms.Compose([ |
|
|
|
transforms.ToTensor(), |
|
|
|
normalize, ]) |
|
|
|
self.labels = open(dataset_path + '/' + stage + '/label.txt').readlines() |
|
|
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
# 注意: 为了满足初始化权重算法的要求, 需要输入参数的均值为0. 可以使用transforms.Normalize() |
|
|
|
image_path = self.dataset[idx] |
|
|
|
image_name = os.path.basename(image_path) |
|
|
|
image = Image.open(image_path) |
|
|
|
|