You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

get_data.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. import torch
  2. from torchvision import datasets, transforms
  3. import torch.nn.functional as F
  4. from scipy.ndimage.interpolation import rotate as scipyrotate
  5. import numpy as np
  6. def get_fashion_mnist(data_root="./data", output_channels=1, image_size=28):
  7. ds_train = datasets.FashionMNIST(
  8. data_root,
  9. train=True,
  10. download=True,
  11. transform=transforms.Compose([transforms.ToTensor(), transforms.Resize([image_size, image_size])]),
  12. )
  13. X_train = ds_train.data
  14. y_train = ds_train.targets
  15. ds_test = datasets.FashionMNIST(
  16. data_root,
  17. train=False,
  18. download=True,
  19. transform=transforms.Compose([transforms.ToTensor(), transforms.Resize([image_size, image_size])]),
  20. )
  21. X_test = ds_test.data
  22. y_test = ds_test.targets
  23. X_train = X_train[:, None, :, :].float()
  24. X_test = X_test[:, None, :, :].float()
  25. if output_channels > 1:
  26. X_train = torch.cat([X_train for i in range(output_channels)], 1)
  27. X_test = torch.cat([X_test for i in range(output_channels)], 1)
  28. X_test = (X_test - torch.mean(X_train, [0, 2, 3], keepdim=True)) / (torch.std(X_train, [0, 2, 3], keepdim=True))
  29. X_train = (X_train - torch.mean(X_train, [0, 2, 3], keepdim=True)) / (torch.std(X_train, [0, 2, 3], keepdim=True))
  30. return X_train, y_train, X_test, y_test
  31. def get_mnist(data_root="./data/", output_channels=1, image_size=28):
  32. ds_train = datasets.MNIST(
  33. data_root,
  34. train=True,
  35. download=True,
  36. transform=transforms.Compose([transforms.ToTensor(), transforms.Resize([image_size, image_size])]),
  37. )
  38. X_train = []
  39. for x, _ in ds_train:
  40. X_train.append(x)
  41. X_train = torch.stack(X_train)
  42. y_train = ds_train.targets
  43. ds_test = datasets.MNIST(
  44. data_root,
  45. train=False,
  46. download=True,
  47. transform=transforms.Compose([transforms.ToTensor(), transforms.Resize([image_size, image_size])]),
  48. )
  49. X_test = []
  50. for x, _ in ds_test:
  51. X_test.append(x)
  52. X_test = torch.stack(X_test)
  53. y_test = ds_test.targets
  54. if output_channels > 1:
  55. X_train = torch.cat([X_train for i in range(output_channels)], 1)
  56. X_test = torch.cat([X_test for i in range(output_channels)], 1)
  57. X_test = (X_test - torch.mean(X_train, [0, 2, 3], keepdim=True)) / (torch.std(X_train, [0, 2, 3], keepdim=True))
  58. X_train = (X_train - torch.mean(X_train, [0, 2, 3], keepdim=True)) / (torch.std(X_train, [0, 2, 3], keepdim=True))
  59. return X_train, y_train, X_test, y_test
  60. def get_cifar10(data_root="./data/", output_channels=3, image_size=32):
  61. ds_train = datasets.CIFAR10(
  62. data_root,
  63. train=True,
  64. download=True,
  65. transform=transforms.Compose([transforms.ToTensor(), transforms.Resize([image_size, image_size])]),
  66. )
  67. X_train = ds_train.data
  68. y_train = ds_train.targets
  69. ds_test = datasets.CIFAR10(
  70. data_root,
  71. train=False,
  72. download=True,
  73. transform=transforms.Compose([transforms.ToTensor(), transforms.Resize([image_size, image_size])]),
  74. )
  75. X_test = ds_test.data
  76. y_test = ds_test.targets
  77. X_train = torch.Tensor(np.moveaxis(X_train, 3, 1))
  78. y_train = torch.Tensor(y_train).long()
  79. X_test = torch.Tensor(np.moveaxis(X_test, 3, 1))
  80. y_test = torch.Tensor(y_test).long()
  81. if output_channels == 1:
  82. X_train = torch.mean(X_train, 1, keepdim=True)
  83. X_test = torch.mean(X_test, 1, keepdim=True)
  84. X_test = (X_test - torch.mean(X_train, [0, 2, 3], keepdim=True)) / (torch.std(X_train, [0, 2, 3], keepdim=True))
  85. X_train = (X_train - torch.mean(X_train, [0, 2, 3], keepdim=True)) / (torch.std(X_train, [0, 2, 3], keepdim=True))
  86. return X_train, y_train, X_test, y_test
  87. def get_svhn(output_channels=1, image_size=32):
  88. ds_train = datasets.SVHN(
  89. "./data/",
  90. split="train",
  91. download=True,
  92. transform=transforms.Compose([transforms.ToTensor(), transforms.Resize([image_size, image_size])]),
  93. )
  94. X_train = ds_train.data
  95. y_train = ds_train.labels
  96. ds_test = datasets.SVHN(
  97. "./data/",
  98. split="test",
  99. download=True,
  100. transform=transforms.Compose([transforms.ToTensor(), transforms.Resize([image_size, image_size])]),
  101. )
  102. X_test = ds_test.data
  103. y_test = ds_test.labels
  104. X_train = torch.Tensor(X_train)
  105. y_train = torch.Tensor(y_train).long()
  106. X_test = torch.Tensor(X_test)
  107. y_test = torch.Tensor(y_test).long()
  108. if output_channels == 1:
  109. X_train = torch.mean(X_train, 1, keepdim=True)
  110. X_test = torch.mean(X_test, 1, keepdim=True)
  111. X_test = (X_test - torch.mean(X_train, [0, 2, 3], keepdim=True)) / (torch.std(X_train, [0, 2, 3], keepdim=True))
  112. X_train = (X_train - torch.mean(X_train, [0, 2, 3], keepdim=True)) / (torch.std(X_train, [0, 2, 3], keepdim=True))
  113. return X_train, y_train, X_test, y_test
  114. def get_cifar100(data_root="./data/", output_channels=3, image_size=32):
  115. ds_train = datasets.CIFAR100(
  116. data_root,
  117. train=True,
  118. download=True,
  119. transform=transforms.Compose([transforms.ToTensor(), transforms.Resize([image_size, image_size])]),
  120. )
  121. X_train = ds_train.data
  122. y_train = ds_train.targets
  123. ds_test = datasets.CIFAR100(
  124. data_root,
  125. train=False,
  126. download=True,
  127. transform=transforms.Compose([transforms.ToTensor(), transforms.Resize([image_size, image_size])]),
  128. )
  129. X_test = ds_test.data
  130. y_test = ds_test.targets
  131. X_train = torch.Tensor(np.moveaxis(X_train, 3, 1))
  132. y_train = torch.Tensor(y_train).long()
  133. X_test = torch.Tensor(np.moveaxis(X_test, 3, 1))
  134. y_test = torch.Tensor(y_test).long()
  135. if output_channels == 1:
  136. X_train = torch.mean(X_train, 1, keepdim=True)
  137. X_test = torch.mean(X_test, 1, keepdim=True)
  138. X_test = (X_test - torch.mean(X_train, [0, 2, 3], keepdim=True)) / (torch.std(X_train, [0, 2, 3], keepdim=True))
  139. X_train = (X_train - torch.mean(X_train, [0, 2, 3], keepdim=True)) / (torch.std(X_train, [0, 2, 3], keepdim=True))
  140. return X_train, y_train, X_test, y_test
  141. def get_zca_matrix(X, reg_coef=0.1):
  142. X_flat = X.reshape(X.shape[0], -1)
  143. cov = (X_flat.T @ X_flat) / X_flat.shape[0]
  144. reg_amount = reg_coef * torch.trace(cov) / cov.shape[0]
  145. u, s, _ = torch.svd(cov.cuda() + reg_amount * torch.eye(cov.shape[0]).cuda())
  146. inv_sqrt_zca_eigs = s ** (-0.5)
  147. whitening_transform = torch.einsum("ij,j,kj->ik", u, inv_sqrt_zca_eigs, u)
  148. return whitening_transform.cpu()
  149. def layernorm_data(X):
  150. X_processed = X - torch.mean(X, [1, 2, 3], keepdim=True)
  151. X_processed = X_processed / torch.sqrt(torch.sum(X_processed**2, [1, 2, 3], keepdim=True))
  152. return X_processed
  153. def transform_data(X, whitening_transform):
  154. if len(whitening_transform.shape) == 2:
  155. X_flat = X.reshape(X.shape[0], -1)
  156. X_flat = X_flat @ whitening_transform
  157. return X_flat.view(*X.shape)
  158. else:
  159. X_flat = X.reshape(X.shape[0], -1)
  160. X_flat = torch.einsum("nd, ndi->ni", X_flat, whitening_transform)
  161. return X_flat.view(*X.shape)
  162. def scale_to_zero_one(X):
  163. mins = torch.min(X.view(X.shape[0], -1), 1)[0].view(-1, 1, 1, 1)
  164. maxes = torch.max(X.view(X.shape[0], -1), 1)[0].view(-1, 1, 1, 1)
  165. return (X - mins) / (maxes - mins)
  166. def augment(images, dc_aug_param, device):
  167. # This can be sped up in the future.
  168. if dc_aug_param != None and dc_aug_param["strategy"] != "none":
  169. scale = dc_aug_param["scale"]
  170. crop = dc_aug_param["crop"]
  171. rotate = dc_aug_param["rotate"]
  172. noise = dc_aug_param["noise"]
  173. strategy = dc_aug_param["strategy"]
  174. shape = images.shape
  175. mean = []
  176. for c in range(shape[1]):
  177. mean.append(float(torch.mean(images[:, c])))
  178. def cropfun(i):
  179. im_ = torch.zeros(shape[1], shape[2] + crop * 2, shape[3] + crop * 2, dtype=torch.float, device=device)
  180. for c in range(shape[1]):
  181. im_[c] = mean[c]
  182. im_[:, crop : crop + shape[2], crop : crop + shape[3]] = images[i]
  183. r, c = np.random.permutation(crop * 2)[0], np.random.permutation(crop * 2)[0]
  184. images[i] = im_[:, r : r + shape[2], c : c + shape[3]]
  185. def scalefun(i):
  186. h = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2])
  187. w = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2])
  188. tmp = F.interpolate(
  189. images[i : i + 1],
  190. [h, w],
  191. )[0]
  192. mhw = max(h, w, shape[2], shape[3])
  193. im_ = torch.zeros(shape[1], mhw, mhw, dtype=torch.float, device=device)
  194. r = int((mhw - h) / 2)
  195. c = int((mhw - w) / 2)
  196. im_[:, r : r + h, c : c + w] = tmp
  197. r = int((mhw - shape[2]) / 2)
  198. c = int((mhw - shape[3]) / 2)
  199. images[i] = im_[:, r : r + shape[2], c : c + shape[3]]
  200. def rotatefun(i):
  201. im_ = scipyrotate(
  202. images[i].cpu().data.numpy(),
  203. angle=np.random.randint(-rotate, rotate),
  204. axes=(-2, -1),
  205. cval=np.mean(mean),
  206. )
  207. r = int((im_.shape[-2] - shape[-2]) / 2)
  208. c = int((im_.shape[-1] - shape[-1]) / 2)
  209. images[i] = torch.tensor(im_[:, r : r + shape[-2], c : c + shape[-1]], dtype=torch.float, device=device)
  210. def noisefun(i):
  211. images[i] = images[i] + noise * torch.randn(shape[1:], dtype=torch.float, device=device)
  212. augs = strategy.split("_")
  213. for i in range(shape[0]):
  214. choice = np.random.permutation(augs)[0] # randomly implement one augmentation
  215. if choice == "crop":
  216. cropfun(i)
  217. elif choice == "scale":
  218. scalefun(i)
  219. elif choice == "rotate":
  220. rotatefun(i)
  221. elif choice == "noise":
  222. noisefun(i)
  223. return images