You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hed_example.ipynb 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": null,
  6. "metadata": {},
  7. "outputs": [],
  8. "source": [
  9. "import sys\n",
  10. "\n",
  11. "sys.path.append(\"../../\")\n",
  12. "\n",
  13. "import torch.nn as nn\n",
  14. "import torch\n",
  15. "\n",
  16. "from abl.abducer.abducer_base import HED_Abducer\n",
  17. "from abl.abducer.kb import HED_prolog_KB\n",
  18. "\n",
  19. "from abl.utils.plog import logger\n",
  20. "from abl.models.basic_model import BasicModel\n",
  21. "from abl.models.wabl_models import WABLBasicModel\n",
  22. "\n",
  23. "from models.nn import SymbolNet\n",
  24. "from datasets.get_hed import get_hed, split_equation\n",
  25. "import framework_hed"
  26. ]
  27. },
  28. {
  29. "cell_type": "code",
  30. "execution_count": null,
  31. "metadata": {},
  32. "outputs": [],
  33. "source": [
  34. "# Initialize logger\n",
  35. "recorder = logger()"
  36. ]
  37. },
  38. {
  39. "attachments": {},
  40. "cell_type": "markdown",
  41. "metadata": {},
  42. "source": [
  43. "### Logic Part"
  44. ]
  45. },
  46. {
  47. "cell_type": "code",
  48. "execution_count": null,
  49. "metadata": {},
  50. "outputs": [],
  51. "source": [
  52. "# Initialize knowledge base and abducer\n",
  53. "kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='./datasets/learn_add.pl')\n",
  54. "abducer = HED_Abducer(kb)"
  55. ]
  56. },
  57. {
  58. "attachments": {},
  59. "cell_type": "markdown",
  60. "metadata": {},
  61. "source": [
  62. "### Machine Learning Part"
  63. ]
  64. },
  65. {
  66. "cell_type": "code",
  67. "execution_count": null,
  68. "metadata": {},
  69. "outputs": [],
  70. "source": [
  71. "# Initialize necessary component for machine learning part\n",
  72. "cls = SymbolNet(\n",
  73. " num_classes=len(kb.pseudo_label_list),\n",
  74. " image_size=(28, 28, 1),\n",
  75. ")\n",
  76. "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
  77. "criterion = nn.CrossEntropyLoss()\n",
  78. "optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6)"
  79. ]
  80. },
  81. {
  82. "cell_type": "code",
  83. "execution_count": null,
  84. "metadata": {},
  85. "outputs": [],
  86. "source": [
  87. "# Pretrain NN classifier\n",
  88. "framework_hed.hed_pretrain(kb, cls, recorder)"
  89. ]
  90. },
  91. {
  92. "cell_type": "code",
  93. "execution_count": null,
  94. "metadata": {},
  95. "outputs": [],
  96. "source": [
  97. "# Initialize BasicModel\n",
  98. "# The function of BasicModel is to wrap NN models into the form of an sklearn estimator\n",
  99. "base_model = BasicModel(\n",
  100. " cls,\n",
  101. " criterion,\n",
  102. " optimizer,\n",
  103. " device,\n",
  104. " save_interval=1,\n",
  105. " save_dir=recorder.save_dir,\n",
  106. " batch_size=32,\n",
  107. " num_epochs=1,\n",
  108. " recorder=recorder,\n",
  109. ")"
  110. ]
  111. },
  112. {
  113. "attachments": {},
  114. "cell_type": "markdown",
  115. "metadata": {},
  116. "source": [
  117. "### Use WABL model to join two parts"
  118. ]
  119. },
  120. {
  121. "cell_type": "code",
  122. "execution_count": null,
  123. "metadata": {},
  124. "outputs": [],
  125. "source": [
  126. "model = WABLBasicModel(base_model, kb.pseudo_label_list)"
  127. ]
  128. },
  129. {
  130. "attachments": {},
  131. "cell_type": "markdown",
  132. "metadata": {},
  133. "source": [
  134. "### Dataset"
  135. ]
  136. },
  137. {
  138. "cell_type": "code",
  139. "execution_count": null,
  140. "metadata": {},
  141. "outputs": [],
  142. "source": [
  143. "total_train_data = get_hed(train=True)\n",
  144. "train_data, val_data = split_equation(total_train_data, 3, 1)\n",
  145. "test_data = get_hed(train=False)"
  146. ]
  147. },
  148. {
  149. "attachments": {},
  150. "cell_type": "markdown",
  151. "metadata": {},
  152. "source": [
  153. "### Train and save"
  154. ]
  155. },
  156. {
  157. "cell_type": "code",
  158. "execution_count": null,
  159. "metadata": {},
  160. "outputs": [],
  161. "source": [
  162. "model, mapping = framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8)\n",
  163. "framework_hed.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8)\n",
  164. "\n",
  165. "recorder.dump()"
  166. ]
  167. }
  168. ],
  169. "metadata": {
  170. "kernelspec": {
  171. "display_name": "ABL",
  172. "language": "python",
  173. "name": "python3"
  174. },
  175. "language_info": {
  176. "codemirror_mode": {
  177. "name": "ipython",
  178. "version": 3
  179. },
  180. "file_extension": ".py",
  181. "mimetype": "text/x-python",
  182. "name": "python",
  183. "nbconvert_exporter": "python",
  184. "pygments_lexer": "ipython3",
  185. "version": "3.8.16"
  186. },
  187. "orig_nbformat": 4
  188. },
  189. "nbformat": 4,
  190. "nbformat_minor": 2
  191. }

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.