Browse Source

增加训练阶段的数据增强; 将输入数据标准化; 删除main中的config不需要的量并修改输入维度;

master
shenyan 4 years ago
parent
commit
b2ac2ad2c1
2 changed files with 25 additions and 12 deletions
  1. +24
    -6
      data_module.py
  2. +1
    -6
      main.py

+ 24
- 6
data_module.py View File

@@ -22,11 +22,11 @@ class DataModule(pl.LightningDataModule):
k_fold_dataset_list = self.get_k_fold_dataset_list() k_fold_dataset_list = self.get_k_fold_dataset_list()
if stage == 'fit' or stage is None: if stage == 'fit' or stage is None:
dataset_train, dataset_val = self.get_fit_dataset_lists(k_fold_dataset_list) dataset_train, dataset_val = self.get_fit_dataset_lists(k_fold_dataset_list)
self.train_dataset = CustomDataset(self.dataset_path, dataset_train, self.config, 'train')
self.val_dataset = CustomDataset(self.dataset_path, dataset_val, self.config, 'train')
self.train_dataset = CustomDataset(self.dataset_path, dataset_train, 'train', self.config,)
self.val_dataset = CustomDataset(self.dataset_path, dataset_val, 'val', self.config,)
if stage == 'test' or stage is None: if stage == 'test' or stage is None:
dataset_test = self.get_test_dataset_lists(k_fold_dataset_list) dataset_test = self.get_test_dataset_lists(k_fold_dataset_list)
self.test_dataset = CustomDataset(self.dataset_path, dataset_test, self.config, 'test')
self.test_dataset = CustomDataset(self.dataset_path, dataset_test, 'test', self.config,)


def get_k_fold_dataset_list(self): def get_k_fold_dataset_list(self):
# 得到用于K折分割的数据的list, 并生成文件夹进行保存 # 得到用于K折分割的数据的list, 并生成文件夹进行保存
@@ -72,13 +72,31 @@ class DataModule(pl.LightningDataModule):




class CustomDataset(Dataset): class CustomDataset(Dataset):
def __init__(self, dataset_path, dataset, config, type):
def __init__(self, dataset_path, dataset, stage, config, ):
super().__init__() super().__init__()
self.dataset = dataset self.dataset = dataset
self.trans = transforms.ToTensor()
self.labels = open(dataset_path + '/' + type + '/label.txt').readlines()
# 此处的均值和方差来源于ImageNet
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if stage == 'train':
self.trans = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(config['dim_in'], 4),
transforms.ToTensor(),
normalize, ])
elif stage == 'val':
stage = 'train'
self.trans = transforms.Compose([
transforms.ToTensor(),
normalize, ])
else:
self.trans = transforms.Compose([
transforms.ToTensor(),
normalize, ])
self.labels = open(dataset_path + '/' + stage + '/label.txt').readlines()


def __getitem__(self, idx): def __getitem__(self, idx):
# 注意: 为了满足初始化权重算法的要求, 需要输入参数的均值为0. 可以使用transforms.Normalize()
image_path = self.dataset[idx] image_path = self.dataset[idx]
image_name = os.path.basename(image_path) image_name = os.path.basename(image_path)
image = Image.open(image_path) image = Image.open(image_path)


+ 1
- 6
main.py View File

@@ -59,12 +59,7 @@ def main(stage,
# TODO 获得最优的batch size # TODO 获得最优的batch size
num_workers = cpu_count() num_workers = cpu_count()
# 获得非通用参数 # 获得非通用参数
config = {'dim_in': 5,
'dim': 10,
'res_coef': 0.5,
'dropout_p': 0.1,
'n_layers': 3,
'dataset_len': 100000}
config = {'dim_in': 32, }
for kth_fold in range(kth_fold_start, k_fold): for kth_fold in range(kth_fold_start, k_fold):
load_checkpoint_path = get_ckpt_path(version_nth, kth_fold) load_checkpoint_path = get_ckpt_path(version_nth, kth_fold)
logger = pl_loggers.TensorBoardLogger('logs/') logger = pl_loggers.TensorBoardLogger('logs/')


Loading…
Cancel
Save