|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283 |
- import torch
- from torchvision import datasets, transforms
- import torch.nn.functional as F
- from scipy.ndimage.interpolation import rotate as scipyrotate
-
- import numpy as np
-
-
- def get_fashion_mnist(data_root="./data", output_channels=1, image_size=28):
- ds_train = datasets.FashionMNIST(
- data_root,
- train=True,
- download=True,
- transform=transforms.Compose([transforms.ToTensor(), transforms.Resize([image_size, image_size])]),
- )
- X_train = ds_train.data
- y_train = ds_train.targets
- ds_test = datasets.FashionMNIST(
- data_root,
- train=False,
- download=True,
- transform=transforms.Compose([transforms.ToTensor(), transforms.Resize([image_size, image_size])]),
- )
-
- X_test = ds_test.data
- y_test = ds_test.targets
-
- X_train = X_train[:, None, :, :].float()
- X_test = X_test[:, None, :, :].float()
-
- if output_channels > 1:
- X_train = torch.cat([X_train for i in range(output_channels)], 1)
- X_test = torch.cat([X_test for i in range(output_channels)], 1)
-
- X_test = (X_test - torch.mean(X_train, [0, 2, 3], keepdim=True)) / (torch.std(X_train, [0, 2, 3], keepdim=True))
- X_train = (X_train - torch.mean(X_train, [0, 2, 3], keepdim=True)) / (torch.std(X_train, [0, 2, 3], keepdim=True))
-
- return X_train, y_train, X_test, y_test
-
-
- def get_mnist(data_root="./data/", output_channels=1, image_size=28):
- ds_train = datasets.MNIST(
- data_root,
- train=True,
- download=True,
- transform=transforms.Compose([transforms.ToTensor(), transforms.Resize([image_size, image_size])]),
- )
- X_train = []
-
- for x, _ in ds_train:
- X_train.append(x)
- X_train = torch.stack(X_train)
-
- y_train = ds_train.targets
- ds_test = datasets.MNIST(
- data_root,
- train=False,
- download=True,
- transform=transforms.Compose([transforms.ToTensor(), transforms.Resize([image_size, image_size])]),
- )
-
- X_test = []
-
- for x, _ in ds_test:
- X_test.append(x)
- X_test = torch.stack(X_test)
-
- y_test = ds_test.targets
-
- if output_channels > 1:
- X_train = torch.cat([X_train for i in range(output_channels)], 1)
- X_test = torch.cat([X_test for i in range(output_channels)], 1)
-
- X_test = (X_test - torch.mean(X_train, [0, 2, 3], keepdim=True)) / (torch.std(X_train, [0, 2, 3], keepdim=True))
- X_train = (X_train - torch.mean(X_train, [0, 2, 3], keepdim=True)) / (torch.std(X_train, [0, 2, 3], keepdim=True))
-
- return X_train, y_train, X_test, y_test
-
-
- def get_cifar10(data_root="./data/", output_channels=3, image_size=32):
- ds_train = datasets.CIFAR10(
- data_root,
- train=True,
- download=True,
- transform=transforms.Compose([transforms.ToTensor(), transforms.Resize([image_size, image_size])]),
- )
- X_train = ds_train.data
- y_train = ds_train.targets
- ds_test = datasets.CIFAR10(
- data_root,
- train=False,
- download=True,
- transform=transforms.Compose([transforms.ToTensor(), transforms.Resize([image_size, image_size])]),
- )
-
- X_test = ds_test.data
- y_test = ds_test.targets
-
- X_train = torch.Tensor(np.moveaxis(X_train, 3, 1))
- y_train = torch.Tensor(y_train).long()
- X_test = torch.Tensor(np.moveaxis(X_test, 3, 1))
- y_test = torch.Tensor(y_test).long()
-
- if output_channels == 1:
- X_train = torch.mean(X_train, 1, keepdim=True)
- X_test = torch.mean(X_test, 1, keepdim=True)
-
- X_test = (X_test - torch.mean(X_train, [0, 2, 3], keepdim=True)) / (torch.std(X_train, [0, 2, 3], keepdim=True))
- X_train = (X_train - torch.mean(X_train, [0, 2, 3], keepdim=True)) / (torch.std(X_train, [0, 2, 3], keepdim=True))
-
- return X_train, y_train, X_test, y_test
-
-
- def get_svhn(output_channels=1, image_size=32):
- ds_train = datasets.SVHN(
- "./data/",
- split="train",
- download=True,
- transform=transforms.Compose([transforms.ToTensor(), transforms.Resize([image_size, image_size])]),
- )
- X_train = ds_train.data
- y_train = ds_train.labels
- ds_test = datasets.SVHN(
- "./data/",
- split="test",
- download=True,
- transform=transforms.Compose([transforms.ToTensor(), transforms.Resize([image_size, image_size])]),
- )
-
- X_test = ds_test.data
- y_test = ds_test.labels
-
- X_train = torch.Tensor(X_train)
- y_train = torch.Tensor(y_train).long()
- X_test = torch.Tensor(X_test)
- y_test = torch.Tensor(y_test).long()
-
- if output_channels == 1:
- X_train = torch.mean(X_train, 1, keepdim=True)
- X_test = torch.mean(X_test, 1, keepdim=True)
-
- X_test = (X_test - torch.mean(X_train, [0, 2, 3], keepdim=True)) / (torch.std(X_train, [0, 2, 3], keepdim=True))
- X_train = (X_train - torch.mean(X_train, [0, 2, 3], keepdim=True)) / (torch.std(X_train, [0, 2, 3], keepdim=True))
-
- return X_train, y_train, X_test, y_test
-
-
- def get_cifar100(data_root="./data/", output_channels=3, image_size=32):
- ds_train = datasets.CIFAR100(
- data_root,
- train=True,
- download=True,
- transform=transforms.Compose([transforms.ToTensor(), transforms.Resize([image_size, image_size])]),
- )
- X_train = ds_train.data
- y_train = ds_train.targets
- ds_test = datasets.CIFAR100(
- data_root,
- train=False,
- download=True,
- transform=transforms.Compose([transforms.ToTensor(), transforms.Resize([image_size, image_size])]),
- )
-
- X_test = ds_test.data
- y_test = ds_test.targets
-
- X_train = torch.Tensor(np.moveaxis(X_train, 3, 1))
- y_train = torch.Tensor(y_train).long()
- X_test = torch.Tensor(np.moveaxis(X_test, 3, 1))
- y_test = torch.Tensor(y_test).long()
-
- if output_channels == 1:
- X_train = torch.mean(X_train, 1, keepdim=True)
- X_test = torch.mean(X_test, 1, keepdim=True)
-
- X_test = (X_test - torch.mean(X_train, [0, 2, 3], keepdim=True)) / (torch.std(X_train, [0, 2, 3], keepdim=True))
- X_train = (X_train - torch.mean(X_train, [0, 2, 3], keepdim=True)) / (torch.std(X_train, [0, 2, 3], keepdim=True))
-
- return X_train, y_train, X_test, y_test
-
-
- def get_zca_matrix(X, reg_coef=0.1):
- X_flat = X.reshape(X.shape[0], -1)
- cov = (X_flat.T @ X_flat) / X_flat.shape[0]
- reg_amount = reg_coef * torch.trace(cov) / cov.shape[0]
- u, s, _ = torch.svd(cov.cuda() + reg_amount * torch.eye(cov.shape[0]).cuda())
- inv_sqrt_zca_eigs = s ** (-0.5)
- whitening_transform = torch.einsum("ij,j,kj->ik", u, inv_sqrt_zca_eigs, u)
-
- return whitening_transform.cpu()
-
-
- def layernorm_data(X):
- X_processed = X - torch.mean(X, [1, 2, 3], keepdim=True)
- X_processed = X_processed / torch.sqrt(torch.sum(X_processed**2, [1, 2, 3], keepdim=True))
-
- return X_processed
-
-
- def transform_data(X, whitening_transform):
- if len(whitening_transform.shape) == 2:
- X_flat = X.reshape(X.shape[0], -1)
- X_flat = X_flat @ whitening_transform
- return X_flat.view(*X.shape)
- else:
- X_flat = X.reshape(X.shape[0], -1)
- X_flat = torch.einsum("nd, ndi->ni", X_flat, whitening_transform)
- return X_flat.view(*X.shape)
-
-
- def scale_to_zero_one(X):
- mins = torch.min(X.view(X.shape[0], -1), 1)[0].view(-1, 1, 1, 1)
- maxes = torch.max(X.view(X.shape[0], -1), 1)[0].view(-1, 1, 1, 1)
- return (X - mins) / (maxes - mins)
-
-
- def augment(images, dc_aug_param, device):
- # This can be sped up in the future.
-
- if dc_aug_param != None and dc_aug_param["strategy"] != "none":
- scale = dc_aug_param["scale"]
- crop = dc_aug_param["crop"]
- rotate = dc_aug_param["rotate"]
- noise = dc_aug_param["noise"]
- strategy = dc_aug_param["strategy"]
-
- shape = images.shape
- mean = []
- for c in range(shape[1]):
- mean.append(float(torch.mean(images[:, c])))
-
- def cropfun(i):
- im_ = torch.zeros(shape[1], shape[2] + crop * 2, shape[3] + crop * 2, dtype=torch.float, device=device)
- for c in range(shape[1]):
- im_[c] = mean[c]
- im_[:, crop : crop + shape[2], crop : crop + shape[3]] = images[i]
- r, c = np.random.permutation(crop * 2)[0], np.random.permutation(crop * 2)[0]
- images[i] = im_[:, r : r + shape[2], c : c + shape[3]]
-
- def scalefun(i):
- h = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2])
- w = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2])
- tmp = F.interpolate(
- images[i : i + 1],
- [h, w],
- )[0]
- mhw = max(h, w, shape[2], shape[3])
- im_ = torch.zeros(shape[1], mhw, mhw, dtype=torch.float, device=device)
- r = int((mhw - h) / 2)
- c = int((mhw - w) / 2)
- im_[:, r : r + h, c : c + w] = tmp
- r = int((mhw - shape[2]) / 2)
- c = int((mhw - shape[3]) / 2)
- images[i] = im_[:, r : r + shape[2], c : c + shape[3]]
-
- def rotatefun(i):
- im_ = scipyrotate(
- images[i].cpu().data.numpy(),
- angle=np.random.randint(-rotate, rotate),
- axes=(-2, -1),
- cval=np.mean(mean),
- )
- r = int((im_.shape[-2] - shape[-2]) / 2)
- c = int((im_.shape[-1] - shape[-1]) / 2)
- images[i] = torch.tensor(im_[:, r : r + shape[-2], c : c + shape[-1]], dtype=torch.float, device=device)
-
- def noisefun(i):
- images[i] = images[i] + noise * torch.randn(shape[1:], dtype=torch.float, device=device)
-
- augs = strategy.split("_")
-
- for i in range(shape[0]):
- choice = np.random.permutation(augs)[0] # randomly implement one augmentation
- if choice == "crop":
- cropfun(i)
- elif choice == "scale":
- scalefun(i)
- elif choice == "rotate":
- rotatefun(i)
- elif choice == "noise":
- noisefun(i)
-
- return images
|