You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

get_dataset.py 2.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import json
  2. import os
  3. import zipfile
  4. import gdown
  5. from PIL import Image
  6. from torchvision.transforms import transforms
  7. CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))
  8. img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1,))])
  9. def download_and_unzip(url, zip_file_name):
  10. try:
  11. gdown.download(url, zip_file_name)
  12. with zipfile.ZipFile(zip_file_name, "r") as zip_ref:
  13. zip_ref.extractall(CURRENT_DIR)
  14. os.remove(zip_file_name)
  15. except Exception as e:
  16. if os.path.exists(zip_file_name):
  17. os.remove(zip_file_name)
  18. raise Exception(
  19. f"An error occurred during download or unzip: {e}. Instead, you can download "
  20. + f"the dataset from {url} and unzip it in 'examples/hwf/datasets' folder"
  21. )
  22. def get_dataset(train=True, get_pseudo_label=False):
  23. data_dir = CURRENT_DIR + "/data"
  24. if not os.path.exists(data_dir):
  25. print("Dataset not exist, downloading it...")
  26. url = "https://drive.google.com/u/0/uc?id=1t52OE2Wdm5GdShX1jD2Wy8phCllk0r8I&export=download"
  27. download_and_unzip(url, os.path.join(CURRENT_DIR, "HWF.zip"))
  28. print("Download and extraction complete.")
  29. if train:
  30. file = os.path.join(data_dir, "expr_train.json")
  31. else:
  32. file = os.path.join(data_dir, "expr_test.json")
  33. X = []
  34. pseudo_label = [] if get_pseudo_label else None
  35. Y = []
  36. img_dir = os.path.join(CURRENT_DIR, "data/Handwritten_Math_Symbols/")
  37. with open(file) as f:
  38. data = json.load(f)
  39. for idx in range(len(data)):
  40. imgs = []
  41. if get_pseudo_label:
  42. imgs_pseudo_label = []
  43. for img_path in data[idx]["img_paths"]:
  44. img = Image.open(img_dir + img_path).convert("L")
  45. img = img_transform(img)
  46. imgs.append(img)
  47. if get_pseudo_label:
  48. label_mappings = {"times": "*", "div": "/"}
  49. label = img_path.split("/")[0]
  50. label = label_mappings.get(label, label)
  51. imgs_pseudo_label.append(label)
  52. X.append(imgs)
  53. if get_pseudo_label:
  54. pseudo_label.append(imgs_pseudo_label)
  55. Y.append(data[idx]["res"])
  56. return X, pseudo_label, Y

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.