From e2b0b330afd27fba6d9613ea1cf0e3adf27edf00 Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Wed, 15 Nov 2023 14:07:42 +0800 Subject: [PATCH] [MNT] support ListData in reasoning --- abl/evaluation/semantics_metric.py | 4 +- abl/reasoning/kb.py | 2 +- abl/reasoning/reasoner.py | 26 ++++---- abl/utils/utils.py | 4 +- examples/mnist_add/mnist_add_example.ipynb | 71 ++++++++++++++++------ 5 files changed, 70 insertions(+), 37 deletions(-) diff --git a/abl/evaluation/semantics_metric.py b/abl/evaluation/semantics_metric.py index 718cfea..ae7aac8 100644 --- a/abl/evaluation/semantics_metric.py +++ b/abl/evaluation/semantics_metric.py @@ -1,11 +1,11 @@ from typing import Optional, Sequence -from ..reasoning import BaseKB +from ..reasoning import KBBase from .base_metric import BaseMetric class SemanticsMetric(BaseMetric): - def __init__(self, kb: BaseKB = None, prefix: Optional[str] = None) -> None: + def __init__(self, kb: KBBase = None, prefix: Optional[str] = None) -> None: super().__init__(prefix) self.kb = kb diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index 37ba5b6..b626504 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -9,7 +9,7 @@ from functools import lru_cache import numpy as np import pyswip -from abl.utils.utils import flatten, reform_idx, hamming_dist, to_hashable, restore_from_hashable +from ..utils.utils import flatten, reform_idx, hamming_dist, to_hashable, restore_from_hashable class KBBase(ABC): diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 9cc24f0..686e9dd 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -1,6 +1,6 @@ import numpy as np from zoopt import Dimension, Objective, Parameter, Opt -from abl.utils.utils import ( +from ..utils.utils import ( confidence_dist, flatten, reform_idx, @@ -191,7 +191,7 @@ class ReasonerBase: return max_revision def abduce( - self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0 + self, data_sample, max_revision=-1, require_more_revision=0 ): """ Perform abductive reasoning on the given prediction data. @@ -219,9 +219,13 @@ class ReasonerBase: A revised pseudo label through abductive reasoning, which is consistent with the knowledge base. """ - symbol_num = len(flatten(pred_pseudo_label)) + symbol_num = data_sample.elements_num("pred_pseudo_label") max_revision_num = self._get_max_revision_num(max_revision, symbol_num) - + + pred_pseudo_label = data_sample.pred_pseudo_label[0] + pred_prob = data_sample.pred_prob[0] + y = data_sample.Y[0] + if self.use_zoopt: solution = self.zoopt_get_solution( symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num @@ -237,20 +241,18 @@ class ReasonerBase: return candidate def batch_abduce( - self, pred_probs, pred_pseudo_labels, Ys, max_revision=-1, require_more_revision=0 + self, data_samples, max_revision=-1, require_more_revision=0 ): """ Perform abductive reasoning on the given prediction data in batches. For detailed information, refer to `abduce`. """ - return [ - self.abduce( - pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision - ) - for pred_prob, pred_pseudo_label, Y in zip( - pred_probs, pred_pseudo_labels, Ys - ) + abduced_pseudo_label = [ + self.abduce(data_sample, max_revision, require_more_revision) + for data_sample in data_samples ] + data_samples.abduced_pseudo_label = abduced_pseudo_label + return abduced_pseudo_label # def _batch_abduce_helper(self, args): # z, prob, y, max_revision, require_more_revision = args diff --git a/abl/utils/utils.py b/abl/utils/utils.py index 8192bf9..1480045 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -192,7 +192,7 @@ def to_hashable(x): return x -def hashable_to_list(x): +def restore_from_hashable(x): """ Convert a nested tuple back to a nested list. @@ -208,7 +208,7 @@ def hashable_to_list(x): otherwise the original input. """ if isinstance(x, tuple): - return [hashable_to_list(item) for item in x] + return [restore_from_hashable(item) for item in x] return x diff --git a/examples/mnist_add/mnist_add_example.ipynb b/examples/mnist_add/mnist_add_example.ipynb index 146bd88..fc06d34 100644 --- a/examples/mnist_add/mnist_add_example.ipynb +++ b/examples/mnist_add/mnist_add_example.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -13,16 +13,16 @@ "\n", "from abl.learning import BasicNN, ABLModel\n", "from abl.bridge import SimpleBridge\n", - "from abl.evaluation import SymbolMetric, ABLMetric\n", + "from abl.evaluation import SymbolMetric\n", "from abl.utils import ABLLogger\n", "\n", - "from models.nn import LeNet5\n", + "from examples.models.nn import LeNet5\n", "from examples.mnist_add.datasets.get_mnist_add import get_mnist_add" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -40,19 +40,19 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Initialize knowledge base and abducer\n", "class add_KB(KBBase):\n", - " def __init__(self, pseudo_label_list=list(range(10)), prebuild_GKB=False, GKB_len_list=[2], max_err=0, use_cache=True):\n", - " super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)\n", + " def __init__(self, pseudo_label_list=list(range(10)), max_err=0, use_cache=True):\n", + " super().__init__(pseudo_label_list, max_err, use_cache)\n", "\n", " def logic_forward(self, nums):\n", " return sum(nums)\n", "\n", - "kb = add_KB(prebuild_GKB=True)\n", + "kb = add_KB()\n", "\n", "# kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='datasets/mnist_add/add.pl')\n", "abducer = ReasonerBase(kb, dist_func=\"confidence\")" @@ -68,7 +68,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -81,7 +81,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -93,7 +93,6 @@ " optimizer,\n", " device,\n", " save_interval=1,\n", - " save_dir=logger.save_dir,\n", " batch_size=32,\n", " num_epochs=1,\n", ")" @@ -109,7 +108,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -129,12 +128,12 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Add metric\n", - "metric = [SymbolMetric(prefix=\"mnist_add\"), ABLMetric(prefix=\"mnist_add\")]" + "metric = [SymbolMetric(prefix=\"mnist_add\")]" ] }, { @@ -147,7 +146,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -166,7 +165,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -181,15 +180,47 @@ "### Train and Test" ] }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11/15 13:36:00 - abl - WARNING - Transform used in the training phase will be used in prediction.\n" + ] + }, + { + "ename": "TypeError", + "evalue": "Input must be of type list.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/huwc/ABL-Package/examples/mnist_add/mnist_add_example.ipynb 单元格 17\u001b[0m line \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m bridge\u001b[39m.\u001b[39;49mtrain(train_data, loops\u001b[39m=\u001b[39;49m\u001b[39m5\u001b[39;49m, segment_size\u001b[39m=\u001b[39;49m\u001b[39m10000\u001b[39;49m)\n\u001b[1;32m 2\u001b[0m bridge\u001b[39m.\u001b[39mtest(test_data)\n", + "File \u001b[0;32m~/ABL-Package/abl/bridge/simple_bridge.py:92\u001b[0m, in \u001b[0;36mSimpleBridge.train\u001b[0;34m(self, train_data, loops, segment_size, eval_interval, save_interval, save_dir)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpredict(sub_data_samples)\n\u001b[1;32m 91\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39midx_to_pseudo_label(sub_data_samples)\n\u001b[0;32m---> 92\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mabduce_pseudo_label(sub_data_samples)\n\u001b[1;32m 93\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpseudo_label_to_idx(sub_data_samples)\n\u001b[1;32m 94\u001b[0m loss \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel\u001b[39m.\u001b[39mtrain(sub_data_samples)\n", + "File \u001b[0;32m~/ABL-Package/abl/bridge/simple_bridge.py:36\u001b[0m, in \u001b[0;36mSimpleBridge.abduce_pseudo_label\u001b[0;34m(self, data_samples, max_revision, require_more_revision)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mabduce_pseudo_label\u001b[39m(\n\u001b[1;32m 31\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 32\u001b[0m data_samples: ListData,\n\u001b[1;32m 33\u001b[0m max_revision: \u001b[39mint\u001b[39m \u001b[39m=\u001b[39m \u001b[39m-\u001b[39m\u001b[39m1\u001b[39m,\n\u001b[1;32m 34\u001b[0m require_more_revision: \u001b[39mint\u001b[39m \u001b[39m=\u001b[39m \u001b[39m0\u001b[39m,\n\u001b[1;32m 35\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m List[List[Any]]:\n\u001b[0;32m---> 36\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mabducer\u001b[39m.\u001b[39;49mbatch_abduce(data_samples, max_revision, require_more_revision)\n\u001b[1;32m 37\u001b[0m \u001b[39mreturn\u001b[39;00m data_samples[\u001b[39m\"\u001b[39m\u001b[39mabduced_pseudo_label\u001b[39m\u001b[39m\"\u001b[39m]\n", + "File \u001b[0;32m~/ABL-Package/abl/reasoning/reasoner.py:246\u001b[0m, in \u001b[0;36mReasonerBase.batch_abduce\u001b[0;34m(self, data_samples, max_revision, require_more_revision)\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mbatch_abduce\u001b[39m(\n\u001b[1;32m 240\u001b[0m \u001b[39mself\u001b[39m, data_samples, max_revision\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, require_more_revision\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m\n\u001b[1;32m 241\u001b[0m ):\n\u001b[1;32m 242\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 243\u001b[0m \u001b[39m Perform abductive reasoning on the given prediction data in batches.\u001b[39;00m\n\u001b[1;32m 244\u001b[0m \u001b[39m For detailed information, refer to `abduce`.\u001b[39;00m\n\u001b[1;32m 245\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 246\u001b[0m \u001b[39mreturn\u001b[39;00m [\n\u001b[1;32m 247\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mabduce(data_sample, max_revision, require_more_revision)\n\u001b[1;32m 248\u001b[0m \u001b[39mfor\u001b[39;00m data_sample \u001b[39min\u001b[39;00m data_samples\n\u001b[1;32m 249\u001b[0m ]\n", + "File \u001b[0;32m~/ABL-Package/abl/reasoning/reasoner.py:247\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mbatch_abduce\u001b[39m(\n\u001b[1;32m 240\u001b[0m \u001b[39mself\u001b[39m, data_samples, max_revision\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, require_more_revision\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m\n\u001b[1;32m 241\u001b[0m ):\n\u001b[1;32m 242\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 243\u001b[0m \u001b[39m Perform abductive reasoning on the given prediction data in batches.\u001b[39;00m\n\u001b[1;32m 244\u001b[0m \u001b[39m For detailed information, refer to `abduce`.\u001b[39;00m\n\u001b[1;32m 245\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[1;32m 246\u001b[0m \u001b[39mreturn\u001b[39;00m [\n\u001b[0;32m--> 247\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mabduce(data_sample, max_revision, require_more_revision)\n\u001b[1;32m 248\u001b[0m \u001b[39mfor\u001b[39;00m data_sample \u001b[39min\u001b[39;00m data_samples\n\u001b[1;32m 249\u001b[0m ]\n", + "File \u001b[0;32m~/ABL-Package/abl/reasoning/reasoner.py:222\u001b[0m, in \u001b[0;36mReasonerBase.abduce\u001b[0;34m(self, pred_prob, pred_pseudo_label, y, max_revision, require_more_revision)\u001b[0m\n\u001b[1;32m 193\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mabduce\u001b[39m(\n\u001b[1;32m 194\u001b[0m \u001b[39mself\u001b[39m, pred_prob, pred_pseudo_label, y, max_revision\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, require_more_revision\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m\n\u001b[1;32m 195\u001b[0m ):\n\u001b[1;32m 196\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 197\u001b[0m \u001b[39m Perform abductive reasoning on the given prediction data.\u001b[39;00m\n\u001b[1;32m 198\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[39m knowledge base.\u001b[39;00m\n\u001b[1;32m 221\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 222\u001b[0m symbol_num \u001b[39m=\u001b[39m \u001b[39mlen\u001b[39m(flatten(pred_pseudo_label))\n\u001b[1;32m 223\u001b[0m max_revision_num \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_get_max_revision_num(max_revision, symbol_num)\n\u001b[1;32m 225\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39muse_zoopt:\n", + "File \u001b[0;32m~/ABL-Package/abl/utils/utils.py:26\u001b[0m, in \u001b[0;36mflatten\u001b[0;34m(nested_list)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 8\u001b[0m \u001b[39mFlattens a nested list.\u001b[39;00m\n\u001b[1;32m 9\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[39m If the input object is not a list.\u001b[39;00m\n\u001b[1;32m 24\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 25\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(nested_list, \u001b[39mlist\u001b[39m):\n\u001b[0;32m---> 26\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mInput must be of type list.\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 28\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m nested_list \u001b[39mor\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(nested_list[\u001b[39m0\u001b[39m], (\u001b[39mlist\u001b[39m, \u001b[39mtuple\u001b[39m)):\n\u001b[1;32m 29\u001b[0m \u001b[39mreturn\u001b[39;00m nested_list\n", + "\u001b[0;31mTypeError\u001b[0m: Input must be of type list." + ] + } + ], + "source": [ + "bridge.train(train_data, loops=5, segment_size=10000)\n", + "bridge.test(test_data)" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "bridge.train(train_data, epochs=5, batch_size=10000)\n", - "bridge.test(test_data)" - ] + "source": [] } ], "metadata": {