import json import os from PIL import Image from torchvision.transforms import transforms CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1,))]) def get_data(file, get_pseudo_label): X, Y = [], [] if get_pseudo_label: Z = [] img_dir = os.path.join(CURRENT_DIR, "data/Handwritten_Math_Symbols/") with open(file) as f: data = json.load(f) for idx in range(len(data)): imgs = [] imgs_pseudo_label = [] for img_path in data[idx]["img_paths"]: img = Image.open(img_dir + img_path).convert("L") img = img_transform(img) imgs.append(img) if get_pseudo_label: imgs_pseudo_label.append(img_path.split("/")[0]) X.append(imgs) if get_pseudo_label: Z.append(imgs_pseudo_label) Y.append(data[idx]["res"]) if get_pseudo_label: return X, Z, Y else: return X, None, Y def get_hwf(train=True, get_gt_pseudo_label=False): if train: file = os.path.join(CURRENT_DIR, "data/expr_train.json") else: file = os.path.join(CURRENT_DIR, "data/expr_test.json") return get_data(file, get_gt_pseudo_label)