Browse Source

[FIX] fix bugs and run mnist and hwf successfully

ab_data
Gao Enhao 1 year ago
parent
commit
831aa855e7
7 changed files with 94 additions and 69 deletions
  1. +1
    -1
      abl/reasoning/__init__.py
  2. +1
    -8
      abl/reasoning/base_kb.py
  3. +82
    -52
      abl/reasoning/reasoner.py
  4. +0
    -1
      abl/reasoning/search_engine/__init__.py
  5. +1
    -1
      abl/reasoning/search_engine/bfs.py
  6. +7
    -1
      examples/mnist_add/mnist_add_example.ipynb
  7. +2
    -5
      examples/mnist_add/mnist_add_kb.py

+ 1
- 1
abl/reasoning/__init__.py View File

@@ -3,4 +3,4 @@ from .ground_kb import GroundKB
from .prolog_based_kb import PrologBasedKB
from .reasoner import ReasonerBase
from .search_based_kb import SearchBasedKB
from .search_engine import BFS, BaseSearchEngine, Zoopt
from .search_engine import BFS, BaseSearchEngine

+ 1
- 8
abl/reasoning/base_kb.py View File

@@ -1,17 +1,10 @@
from abc import ABC, abstractmethod

from ..structures import ListData
from abc import ABC


class BaseKB(ABC):
def __init__(self, pseudo_label_list) -> None:
self.pseudo_label_list = pseudo_label_list

@abstractmethod
def abduce_candidates(self, data_sample: ListData):
"""Placeholder for abduction of the knowledge base."""
pass

# TODO: When the output is excessively long, use ellipses as a substitute.
def __repr__(self):
return (


+ 82
- 52
abl/reasoning/reasoner.py View File

@@ -3,8 +3,7 @@ from typing import Any, List, Mapping, Optional
import numpy as np

from ..structures import ListData
from ..utils import (Cache, calculate_revision_num, confidence_dist,
hamming_dist)
from ..utils import Cache, calculate_revision_num, confidence_dist, hamming_dist
from .base_kb import BaseKB
from .search_engine import BFS, BaseSearchEngine

@@ -81,70 +80,53 @@ class ReasonerBase:
else:
key_func = lambda x: x
self.cache = Cache[ListData, List[List[Any]]](
func=self.abduce,
func=self.abduce_candidates,
cache=self.use_cache,
cache_file=self.cache_file,
key_func=key_func,
max_size=cache_size,
)

def _get_dist_list(self, data_sample: ListData, candidates: List[List[Any]]):
def abduce(
self,
data_sample: ListData,
max_revision: int = -1,
require_more_revision: int = 0,
):
"""
Get the list of costs between each pseudo label and candidate.
Perform revision by abduction on the given data.

Parameters
----------
pred_pseudo_label : list
The pseudo label to be used for computing costs of candidates.
pred_prob : list
Probabilities of the predictions. Used when distance function is "confidence".
candidates : list
List of candidate abduction result.

Returns
-------
numpy.ndarray
Array of computed costs for each candidate.
"""
if self.dist_func == "hamming":
return hamming_dist(data_sample["pred_pseudo_label"][0], candidates)

elif self.dist_func == "confidence":
candidates = [[self.remapping[x] for x in c] for c in candidates]
return confidence_dist(data_sample["pred_prob"][0], candidates)

def select(self, data_sample: ListData, candidates: List[List[Any]]):
"""
Get one candidate. If multiple candidates exist, return the one with minimum cost.

Parameters
----------
List of probabilities for predicted results.
pred_pseudo_label : list
The pseudo label to be used for selecting a candidate.
pred_prob : list
Probabilities of the predictions.
candidates : list
List of candidate abduction result.
List of predicted pseudo labels.
y : any
Ground truth for the predicted results.
max_revision : int or float, optional
Maximum number of revisions to use. If float, represents the fraction of total revisions to use.
If -1, any revisions are allowed. Defaults to -1.
require_more_revision : int, optional
Number of additional revisions to require. Defaults to 0.

Returns
-------
list
The chosen candidate based on minimum cost.
If no candidates, an empty list is returned.
The abduced revisions.
"""
if len(candidates) == 0:
return []
elif len(candidates) == 1:
return candidates[0]
else:
cost_array = self._get_dist_list(data_sample, candidates)
candidate = candidates[np.argmin(cost_array)]
return candidate
symbol_num = data_sample.elements_num("pred_pseudo_label")
max_revision_num = calculate_revision_num(max_revision, symbol_num)
data_sample.set_metainfo(dict(symbol_num=symbol_num))

def abduce(
candidates = self.cache.get(data_sample, max_revision_num, require_more_revision)
candidate = self.select_one_candidate(data_sample, candidates)
return candidate

def abduce_candidates(
self,
data_sample: ListData,
max_revision: int = -1,
max_revision_num: int = -1,
require_more_revision: int = 0,
):
"""
@@ -169,9 +151,6 @@ class ReasonerBase:
list
The abduced revisions.
"""
symbol_num = data_sample.elements_num("pred_pseudo_label")
max_revision_num = calculate_revision_num(max_revision, symbol_num)
data_sample.set_metainfo(dict(symbol_num=symbol_num))

if hasattr(self.kb, "abduce_candidates"):
candidates = self.kb.abduce_candidates(
@@ -198,9 +177,60 @@ class ReasonerBase:
raise NotImplementedError(
"The kb should either implement abduce_candidates or revise_at_idx."
)
return candidates

candidate = self.select(data_sample, candidates)
return candidate
def select_one_candidate(self, data_sample: ListData, candidates: List[List[Any]]):
"""
Get one candidate. If multiple candidates exist, return the one with minimum cost.

Parameters
----------
pred_pseudo_label : list
The pseudo label to be used for selecting a candidate.
pred_prob : list
Probabilities of the predictions.
candidates : list
List of candidate abduction result.

Returns
-------
list
The chosen candidate based on minimum cost.
If no candidates, an empty list is returned.
"""
if len(candidates) == 0:
return []
elif len(candidates) == 1:
return candidates[0]
else:
cost_array = self._get_dist_list(data_sample, candidates)
candidate = candidates[np.argmin(cost_array)]
return candidate

def _get_dist_list(self, data_sample: ListData, candidates: List[List[Any]]):
"""
Get the list of costs between each pseudo label and candidate.

Parameters
----------
pred_pseudo_label : list
The pseudo label to be used for computing costs of candidates.
pred_prob : list
Probabilities of the predictions. Used when distance function is "confidence".
candidates : list
List of candidate abduction result.

Returns
-------
numpy.ndarray
Array of computed costs for each candidate.
"""
if self.dist_func == "hamming":
return hamming_dist(data_sample["pred_pseudo_label"][0], candidates)

elif self.dist_func == "confidence":
candidates = [[self.remapping[x] for x in c] for c in candidates]
return confidence_dist(data_sample["pred_prob"][0], candidates)

def batch_abduce(
self,
@@ -231,7 +261,7 @@ class ReasonerBase:
The abduced revisions in batches.
"""
abduced_pseudo_label = [
self.cache.get(
self.abduce(
data_sample,
max_revision=max_revision,
require_more_revision=require_more_revision,


+ 0
- 1
abl/reasoning/search_engine/__init__.py View File

@@ -1,3 +1,2 @@
from .base_search_engine import BaseSearchEngine
from .bfs import BFS
from .zoopt import Zoopt

+ 1
- 1
abl/reasoning/search_engine/bfs.py View File

@@ -12,7 +12,7 @@ class BFS(BaseSearchEngine):
pass

def generator(
data_sample: ListData, max_revision_num: int, require_more_revision: int = 0
self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0
) -> Union[List, Tuple, numpy.ndarray]:
symbol_num = data_sample["symbol_num"]
max_revision_num = min(max_revision_num, symbol_num)


+ 7
- 1
examples/mnist_add/mnist_add_example.ipynb View File

@@ -51,7 +51,13 @@
"source": [
"# Initialize knowledge base and abducer\n",
"kb = AddKB()\n",
"abducer = ReasonerBase(kb, dist_func=\"confidence\")"
"\n",
"# If use cache, get_key should be implemented in the abducer\n",
"class AddAbducer(ReasonerBase):\n",
" def get_key(self, data_sample):\n",
" return (data_sample.to_tuple(\"pred_pseudo_label\"), data_sample[\"Y\"][0])\n",
"\n",
"abducer = AddAbducer(kb, dist_func=\"confidence\", use_cache=True)"
]
},
{


+ 2
- 5
examples/mnist_add/mnist_add_kb.py View File

@@ -5,14 +5,11 @@ from abl.structures import ListData


class AddKB(SearchBasedKB):
def __init__(self, pseudo_label_list=list(range(10)), use_cache=True, cache_size=4096):
def __init__(self, pseudo_label_list=list(range(10))):
super().__init__(
pseudo_label_list=pseudo_label_list, use_cache=use_cache, cache_size=cache_size
pseudo_label_list=pseudo_label_list
)

def get_key(self, data_sample: ListData):
return (data_sample.to_tuple("pred_pseudo_label"), data_sample["Y"][0])

def check_equal(self, data_sample: ListData, y: Any):
return self.logic_forward(data_sample) == y



Loading…
Cancel
Save