You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. Handwritten Formula (HWF)
  2. =========================
  3. .. raw:: html
  4. <p>For detailed code implementation, please view it on <a class="reference external" href="https://github.com/AbductiveLearning/ABLkit/tree/main/examples/hwf" target="_blank">GitHub</a>.</p>
  5. Below shows an implementation of `Handwritten
  6. Formula <https://arxiv.org/abs/2006.06649>`__. In this
  7. task, handwritten images of decimal formulas and their computed results
  8. are given, alongwith a domain knowledge base containing information on
  9. how to compute the decimal formula. The task is to recognize the symbols
  10. (which can be digits or operators ‘+’, ‘-’, ‘×’, ‘÷’) of handwritten
  11. images and accurately determine their results.
  12. Intuitively, we first use a machine learning model (learning part) to
  13. convert the input images to symbols (we call them pseudo-labels), and
  14. then use the knowledge base (reasoning part) to calculate the results of
  15. these symbols. Since we do not have ground-truth of the symbols, in
  16. Abductive Learning, the reasoning part will leverage domain knowledge
  17. and revise the initial symbols yielded by the learning part through
  18. abductive reasoning. This process enables us to further update the
  19. machine learning model.
  20. .. code:: python
  21. # Import necessary libraries and modules
  22. import os.path as osp
  23. import matplotlib.pyplot as plt
  24. import numpy as np
  25. import torch
  26. import torch.nn as nn
  27. from ablkit.bridge import SimpleBridge
  28. from ablkit.data.evaluation import ReasoningMetric, SymbolAccuracy
  29. from ablkit.learning import ABLModel, BasicNN
  30. from ablkit.reasoning import KBBase, Reasoner
  31. from ablkit.utils import ABLLogger, print_log
  32. from datasets import get_dataset
  33. from models.nn import SymbolNet
  34. Working with Data
  35. -----------------
  36. First, we get the training and testing datasets:
  37. .. code:: python
  38. train_data = get_dataset(train=True, get_pseudo_label=True)
  39. test_data = get_dataset(train=False, get_pseudo_label=True)
  40. Both ``train_data`` and ``test_data`` have the same structures: tuples
  41. with three components: X (list where each element is a list of images),
  42. gt_pseudo_label (list where each element is a list of symbols, i.e.,
  43. pseudo-labels) and Y (list where each element is the computed result).
  44. The length and structures of datasets are illustrated as follows.
  45. .. note::
  46. ``gt_pseudo_label`` is only used to evaluate the performance of
  47. the learning part but not to train the model.
  48. .. code:: python
  49. print(f"Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y")
  50. print()
  51. train_X, train_gt_pseudo_label, train_Y = train_data
  52. print(f"Length of X, gt_pseudo_label, Y in train_data: " +
  53. f"{len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}")
  54. test_X, test_gt_pseudo_label, test_Y = test_data
  55. print(f"Length of X, gt_pseudo_label, Y in test_data: " +
  56. f"{len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}")
  57. print()
  58. X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]
  59. print(f"X is a {type(train_X).__name__}, " +
  60. f"with each element being a {type(X_0).__name__} of {type(X_0[0]).__name__}.")
  61. print(f"gt_pseudo_label is a {type(train_gt_pseudo_label).__name__}, " +
  62. f"with each element being a {type(gt_pseudo_label_0).__name__} " +
  63. f"of {type(gt_pseudo_label_0[0]).__name__}.")
  64. print(f"Y is a {type(train_Y).__name__}, " +
  65. f"with each element being an {type(Y_0).__name__}.")
  66. Out:
  67. .. code:: none
  68. :class: code-out
  69. Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y
  70. Length of X, gt_pseudo_label, Y in train_data: 10000, 10000, 10000
  71. Length of X, gt_pseudo_label, Y in test_data: 2000, 2000, 2000
  72. X is a list, with each element being a list of Tensor.
  73. gt_pseudo_label is a list, with each element being a list of str.
  74. Y is a list, with each element being an int.
  75. The ith element of X, gt_pseudo_label, and Y together constitute the ith
  76. data example. Here we use two of them (the 1001st and the 3001st) as
  77. illstrations:
  78. .. code:: python
  79. X_1000, gt_pseudo_label_1000, Y_1000 = train_X[1000], train_gt_pseudo_label[1000], train_Y[1000]
  80. print(f"X in the 1001st data example (a list of images):")
  81. for i, x in enumerate(X_1000):
  82. plt.subplot(1, len(X_1000), i+1)
  83. plt.axis('off')
  84. plt.imshow(x.squeeze(), cmap='gray')
  85. plt.show()
  86. print(f"gt_pseudo_label in the 1001st data example (a list of ground truth pseudo-labels): {gt_pseudo_label_1000}")
  87. print(f"Y in the 1001st data example (the computed result): {Y_1000}")
  88. print()
  89. X_3000, gt_pseudo_label_3000, Y_3000 = train_X[3000], train_gt_pseudo_label[3000], train_Y[3000]
  90. print(f"X in the 3001st data example (a list of images):")
  91. for i, x in enumerate(X_3000):
  92. plt.subplot(1, len(X_3000), i+1)
  93. plt.axis('off')
  94. plt.imshow(x.squeeze(), cmap='gray')
  95. plt.show()
  96. print(f"gt_pseudo_label in the 3001st data example (a list of ground truth pseudo-labels): {gt_pseudo_label_3000}")
  97. print(f"Y in the 3001st data example (the computed result): {Y_3000}")
  98. Out:
  99. .. code:: none
  100. :class: code-out
  101. X in the 1001st data example (a list of images):
  102. .. image:: ../_static/img/hwf_dataset1.png
  103. :width: 210px
  104. .. code:: none
  105. :class: code-out
  106. gt_pseudo_label in the 1001st data example (a list of pseudo-labels): ['5', '-', '3']
  107. Y in the 1001st data example (the computed result): 2
  108. .. code:: none
  109. :class: code-out
  110. X in the 3001st data example (a list of images):
  111. .. image:: ../_static/img/hwf_dataset2.png
  112. :width: 350px
  113. .. code:: none
  114. :class: code-out
  115. gt_pseudo_label in the 3001st data example (a list of pseudo-labels): ['4', '/', '6', '*', '5']
  116. Y in the 3001st data example (the computed result): 3.333333333333333
  117. .. note::
  118. The symbols in the HWF dataset can be one of digits or operators
  119. '+', '-', '×', '÷'.
  120. We may see that, in the 1001st data example, the length of the
  121. formula is 3, while in the 3001st data example, the length of the
  122. formula is 5. In the HWF dataset, the lengths of the formulas are
  123. 1, 3, 5, and 7 (Specifically, 10% of the equations have a length of 1,
  124. 10% have a length of 3, 20% have a length of 5, and 60% have a length of 7).
  125. Building the Learning Part
  126. --------------------------
  127. To build the learning part, we need to first build a machine learning
  128. base model. We use SymbolNet, and encapsulate it within a ``BasicNN``
  129. object to create the base model. ``BasicNN`` is a class that
  130. encapsulates a PyTorch model, transforming it into a base model with an
  131. sklearn-style interface.
  132. .. code:: python
  133. # class of symbol may be one of ['1', ..., '9', '+', '-', '*', '/'], total of 14 classes
  134. cls = SymbolNet(num_classes=13, image_size=(45, 45, 1))
  135. loss_fn = nn.CrossEntropyLoss()
  136. optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))
  137. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  138. base_model = BasicNN(
  139. model=cls,
  140. loss_fn=loss_fn,
  141. optimizer=optimizer,
  142. device=device,
  143. batch_size=128,
  144. num_epochs=3,
  145. )
  146. ``BasicNN`` offers methods like ``predict`` and ``predict_proba``, which
  147. are used to predict the class index and the probabilities of each class
  148. for images. As shown below:
  149. .. code:: python
  150. data_instances = [torch.randn(1, 45, 45) for _ in range(32)]
  151. pred_idx = base_model.predict(X=data_instances)
  152. print(f"Predicted class index for a batch of 32 instances: " +
  153. f"{type(pred_idx).__name__} with shape {pred_idx.shape}")
  154. pred_prob = base_model.predict_proba(X=data_instances)
  155. print(f"Predicted class probabilities for a batch of 32 instances: " +
  156. f"{type(pred_prob).__name__} with shape {pred_prob.shape}")
  157. Out:
  158. .. code:: none
  159. :class: code-out
  160. Predicted class index for a batch of 32 instances: ndarray with shape (32,)
  161. Predicted class probabilities for a batch of 32 instances: ndarray with shape (32, 14)
  162. However, the base model built above deals with instance-level data
  163. (i.e., individual images), and can not directly deal with example-level
  164. data (i.e., a list of images comprising the formula). Therefore, we wrap
  165. the base model into ``ABLModel``, which enables the learning part to
  166. train, test, and predict on example-level data.
  167. .. code:: python
  168. model = ABLModel(base_model)
  169. As an illustration, consider this example of training on example-level
  170. data using the ``predict`` method in ``ABLModel``. In this process, the
  171. method accepts data examples as input and outputs the class labels and
  172. the probabilities of each class for all instances within these data
  173. examples.
  174. .. code:: python
  175. from ablkit.data.structures import ListData
  176. # ListData is a data structure provided by ABLkit that can be used to organize data examples
  177. data_examples = ListData()
  178. # We use the first 1001st and 3001st data examples in the training set as an illustration
  179. data_examples.X = [X_1000, X_3000]
  180. data_examples.gt_pseudo_label = [gt_pseudo_label_1000, gt_pseudo_label_3000]
  181. data_examples.Y = [Y_1000, Y_3000]
  182. # Perform prediction on the two data examples
  183. # Remind that, in the 1001st data example, the length of the formula is 3,
  184. # while in the 3001st data example, the length of the formula is 5.
  185. pred_label, pred_prob = model.predict(data_examples)['label'], model.predict(data_examples)['prob']
  186. print(f"Predicted class labels for the 100 data examples: a list of length {len(pred_label)}, \n" +
  187. f"the first element is a {type(pred_label[0]).__name__} of shape {pred_label[0].shape}, "+
  188. f"and the second element is a {type(pred_label[1]).__name__} of shape {pred_label[1].shape}.\n")
  189. print(f"Predicted class probabilities for the 100 data examples: a list of length {len(pred_prob)}, \n"
  190. f"the first element is a {type(pred_prob[0]).__name__} of shape {pred_prob[0].shape}, " +
  191. f"and the second element is a {type(pred_prob[1]).__name__} of shape {pred_prob[1].shape}.")
  192. Out:
  193. .. code:: none
  194. :class: code-out
  195. Predicted class labels for the 100 data examples: a list of length 2,
  196. the first element is a ndarray of shape (3,), and the second element is a ndarray of shape (5,).
  197. Predicted class probabilities for the 100 data examples: a list of length 2,
  198. the first element is a ndarray of shape (3, 14), and the second element is a ndarray of shape (5, 14).
  199. Building the Reasoning Part
  200. ---------------------------
  201. In the reasoning part, we first build a knowledge base which contains
  202. information on how to compute a formula. We build it by
  203. creating a subclass of ``KBBase``. In the derived subclass, we
  204. initialize the ``pseudo_label_list`` parameter specifying list of
  205. possible pseudo-labels, and override the ``logic_forward`` function
  206. defining how to perform (deductive) reasoning.
  207. .. code:: python
  208. class HwfKB(KBBase):
  209. def __init__(self, pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"]):
  210. super().__init__(pseudo_label_list)
  211. def _valid_candidate(self, formula):
  212. if len(formula) % 2 == 0:
  213. return False
  214. for i in range(len(formula)):
  215. if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]:
  216. return False
  217. if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]:
  218. return False
  219. return True
  220. # Implement the deduction function
  221. def logic_forward(self, formula):
  222. if not self._valid_candidate(formula):
  223. return np.inf
  224. return eval("".join(formula))
  225. kb = HwfKB()
  226. The knowledge base can perform logical reasoning (both deductive
  227. reasoning and abductive reasoning). Below is an example of performing
  228. (deductive) reasoning, and users can refer to :ref:`Performing abductive
  229. reasoning in the knowledge base <kb-abd>` for details of abductive reasoning.
  230. .. code:: python
  231. pseudo_labels = ["1", "-", "2", "*", "5"]
  232. reasoning_result = kb.logic_forward(pseudo_labels)
  233. print(f"Reasoning result of pseudo-labels {pseudo_labels} is {reasoning_result}.")
  234. Out:
  235. .. code:: none
  236. :class: code-out
  237. Reasoning result of pseudo-labels ['1', '-', '2', '*', '5'] is -9.
  238. .. note::
  239. In addition to building a knowledge base based on ``KBBase``, we
  240. can also establish a knowledge base with a ground KB using ``GroundKB``.
  241. The corresponding code can be found in the ``examples/hwf/main.py`` file. Those
  242. interested are encouraged to examine it for further insights.
  243. Also, when building the knowledge base, we can also set the
  244. ``max_err`` parameter during initialization, which is shown in the
  245. ``examples/hwf/main.py`` file. This parameter specifies the upper tolerance limit
  246. when comparing the similarity between the reasoning result of pseudo-labels and
  247. the ground truth during abductive reasoning, with a default
  248. value of 1e-10.
  249. Then, we create a reasoner by instantiating the class ``Reasoner``. Due
  250. to the indeterminism of abductive reasoning, there could be multiple
  251. candidates compatible with the knowledge base. When this happens, reasoner
  252. can minimize inconsistencies between the knowledge base and
  253. pseudo-labels predicted by the learning part, and then return only one
  254. candidate that has the highest consistency.
  255. .. code:: python
  256. reasoner = Reasoner(kb)
  257. .. note::
  258. During creating reasoner, the definition of “consistency” can be
  259. customized within the ``dist_func`` parameter. In the code above, we
  260. employ a consistency measurement based on confidence, which calculates
  261. the consistency between the data example and candidates based on the
  262. confidence derived from the predicted probability. In ``examples/hwf/main.py``, we
  263. provide options for utilizing other forms of consistency measurement.
  264. Also, during the process of inconsistency minimization, we can
  265. leverage `ZOOpt library <https://github.com/polixir/ZOOpt>`__ for
  266. acceleration. Options for this are also available in ``examples/hwf/main.py``. Those
  267. interested are encouraged to explore these features.
  268. Building Evaluation Metrics
  269. ---------------------------
  270. Next, we set up evaluation metrics. These metrics will be used to
  271. evaluate the model performance during training and testing.
  272. Specifically, we use ``SymbolAccuracy`` and ``ReasoningMetric``, which are
  273. used to evaluate the accuracy of the machine learning model’s
  274. predictions and the accuracy of the final reasoning results,
  275. respectively.
  276. .. code:: python
  277. metric_list = [SymbolAccuracy(prefix="hwf"), ReasoningMetric(kb=kb, prefix="hwf")]
  278. Bridging Learning and Reasoning
  279. -------------------------------
  280. Now, the last step is to bridge the learning and reasoning part. We
  281. proceed with this step by creating an instance of ``SimpleBridge``.
  282. .. code:: python
  283. bridge = SimpleBridge(model, reasoner, metric_list)
  284. Perform training and testing by invoking the ``train`` and ``test``
  285. methods of ``SimpleBridge``.
  286. .. code:: python
  287. # Build logger
  288. print_log("Abductive Learning on the HWF example.", logger="current")
  289. log_dir = ABLLogger.get_current_instance().log_dir
  290. weights_dir = osp.join(log_dir, "weights")
  291. bridge.train(train_data, loops=3, segment_size=1000, save_dir=weights_dir)
  292. bridge.test(test_data)
  293. The log will appear similar to the following:
  294. Log:
  295. .. code:: none
  296. :class: code-out
  297. abl - INFO - Abductive Learning on the HWF example.
  298. abl - INFO - loop(train) [1/3] segment(train) [1/10]
  299. abl - INFO - model loss: 0.00024
  300. abl - INFO - loop(train) [1/3] segment(train) [2/10]
  301. abl - INFO - model loss: 0.00011
  302. abl - INFO - loop(train) [1/3] segment(train) [3/10]
  303. abl - INFO - model loss: 0.00332
  304. ...
  305. abl - INFO - Eval start: loop(val) [1]
  306. abl - INFO - Evaluation ended, hwf/character_accuracy: 0.997 hwf/reasoning_accuracy: 0.985
  307. abl - INFO - loop(train) [2/3] segment(train) [1/10]
  308. abl - INFO - model loss: 0.00126
  309. ...
  310. abl - INFO - Eval start: loop(val) [2]
  311. abl - INFO - Evaluation ended, hwf/character_accuracy: 0.998 hwf/reasoning_accuracy: 0.989
  312. abl - INFO - loop(train) [3/3] segment(train) [1/10]
  313. abl - INFO - model loss: 0.00030
  314. ...
  315. abl - INFO - Eval start: loop(val) [3]
  316. abl - INFO - Evaluation ended, hwf/character_accuracy: 0.999 hwf/reasoning_accuracy: 0.996
  317. abl - INFO - Test start:
  318. abl - INFO - Evaluation ended, hwf/character_accuracy: 0.997 hwf/reasoning_accuracy: 0.986
  319. Environment
  320. -----------
  321. For all experiments, we used a single linux server. Details on the specifications are listed in the table below.
  322. .. raw:: html
  323. <style type="text/css">
  324. .tg {border-collapse:collapse;border-spacing:0;margin-bottom:20px;}
  325. .tg td, .tg th {border:1px solid #ddd;padding:8px 22px;text-align:center;}
  326. .tg th {background-color:#f5f5f5;color:#333333;}
  327. .tg tr:nth-child(even) {background-color:#f9f9f9;}
  328. .tg tr:nth-child(odd) {background-color:#ffffff;}
  329. </style>
  330. <table class="tg" style="margin-left: auto; margin-right: auto;">
  331. <thead>
  332. <tr>
  333. <th>CPU</th>
  334. <th>GPU</th>
  335. <th>Memory</th>
  336. <th>OS</th>
  337. </tr>
  338. </thead>
  339. <tbody>
  340. <tr>
  341. <td>2 * Xeon Platinum 8358, 32 Cores, 2.6 GHz Base Frequency</td>
  342. <td>A100 80GB</td>
  343. <td>512GB</td>
  344. <td>Ubuntu 20.04</td>
  345. </tr>
  346. </tbody>
  347. </table>
  348. Performance
  349. -----------
  350. We present the results of ABL as follows, which include the reasoning accuracy (for different equation lengths in the HWF dataset), training time (to achieve the accuracy using all equation lengths), and average memory usage (using all equation lengths). These results are compared with the following methods:
  351. - `NGS <https://github.com/liqing-ustc/NGS>`_: A neural-symbolic framework that uses a grammar model and a back-search algorithm to improve its computing process;
  352. - `DeepProbLog <https://github.com/ML-KULeuven/deepproblog/tree/master>`_: An extension of ProbLog by introducing neural predicates in Probabilistic Logic Programming;
  353. - `DeepStochLog <https://github.com/ML-KULeuven/deepstochlog/tree/main>`_: A neural-symbolic framework based on stochastic logic program.
  354. .. raw:: html
  355. <style type="text/css">
  356. .tg {border-collapse:collapse;border-spacing:0;margin-bottom:20px;}
  357. .tg td, .tg th {border:1px solid #ddd;padding:10px 15px;text-align:center;}
  358. .tg th {background-color:#f5f5f5;color:#333333;}
  359. .tg tr:nth-child(even) {background-color:#f9f9f9;}
  360. .tg tr:nth-child(odd) {background-color:#ffffff;}
  361. </style>
  362. <table class="tg" style="margin-left: auto; margin-right: auto;">
  363. <thead>
  364. <tr>
  365. <th rowspan="2"></th>
  366. <th colspan="5">Reasoning Accuracy<br><span style="font-weight: normal; font-size: smaller;">(for different equation lengths)</span></th>
  367. <th rowspan="2">Training Time (s)<br><span style="font-weight: normal; font-size: smaller;">(to achieve the Acc. using all lengths)</span></th>
  368. <th rowspan="2">Average Memory Usage (MB)<br><span style="font-weight: normal; font-size: smaller;">(using all lengths)</span></th>
  369. </tr>
  370. <tr>
  371. <th>1</th>
  372. <th>3</th>
  373. <th>5</th>
  374. <th>7</th>
  375. <th>All</th>
  376. </tr>
  377. </thead>
  378. <tbody>
  379. <tr>
  380. <td>NGS</td>
  381. <td>91.2</td>
  382. <td>89.1</td>
  383. <td>92.7</td>
  384. <td>5.2</td>
  385. <td>98.4</td>
  386. <td>426.2</td>
  387. <td>3705</td>
  388. </tr>
  389. <tr>
  390. <td>DeepProbLog</td>
  391. <td>90.8</td>
  392. <td>85.6</td>
  393. <td>timeout*</td>
  394. <td>timeout</td>
  395. <td>timeout</td>
  396. <td>timeout</td>
  397. <td>4315</td>
  398. </tr>
  399. <tr>
  400. <td>DeepStochLog</td>
  401. <td>92.8</td>
  402. <td>87.5</td>
  403. <td>92.1</td>
  404. <td>timeout</td>
  405. <td>timeout</td>
  406. <td>timeout</td>
  407. <td>4355</td>
  408. </tr>
  409. <tr>
  410. <td>ABL</td>
  411. <td><span style="font-weight:bold">94.0</span></td>
  412. <td><span style="font-weight:bold">89.7</span></td>
  413. <td><span style="font-weight:bold">96.5</span></td>
  414. <td><span style="font-weight:bold">97.2</span></td>
  415. <td><span style="font-weight:bold">98.6</span></td>
  416. <td><span style="font-weight:bold">77.3</span></td>
  417. <td><span style="font-weight:bold">3074</span></td>
  418. </tr>
  419. </tbody>
  420. </table>
  421. <p style="font-size: 13px;">* timeout: need more than 1 hour to execute</p>

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.