You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_module.py 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import os
  2. import numpy
  3. import torch
  4. from torch import Tensor
  5. from torch.utils.data import Dataset, DataLoader
  6. import pytorch_lightning as pl
  7. class DataModule(pl.LightningDataModule):
  8. def __init__(self, batch_size, num_workers, k_fold, kth_fold, dataset_path, config=None):
  9. super().__init__()
  10. self.batch_size = batch_size
  11. self.num_workers = num_workers
  12. self.config = config
  13. self.k_fold = k_fold
  14. self.kth_fold = kth_fold
  15. self.dataset_path = dataset_path
  16. def setup(self, stage=None) -> None:
  17. # 得到全部数据的list
  18. dataset_list = self.get_dataset_list()
  19. if stage == 'fit' or stage is None:
  20. dataset_train, dataset_val = self.get_dataset_lists(dataset_list)
  21. self.train_dataset = CustomDataset(dataset_train, self.config)
  22. self.val_dataset = CustomDataset(dataset_val, self.config)
  23. if stage == 'test' or stage is None:
  24. self.test_dataset = CustomDataset(dataset_list, self.config)
  25. def get_dataset_list(self):
  26. if not os.path.exists(self.dataset_path + '/dataset_list.txt'):
  27. # 针对数据拟合获得dataset
  28. dataset = torch.randn(self.config['dataset_len'], self.config['dim_in'] + 1)
  29. noise = torch.randn(self.config['dataset_len'])
  30. dataset[:, self.config['dim_in']] = torch.cos(1.5 * dataset[:, 0]) * (dataset[:, 1] ** 2.0) + torch.cos(
  31. torch.sin(dataset[:, 2] ** 3)) + torch.arctan(dataset[:, 4]) + noise
  32. assert (dataset[torch.isnan(dataset)].shape[0] == 0)
  33. written = [' '.join([str(temp) for temp in dataset[cou, :].tolist()]) for cou in range(dataset.shape[0])]
  34. with open(self.dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f:
  35. for line in written:
  36. f.write(line + '\n')
  37. print('已生成新的数据list')
  38. else:
  39. dataset_list = open(self.dataset_path + '/dataset_list.txt').readlines()
  40. # 针对数据拟合获得dataset
  41. dataset_list = [[float(temp) for temp in item.strip('\n').split(' ')] for item in dataset_list]
  42. dataset = torch.Tensor(dataset_list).float()
  43. return dataset
  44. def get_dataset_lists(self, dataset_list: Tensor):
  45. # 得到一个fold的数据量和不够组成一个fold的剩余数据的数据量
  46. num_1fold, remainder = divmod(self.config['dataset_len'], self.k_fold)
  47. # 分割全部数据, 得到训练集, 验证集, 测试集
  48. dataset_val = dataset_list[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder), :]
  49. temp = torch.ones(dataset_list.shape[0])
  50. temp[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder)] = 0
  51. dataset_train = dataset_list[temp == 1]
  52. return dataset_train, dataset_val
  53. def train_dataloader(self):
  54. return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers,
  55. pin_memory=True)
  56. def val_dataloader(self):
  57. return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers,
  58. pin_memory=True)
  59. def test_dataloader(self):
  60. return DataLoader(self.test_dataset, batch_size=1, shuffle=False, num_workers=self.num_workers,
  61. pin_memory=True)
  62. class CustomDataset(Dataset):
  63. def __init__(self, dataset, config):
  64. super().__init__()
  65. self.x = dataset[:, 0:config['dim_in']]
  66. self.y = dataset[:, config['dim_in']]
  67. def __getitem__(self, idx):
  68. return self.x[idx, :], self.y[idx]
  69. def __len__(self):
  70. return self.x.shape[0]

基于pytorch lightning的机器学习模板, 用于对机器学习算法进行训练, 验证, 测试等, 目前实现了神经网路, 深度学习, k折交叉, 自动保存训练信息等.

Contributors (1)