Handwritten Equation Decipherment (HED) ======================================= .. raw:: html

For detailed code implementation, please view it on GitHub.

Below shows an implementation of `Handwritten Equation Decipherment `__. In this task, the handwritten equations are given, which consist of sequential pictures of characters. The equations are generated with unknown operation rules from images of symbols (‘0’, ‘1’, ‘+’ and ‘=’), and each equation is associated with a label indicating whether the equation is correct (i.e., positive) or not (i.e., negative). Also, we are given a knowledge base which involves the structure of the equations and a recursive definition of bit-wise operations. The task is to learn from a training set of above-mentioned equations and then to predict labels of unseen equations. Intuitively, we first use a machine learning model (learning part) to obtain the pseudo-labels (‘0’, ‘1’, ‘+’ and ‘=’) for the observed pictures. We then use the knowledge base (reasoning part) to perform abductive reasoning so as to yield ground hypotheses as possible explanations to the observed facts, suggesting some pseudo-labels to be revised. This process enables us to further update the machine learning model. .. code:: python # Import necessary libraries and modules import os.path as osp import matplotlib.pyplot as plt import torch import torch.nn as nn from ablkit.learning import ABLModel, BasicNN from ablkit.utils import ABLLogger, print_log from bridge import HedBridge from consistency_metric import ConsistencyMetric from datasets import get_dataset, split_equation from models.nn import SymbolNet from reasoning import HedKB, HedReasoner Working with Data ----------------- First, we get the datasets of handwritten equations: .. code:: python total_train_data = get_dataset(train=True) train_data, val_data = split_equation(total_train_data, 3, 1) test_data = get_dataset(train=False) The datasets are shown below: .. code:: python true_train_equation = train_data[1] false_train_equation = train_data[0] print(f"Equations in the dataset is organized by equation length, " + f"from {min(train_data[0].keys())} to {max(train_data[0].keys())}.") print() true_train_equation_with_length_5 = true_train_equation[5] false_train_equation_with_length_5 = false_train_equation[5] print(f"For each eqaation length, there are {len(true_train_equation_with_length_5)} " + f"true equations and {len(false_train_equation_with_length_5)} false equations " + f"in the training set.") true_val_equation = val_data[1] false_val_equation = val_data[0] true_val_equation_with_length_5 = true_val_equation[5] false_val_equation_with_length_5 = false_val_equation[5] print(f"For each equation length, there are {len(true_val_equation_with_length_5)} " + f"true equations and {len(false_val_equation_with_length_5)} false equations " + f"in the validation set.") true_test_equation = test_data[1] false_test_equation = test_data[0] true_test_equation_with_length_5 = true_test_equation[5] false_test_equation_with_length_5 = false_test_equation[5] print(f"For each equation length, there are {len(true_test_equation_with_length_5)} " + f"true equations and {len(false_test_equation_with_length_5)} false equations " + f"in the test set.") Out: .. code:: none :class: code-out Equations in the dataset is organized by equation length, from 5 to 26. For each equation length, there are 225 true equations and 225 false equations in the training set. For each equation length, there are 75 true equations and 75 false equations in the validation set. For each equation length, there are 300 true equations and 300 false equations in the test set. As illustrations, we show four equations in the training dataset: .. code:: python true_train_equation_with_length_5 = true_train_equation[5] true_train_equation_with_length_8 = true_train_equation[8] print(f"First true equation with length 5 in the training dataset:") for i, x in enumerate(true_train_equation_with_length_5[0]): plt.subplot(1, 5, i+1) plt.axis('off') plt.imshow(x.squeeze(), cmap='gray') plt.show() print(f"First true equation with length 8 in the training dataset:") for i, x in enumerate(true_train_equation_with_length_8[0]): plt.subplot(1, 8, i+1) plt.axis('off') plt.imshow(x.squeeze(), cmap='gray') plt.show() false_train_equation_with_length_5 = false_train_equation[5] false_train_equation_with_length_8 = false_train_equation[8] print(f"First false equation with length 5 in the training dataset:") for i, x in enumerate(false_train_equation_with_length_5[0]): plt.subplot(1, 5, i+1) plt.axis('off') plt.imshow(x.squeeze(), cmap='gray') plt.show() print(f"First false equation with length 8 in the training dataset:") for i, x in enumerate(false_train_equation_with_length_8[0]): plt.subplot(1, 8, i+1) plt.axis('off') plt.imshow(x.squeeze(), cmap='gray') plt.show() Out: .. code:: none :class: code-out First true equation with length 5 in the training dataset: .. image:: ../_static/img/hed_dataset1.png :width: 300px .. code:: none :class: code-out First true equation with length 8 in the training dataset: .. image:: ../_static/img/hed_dataset2.png :width: 480px .. code:: none :class: code-out First false equation with length 5 in the training dataset: .. image:: ../_static/img/hed_dataset3.png :width: 300px .. code:: none :class: code-out First false equation with length 8 in the training dataset: .. image:: ../_static/img/hed_dataset4.png :width: 480px Building the Learning Part -------------------------- To build the learning part, we need to first build a machine learning base model. We use SymbolNet, and encapsulate it within a ``BasicNN`` object to create the base model. ``BasicNN`` is a class that encapsulates a PyTorch model, transforming it into a base model with an sklearn-style interface. .. code:: python # class of symbol may be one of ['0', '1', '+', '='], total of 4 classes cls = SymbolNet(num_classes=4) loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-4) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") base_model = BasicNN( cls, loss_fn, optimizer, device=device, batch_size=32, num_epochs=1, stop_loss=None, ) However, the base model built above deals with instance-level data (i.e., individual images), and can not directly deal with example-level data (i.e., a list of images comprising the equation). Therefore, we wrap the base model into ``ABLModel``, which enables the learning part to train, test, and predict on example-level data. .. code:: python model = ABLModel(base_model) Building the Reasoning Part --------------------------- In the reasoning part, we first build a knowledge base. As mentioned before, the knowledge base in this task involves the structure of the equations and a recursive definition of bit-wise operations, which are defined in Prolog file ``examples/hed/reasoning/BK.pl`` and ``examples/hed/reasoning/learn_add.pl``, respectively. Specifically, the knowledge about the structure of equations is a set of DCG rules recursively define that a digit is a sequence of ‘0’ and ‘1’, and equations share the structure of X+Y=Z, though the length of X, Y and Z can be varied. The knowledge about bit-wise operations is a recursive logic program, which reversely calculates X+Y, i.e., it operates on X and Y digit-by-digit and from the last digit to the first. The knowledge base is already built in ``HedKB``. ``HedKB`` is derived from class ``PrologKB``, and is built upon the aformentioned Prolog files. .. code:: python kb = HedKB() .. note:: Please notice that, the specific rules for calculating the operations are undefined in the knowledge base, i.e., results of ‘0+0’, ‘0+1’ and ‘1+1’ could be ‘0’, ‘1’, ‘00’, ‘01’ or even ‘10’. The missing calculation rules are required to be learned from the data. Therefore, ``HedKB`` incorporates methods for abducing rules from data. Users interested can refer to the specific implementation of ``HedKB`` in ``examples/hed/reasoning/reasoning.py`` Then, we create a reasoner. Due to the indeterminism of abductive reasoning, there could be multiple candidates compatible with the knowledge base. When this happens, reasoner can minimize inconsistencies between the knowledge base and pseudo-labels predicted by the learning part, and then return only one candidate that has the highest consistency. In this task, we create the reasoner by instantiating the class ``HedReasoner``, which is a reasoner derived from ``Reasoner`` and tailored specifically for this task. ``HedReasoner`` leverages `ZOOpt library `__ for acceleration, and has designed a specific strategy to better harness ZOOpt’s capabilities. Additionally, methods for abducing rules from data have been incorporated. Users interested can refer to the specific implementation of ``HedReasoner`` in ``reasoning/reasoning.py``. .. code:: python reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=10) Building Evaluation Metrics --------------------------- Next, we set up evaluation metrics. These metrics will be used to evaluate the model performance during training and testing. Specifically, we use ``SymbolAccuracy`` and ``ReasoningMetric``, which are used to evaluate the accuracy of the machine learning model’s predictions and the accuracy of the final reasoning results, respectively. .. code:: python # Set up metrics metric_list = [SymbolAccuracy(prefix="hed"), ReasoningMetric(kb=kb, prefix="hed")] Bridging Learning and Reasoning ------------------------------- Now, the last step is to bridge the learning and reasoning part. We proceed with this step by creating an instance of ``HedBridge``, which is derived from ``SimpleBridge`` and tailored specific for this task. .. code:: python bridge = HedBridge(model, reasoner, metric_list) Perform pretraining, training and testing by invoking the ``pretrain``, ``train`` and ``test`` methods of ``HedBridge``. .. code:: python # Build logger print_log("Abductive Learning on the HED example.", logger="current") # Retrieve the directory of the Log file and define the directory for saving the model weights. log_dir = ABLLogger.get_current_instance().log_dir weights_dir = osp.join(log_dir, "weights") bridge.pretrain("./weights") bridge.train(train_data, val_data, save_dir=weights_dir) bridge.test(test_data)