|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471 |
- {
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# MNIST Addition\n",
- "\n",
- "This notebook shows an implementation of [MNIST Addition](https://arxiv.org/abs/1805.10872). In this task, pairs of MNIST handwritten images and their sums are given, alongwith a domain knowledge base containing information on how to perform addition operations. The task is to recognize the digits of handwritten images and accurately determine their sum.\n",
- "\n",
- "Intuitively, we first use a machine learning model (learning part) to convert the input images to digits (we call them pseudo-labels), and then use the knowledge base (reasoning part) to calculate the sum of these digits. Since we do not have ground-truth of the digits, in Abductive Learning, the reasoning part will leverage domain knowledge and revise the initial digits yielded by the learning part through abductive reasoning. This process enables us to further update the machine learning model."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Import necessary libraries and modules\n",
- "import os.path as osp\n",
- "import torch\n",
- "import torch.nn as nn\n",
- "import matplotlib.pyplot as plt\n",
- "from examples.mnist_add.datasets import get_dataset\n",
- "from examples.models.nn import LeNet5\n",
- "from abl.learning import ABLModel, BasicNN\n",
- "from abl.reasoning import KBBase, Reasoner\n",
- "from abl.evaluation import ReasoningMetric, SymbolMetric\n",
- "from abl.utils import ABLLogger, print_log\n",
- "from abl.bridge import SimpleBridge"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Working with Data\n",
- "\n",
- "First, we get the training and testing datasets:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "train_data = get_dataset(train=True, get_pseudo_label=True)\n",
- "test_data = get_dataset(train=False, get_pseudo_label=True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "`train_data` and `test_data` share identical structures: tuples with three components: X (list where each element is a list of two images), gt_pseudo_label (list where each element is a list of two digits, i.e., pseudo-labels) and Y (list where each element is the sum of the two digits). The length and structures of datasets are illustrated as follows.\n",
- "\n",
- "Note: ``gt_pseudo_label`` is only used to evaluate the performance of the learning part but not to train the model."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y\n",
- "\n",
- "Length of X, gt_pseudo_label, Y in train_data: 30000, 30000, 30000\n",
- "Length of X, gt_pseudo_label, Y in test_data: 5000, 5000, 5000\n",
- "\n",
- "X is a list, with each element being a list of 2 Tensor.\n",
- "gt_pseudo_label is a list, with each element being a list of 2 int.\n",
- "Y is a list, with each element being a int.\n"
- ]
- }
- ],
- "source": [
- "print(f\"Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y\")\n",
- "print()\n",
- "train_X, train_gt_pseudo_label, train_Y = train_data\n",
- "print(f\"Length of X, gt_pseudo_label, Y in train_data: \" +\n",
- " f\"{len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}\")\n",
- "test_X, test_gt_pseudo_label, test_Y = test_data\n",
- "print(f\"Length of X, gt_pseudo_label, Y in test_data: \" +\n",
- " f\"{len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}\")\n",
- "print()\n",
- "\n",
- "X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]\n",
- "print(f\"X is a {type(train_X).__name__}, \" +\n",
- " f\"with each element being a {type(X_0).__name__} \" +\n",
- " f\"of {len(X_0)} {type(X_0[0]).__name__}.\")\n",
- "print(f\"gt_pseudo_label is a {type(train_gt_pseudo_label).__name__}, \" +\n",
- " f\"with each element being a {type(gt_pseudo_label_0).__name__} \" +\n",
- " f\"of {len(gt_pseudo_label_0)} {type(gt_pseudo_label_0[0]).__name__}.\")\n",
- "print(f\"Y is a {type(train_Y).__name__}, \" +\n",
- " f\"with each element being a {type(Y_0).__name__}.\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The ith element of X, gt_pseudo_label, and Y together constitute the ith data example. As an illustration, in the first data example of the training set, we have:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "X in the first data example (a list of two images):\n"
- ]
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAD1CAYAAADNj/Z6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAKUklEQVR4nO3df6zVdR3H8XPuDy4iCHcRUnKRQIFkOFHzV+Z0maVhM5dZyzY1JWCUbmk/tkwrXWXRQB0xayVUarOcthhmFJpL5GflYgTiryyV3wgIXrjnnv6x/kh4fy+de++58H48/n2d+z3ff+7hyXe7n1OuVqvVEgCQVkO9bwAAqC8xAADJiQEASE4MAEByYgAAkhMDAJCcGACA5MQAACQnBgAguaauvvADDZf15H0AXfC7zgfqfQsHzWcH1F/RZ4cnAwCQnBgAgOTEAAAkJwYAIDkxAADJiQEASE4MAEByYgAAkhMDAJCcGACA5MQAACQnBgAgOTEAAMmJAQBITgwAQHJiAACSEwMAkJwYAIDkxAAAJCcGACA5MQAAyYkBAEhODABAcmIAAJITAwCQnBgAgOTEAAAkJwYAIDkxAADJiQEASE4MAEByYgAAkhMDAJCcGACA5MQAACQnBgAgOTEAAMmJAQBIrqneNwBAbg2DBoV7+Z1H99KdHNiLlw4L9/ahneHetCv+v/foezeGe2Xt+nCvlScDAJCcGACA5MQAACQnBgAgOTEAAMmJAQBITgwAQHLOGehGjUMGh/tzd48M99XvnRfu7dWOcD971afDvVyuxvtDbwv3hko412zob9aFe2Xzlp69AaAunr9hYrj//do5vXQn9XP/J1rD/Sfjju3R9/dkAACSEwMAkJwYAIDkxAAAJCcGACA5MQAAyYkBAEjOOQPdqLJjV7hfevxfw72zFJ8D0FxuDPelp9wb7g2lcvz+J8fv39OmTD033Ddc/PZwr2za1I13A3RV07Ft4d7x4kvh3m9HfP3jfj4t3C86b0V8gW6w8A+nhnv/zfHn67CV7eHe/Hj870OpFJ8zUytPBgAgOTEAAMmJAQBITgwAQHJiAACSEwMAkJwYAIDkxAAAJOfQoe7UWQnnXzx2Vrh//eN/7s67OeTc3fZYuH/oxCnh3vR7hw5BPTTP3xvu65aeGe6jv/hkTe+/tqaf7prRpSU9ev36HvnmyQAApCcGACA5MQAAyYkBAEhODABAcmIAAJITAwCQnHMGetG4b8R/DTtx54yarj/0tA3hfuaw58O9s1oO9wdXnhLuyy+cFe6DG/qHO9A3bfh8fEbK0uNmh/v4Z6d15+3QAzwZAIDkxAAAJCcGACA5MQAAyYkBAEhODABAcmIAAJJzzkAvqmzbFu7Hfq1nvy/76cJXxN+o/e4JO8L9tQ/GPz+4ID2/s2VCuPdbsibcO+PLAwfQ2Noa7rd8bn64b620h/vYufFO/XkyAADJiQEASE4MAEByYgAAkhMDAJCcGACA5MQAACTnnAG6bM+sN8J9ZNMRNV1/0ZffF+4tu5fXdH1g/9bMHBPulxy5ONzHzrsx3N+1rGfPUKF2ngwAQHJiAACSEwMAkJwYAIDkxAAAJCcGACA5MQAAyTlngP+qnHdyuP903B0FV4jPGViwe3C4H7lmY7h3FLw7sH8v33hWuD9+/u0FVxgYrsfNXBfulYKrU3+eDABAcmIAAJITAwCQnBgAgOTEAAAkJwYAIDkxAADJOWcgkYb+/cP9nnnxOQJDG+NzBHZX94b7HdMuD/fm51eGO7B/uy47PdxXXD873FvK8TkCReauejjcZ286J9yX3faecB/w4NKDvicOjicDAJCcGACA5MQAACQnBgAgOTEAAMmJAQBITgwAQHLOGUhkz6+Hh3vROQJFzllxdbgPX+QcAfh/NI04Jtx/9r2ZBVdoCdevbpwY7le1LomvXo7ffeY7VoX7rG9uDfdH/zgm3Cubt8Q3QCFPBgAgOTEAAMmJAQBITgwAQHJiAACSEwMAkJwYAIDknDNwGNk446xwXzHhrnDvLLj+svb4j4lHTI3/Vrij4PrAAXTEvz1T1n8y3F9e1BbuI771ZLgvL50d7o1HDwv3zzzxVLhf3/pCuP/gugvDfdRN8TkIFPNkAACSEwMAkJwYAIDkxAAAJCcGACA5MQAAyYkBAEjOOQOHkPKkCeG+4ivxOQKN5YL2q8YnDVz74xnh3vZq/LfKcCja+PD4cB82cFfhNRo/9nq4V7ZtC/eOVzeEe8P74/cfUXopfkGNqrv3hPuWysCCK+wI136vxWecUDtPBgAgOTEAAMmJAQBITgwAQHJiAACSEwMAkJwYAIDknDPQl5xxYjjf98DccO8stcTXLzhH4PhfTg/3cd//S8H7w+Hn6EE7w/2R8QsKr3HC9Ph3q+22vn1GR7m5X7ivvTU+A2XK4CfC/Ve7jgr3th+tDvdKuNIVngwAQHJiAACSEwMAkJwYAIDkxAAAJCcGACA5MQAAyTlnoBc1DBgQ7ttvjr/zfGBDwTkCBS5/7oJwH3vDqnDv3Le3pveHQ9G61SPiF4wvvsa3r7wn3L/UdGW4D3kmPsVj4D/eCPddI/uHe0dLOdxPnx5/NjxyTHwGyu7O+LPj1juuCPdh2/v2OQyHA08GACA5MQAAyYkBAEhODABAcmIAAJITAwCQnBgAgOScM9CLXrnmpHBffuKdNV3/6b3xt3q/flF7uFedIwBvMX7O1nBfffGewmt85MiC/bNzDuaW3mJzJT6jZGhjwQ3UaFn7vnC/dvb14T78LucI1JsnAwCQnBgAgOTEAAAkJwYAIDkxAADJiQEASE4MAEByzhnoRo0njA33R2/8bsEV4u8c31eNzxGYfvN14T5k55KC9wf+V2XNM+E+efGMwmusv+CH4d5Yru3/ZT19jsD9O1vDff4l54f78DXOEejrPBkAgOTEAAAkJwYAIDkxAADJiQEASE4MAEByYgAAknPOwEFoGDAg3J+9OT4noLUh3ouc+tTV4d423zkC0NvGXrWy8DWTHvpUuD992n3ddTv7dcUL54b7qgUnhPuoBzaEe2VdfBYDfZ8nAwCQnBgAgOTEAAAkJwYAIDkxAADJiQEASE4MAEByYgAAknPo0EF45ZqTwv1vZ99Z0/X/9EZzuI+aVnDwR03vDvSUEVe+HO7nT4oPFHvhw/3CvW1R/Nvf77cr4p+vPhnuPlsOf54MAEByYgAAkhMDAJCcGACA5MQAACQnBgAgOTEAAMk5Z+A/TptY+JKFN9xe8IojwvWfHXvC/ZYZU8O9ZdPygvcH+qLK9tfCvXHxqnAfs7g77wbeypMBAEhODABAcmIAAJITAwCQnBgAgOTEAAAkJwYAIDnnDLxp31Hx94WXSqXS0Mb4HIEiN/1rcri3LHSOAAC9z5MBAEhODABAcmIAAJITAwCQnBgAgOTEAAAkJwYAIDnnDHSjL7xyRrhv/WjxWQYA0Ns8GQCA5MQAACQnBgAgOTEAAMmJAQBITgwAQHJiAACSc87Am5oXrSx8zeRjTil4xb6CfWOX7wcAeosnAwCQnBgAgOTEAAAkJwYAIDkxAADJiQEASE4MAEByYgAAkhMDAJCcGACA5MQAACQnBgAgOTEAAMmJAQBITgwAQHLlarVarfdNAAD148kAACQnBgAgOTEAAMmJAQBITgwAQHJiAACSEwMAkJwYAIDkxAAAJPdvx8SA9zSwVzUAAAAASUVORK5CYII=",
- "text/plain": [
- "<Figure size 640x480 with 2 Axes>"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "gt_pseudo_label in the first data example (a list of two ground truth pseudo-labels): [7, 5]\n",
- "Y in the first data example (their sum result): 12\n"
- ]
- }
- ],
- "source": [
- "X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]\n",
- "print(f\"X in the first data example (a list of two images):\")\n",
- "plt.subplot(1,2,1)\n",
- "plt.axis('off') \n",
- "plt.imshow(X_0[0].numpy().transpose(1, 2, 0))\n",
- "plt.subplot(1,2,2)\n",
- "plt.axis('off') \n",
- "plt.imshow(X_0[1].numpy().transpose(1, 2, 0))\n",
- "plt.show()\n",
- "print(f\"gt_pseudo_label in the first data example (a list of two ground truth pseudo-labels): {gt_pseudo_label_0}\")\n",
- "print(f\"Y in the first data example (their sum result): {Y_0}\")"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Building the Learning Part"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "To build the learning part, we need to first build a machine learning base model. We use a simple [LeNet-5 neural network](https://en.wikipedia.org/wiki/LeNet), and encapsulate it within a `BasicNN` object to create the base model. `BasicNN` is a class that encapsulates a PyTorch model, transforming it into a base model with an sklearn-style interface. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [],
- "source": [
- "cls = LeNet5(num_classes=10)\n",
- "loss_fn = nn.CrossEntropyLoss()\n",
- "optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, alpha=0.9)\n",
- "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
- "\n",
- "base_model = BasicNN(\n",
- " cls,\n",
- " loss_fn,\n",
- " optimizer,\n",
- " device,\n",
- " batch_size=32,\n",
- " num_epochs=1,\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "`BasicNN` offers methods like `predict` and `predict_prob`, which are used to predict the class index and the probabilities of each class for images. As shown below:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Predicted class index for a batch of 32 instances: ndarray with shape (32,)\n",
- "Predicted class probabilities for a batch of 32 instances: ndarray with shape (32, 10)\n"
- ]
- }
- ],
- "source": [
- "data_instances = [torch.randn(1, 28, 28).to(device) for _ in range(32)]\n",
- "pred_idx = base_model.predict(X=data_instances)\n",
- "print(f\"Predicted class index for a batch of 32 instances: \" +\n",
- " f\"{type(pred_idx).__name__} with shape {pred_idx.shape}\")\n",
- "pred_prob = base_model.predict_proba(X=data_instances)\n",
- "print(f\"Predicted class probabilities for a batch of 32 instances: \" +\n",
- " f\"{type(pred_prob).__name__} with shape {pred_prob.shape}\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "However, the base model built above deals with instance-level data (i.e., individual images), and can not directly deal with example-level data (i.e., a pair of images). Therefore, we wrap the base model into `ABLModel`, which enables the learning part to train, test, and predict on example-level data."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [],
- "source": [
- "model = ABLModel(base_model)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "As an illustration, consider this example of training on example-level data using the `predict` method in `ABLModel`. In this process, the method accepts data examples as input and outputs the class labels and the probabilities of each class for all instances within these data examples."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Predicted class labels for the 100 data examples: \n",
- "a list of length 100, and each element is a ndarray of shape (2,).\n",
- "\n",
- "Predicted class probabilities for the 100 data examples: \n",
- "a list of length 100, and each element is a ndarray of shape (2, 10).\n"
- ]
- }
- ],
- "source": [
- "from abl.structures import ListData\n",
- "# ListData is a data structure provided by ABL-Package that can be used to organize data examples\n",
- "data_examples = ListData()\n",
- "# We use the first 100 data examples in the training set as an illustration\n",
- "data_examples.X = train_X[:100]\n",
- "data_examples.gt_pseudo_label = train_gt_pseudo_label[:100]\n",
- "data_examples.Y = train_Y[:100]\n",
- "\n",
- "# Perform prediction on the 100 data examples\n",
- "pred_label, pred_prob = model.predict(data_examples)['label'], model.predict(data_examples)['prob']\n",
- "print(f\"Predicted class labels for the 100 data examples: \\n\" +\n",
- " f\"a list of length {len(pred_label)}, and each element is \" +\n",
- " f\"a {type(pred_label[0]).__name__} of shape {pred_label[0].shape}.\\n\")\n",
- "print(f\"Predicted class probabilities for the 100 data examples: \\n\" +\n",
- " f\"a list of length {len(pred_prob)}, and each element is \" +\n",
- " f\"a {type(pred_prob[0]).__name__} of shape {pred_prob[0].shape}.\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Building the Reasoning Part"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "In the reasoning part, we first build a knowledge base which contain information on how to perform addition operations. We build it by creating a subclass of `KBBase`. In the derived subclass, we initialize the `pseudo_label_list` parameter specifying list of possible pseudo-labels, and override the `logic_forward` function defining how to perform (deductive) reasoning."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [],
- "source": [
- "class AddKB(KBBase):\n",
- " def __init__(self, pseudo_label_list=list(range(10))):\n",
- " super().__init__(pseudo_label_list)\n",
- "\n",
- " # Implement the deduction function\n",
- " def logic_forward(self, nums):\n",
- " return sum(nums)\n",
- "\n",
- "kb = AddKB()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The knowledge base can perform logical reasoning (both deductive reasoning and abductive reasoning). Below is an example of performing (deductive) reasoning, and users can refer to [Documentation]() for details of abductive reasoning."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Reasoning result of pseudo-label example [1, 2] is 3.\n"
- ]
- }
- ],
- "source": [
- "pseudo_label_example = [1, 2]\n",
- "reasoning_result = kb.logic_forward(pseudo_label_example)\n",
- "print(f\"Reasoning result of pseudo-label example {pseudo_label_example} is {reasoning_result}.\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Note: In addition to building a knowledge base based on `KBBase`, we can also establish a knowledge base with a ground KB using `GroundKB`, or a knowledge base implemented based on Prolog files using `PrologKB`. The corresponding code for these implementations can be found in the `main.py` file. Those interested are encouraged to examine it for further insights."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Then, we create a reasoner by instantiating the class ``Reasoner``. Due to the indeterminism of abductive reasoning, there could be multiple candidates compatible to the knowledge base. When this happens, reasoner can minimize inconsistencies between the knowledge base and pseudo-labels predicted by the learning part, and then return only one candidate that has the highest consistency."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [],
- "source": [
- "reasoner = Reasoner(kb)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Note: During creating reasoner, the definition of \"consistency\" can be customized within the `dist_func` parameter. In the code above, we employ a consistency measurement based on confidence, which calculates the consistency between the data example and candidates based on the confidence derived from the predicted probability. In `main.py`, we provide options for utilizing other forms of consistency measurement.\n",
- "\n",
- "Note: Also, during process of inconsistency minimization, one can leverage [ZOOpt library](https://github.com/polixir/ZOOpt) for acceleration. Options for this are also available in `main.py`. Those interested are encouraged to explore these features."
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Building Evaluation Metrics"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Next, we set up evaluation metrics. These metrics will be used to evaluate the model performance during training and testing. Specifically, we use `SymbolMetric` and `ReasoningMetric`, which are used to evaluate the accuracy of the machine learning model’s predictions and the accuracy of the final reasoning results, respectively."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [],
- "source": [
- "metric_list = [SymbolMetric(prefix=\"mnist_add\"), ReasoningMetric(kb=kb, prefix=\"mnist_add\")]"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Bridging Learning and Reasoning\n",
- "\n",
- "Now, the last step is to bridge the learning and reasoning part. We proceed this step by creating an instance of `SimpleBridge`."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [],
- "source": [
- "bridge = SimpleBridge(model, reasoner, metric_list)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Perform training and testing by invoking the `train` and `test` methods of `SimpleBridge`."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Build logger\n",
- "print_log(\"Abductive Learning on the MNIST Addition example.\", logger=\"current\")\n",
- "log_dir = ABLLogger.get_current_instance().log_dir\n",
- "weights_dir = osp.join(log_dir, \"weights\")\n",
- "\n",
- "bridge.train(train_data, loops=5, segment_size=1/3, save_interval=1, save_dir=weights_dir)\n",
- "bridge.test(test_data)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "abl",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.13"
- },
- "orig_nbformat": 4,
- "vscode": {
- "interpreter": {
- "hash": "9c8d454494e49869a4ee4046edcac9a39ff683f7d38abf0769f648402670238e"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
|