You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MNISTAdd.rst 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. MNIST Addition
  2. ==============
  3. Below shows an implementation of `MNIST
  4. Addition <https://arxiv.org/abs/1805.10872>`__. In this task, pairs of
  5. MNIST handwritten images and their sums are given, alongwith a domain
  6. knowledge base containing information on how to perform addition
  7. operations. The task is to recognize the digits of handwritten images
  8. and accurately determine their sum.
  9. Intuitively, we first use a machine learning model (learning part) to
  10. convert the input images to digits (we call them pseudo-labels), and
  11. then use the knowledge base (reasoning part) to calculate the sum of
  12. these digits. Since we do not have ground-truth of the digits, in
  13. Abductive Learning, the reasoning part will leverage domain knowledge
  14. and revise the initial digits yielded by the learning part through
  15. abductive reasoning. This process enables us to further update the
  16. machine learning model.
  17. .. code:: ipython3
  18. # Import necessary libraries and modules
  19. import os.path as osp
  20. import torch
  21. import torch.nn as nn
  22. import matplotlib.pyplot as plt
  23. from torch.optim import RMSprop, lr_scheduler
  24. from datasets import get_dataset
  25. from models.nn import LeNet5
  26. from abl.learning import ABLModel, BasicNN
  27. from abl.reasoning import KBBase, Reasoner
  28. from abl.data.evaluation import ReasoningMetric, SymbolAccuracy
  29. from abl.utils import ABLLogger, print_log
  30. from abl.bridge import SimpleBridge
  31. Working with Data
  32. -----------------
  33. First, we get the training and testing datasets:
  34. .. code:: ipython3
  35. train_data = get_dataset(train=True, get_pseudo_label=True)
  36. test_data = get_dataset(train=False, get_pseudo_label=True)
  37. ``train_data`` and ``test_data`` share identical structures:
  38. tuples with three components: X (list where each element is a
  39. list of two images), gt_pseudo_label (list where each element
  40. is a list of two digits, i.e., pseudo-labels) and Y (list where
  41. each element is the sum of the two digits). The length and structures
  42. of datasets are illustrated as follows.
  43. .. note::
  44. ``gt_pseudo_label`` is only used to evaluate the performance of
  45. the learning part but not to train the model.
  46. .. code:: ipython3
  47. print(f"Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y")
  48. print("\n")
  49. train_X, train_gt_pseudo_label, train_Y = train_data
  50. print(f"Length of X, gt_pseudo_label, Y in train_data: " +
  51. f"{len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}")
  52. test_X, test_gt_pseudo_label, test_Y = test_data
  53. print(f"Length of X, gt_pseudo_label, Y in test_data: " +
  54. f"{len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}")
  55. print("\n")
  56. X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]
  57. print(f"X is a {type(train_X).__name__}, " +
  58. f"with each element being a {type(X_0).__name__} " +
  59. f"of {len(X_0)} {type(X_0[0]).__name__}.")
  60. print(f"gt_pseudo_label is a {type(train_gt_pseudo_label).__name__}, " +
  61. f"with each element being a {type(gt_pseudo_label_0).__name__} " +
  62. f"of {len(gt_pseudo_label_0)} {type(gt_pseudo_label_0[0]).__name__}.")
  63. print(f"Y is a {type(train_Y).__name__}, " +
  64. f"with each element being a {type(Y_0).__name__}.")
  65. Out:
  66. .. code:: none
  67. :class: code-out
  68. Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y
  69. Length of X, gt_pseudo_label, Y in train_data: 30000, 30000, 30000
  70. Length of X, gt_pseudo_label, Y in test_data: 5000, 5000, 5000
  71. X is a list, with each element being a list of 2 Tensor.
  72. gt_pseudo_label is a list, with each element being a list of 2 int.
  73. Y is a list, with each element being a int.
  74. The ith element of X, gt_pseudo_label, and Y together constitute the ith
  75. data example. As an illustration, in the first data example of the
  76. training set, we have:
  77. .. code:: ipython3
  78. X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]
  79. print(f"X in the first data example (a list of two images):")
  80. plt.subplot(1,2,1)
  81. plt.axis('off')
  82. plt.imshow(X_0[0].squeeze(), cmap='gray')
  83. plt.subplot(1,2,2)
  84. plt.axis('off')
  85. plt.imshow(X_0[1].squeeze(), cmap='gray')
  86. plt.show()
  87. print(f"gt_pseudo_label in the first data example (a list of two ground truth pseudo-labels): {gt_pseudo_label_0}")
  88. print(f"Y in the first data example (their sum result): {Y_0}")
  89. Out:
  90. .. code:: none
  91. :class: code-out
  92. X in the first data example (a list of two images):
  93. .. image:: ../img/mnist_add_datasets.png
  94. :width: 200px
  95. .. code:: none
  96. :class: code-out
  97. gt_pseudo_label in the first data example (a list of two ground truth pseudo-labels): [7, 5]
  98. Y in the first data example (their sum result): 12
  99. Building the Learning Part
  100. --------------------------
  101. To build the learning part, we need to first build a machine learning
  102. base model. We use a simple `LeNet-5 neural
  103. network <https://en.wikipedia.org/wiki/LeNet>`__, and encapsulate it
  104. within a ``BasicNN`` object to create the base model. ``BasicNN`` is a
  105. class that encapsulates a PyTorch model, transforming it into a base
  106. model with an sklearn-style interface.
  107. .. code:: ipython3
  108. cls = LeNet5(num_classes=10)
  109. loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
  110. optimizer = RMSprop(cls.parameters(), lr=0.001, alpha=0.9)
  111. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  112. scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, pct_start=0.1, total_steps=100)
  113. base_model = BasicNN(
  114. cls,
  115. loss_fn,
  116. optimizer,
  117. scheduler=scheduler,
  118. device=device,
  119. batch_size=32,
  120. num_epochs=1,
  121. )
  122. ``BasicNN`` offers methods like ``predict`` and ``predict_prob``, which
  123. are used to predict the class index and the probabilities of each class
  124. for images. As shown below:
  125. .. code:: ipython3
  126. data_instances = [torch.randn(1, 28, 28).to(device) for _ in range(32)]
  127. pred_idx = base_model.predict(X=data_instances)
  128. print(f"Predicted class index for a batch of 32 instances: np.ndarray with shape {pred_idx.shape}")
  129. pred_prob = base_model.predict_proba(X=data_instances)
  130. print(f"Predicted class probabilities for a batch of 32 instances: np.ndarray with shape {pred_prob.shape}")
  131. Out:
  132. .. code:: none
  133. :class: code-out
  134. Predicted class index for a batch of 32 instances: np.ndarray with shape (32,)
  135. Predicted class probabilities for a batch of 32 instances: np.ndarray with shape (32, 10)
  136. However, the base model built above deals with instance-level data
  137. (i.e., individual images), and can not directly deal with example-level
  138. data (i.e., a pair of images). Therefore, we wrap the base model into
  139. ``ABLModel``, which enables the learning part to train, test, and
  140. predict on example-level data.
  141. .. code:: ipython3
  142. model = ABLModel(base_model)
  143. As an illustration, consider this example of training on example-level
  144. data using the ``predict`` method in ``ABLModel``. In this process, the
  145. method accepts data examples as input and outputs the class labels and
  146. the probabilities of each class for all instances within these data
  147. examples.
  148. .. code:: ipython3
  149. from abl.data.structures import ListData
  150. # ListData is a data structure provided by ABL-Package that can be used to organize data examples
  151. data_examples = ListData()
  152. # We use the first 100 data examples in the training set as an illustration
  153. data_examples.X = train_X[:100]
  154. data_examples.gt_pseudo_label = train_gt_pseudo_label[:100]
  155. data_examples.Y = train_Y[:100]
  156. # Perform prediction on the 100 data examples
  157. pred_label, pred_prob = model.predict(data_examples)['label'], model.predict(data_examples)['prob']
  158. print(f"Predicted class labels for the 100 data examples: \n" +
  159. f"a list of length {len(pred_label)}, and each element is " +
  160. f"a {type(pred_label[0]).__name__} of shape {pred_label[0].shape}.\n")
  161. print(f"Predicted class probabilities for the 100 data examples: \n" +
  162. f"a list of length {len(pred_prob)}, and each element is " +
  163. f"a {type(pred_prob[0]).__name__} of shape {pred_prob[0].shape}.")
  164. Out:
  165. .. code:: none
  166. :class: code-out
  167. Predicted class labels for the 100 data examples:
  168. a list of length 100, and each element is a ndarray of shape (2,).
  169. Predicted class probabilities for the 100 data examples:
  170. a list of length 100, and each element is a ndarray of shape (2, 10).
  171. Building the Reasoning Part
  172. ---------------------------
  173. In the reasoning part, we first build a knowledge base which contain
  174. information on how to perform addition operations. We build it by
  175. creating a subclass of ``KBBase``. In the derived subclass, we
  176. initialize the ``pseudo_label_list`` parameter specifying list of
  177. possible pseudo-labels, and override the ``logic_forward`` function
  178. defining how to perform (deductive) reasoning.
  179. .. code:: ipython3
  180. class AddKB(KBBase):
  181. def __init__(self, pseudo_label_list=list(range(10))):
  182. super().__init__(pseudo_label_list)
  183. # Implement the deduction function
  184. def logic_forward(self, nums):
  185. return sum(nums)
  186. kb = AddKB()
  187. The knowledge base can perform logical reasoning (both deductive
  188. reasoning and abductive reasoning). Below is an example of performing
  189. (deductive) reasoning, and users can refer to :ref:`Performing abductive
  190. reasoning in the knowledge base <kb-abd>` for details of abductive reasoning.
  191. .. code:: ipython3
  192. pseudo_labels = [1, 2]
  193. reasoning_result = kb.logic_forward(pseudo_labels)
  194. print(f"Reasoning result of pseudo-labels {pseudo_labels} is {reasoning_result}.")
  195. Out:
  196. .. code:: none
  197. :class: code-out
  198. Reasoning result of pseudo-labels [1, 2] is 3.
  199. .. note::
  200. In addition to building a knowledge base based on ``KBBase``, we
  201. can also establish a knowledge base with a ground KB using ``GroundKB``,
  202. or a knowledge base implemented based on Prolog files using
  203. ``PrologKB``. The corresponding code for these implementations can be
  204. found in the ``main.py`` file. Those interested are encouraged to
  205. examine it for further insights.
  206. Then, we create a reasoner by instantiating the class ``Reasoner``. Due
  207. to the indeterminism of abductive reasoning, there could be multiple
  208. candidates compatible to the knowledge base. When this happens, reasoner
  209. can minimize inconsistencies between the knowledge base and
  210. pseudo-labels predicted by the learning part, and then return only one
  211. candidate that has the highest consistency.
  212. .. code:: ipython3
  213. reasoner = Reasoner(kb)
  214. .. note::
  215. During creating reasoner, the definition of “consistency” can be
  216. customized within the ``dist_func`` parameter. In the code above, we
  217. employ a consistency measurement based on confidence, which calculates
  218. the consistency between the data example and candidates based on the
  219. confidence derived from the predicted probability. In ``examples/mnist_add/main.py``, we
  220. provide options for utilizing other forms of consistency measurement.
  221. Also, during process of inconsistency minimization, we can leverage
  222. `ZOOpt library <https://github.com/polixir/ZOOpt>`__ for acceleration.
  223. Options for this are also available in ``examples/mnist_add/main.py``. Those interested are
  224. encouraged to explore these features.
  225. Building Evaluation Metrics
  226. ---------------------------
  227. Next, we set up evaluation metrics. These metrics will be used to
  228. evaluate the model performance during training and testing.
  229. Specifically, we use ``SymbolAccuracy`` and ``ReasoningMetric``, which are
  230. used to evaluate the accuracy of the machine learning model’s
  231. predictions and the accuracy of the final reasoning results,
  232. respectively.
  233. .. code:: ipython3
  234. metric_list = [SymbolAccuracy(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")]
  235. Bridge Learning and Reasoning
  236. -----------------------------
  237. Now, the last step is to bridge the learning and reasoning part. We
  238. proceed this step by creating an instance of ``SimpleBridge``.
  239. .. code:: ipython3
  240. bridge = SimpleBridge(model, reasoner, metric_list)
  241. Perform training and testing by invoking the ``train`` and ``test``
  242. methods of ``SimpleBridge``.
  243. .. code:: ipython3
  244. # Build logger
  245. print_log("Abductive Learning on the MNIST Addition example.", logger="current")
  246. log_dir = ABLLogger.get_current_instance().log_dir
  247. weights_dir = osp.join(log_dir, "weights")
  248. bridge.train(train_data, loops=1, segment_size=0.01, save_interval=1, save_dir=weights_dir)
  249. bridge.test(test_data)
  250. Out:
  251. .. code:: none
  252. :class: code-out
  253. abl - INFO - Abductive Learning on the MNIST Addition example.
  254. abl - INFO - loop(train) [1/1] segment(train) [1/100]
  255. abl - INFO - model loss: 2.23587
  256. abl - INFO - loop(train) [1/1] segment(train) [2/100]
  257. abl - INFO - model loss: 2.23756
  258. abl - INFO - loop(train) [1/1] segment(train) [3/100]
  259. abl - INFO - model loss: 2.04475
  260. abl - INFO - loop(train) [1/1] segment(train) [4/100]
  261. abl - INFO - model loss: 2.01035
  262. abl - INFO - loop(train) [1/1] segment(train) [5/100]
  263. abl - INFO - model loss: 1.97584
  264. abl - INFO - loop(train) [1/1] segment(train) [6/100]
  265. abl - INFO - model loss: 1.91570
  266. abl - INFO - loop(train) [1/1] segment(train) [7/100]
  267. abl - INFO - model loss: 1.90268
  268. abl - INFO - loop(train) [1/1] segment(train) [8/100]
  269. abl - INFO - model loss: 1.77436
  270. abl - INFO - loop(train) [1/1] segment(train) [9/100]
  271. abl - INFO - model loss: 1.73454
  272. abl - INFO - loop(train) [1/1] segment(train) [10/100]
  273. abl - INFO - model loss: 1.62495
  274. abl - INFO - loop(train) [1/1] segment(train) [11/100]
  275. abl - INFO - model loss: 1.58456
  276. abl - INFO - loop(train) [1/1] segment(train) [12/100]
  277. abl - INFO - model loss: 1.62575
  278. ...
  279. abl - INFO - Eval start: loop(val) [1]
  280. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.986 mnist_add/reasoning_accuracy: 0.973
  281. abl - INFO - Saving model: loop(save) [1]
  282. abl - INFO - Checkpoints will be saved to results/20231222_22_25_07/weights/model_checkpoint_loop_1.pth
  283. abl - INFO - Test start:
  284. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.983 mnist_add/reasoning_accuracy: 0.967
  285. More concrete examples are available in ``examples/mnist_add/main.py`` and ``examples/mnist_add/mnist_add.ipynb``.

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.