|
|
@@ -59,8 +59,8 @@ def split_equation(equations_by_len, prop_train, prop_val): |
|
|
|
""" |
|
|
|
Split the equations in each length to training and validation data according to the proportion |
|
|
|
""" |
|
|
|
train_equations_by_len = {1: dict(), 0: dict()} |
|
|
|
val_equations_by_len = {1: dict(), 0: dict()} |
|
|
|
train_equations_by_len = {1: {}, 0: {}} |
|
|
|
val_equations_by_len = {1: {}, 0: {}} |
|
|
|
|
|
|
|
for label in range(2): |
|
|
|
for equation_len, equations in equations_by_len[label].items(): |
|
|
@@ -80,7 +80,7 @@ def get_dataset(dataset="mnist", train=True): |
|
|
|
|
|
|
|
if not os.path.exists(data_dir): |
|
|
|
print("Dataset not exist, downloading it...") |
|
|
|
url = "https://drive.google.com/u/0/uc?id=1XoJDjO3cNUdytqVgXUKOBe9dOcUBobom&export=download" |
|
|
|
url = "https://drive.google.com/u/0/uc?id=1W2AUn_fnXa4XkgLk4d17K3bEgpae8GMg&export=download" |
|
|
|
download_and_unzip(url, os.path.join(CURRENT_DIR, "HED.zip")) |
|
|
|
print("Download and extraction complete.") |
|
|
|
|
|
|
|