You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MNISTAdd.rst 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. MNIST Addition
  2. ==============
  3. .. raw:: html
  4. <p>For detailed code implementation, please view it on <a class="reference external" href="https://github.com/AbductiveLearning/ABLkit/tree/main/examples/mnist_add" target="_blank">GitHub</a>.</p>
  5. Below shows an implementation of `MNIST
  6. Addition <https://arxiv.org/abs/1805.10872>`__. In this task, pairs of
  7. MNIST handwritten images and their sums are given, alongwith a domain
  8. knowledge base containing information on how to perform addition
  9. operations. The task is to recognize the digits of handwritten images
  10. and accurately determine their sum.
  11. Intuitively, we first use a machine learning model (learning part) to
  12. convert the input images to digits (we call them pseudo-labels), and
  13. then use the knowledge base (reasoning part) to calculate the sum of
  14. these digits. Since we do not have ground-truth of the digits, in
  15. Abductive Learning, the reasoning part will leverage domain knowledge
  16. and revise the initial digits yielded by the learning part through
  17. abductive reasoning. This process enables us to further update the
  18. machine learning model.
  19. .. code:: python
  20. # Import necessary libraries and modules
  21. import os.path as osp
  22. import matplotlib.pyplot as plt
  23. import torch
  24. import torch.nn as nn
  25. from torch.optim import RMSprop, lr_scheduler
  26. from ablkit.bridge import SimpleBridge
  27. from ablkit.data.evaluation import ReasoningMetric, SymbolAccuracy
  28. from ablkit.learning import ABLModel, BasicNN
  29. from ablkit.reasoning import KBBase, Reasoner
  30. from ablkit.utils import ABLLogger, print_log
  31. from datasets import get_dataset
  32. from models.nn import LeNet5
  33. Working with Data
  34. -----------------
  35. First, we get the training and testing datasets:
  36. .. code:: python
  37. train_data = get_dataset(train=True, get_pseudo_label=True)
  38. test_data = get_dataset(train=False, get_pseudo_label=True)
  39. ``train_data`` and ``test_data`` share identical structures:
  40. tuples with three components: X (list where each element is a
  41. list of two images), gt_pseudo_label (list where each element
  42. is a list of two digits, i.e., pseudo-labels) and Y (list where
  43. each element is the sum of the two digits). The length and structures
  44. of datasets are illustrated as follows.
  45. .. note::
  46. ``gt_pseudo_label`` is only used to evaluate the performance of
  47. the learning part but not to train the model.
  48. .. code:: python
  49. print(f"Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y")
  50. print("\n")
  51. train_X, train_gt_pseudo_label, train_Y = train_data
  52. print(f"Length of X, gt_pseudo_label, Y in train_data: " +
  53. f"{len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}")
  54. test_X, test_gt_pseudo_label, test_Y = test_data
  55. print(f"Length of X, gt_pseudo_label, Y in test_data: " +
  56. f"{len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}")
  57. print("\n")
  58. X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]
  59. print(f"X is a {type(train_X).__name__}, " +
  60. f"with each element being a {type(X_0).__name__} " +
  61. f"of {len(X_0)} {type(X_0[0]).__name__}.")
  62. print(f"gt_pseudo_label is a {type(train_gt_pseudo_label).__name__}, " +
  63. f"with each element being a {type(gt_pseudo_label_0).__name__} " +
  64. f"of {len(gt_pseudo_label_0)} {type(gt_pseudo_label_0[0]).__name__}.")
  65. print(f"Y is a {type(train_Y).__name__}, " +
  66. f"with each element being an {type(Y_0).__name__}.")
  67. Out:
  68. .. code:: none
  69. :class: code-out
  70. Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y
  71. Length of X, gt_pseudo_label, Y in train_data: 30000, 30000, 30000
  72. Length of X, gt_pseudo_label, Y in test_data: 5000, 5000, 5000
  73. X is a list, with each element being a list of 2 Tensor.
  74. gt_pseudo_label is a list, with each element being a list of 2 int.
  75. Y is a list, with each element being an int.
  76. The ith element of X, gt_pseudo_label, and Y together constitute the ith
  77. data example. As an illustration, in the first data example of the
  78. training set, we have:
  79. .. code:: python
  80. X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]
  81. print(f"X in the first data example (a list of two images):")
  82. plt.subplot(1,2,1)
  83. plt.axis('off')
  84. plt.imshow(X_0[0].squeeze(), cmap='gray')
  85. plt.subplot(1,2,2)
  86. plt.axis('off')
  87. plt.imshow(X_0[1].squeeze(), cmap='gray')
  88. plt.show()
  89. print(f"gt_pseudo_label in the first data example (a list of two ground truth pseudo-labels): {gt_pseudo_label_0}")
  90. print(f"Y in the first data example (their sum result): {Y_0}")
  91. Out:
  92. .. code:: none
  93. :class: code-out
  94. X in the first data example (a list of two images):
  95. .. image:: ../_static/img/mnist_add_datasets.png
  96. :width: 200px
  97. .. code:: none
  98. :class: code-out
  99. gt_pseudo_label in the first data example (a list of two ground truth pseudo-labels): [7, 5]
  100. Y in the first data example (their sum result): 12
  101. Building the Learning Part
  102. --------------------------
  103. To build the learning part, we need to first build a machine learning
  104. base model. We use a simple `LeNet-5 neural
  105. network <https://en.wikipedia.org/wiki/LeNet>`__, and encapsulate it
  106. within a ``BasicNN`` object to create the base model. ``BasicNN`` is a
  107. class that encapsulates a PyTorch model, transforming it into a base
  108. model with a sklearn-style interface.
  109. .. code:: python
  110. cls = LeNet5(num_classes=10)
  111. loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
  112. optimizer = RMSprop(cls.parameters(), lr=0.001, alpha=0.9)
  113. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  114. scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, pct_start=0.1, total_steps=100)
  115. base_model = BasicNN(
  116. cls,
  117. loss_fn,
  118. optimizer,
  119. scheduler=scheduler,
  120. device=device,
  121. batch_size=32,
  122. num_epochs=1,
  123. )
  124. ``BasicNN`` offers methods like ``predict`` and ``predict_prob``, which
  125. are used to predict the class index and the probabilities of each class
  126. for images. As shown below:
  127. .. code:: python
  128. data_instances = [torch.randn(1, 28, 28) for _ in range(32)]
  129. pred_idx = base_model.predict(X=data_instances)
  130. print(f"Predicted class index for a batch of 32 instances: np.ndarray with shape {pred_idx.shape}")
  131. pred_prob = base_model.predict_proba(X=data_instances)
  132. print(f"Predicted class probabilities for a batch of 32 instances: np.ndarray with shape {pred_prob.shape}")
  133. Out:
  134. .. code:: none
  135. :class: code-out
  136. Predicted class index for a batch of 32 instances: np.ndarray with shape (32,)
  137. Predicted class probabilities for a batch of 32 instances: np.ndarray with shape (32, 10)
  138. However, the base model built above deals with instance-level data
  139. (i.e., individual images), and can not directly deal with example-level
  140. data (i.e., a pair of images). Therefore, we wrap the base model into
  141. ``ABLModel``, which enables the learning part to train, test, and
  142. predict on example-level data.
  143. .. code:: python
  144. model = ABLModel(base_model)
  145. As an illustration, consider this example of training on example-level
  146. data using the ``predict`` method in ``ABLModel``. In this process, the
  147. method accepts data examples as input and outputs the class labels and
  148. the probabilities of each class for all instances within these data
  149. examples.
  150. .. code:: python
  151. from ablkit.data.structures import ListData
  152. # ListData is a data structure provided by ABLkit that can be used to organize data examples
  153. data_examples = ListData()
  154. # We use the first 100 data examples in the training set as an illustration
  155. data_examples.X = train_X[:100]
  156. data_examples.gt_pseudo_label = train_gt_pseudo_label[:100]
  157. data_examples.Y = train_Y[:100]
  158. # Perform prediction on the 100 data examples
  159. pred_label, pred_prob = model.predict(data_examples)['label'], model.predict(data_examples)['prob']
  160. print(f"Predicted class labels for the 100 data examples: \n" +
  161. f"a list of length {len(pred_label)}, and each element is " +
  162. f"a {type(pred_label[0]).__name__} of shape {pred_label[0].shape}.\n")
  163. print(f"Predicted class probabilities for the 100 data examples: \n" +
  164. f"a list of length {len(pred_prob)}, and each element is " +
  165. f"a {type(pred_prob[0]).__name__} of shape {pred_prob[0].shape}.")
  166. Out:
  167. .. code:: none
  168. :class: code-out
  169. Predicted class labels for the 100 data examples:
  170. a list of length 100, and each element is a ndarray of shape (2,).
  171. Predicted class probabilities for the 100 data examples:
  172. a list of length 100, and each element is a ndarray of shape (2, 10).
  173. Building the Reasoning Part
  174. ---------------------------
  175. In the reasoning part, we first build a knowledge base which contains
  176. information on how to perform addition operations. We build it by
  177. creating a subclass of ``KBBase``. In the derived subclass, we
  178. initialize the ``pseudo_label_list`` parameter specifying list of
  179. possible pseudo-labels, and override the ``logic_forward`` function
  180. defining how to perform (deductive) reasoning.
  181. .. code:: python
  182. class AddKB(KBBase):
  183. def __init__(self, pseudo_label_list=list(range(10))):
  184. super().__init__(pseudo_label_list)
  185. # Implement the deduction function
  186. def logic_forward(self, nums):
  187. return sum(nums)
  188. kb = AddKB()
  189. The knowledge base can perform logical reasoning (both deductive
  190. reasoning and abductive reasoning). Below is an example of performing
  191. (deductive) reasoning, and users can refer to :ref:`Performing abductive
  192. reasoning in the knowledge base <kb-abd>` for details of abductive reasoning.
  193. .. code:: python
  194. pseudo_labels = [1, 2]
  195. reasoning_result = kb.logic_forward(pseudo_labels)
  196. print(f"Reasoning result of pseudo-labels {pseudo_labels} is {reasoning_result}.")
  197. Out:
  198. .. code:: none
  199. :class: code-out
  200. Reasoning result of pseudo-labels [1, 2] is 3.
  201. .. note::
  202. In addition to building a knowledge base based on ``KBBase``, we
  203. can also establish a knowledge base with a ground KB using ``GroundKB``,
  204. or a knowledge base implemented based on Prolog files using
  205. ``PrologKB``. The corresponding code for these implementations can be
  206. found in the ``main.py`` file. Those interested are encouraged to
  207. examine it for further insights.
  208. Then, we create a reasoner by instantiating the class ``Reasoner``. Due
  209. to the indeterminism of abductive reasoning, there could be multiple
  210. candidates compatible with the knowledge base. When this happens, reasoner
  211. can minimize inconsistencies between the knowledge base and
  212. pseudo-labels predicted by the learning part, and then return only one
  213. candidate that has the highest consistency.
  214. .. code:: python
  215. reasoner = Reasoner(kb)
  216. .. note::
  217. During creating reasoner, the definition of “consistency” can be
  218. customized within the ``dist_func`` parameter. In the code above, we
  219. employ a consistency measurement based on confidence, which calculates
  220. the consistency between the data example and candidates based on the
  221. confidence derived from the predicted probability. In ``examples/mnist_add/main.py``, we
  222. provide options for utilizing other forms of consistency measurement.
  223. Also, during the process of inconsistency minimization, we can leverage
  224. `ZOOpt library <https://github.com/polixir/ZOOpt>`__ for acceleration.
  225. Options for this are also available in ``examples/mnist_add/main.py``. Those interested are
  226. encouraged to explore these features.
  227. Building Evaluation Metrics
  228. ---------------------------
  229. Next, we set up evaluation metrics. These metrics will be used to
  230. evaluate the model performance during training and testing.
  231. Specifically, we use ``SymbolAccuracy`` and ``ReasoningMetric``, which are
  232. used to evaluate the accuracy of the machine learning model’s
  233. predictions and the accuracy of the final reasoning results,
  234. respectively.
  235. .. code:: python
  236. metric_list = [SymbolAccuracy(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")]
  237. Bridging Learning and Reasoning
  238. -------------------------------
  239. Now, the last step is to bridge the learning and reasoning part. We
  240. proceed with this step by creating an instance of ``SimpleBridge``.
  241. .. code:: python
  242. bridge = SimpleBridge(model, reasoner, metric_list)
  243. Perform training and testing by invoking the ``train`` and ``test``
  244. methods of ``SimpleBridge``.
  245. .. code:: python
  246. # Build logger
  247. print_log("Abductive Learning on the MNIST Addition example.", logger="current")
  248. log_dir = ABLLogger.get_current_instance().log_dir
  249. weights_dir = osp.join(log_dir, "weights")
  250. bridge.train(train_data, loops=1, segment_size=0.01, save_interval=1, save_dir=weights_dir)
  251. bridge.test(test_data)
  252. The log will appear similar to the following:
  253. Log:
  254. .. code:: none
  255. :class: code-out
  256. abl - INFO - Abductive Learning on the MNIST Addition example.
  257. abl - INFO - Working with Data.
  258. abl - INFO - Building the Learning Part.
  259. abl - INFO - Building the Reasoning Part.
  260. abl - INFO - Building Evaluation Metrics.
  261. abl - INFO - Bridge Learning and Reasoning.
  262. abl - INFO - loop(train) [1/2] segment(train) [1/100]
  263. abl - INFO - model loss: 2.25980
  264. abl - INFO - loop(train) [1/2] segment(train) [2/100]
  265. abl - INFO - model loss: 2.14168
  266. abl - INFO - loop(train) [1/2] segment(train) [3/100]
  267. abl - INFO - model loss: 2.02010
  268. ...
  269. abl - INFO - loop(train) [2/2] segment(train) [1/100]
  270. abl - INFO - model loss: 0.90260
  271. ...
  272. abl - INFO - Eval start: loop(val) [2]
  273. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.993 mnist_add/reasoning_accuracy: 0.986
  274. abl - INFO - Test start:
  275. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.991 mnist_add/reasoning_accuracy: 0.980
  276. Performance
  277. -----------
  278. We present the results of ABL as follows, which include the reasoning accuracy (the proportion of equations that are correctly summed), and the training time used to achieve this accuracy. These results are compared with the following methods:
  279. - `NeurASP <https://github.com/azreasoners/NeurASP>`_: An extension of answer set programs by treating the neural network output as the probability distribution over atomic facts;
  280. - `DeepProbLog <https://github.com/ML-KULeuven/deepproblog>`_: An extension of ProbLog by introducing neural predicates in Probabilistic Logic Programming;
  281. - `DeepStochLog <https://github.com/ML-KULeuven/deepstochlog>`_: A neural-symbolic framework based on stochastic logic program.
  282. .. table::
  283. :class: centered
  284. +--------------+----------+------------------------------+
  285. | Method | Accuracy | Time to achieve the Acc. (s) |
  286. +==============+==========+==============================+
  287. | NeurASP | 0.964 | 354 |
  288. +--------------+----------+------------------------------+
  289. | DeepProbLog | 0.965 | 1965 |
  290. +--------------+----------+------------------------------+
  291. | DeepStochLog | 0.975 | 727 |
  292. +--------------+----------+------------------------------+
  293. | ABL | 0.980 | 42 |
  294. +--------------+----------+------------------------------+

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.