You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

HED.rst 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. Handwritten Equation Decipherment (HED)
  2. =======================================
  3. .. raw:: html
  4. <p>For detailed code implementation, please view it on <a class="reference external" href="https://github.com/AbductiveLearning/ABLKit/tree/main/examples/hed" target="_blank">GitHub</a>.</p>
  5. Below shows an implementation of `Handwritten Equation
  6. Decipherment <https://proceedings.neurips.cc/paper_files/paper/2019/file/9c19a2aa1d84e04b0bd4bc888792bd1e-Paper.pdf>`__.
  7. In this task, the handwritten equations are given, which consist of
  8. sequential pictures of characters. The equations are generated with
  9. unknown operation rules from images of symbols (‘0’, ‘1’, ‘+’ and ‘=’),
  10. and each equation is associated with a label indicating whether the
  11. equation is correct (i.e., positive) or not (i.e., negative). Also, we
  12. are given a knowledge base which involves the structure of the equations
  13. and a recursive definition of bit-wise operations. The task is to learn
  14. from a training set of above-mentioned equations and then to predict
  15. labels of unseen equations.
  16. Intuitively, we first use a machine learning model (learning part) to
  17. obtain the pseudo-labels (‘0’, ‘1’, ‘+’ and ‘=’) for the observed
  18. pictures. We then use the knowledge base (reasoning part) to perform
  19. abductive reasoning so as to yield ground hypotheses as possible
  20. explanations to the observed facts, suggesting some pseudo-labels to be
  21. revised. This process enables us to further update the machine learning
  22. model.
  23. .. code:: ipython3
  24. # Import necessary libraries and modules
  25. import os.path as osp
  26. import matplotlib.pyplot as plt
  27. import torch
  28. import torch.nn as nn
  29. from ablkit.learning import ABLModel, BasicNN
  30. from ablkit.utils import ABLLogger, print_log
  31. from bridge import HedBridge
  32. from consistency_metric import ConsistencyMetric
  33. from datasets import get_dataset, split_equation
  34. from models.nn import SymbolNet
  35. from reasoning import HedKB, HedReasoner
  36. Working with Data
  37. -----------------
  38. First, we get the datasets of handwritten equations:
  39. .. code:: ipython3
  40. total_train_data = get_dataset(train=True)
  41. train_data, val_data = split_equation(total_train_data, 3, 1)
  42. test_data = get_dataset(train=False)
  43. The datasets are shown below:
  44. .. code:: ipython3
  45. true_train_equation = train_data[1]
  46. false_train_equation = train_data[0]
  47. print(f"Equations in the dataset is organized by equation length, " +
  48. f"from {min(train_data[0].keys())} to {max(train_data[0].keys())}.")
  49. print()
  50. true_train_equation_with_length_5 = true_train_equation[5]
  51. false_train_equation_with_length_5 = false_train_equation[5]
  52. print(f"For each eqaation length, there are {len(true_train_equation_with_length_5)} " +
  53. f"true equations and {len(false_train_equation_with_length_5)} false equations " +
  54. f"in the training set.")
  55. true_val_equation = val_data[1]
  56. false_val_equation = val_data[0]
  57. true_val_equation_with_length_5 = true_val_equation[5]
  58. false_val_equation_with_length_5 = false_val_equation[5]
  59. print(f"For each equation length, there are {len(true_val_equation_with_length_5)} " +
  60. f"true equations and {len(false_val_equation_with_length_5)} false equations " +
  61. f"in the validation set.")
  62. true_test_equation = test_data[1]
  63. false_test_equation = test_data[0]
  64. true_test_equation_with_length_5 = true_test_equation[5]
  65. false_test_equation_with_length_5 = false_test_equation[5]
  66. print(f"For each equation length, there are {len(true_test_equation_with_length_5)} " +
  67. f"true equations and {len(false_test_equation_with_length_5)} false equations " +
  68. f"in the test set.")
  69. Out:
  70. .. code:: none
  71. :class: code-out
  72. Equations in the dataset is organized by equation length, from 5 to 26.
  73. For each equation length, there are 225 true equations and 225 false equations in the training set.
  74. For each equation length, there are 75 true equations and 75 false equations in the validation set.
  75. For each equation length, there are 300 true equations and 300 false equations in the test set.
  76. As illustrations, we show four equations in the training dataset:
  77. .. code:: ipython3
  78. true_train_equation_with_length_5 = true_train_equation[5]
  79. true_train_equation_with_length_8 = true_train_equation[8]
  80. print(f"First true equation with length 5 in the training dataset:")
  81. for i, x in enumerate(true_train_equation_with_length_5[0]):
  82. plt.subplot(1, 5, i+1)
  83. plt.axis('off')
  84. plt.imshow(x.squeeze(), cmap='gray')
  85. plt.show()
  86. print(f"First true equation with length 8 in the training dataset:")
  87. for i, x in enumerate(true_train_equation_with_length_8[0]):
  88. plt.subplot(1, 8, i+1)
  89. plt.axis('off')
  90. plt.imshow(x.squeeze(), cmap='gray')
  91. plt.show()
  92. false_train_equation_with_length_5 = false_train_equation[5]
  93. false_train_equation_with_length_8 = false_train_equation[8]
  94. print(f"First false equation with length 5 in the training dataset:")
  95. for i, x in enumerate(false_train_equation_with_length_5[0]):
  96. plt.subplot(1, 5, i+1)
  97. plt.axis('off')
  98. plt.imshow(x.squeeze(), cmap='gray')
  99. plt.show()
  100. print(f"First false equation with length 8 in the training dataset:")
  101. for i, x in enumerate(false_train_equation_with_length_8[0]):
  102. plt.subplot(1, 8, i+1)
  103. plt.axis('off')
  104. plt.imshow(x.squeeze(), cmap='gray')
  105. plt.show()
  106. Out:
  107. .. code:: none
  108. :class: code-out
  109. First true equation with length 5 in the training dataset:
  110. .. image:: ../_static/img/hed_dataset1.png
  111. :width: 300px
  112. .. code:: none
  113. :class: code-out
  114. First true equation with length 8 in the training dataset:
  115. .. image:: ../_static/img/hed_dataset2.png
  116. :width: 480px
  117. .. code:: none
  118. :class: code-out
  119. First false equation with length 5 in the training dataset:
  120. .. image:: ../_static/img/hed_dataset3.png
  121. :width: 300px
  122. .. code:: none
  123. :class: code-out
  124. First false equation with length 8 in the training dataset:
  125. .. image:: ../_static/img/hed_dataset4.png
  126. :width: 480px
  127. Building the Learning Part
  128. --------------------------
  129. To build the learning part, we need to first build a machine learning
  130. base model. We use SymbolNet, and encapsulate it within a ``BasicNN``
  131. object to create the base model. ``BasicNN`` is a class that
  132. encapsulates a PyTorch model, transforming it into a base model with an
  133. sklearn-style interface.
  134. .. code:: ipython3
  135. # class of symbol may be one of ['0', '1', '+', '='], total of 4 classes
  136. cls = SymbolNet(num_classes=4)
  137. loss_fn = nn.CrossEntropyLoss()
  138. optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-4)
  139. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  140. base_model = BasicNN(
  141. cls,
  142. loss_fn,
  143. optimizer,
  144. device=device,
  145. batch_size=32,
  146. num_epochs=1,
  147. stop_loss=None,
  148. )
  149. However, the base model built above deals with instance-level data
  150. (i.e., individual images), and can not directly deal with example-level
  151. data (i.e., a list of images comprising the equation). Therefore, we
  152. wrap the base model into ``ABLModel``, which enables the learning part
  153. to train, test, and predict on example-level data.
  154. .. code:: ipython3
  155. model = ABLModel(base_model)
  156. Building the Reasoning Part
  157. ---------------------------
  158. In the reasoning part, we first build a knowledge base. As mentioned
  159. before, the knowledge base in this task involves the structure of the
  160. equations and a recursive definition of bit-wise operations, which are
  161. defined in Prolog file ``examples/hed/reasoning/BK.pl``
  162. and ``examples/hed/reasoning/learn_add.pl``, respectively.
  163. Specifically, the knowledge about the structure of equations is a set of DCG
  164. rules recursively define that a digit is a sequence of ‘0’ and ‘1’, and
  165. equations share the structure of X+Y=Z, though the length of X, Y and Z
  166. can be varied. The knowledge about bit-wise operations is a recursive
  167. logic program, which reversely calculates X+Y, i.e., it operates on
  168. X and Y digit-by-digit and from the last digit to the first.
  169. The knowledge base is already built in ``HedKB``.
  170. ``HedKB`` is derived from class ``PrologKB``, and is built upon the aformentioned Prolog
  171. files.
  172. .. code:: ipython3
  173. kb = HedKB()
  174. .. note::
  175. Please notice that, the specific rules for calculating the
  176. operations are undefined in the knowledge base, i.e., results of ‘0+0’,
  177. ‘0+1’ and ‘1+1’ could be ‘0’, ‘1’, ‘00’, ‘01’ or even ‘10’. The missing
  178. calculation rules are required to be learned from the data. Therefore,
  179. ``HedKB`` incorporates methods for abducing rules from data. Users
  180. interested can refer to the specific implementation of ``HedKB`` in
  181. ``examples/hed/reasoning/reasoning.py``
  182. Then, we create a reasoner. Due to the indeterminism of abductive
  183. reasoning, there could be multiple candidates compatible with the
  184. knowledge base. When this happens, reasoner can minimize inconsistencies
  185. between the knowledge base and pseudo-labels predicted by the learning
  186. part, and then return only one candidate that has the highest
  187. consistency.
  188. In this task, we create the reasoner by instantiating the class
  189. ``HedReasoner``, which is a reasoner derived from ``Reasoner`` and
  190. tailored specifically for this task. ``HedReasoner`` leverages `ZOOpt
  191. library <https://github.com/polixir/ZOOpt>`__ for acceleration, and has
  192. designed a specific strategy to better harness ZOOpt’s capabilities.
  193. Additionally, methods for abducing rules from data have been
  194. incorporated. Users interested can refer to the specific implementation
  195. of ``HedReasoner`` in ``reasoning/reasoning.py``.
  196. .. code:: ipython3
  197. reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=10)
  198. Building Evaluation Metrics
  199. ---------------------------
  200. Next, we set up evaluation metrics. These metrics will be used to
  201. evaluate the model performance during training and testing.
  202. Specifically, we use ``SymbolAccuracy`` and ``ReasoningMetric``, which are
  203. used to evaluate the accuracy of the machine learning model’s
  204. predictions and the accuracy of the final reasoning results,
  205. respectively.
  206. .. code:: ipython3
  207. # Set up metrics
  208. metric_list = [SymbolAccuracy(prefix="hed"), ReasoningMetric(kb=kb, prefix="hed")]
  209. Bridging Learning and Reasoning
  210. -------------------------------
  211. Now, the last step is to bridge the learning and reasoning part. We
  212. proceed with this step by creating an instance of ``HedBridge``, which is
  213. derived from ``SimpleBridge`` and tailored specific for this task.
  214. .. code:: ipython3
  215. bridge = HedBridge(model, reasoner, metric_list)
  216. Perform pretraining, training and testing by invoking the ``pretrain``, ``train`` and ``test`` methods of ``HedBridge``.
  217. .. code:: ipython3
  218. # Build logger
  219. print_log("Abductive Learning on the HED example.", logger="current")
  220. # Retrieve the directory of the Log file and define the directory for saving the model weights.
  221. log_dir = ABLLogger.get_current_instance().log_dir
  222. weights_dir = osp.join(log_dir, "weights")
  223. bridge.pretrain("./weights")
  224. bridge.train(train_data, val_data, save_dir=weights_dir)
  225. bridge.test(test_data)

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.