You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MNISTAdd.rst 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. MNIST Addition
  2. ==============
  3. In this example, we show an implementation of `MNIST
  4. Addition <https://arxiv.org/abs/1805.10872>`_. In this task, pairs of
  5. MNIST handwritten images and their sums are given, alongwith a domain
  6. knowledge base containing information on how to perform addition
  7. operations. The task is to recognize the digits of handwritten
  8. images and accurately determine their sum.
  9. Intuitively, we first use a machine learning model (learning part) to
  10. convert the input images to digits (we call them pseudo labels), and
  11. then use the knowledge base (reasoning part) to calculate the sum of
  12. these digits. Since we do not have ground-truth of the digits, in
  13. abductive learning, the reasoning part will leverage domain knowledge
  14. and revise the initial digits yielded by the learning part through
  15. abductive reasoning. This process enables us to further update
  16. the machine learning model.
  17. .. code:: ipython3
  18. # Import necessary libraries and modules
  19. import os.path as osp
  20. import torch
  21. import torch.nn as nn
  22. import matplotlib.pyplot as plt
  23. from examples.mnist_add.datasets import get_dataset
  24. from examples.models.nn import LeNet5
  25. from abl.learning import ABLModel, BasicNN
  26. from abl.reasoning import KBBase, Reasoner
  27. from abl.evaluation import ReasoningMetric, SymbolMetric
  28. from abl.utils import ABLLogger, print_log
  29. from abl.bridge import SimpleBridge
  30. Working with Data
  31. -----------------
  32. First, we get the training and testing datasets:
  33. .. code:: ipython3
  34. train_data = get_dataset(train=True, get_pseudo_label=True)
  35. test_data = get_dataset(train=False, get_pseudo_label=True)
  36. Both datasets contain several data examples. In each data example, we
  37. have three components: X (a pair of images), gt_pseudo_label (a pair of
  38. corresponding ground truth digits, i.e., pseudo labels), and Y (their sum).
  39. The datasets are illustrated as follows.
  40. .. code:: ipython3
  41. print(f"There are {len(train_data[0])} data examples in the training set and {len(test_data[0])} data examples in the test set")
  42. print("As an illustration, in the first data example of the training set, we have:")
  43. print(f"X ({len(train_data[0][0])} images):")
  44. plt.subplot(1,2,1)
  45. plt.axis('off')
  46. plt.imshow(train_data[0][0][0].numpy().transpose(1, 2, 0))
  47. plt.subplot(1,2,2)
  48. plt.axis('off')
  49. plt.imshow(train_data[0][0][1].numpy().transpose(1, 2, 0))
  50. plt.show()
  51. print(f"gt_pseudo_label ({len(train_data[1][0])} ground truth pseudo label): {train_data[1][0][0]}, {train_data[1][0][1]}")
  52. print(f"Y (their sum result): {train_data[2][0]}")
  53. Out:
  54. .. code:: none
  55. :class: code-out
  56. There are 30000 data examples in the training set and 5000 data examples in the test set
  57. As an illustration, in the first data example of the training set, we have:
  58. X (2 images):
  59. .. image:: ../img/mnist_add_datasets.png
  60. :width: 400px
  61. .. code:: none
  62. :class: code-out
  63. gt_pseudo_label (2 ground truth pseudo label): 7, 5
  64. Y (their sum result): 12
  65. Building the Learning Part
  66. --------------------------
  67. To build the learning part, we need to first build a base machine
  68. learning model. We use a simple `LeNet-5 neural
  69. network <https://en.wikipedia.org/wiki/LeNet>`__ to complete this task,
  70. and encapsulate it within a ``BasicNN`` object to create the base model.
  71. ``BasicNN`` is a class that encapsulates a PyTorch model, transforming
  72. it into a base model with an sklearn-style interface.
  73. .. code:: ipython3
  74. cls = LeNet5(num_classes=10)
  75. loss_fn = nn.CrossEntropyLoss()
  76. optimizer = torch.optim.Adam(cls.parameters(), lr=0.001)
  77. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  78. base_model = BasicNN(
  79. cls,
  80. loss_fn,
  81. optimizer,
  82. device,
  83. batch_size=32,
  84. num_epochs=1,
  85. )
  86. ``BasicNN`` offers methods like ``predict`` and ``predict_prob``, which
  87. are used to predict the outcome class index and the probabilities for an
  88. image, respectively. As shown below:
  89. .. code:: ipython3
  90. pred_idx = base_model.predict(X=[torch.randn(1, 28, 28).to(device) for _ in range(32)])
  91. print(f"Shape of pred_idx for a batch of 32 examples: {pred_idx.shape}")
  92. pred_prob = base_model.predict_proba(X=[torch.randn(1, 28, 28).to(device) for _ in range(32)])
  93. print(f"Shape of pred_prob for a batch of 32 examples: {pred_prob.shape}")
  94. Out:
  95. .. code:: none
  96. :class: code-out
  97. Shape of pred_idx for a batch of 32 examples: (32,)
  98. Shape of pred_prob for a batch of 32 examples: (32, 10)
  99. However, the base model built above deals with instance-level data
  100. (i.e., a single image), and can not directly deal with example-level
  101. data (i.e., a pair of images). Therefore, we wrap the base model
  102. into ``ABLModel``, which enables the learning part to train, test,
  103. and predict on example-level data.
  104. .. code:: ipython3
  105. model = ABLModel(base_model)
  106. TODO: 示例展示ablmodel和base model的predict的不同
  107. .. code:: ipython3
  108. # from abl.structures import ListData
  109. # data_examples = ListData()
  110. # data_examples.X = [list(torch.randn(2, 1, 28, 28)) for _ in range(3)]
  111. # model.predict(data_examples)
  112. Building the Reasoning Part
  113. ---------------------------
  114. In the reasoning part, we first build a knowledge base which contain
  115. information on how to perform addition operations. We build it by
  116. creating a subclass of ``KBBase``. In the derived subclass, we first
  117. initialize the ``pseudo_label_list`` parameter specifying list of
  118. possible pseudo labels, and then override the ``logic_forward`` function
  119. defining how to perform (deductive) reasoning.
  120. .. code:: ipython3
  121. class AddKB(KBBase):
  122. def __init__(self, pseudo_label_list=list(range(10))):
  123. super().__init__(pseudo_label_list)
  124. # Implement the deduction function
  125. def logic_forward(self, nums):
  126. return sum(nums)
  127. kb = AddKB()
  128. The knowledge base can perform logical reasoning. Below is an example of
  129. performing (deductive) reasoning: # TODO: ABDUCTIVE REASONING
  130. .. code:: ipython3
  131. pseudo_label_example = [1, 2]
  132. reasoning_result = kb.logic_forward(pseudo_label_example)
  133. print(f"Reasoning result of pseudo label example {pseudo_label_example} is {reasoning_result}.")
  134. Out:
  135. .. code:: none
  136. :class: code-out
  137. Reasoning result of pseudo label example [1, 2] is 3.
  138. .. note::
  139. In addition to building a knowledge base based on ``KBBase``, we
  140. can also establish a knowledge base with a ground KB using ``GroundKB``,
  141. or a knowledge base implemented based on Prolog files using
  142. ``PrologKB``. The corresponding code for these implementations can be
  143. found in the ``examples/mnist_add/main.py`` file. Those interested are encouraged to
  144. examine it for further insights.
  145. Then, we create a reasoner by instantiating the class ``Reasoner``. Due
  146. to the indeterminism of abductive reasoning, there could be multiple
  147. candidates compatible to the knowledge base. When this happens, reasoner
  148. can minimize inconsistencies between the knowledge base and pseudo
  149. labels predicted by the learning part, and then return only one
  150. candidate that has highest consistency.
  151. .. code:: ipython3
  152. reasoner = Reasoner(kb)
  153. .. note::
  154. During creating reasoner, the definition of “consistency” can be
  155. customized within the ``dist_func`` parameter. In the code above, we
  156. employ a consistency measurement based on confidence, which calculates
  157. the consistency between the data example and candidates based on the
  158. confidence derived from the predicted probability. In ``examples/mnist_add/main.py``, we
  159. provide options for utilizing other forms of consistency measurement.
  160. Also, during process of inconsistency minimization, one can leverage
  161. `ZOOpt library <https://github.com/polixir/ZOOpt>`__ for acceleration.
  162. Options for this are also available in ``examples/mnist_add/main.py``. Those interested are
  163. encouraged to explore these features.
  164. Building Evaluation Metrics
  165. ---------------------------
  166. Next, we set up evaluation metrics. These metrics will be used to
  167. evaluate the model performance during training and testing.
  168. Specifically, we use ``SymbolMetric`` and ``ReasoningMetric``, which are
  169. used to evaluate the accuracy of the machine learning model’s
  170. predictions and the accuracy of the final reasoning results,
  171. respectively.
  172. .. code:: ipython3
  173. metric_list = [SymbolMetric(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")]
  174. Bridge Learning and Reasoning
  175. -----------------------------
  176. Now, the last step is to bridge the learning and reasoning part. We
  177. proceed this step by creating an instance of ``SimpleBridge``.
  178. .. code:: ipython3
  179. bridge = SimpleBridge(model, reasoner, metric_list)
  180. Perform training and testing by invoking the ``train`` and ``test``
  181. methods of ``SimpleBridge``.
  182. .. code:: ipython3
  183. # Build logger
  184. print_log("Abductive Learning on the MNIST Addition example.", logger="current")
  185. log_dir = ABLLogger.get_current_instance().log_dir
  186. weights_dir = osp.join(log_dir, "weights")
  187. bridge.train(train_data, loops=5, segment_size=1/3, save_interval=1, save_dir=weights_dir)
  188. bridge.test(test_data)
  189. Out:
  190. .. code:: none
  191. :class: code-out
  192. abl - INFO - Abductive Learning on the MNIST Addition example.
  193. abl - INFO - loop(train) [1/5] segment(train) [1/3]
  194. abl - INFO - model loss: 1.81231
  195. abl - INFO - loop(train) [1/5] segment(train) [2/3]
  196. abl - INFO - model loss: 1.37639
  197. abl - INFO - loop(train) [1/5] segment(train) [3/3]
  198. abl - INFO - model loss: 1.14446
  199. abl - INFO - Evaluation start: loop(val) [1]
  200. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.207 mnist_add/reasoning_accuracy: 0.245
  201. abl - INFO - Saving model: loop(save) [1]
  202. abl - INFO - Checkpoints will be saved to log_dir/weights/model_checkpoint_loop_1.pth
  203. abl - INFO - loop(train) [2/5] segment(train) [1/3]
  204. abl - INFO - model loss: 0.97430
  205. abl - INFO - loop(train) [2/5] segment(train) [2/3]
  206. abl - INFO - model loss: 0.91448
  207. abl - INFO - loop(train) [2/5] segment(train) [3/3]
  208. abl - INFO - model loss: 0.83089
  209. abl - INFO - Evaluation start: loop(val) [2]
  210. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.191 mnist_add/reasoning_accuracy: 0.353
  211. abl - INFO - Saving model: loop(save) [2]
  212. abl - INFO - Checkpoints will be saved to log_dir/weights/model_checkpoint_loop_2.pth
  213. abl - INFO - loop(train) [3/5] segment(train) [1/3]
  214. abl - INFO - model loss: 0.79906
  215. abl - INFO - loop(train) [3/5] segment(train) [2/3]
  216. abl - INFO - model loss: 0.77949
  217. abl - INFO - loop(train) [3/5] segment(train) [3/3]
  218. abl - INFO - model loss: 0.75007
  219. abl - INFO - Evaluation start: loop(val) [3]
  220. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.148 mnist_add/reasoning_accuracy: 0.385
  221. abl - INFO - Saving model: loop(save) [3]
  222. abl - INFO - Checkpoints will be saved to log_dir/weights/model_checkpoint_loop_3.pth
  223. abl - INFO - loop(train) [4/5] segment(train) [1/3]
  224. abl - INFO - model loss: 0.72659
  225. abl - INFO - loop(train) [4/5] segment(train) [2/3]
  226. abl - INFO - model loss: 0.70985
  227. abl - INFO - loop(train) [4/5] segment(train) [3/3]
  228. abl - INFO - model loss: 0.66337
  229. abl - INFO - Evaluation start: loop(val) [4]
  230. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.016 mnist_add/reasoning_accuracy: 0.494
  231. abl - INFO - Saving model: loop(save) [4]
  232. abl - INFO - Checkpoints will be saved to log_dir/weights/model_checkpoint_loop_4.pth
  233. abl - INFO - loop(train) [5/5] segment(train) [1/3]
  234. abl - INFO - model loss: 0.61140
  235. abl - INFO - loop(train) [5/5] segment(train) [2/3]
  236. abl - INFO - model loss: 0.57534
  237. abl - INFO - loop(train) [5/5] segment(train) [3/3]
  238. abl - INFO - model loss: 0.57018
  239. abl - INFO - Evaluation start: loop(val) [5]
  240. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.002 mnist_add/reasoning_accuracy: 0.507
  241. abl - INFO - Saving model: loop(save) [5]
  242. abl - INFO - Checkpoints will be saved to log_dir/weights/model_checkpoint_loop_5.pth
  243. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.002 mnist_add/reasoning_accuracy: 0.482
  244. More concrete examples are available in ``examples/mnist_add/main.py`` and ``examples/mnist_add/mnist_add.ipynb``.

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.