You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

get_dataset.py 2.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import json
  2. import os
  3. import zipfile
  4. import gdown
  5. from PIL import Image
  6. from torchvision.transforms import transforms
  7. CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))
  8. img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1,))])
  9. def download_and_unzip(url, zip_file_name):
  10. try:
  11. gdown.download(url, zip_file_name)
  12. with zipfile.ZipFile(zip_file_name, 'r') as zip_ref:
  13. zip_ref.extractall(CURRENT_DIR)
  14. os.remove(zip_file_name)
  15. except Exception as e:
  16. if os.path.exists(zip_file_name):
  17. os.remove(zip_file_name)
  18. raise Exception(f"An error occurred during download or unzip: {e}. Instead, you can download the dataset from {url} and unzip it in 'examples/hwf/datasets' folder")
  19. def get_dataset(train=True, get_pseudo_label=False):
  20. data_dir = CURRENT_DIR + '/data'
  21. if not os.path.exists(data_dir):
  22. print("Dataset not exist, downloading it...")
  23. url = 'https://drive.google.com/u/0/uc?id=1G07kw-wK-rqbg_85tuB7FNfA49q8lvoy&export=download'
  24. download_and_unzip(url, os.path.join(CURRENT_DIR, "HWF.zip"))
  25. print("Download and extraction complete.")
  26. if train:
  27. file = os.path.join(data_dir, "expr_train.json")
  28. else:
  29. file = os.path.join(data_dir, "expr_test.json")
  30. X = []
  31. pseudo_label = [] if get_pseudo_label else None
  32. Y = []
  33. img_dir = os.path.join(CURRENT_DIR, "data/Handwritten_Math_Symbols/")
  34. with open(file) as f:
  35. data = json.load(f)
  36. for idx in range(len(data)):
  37. imgs = []
  38. if get_pseudo_label:
  39. imgs_pseudo_label = []
  40. for img_path in data[idx]["img_paths"]:
  41. img = Image.open(img_dir + img_path).convert("L")
  42. img = img_transform(img)
  43. imgs.append(img)
  44. if get_pseudo_label:
  45. label_mappings = {"times": "*", "div": "/"}
  46. label = img_path.split("/")[0]
  47. label = label_mappings.get(label, label)
  48. imgs_pseudo_label.append(label)
  49. X.append(imgs)
  50. if get_pseudo_label:
  51. pseudo_label.append(imgs_pseudo_label)
  52. Y.append(data[idx]["res"])
  53. return X, pseudo_label, Y

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.