Browse Source

[ENH] integrate choice of cache in to abl_cache

pull/4/head
Gao Enhao 1 year ago
parent
commit
c95faf043d
2 changed files with 121 additions and 123 deletions
  1. +104
    -97
      abl/reasoning/kb.py
  2. +17
    -26
      abl/utils/cache.py

+ 104
- 97
abl/reasoning/kb.py View File

@@ -12,6 +12,7 @@ import pyswip
from ..utils.utils import flatten, reform_idx, hamming_dist, to_hashable, restore_from_hashable
from ..utils.cache import abl_cache


class KBBase(ABC):
"""
Base class for knowledge base.
@@ -21,35 +22,36 @@ class KBBase(ABC):
pseudo_label_list : list
List of possible pseudo labels.
max_err : float, optional
The upper tolerance limit when comparing the similarity between a candidate's logical
result. This is only applicable when the logical result is of a numerical type.
This is particularly relevant for regression problems where exact matches might not be
feasible. Defaults to 1e-10.
The upper tolerance limit when comparing the similarity between a candidate's logical
result. This is only applicable when the logical result is of a numerical type.
This is particularly relevant for regression problems where exact matches might not be
feasible. Defaults to 1e-10.
use_cache : bool, optional
Whether to use a cache for previously abduced candidates to speed up subsequent
Whether to use a cache for previously abduced candidates to speed up subsequent
operations. Defaults to True.
Notes
-----
Users should inherit from this base class to build their own knowledge base. For the
user-build KB (an inherited subclass), it's only required for the user to provide the
`pseudo_label_list` and override the `logic_forward` function (specifying how to
perform logical reasoning). After that, other operations (e.g. how to perform abductive
reasoning) will be automatically set up.
Users should inherit from this base class to build their own knowledge base. For the
user-build KB (an inherited subclass), it's only required for the user to provide the
`pseudo_label_list` and override the `logic_forward` function (specifying how to
perform logical reasoning). After that, other operations (e.g. how to perform abductive
reasoning) will be automatically set up.
"""

def __init__(self, pseudo_label_list, max_err=1e-10, use_cache=True):
if not isinstance(pseudo_label_list, list):
raise TypeError("pseudo_label_list should be list")
self.pseudo_label_list = pseudo_label_list
self.max_err = max_err
self.use_cache = use_cache
self.use_cache = use_cache

@abstractmethod
def logic_forward(self, pseudo_label):
"""
How to perform (deductive) logical reasoning, i.e. matching each pseudo label to
How to perform (deductive) logical reasoning, i.e. matching each pseudo label to
their logical result. Users are required to provide this.
Parameters
----------
pred_pseudo_label : List[Any]
@@ -70,23 +72,22 @@ class KBBase(ABC):
max_revision_num : int
The upper limit on the number of revisions.
require_more_revision : int, optional
Specifies additional number of revisions permitted beyond the minimum required.
Specifies additional number of revisions permitted beyond the minimum required.
Defaults to 0.

Returns
-------
List[List[Any]]
A list of candidates, i.e. revised pseudo labels that are consistent with the
A list of candidates, i.e. revised pseudo labels that are consistent with the
knowledge base.
"""
if self.use_cache:
return self._abduce_by_search_cache(to_hashable(pred_pseudo_label),
to_hashable(y),
max_revision_num, require_more_revision)
else:
return self._abduce_by_search(pred_pseudo_label, y,
max_revision_num, require_more_revision)
# if self.use_cache:
# return self._abduce_by_search_cache(to_hashable(pred_pseudo_label),
# to_hashable(y),
# max_revision_num, require_more_revision)
# else:
return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision)

def _check_equal(self, logic_result, y):
"""
Check whether the logical result of a candidate is equal to the ground truth
@@ -94,12 +95,12 @@ class KBBase(ABC):
"""
if logic_result == None:
return False
if isinstance(logic_result, (int, float)) and isinstance(y, (int, float)):
return abs(logic_result - y) <= self.max_err
else:
return logic_result == y
def revise_at_idx(self, pred_pseudo_label, y, revision_idx):
"""
Revise the predicted pseudo label at specified index positions.
@@ -125,7 +126,7 @@ class KBBase(ABC):

def _revision(self, revision_num, pred_pseudo_label, y):
"""
For a specified number of pseudo label to revise, iterate through all possible
For a specified number of pseudo label to revise, iterate through all possible
indices to find any candidates that are consistent with the knowledge base.
"""
new_candidates = []
@@ -136,12 +137,13 @@ class KBBase(ABC):
new_candidates.extend(candidates)
return new_candidates

def _abduce_by_search(self, pred_pseudo_label, y, max_revision_num, require_more_revision):
@abl_cache(max_size=4096)
def _abduce_by_search(self, pred_pseudo_label, y, max_revision_num, require_more_revision):
"""
Perform abductive reasoning by exhastive search. Specifically, begin with 0 and
continuously increase the number of pseudo labels to revise, until candidates
Perform abductive reasoning by exhastive search. Specifically, begin with 0 and
continuously increase the number of pseudo labels to revise, until candidates
that are consistent with the knowledge base are found.
Parameters
----------
pred_pseudo_label : List[Any]
@@ -151,16 +153,16 @@ class KBBase(ABC):
max_revision_num : int
The upper limit on the number of revisions.
require_more_revision : int
If larger than 0, then after having found any candidates consistent with the
knowledge base, continue to increase the number pseudo labels to revise to
If larger than 0, then after having found any candidates consistent with the
knowledge base, continue to increase the number pseudo labels to revise to
get more possible consistent candidates.

Returns
-------
List[List[Any]]
A list of candidates, i.e. revised pseudo label that are consistent with the
A list of candidates, i.e. revised pseudo label that are consistent with the
knowledge base.
"""
"""
candidates = []
for revision_num in range(len(pred_pseudo_label) + 1):
if revision_num == 0 and self._check_equal(self.logic_forward(pred_pseudo_label), y):
@@ -173,20 +175,22 @@ class KBBase(ABC):
if revision_num >= max_revision_num:
return []

for revision_num in range(min_revision_num + 1, min_revision_num + require_more_revision + 1):
for revision_num in range(
min_revision_num + 1, min_revision_num + require_more_revision + 1
):
if revision_num > max_revision_num:
return candidates
candidates.extend(self._revision(revision_num, pred_pseudo_label, y))
return candidates
@abl_cache(max_size=4096)
def _abduce_by_search_cache(self, pred_pseudo_label, y, max_revision_num, require_more_revision):
"""
`_abduce_by_search` with cache.
"""
pred_pseudo_label = restore_from_hashable(pred_pseudo_label)
y = restore_from_hashable(y)
return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision)
# @abl_cache(max_size=4096)
# def _abduce_by_search_cache(self, pred_pseudo_label, y, max_revision_num, require_more_revision):
# """
# `_abduce_by_search` with cache.
# """
# pred_pseudo_label = restore_from_hashable(pred_pseudo_label)
# y = restore_from_hashable(y)
# return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision)

def __repr__(self):
return (
@@ -195,13 +199,13 @@ class KBBase(ABC):
f"max_err={self.max_err!r}, "
f"use_cache={self.use_cache!r}."
)
class GroundKB(KBBase):
"""
Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon
class initialization, storing all potential candidates along with their respective
logical result. Ground KB can accelerate abductive reasoning in `abduce_candidates`.
Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon
class initialization, storing all potential candidates along with their respective
logical result. Ground KB can accelerate abductive reasoning in `abduce_candidates`.

Parameters
----------
@@ -211,15 +215,16 @@ class GroundKB(KBBase):
List of possible lengths of pseudo label.
max_err : float, optional
Refer to class `KBBase`.
Notes
-----
Users can also inherit from this class to build their own knowledge base. Similar
to `KBBase`, users are only required to provide the `pseudo_label_list` and override
Users can also inherit from this class to build their own knowledge base. Similar
to `KBBase`, users are only required to provide the `pseudo_label_list` and override
the `logic_forward` function. Additionally, users should provide the `GKB_len_list`.
After that, other operations (e.g. auto-construction of GKB, and how to perform
After that, other operations (e.g. auto-construction of GKB, and how to perform
abductive reasoning) will be automatically set up.
"""

def __init__(self, pseudo_label_list, GKB_len_list, max_err=1e-10):
super().__init__(pseudo_label_list, max_err)
if not isinstance(GKB_len_list, list):
@@ -229,7 +234,6 @@ class GroundKB(KBBase):
X, Y = self._get_GKB()
for x, y in zip(X, Y):
self.GKB.setdefault(len(x), defaultdict(list))[y].append(x)

def _get_XY_list(self, args):
pre_x, post_x_it = args[0], args[1]
@@ -259,21 +263,21 @@ class GroundKB(KBBase):
part_X, part_Y = zip(*XY_list)
X.extend(part_X)
Y.extend(part_Y)
if Y and isinstance(Y[0], (int, float)):
if Y and isinstance(Y[0], (int, float)):
X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1]))
return X, Y
def abduce_candidates(self, pred_pseudo_label, y, max_revision_num, require_more_revision=0):
"""
Perform abductive reasoning by directly retrieving consistent candidates from
the prebuilt GKB. In this way, the time-consuming exhaustive search can be
Perform abductive reasoning by directly retrieving consistent candidates from
the prebuilt GKB. In this way, the time-consuming exhaustive search can be
avoided.
This is an overridden function. For more information about the parameters and
This is an overridden function. For more information about the parameters and
returns, refer to the function of the same name in class `KBBase`.
"""
if self.GKB == {} or len(pred_pseudo_label) not in self.GKB_len_list:
return []
all_candidates = self._find_candidate_GKB(pred_pseudo_label, y)
if len(all_candidates) == 0:
return []
@@ -284,29 +288,30 @@ class GroundKB(KBBase):
idxs = np.where(cost_list <= revision_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
return candidates
def _find_candidate_GKB(self, pred_pseudo_label, y):
"""
Retrieve consistent candidates from the prebuilt GKB. For numerical logical results,
return all candidates whose logical results fall within the
Retrieve consistent candidates from the prebuilt GKB. For numerical logical results,
return all candidates whose logical results fall within the
[y - max_err, y + max_err] range.
"""
if isinstance(y, (int, float)):
potential_candidates = self.GKB[len(pred_pseudo_label)]
key_list = list(potential_candidates.keys())
low_key = bisect.bisect_left(key_list, y - self.max_err)
high_key = bisect.bisect_right(key_list, y + self.max_err)

all_candidates = [candidate
for key in key_list[low_key:high_key]
for candidate in potential_candidates[key]]
all_candidates = [
candidate
for key in key_list[low_key:high_key]
for candidate in potential_candidates[key]
]
return all_candidates
else:
return self.GKB[len(pred_pseudo_label)][y]

def __repr__(self):
return (
f"{self.__class__.__name__} is a KB with "
@@ -321,78 +326,80 @@ class GroundKB(KBBase):
class PrologKB(KBBase):
"""
Knowledge base provided by a Prolog (.pl) file.
Parameters
----------
pseudo_label_list : list
Refer to class `KBBase`.
pl_file :
Prolog file containing the KB.
pl_file :
Prolog file containing the KB.
max_err : float, optional
Refer to class `KBBase`.
Notes
-----
Users can instantiate this class to build their own knowledge base. During the
Users can instantiate this class to build their own knowledge base. During the
instantiation, users are only required to provide the `pseudo_label_list` and `pl_file`.
To use the default logic forward and abductive reasoning methods in this class, in the
Prolog (.pl) file, there needs to be a rule which is strictly formatted as
To use the default logic forward and abductive reasoning methods in this class, in the
Prolog (.pl) file, there needs to be a rule which is strictly formatted as
`logic_forward(Pseudo_labels, Res).`, e.g., `logic_forward([A,B], C) :- C is A+B`.
For specifics, refer to the `logic_forward` and `get_query_string` functions in this
For specifics, refer to the `logic_forward` and `get_query_string` functions in this
class. Users are also welcome to override related functions for more flexible support.
"""

def __init__(self, pseudo_label_list, pl_file):
super().__init__(pseudo_label_list)
self.pl_file = pl_file
self.prolog = pyswip.Prolog()
if not os.path.exists(self.pl_file):
raise FileNotFoundError(f"The Prolog file {self.pl_file} does not exist.")
self.prolog.consult(self.pl_file)

def logic_forward(self, pseudo_labels):
"""
Consult prolog with the query `logic_forward(pseudo_labels, Res).`, and set the
returned `Res` as the logical results. To use this default function, there must be
a Prolog `log_forward` method in the pl file to perform logical. reasoning.
Consult prolog with the query `logic_forward(pseudo_labels, Res).`, and set the
returned `Res` as the logical results. To use this default function, there must be
a Prolog `log_forward` method in the pl file to perform logical. reasoning.
Otherwise, users would override this function.
"""
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]['Res']
if result == 'true':
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]["Res"]
if result == "true":
return True
elif result == 'false':
elif result == "false":
return False
return result
def _revision_pred_pseudo_label(self, pred_pseudo_label, revision_idx):
import re

revision_pred_pseudo_label = pred_pseudo_label.copy()
revision_pred_pseudo_label = flatten(revision_pred_pseudo_label)
for idx in revision_idx:
revision_pred_pseudo_label[idx] = 'P' + str(idx)
revision_pred_pseudo_label[idx] = "P" + str(idx)
revision_pred_pseudo_label = reform_idx(revision_pred_pseudo_label, pred_pseudo_label)
regex = r"'P\d+'"
return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pred_pseudo_label))
def get_query_string(self, pred_pseudo_label, y, revision_idx):
"""
Consult prolog with `logic_forward([kept_labels, Revise_labels], Res).`, and set
the returned `Revise_labels` together with the kept labels as the candidates. This is
Consult prolog with `logic_forward([kept_labels, Revise_labels], Res).`, and set
the returned `Revise_labels` together with the kept labels as the candidates. This is
a default fuction for demo, users would override this function to adapt to their own
Prolog file.
Prolog file.
"""
query_string = "logic_forward("
query_string += self._revision_pred_pseudo_label(pred_pseudo_label, revision_idx)
key_is_none_flag = y is None or (type(y) == list and y[0] is None)
query_string += ",%s)." % y if not key_is_none_flag else ")."
return query_string
def revise_at_idx(self, pred_pseudo_label, y, revision_idx):
"""
Revise the predicted pseudo label at specified index positions by querying Prolog.
This is an overridden function. For more information about the parameters, refer to
This is an overridden function. For more information about the parameters, refer to
the function of the same name in class `KBBase`.
"""
candidates = []
@@ -414,4 +421,4 @@ class PrologKB(KBBase):
f"pseudo_label_list={self.pseudo_label_list!r}, "
f"defined by "
f"Prolog file {self.pl_file!r}."
)
)

+ 17
- 26
abl/utils/cache.py View File

@@ -3,6 +3,7 @@ from os import PathLike
from typing import Callable, Generic, Hashable, TypeVar, Union

from .logger import print_log
from .utils import to_hashable

K = TypeVar("K")
T = TypeVar("T")
@@ -13,7 +14,6 @@ class Cache(Generic[K, T]):
def __init__(
self,
func: Callable[[K], T],
cache: bool = True,
cache_file: Union[None, str, PathLike] = None,
key_func: Callable[[K], Hashable] = lambda x: x,
max_size: int = 4096,
@@ -27,23 +27,15 @@ class Cache(Generic[K, T]):
"""
self.func = func
self.key_func = key_func
self.cache = cache
if cache is True or cache_file is not None:
print_log("Caching is activated", logger="current")
self._init_cache(cache_file, max_size)
self.first = self.get_from_dict
else:
self.first = self.func

def __getitem__(self, item: K, *args) -> T:
return self.first(item, *args)
self._init_cache(cache_file, max_size)

def __getitem__(self, obj, *args) -> T:
return self.get_from_dict(obj, *args)

def invalidate(self):
def clear_cache(self):
"""Invalidate entire cache."""
self.cache_dict.clear()
if self.cache_file:
for p in self.cache_root.iterdir():
p.unlink()

def _init_cache(self, cache_file, max_size):
self.cache = True
@@ -66,13 +58,10 @@ class Cache(Generic[K, T]):
link = [last, self.root, cache_key, result]
last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link

def get(self, obj, item: K, *args) -> T:
return self.first(obj, item, *args)

def get_from_dict(self, obj, item: K, *args) -> T:
def get_from_dict(self, obj, *args) -> T:
"""Implements dict based cache."""
# result = self.func(obj, item, *args)
cache_key = (self.key_func(item), *args)
pred_pseudo_label, y, *res_args = args
cache_key = (self.key_func(pred_pseudo_label), self.key_func(y), *res_args)
link = self.cache_dict.get(cache_key)
if link is not None:
# Move the link to the front of the circular queue
@@ -87,7 +76,7 @@ class Cache(Generic[K, T]):
return result
self.misses += 1

result = self.func(obj, item, *args)
result = self.func(obj, *args)

if self.full:
# Use the old root to store the new key and result.
@@ -113,16 +102,18 @@ class Cache(Generic[K, T]):


def abl_cache(
cache: bool = True,
cache_file: Union[None, str, PathLike] = None,
key_func: Callable[[K], Hashable] = lambda x: x,
key_func: Callable[[K], Hashable] = to_hashable,
max_size: int = 4096,
):
def decorator(func):
cache_instance = Cache(func, cache, cache_file, key_func, max_size)
cache_instance = Cache(func, cache_file, key_func, max_size)

def wrapper(self, *args, **kwargs):
return cache_instance.get(self, *args, **kwargs)
def wrapper(obj, *args):
if obj.use_cache:
return cache_instance.get_from_dict(obj, *args)
else:
return func(obj, *args)

return wrapper



Loading…
Cancel
Save