You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

get_hed.py 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import os
  2. import cv2
  3. import torch
  4. import torchvision
  5. import pickle
  6. import numpy as np
  7. import random
  8. from collections import defaultdict
  9. from torch.utils.data import Dataset
  10. from torchvision.transforms import transforms
  11. def get_data(img_dataset, train):
  12. transform = transforms.Compose([transforms.ToTensor()])
  13. X = []
  14. Y = []
  15. if train:
  16. positive = img_dataset["train:positive"]
  17. negative = img_dataset["train:negative"]
  18. else:
  19. positive = img_dataset["test:positive"]
  20. negative = img_dataset["test:negative"]
  21. for equation in positive:
  22. equation = equation.astype(np.float32)
  23. img_list = np.vsplit(equation, equation.shape[0])
  24. X.append(img_list)
  25. Y.append(1)
  26. for equation in negative:
  27. equation = equation.astype(np.float32)
  28. img_list = np.vsplit(equation, equation.shape[0])
  29. X.append(img_list)
  30. Y.append(0)
  31. return X, None, Y
  32. def get_pretrain_data(labels, image_size=(28, 28, 1)):
  33. transform = transforms.Compose([transforms.ToTensor()])
  34. X = []
  35. for label in labels:
  36. label_path = os.path.join(
  37. "./datasets/mnist_images", label
  38. )
  39. img_path_list = os.listdir(label_path)
  40. for img_path in img_path_list:
  41. img = cv2.imread(
  42. os.path.join(label_path, img_path), cv2.IMREAD_GRAYSCALE
  43. )
  44. img = cv2.resize(img, (image_size[1], image_size[0]))
  45. X.append(np.array(img, dtype=np.float32))
  46. X = [((img[:, :, np.newaxis] - 127) / 128.0) for img in X]
  47. Y = [img.copy().reshape(image_size[0] * image_size[1] * image_size[2]) for img in X]
  48. X = [transform(img) for img in X]
  49. return X, Y
  50. # def get_pretrain_data(train_data, image_size=(28, 28, 1)):
  51. # X = []
  52. # for label in [0, 1]:
  53. # for _, equation_list in train_data[label].items():
  54. # for equation in equation_list:
  55. # X = X + equation
  56. # X = np.array(X)
  57. # index = np.array(list(range(len(X))))
  58. # np.random.shuffle(index)
  59. # X = X[index]
  60. # X = [img for img in X]
  61. # Y = [img.copy().reshape(image_size[0] * image_size[1] * image_size[2]) for img in X]
  62. # return X, Y
  63. def divide_equations_by_len(equations, labels):
  64. equations_by_len = {1: defaultdict(list), 0: defaultdict(list)}
  65. for i, equation in enumerate(equations):
  66. equations_by_len[labels[i]][len(equation)].append(equation)
  67. return equations_by_len
  68. def split_equation(equations_by_len, prop_train, prop_val):
  69. """
  70. Split the equations in each length to training and validation data according to the proportion
  71. """
  72. train_equations_by_len = {1: dict(), 0: dict()}
  73. val_equations_by_len = {1: dict(), 0: dict()}
  74. for label in range(2):
  75. for equation_len, equations in equations_by_len[label].items():
  76. random.shuffle(equations)
  77. train_equations_by_len[label][equation_len] = equations[
  78. : len(equations) // (prop_train + prop_val) * prop_train
  79. ]
  80. val_equations_by_len[label][equation_len] = equations[
  81. len(equations) // (prop_train + prop_val) * prop_train :
  82. ]
  83. return train_equations_by_len, val_equations_by_len
  84. def get_hed(dataset="mnist", train=True):
  85. if dataset == "mnist":
  86. with open(
  87. "./datasets/mnist_equation_data_train_len_26_test_len_26_sys_2_.pk",
  88. "rb",
  89. ) as f:
  90. img_dataset = pickle.load(f)
  91. elif dataset == "random":
  92. with open(
  93. "./datasets/random_equation_data_train_len_26_test_len_26_sys_2_.pk",
  94. "rb",
  95. ) as f:
  96. img_dataset = pickle.load(f)
  97. else:
  98. raise Exception("Undefined dataset")
  99. X, _, Y = get_data(img_dataset, train)
  100. equations_by_len = divide_equations_by_len(X, Y)
  101. return equations_by_len
  102. if __name__ == "__main__":
  103. get_hed()

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.