You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

get_dataset.py 3.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import os
  2. import os.path as osp
  3. import pickle
  4. import random
  5. from collections import defaultdict
  6. import cv2
  7. import numpy as np
  8. from torchvision.transforms import transforms
  9. CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))
  10. def get_data(img_dataset, train):
  11. X, Y = [], []
  12. if train:
  13. positive = img_dataset["train:positive"]
  14. negative = img_dataset["train:negative"]
  15. else:
  16. positive = img_dataset["test:positive"]
  17. negative = img_dataset["test:negative"]
  18. for equation in positive:
  19. equation = equation.astype(np.float32)
  20. img_list = np.vsplit(equation, equation.shape[0])
  21. X.append(img_list)
  22. Y.append(1)
  23. for equation in negative:
  24. equation = equation.astype(np.float32)
  25. img_list = np.vsplit(equation, equation.shape[0])
  26. X.append(img_list)
  27. Y.append(0)
  28. return X, None, Y
  29. def get_pretrain_data(labels, image_size=(28, 28, 1)):
  30. transform = transforms.Compose([transforms.ToTensor()])
  31. X = []
  32. img_dir = osp.join(CURRENT_DIR, "mnist_images")
  33. for label in labels:
  34. label_path = osp.join(img_dir, label)
  35. img_path_list = os.listdir(label_path)
  36. for img_path in img_path_list:
  37. img = cv2.imread(osp.join(label_path, img_path), cv2.IMREAD_GRAYSCALE)
  38. img = cv2.resize(img, (image_size[1], image_size[0]))
  39. X.append(np.array(img, dtype=np.float32))
  40. X = [((img[:, :, np.newaxis] - 127) / 128.0) for img in X]
  41. Y = [img.copy().reshape(image_size[0] * image_size[1] * image_size[2]) for img in X]
  42. X = [transform(img) for img in X]
  43. return X, Y
  44. def divide_equations_by_len(equations, labels):
  45. equations_by_len = {1: defaultdict(list), 0: defaultdict(list)}
  46. for i, equation in enumerate(equations):
  47. equations_by_len[labels[i]][len(equation)].append(equation)
  48. return equations_by_len
  49. def split_equation(equations_by_len, prop_train, prop_val):
  50. """
  51. Split the equations in each length to training and validation data according to the proportion
  52. """
  53. train_equations_by_len = {1: dict(), 0: dict()}
  54. val_equations_by_len = {1: dict(), 0: dict()}
  55. for label in range(2):
  56. for equation_len, equations in equations_by_len[label].items():
  57. random.shuffle(equations)
  58. train_equations_by_len[label][equation_len] = equations[
  59. : len(equations) // (prop_train + prop_val) * prop_train
  60. ]
  61. val_equations_by_len[label][equation_len] = equations[
  62. len(equations) // (prop_train + prop_val) * prop_train :
  63. ]
  64. return train_equations_by_len, val_equations_by_len
  65. def get_dataset(dataset="mnist", train=True):
  66. if dataset == "mnist":
  67. file = osp.join(CURRENT_DIR, "mnist_equation_data_train_len_26_test_len_26_sys_2_.pk")
  68. elif dataset == "random":
  69. file = osp.join(CURRENT_DIR, "random_equation_data_train_len_26_test_len_26_sys_2_.pk")
  70. else:
  71. raise ValueError("Undefined dataset")
  72. with open(file, "rb") as f:
  73. img_dataset = pickle.load(f)
  74. X, _, Y = get_data(img_dataset, train)
  75. equations_by_len = divide_equations_by_len(X, Y)
  76. return equations_by_len
  77. if __name__ == "__main__":
  78. get_hed()

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.