You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_module.py 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import glob
  2. import os
  3. import random
  4. import torch
  5. from torch.utils.data import Dataset, DataLoader
  6. import pytorch_lightning as pl
  7. from PIL import Image
  8. from torchvision import transforms
  9. class DataModule(pl.LightningDataModule):
  10. def __init__(self, batch_size, num_workers, k_fold, kth_fold, dataset_path, config=None):
  11. super().__init__()
  12. self.batch_size = batch_size
  13. self.num_workers = num_workers
  14. self.config = config
  15. self.k_fold = k_fold
  16. self.kth_fold = kth_fold
  17. self.dataset_path = dataset_path
  18. def setup(self, stage=None) -> None:
  19. k_fold_dataset_list = self.get_k_fold_dataset_list()
  20. if stage == 'fit' or stage is None:
  21. dataset_train, dataset_val = self.get_fit_dataset_lists(k_fold_dataset_list)
  22. self.train_dataset = CustomDataset(self.dataset_path, dataset_train, 'train', self.config,)
  23. self.val_dataset = CustomDataset(self.dataset_path, dataset_val, 'val', self.config,)
  24. if stage == 'test' or stage is None:
  25. dataset_test = self.get_test_dataset_lists(k_fold_dataset_list)
  26. self.test_dataset = CustomDataset(self.dataset_path, dataset_test, 'test', self.config,)
  27. def get_k_fold_dataset_list(self):
  28. # 得到用于K折分割的数据的list, 并生成文件夹进行保存
  29. if not os.path.exists(self.dataset_path + '/k_fold_dataset_list.txt'):
  30. # 获得用于k折分割的数据的list
  31. dataset = glob.glob(self.dataset_path + '/train/image/*.png')
  32. random.shuffle(dataset)
  33. written = dataset
  34. with open(self.dataset_path + '/k_fold_dataset_list.txt', 'w', encoding='utf-8') as f:
  35. for line in written:
  36. f.write(line.replace('\\', '/') + '\n')
  37. print('已生成新的k折数据list')
  38. else:
  39. dataset = open(self.dataset_path + '/k_fold_dataset_list.txt').readlines()
  40. dataset = [item.strip('\n') for item in dataset]
  41. return dataset
  42. def get_fit_dataset_lists(self, dataset_list: list):
  43. # 得到一个fold的数据量和不够组成一个fold的剩余数据的数据量
  44. num_1fold, remainder = divmod(len(dataset_list), self.k_fold)
  45. # 分割全部数据, 得到训练集, 验证集, 测试集
  46. dataset_val = dataset_list[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder)]
  47. del (dataset_list[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder)])
  48. dataset_train = dataset_list
  49. return dataset_train, dataset_val
  50. def get_test_dataset_lists(self, dataset_list):
  51. dataset = glob.glob(self.dataset_path + '/test/image/*.png')
  52. return dataset
  53. def train_dataloader(self):
  54. return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers,
  55. pin_memory=True)
  56. def val_dataloader(self):
  57. return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers,
  58. pin_memory=True)
  59. def test_dataloader(self):
  60. return DataLoader(self.test_dataset, batch_size=1, shuffle=False, num_workers=self.num_workers,
  61. pin_memory=True)
  62. class CustomDataset(Dataset):
  63. def __init__(self, dataset_path, dataset, stage, config, ):
  64. super().__init__()
  65. self.dataset = dataset
  66. # 此处的均值和方差来源于ImageNet
  67. normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
  68. std=[0.229, 0.224, 0.225])
  69. if stage == 'train':
  70. self.trans = transforms.Compose([
  71. transforms.RandomHorizontalFlip(),
  72. transforms.RandomCrop(config['dim_in'], 4),
  73. transforms.ToTensor(),
  74. normalize, ])
  75. elif stage == 'val':
  76. stage = 'train'
  77. self.trans = transforms.Compose([
  78. transforms.ToTensor(),
  79. normalize, ])
  80. else:
  81. self.trans = transforms.Compose([
  82. transforms.ToTensor(),
  83. normalize, ])
  84. self.labels = open(dataset_path + '/' + stage + '/label.txt').readlines()
  85. def __getitem__(self, idx):
  86. # 注意: 为了满足初始化权重算法的要求, 需要输入参数的均值为0. 可以使用transforms.Normalize()
  87. image_path = self.dataset[idx]
  88. image_name = os.path.basename(image_path)
  89. image = Image.open(image_path)
  90. image = self.trans(image)
  91. label = torch.Tensor([int(self.labels[int(image_name.strip('.png'))].strip('\n'))])
  92. return image_name, image, label.long()
  93. def __len__(self):
  94. return len(self.dataset)

基于pytorch lightning的机器学习模板, 用于对机器学习算法进行训练, 验证, 测试等, 目前实现了神经网路, 深度学习, k折交叉, 自动保存训练信息等.

Contributors (1)