Browse Source

[MNT] support ListData in reasoning

pull/4/head
troyyyyy 1 year ago
parent
commit
e2b0b330af
5 changed files with 70 additions and 37 deletions
  1. +2
    -2
      abl/evaluation/semantics_metric.py
  2. +1
    -1
      abl/reasoning/kb.py
  3. +14
    -12
      abl/reasoning/reasoner.py
  4. +2
    -2
      abl/utils/utils.py
  5. +51
    -20
      examples/mnist_add/mnist_add_example.ipynb

+ 2
- 2
abl/evaluation/semantics_metric.py View File

@@ -1,11 +1,11 @@
from typing import Optional, Sequence from typing import Optional, Sequence


from ..reasoning import BaseKB
from ..reasoning import KBBase
from .base_metric import BaseMetric from .base_metric import BaseMetric




class SemanticsMetric(BaseMetric): class SemanticsMetric(BaseMetric):
def __init__(self, kb: BaseKB = None, prefix: Optional[str] = None) -> None:
def __init__(self, kb: KBBase = None, prefix: Optional[str] = None) -> None:
super().__init__(prefix) super().__init__(prefix)
self.kb = kb self.kb = kb




+ 1
- 1
abl/reasoning/kb.py View File

@@ -9,7 +9,7 @@ from functools import lru_cache
import numpy as np import numpy as np
import pyswip import pyswip


from abl.utils.utils import flatten, reform_idx, hamming_dist, to_hashable, restore_from_hashable
from ..utils.utils import flatten, reform_idx, hamming_dist, to_hashable, restore_from_hashable




class KBBase(ABC): class KBBase(ABC):


+ 14
- 12
abl/reasoning/reasoner.py View File

@@ -1,6 +1,6 @@
import numpy as np import numpy as np
from zoopt import Dimension, Objective, Parameter, Opt from zoopt import Dimension, Objective, Parameter, Opt
from abl.utils.utils import (
from ..utils.utils import (
confidence_dist, confidence_dist,
flatten, flatten,
reform_idx, reform_idx,
@@ -191,7 +191,7 @@ class ReasonerBase:
return max_revision return max_revision
def abduce( def abduce(
self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0
self, data_sample, max_revision=-1, require_more_revision=0
): ):
""" """
Perform abductive reasoning on the given prediction data. Perform abductive reasoning on the given prediction data.
@@ -219,9 +219,13 @@ class ReasonerBase:
A revised pseudo label through abductive reasoning, which is consistent with the A revised pseudo label through abductive reasoning, which is consistent with the
knowledge base. knowledge base.
""" """
symbol_num = len(flatten(pred_pseudo_label))
symbol_num = data_sample.elements_num("pred_pseudo_label")
max_revision_num = self._get_max_revision_num(max_revision, symbol_num) max_revision_num = self._get_max_revision_num(max_revision, symbol_num)

pred_pseudo_label = data_sample.pred_pseudo_label[0]
pred_prob = data_sample.pred_prob[0]
y = data_sample.Y[0]
if self.use_zoopt: if self.use_zoopt:
solution = self.zoopt_get_solution( solution = self.zoopt_get_solution(
symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num
@@ -237,20 +241,18 @@ class ReasonerBase:
return candidate return candidate


def batch_abduce( def batch_abduce(
self, pred_probs, pred_pseudo_labels, Ys, max_revision=-1, require_more_revision=0
self, data_samples, max_revision=-1, require_more_revision=0
): ):
""" """
Perform abductive reasoning on the given prediction data in batches. Perform abductive reasoning on the given prediction data in batches.
For detailed information, refer to `abduce`. For detailed information, refer to `abduce`.
""" """
return [
self.abduce(
pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision
)
for pred_prob, pred_pseudo_label, Y in zip(
pred_probs, pred_pseudo_labels, Ys
)
abduced_pseudo_label = [
self.abduce(data_sample, max_revision, require_more_revision)
for data_sample in data_samples
] ]
data_samples.abduced_pseudo_label = abduced_pseudo_label
return abduced_pseudo_label


# def _batch_abduce_helper(self, args): # def _batch_abduce_helper(self, args):
# z, prob, y, max_revision, require_more_revision = args # z, prob, y, max_revision, require_more_revision = args


+ 2
- 2
abl/utils/utils.py View File

@@ -192,7 +192,7 @@ def to_hashable(x):
return x return x




def hashable_to_list(x):
def restore_from_hashable(x):
""" """
Convert a nested tuple back to a nested list. Convert a nested tuple back to a nested list.


@@ -208,7 +208,7 @@ def hashable_to_list(x):
otherwise the original input. otherwise the original input.
""" """
if isinstance(x, tuple): if isinstance(x, tuple):
return [hashable_to_list(item) for item in x]
return [restore_from_hashable(item) for item in x]
return x return x






+ 51
- 20
examples/mnist_add/mnist_add_example.ipynb View File

@@ -2,7 +2,7 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3,
"execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -13,16 +13,16 @@
"\n", "\n",
"from abl.learning import BasicNN, ABLModel\n", "from abl.learning import BasicNN, ABLModel\n",
"from abl.bridge import SimpleBridge\n", "from abl.bridge import SimpleBridge\n",
"from abl.evaluation import SymbolMetric, ABLMetric\n",
"from abl.evaluation import SymbolMetric\n",
"from abl.utils import ABLLogger\n", "from abl.utils import ABLLogger\n",
"\n", "\n",
"from models.nn import LeNet5\n",
"from examples.models.nn import LeNet5\n",
"from examples.mnist_add.datasets.get_mnist_add import get_mnist_add" "from examples.mnist_add.datasets.get_mnist_add import get_mnist_add"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -40,19 +40,19 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Initialize knowledge base and abducer\n", "# Initialize knowledge base and abducer\n",
"class add_KB(KBBase):\n", "class add_KB(KBBase):\n",
" def __init__(self, pseudo_label_list=list(range(10)), prebuild_GKB=False, GKB_len_list=[2], max_err=0, use_cache=True):\n",
" super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)\n",
" def __init__(self, pseudo_label_list=list(range(10)), max_err=0, use_cache=True):\n",
" super().__init__(pseudo_label_list, max_err, use_cache)\n",
"\n", "\n",
" def logic_forward(self, nums):\n", " def logic_forward(self, nums):\n",
" return sum(nums)\n", " return sum(nums)\n",
"\n", "\n",
"kb = add_KB(prebuild_GKB=True)\n",
"kb = add_KB()\n",
"\n", "\n",
"# kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='datasets/mnist_add/add.pl')\n", "# kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='datasets/mnist_add/add.pl')\n",
"abducer = ReasonerBase(kb, dist_func=\"confidence\")" "abducer = ReasonerBase(kb, dist_func=\"confidence\")"
@@ -68,7 +68,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6,
"execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -81,7 +81,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -93,7 +93,6 @@
" optimizer,\n", " optimizer,\n",
" device,\n", " device,\n",
" save_interval=1,\n", " save_interval=1,\n",
" save_dir=logger.save_dir,\n",
" batch_size=32,\n", " batch_size=32,\n",
" num_epochs=1,\n", " num_epochs=1,\n",
")" ")"
@@ -109,7 +108,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8,
"execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -129,12 +128,12 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9,
"execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Add metric\n", "# Add metric\n",
"metric = [SymbolMetric(prefix=\"mnist_add\"), ABLMetric(prefix=\"mnist_add\")]"
"metric = [SymbolMetric(prefix=\"mnist_add\")]"
] ]
}, },
{ {
@@ -147,7 +146,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10,
"execution_count": 8,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -166,7 +165,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -181,15 +180,47 @@
"### Train and Test" "### Train and Test"
] ]
}, },
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"11/15 13:36:00 - abl - WARNING - Transform used in the training phase will be used in prediction.\n"
]
},
{
"ename": "TypeError",
"evalue": "Input must be of type list.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/home/huwc/ABL-Package/examples/mnist_add/mnist_add_example.ipynb 单元格 17\u001b[0m line \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> <a href='vscode-notebook-cell://ssh-remote%2B210.28.135.93/home/huwc/ABL-Package/examples/mnist_add/mnist_add_example.ipynb#X22sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a>\u001b[0m bridge\u001b[39m.\u001b[39;49mtrain(train_data, loops\u001b[39m=\u001b[39;49m\u001b[39m5\u001b[39;49m, segment_size\u001b[39m=\u001b[39;49m\u001b[39m10000\u001b[39;49m)\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2B210.28.135.93/home/huwc/ABL-Package/examples/mnist_add/mnist_add_example.ipynb#X22sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1'>2</a>\u001b[0m bridge\u001b[39m.\u001b[39mtest(test_data)\n",
"File \u001b[0;32m~/ABL-Package/abl/bridge/simple_bridge.py:92\u001b[0m, in \u001b[0;36mSimpleBridge.train\u001b[0;34m(self, train_data, loops, segment_size, eval_interval, save_interval, save_dir)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpredict(sub_data_samples)\n\u001b[1;32m 91\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39midx_to_pseudo_label(sub_data_samples)\n\u001b[0;32m---> 92\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mabduce_pseudo_label(sub_data_samples)\n\u001b[1;32m 93\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpseudo_label_to_idx(sub_data_samples)\n\u001b[1;32m 94\u001b[0m loss \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel\u001b[39m.\u001b[39mtrain(sub_data_samples)\n",
"File \u001b[0;32m~/ABL-Package/abl/bridge/simple_bridge.py:36\u001b[0m, in \u001b[0;36mSimpleBridge.abduce_pseudo_label\u001b[0;34m(self, data_samples, max_revision, require_more_revision)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mabduce_pseudo_label\u001b[39m(\n\u001b[1;32m 31\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 32\u001b[0m data_samples: ListData,\n\u001b[1;32m 33\u001b[0m max_revision: \u001b[39mint\u001b[39m \u001b[39m=\u001b[39m \u001b[39m-\u001b[39m\u001b[39m1\u001b[39m,\n\u001b[1;32m 34\u001b[0m require_more_revision: \u001b[39mint\u001b[39m \u001b[39m=\u001b[39m \u001b[39m0\u001b[39m,\n\u001b[1;32m 35\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m List[List[Any]]:\n\u001b[0;32m---> 36\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mabducer\u001b[39m.\u001b[39;49mbatch_abduce(data_samples, max_revision, require_more_revision)\n\u001b[1;32m 37\u001b[0m \u001b[39mreturn\u001b[39;00m data_samples[\u001b[39m\"\u001b[39m\u001b[39mabduced_pseudo_label\u001b[39m\u001b[39m\"\u001b[39m]\n",
"File \u001b[0;32m~/ABL-Package/abl/reasoning/reasoner.py:246\u001b[0m, in \u001b[0;36mReasonerBase.batch_abduce\u001b[0;34m(self, data_samples, max_revision, require_more_revision)\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mbatch_abduce\u001b[39m(\n\u001b[1;32m 240\u001b[0m \u001b[39mself\u001b[39m, data_samples, max_revision\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, require_more_revision\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m\n\u001b[1;32m 241\u001b[0m ):\n\u001b[1;32m 242\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 243\u001b[0m \u001b[39m Perform abductive reasoning on the given prediction data in batches.\u001b[39;00m\n\u001b[1;32m 244\u001b[0m \u001b[39m For detailed information, refer to `abduce`.\u001b[39;00m\n\u001b[1;32m 245\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 246\u001b[0m \u001b[39mreturn\u001b[39;00m [\n\u001b[1;32m 247\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mabduce(data_sample, max_revision, require_more_revision)\n\u001b[1;32m 248\u001b[0m \u001b[39mfor\u001b[39;00m data_sample \u001b[39min\u001b[39;00m data_samples\n\u001b[1;32m 249\u001b[0m ]\n",
"File \u001b[0;32m~/ABL-Package/abl/reasoning/reasoner.py:247\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mbatch_abduce\u001b[39m(\n\u001b[1;32m 240\u001b[0m \u001b[39mself\u001b[39m, data_samples, max_revision\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, require_more_revision\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m\n\u001b[1;32m 241\u001b[0m ):\n\u001b[1;32m 242\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 243\u001b[0m \u001b[39m Perform abductive reasoning on the given prediction data in batches.\u001b[39;00m\n\u001b[1;32m 244\u001b[0m \u001b[39m For detailed information, refer to `abduce`.\u001b[39;00m\n\u001b[1;32m 245\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[1;32m 246\u001b[0m \u001b[39mreturn\u001b[39;00m [\n\u001b[0;32m--> 247\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mabduce(data_sample, max_revision, require_more_revision)\n\u001b[1;32m 248\u001b[0m \u001b[39mfor\u001b[39;00m data_sample \u001b[39min\u001b[39;00m data_samples\n\u001b[1;32m 249\u001b[0m ]\n",
"File \u001b[0;32m~/ABL-Package/abl/reasoning/reasoner.py:222\u001b[0m, in \u001b[0;36mReasonerBase.abduce\u001b[0;34m(self, pred_prob, pred_pseudo_label, y, max_revision, require_more_revision)\u001b[0m\n\u001b[1;32m 193\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mabduce\u001b[39m(\n\u001b[1;32m 194\u001b[0m \u001b[39mself\u001b[39m, pred_prob, pred_pseudo_label, y, max_revision\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, require_more_revision\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m\n\u001b[1;32m 195\u001b[0m ):\n\u001b[1;32m 196\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 197\u001b[0m \u001b[39m Perform abductive reasoning on the given prediction data.\u001b[39;00m\n\u001b[1;32m 198\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[39m knowledge base.\u001b[39;00m\n\u001b[1;32m 221\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 222\u001b[0m symbol_num \u001b[39m=\u001b[39m \u001b[39mlen\u001b[39m(flatten(pred_pseudo_label))\n\u001b[1;32m 223\u001b[0m max_revision_num \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_get_max_revision_num(max_revision, symbol_num)\n\u001b[1;32m 225\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39muse_zoopt:\n",
"File \u001b[0;32m~/ABL-Package/abl/utils/utils.py:26\u001b[0m, in \u001b[0;36mflatten\u001b[0;34m(nested_list)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 8\u001b[0m \u001b[39mFlattens a nested list.\u001b[39;00m\n\u001b[1;32m 9\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[39m If the input object is not a list.\u001b[39;00m\n\u001b[1;32m 24\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 25\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(nested_list, \u001b[39mlist\u001b[39m):\n\u001b[0;32m---> 26\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mInput must be of type list.\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 28\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m nested_list \u001b[39mor\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(nested_list[\u001b[39m0\u001b[39m], (\u001b[39mlist\u001b[39m, \u001b[39mtuple\u001b[39m)):\n\u001b[1;32m 29\u001b[0m \u001b[39mreturn\u001b[39;00m nested_list\n",
"\u001b[0;31mTypeError\u001b[0m: Input must be of type list."
]
}
],
"source": [
"bridge.train(train_data, loops=5, segment_size=10000)\n",
"bridge.test(test_data)"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [
"bridge.train(train_data, epochs=5, batch_size=10000)\n",
"bridge.test(test_data)"
]
"source": []
} }
], ],
"metadata": { "metadata": {


Loading…
Cancel
Save