|
- import torch
- from torch.utils.data import Dataset, DataLoader
- import pytorch_lightning as pl
-
-
- class DataModule(pl.LightningDataModule):
- def __init__(self, batch_size, num_workers, config=None):
- super().__init__()
- # TODO 使用k折交叉验证
- # divide_dataset(config['dataset_path'], [0.8, 0, 0.2])
- if config['flag']:
- x = torch.randn(100000, 2)
- noise = torch.randn(100000, )
- y = ((1.0 * x[:, 0] + 2.0 * x[:, 1] + noise) > 0).type(torch.int64)
- else:
- x_1 = torch.randn(100000)
- x_2 = torch.randn(100000)
- x_useful = torch.cos(1.5 * x_1) * (x_2 ** 2)
- x_1_rest_small = torch.randn(100000, 15) + 0.01 * x_1.unsqueeze(1)
- x_1_rest_large = torch.randn(100000, 15) + 0.1 * x_1.unsqueeze(1)
- x_2_rest_small = torch.randn(100000, 15) + 0.01 * x_2.unsqueeze(1)
- x_2_rest_large = torch.randn(100000, 15) + 0.1 * x_2.unsqueeze(1)
- x = torch.cat([x_1[:, None], x_2[:, None], x_1_rest_small, x_1_rest_large, x_2_rest_small, x_2_rest_large],
- dim=1)
- y = ((10 * x_useful) + 5 * torch.randn(100000) > 0.0).type(torch.int64)
-
- self.y_train, self.y_test = y[:50000], y[50000:]
- self.x_train, self.x_test = x[:50000, :], x[50000:, :]
-
- self.batch_size = batch_size
- self.num_workers = num_workers
- self.config = config
- self.train_dataset = None
- self.val_dataset = None
- self.test_dataset = None
-
- def setup(self, stage=None) -> None:
- if stage == 'fit' or stage is None:
- self.train_dataset = CustomDataset(self.x_train, self.y_train, self.config)
- self.val_dataset = CustomDataset(self.x_test, self.y_test, self.config)
- if stage == 'test' or stage is None:
- self.test_dataset = CustomDataset(self.x_test, self.y_test, self.config)
-
- def train_dataloader(self):
- return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
-
- def val_dataloader(self):
- return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
-
- def test_dataloader(self):
- return DataLoader(self.test_dataset, batch_size=1, shuffle=False, num_workers=self.num_workers)
-
-
- class CustomDataset(Dataset):
- def __init__(self, x, y, config):
- super().__init__()
- self.x = x
- self.y = y
- self.config = config
-
- def __getitem__(self, idx):
- return self.x[idx, :], self.y[idx]
-
- def __len__(self):
- return self.x.shape[0]
|