You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_module.py 2.7 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import torch
  2. from torch.utils.data import Dataset, DataLoader
  3. import pytorch_lightning as pl
  4. class DataModule(pl.LightningDataModule):
  5. def __init__(self, batch_size, num_workers, config=None):
  6. super().__init__()
  7. # TODO 使用k折交叉验证
  8. # divide_dataset(config['dataset_path'], [0.8, 0, 0.2])
  9. if config['flag']:
  10. x = torch.randn(100000, 2)
  11. noise = torch.randn(100000, )
  12. y = ((1.0 * x[:, 0] + 2.0 * x[:, 1] + noise) > 0).type(torch.int64)
  13. else:
  14. x_1 = torch.randn(100000)
  15. x_2 = torch.randn(100000)
  16. x_useful = torch.cos(1.5 * x_1) * (x_2 ** 2)
  17. x_1_rest_small = torch.randn(100000, 15) + 0.01 * x_1.unsqueeze(1)
  18. x_1_rest_large = torch.randn(100000, 15) + 0.1 * x_1.unsqueeze(1)
  19. x_2_rest_small = torch.randn(100000, 15) + 0.01 * x_2.unsqueeze(1)
  20. x_2_rest_large = torch.randn(100000, 15) + 0.1 * x_2.unsqueeze(1)
  21. x = torch.cat([x_1[:, None], x_2[:, None], x_1_rest_small, x_1_rest_large, x_2_rest_small, x_2_rest_large],
  22. dim=1)
  23. y = ((10 * x_useful) + 5 * torch.randn(100000) > 0.0).type(torch.int64)
  24. self.y_train, self.y_test = y[:50000], y[50000:]
  25. self.x_train, self.x_test = x[:50000, :], x[50000:, :]
  26. self.batch_size = batch_size
  27. self.num_workers = num_workers
  28. self.config = config
  29. self.train_dataset = None
  30. self.val_dataset = None
  31. self.test_dataset = None
  32. def setup(self, stage=None) -> None:
  33. if stage == 'fit' or stage is None:
  34. self.train_dataset = CustomDataset(self.x_train, self.y_train, self.config)
  35. self.val_dataset = CustomDataset(self.x_test, self.y_test, self.config)
  36. if stage == 'test' or stage is None:
  37. self.test_dataset = CustomDataset(self.x_test, self.y_test, self.config)
  38. def train_dataloader(self):
  39. return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
  40. def val_dataloader(self):
  41. return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
  42. def test_dataloader(self):
  43. return DataLoader(self.test_dataset, batch_size=1, shuffle=False, num_workers=self.num_workers)
  44. class CustomDataset(Dataset):
  45. def __init__(self, x, y, config):
  46. super().__init__()
  47. self.x = x
  48. self.y = y
  49. self.config = config
  50. def __getitem__(self, idx):
  51. return self.x[idx, :], self.y[idx]
  52. def __len__(self):
  53. return self.x.shape[0]

基于pytorch lightning的机器学习模板, 用于对机器学习算法进行训练, 验证, 测试等, 目前实现了神经网路, 深度学习, k折交叉, 自动保存训练信息等.

Contributors (1)