You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_module.py 3.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import os
  2. import numpy
  3. import torch
  4. from torch import Tensor
  5. from torch.utils.data import Dataset, DataLoader
  6. import pytorch_lightning as pl
  7. class DataModule(pl.LightningDataModule):
  8. def __init__(self, batch_size, num_workers, k_fold, kth_fold, dataset_path, config=None):
  9. super().__init__()
  10. self.batch_size = batch_size
  11. self.num_workers = num_workers
  12. self.config = config
  13. self.k_fold = k_fold
  14. self.kth_fold = kth_fold
  15. self.dataset_path = dataset_path
  16. def setup(self, stage=None) -> None:
  17. # 得到全部数据的list
  18. # dataset_list = get_dataset_list(dataset_path)
  19. x, y = self.get_fit_dataset_list()
  20. if stage == 'fit' or stage is None:
  21. x_train, y_train, x_val, y_val = self.get_dataset_lists(x, y)
  22. self.train_dataset = CustomDataset(x_train, y_train, self.config)
  23. self.val_dataset = CustomDataset(x_val, y_val, self.config)
  24. if stage == 'test' or stage is None:
  25. self.test_dataset = CustomDataset(x, y, self.config)
  26. def get_fit_dataset_list(self):
  27. if not os.path.exists(self.dataset_path + '/dataset_list.txt'):
  28. x = torch.randn(self.config['dataset_len'], self.config['dim_in'])
  29. noise = torch.randn(self.config['dataset_len'])
  30. y = torch.cos(1.5 * x[:, 0]) * (x[:, 1] ** 2.0) + noise
  31. with open(self.dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f:
  32. for line in range(self.config['dataset_len']):
  33. f.write(' '.join([str(temp) for temp in x[line].tolist()]) + ' ' + str(y[line].item()) + '\n')
  34. print('已生成新的数据list')
  35. else:
  36. dataset_list = open(self.dataset_path + '/dataset_list.txt').readlines()
  37. dataset_list = [[float(temp) for temp in item.strip('\n').split(' ')] for item in dataset_list]
  38. x = torch.from_numpy(numpy.array(dataset_list)[:, 0:self.config['dim_in']]).float()
  39. y = torch.from_numpy(numpy.array(dataset_list)[:, self.config['dim_in']]).float()
  40. return x, y
  41. def get_dataset_lists(self, x: Tensor, y):
  42. # 得到一个fold的数据量和不够组成一个fold的剩余数据的数据量
  43. num_1fold, remainder = divmod(self.config['dataset_len'], self.k_fold)
  44. # 分割全部数据, 得到训练集, 验证集, 测试集
  45. x_val = x[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder)]
  46. y_val = y[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder)]
  47. temp = torch.ones(x.shape[0])
  48. temp[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder)] = 0
  49. x_train = x[temp == 1]
  50. y_train = y[temp == 1]
  51. return x_train, y_train, x_val, y_val
  52. def train_dataloader(self):
  53. return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers,
  54. pin_memory=True)
  55. def val_dataloader(self):
  56. return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers,
  57. pin_memory=True)
  58. def test_dataloader(self):
  59. return DataLoader(self.test_dataset, batch_size=1, shuffle=False, num_workers=self.num_workers,
  60. pin_memory=True)
  61. class CustomDataset(Dataset):
  62. def __init__(self, x, y, config):
  63. super().__init__()
  64. self.x = x
  65. self.y = y
  66. self.config = config
  67. def __getitem__(self, idx):
  68. return self.x[idx, :], self.y[idx]
  69. def __len__(self):
  70. return self.x.shape[0]

基于pytorch lightning的机器学习模板, 用于对机器学习算法进行训练, 验证, 测试等, 目前实现了神经网路, 深度学习, k折交叉, 自动保存训练信息等.

Contributors (1)