You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MNISTAdd.rst 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. MNIST Addition
  2. ==============
  3. Below shows an implementation of `MNIST
  4. Addition <https://arxiv.org/abs/1805.10872>`__. In this task, pairs of
  5. MNIST handwritten images and their sums are given, alongwith a domain
  6. knowledge base containing information on how to perform addition
  7. operations. The task is to recognize the digits of handwritten images
  8. and accurately determine their sum.
  9. Intuitively, we first use a machine learning model (learning part) to
  10. convert the input images to digits (we call them pseudo-labels), and
  11. then use the knowledge base (reasoning part) to calculate the sum of
  12. these digits. Since we do not have ground-truth of the digits, in
  13. Abductive Learning, the reasoning part will leverage domain knowledge
  14. and revise the initial digits yielded by the learning part through
  15. abductive reasoning. This process enables us to further update the
  16. machine learning model.
  17. .. code:: ipython3
  18. # Import necessary libraries and modules
  19. import os.path as osp
  20. import torch
  21. import torch.nn as nn
  22. import matplotlib.pyplot as plt
  23. from examples.mnist_add.datasets import get_dataset
  24. from examples.models.nn import LeNet5
  25. from abl.learning import ABLModel, BasicNN
  26. from abl.reasoning import KBBase, Reasoner
  27. from abl.evaluation import ReasoningMetric, SymbolMetric
  28. from abl.utils import ABLLogger, print_log
  29. from abl.bridge import SimpleBridge
  30. Working with Data
  31. -----------------
  32. First, we get the training and testing datasets:
  33. .. code:: ipython3
  34. train_data = get_dataset(train=True, get_pseudo_label=True)
  35. test_data = get_dataset(train=False, get_pseudo_label=True)
  36. ``train_data`` and ``test_data`` share identical structures:
  37. tuples with three components: X (list where each element is a
  38. list of two images), gt_pseudo_label (list where each element
  39. is a list of two digits, i.e., pseudo-labels) and Y (list where
  40. each element is the sum of the two digits). The length and structures
  41. of datasets are illustrated as follows.
  42. .. note::
  43. ``gt_pseudo_label`` is only used to evaluate the performance of
  44. the learning part but not to train the model.
  45. .. code:: ipython3
  46. print(f"Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y")
  47. print("\n")
  48. train_X, train_gt_pseudo_label, train_Y = train_data
  49. print(f"Length of X, gt_pseudo_label, Y in train_data: " +
  50. f"{len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}")
  51. test_X, test_gt_pseudo_label, test_Y = test_data
  52. print(f"Length of X, gt_pseudo_label, Y in test_data: " +
  53. f"{len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}")
  54. print("\n")
  55. X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]
  56. print(f"X is a {type(train_X).__name__}, " +
  57. f"with each element being a {type(X_0).__name__} " +
  58. f"of {len(X_0)} {type(X_0[0]).__name__}.")
  59. print(f"gt_pseudo_label is a {type(train_gt_pseudo_label).__name__}, " +
  60. f"with each element being a {type(gt_pseudo_label_0).__name__} " +
  61. f"of {len(gt_pseudo_label_0)} {type(gt_pseudo_label_0[0]).__name__}.")
  62. print(f"Y is a {type(train_Y).__name__}, " +
  63. f"with each element being a {type(Y_0).__name__}.")
  64. Out:
  65. .. code:: none
  66. :class: code-out
  67. Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y
  68. Length of X, gt_pseudo_label, Y in train_data: 30000, 30000, 30000
  69. Length of X, gt_pseudo_label, Y in test_data: 5000, 5000, 5000
  70. X is a list, with each element being a list of 2 Tensor.
  71. gt_pseudo_label is a list, with each element being a list of 2 int.
  72. Y is a list, with each element being a int.
  73. The ith element of X, gt_pseudo_label, and Y together constitute the ith
  74. data example. As an illustration, in the first data example of the
  75. training set, we have:
  76. .. code:: ipython3
  77. X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]
  78. print(f"X in the first data example (a list of two images):")
  79. plt.subplot(1,2,1)
  80. plt.axis('off')
  81. plt.imshow(X_0[0].numpy().transpose(1, 2, 0))
  82. plt.subplot(1,2,2)
  83. plt.axis('off')
  84. plt.imshow(X_0[1].numpy().transpose(1, 2, 0))
  85. plt.show()
  86. print(f"gt_pseudo_label in the first data example (a list of two ground truth pseudo-labels): {gt_pseudo_label_0}")
  87. print(f"Y in the first data example (their sum result): {Y_0}")
  88. Out:
  89. .. code:: none
  90. :class: code-out
  91. X in the first data example (a list of two images):
  92. .. image:: ../img/mnist_add_datasets.png
  93. :width: 400px
  94. .. parsed-literal::
  95. gt_pseudo_label in the first data example (a list of two ground truth pseudo-labels): [7, 5]
  96. Y in the first data example (their sum result): 12
  97. Building the Learning Part
  98. --------------------------
  99. To build the learning part, we need to first build a machine learning
  100. base model. We use a simple `LeNet-5 neural
  101. network <https://en.wikipedia.org/wiki/LeNet>`__, and encapsulate it
  102. within a ``BasicNN`` object to create the base model. ``BasicNN`` is a
  103. class that encapsulates a PyTorch model, transforming it into a base
  104. model with an sklearn-style interface.
  105. .. code:: ipython3
  106. cls = LeNet5(num_classes=10)
  107. loss_fn = nn.CrossEntropyLoss()
  108. optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, alpha=0.9)
  109. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  110. base_model = BasicNN(
  111. cls,
  112. loss_fn,
  113. optimizer,
  114. device,
  115. batch_size=32,
  116. num_epochs=1,
  117. )
  118. ``BasicNN`` offers methods like ``predict`` and ``predict_prob``, which
  119. are used to predict the class index and the probabilities of each class
  120. for images. As shown below:
  121. .. code:: ipython3
  122. data_instances = [torch.randn(1, 28, 28).to(device) for _ in range(32)]
  123. pred_idx = base_model.predict(X=data_instances)
  124. print(f"Predicted class index for a batch of 32 instances: np.ndarray with shape {pred_idx.shape}")
  125. pred_prob = base_model.predict_proba(X=data_instances)
  126. print(f"Predicted class probabilities for a batch of 32 instances: np.ndarray with shape {pred_prob.shape}")
  127. Out:
  128. .. code:: none
  129. :class: code-out
  130. Predicted class index for a batch of 32 instances: np.ndarray with shape (32,)
  131. Predicted class probabilities for a batch of 32 instances: np.ndarray with shape (32, 10)
  132. However, the base model built above deals with instance-level data
  133. (i.e., individual images), and can not directly deal with example-level
  134. data (i.e., a pair of images). Therefore, we wrap the base model into
  135. ``ABLModel``, which enables the learning part to train, test, and
  136. predict on example-level data.
  137. .. code:: ipython3
  138. model = ABLModel(base_model)
  139. As an illustration, consider this example of training on example-level
  140. data using the ``predict`` method in ``ABLModel``. In this process, the
  141. method accepts data examples as input and outputs the class labels and
  142. the probabilities of each class for all instances within these data
  143. examples.
  144. .. code:: ipython3
  145. from abl.structures import ListData
  146. # ListData is a data structure provided by ABL-Package that can be used to organize data examples
  147. data_examples = ListData()
  148. # We use the first 100 data examples in the training set as an illustration
  149. data_examples.X = train_X[:100]
  150. data_examples.gt_pseudo_label = train_gt_pseudo_label[:100]
  151. data_examples.Y = train_Y[:100]
  152. # Perform prediction on the 100 data examples
  153. pred_label, pred_prob = model.predict(data_examples)['label'], model.predict(data_examples)['prob']
  154. print(f"Predicted class labels for the 100 data examples: \n" +
  155. f"a list of length {len(pred_label)}, and each element is " +
  156. f"a {type(pred_label[0]).__name__} of shape {pred_label[0].shape}.\n")
  157. print(f"Predicted class probabilities for the 100 data examples: \n" +
  158. f"a list of length {len(pred_prob)}, and each element is " +
  159. f"a {type(pred_prob[0]).__name__} of shape {pred_prob[0].shape}.")
  160. Out:
  161. .. code:: none
  162. :class: code-out
  163. Predicted class labels for the 100 data examples:
  164. a list of length 100, and each element is a ndarray of shape (2,).
  165. Predicted class probabilities for the 100 data examples:
  166. a list of length 100, and each element is a ndarray of shape (2, 10).
  167. Building the Reasoning Part
  168. ---------------------------
  169. In the reasoning part, we first build a knowledge base which contain
  170. information on how to perform addition operations. We build it by
  171. creating a subclass of ``KBBase``. In the derived subclass, we
  172. initialize the ``pseudo_label_list`` parameter specifying list of
  173. possible pseudo-labels, and override the ``logic_forward`` function
  174. defining how to perform (deductive) reasoning.
  175. .. code:: ipython3
  176. class AddKB(KBBase):
  177. def __init__(self, pseudo_label_list=list(range(10))):
  178. super().__init__(pseudo_label_list)
  179. # Implement the deduction function
  180. def logic_forward(self, nums):
  181. return sum(nums)
  182. kb = AddKB()
  183. The knowledge base can perform logical reasoning (both deductive
  184. reasoning and abductive reasoning). Below is an example of performing
  185. (deductive) reasoning, and users can refer to :ref:`Performing abductive
  186. reasoning in the knowledge base <kb-abd>` for details of abductive reasoning.
  187. .. code:: ipython3
  188. pseudo_label_example = [1, 2]
  189. reasoning_result = kb.logic_forward(pseudo_label_example)
  190. print(f"Reasoning result of pseudo-label example {pseudo_label_example} is {reasoning_result}.")
  191. Out:
  192. .. code:: none
  193. :class: code-out
  194. Reasoning result of pseudo-label example [1, 2] is 3.
  195. .. note::
  196. In addition to building a knowledge base based on ``KBBase``, we
  197. can also establish a knowledge base with a ground KB using ``GroundKB``,
  198. or a knowledge base implemented based on Prolog files using
  199. ``PrologKB``. The corresponding code for these implementations can be
  200. found in the ``main.py`` file. Those interested are encouraged to
  201. examine it for further insights.
  202. Then, we create a reasoner by instantiating the class ``Reasoner``. Due
  203. to the indeterminism of abductive reasoning, there could be multiple
  204. candidates compatible to the knowledge base. When this happens, reasoner
  205. can minimize inconsistencies between the knowledge base and
  206. pseudo-labels predicted by the learning part, and then return only one
  207. candidate that has the highest consistency.
  208. .. code:: ipython3
  209. reasoner = Reasoner(kb)
  210. .. note::
  211. During creating reasoner, the definition of “consistency” can be
  212. customized within the ``dist_func`` parameter. In the code above, we
  213. employ a consistency measurement based on confidence, which calculates
  214. the consistency between the data example and candidates based on the
  215. confidence derived from the predicted probability. In ``examples/mnist_add/main.py``, we
  216. provide options for utilizing other forms of consistency measurement.
  217. Also, during process of inconsistency minimization, we can leverage
  218. `ZOOpt library <https://github.com/polixir/ZOOpt>`__ for acceleration.
  219. Options for this are also available in ``examples/mnist_add/main.py``. Those interested are
  220. encouraged to explore these features.
  221. Building Evaluation Metrics
  222. ---------------------------
  223. Next, we set up evaluation metrics. These metrics will be used to
  224. evaluate the model performance during training and testing.
  225. Specifically, we use ``SymbolMetric`` and ``ReasoningMetric``, which are
  226. used to evaluate the accuracy of the machine learning model’s
  227. predictions and the accuracy of the final reasoning results,
  228. respectively.
  229. .. code:: ipython3
  230. metric_list = [SymbolMetric(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")]
  231. Bridge Learning and Reasoning
  232. -----------------------------
  233. Now, the last step is to bridge the learning and reasoning part. We
  234. proceed this step by creating an instance of ``SimpleBridge``.
  235. .. code:: ipython3
  236. bridge = SimpleBridge(model, reasoner, metric_list)
  237. Perform training and testing by invoking the ``train`` and ``test``
  238. methods of ``SimpleBridge``.
  239. .. code:: ipython3
  240. # Build logger
  241. print_log("Abductive Learning on the MNIST Addition example.", logger="current")
  242. log_dir = ABLLogger.get_current_instance().log_dir
  243. weights_dir = osp.join(log_dir, "weights")
  244. bridge.train(train_data, loops=5, segment_size=1/3, save_interval=1, save_dir=weights_dir)
  245. bridge.test(test_data)
  246. Out:
  247. .. code:: none
  248. :class: code-out
  249. abl - INFO - Abductive Learning on the MNIST Addition example.
  250. abl - INFO - loop(train) [1/5] segment(train) [1/3]
  251. abl - INFO - model loss: 1.49104
  252. abl - INFO - loop(train) [1/5] segment(train) [2/3]
  253. abl - INFO - model loss: 1.24945
  254. abl - INFO - loop(train) [1/5] segment(train) [3/3]
  255. abl - INFO - model loss: 0.87861
  256. abl - INFO - Evaluation start: loop(val) [1]
  257. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.818 mnist_add/reasoning_accuracy: 0.672
  258. abl - INFO - Saving model: loop(save) [1]
  259. abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_1.pth
  260. abl - INFO - loop(train) [2/5] segment(train) [1/3]
  261. abl - INFO - model loss: 0.31148
  262. abl - INFO - loop(train) [2/5] segment(train) [2/3]
  263. abl - INFO - model loss: 0.09520
  264. abl - INFO - loop(train) [2/5] segment(train) [3/3]
  265. abl - INFO - model loss: 0.07402
  266. abl - INFO - Evaluation start: loop(val) [2]
  267. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.982 mnist_add/reasoning_accuracy: 0.964
  268. abl - INFO - Saving model: loop(save) [2]
  269. abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_2.pth
  270. abl - INFO - loop(train) [3/5] segment(train) [1/3]
  271. abl - INFO - model loss: 0.06027
  272. abl - INFO - loop(train) [3/5] segment(train) [2/3]
  273. abl - INFO - model loss: 0.05341
  274. abl - INFO - loop(train) [3/5] segment(train) [3/3]
  275. abl - INFO - model loss: 0.04915
  276. abl - INFO - Evaluation start: loop(val) [3]
  277. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.987 mnist_add/reasoning_accuracy: 0.975
  278. abl - INFO - Saving model: loop(save) [3]
  279. abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_3.pth
  280. abl - INFO - loop(train) [4/5] segment(train) [1/3]
  281. abl - INFO - model loss: 0.04413
  282. abl - INFO - loop(train) [4/5] segment(train) [2/3]
  283. abl - INFO - model loss: 0.04181
  284. abl - INFO - loop(train) [4/5] segment(train) [3/3]
  285. abl - INFO - model loss: 0.04127
  286. abl - INFO - Evaluation start: loop(val) [4]
  287. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.990 mnist_add/reasoning_accuracy: 0.980
  288. abl - INFO - Saving model: loop(save) [4]
  289. abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_4.pth
  290. abl - INFO - loop(train) [5/5] segment(train) [1/3]
  291. abl - INFO - model loss: 0.03544
  292. abl - INFO - loop(train) [5/5] segment(train) [2/3]
  293. abl - INFO - model loss: 0.03092
  294. abl - INFO - loop(train) [5/5] segment(train) [3/3]
  295. abl - INFO - model loss: 0.03663
  296. abl - INFO - Evaluation start: loop(val) [5]
  297. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.991 mnist_add/reasoning_accuracy: 0.982
  298. abl - INFO - Saving model: loop(save) [5]
  299. abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_5.pth
  300. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.987 mnist_add/reasoning_accuracy: 0.974
  301. More concrete examples are available in ``examples/mnist_add/main.py`` and ``examples/mnist_add/mnist_add.ipynb``.

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.