You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

get_mnist_add.py 1.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import os.path as osp
  2. import torchvision
  3. from torchvision.transforms import transforms
  4. CURRENT_DIR = osp.abspath(osp.dirname(__file__))
  5. def get_data(file, img_dataset, get_pseudo_label):
  6. X, Y = [], []
  7. if get_pseudo_label:
  8. Z = []
  9. with open(file) as f:
  10. for line in f:
  11. # if len(X) == 1000:
  12. # break
  13. line = line.strip().split(" ")
  14. X.append([img_dataset[int(line[0])][0], img_dataset[int(line[1])][0]])
  15. if get_pseudo_label:
  16. Z.append([img_dataset[int(line[0])][1], img_dataset[int(line[1])][1]])
  17. Y.append(int(line[2]))
  18. if get_pseudo_label:
  19. return X, Z, Y
  20. else:
  21. return X, None, Y
  22. def get_mnist_add(train=True, get_pseudo_label=False):
  23. transform = transforms.Compose(
  24. [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
  25. )
  26. img_dataset = torchvision.datasets.MNIST(
  27. root=CURRENT_DIR, train=train, download=True, transform=transform
  28. )
  29. if train:
  30. file = osp.join(CURRENT_DIR, "train_data.txt")
  31. else:
  32. file = osp.join(CURRENT_DIR, "test_data.txt")
  33. return get_data(file, img_dataset, get_pseudo_label)
  34. if __name__ == "__main__":
  35. train_X, train_Z, train_Y = get_mnist_add(train=True)
  36. test_X, test_Z, test_Y = get_mnist_add(train=False)
  37. print(len(train_X), len(test_X))
  38. print(train_X[0][0].shape, train_X[0][1].shape, train_Y[0])

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.