You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

get_dataset.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import os
  2. import os.path as osp
  3. import pickle
  4. import random
  5. import zipfile
  6. from collections import defaultdict
  7. from PIL import Image
  8. import gdown
  9. import numpy as np
  10. from torchvision.transforms import transforms
  11. CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))
  12. def download_and_unzip(url, zip_file_name):
  13. try:
  14. gdown.download(url, zip_file_name)
  15. with zipfile.ZipFile(zip_file_name, "r") as zip_ref:
  16. zip_ref.extractall(CURRENT_DIR)
  17. os.remove(zip_file_name)
  18. except Exception as e:
  19. if os.path.exists(zip_file_name):
  20. os.remove(zip_file_name)
  21. raise Exception(
  22. f"An error occurred during download or unzip: {e}. Instead, you can download "
  23. + f"the dataset from {url} and unzip it in 'examples/hed/datasets' folder"
  24. )
  25. def get_pretrain_data(labels, image_size=(28, 28, 1)):
  26. transform = transforms.Compose([transforms.ToTensor()])
  27. X = []
  28. img_dir = osp.join(CURRENT_DIR, "mnist_images")
  29. for label in labels:
  30. label_path = osp.join(img_dir, label)
  31. img_path_list = os.listdir(label_path)
  32. for img_path in img_path_list:
  33. with Image.open(osp.join(label_path, img_path)) as img:
  34. img = img.convert("L")
  35. img = img.resize((image_size[1], image_size[0]))
  36. img_array = np.array(img, dtype=np.float32)
  37. normalized_img = (img_array - 127) / 128.0
  38. X.append(normalized_img)
  39. Y = [img.copy().reshape(image_size[0] * image_size[1] * image_size[2]) for img in X]
  40. X = [transform(img[:, :, np.newaxis]) for img in X]
  41. return X, Y
  42. def divide_equations_by_len(equations, labels):
  43. equations_by_len = {1: defaultdict(list), 0: defaultdict(list)}
  44. for i, equation in enumerate(equations):
  45. equations_by_len[labels[i]][len(equation)].append(equation)
  46. return equations_by_len
  47. def split_equation(equations_by_len, prop_train, prop_val):
  48. """
  49. Split the equations in each length to training and validation data according to the proportion
  50. """
  51. train_equations_by_len = {1: {}, 0: {}}
  52. val_equations_by_len = {1: {}, 0: {}}
  53. for label in range(2):
  54. for equation_len, equations in equations_by_len[label].items():
  55. random.shuffle(equations)
  56. train_equations_by_len[label][equation_len] = equations[
  57. : len(equations) // (prop_train + prop_val) * prop_train
  58. ]
  59. val_equations_by_len[label][equation_len] = equations[
  60. len(equations) // (prop_train + prop_val) * prop_train :
  61. ]
  62. return train_equations_by_len, val_equations_by_len
  63. def get_dataset(dataset="mnist", train=True):
  64. data_dir = CURRENT_DIR + "/mnist_images"
  65. if not os.path.exists(data_dir):
  66. print("Dataset not exist, downloading it...")
  67. url = "https://drive.google.com/u/0/uc?id=1W2AUn_fnXa4XkgLk4d17K3bEgpae8GMg&export=download"
  68. download_and_unzip(url, os.path.join(CURRENT_DIR, "HED.zip"))
  69. print("Download and extraction complete.")
  70. if train:
  71. file = os.path.join(data_dir, "expr_train.json")
  72. else:
  73. file = os.path.join(data_dir, "expr_test.json")
  74. if dataset == "mnist":
  75. file = osp.join(CURRENT_DIR, "mnist_equation_data_train_len_26_test_len_26_sys_2_.pk")
  76. elif dataset == "random":
  77. file = osp.join(CURRENT_DIR, "random_equation_data_train_len_26_test_len_26_sys_2_.pk")
  78. else:
  79. raise ValueError("Undefined dataset")
  80. with open(file, "rb") as f:
  81. img_dataset = pickle.load(f)
  82. X, Y = [], []
  83. if train:
  84. positive = img_dataset["train:positive"]
  85. negative = img_dataset["train:negative"]
  86. else:
  87. positive = img_dataset["test:positive"]
  88. negative = img_dataset["test:negative"]
  89. for equation in positive:
  90. equation = equation.astype(np.float32)
  91. img_list = np.vsplit(equation, equation.shape[0])
  92. X.append(img_list)
  93. Y.append(1)
  94. for equation in negative:
  95. equation = equation.astype(np.float32)
  96. img_list = np.vsplit(equation, equation.shape[0])
  97. X.append(img_list)
  98. Y.append(0)
  99. equations_by_len = divide_equations_by_len(X, Y)
  100. return equations_by_len

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.