You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

zoo_example.ipynb 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": null,
  6. "metadata": {},
  7. "outputs": [],
  8. "source": [
  9. "import os.path as osp\n",
  10. "\n",
  11. "import numpy as np\n",
  12. "from sklearn.ensemble import RandomForestClassifier\n",
  13. "from sklearn.metrics import accuracy_score\n",
  14. "from z3 import Solver, Int, If, Not, Implies, Sum, sat\n",
  15. "import openml\n",
  16. "\n",
  17. "from abl.learning import ABLModel\n",
  18. "from abl.reasoning import KBBase, Reasoner\n",
  19. "from abl.evaluation import ReasoningMetric, SymbolMetric\n",
  20. "from abl.bridge import SimpleBridge\n",
  21. "from abl.utils.utils import confidence_dist\n",
  22. "from abl.utils import ABLLogger, print_log"
  23. ]
  24. },
  25. {
  26. "cell_type": "code",
  27. "execution_count": null,
  28. "metadata": {},
  29. "outputs": [],
  30. "source": [
  31. "# Build logger\n",
  32. "print_log(\"Abductive Learning on the Zoo example.\", logger=\"current\")\n",
  33. "\n",
  34. "# Retrieve the directory of the Log file and define the directory for saving the model weights.\n",
  35. "log_dir = ABLLogger.get_current_instance().log_dir\n",
  36. "weights_dir = osp.join(log_dir, \"weights\")"
  37. ]
  38. },
  39. {
  40. "cell_type": "markdown",
  41. "metadata": {},
  42. "source": [
  43. "### Learning Part"
  44. ]
  45. },
  46. {
  47. "cell_type": "code",
  48. "execution_count": null,
  49. "metadata": {},
  50. "outputs": [],
  51. "source": [
  52. "rf = RandomForestClassifier()\n",
  53. "model = ABLModel(rf)"
  54. ]
  55. },
  56. {
  57. "cell_type": "markdown",
  58. "metadata": {},
  59. "source": [
  60. "### Logic Part"
  61. ]
  62. },
  63. {
  64. "cell_type": "code",
  65. "execution_count": null,
  66. "metadata": {},
  67. "outputs": [],
  68. "source": [
  69. "class ZooKB(KBBase):\n",
  70. " def __init__(self):\n",
  71. " super().__init__(pseudo_label_list=list(range(7)), use_cache=False)\n",
  72. " \n",
  73. " # Use z3 solver \n",
  74. " self.solver = Solver()\n",
  75. "\n",
  76. " # Load information of Zoo dataset\n",
  77. " dataset = openml.datasets.get_dataset(dataset_id = 62, download_data=False, download_qualities=False, download_features_meta_data=False)\n",
  78. " X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n",
  79. " self.attribute_names = attribute_names\n",
  80. " self.target_names = y.cat.categories.tolist()\n",
  81. " print(\"Attribute names are: \", self.attribute_names)\n",
  82. " print(\"Target names are: \", self.target_names)\n",
  83. " # self.attribute_names = [\"hair\", \"feathers\", \"eggs\", \"milk\", \"airborne\", \"aquatic\", \"predator\", \"toothed\", \"backbone\", \"breathes\", \"venomous\", \"fins\", \"legs\", \"tail\", \"domestic\", \"catsize\"]\n",
  84. " # self.target_names = [\"mammal\", \"bird\", \"reptile\", \"fish\", \"amphibian\", \"insect\", \"invertebrate\"]\n",
  85. "\n",
  86. " # Define variables\n",
  87. " for name in self.attribute_names+self.target_names:\n",
  88. " exec(f\"globals()['{name}'] = Int('{name}')\") ## or use dict to create var and modify rules\n",
  89. " # Define rules\n",
  90. " rules = [\n",
  91. " Implies(milk == 1, mammal == 1),\n",
  92. " Implies(mammal == 1, milk == 1),\n",
  93. " Implies(mammal == 1, backbone == 1),\n",
  94. " Implies(mammal == 1, breathes == 1),\n",
  95. " Implies(feathers == 1, bird == 1),\n",
  96. " Implies(bird == 1, feathers == 1),\n",
  97. " Implies(bird == 1, eggs == 1),\n",
  98. " Implies(bird == 1, backbone == 1),\n",
  99. " Implies(bird == 1, breathes == 1),\n",
  100. " Implies(bird == 1, legs == 2),\n",
  101. " Implies(bird == 1, tail == 1),\n",
  102. " Implies(reptile == 1, backbone == 1),\n",
  103. " Implies(reptile == 1, breathes == 1),\n",
  104. " Implies(reptile == 1, tail == 1),\n",
  105. " Implies(fish == 1, aquatic == 1),\n",
  106. " Implies(fish == 1, toothed == 1),\n",
  107. " Implies(fish == 1, backbone == 1),\n",
  108. " Implies(fish == 1, Not(breathes == 1)),\n",
  109. " Implies(fish == 1, fins == 1),\n",
  110. " Implies(fish == 1, legs == 0),\n",
  111. " Implies(fish == 1, tail == 1),\n",
  112. " Implies(amphibian == 1, eggs == 1),\n",
  113. " Implies(amphibian == 1, aquatic == 1),\n",
  114. " Implies(amphibian == 1, backbone == 1),\n",
  115. " Implies(amphibian == 1, breathes == 1),\n",
  116. " Implies(amphibian == 1, legs == 4),\n",
  117. " Implies(insect == 1, eggs == 1),\n",
  118. " Implies(insect == 1, Not(backbone == 1)),\n",
  119. " Implies(insect == 1, legs == 6),\n",
  120. " Implies(invertebrate == 1, Not(backbone == 1))\n",
  121. " ]\n",
  122. " # Define weights and sum of violated weights\n",
  123. " self.weights = {rule: 1 for rule in rules}\n",
  124. " self.total_violation_weight = Sum([If(Not(rule), self.weights[rule], 0) for rule in self.weights])\n",
  125. " \n",
  126. " def logic_forward(self, pseudo_label, data_point):\n",
  127. " attribute_names, target_names = self.attribute_names, self.target_names\n",
  128. " solver = self.solver\n",
  129. " total_violation_weight = self.total_violation_weight\n",
  130. " pseudo_label, data_point = pseudo_label[0], data_point[0]\n",
  131. " \n",
  132. " self.solver.reset()\n",
  133. " for name, value in zip(attribute_names, data_point):\n",
  134. " solver.add(eval(f\"{name} == {value}\"))\n",
  135. " for cate, name in zip(self.pseudo_label_list,target_names):\n",
  136. " value = 1 if (cate == pseudo_label) else 0\n",
  137. " solver.add(eval(f\"{name} == {value}\"))\n",
  138. " \n",
  139. " if solver.check() == sat:\n",
  140. " model = solver.model()\n",
  141. " total_weight = model.evaluate(total_violation_weight)\n",
  142. " # violated_rules = [str(rule) for rule in self.weights if model.evaluate(Not(rule))]\n",
  143. " # print(\"Total violation weight for the given data point:\", total_weight)\n",
  144. " # print(\"Violated rules:\", violated_rules)\n",
  145. " return total_weight.as_long()\n",
  146. " else:\n",
  147. " # No solution found\n",
  148. " return 1e10\n",
  149. " \n",
  150. "def consitency(data_sample, candidates, candidate_idxs, reasoning_results):\n",
  151. " pred_prob = data_sample.pred_prob\n",
  152. " model_scores = confidence_dist(pred_prob, candidate_idxs)\n",
  153. " rule_scores = np.array(reasoning_results)\n",
  154. " scores = model_scores + rule_scores\n",
  155. " return scores\n",
  156. "\n",
  157. "kb = ZooKB()\n",
  158. "reasoner = Reasoner(kb, dist_func=consitency)"
  159. ]
  160. },
  161. {
  162. "cell_type": "markdown",
  163. "metadata": {},
  164. "source": [
  165. "### Datasets and Evaluation Metrics"
  166. ]
  167. },
  168. {
  169. "cell_type": "code",
  170. "execution_count": null,
  171. "metadata": {},
  172. "outputs": [],
  173. "source": [
  174. "# Function to load and preprocess the dataset\n",
  175. "def load_and_preprocess_dataset(dataset_id):\n",
  176. " dataset = openml.datasets.get_dataset(dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False)\n",
  177. " X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n",
  178. " # Convert data types\n",
  179. " for col in X.select_dtypes(include='bool').columns:\n",
  180. " X[col] = X[col].astype(int)\n",
  181. " y = y.cat.codes.astype(int)\n",
  182. " X, y = X.to_numpy(), y.to_numpy()\n",
  183. " return X, y\n",
  184. "\n",
  185. "# Function to split data (one shot)\n",
  186. "def split_dataset(X, y, test_size = 0.3):\n",
  187. " # For every class: 1 : (1-test_size)*(len-1) : test_size*(len-1)\n",
  188. " label_indices, unlabel_indices, test_indices = [], [], []\n",
  189. " for class_label in np.unique(y):\n",
  190. " idxs = np.where(y == class_label)[0]\n",
  191. " np.random.shuffle(idxs)\n",
  192. " n_train_unlabel = int((1-test_size)*(len(idxs)-1))\n",
  193. " label_indices.append(idxs[0])\n",
  194. " unlabel_indices.extend(idxs[1:1+n_train_unlabel])\n",
  195. " test_indices.extend(idxs[1+n_train_unlabel:])\n",
  196. " X_label, y_label = X[label_indices], y[label_indices]\n",
  197. " X_unlabel, y_unlabel = X[unlabel_indices], y[unlabel_indices]\n",
  198. " X_test, y_test = X[test_indices], y[test_indices]\n",
  199. " return X_label, y_label, X_unlabel, y_unlabel, X_test, y_test\n",
  200. "\n",
  201. "# Load and preprocess the Zoo dataset\n",
  202. "X, y = load_and_preprocess_dataset(dataset_id=62)\n",
  203. "\n",
  204. "# Split data into labeled/unlabeled/test data\n",
  205. "X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3)\n",
  206. "\n",
  207. "# Transform tabluar data to the format required by ABL, which is a tuple of (X, ground truth of X, reasoning results)\n",
  208. "# For tabular data in abl, each sample contains a single instance (a row from the dataset).\n",
  209. "# For these tabular data samples, the reasoning results are expected to be 0, indicating no rules are violated.\n",
  210. "def transform_tab_data(X, y):\n",
  211. " return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))\n",
  212. "label_data = transform_tab_data(X_label, y_label)\n",
  213. "test_data = transform_tab_data(X_test, y_test)\n",
  214. "train_data = transform_tab_data(X_unlabel, y_unlabel)"
  215. ]
  216. },
  217. {
  218. "cell_type": "code",
  219. "execution_count": null,
  220. "metadata": {},
  221. "outputs": [],
  222. "source": [
  223. "# Set up metrics\n",
  224. "metric_list = [SymbolMetric(prefix=\"zoo\"), ReasoningMetric(kb=kb, prefix=\"zoo\")]"
  225. ]
  226. },
  227. {
  228. "cell_type": "markdown",
  229. "metadata": {},
  230. "source": [
  231. "### Bridge Machine Learning and Logic Reasoning"
  232. ]
  233. },
  234. {
  235. "cell_type": "code",
  236. "execution_count": null,
  237. "metadata": {},
  238. "outputs": [],
  239. "source": [
  240. "bridge = SimpleBridge(model, reasoner, metric_list)"
  241. ]
  242. },
  243. {
  244. "cell_type": "markdown",
  245. "metadata": {},
  246. "source": [
  247. "### Train and Test"
  248. ]
  249. },
  250. {
  251. "cell_type": "code",
  252. "execution_count": null,
  253. "metadata": {},
  254. "outputs": [],
  255. "source": [
  256. "# Pre-train the machine learning model\n",
  257. "rf.fit(X_label, y_label)"
  258. ]
  259. },
  260. {
  261. "cell_type": "code",
  262. "execution_count": null,
  263. "metadata": {},
  264. "outputs": [],
  265. "source": [
  266. "# Test the initial model\n",
  267. "print(\"------- Test the initial model -----------\")\n",
  268. "bridge.test(test_data)\n",
  269. "print(\"------- Use ABL to train the model -----------\")\n",
  270. "# Use ABL to train the model\n",
  271. "bridge.train(train_data=train_data, label_data=label_data, loops=3, segment_size=len(X_unlabel), save_dir=weights_dir)\n",
  272. "print(\"------- Test the final model -----------\")\n",
  273. "# Test the final model\n",
  274. "bridge.test(test_data)"
  275. ]
  276. }
  277. ],
  278. "metadata": {
  279. "kernelspec": {
  280. "display_name": "abl",
  281. "language": "python",
  282. "name": "python3"
  283. },
  284. "language_info": {
  285. "codemirror_mode": {
  286. "name": "ipython",
  287. "version": 3
  288. },
  289. "file_extension": ".py",
  290. "mimetype": "text/x-python",
  291. "name": "python",
  292. "nbconvert_exporter": "python",
  293. "pygments_lexer": "ipython3",
  294. "version": "3.8.13"
  295. }
  296. },
  297. "nbformat": 4,
  298. "nbformat_minor": 2
  299. }

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.