You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

HWF.rst 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. Handwritten Formula (HWF)
  2. =========================
  3. Below shows an implementation of `Handwritten
  4. Formula <https://arxiv.org/abs/2006.06649>`__. In this
  5. task, handwritten images of decimal formulas and their computed results
  6. are given, alongwith a domain knowledge base containing information on
  7. how to compute the decimal formula. The task is to recognize the symbols
  8. (which can be digits or operators ‘+’, ‘-’, ‘×’, ‘÷’) of handwritten
  9. images and accurately determine their results.
  10. Intuitively, we first use a machine learning model (learning part) to
  11. convert the input images to symbols (we call them pseudo-labels), and
  12. then use the knowledge base (reasoning part) to calculate the results of
  13. these symbols. Since we do not have ground-truth of the symbols, in
  14. Abductive Learning, the reasoning part will leverage domain knowledge
  15. and revise the initial symbols yielded by the learning part through
  16. abductive reasoning. This process enables us to further update the
  17. machine learning model.
  18. .. code:: ipython3
  19. # Import necessary libraries and modules
  20. import os.path as osp
  21. import numpy as np
  22. import torch
  23. import torch.nn as nn
  24. import matplotlib.pyplot as plt
  25. from datasets import get_dataset
  26. from models.nn import SymbolNet
  27. from abl.learning import ABLModel, BasicNN
  28. from abl.reasoning import KBBase, Reasoner
  29. from abl.data.evaluation import ReasoningMetric, SymbolAccuracy
  30. from abl.utils import ABLLogger, print_log
  31. from abl.bridge import SimpleBridge
  32. Working with Data
  33. -----------------
  34. First, we get the training and testing datasets:
  35. .. code:: ipython3
  36. train_data = get_dataset(train=True, get_pseudo_label=True)
  37. test_data = get_dataset(train=False, get_pseudo_label=True)
  38. Both ``train_data`` and ``test_data`` have the same structures: tuples
  39. with three components: X (list where each element is a list of images),
  40. gt_pseudo_label (list where each element is a list of symbols, i.e.,
  41. pseudo-labels) and Y (list where each element is the computed result).
  42. The length and structures of datasets are illustrated as follows.
  43. .. note::
  44. ``gt_pseudo_label`` is only used to evaluate the performance of
  45. the learning part but not to train the model.
  46. .. code:: ipython3
  47. print(f"Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y")
  48. print()
  49. train_X, train_gt_pseudo_label, train_Y = train_data
  50. print(f"Length of X, gt_pseudo_label, Y in train_data: " +
  51. f"{len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}")
  52. test_X, test_gt_pseudo_label, test_Y = test_data
  53. print(f"Length of X, gt_pseudo_label, Y in test_data: " +
  54. f"{len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}")
  55. print()
  56. X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]
  57. print(f"X is a {type(train_X).__name__}, " +
  58. f"with each element being a {type(X_0).__name__} of {type(X_0[0]).__name__}.")
  59. print(f"gt_pseudo_label is a {type(train_gt_pseudo_label).__name__}, " +
  60. f"with each element being a {type(gt_pseudo_label_0).__name__} " +
  61. f"of {type(gt_pseudo_label_0[0]).__name__}.")
  62. print(f"Y is a {type(train_Y).__name__}, " +
  63. f"with each element being a {type(Y_0).__name__}.")
  64. Out:
  65. .. code:: none
  66. :class: code-out
  67. Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y
  68. Length of X, gt_pseudo_label, Y in train_data: 10000, 10000, 10000
  69. Length of X, gt_pseudo_label, Y in test_data: 2000, 2000, 2000
  70. X is a list, with each element being a list of Tensor.
  71. gt_pseudo_label is a list, with each element being a list of str.
  72. Y is a list, with each element being a int.
  73. The ith element of X, gt_pseudo_label, and Y together constitute the ith
  74. data example. Here we use two of them (the 1001st and the 3001st) as
  75. illstrations:
  76. .. code:: ipython3
  77. X_1000, gt_pseudo_label_1000, Y_1000 = train_X[1000], train_gt_pseudo_label[1000], train_Y[1000]
  78. print(f"X in the 1001st data example (a list of images):")
  79. for i, x in enumerate(X_1000):
  80. plt.subplot(1, len(X_1000), i+1)
  81. plt.axis('off')
  82. plt.imshow(x.squeeze(), cmap='gray')
  83. plt.show()
  84. print(f"gt_pseudo_label in the 1001st data example (a list of ground truth pseudo-labels): {gt_pseudo_label_1000}")
  85. print(f"Y in the 1001st data example (the computed result): {Y_1000}")
  86. print()
  87. X_3000, gt_pseudo_label_3000, Y_3000 = train_X[3000], train_gt_pseudo_label[3000], train_Y[3000]
  88. print(f"X in the 3001st data example (a list of images):")
  89. for i, x in enumerate(X_3000):
  90. plt.subplot(1, len(X_3000), i+1)
  91. plt.axis('off')
  92. plt.imshow(x.squeeze(), cmap='gray')
  93. plt.show()
  94. print(f"gt_pseudo_label in the 3001st data example (a list of ground truth pseudo-labels): {gt_pseudo_label_3000}")
  95. print(f"Y in the 3001st data example (the computed result): {Y_3000}")
  96. Out:
  97. .. code:: none
  98. :class: code-out
  99. X in the 1001st data example (a list of images):
  100. .. image:: ../img/hwf_dataset1.png
  101. :width: 210px
  102. .. code:: none
  103. :class: code-out
  104. gt_pseudo_label in the 1001st data example (a list of pseudo-labels): ['5', '-', '3']
  105. Y in the 1001st data example (the computed result): 2
  106. .. code:: none
  107. :class: code-out
  108. X in the 3001st data example (a list of images):
  109. .. image:: ../img/hwf_dataset2.png
  110. :width: 350px
  111. .. code:: none
  112. :class: code-out
  113. gt_pseudo_label in the 3001st data example (a list of pseudo-labels): ['4', '/', '6', '*', '5']
  114. Y in the 3001st data example (the computed result): 3.333333333333333
  115. .. note::
  116. The symbols in the HWF dataset can be one of digits or operators
  117. '+', '-', '×', '÷'.
  118. We may see that, in the 1001st data example, the length of the
  119. formula is 3, while in the 3001st data example, the length of the
  120. formula is 5. In the HWF dataset, the length of the formula varies from
  121. 1 to 7.
  122. Building the Learning Part
  123. --------------------------
  124. To build the learning part, we need to first build a machine learning
  125. base model. We use SymbolNet, and encapsulate it within a ``BasicNN``
  126. object to create the base model. ``BasicNN`` is a class that
  127. encapsulates a PyTorch model, transforming it into a base model with an
  128. sklearn-style interface.
  129. .. code:: ipython3
  130. # class of symbol may be one of ['1', ..., '9', '+', '-', '*', '/'], total of 14 classes
  131. cls = SymbolNet(num_classes=13, image_size=(45, 45, 1))
  132. loss_fn = nn.CrossEntropyLoss()
  133. optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))
  134. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  135. base_model = BasicNN(
  136. model=cls,
  137. loss_fn=loss_fn,
  138. optimizer=optimizer,
  139. device=device,
  140. batch_size=128,
  141. num_epochs=3,
  142. )
  143. ``BasicNN`` offers methods like ``predict`` and ``predict_prob``, which
  144. are used to predict the class index and the probabilities of each class
  145. for images. As shown below:
  146. .. code:: ipython3
  147. data_instances = [torch.randn(1, 45, 45).to(device) for _ in range(32)]
  148. pred_idx = base_model.predict(X=data_instances)
  149. print(f"Predicted class index for a batch of 32 instances: " +
  150. f"{type(pred_idx).__name__} with shape {pred_idx.shape}")
  151. pred_prob = base_model.predict_proba(X=data_instances)
  152. print(f"Predicted class probabilities for a batch of 32 instances: " +
  153. f"{type(pred_prob).__name__} with shape {pred_prob.shape}")
  154. Out:
  155. .. code:: none
  156. :class: code-out
  157. Predicted class index for a batch of 32 instances: ndarray with shape (32,)
  158. Predicted class probabilities for a batch of 32 instances: ndarray with shape (32, 14)
  159. However, the base model built above deals with instance-level data
  160. (i.e., individual images), and can not directly deal with example-level
  161. data (i.e., a list of images comprising the formula). Therefore, we wrap
  162. the base model into ``ABLModel``, which enables the learning part to
  163. train, test, and predict on example-level data.
  164. .. code:: ipython3
  165. model = ABLModel(base_model)
  166. As an illustration, consider this example of training on example-level
  167. data using the ``predict`` method in ``ABLModel``. In this process, the
  168. method accepts data examples as input and outputs the class labels and
  169. the probabilities of each class for all instances within these data
  170. examples.
  171. .. code:: ipython3
  172. from abl.data.structures import ListData
  173. # ListData is a data structure provided by ABL-Package that can be used to organize data examples
  174. data_examples = ListData()
  175. # We use the first 1001st and 3001st data examples in the training set as an illustration
  176. data_examples.X = [X_1000, X_3000]
  177. data_examples.gt_pseudo_label = [gt_pseudo_label_1000, gt_pseudo_label_3000]
  178. data_examples.Y = [Y_1000, Y_3000]
  179. # Perform prediction on the two data examples
  180. # Remind that, in the 1001st data example, the length of the formula is 3,
  181. # while in the 3001st data example, the length of the formula is 5.
  182. pred_label, pred_prob = model.predict(data_examples)['label'], model.predict(data_examples)['prob']
  183. print(f"Predicted class labels for the 100 data examples: a list of length {len(pred_label)}, \n" +
  184. f"the first element is a {type(pred_label[0]).__name__} of shape {pred_label[0].shape}, "+
  185. f"and the second element is a {type(pred_label[1]).__name__} of shape {pred_label[1].shape}.\n")
  186. print(f"Predicted class probabilities for the 100 data examples: a list of length {len(pred_prob)}, \n"
  187. f"the first element is a {type(pred_prob[0]).__name__} of shape {pred_prob[0].shape}, " +
  188. f"and the second element is a {type(pred_prob[1]).__name__} of shape {pred_prob[1].shape}.")
  189. Out:
  190. .. code:: none
  191. :class: code-out
  192. Predicted class labels for the 100 data examples: a list of length 2,
  193. the first element is a ndarray of shape (3,), and the second element is a ndarray of shape (5,).
  194. Predicted class probabilities for the 100 data examples: a list of length 2,
  195. the first element is a ndarray of shape (3, 14), and the second element is a ndarray of shape (5, 14).
  196. Building the Reasoning Part
  197. ---------------------------
  198. In the reasoning part, we first build a knowledge base which contain
  199. information on how to perform addition operations. We build it by
  200. creating a subclass of ``KBBase``. In the derived subclass, we
  201. initialize the ``pseudo_label_list`` parameter specifying list of
  202. possible pseudo-labels, and override the ``logic_forward`` function
  203. defining how to perform (deductive) reasoning.
  204. .. code:: ipython3
  205. class HwfKB(KBBase):
  206. def __init__(self, pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"]):
  207. super().__init__(pseudo_label_list)
  208. def _valid_candidate(self, formula):
  209. if len(formula) % 2 == 0:
  210. return False
  211. for i in range(len(formula)):
  212. if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]:
  213. return False
  214. if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]:
  215. return False
  216. return True
  217. # Implement the deduction function
  218. def logic_forward(self, formula):
  219. if not self._valid_candidate(formula):
  220. return np.inf
  221. return eval("".join(formula))
  222. kb = HwfKB()
  223. The knowledge base can perform logical reasoning (both deductive
  224. reasoning and abductive reasoning). Below is an example of performing
  225. (deductive) reasoning, and users can refer to :ref:`Performing abductive
  226. reasoning in the knowledge base <kb-abd>` for details of abductive reasoning.
  227. .. code:: ipython3
  228. pseudo_labels = ["1", "-", "2", "*", "5"]
  229. reasoning_result = kb.logic_forward(pseudo_labels)
  230. print(f"Reasoning result of pseudo-labels {pseudo_labels} is {reasoning_result}.")
  231. Out:
  232. .. code:: none
  233. :class: code-out
  234. Reasoning result of pseudo-labels ['1', '-', '2', '*', '5'] is -9.
  235. .. note::
  236. In addition to building a knowledge base based on ``KBBase``, we
  237. can also establish a knowledge base with a ground KB using ``GroundKB``.
  238. The corresponding code can be found in the ``examples/hwf/main.py`` file. Those
  239. interested are encouraged to examine it for further insights.
  240. Also, when building the knowledge base, we can also set the
  241. ``max_err`` parameter during initialization, which is shown in the
  242. ``examples/hwf/main.py`` file. This parameter specifies the upper tolerance limit
  243. when comparing the similarity between the reasoning result of pseudo-labels and
  244. the ground truth during abductive reasoning, with a default
  245. value of 1e-10.
  246. Then, we create a reasoner by instantiating the class ``Reasoner``. Due
  247. to the indeterminism of abductive reasoning, there could be multiple
  248. candidates compatible to the knowledge base. When this happens, reasoner
  249. can minimize inconsistencies between the knowledge base and
  250. pseudo-labels predicted by the learning part, and then return only one
  251. candidate that has the highest consistency.
  252. .. code:: ipython3
  253. reasoner = Reasoner(kb)
  254. .. note::
  255. During creating reasoner, the definition of “consistency” can be
  256. customized within the ``dist_func`` parameter. In the code above, we
  257. employ a consistency measurement based on confidence, which calculates
  258. the consistency between the data example and candidates based on the
  259. confidence derived from the predicted probability. In ``examples/hwf/main.py``, we
  260. provide options for utilizing other forms of consistency measurement.
  261. Also, during process of inconsistency minimization, we can
  262. leverage `ZOOpt library <https://github.com/polixir/ZOOpt>`__ for
  263. acceleration. Options for this are also available in ``examples/hwf/main.py``. Those
  264. interested are encouraged to explore these features.
  265. Building Evaluation Metrics
  266. ---------------------------
  267. Next, we set up evaluation metrics. These metrics will be used to
  268. evaluate the model performance during training and testing.
  269. Specifically, we use ``SymbolAccuracy`` and ``ReasoningMetric``, which are
  270. used to evaluate the accuracy of the machine learning model’s
  271. predictions and the accuracy of the final reasoning results,
  272. respectively.
  273. .. code:: ipython3
  274. metric_list = [SymbolAccuracy(prefix="hwf"), ReasoningMetric(kb=kb, prefix="hwf")]
  275. Bridge Learning and Reasoning
  276. -----------------------------
  277. Now, the last step is to bridge the learning and reasoning part. We
  278. proceed this step by creating an instance of ``SimpleBridge``.
  279. .. code:: ipython3
  280. bridge = SimpleBridge(model, reasoner, metric_list)
  281. Perform training and testing by invoking the ``train`` and ``test``
  282. methods of ``SimpleBridge``.
  283. .. code:: ipython3
  284. # Build logger
  285. print_log("Abductive Learning on the HWF example.", logger="current")
  286. log_dir = ABLLogger.get_current_instance().log_dir
  287. weights_dir = osp.join(log_dir, "weights")
  288. bridge.train(train_data, train_data, loops=3, segment_size=1000, save_dir=weights_dir)
  289. bridge.test(test_data)
  290. Out:
  291. .. code:: none
  292. :class: code-out
  293. abl - INFO - Abductive Learning on the HWF example.
  294. abl - INFO - loop(train) [1/3] segment(train) [1/10]
  295. abl - INFO - model loss: 0.00024
  296. abl - INFO - loop(train) [1/3] segment(train) [2/10]
  297. abl - INFO - model loss: 0.00053
  298. abl - INFO - loop(train) [1/3] segment(train) [3/10]
  299. abl - INFO - model loss: 0.00260
  300. abl - INFO - loop(train) [1/3] segment(train) [4/10]
  301. abl - INFO - model loss: 0.00162
  302. abl - INFO - loop(train) [1/3] segment(train) [5/10]
  303. abl - INFO - model loss: 0.00073
  304. abl - INFO - loop(train) [1/3] segment(train) [6/10]
  305. abl - INFO - model loss: 0.00055
  306. abl - INFO - loop(train) [1/3] segment(train) [7/10]
  307. abl - INFO - model loss: 0.00148
  308. abl - INFO - loop(train) [1/3] segment(train) [8/10]
  309. abl - INFO - model loss: 0.00034
  310. abl - INFO - loop(train) [1/3] segment(train) [9/10]
  311. abl - INFO - model loss: 0.00167
  312. abl - INFO - loop(train) [1/3] segment(train) [10/10]
  313. abl - INFO - model loss: 0.00185
  314. abl - INFO - Evaluation start: loop(val) [1]
  315. abl - INFO - Evaluation ended, hwf/character_accuracy: 1.000 hwf/reasoning_accuracy: 0.999
  316. abl - INFO - Saving model: loop(save) [1]
  317. abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_1.pth
  318. abl - INFO - loop(train) [2/3] segment(train) [1/10]
  319. abl - INFO - model loss: 0.00219
  320. abl - INFO - loop(train) [2/3] segment(train) [2/10]
  321. abl - INFO - model loss: 0.00069
  322. abl - INFO - loop(train) [2/3] segment(train) [3/10]
  323. abl - INFO - model loss: 0.00013
  324. abl - INFO - loop(train) [2/3] segment(train) [4/10]
  325. abl - INFO - model loss: 0.00013
  326. abl - INFO - loop(train) [2/3] segment(train) [5/10]
  327. abl - INFO - model loss: 0.00248
  328. abl - INFO - loop(train) [2/3] segment(train) [6/10]
  329. abl - INFO - model loss: 0.00010
  330. abl - INFO - loop(train) [2/3] segment(train) [7/10]
  331. abl - INFO - model loss: 0.00020
  332. abl - INFO - loop(train) [2/3] segment(train) [8/10]
  333. abl - INFO - model loss: 0.00076
  334. abl - INFO - loop(train) [2/3] segment(train) [9/10]
  335. abl - INFO - model loss: 0.00061
  336. abl - INFO - loop(train) [2/3] segment(train) [10/10]
  337. abl - INFO - model loss: 0.00117
  338. abl - INFO - Evaluation start: loop(val) [2]
  339. abl - INFO - Evaluation ended, hwf/character_accuracy: 1.000 hwf/reasoning_accuracy: 1.000
  340. abl - INFO - Saving model: loop(save) [2]
  341. abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_2.pth
  342. abl - INFO - loop(train) [3/3] segment(train) [1/10]
  343. abl - INFO - model loss: 0.00120
  344. abl - INFO - loop(train) [3/3] segment(train) [2/10]
  345. abl - INFO - model loss: 0.00114
  346. abl - INFO - loop(train) [3/3] segment(train) [3/10]
  347. abl - INFO - model loss: 0.00071
  348. abl - INFO - loop(train) [3/3] segment(train) [4/10]
  349. abl - INFO - model loss: 0.00027
  350. abl - INFO - loop(train) [3/3] segment(train) [5/10]
  351. abl - INFO - model loss: 0.00017
  352. abl - INFO - loop(train) [3/3] segment(train) [6/10]
  353. abl - INFO - model loss: 0.00018
  354. abl - INFO - loop(train) [3/3] segment(train) [7/10]
  355. abl - INFO - model loss: 0.00141
  356. abl - INFO - loop(train) [3/3] segment(train) [8/10]
  357. abl - INFO - model loss: 0.00099
  358. abl - INFO - loop(train) [3/3] segment(train) [9/10]
  359. abl - INFO - model loss: 0.00145
  360. abl - INFO - loop(train) [3/3] segment(train) [10/10]
  361. abl - INFO - model loss: 0.00215
  362. abl - INFO - Evaluation start: loop(val) [3]
  363. abl - INFO - Evaluation ended, hwf/character_accuracy: 1.000 hwf/reasoning_accuracy: 1.000
  364. abl - INFO - Saving model: loop(save) [3]
  365. abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_2.pth
  366. abl - INFO - Evaluation ended, hwf/character_accuracy: 0.996 hwf/reasoning_accuracy: 0.977
  367. More concrete examples are available in ``examples/hwf/main.py`` and ``examples/hwf/hwf.ipynb``.

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.