|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380 |
- {
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Zoo\n",
- "\n",
- "This notebook shows an implementation of [Zoo](https://archive.ics.uci.edu/dataset/111/zoo). In this task, attributes of animals (such as presence of hair, eggs, etc.) and their targets (the animal class they belong to) are given, along with a knowledge base which contain information about the relations between attributes and targets, e.g., Implies(milk == 1, mammal == 1). \n",
- "\n",
- "The goal of this task is to develop a learning model that can predict the targets of animals based on their attributes. In the initial stages, when the model is under-trained, it may produce incorrect predictions that conflict with the relations contained in the knowledge base. When this happens, abductive reasoning can be employed to adjust these results and retrain the model accordingly. This process enables us to further update the learning model."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Import necessary libraries and modules\n",
- "import os.path as osp\n",
- "import numpy as np\n",
- "from sklearn.ensemble import RandomForestClassifier\n",
- "from get_dataset import load_and_preprocess_dataset, split_dataset\n",
- "from abl.learning import ABLModel\n",
- "from kb import ZooKB\n",
- "from abl.reasoning import Reasoner\n",
- "from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n",
- "from abl.utils import ABLLogger, print_log, confidence_dist\n",
- "from abl.bridge import SimpleBridge"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Working with Data\n",
- "\n",
- "First, we load and preprocess the [Zoo dataset](https://archive.ics.uci.edu/dataset/111/zoo), and split it into labeled/unlabeled/test data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "X, y = load_and_preprocess_dataset(dataset_id=62)\n",
- "X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Zoo dataset consist of tabular data. The attributes contains 17 boolean values (e.g., hair, feathers, eggs, milk, airborne, aquatic, etc.) and the target is a integer value in range [0,6] representing 7 classes (e.g., mammal, bird, reptile, fish, amphibian, insect, and other). Below is an illustration:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Shape of X and y: (101, 16) (101,)\n",
- "First five elements of X:\n",
- "[[True False False True False False True True True True False False 4\n",
- " False False True]\n",
- " [True False False True False False False True True True False False 4\n",
- " True False True]\n",
- " [False False True False False True True True True False False True 0\n",
- " True False False]\n",
- " [True False False True False False True True True True False False 4\n",
- " False False True]\n",
- " [True False False True False False True True True True False False 4\n",
- " True False True]]\n",
- "First five elements of y:\n",
- "[0 0 3 0 0]\n"
- ]
- }
- ],
- "source": [
- "print(\"Shape of X and y:\", X.shape, y.shape)\n",
- "print(\"First five elements of X:\")\n",
- "print(X[:5])\n",
- "print(\"First five elements of y:\")\n",
- "print(y[:5])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Next, we transform the tabular data to the format required by ABL-Package, which is a tuple of (X, gt_pseudo_label, Y). In this task, we treat the attributes as X and the targets as gt_pseudo_label (ground truth pseudo-labels). Y (reasoning results) are expected to be 0, indicating no rules are violated."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "def transform_tab_data(X, y):\n",
- " return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))\n",
- "label_data = transform_tab_data(X_label, y_label)\n",
- "test_data = transform_tab_data(X_test, y_test)\n",
- "train_data = transform_tab_data(X_unlabel, y_unlabel)"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Building the Learning Part"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "To build the learning part, we need to first build a machine learning base model. We use a [Random Forest](https://en.wikipedia.org/wiki/Random_forest) as the base model."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "base_model = RandomForestClassifier()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "However, the base model built above deals with instance-level data, and can not directly deal with example-level data. Therefore, we wrap the base model into `ABLModel`, which enables the learning part to train, test, and predict on example-level data."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "model = ABLModel(base_model)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Building the Reasoning Part"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "In the reasoning part, we first build a knowledge base which contains information about the relations between attributes (X) and targets (pseudo-labels), e.g., Implies(milk == 1, mammal == 1). The knowledge base is built in the `ZooKB` class within file `kb.py`, and is derived from the `KBBase` class."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "kb = ZooKB()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "As mentioned, for all attributes and targets in the dataset, the reasoning results are expected to be 0 since there should be no violations of the established knowledge in real data. As shown below:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Example 0: the attributes are: [True False False True False False True True True True False False 4 False\n",
- " False True], and the target is 0.\n",
- "Reasoning result is 0.\n",
- "\n",
- "Example 1: the attributes are: [True False False True False False False True True True False False 4 True\n",
- " False True], and the target is 0.\n",
- "Reasoning result is 0.\n",
- "\n",
- "Example 2: the attributes are: [False False True False False True True True True False False True 0 True\n",
- " False False], and the target is 3.\n",
- "Reasoning result is 0.\n",
- "\n",
- "Example 3: the attributes are: [True False False True False False True True True True False False 4 False\n",
- " False True], and the target is 0.\n",
- "Reasoning result is 0.\n",
- "\n",
- "Example 4: the attributes are: [True False False True False False True True True True False False 4 True\n",
- " False True], and the target is 0.\n",
- "Reasoning result is 0.\n",
- "\n"
- ]
- }
- ],
- "source": [
- "for idx, (x, y_item) in enumerate(zip(X[:5], y[:5])):\n",
- " print(f\"Example {idx}: the attributes are: {x}, and the target is {y_item}.\")\n",
- " print(f\"Reasoning result is {kb.logic_forward([y_item], [x])}.\")\n",
- " print()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Then, we create a reasoner by instantiating the class ``Reasoner``. Due to the indeterminism of abductive reasoning, there could be multiple candidates compatible to the knowledge base. When this happens, reasoner can minimize inconsistencies between the knowledge base and pseudo-labels predicted by the learning part, and then return only one candidate that has the highest consistency."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "def consitency(data_example, candidates, candidate_idxs, reasoning_results):\n",
- " pred_prob = data_example.pred_prob\n",
- " model_scores = confidence_dist(pred_prob, candidate_idxs)\n",
- " rule_scores = np.array(reasoning_results)\n",
- " scores = model_scores + rule_scores\n",
- " return scores\n",
- "\n",
- "reasoner = Reasoner(kb, dist_func=consitency)"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Building Evaluation Metrics"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Next, we set up evaluation metrics. These metrics will be used to evaluate the model performance during training and testing. Specifically, we use `SymbolAccuracy` and `ReasoningMetric`, which are used to evaluate the accuracy of the machine learning model’s predictions and the accuracy of the final reasoning results, respectively."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "metric_list = [SymbolAccuracy(prefix=\"zoo\"), ReasoningMetric(kb=kb, prefix=\"zoo\")]"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Bridging Learning and Reasoning\n",
- "\n",
- "Now, the last step is to bridge the learning and reasoning part. We proceed this step by creating an instance of `SimpleBridge`."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "bridge = SimpleBridge(model, reasoner, metric_list)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Perform training and testing by invoking the `train` and `test` methods of `SimpleBridge`."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "12/22 11:48:01 - abl - INFO - Abductive Learning on the ZOO example.\n",
- "12/22 11:48:01 - abl - INFO - ------- Use labeled data to pretrain the model -----------\n",
- "12/22 11:48:01 - abl - INFO - ------- Test the initial model -----------\n",
- "12/22 11:48:01 - abl - INFO - Evaluation ended, zoo/character_accuracy: 0.903 zoo/reasoning_accuracy: 0.903 \n",
- "12/22 11:48:01 - abl - INFO - ------- Use ABL to train the model -----------\n",
- "12/22 11:48:01 - abl - INFO - loop(train) [1/3] segment(train) [1/1] \n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "12/22 11:48:02 - abl - INFO - Evaluation start: loop(val) [1]\n",
- "12/22 11:48:03 - abl - INFO - Evaluation ended, zoo/character_accuracy: 1.000 zoo/reasoning_accuracy: 1.000 \n",
- "12/22 11:48:03 - abl - INFO - loop(train) [2/3] segment(train) [1/1] \n",
- "12/22 11:48:04 - abl - INFO - Evaluation start: loop(val) [2]\n",
- "12/22 11:48:05 - abl - INFO - Evaluation ended, zoo/character_accuracy: 1.000 zoo/reasoning_accuracy: 1.000 \n",
- "12/22 11:48:05 - abl - INFO - loop(train) [3/3] segment(train) [1/1] \n",
- "12/22 11:48:05 - abl - INFO - Evaluation start: loop(val) [3]\n",
- "12/22 11:48:06 - abl - INFO - Evaluation ended, zoo/character_accuracy: 1.000 zoo/reasoning_accuracy: 1.000 \n",
- "12/22 11:48:06 - abl - INFO - ------- Test the final model -----------\n",
- "12/22 11:48:06 - abl - INFO - Evaluation ended, zoo/character_accuracy: 0.968 zoo/reasoning_accuracy: 0.968 \n"
- ]
- }
- ],
- "source": [
- "# Build logger\n",
- "print_log(\"Abductive Learning on the Zoo example.\", logger=\"current\")\n",
- "log_dir = ABLLogger.get_current_instance().log_dir\n",
- "weights_dir = osp.join(log_dir, \"weights\")\n",
- "\n",
- "print_log(\"------- Use labeled data to pretrain the model -----------\", logger=\"current\")\n",
- "base_model.fit(X_label, y_label)\n",
- "print_log(\"------- Test the initial model -----------\", logger=\"current\")\n",
- "bridge.test(test_data)\n",
- "print_log(\"------- Use ABL to train the model -----------\", logger=\"current\")\n",
- "bridge.train(train_data=train_data, label_data=label_data, loops=3, segment_size=len(X_unlabel), save_dir=weights_dir)\n",
- "print_log(\"------- Test the final model -----------\", logger=\"current\")\n",
- "bridge.test(test_data)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We may see from the results, after undergoing training with ABL, the model's accuracy has improved."
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "abl",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.18"
- },
- "orig_nbformat": 4,
- "vscode": {
- "interpreter": {
- "hash": "9c8d454494e49869a4ee4046edcac9a39ff683f7d38abf0769f648402670238e"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
|