{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Handwritten Equation Decipherment (HED)\n", "\n", "This notebook shows an implementation of [Handwritten Equation Decipherment](https://proceedings.neurips.cc/paper_files/paper/2019/file/9c19a2aa1d84e04b0bd4bc888792bd1e-Paper.pdf). In this task, the handwritten equations are given, which consist of sequential pictures of characters. The equations are generated with unknown operation rules from images of symbols ('0', '1', '+' and '='), and each equation is associated with a label indicating whether the equation is correct (i.e., positive) or not (i.e., negative). Also, we are given a knowledge base which involves the structure of the equations and a recursive definition of bit-wise operations. The task is to learn from a training set of above mentioned equations and then to predict labels of unseen equations. \n", "\n", "Intuitively, we first use a machine learning model (learning part) to obtain the pseudo-labels ('0', '1', '+' and '=') for the observed pictures. We then use the knowledge base (reasoning part) to perform abductive reasoning so as to yield ground hypotheses as possible explanations to the observed facts, suggesting some pseudo-labels to be revised. This process enables us to further update the machine learning model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Import necessary libraries and modules\n", "import os.path as osp\n", "import torch\n", "import torch.nn as nn\n", "import matplotlib.pyplot as plt\n", "from examples.hed.datasets import get_dataset, split_equation\n", "from examples.models.nn import SymbolNet\n", "from abl.learning import ABLModel, BasicNN\n", "from examples.hed.reasoning import HedKB, HedReasoner\n", "from abl.data.evaluation import SymbolAccuracy\n", "from examples.hed.consistency_metric import ConsistencyMetric\n", "from abl.utils import ABLLogger, print_log\n", "from examples.hed.bridge import HedBridge" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Working with Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we get the datasets of handwritten equations:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "total_train_data = get_dataset(train=True)\n", "train_data, val_data = split_equation(total_train_data, 3, 1)\n", "test_data = get_dataset(train=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The dataset are shown below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Equations in the dataset is organized by equation length, from 5 to 26\n", "\n", "For each euqation length, there are 225 true equation and 225 false equation in the training set\n", "For each euqation length, there are 75 true equation and 75 false equation in the validation set\n", "For each euqation length, there are 300 true equation and 300 false equation in the test set\n" ] } ], "source": [ "true_train_equation = train_data[1]\n", "false_train_equation = train_data[0]\n", "print(f\"Equations in the dataset is organized by equation length, \" +\n", " f\"from {min(train_data[0].keys())} to {max(train_data[0].keys())}\")\n", "print()\n", "\n", "true_train_equation_with_length_5 = true_train_equation[5]\n", "false_train_equation_with_length_5 = false_train_equation[5]\n", "print(f\"For each euqation length, there are {len(true_train_equation_with_length_5)} \" +\n", " f\"true equation and {len(false_train_equation_with_length_5)} false equation \" +\n", " f\"in the training set\")\n", "\n", "true_val_equation = val_data[1]\n", "false_val_equation = val_data[0]\n", "true_val_equation_with_length_5 = true_val_equation[5]\n", "false_val_equation_with_length_5 = false_val_equation[5]\n", "print(f\"For each euqation length, there are {len(true_val_equation_with_length_5)} \" +\n", " f\"true equation and {len(false_val_equation_with_length_5)} false equation \" +\n", " f\"in the validation set\")\n", "\n", "true_test_equation = test_data[1]\n", "false_test_equation = test_data[0]\n", "true_test_equation_with_length_5 = true_test_equation[5]\n", "false_test_equation_with_length_5 = false_test_equation[5]\n", "print(f\"For each euqation length, there are {len(true_test_equation_with_length_5)} \" +\n", " f\"true equation and {len(false_test_equation_with_length_5)} false equation \" +\n", " f\"in the test set\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As illustrations, we show four equations in the training dataset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "First true equation with length 5 in the training dataset:\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgQAAABpCAYAAABF9zs7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAANF0lEQVR4nO3de3CU1RnH8bO7WXIBJAki14BIQAQZURFBEcdOFW0VBisD6CgdqQwqIFClDjhjp3Z6wYpS0BGhilQ7WmuxUxURtQpeuIiGooJIuClgBCEQkizZ7Pv2H+d5ztLdXEj23dv389dvsyfZk2V3OTnPe87xua7rGgAAkNX8ye4AAABIPgYEAACAAQEAAGBAAAAADAMCAABgGBAAAADDgAAAABgGBAAAwBiT09SGV/vHJbIfWWuN81KLf8YLO4e0Qk9aV6G/RvKfJt0Ss829zz4vudIpSHifmmtC6cct+n7eM4nRGu+ZTXt7tUJPcKpLeu1t0ffznkmMpr5nmCEAAAAMCAAAAAMCAABgGBAAAADDgAAAABgGBAAAwDAgAAAAhgEBAAAwDAgAAIBhQAAAAAwDAgAAYJpxlgHQGgLVdcnuAn4QGHiu5O13FUru23+/5H1HiiT3HLfVk34B6axi+mWSN92/KGabIQ9Pl9zlsQ8T3qemYoYAAAAwIAAAAFlQMqgZe6nk3nO2SV7Ra63kUd0Ge9mlrJDnD0uO5AYku7+pklzj5nrap6zl1+c/cuUFkn+59K+Sr8oPaXPjk1zraonn2S/6Sv6qtrPk11dfIrnHO9o++NbmlvQaSBvHJw6TvHz2o5LDri9Wc2PcRPfo9DBDAAAAGBAAAIAsKBkcGKlTNuusMgFaX/eco5Jn/PFuyZ2/+U7y5JL3JNe5OpWNxLHLBKueW9qs7831BSVP6bBH8sqc45IfnrRB8ucT6yXPvWZidD++2tWsx06koM/x9PEC1hxxtasfuzWOls38HvepKRxX/2Zs79eyUp4vIjli4kyLZ5GaLvo89Qum7/PBDAEAAGBAAAAAsqBkAO/Y06L53+v0p69epxernTaS7ZUIaF11o4ZIfnbpY9Y9+ZIm7holuXaC/rscurqX5GO6sMC0/Vpzlxe3S35iWKnkt5ctkXxwVJeoPp2VQiWDd6r7e/p4JcEjkuf+/RbJfeZ/ro3y8zRH9D3jNV+O/rfgVNdI3r5QN7L67eUrJR+qby95cGK7hgRjhgAAADAgAAAADAgAAIDJgmsIysc/GfPrt+0dad06HrMNTp/rT9+lN5kgNFOXgHYN6HUDK6uLJdeMtpaOHT0guWi5leP8fLvCnbtqU8w2t059I+r26sVnNNRlT60aWOjxI+rj9c7R5yvi6HU39UP0go2TRfrR7MVqRHsFcMF+XV7o2/CZ5H6/+FTyCqck5s+ZnXorJz0xdMKWZHehVTBDAAAAGBAAAIAMLRnYBxoZUxazzQfrB0guNesT2yHAAwfu1XPYt17whOTycK3k5WOukRw5+lWr96HvK3dKfuEni6PuW20uObV50lTMuKzxRh4IhLRk8PN7Xpc8uYP+29S4iV+eW+TXstKUr7WcuuUvQyXX51MGtIWu1+dmaclTksNxdmCdtOfHkrss/DBxHWsBZggAAAADAgAAkKElg3WPL2m0TeksygRIf76g7jDYd7ROM0dcvdz7mnXTJZdu0yvFE6Hn6/q4A8ZE77ZX/TMt5bV9eYNJpvfmPJLUx4/lGz0XynwW1o/mgAnGaN26dlnrRn7VdbXkzg+uaeZPmtVKPUpNzpUXSp732DOSw24kZt5mVXv2Legnua1J7us/HmYIAAAAAwIAAJChJYN42IwoOVzrgBS0rn1z9BCjLX0WSS6r02nLkueT8zbP9UVPdde1078/2nrdmVOErZJKyNUr/Q9EciXbh3V5zevHth+vyglauXk/J3XWkSRGqFhLdCPyqq17Yq8sGP/PGZL7vJz6ZWpmCAAAAAMCAACQQSWDnY8Os26VxWxTMZwyQSIV+PWSWp+1R/vuxZ0lnxHQfdILfCe96VgjqhzdlCVi0mvzlboBtTG/Pm/3WMnxzhpIhFBh7KnTVHPT1JmSD5+vU+Rv3D1fckWkjckkdlmgwF8f977msksMmSgw8FzJox96K4k9STxmCAAAAAMCAACQQSWDeLJtZYE99Rf01TfQsnX0zNFjdsc/M1vyOe+XS758XoXkdcd1c47Ppw6U7LTxdqo5cELLFcE/6+8wrfvbnvajuQID+kXdXj1ikXVLSx+1C7pLzjP7E90tcXS0rijZFo7eg7/Tm7slJ/6V2bDcVZ9IPtO9SLJ9UX21m1klgyonT/LCm8dH3eev038RX6Tx8oHPan/dS7rJzkWxGqe5LyfrIeD/KNzWaPuP6/R10/uV1CiLNhUzBAAAgAEBAADIoJJB+fgnY349G445jt5URKcFt1b3kJzjb+YOIw046ejLpihHp4gLDsaeatw8SycSXb9exR9oE4nVHA2oXRg9BXl2ToFk+9javFc3etYn25dXrJC84MiAqPvqD37rdXfiCnTqKLlg817JUy8ak4zueC5w4suo2762+jpyzu6qX29C+SAT2SsLto9/3L4nZvugT7/+wOwpkvPXJed9eLqYIQAAAAwIAABAmpcMmrIZUbe1mT/lVRjQPbUX771KcpurdSo0UFRkWsKtt65CLtEpxcpBxZI77Tqh3xDQKbQrFmqpZkQ7naqMuKkxHrU3Jqp0ChpomXyuG71xkmOVi9btOUdyb/Nfz/pks49dfnXmVVH3Bc1mr7sTl3OkUnLgTH0NH7uit2RfBle0AnXRJcTDg3RzofXTFkiuiDS+HiTTNyayjzOOZ8yOGyS3+2iP5HR7CaXGJzIAAEgqBgQAACC9Swbx2JsRFazc0EDLzBBydcquX4fvJG+86zLJTq5pEb91cXvlkDrJu6/V1R2DFtwluWTZPsk1jm7UUW11xO430teJcZdat3TDnzaV0SsiUql4V3GHHht9vK9On9urlQ5Hqk22sI+A/iJsHwGdnX8z7h7XsfFGdvvD2r5nxdbW7o5nsvNfGwAARGFAAAAA0rtkkM2bEdnsafjRxZ9KvvO+dyW39Fhfe/OjSuuq/PKwrixwrT07Isf03Ih6h3FnpvEX6GqMC+eUSX65WlezBPYfjvqeZJ9fYHt37iOST1orIzae1PJWwGRnSaslRyGns+/vGC55ze3zrXsy60yLhvBJDQAAGBAAAIA0LBk0ZTOi0lmZXyaIxy4f7HDOSshjnBWoknzjp3dILlldKbn89/rvdEvblySH3bR7ySEGd2AfyY92Wy55+IPTJHc8+JGXXWqWnWF7T/qAlbJjuvzU37NHTuyCTsAqNX5r7bJTZa0QypTnLNxOf9dif/PKBJ2X5zXeKA0wQwAAABgQAACANCwZxFtZYG9GZMzxmG3QOtr7Q5JrtxVKdsp0injwYmujjuARyd9F2ie2cxnO54uenvW3cPVIc1RN0DLQmHlvx+xDx2WpWyawFfrrGm+UYYLWS2V9qHvUffdMHy/ZsRoGq7SU4Mz5XvKK/s9JPhTJjKvwt9z3hOSwG/uY47Uh/V0fmnG75NzXNiWuYx5ihgAAADAgAAAAaVgyiCfbNiNKJnuTIyeoU9i+oE6nHTuZH7M9WubrLV2jbp88T6d0Vw5bInnmiDsl+98vO+3Hqx0zVPJND7wpeURbPcb6R1PvlpxnNp72Y3lp2pgpye5CctVHH39csNM6LjuoKwicKl1RtHfKIG2SuJ4ljX3Mcbwjj5d+e6XkTCkT2JghAAAADAgAAECalAzYjCi1tDE63RjnYlwkSJ97o1/nI/rdKnnDxX/Trz+uU/f/uf9yybmrdJoz0LFY8uEbzpV8aJhOl7513QLJ5WE9p2DS8nsk9/z3h03/BZAacqL/FvT1LpG858ZOkl+bonv6V0Q2Sz6QISsLEI0ZAgAAwIAAAAAwIAAAACZNriGItzthnxenSmapYWLl+cKSN4Z6Sy44qGNKN6y7v9W7jDW90O7pDpJ/10uXhT1w5meSJy/R6wkORPTwqzyfXitwXnCNZHvnwUWV50t+/pHrJPd8Or2vG1j8r6eS3YWUFbIuDDri6LUCQV/spXjZZGP52ZL7miPxG6YpPrUBAAADAgAAkCYlAyRfx8AJybPW60EofRfodHT1jZdKvrDwE8khNxP3NUsN+a/o87+xrFRy6Vxdarjjp1py6xy1TFT/Hrh221jJh1/rIbnbsq2Si6vS4+Cipqh0WDaHaAPemyx5y8jYJaX+f9DPwUwsoDBDAAAAGBAAAIA0KRnctnek5BW91kpmd8Lk8Af0QCO3Xg/XqbxVD0KZ1PEDyeVh3fkMiVO/Z5/kflM0X28ubvR7c4y272JlJ1ZjIAOdc3OZ5LFmaJxWOzzpS7IwQwAAABgQAACANCkZVAw/LnmUGZy8juD/+XQTm1CtXrkdctPipQUA+AEzBAAAgAEBAABIk5IBki9gdGWBP8C15wCQaZghAAAADAgAAAAlA8RR6K+Juj3/m2sl9/l1rd5RXORVlwAACcQMAQAAYEAAAACM8bmu6zbeDAAAZDJmCAAAAAMCAADAgAAAABgGBAAAwDAgAAAAhgEBAAAwDAgAAIBhQAAAAAwDAgAAYIz5HxK9QIKCV9rsAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "First true equation with length 8 in the training dataset:\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "First false equation with length 5 in the training dataset:\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgQAAABpCAYAAABF9zs7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAPzklEQVR4nO3deXTU5b3H8WdmEkJCAmE1LIFIIGwiIpUWuJQ1eqqWq2ilUtriRbGISlWqt1fusdjFVq+eVmtdClpjLUgXoOBy9VBoK0WiKBVkly1BDGsIAplMZn73j97z/T7DmZDJZOY3M8n79dcnyW8yD+E3yTPP91k8juM4BgAAtGreZDcAAAAkHx0CAABAhwAAANAhAAAAhg4BAAAwdAgAAIChQwAAAAwdAgAAYIzJiPbCUu/XEtmOVuvt0O+b/T3W7e8fh5Y0X1HG55Kn33uf5Pbr9uhFHTtI9Jw+I/ngjGLJb9z9qOS99bnxbmbUxhftbtbjec0kRjxeM5sPFsahJTjfZb0rmvV4XjOJEe1rhhECAABAhwAAADShZAA0ptbxSD4xwCfZE+onuX15pattAgBEhxECAABAhwAAAFAyQBx9FmwnecXsxyT7jJ6wPfWJ+yX3XGqtPgBcZN+TvazfgiFOg4+K16Plwcr6JDYEccUIAQAAoEMAAAAoGSBB7PJBO0+d5O/fsURyWdllbjYJEEGjQ97DV8yT7PgoGUTDE9Kf38+vfFny0GQ0BnHDCAEAAKBDAAAAWlDJIDR2uOS9t+vnd0xYFPH6SXPnSM5eUZ6wdiF8eLYi0Em/4IQkhjL10528bSTvTWjL0Bhfvp49UbO0s+QnBrwq+QdXT5cc3N688x/cMuPJeyUPWqJ3mXPunF7k4f1SQxy/X/Le8d2S2JLUc2TlQMkfXrFU8o2fTJZ8Zk4XyaGtO9xpWBS44wEAAB0CAACQhiUDewhzz7N9JJeN1NLA8CxrKLqB7/PMz38h+Ru99ajenssPSq6vYN/9RPK005UIvVedkHzZpbdJfmP005IPBZN3FHJr5b9cz6FYM/RZyV7rvURdQZ5k33Z32tVcvV7RTbFO/dvFkkc9qOXDY3Xcbw3Ts0quyKaw5xk+RPL6ES9KDjj6c1rS938lf7xKd3N6cOxUyfWVhxLVxKgwQgAAAOgQAACANCkZ2GWCU6U6g3PL2F9KtocwGyoT2EoydSb7xge0fPDv112nz3uzzp4NVh2Jtrm4AHsPeZOl/weeQ1WSQ58OkJypCxSQBAduDUq2X2Nek97/Mdes1dpG14yNkke31SHbaH6PwJjqUFr8GYk7Z/QwyfeU6WqCTI+WCW45MEny/hpdYbV26O8lV0zT0nf3xykZAACAJKNDAAAAUrhkYB2vaa8msMsEibBywArJs5aXSj465rwhUo5JjVqto7sOTc7dJnndiyWS6yafkdxKRyBTXsgaRH+2WlcftNmkM/aDJj1MzNkZ8fMnuPlwAZ4RuprgvrLfSZ6QXSv52/t1A6JTN2hZtH1na/e1tzQ+P/cpyQ89PiJeTY0JIwQAAIAOAQAAoEMAAABMqs0hsOYNnHqtWPKWYYtj/palW2+SPKFgl+QFXT5q9LGL+7wt+bZ/TAr72rEpWZKDR4/G3L7Wxl522Maru3XVJaMxiOj4rFGSd47TnSJD1vuHw3W6FDhYU+NOw+IowHshRMk3RJdBz331j5LteQOzK8ZLrr5G59oEq3U5ta+D7uh5MqSPHW4tv67+pr72jDEm/+UNMbY6NrwqAAAAHQIAAJBqJYMvDpX41yaWCZaf0V2gnrnra5JzN2iZ4P32RZJHTBsnuXT6u5J/UqC7ltkW914b9vGwZ74tufBGSgZoOa68c73kkFXisZcdfjB7mPWILW40C0iKHXfkS74yW5dH7wn4JR+e01uyU/1xxO8T3KnLc0f/fa7k7eP0b10wyyQVIwQAAIAOAQAASHLJIDR2eNjHsxataNLjB625XXL/XwYktyl/X3LYzmnWbGj7EIn3PhkpufIXf5PcK6Ph8ZvCjtVNaiuQyg4sHC15dTfdOS38ECN9/+C8l95lAnu1S5+MyAc1nQjpKpgTQd1lzuthl9KW7tx1+jdh3Vcfl1ypt4SZP+Ebkp19kcsEDSl+UstvlWPOSW57Y1X4hYua9G2bjRECAABAhwAAACS5ZJD5cPjwyPW5Rxp9zEa/Dt0VP2+dWF4e+xBm9opyyXfN042MlpesbPAxt/TSmdgvjrw2Lu2IVVHG564/Z6w6ePWs8M5ZOmP3dFCLO06m/r/2zsi1Hp0+/850M+Yq3ajLXk1gv2cY8AedGd3fvGvS2cd1BZJvek6Hfn3WDlldv1Ipec3gP0s+GTzb5Oc74+jP9Bjlh5RXMP8Tyd192ZKHvHyn5Iv3NWPToHf19bY70FGft134Jl+nY3+GmDBCAAAA6BAAAIAklAxqpn9J8t9Knj7vq5H7J1/dOUWyM1FXB3jN5ng2zRhjzJ6NffT7l2h7Mj2+sOtuyD0m+cxLeubBskEFxm3THpjv+nPGzJrQ3ea0DqPm9vxUcvEyncr7pfXf0YfaI9ku2PiKu8/nNvvMgtcLI59Z8NARXQnUf156lwlsVYF8yX0W7ZbsaaPD+bWbu0v+QtEcyd6wpUsNy6jVcsDhcXrz7piiP+tP6/0GqaHiQV1ps6xIVxasOZcvuf9z+nvKWnDQZN5LB0ouzPiH5J0rS8Ku62GOGTcxQgAAAOgQAACAJJQMRtzzoeTw2cwN8/9Mh+7amEMXuLL5+j70geRh/fS8gn+OeinsOrvtBZnVkn2Dvyw5uG2XcUOnjZ+58jzx5visjW5ydSZv1oHjVrYe4Im8gQxiE82ZBav2XyK5h9nmTsNc8MUc3Ve+x/qTkr+/6XrJ/f77hOSuVU3fZN7j6M8092BbyVetvkOy12/9DvSm//297s1kt6BpfBd1k/yTmWX6eau2eU/ZLMm99+nwfnMcv1xXFvTL1Hsrqzq5q04YIQAAAHQIAACASyWDkzN1NvN93R6zvtLwMFzpVt0gyD7COMoJvjFz/Drr91xN2wtcqfK8tZIDnXIku9XbuuutN1x6pvgq8OkmHNPKb5Nc9HUdzt2/VI/EfnXkryV/Fmyf4NYZY8xjjV+SZmpu1lU+C7vpbHf7zIJNfr1zezwSvrqmpWjr0d8kQ7MOS/7DqOckn36rjWT77INYdPXpfvWT3/6u5P4v6Fz1+txMA3ft/J+ekq/JOSV5yN9nS7744fiUCWxnp9Q0flESMEIAAADoEAAAAJdKBp8X6nBkjwscKRzmV10lBmv2xbtJDcoo7CV55oj4DxUlQo43PTc3yfPqkdU+nzXb2pqdbX/evr7GSc9/c7INnrdVckNnFszYcKvk4vIPTUtX60Qui+R56yJ+PhZnHP1V+/rkJyW3LXV5t62EeyDZDWhUaJxutvXOuKesr+hKp06rc0y8ZXTXTesWXrIq7t8/HhghAAAAdAgAAIBLJQPH2m/DG2UfJHtleeMXJYDnZZ19/F9d9Cjj888yCFiTjnf7dSjI+87mhLUNiIVdBnu+UI/xtc8sqArqLPiOa6JbXYPYBKyfe8DhPZnbDnxF7+8u1tHGl5d/U3Kv5fq7P15Fnb2z+0qe0u41ybsCWpq66M2DYY9pznkJseBuBAAAdAgAAIBLJQOPNbwe7fkFifbZPD3q8u7v/EnyxJx3JIesjZMC5+1LssGvJYTFC6+TnGdazhGxaBkOTO8tuaEzCyYs+Z7kvi9scKdhgAsyevUM+/hHU38X8brsFR0kh86cictz2ysaVv/Ho/azSZq58F7JnSqT+9pjhAAAANAhAAAASTj+OFr2zOj6isomPdY3ZIDk7fN0z/tB/fXo5Fu6vS55RvsK69HRbZz0433XSs5bSpkAqcU3uETyott18xX7zAL7/UDWifQ/eheIJNg1P+zj69udiHxhnDhjLpOc+YMqyVnWS6zkzdslD1z2keRkF9QZIQAAAHQIAABACpcMvrBazy8o2zTqAlf+P2spw8oJeqzrgMzI+5TbGyTFMkxz+DWdud3dNK2kASSavyBP8vAsvcNDYfd9sgcogZbBf/UVki99eLPkwrZanrjp/vmSS6wycyq9ChkhAAAAdAgAAEAKn2WwoIvOvFxw1UcXuPJf7LMGAk5mE6+PfM2sgxMkV42qCftad5MeRyOnsqA14z0YjHxfOA6z32Phv/+kZPs119Aqg55rT7vRLCDtZBRcJPnkuIslOzOPSn5ukK7kGZSpf39KVs3RnAar0RghAAAAdAgAAEALOssg0MTnWPZ5J8mLK8ZKrv5NoeTOb+6JT+MgOnv1mN0FFVMk95unG3g4/XRYLjfbLzloKB9c0MihEtcO/Y3k8NeDvgd4urpYP12+xcTb8Vnhq4M6L07NMxJ81vkOed7gBa78l9MhLTe2xHsy07pf2nlTaQ58bDzbPwn7eOBfbpW8Y+IiyaPvfk/yX64fJPnRS/Ssm0nZuqGd7ayj99CgJXP1uRZ+LDkdfpKMEAAAADoEAACADgEAADAuzSEo+uMxyT+9cZjk/+zyz4Q8X2W91p0frSqVvOUJfe68vdZ511b9NN/adbDxaiKaI2QtKQweOy5514Iiye8Oe0LyjkA7V9qVrg5N0N0JG1pWay87XH3nRMk+80FiG5di7HkDFfX5kn97ROc9ZFrzCfxB/VX53R5vSe6fEZAcMA2sX04DmdZ9sdHfUXJZ1WjJWb76Rr/PS70avcR1odrasI8HPKK/+w9+Wec0PVawUS+ycwNeOd1d8q8euUFy8Us6VyYd5g3YGCEAAAB0CAAAgEslg+C2XZI33nyJ5OFTx4ddd/nV2yQv7vN2k55j2DN3Sc47qEN3+WU6fJNnUn+nqFbLY/VNgzp8Weuk7zCs2zpv0yHdp072kTw7X5fPjt8yTXKHTfr5RJTHUnWZoTHGdLDKAT+sGiP55Bg9jMaX30EfYK0unPZjXVbWtY/uCBmoj3yQWjrIyaqTfOx93Zmv+Cl7yZ61A6xfy7Jhr12t/KUs++9R6V/vlrxz0q8jXr/2XFvJ97xwm+SiF/dK7ng4de/1pmCEAAAA0CEAAAAulQxs9nBNoZWNMebojzRPMVeYpijksCG0cm1XlUtevUpniq+2Xku5Roc5W/MqmrPWCpexHXdLfvzJqyU77fQnVLhK3zsNnL9Vv5EvfcsEYUI6H75D9in9vPXvOzS9n+Sawbq6wtSn726N/b+lq2uuNSMavb6X9Xem8TUX6YcRAgAAQIcAAAAkoWQAAMlW6+hQ+MScnZJnTNUyZq4nS/IdQ3Ulwvpv9ZWc6WvZhZc6a+XEA0OWSZ6Wd1jy2ZBVPjDfc6NZSBBGCAAAAB0CAABAyQBAKxew3hcdsKaOhxwdCn+wQDdKy+mus+rTba/6prLfMZ6w/rHb6/RPh9famKizC21C4jBCAAAA6BAAAABKBgAQkdej52hUh/RXZXUS2pJq7J8NWg5GCAAAAB0CAABgjMdxOF8WAIDWjhECAABAhwAAANAhAAAAhg4BAAAwdAgAAIChQwAAAAwdAgAAYOgQAAAAQ4cAAAAYY/4P1F7bW+utHi0AAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "First false equation with length 8 in the training dataset:\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "true_train_equation_with_length_5 = true_train_equation[5]\n", "true_train_equation_with_length_8 = true_train_equation[8]\n", "print(f\"First true equation with length 5 in the training dataset:\")\n", "for i, x in enumerate(true_train_equation_with_length_5[0]):\n", " plt.subplot(1, 5, i+1)\n", " plt.axis('off') \n", " plt.imshow(x.transpose(1, 2, 0))\n", "plt.show()\n", "print(f\"First true equation with length 8 in the training dataset:\")\n", "for i, x in enumerate(true_train_equation_with_length_8[0]):\n", " plt.subplot(1, 8, i+1)\n", " plt.axis('off') \n", " plt.imshow(x.transpose(1, 2, 0))\n", "plt.show()\n", "\n", "false_train_equation_with_length_5 = false_train_equation[5]\n", "false_train_equation_with_length_8 = false_train_equation[8]\n", "print(f\"First false equation with length 5 in the training dataset:\")\n", "for i, x in enumerate(false_train_equation_with_length_5[0]):\n", " plt.subplot(1, 5, i+1)\n", " plt.axis('off') \n", " plt.imshow(x.transpose(1, 2, 0))\n", "plt.show()\n", "print(f\"First false equation with length 8 in the training dataset:\")\n", "for i, x in enumerate(false_train_equation_with_length_8[0]):\n", " plt.subplot(1, 8, i+1)\n", " plt.axis('off') \n", " plt.imshow(x.transpose(1, 2, 0))\n", "plt.show()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Building the Learning Part" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To build the learning part, we need to first build a machine learning base model. We use SymbolNet, and encapsulate it within a `BasicNN` object to create the base model. `BasicNN` is a class that encapsulates a PyTorch model, transforming it into a base model with an sklearn-style interface. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# class of symbol may be one of ['0', '1', '+', '='], total of 4 classes\n", "cls = SymbolNet(num_classes=4)\n", "loss_fn = nn.CrossEntropyLoss()\n", "optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-4)\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "base_model = BasicNN(\n", " cls,\n", " loss_fn,\n", " optimizer,\n", " device=device,\n", " batch_size=32,\n", " num_epochs=1,\n", " stop_loss=None,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "However, the base model built above deals with instance-level data (i.e., individual images), and can not directly deal with example-level data (i.e., a list of images comprising the equation). Therefore, we wrap the base model into `ABLModel`, which enables the learning part to train, test, and predict on example-level data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = ABLModel(base_model)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Building the Reasoning Part" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the reasoning part, we first build a knowledge base. As mentioned before, the knowledge base in this task involves the structure of the equations and a recursive definition of bit-wise operations. The knowledge base is already defined in `HedKB`, which is derived from `PrologKB`, and is built upon Prolog file `reasoning/BK.pl` and `reasoning/learn_add.pl`.\n", "\n", "Specifically, the knowledge about the structure of equations (in `reasoning/BK.pl`) is a set of DCG (definite clause grammar) rules recursively define that a digit is a sequence of '0' and '1', and equations share the structure of X+Y=Z, though the length of X, Y and Z can be varied. The knowledge about bit-wise operations (in `reasoning/learn_add.pl`) is a recursive logic program, which reversely calculates X+Y, i.e., it operates on X and Y digit-by-digit and from the last digit to the first.\n", "\n", "Note: Please notice that, the specific rules for calculating the operations are undefined in the knowledge base, i.e., results of '0+0', '0+1' and '1+1' could be '0', '1', '00', '01' or even '10'. The missing calculation rules are required to be learned from the data. Therefore, `HedKB` incorporates methods for abducing rules from data. Users interested can refer to the specific implementation of `HedKB` in `reasoning/reasoning.py`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "kb = HedKB()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, we create a reasoner. Due to the indeterminism of abductive reasoning, there could be multiple candidates compatible to the knowledge base. When this happens, reasoner can minimize inconsistencies between the knowledge base and pseudo-labels predicted by the learning part, and then return only one candidate that has the highest consistency. \n", "\n", "In this task, we create the reasoner by instantiating the class `HedReasoner`, which is a reasoner derived from `Reasoner` and tailored specifically for this task. `HedReasoner` leverages [ZOOpt library](https://github.com/polixir/ZOOpt) for acceleration, and has designed a specific strategy to better harness ZOOpt’s capabilities. Additionally, methods for abducing rules from data have been incorporated. Users interested can refer to the specific implementation of `HedReasoner` in `reasoning/reasoning.py`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "reasoner = HedReasoner(kb, dist_func=\"hamming\", use_zoopt=True, max_revision=10)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Building Evaluation Metrics" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we set up evaluation metrics. These metrics will be used to evaluate the model performance during training and testing. Specifically, we use `SymbolAccuracy` and `ReasoningMetric`, which are used to evaluate the accuracy of the machine learning model’s predictions and the accuracy of the final reasoning results, respectively." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "metric_list = [ConsistencyMetric(kb=kb)]" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Bridge Learning and Reasoning\n", "\n", "Now, the last step is to bridge the learning and reasoning part. We proceed this step by creating an instance of `HedBridge`, which is derived from `SimpleBridge` and tailored specific for this task." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bridge = HedBridge(model, reasoner, metric_list)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Perform pretraining, training and testing by invoking the `pretrain`, `train` and `test` methods of `HedBridge`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Build logger\n", "print_log(\"Abductive Learning on the HED example.\", logger=\"current\")\n", "\n", "# Retrieve the directory of the Log file and define the directory for saving the model weights.\n", "log_dir = ABLLogger.get_current_instance().log_dir\n", "weights_dir = osp.join(log_dir, \"weights\")\n", "\n", "bridge.pretrain(weights_dir)\n", "bridge.train(train_data, val_data)\n", "bridge.test(test_data)" ] } ], "metadata": { "kernelspec": { "display_name": "ABL", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "fb6f4ceeabb9a733f366948eb80109f83aedf798cc984df1e68fb411adb27d58" } } }, "nbformat": 4, "nbformat_minor": 2 }