|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392 |
- MNIST Addition
- ==============
-
- Below shows an implementation of `MNIST
- Addition <https://arxiv.org/abs/1805.10872>`__. In this task, pairs of
- MNIST handwritten images and their sums are given, alongwith a domain
- knowledge base containing information on how to perform addition
- operations. The task is to recognize the digits of handwritten images
- and accurately determine their sum.
-
- Intuitively, we first use a machine learning model (learning part) to
- convert the input images to digits (we call them pseudo-labels), and
- then use the knowledge base (reasoning part) to calculate the sum of
- these digits. Since we do not have ground-truth of the digits, in
- Abductive Learning, the reasoning part will leverage domain knowledge
- and revise the initial digits yielded by the learning part through
- abductive reasoning. This process enables us to further update the
- machine learning model.
-
- .. code:: ipython3
-
- # Import necessary libraries and modules
- import os.path as osp
- import torch
- import torch.nn as nn
- import matplotlib.pyplot as plt
- from examples.mnist_add.datasets import get_dataset
- from examples.models.nn import LeNet5
- from abl.learning import ABLModel, BasicNN
- from abl.reasoning import KBBase, Reasoner
- from abl.evaluation import ReasoningMetric, SymbolMetric
- from abl.utils import ABLLogger, print_log
- from abl.bridge import SimpleBridge
-
- Working with Data
- -----------------
-
- First, we get the training and testing datasets:
-
- .. code:: ipython3
-
- train_data = get_dataset(train=True, get_pseudo_label=True)
- test_data = get_dataset(train=False, get_pseudo_label=True)
-
- ``train_data`` and ``test_data`` share identical structures:
- tuples with three components: X (list where each element is a
- list of two images), gt_pseudo_label (list where each element
- is a list of two digits, i.e., pseudo-labels) and Y (list where
- each element is the sum of the two digits). The length and structures
- of datasets are illustrated as follows.
-
- .. note::
-
- ``gt_pseudo_label`` is only used to evaluate the performance of
- the learning part but not to train the model.
-
- .. code:: ipython3
-
- print(f"Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y")
- print("\n")
- train_X, train_gt_pseudo_label, train_Y = train_data
- print(f"Length of X, gt_pseudo_label, Y in train_data: " +
- f"{len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}")
- test_X, test_gt_pseudo_label, test_Y = test_data
- print(f"Length of X, gt_pseudo_label, Y in test_data: " +
- f"{len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}")
- print("\n")
-
- X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]
- print(f"X is a {type(train_X).__name__}, " +
- f"with each element being a {type(X_0).__name__} " +
- f"of {len(X_0)} {type(X_0[0]).__name__}.")
- print(f"gt_pseudo_label is a {type(train_gt_pseudo_label).__name__}, " +
- f"with each element being a {type(gt_pseudo_label_0).__name__} " +
- f"of {len(gt_pseudo_label_0)} {type(gt_pseudo_label_0[0]).__name__}.")
- print(f"Y is a {type(train_Y).__name__}, " +
- f"with each element being a {type(Y_0).__name__}.")
-
-
- Out:
- .. code:: none
- :class: code-out
-
- Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y
-
- Length of X, gt_pseudo_label, Y in train_data: 30000, 30000, 30000
- Length of X, gt_pseudo_label, Y in test_data: 5000, 5000, 5000
-
- X is a list, with each element being a list of 2 Tensor.
- gt_pseudo_label is a list, with each element being a list of 2 int.
- Y is a list, with each element being a int.
-
-
- The ith element of X, gt_pseudo_label, and Y together constitute the ith
- data example. As an illustration, in the first data example of the
- training set, we have:
-
- .. code:: ipython3
-
- X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]
- print(f"X in the first data example (a list of two images):")
- plt.subplot(1,2,1)
- plt.axis('off')
- plt.imshow(X_0[0].numpy().transpose(1, 2, 0))
- plt.subplot(1,2,2)
- plt.axis('off')
- plt.imshow(X_0[1].numpy().transpose(1, 2, 0))
- plt.show()
- print(f"gt_pseudo_label in the first data example (a list of two ground truth pseudo-labels): {gt_pseudo_label_0}")
- print(f"Y in the first data example (their sum result): {Y_0}")
-
-
- Out:
- .. code:: none
- :class: code-out
-
- X in the first data example (a list of two images):
-
- .. image:: ../img/mnist_add_datasets.png
- :width: 400px
-
-
- .. parsed-literal::
-
- gt_pseudo_label in the first data example (a list of two ground truth pseudo-labels): [7, 5]
- Y in the first data example (their sum result): 12
-
-
- Building the Learning Part
- --------------------------
-
- To build the learning part, we need to first build a machine learning
- base model. We use a simple `LeNet-5 neural
- network <https://en.wikipedia.org/wiki/LeNet>`__, and encapsulate it
- within a ``BasicNN`` object to create the base model. ``BasicNN`` is a
- class that encapsulates a PyTorch model, transforming it into a base
- model with an sklearn-style interface.
-
- .. code:: ipython3
-
- cls = LeNet5(num_classes=10)
- loss_fn = nn.CrossEntropyLoss()
- optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, alpha=0.9)
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-
- base_model = BasicNN(
- cls,
- loss_fn,
- optimizer,
- device,
- batch_size=32,
- num_epochs=1,
- )
-
- ``BasicNN`` offers methods like ``predict`` and ``predict_prob``, which
- are used to predict the class index and the probabilities of each class
- for images. As shown below:
-
- .. code:: ipython3
-
- data_instances = [torch.randn(1, 28, 28).to(device) for _ in range(32)]
- pred_idx = base_model.predict(X=data_instances)
- print(f"Predicted class index for a batch of 32 instances: np.ndarray with shape {pred_idx.shape}")
- pred_prob = base_model.predict_proba(X=data_instances)
- print(f"Predicted class probabilities for a batch of 32 instances: np.ndarray with shape {pred_prob.shape}")
-
-
- Out:
- .. code:: none
- :class: code-out
-
- Predicted class index for a batch of 32 instances: np.ndarray with shape (32,)
- Predicted class probabilities for a batch of 32 instances: np.ndarray with shape (32, 10)
-
-
- However, the base model built above deals with instance-level data
- (i.e., individual images), and can not directly deal with example-level
- data (i.e., a pair of images). Therefore, we wrap the base model into
- ``ABLModel``, which enables the learning part to train, test, and
- predict on example-level data.
-
- .. code:: ipython3
-
- model = ABLModel(base_model)
-
- As an illustration, consider this example of training on example-level
- data using the ``predict`` method in ``ABLModel``. In this process, the
- method accepts data examples as input and outputs the class labels and
- the probabilities of each class for all instances within these data
- examples.
-
- .. code:: ipython3
-
- from abl.structures import ListData
- # ListData is a data structure provided by ABL-Package that can be used to organize data examples
- data_examples = ListData()
- # We use the first 100 data examples in the training set as an illustration
- data_examples.X = train_X[:100]
- data_examples.gt_pseudo_label = train_gt_pseudo_label[:100]
- data_examples.Y = train_Y[:100]
-
- # Perform prediction on the 100 data examples
- pred_label, pred_prob = model.predict(data_examples)['label'], model.predict(data_examples)['prob']
- print(f"Predicted class labels for the 100 data examples: \n" +
- f"a list of length {len(pred_label)}, and each element is " +
- f"a {type(pred_label[0]).__name__} of shape {pred_label[0].shape}.\n")
- print(f"Predicted class probabilities for the 100 data examples: \n" +
- f"a list of length {len(pred_prob)}, and each element is " +
- f"a {type(pred_prob[0]).__name__} of shape {pred_prob[0].shape}.")
-
-
- Out:
- .. code:: none
- :class: code-out
-
- Predicted class labels for the 100 data examples:
- a list of length 100, and each element is a ndarray of shape (2,).
-
- Predicted class probabilities for the 100 data examples:
- a list of length 100, and each element is a ndarray of shape (2, 10).
-
-
- Building the Reasoning Part
- ---------------------------
-
- In the reasoning part, we first build a knowledge base which contain
- information on how to perform addition operations. We build it by
- creating a subclass of ``KBBase``. In the derived subclass, we
- initialize the ``pseudo_label_list`` parameter specifying list of
- possible pseudo-labels, and override the ``logic_forward`` function
- defining how to perform (deductive) reasoning.
-
- .. code:: ipython3
-
- class AddKB(KBBase):
- def __init__(self, pseudo_label_list=list(range(10))):
- super().__init__(pseudo_label_list)
-
- # Implement the deduction function
- def logic_forward(self, nums):
- return sum(nums)
-
- kb = AddKB()
-
- The knowledge base can perform logical reasoning (both deductive
- reasoning and abductive reasoning). Below is an example of performing
- (deductive) reasoning, and users can refer to :ref:`Performing abductive
- reasoning in the knowledge base <kb-abd>` for details of abductive reasoning.
-
- .. code:: ipython3
-
- pseudo_label_example = [1, 2]
- reasoning_result = kb.logic_forward(pseudo_label_example)
- print(f"Reasoning result of pseudo-label example {pseudo_label_example} is {reasoning_result}.")
-
-
- Out:
- .. code:: none
- :class: code-out
-
- Reasoning result of pseudo-label example [1, 2] is 3.
-
-
- .. note::
-
- In addition to building a knowledge base based on ``KBBase``, we
- can also establish a knowledge base with a ground KB using ``GroundKB``,
- or a knowledge base implemented based on Prolog files using
- ``PrologKB``. The corresponding code for these implementations can be
- found in the ``main.py`` file. Those interested are encouraged to
- examine it for further insights.
-
- Then, we create a reasoner by instantiating the class ``Reasoner``. Due
- to the indeterminism of abductive reasoning, there could be multiple
- candidates compatible to the knowledge base. When this happens, reasoner
- can minimize inconsistencies between the knowledge base and
- pseudo-labels predicted by the learning part, and then return only one
- candidate that has the highest consistency.
-
- .. code:: ipython3
-
- reasoner = Reasoner(kb)
-
- .. note::
-
- During creating reasoner, the definition of “consistency” can be
- customized within the ``dist_func`` parameter. In the code above, we
- employ a consistency measurement based on confidence, which calculates
- the consistency between the data example and candidates based on the
- confidence derived from the predicted probability. In ``examples/mnist_add/main.py``, we
- provide options for utilizing other forms of consistency measurement.
-
- Also, during process of inconsistency minimization, we can leverage
- `ZOOpt library <https://github.com/polixir/ZOOpt>`__ for acceleration.
- Options for this are also available in ``examples/mnist_add/main.py``. Those interested are
- encouraged to explore these features.
-
- Building Evaluation Metrics
- ---------------------------
-
- Next, we set up evaluation metrics. These metrics will be used to
- evaluate the model performance during training and testing.
- Specifically, we use ``SymbolMetric`` and ``ReasoningMetric``, which are
- used to evaluate the accuracy of the machine learning model’s
- predictions and the accuracy of the final reasoning results,
- respectively.
-
- .. code:: ipython3
-
- metric_list = [SymbolMetric(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")]
-
- Bridge Learning and Reasoning
- -----------------------------
-
- Now, the last step is to bridge the learning and reasoning part. We
- proceed this step by creating an instance of ``SimpleBridge``.
-
- .. code:: ipython3
-
- bridge = SimpleBridge(model, reasoner, metric_list)
-
- Perform training and testing by invoking the ``train`` and ``test``
- methods of ``SimpleBridge``.
-
- .. code:: ipython3
-
- # Build logger
- print_log("Abductive Learning on the MNIST Addition example.", logger="current")
- log_dir = ABLLogger.get_current_instance().log_dir
- weights_dir = osp.join(log_dir, "weights")
-
- bridge.train(train_data, loops=5, segment_size=1/3, save_interval=1, save_dir=weights_dir)
- bridge.test(test_data)
-
- Out:
- .. code:: none
- :class: code-out
-
- abl - INFO - Abductive Learning on the MNIST Addition example.
- abl - INFO - loop(train) [1/5] segment(train) [1/3]
- abl - INFO - model loss: 1.49104
- abl - INFO - loop(train) [1/5] segment(train) [2/3]
- abl - INFO - model loss: 1.24945
- abl - INFO - loop(train) [1/5] segment(train) [3/3]
- abl - INFO - model loss: 0.87861
- abl - INFO - Evaluation start: loop(val) [1]
- abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.818 mnist_add/reasoning_accuracy: 0.672
- abl - INFO - Saving model: loop(save) [1]
- abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_1.pth
- abl - INFO - loop(train) [2/5] segment(train) [1/3]
- abl - INFO - model loss: 0.31148
- abl - INFO - loop(train) [2/5] segment(train) [2/3]
- abl - INFO - model loss: 0.09520
- abl - INFO - loop(train) [2/5] segment(train) [3/3]
- abl - INFO - model loss: 0.07402
- abl - INFO - Evaluation start: loop(val) [2]
- abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.982 mnist_add/reasoning_accuracy: 0.964
- abl - INFO - Saving model: loop(save) [2]
- abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_2.pth
- abl - INFO - loop(train) [3/5] segment(train) [1/3]
- abl - INFO - model loss: 0.06027
- abl - INFO - loop(train) [3/5] segment(train) [2/3]
- abl - INFO - model loss: 0.05341
- abl - INFO - loop(train) [3/5] segment(train) [3/3]
- abl - INFO - model loss: 0.04915
- abl - INFO - Evaluation start: loop(val) [3]
- abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.987 mnist_add/reasoning_accuracy: 0.975
- abl - INFO - Saving model: loop(save) [3]
- abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_3.pth
- abl - INFO - loop(train) [4/5] segment(train) [1/3]
- abl - INFO - model loss: 0.04413
- abl - INFO - loop(train) [4/5] segment(train) [2/3]
- abl - INFO - model loss: 0.04181
- abl - INFO - loop(train) [4/5] segment(train) [3/3]
- abl - INFO - model loss: 0.04127
- abl - INFO - Evaluation start: loop(val) [4]
- abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.990 mnist_add/reasoning_accuracy: 0.980
- abl - INFO - Saving model: loop(save) [4]
- abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_4.pth
- abl - INFO - loop(train) [5/5] segment(train) [1/3]
- abl - INFO - model loss: 0.03544
- abl - INFO - loop(train) [5/5] segment(train) [2/3]
- abl - INFO - model loss: 0.03092
- abl - INFO - loop(train) [5/5] segment(train) [3/3]
- abl - INFO - model loss: 0.03663
- abl - INFO - Evaluation start: loop(val) [5]
- abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.991 mnist_add/reasoning_accuracy: 0.982
- abl - INFO - Saving model: loop(save) [5]
- abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_5.pth
- abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.987 mnist_add/reasoning_accuracy: 0.974
-
- More concrete examples are available in ``examples/mnist_add/main.py`` and ``examples/mnist_add/mnist_add.ipynb``.
|