import os import os.path as osp import pickle import random import zipfile from collections import defaultdict from PIL import Image import gdown import numpy as np from torchvision.transforms import transforms CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) def download_and_unzip(url, zip_file_name): try: gdown.download(url, zip_file_name) with zipfile.ZipFile(zip_file_name, "r") as zip_ref: zip_ref.extractall(CURRENT_DIR) os.remove(zip_file_name) except Exception as e: if os.path.exists(zip_file_name): os.remove(zip_file_name) raise Exception( f"An error occurred during download or unzip: {e}. Instead, you can download " + f"the dataset from {url} and unzip it in 'examples/hed/datasets' folder" ) def get_pretrain_data(labels, image_size=(28, 28, 1)): transform = transforms.Compose([transforms.ToTensor()]) X = [] img_dir = osp.join(CURRENT_DIR, "mnist_images") for label in labels: label_path = osp.join(img_dir, label) img_path_list = os.listdir(label_path) for img_path in img_path_list: with Image.open(osp.join(label_path, img_path)) as img: img = img.convert("L") img = img.resize((image_size[1], image_size[0])) img_array = np.array(img, dtype=np.float32) normalized_img = (img_array - 127) / 128.0 X.append(normalized_img) Y = [img.copy().reshape(image_size[0] * image_size[1] * image_size[2]) for img in X] X = [transform(img[:, :, np.newaxis]) for img in X] return X, Y def divide_equations_by_len(equations, labels): equations_by_len = {1: defaultdict(list), 0: defaultdict(list)} for i, equation in enumerate(equations): equations_by_len[labels[i]][len(equation)].append(equation) return equations_by_len def split_equation(equations_by_len, prop_train, prop_val): """ Split the equations in each length to training and validation data according to the proportion """ train_equations_by_len = {1: {}, 0: {}} val_equations_by_len = {1: {}, 0: {}} for label in range(2): for equation_len, equations in equations_by_len[label].items(): random.shuffle(equations) train_equations_by_len[label][equation_len] = equations[ : len(equations) // (prop_train + prop_val) * prop_train ] val_equations_by_len[label][equation_len] = equations[ len(equations) // (prop_train + prop_val) * prop_train : ] return train_equations_by_len, val_equations_by_len def get_dataset(dataset="mnist", train=True): data_dir = CURRENT_DIR + "/mnist_images" if not os.path.exists(data_dir): print("Dataset not exist, downloading it...") url = "https://drive.google.com/u/0/uc?id=1W2AUn_fnXa4XkgLk4d17K3bEgpae8GMg&export=download" download_and_unzip(url, os.path.join(CURRENT_DIR, "HED.zip")) print("Download and extraction complete.") if train: file = os.path.join(data_dir, "expr_train.json") else: file = os.path.join(data_dir, "expr_test.json") if dataset == "mnist": file = osp.join(CURRENT_DIR, "mnist_equation_data_train_len_26_test_len_26_sys_2_.pk") elif dataset == "random": file = osp.join(CURRENT_DIR, "random_equation_data_train_len_26_test_len_26_sys_2_.pk") else: raise ValueError("Undefined dataset") with open(file, "rb") as f: img_dataset = pickle.load(f) X, Y = [], [] if train: positive = img_dataset["train:positive"] negative = img_dataset["train:negative"] else: positive = img_dataset["test:positive"] negative = img_dataset["test:negative"] for equation in positive: equation = equation.astype(np.float32) img_list = np.vsplit(equation, equation.shape[0]) X.append(img_list) Y.append(1) for equation in negative: equation = equation.astype(np.float32) img_list = np.vsplit(equation, equation.shape[0]) X.append(img_list) Y.append(0) equations_by_len = divide_equations_by_len(X, Y) return equations_by_len