You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hed_example.ipynb 9.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": null,
  6. "metadata": {},
  7. "outputs": [],
  8. "source": [
  9. "import os.path as osp\n",
  10. "\n",
  11. "import numpy as np\n",
  12. "import torch\n",
  13. "import torch.nn as nn\n",
  14. "from zoopt import Dimension, Objective, Opt, Parameter\n",
  15. "\n",
  16. "from abl.evaluation import ReasoningMetric, SymbolMetric\n",
  17. "from abl.learning import ABLModel, BasicNN\n",
  18. "from abl.reasoning import PrologKB, Reasoner\n",
  19. "from abl.utils import ABLLogger, print_log, reform_list\n",
  20. "from examples.hed.datasets.get_hed import get_hed, split_equation\n",
  21. "from examples.hed.hed_bridge import HEDBridge\n",
  22. "from examples.models.nn import SymbolNet"
  23. ]
  24. },
  25. {
  26. "cell_type": "code",
  27. "execution_count": null,
  28. "metadata": {},
  29. "outputs": [],
  30. "source": [
  31. "# Build logger\n",
  32. "print_log(\"Abductive Learning on the HED example.\", logger=\"current\")\n",
  33. "\n",
  34. "# Retrieve the directory of the Log file and define the directory for saving the model weights.\n",
  35. "log_dir = ABLLogger.get_current_instance().log_dir\n",
  36. "weights_dir = osp.join(log_dir, \"weights\")"
  37. ]
  38. },
  39. {
  40. "attachments": {},
  41. "cell_type": "markdown",
  42. "metadata": {},
  43. "source": [
  44. "### Logic Part"
  45. ]
  46. },
  47. {
  48. "cell_type": "code",
  49. "execution_count": null,
  50. "metadata": {},
  51. "outputs": [],
  52. "source": [
  53. "# Initialize knowledge base and abducer\n",
  54. "class HedKB(PrologKB):\n",
  55. " def __init__(self, pseudo_label_list, pl_file):\n",
  56. " super().__init__(pseudo_label_list, pl_file)\n",
  57. "\n",
  58. " def consist_rule(self, exs, rules):\n",
  59. " rules = str(rules).replace(\"'\", \"\")\n",
  60. " return len(list(self.prolog.query(\"eval_inst_feature(%s, %s).\" % (exs, rules)))) != 0\n",
  61. "\n",
  62. " def abduce_rules(self, pred_res):\n",
  63. " prolog_result = list(self.prolog.query(\"consistent_inst_feature(%s, X).\" % pred_res))\n",
  64. " if len(prolog_result) == 0:\n",
  65. " return None\n",
  66. " prolog_rules = prolog_result[0][\"X\"]\n",
  67. " rules = [rule.value for rule in prolog_rules]\n",
  68. " return rules\n",
  69. "\n",
  70. "\n",
  71. "class HedReasoner(Reasoner):\n",
  72. " def revise_at_idx(self, data_example):\n",
  73. " revision_idx = np.where(np.array(data_example.flatten(\"revision_flag\")) != 0)[0]\n",
  74. " candidate = self.kb.revise_at_idx(\n",
  75. " data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx\n",
  76. " )\n",
  77. " return candidate\n",
  78. "\n",
  79. " def zoopt_revision_score(self, symbol_num, data_example, sol):\n",
  80. " revision_flag = reform_list(\n",
  81. " list(sol.get_x().astype(np.int32)), data_example.pred_pseudo_label\n",
  82. " )\n",
  83. " data_example.revision_flag = revision_flag\n",
  84. "\n",
  85. " lefted_idxs = [i for i in range(len(data_example.pred_idx))]\n",
  86. " candidate_size = []\n",
  87. " max_consistent_idxs = []\n",
  88. " while lefted_idxs:\n",
  89. " idxs = []\n",
  90. " idxs.append(lefted_idxs.pop(0))\n",
  91. " max_candidate_idxs = []\n",
  92. " found = False\n",
  93. " for idx in range(-1, len(data_example.pred_idx)):\n",
  94. " if (not idx in idxs) and (idx >= 0):\n",
  95. " idxs.append(idx)\n",
  96. " candidates, _ = self.revise_at_idx(data_example[idxs])\n",
  97. " if len(candidates) == 0:\n",
  98. " if len(idxs) > 1:\n",
  99. " idxs.pop()\n",
  100. " else:\n",
  101. " if len(idxs) > len(max_candidate_idxs):\n",
  102. " found = True\n",
  103. " max_candidate_idxs = idxs.copy()\n",
  104. " removed = [i for i in lefted_idxs if i in max_candidate_idxs]\n",
  105. " if found:\n",
  106. " removed.insert(0, idxs[0])\n",
  107. " candidate_size.append(len(removed))\n",
  108. " max_consistent_idxs = max_candidate_idxs.copy()\n",
  109. " lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]\n",
  110. " candidate_size.sort()\n",
  111. " score = 0\n",
  112. " import math\n",
  113. "\n",
  114. " for i in range(0, len(candidate_size)):\n",
  115. " score -= math.exp(-i) * candidate_size[i]\n",
  116. " return score, max_consistent_idxs\n",
  117. " \n",
  118. " def _zoopt_get_solution(self, symbol_num, data_example, max_revision_num):\n",
  119. " dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num)\n",
  120. " objective = Objective(\n",
  121. " lambda sol: self.zoopt_revision_score(symbol_num, data_example, sol)[0],\n",
  122. " dim=dimension,\n",
  123. " constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),\n",
  124. " )\n",
  125. " parameter = Parameter(budget=200, intermediate_result=False, autoset=True)\n",
  126. " solution = Opt.min(objective, parameter)\n",
  127. " return solution\n",
  128. "\n",
  129. " def abduce(self, data_example):\n",
  130. " symbol_num = data_example.elements_num(\"pred_pseudo_label\")\n",
  131. " max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)\n",
  132. "\n",
  133. " solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num)\n",
  134. " _, max_candidate_idxs = self.zoopt_revision_score(symbol_num, data_example, solution)\n",
  135. "\n",
  136. " abduced_pseudo_label = [[] for _ in range(len(data_example))]\n",
  137. "\n",
  138. " if len(max_candidate_idxs) > 0:\n",
  139. " candidates, _ = self.revise_at_idx(data_example[max_candidate_idxs])\n",
  140. " for i, idx in enumerate(max_candidate_idxs):\n",
  141. " abduced_pseudo_label[idx] = candidates[0][i]\n",
  142. " data_example.abduced_pseudo_label = abduced_pseudo_label\n",
  143. " return abduced_pseudo_label\n",
  144. "\n",
  145. " def abduce_rules(self, pred_res):\n",
  146. " return self.kb.abduce_rules(pred_res)\n",
  147. "\n",
  148. "\n",
  149. "kb = HedKB(pseudo_label_list=[1, 0, \"+\", \"=\"], pl_file=\"./datasets/learn_add.pl\")\n",
  150. "reasoner = HedReasoner(kb, dist_func=\"hamming\", use_zoopt=True, max_revision=10)"
  151. ]
  152. },
  153. {
  154. "attachments": {},
  155. "cell_type": "markdown",
  156. "metadata": {},
  157. "source": [
  158. "### Machine Learning Part"
  159. ]
  160. },
  161. {
  162. "cell_type": "code",
  163. "execution_count": null,
  164. "metadata": {},
  165. "outputs": [],
  166. "source": [
  167. "# Build necessary components for BasicNN\n",
  168. "cls = SymbolNet(num_classes=4)\n",
  169. "loss_fn = nn.CrossEntropyLoss()\n",
  170. "optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-4)\n",
  171. "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
  172. ]
  173. },
  174. {
  175. "cell_type": "code",
  176. "execution_count": null,
  177. "metadata": {},
  178. "outputs": [],
  179. "source": [
  180. "# Build BasicNN\n",
  181. "# The function of BasicNN is to wrap NN models into the form of an sklearn estimator\n",
  182. "base_model = BasicNN(\n",
  183. " cls,\n",
  184. " loss_fn,\n",
  185. " optimizer,\n",
  186. " device,\n",
  187. " batch_size=32,\n",
  188. " num_epochs=1,\n",
  189. " save_interval=1,\n",
  190. " stop_loss=None,\n",
  191. " save_dir=weights_dir,\n",
  192. ")"
  193. ]
  194. },
  195. {
  196. "cell_type": "code",
  197. "execution_count": null,
  198. "metadata": {},
  199. "outputs": [],
  200. "source": [
  201. "# Build ABLModel\n",
  202. "# The main function of the ABL model is to serialize data and\n",
  203. "# provide a unified interface for different machine learning models\n",
  204. "model = ABLModel(base_model)"
  205. ]
  206. },
  207. {
  208. "attachments": {},
  209. "cell_type": "markdown",
  210. "metadata": {},
  211. "source": [
  212. "### Metric"
  213. ]
  214. },
  215. {
  216. "cell_type": "code",
  217. "execution_count": null,
  218. "metadata": {},
  219. "outputs": [],
  220. "source": [
  221. "# Set up metrics\n",
  222. "metric_list = [SymbolMetric(prefix=\"hed\"), ReasoningMetric(kb=kb, prefix=\"hed\")]"
  223. ]
  224. },
  225. {
  226. "attachments": {},
  227. "cell_type": "markdown",
  228. "metadata": {},
  229. "source": [
  230. "### Bridge Machine Learning and Logic Reasoning"
  231. ]
  232. },
  233. {
  234. "cell_type": "code",
  235. "execution_count": null,
  236. "metadata": {},
  237. "outputs": [],
  238. "source": [
  239. "bridge = HEDBridge(model, reasoner, metric_list)"
  240. ]
  241. },
  242. {
  243. "attachments": {},
  244. "cell_type": "markdown",
  245. "metadata": {},
  246. "source": [
  247. "### Dataset"
  248. ]
  249. },
  250. {
  251. "cell_type": "code",
  252. "execution_count": null,
  253. "metadata": {},
  254. "outputs": [],
  255. "source": [
  256. "total_train_data = get_hed(train=True)\n",
  257. "train_data, val_data = split_equation(total_train_data, 3, 1)\n",
  258. "test_data = get_hed(train=False)"
  259. ]
  260. },
  261. {
  262. "attachments": {},
  263. "cell_type": "markdown",
  264. "metadata": {},
  265. "source": [
  266. "### Train and Test"
  267. ]
  268. },
  269. {
  270. "cell_type": "code",
  271. "execution_count": null,
  272. "metadata": {},
  273. "outputs": [],
  274. "source": [
  275. "bridge.pretrain(\"./weights\")\n",
  276. "bridge.train(train_data, val_data)"
  277. ]
  278. }
  279. ],
  280. "metadata": {
  281. "kernelspec": {
  282. "display_name": "ABL",
  283. "language": "python",
  284. "name": "python3"
  285. },
  286. "language_info": {
  287. "codemirror_mode": {
  288. "name": "ipython",
  289. "version": 3
  290. },
  291. "file_extension": ".py",
  292. "mimetype": "text/x-python",
  293. "name": "python",
  294. "nbconvert_exporter": "python",
  295. "pygments_lexer": "ipython3",
  296. "version": "3.8.18"
  297. },
  298. "orig_nbformat": 4,
  299. "vscode": {
  300. "interpreter": {
  301. "hash": "fb6f4ceeabb9a733f366948eb80109f83aedf798cc984df1e68fb411adb27d58"
  302. }
  303. }
  304. },
  305. "nbformat": 4,
  306. "nbformat_minor": 2
  307. }

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.