|
- # Copyright (c) Microsoft Corporation.
- # Licensed under the MIT license.
-
- from torchvision import transforms
- from torchvision.datasets import CIFAR10
-
-
- def get_dataset(cls,datadir):
- MEAN = [0.49139968, 0.48215827, 0.44653124]
- STD = [0.24703233, 0.24348505, 0.26158768]
- transf = [
- transforms.RandomCrop(32, padding=4),
- transforms.RandomHorizontalFlip()
- ]
- normalize = [
- transforms.ToTensor(),
- transforms.Normalize(MEAN, STD)
- ]
-
- train_transform = transforms.Compose(transf + normalize)
- valid_transform = transforms.Compose(normalize)
-
- if cls == "cifar10":
- dataset_train = CIFAR10(root=datadir, train=True, download=True, transform=train_transform)
- dataset_valid = CIFAR10(root=datadir, train=False, download=True, transform=valid_transform)
- else:
- raise NotImplementedError
- return dataset_train, dataset_valid
|