|
|
@@ -12,6 +12,7 @@ import pyswip |
|
|
|
from ..utils.utils import flatten, reform_idx, hamming_dist, to_hashable, restore_from_hashable |
|
|
|
from ..utils.cache import abl_cache |
|
|
|
|
|
|
|
|
|
|
|
class KBBase(ABC): |
|
|
|
""" |
|
|
|
Base class for knowledge base. |
|
|
@@ -21,35 +22,36 @@ class KBBase(ABC): |
|
|
|
pseudo_label_list : list |
|
|
|
List of possible pseudo labels. |
|
|
|
max_err : float, optional |
|
|
|
The upper tolerance limit when comparing the similarity between a candidate's logical |
|
|
|
result. This is only applicable when the logical result is of a numerical type. |
|
|
|
This is particularly relevant for regression problems where exact matches might not be |
|
|
|
feasible. Defaults to 1e-10. |
|
|
|
The upper tolerance limit when comparing the similarity between a candidate's logical |
|
|
|
result. This is only applicable when the logical result is of a numerical type. |
|
|
|
This is particularly relevant for regression problems where exact matches might not be |
|
|
|
feasible. Defaults to 1e-10. |
|
|
|
use_cache : bool, optional |
|
|
|
Whether to use a cache for previously abduced candidates to speed up subsequent |
|
|
|
Whether to use a cache for previously abduced candidates to speed up subsequent |
|
|
|
operations. Defaults to True. |
|
|
|
|
|
|
|
|
|
|
|
Notes |
|
|
|
----- |
|
|
|
Users should inherit from this base class to build their own knowledge base. For the |
|
|
|
user-build KB (an inherited subclass), it's only required for the user to provide the |
|
|
|
`pseudo_label_list` and override the `logic_forward` function (specifying how to |
|
|
|
perform logical reasoning). After that, other operations (e.g. how to perform abductive |
|
|
|
reasoning) will be automatically set up. |
|
|
|
Users should inherit from this base class to build their own knowledge base. For the |
|
|
|
user-build KB (an inherited subclass), it's only required for the user to provide the |
|
|
|
`pseudo_label_list` and override the `logic_forward` function (specifying how to |
|
|
|
perform logical reasoning). After that, other operations (e.g. how to perform abductive |
|
|
|
reasoning) will be automatically set up. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, pseudo_label_list, max_err=1e-10, use_cache=True): |
|
|
|
if not isinstance(pseudo_label_list, list): |
|
|
|
raise TypeError("pseudo_label_list should be list") |
|
|
|
self.pseudo_label_list = pseudo_label_list |
|
|
|
self.max_err = max_err |
|
|
|
self.use_cache = use_cache |
|
|
|
self.use_cache = use_cache |
|
|
|
|
|
|
|
@abstractmethod |
|
|
|
def logic_forward(self, pseudo_label): |
|
|
|
""" |
|
|
|
How to perform (deductive) logical reasoning, i.e. matching each pseudo label to |
|
|
|
How to perform (deductive) logical reasoning, i.e. matching each pseudo label to |
|
|
|
their logical result. Users are required to provide this. |
|
|
|
|
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
pred_pseudo_label : List[Any] |
|
|
@@ -70,23 +72,22 @@ class KBBase(ABC): |
|
|
|
max_revision_num : int |
|
|
|
The upper limit on the number of revisions. |
|
|
|
require_more_revision : int, optional |
|
|
|
Specifies additional number of revisions permitted beyond the minimum required. |
|
|
|
Specifies additional number of revisions permitted beyond the minimum required. |
|
|
|
Defaults to 0. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
List[List[Any]] |
|
|
|
A list of candidates, i.e. revised pseudo labels that are consistent with the |
|
|
|
A list of candidates, i.e. revised pseudo labels that are consistent with the |
|
|
|
knowledge base. |
|
|
|
""" |
|
|
|
if self.use_cache: |
|
|
|
return self._abduce_by_search_cache(to_hashable(pred_pseudo_label), |
|
|
|
to_hashable(y), |
|
|
|
max_revision_num, require_more_revision) |
|
|
|
else: |
|
|
|
return self._abduce_by_search(pred_pseudo_label, y, |
|
|
|
max_revision_num, require_more_revision) |
|
|
|
|
|
|
|
# if self.use_cache: |
|
|
|
# return self._abduce_by_search_cache(to_hashable(pred_pseudo_label), |
|
|
|
# to_hashable(y), |
|
|
|
# max_revision_num, require_more_revision) |
|
|
|
# else: |
|
|
|
return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision) |
|
|
|
|
|
|
|
def _check_equal(self, logic_result, y): |
|
|
|
""" |
|
|
|
Check whether the logical result of a candidate is equal to the ground truth |
|
|
@@ -94,12 +95,12 @@ class KBBase(ABC): |
|
|
|
""" |
|
|
|
if logic_result == None: |
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(logic_result, (int, float)) and isinstance(y, (int, float)): |
|
|
|
return abs(logic_result - y) <= self.max_err |
|
|
|
else: |
|
|
|
return logic_result == y |
|
|
|
|
|
|
|
|
|
|
|
def revise_at_idx(self, pred_pseudo_label, y, revision_idx): |
|
|
|
""" |
|
|
|
Revise the predicted pseudo label at specified index positions. |
|
|
@@ -125,7 +126,7 @@ class KBBase(ABC): |
|
|
|
|
|
|
|
def _revision(self, revision_num, pred_pseudo_label, y): |
|
|
|
""" |
|
|
|
For a specified number of pseudo label to revise, iterate through all possible |
|
|
|
For a specified number of pseudo label to revise, iterate through all possible |
|
|
|
indices to find any candidates that are consistent with the knowledge base. |
|
|
|
""" |
|
|
|
new_candidates = [] |
|
|
@@ -136,12 +137,13 @@ class KBBase(ABC): |
|
|
|
new_candidates.extend(candidates) |
|
|
|
return new_candidates |
|
|
|
|
|
|
|
def _abduce_by_search(self, pred_pseudo_label, y, max_revision_num, require_more_revision): |
|
|
|
@abl_cache(max_size=4096) |
|
|
|
def _abduce_by_search(self, pred_pseudo_label, y, max_revision_num, require_more_revision): |
|
|
|
""" |
|
|
|
Perform abductive reasoning by exhastive search. Specifically, begin with 0 and |
|
|
|
continuously increase the number of pseudo labels to revise, until candidates |
|
|
|
Perform abductive reasoning by exhastive search. Specifically, begin with 0 and |
|
|
|
continuously increase the number of pseudo labels to revise, until candidates |
|
|
|
that are consistent with the knowledge base are found. |
|
|
|
|
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
pred_pseudo_label : List[Any] |
|
|
@@ -151,16 +153,16 @@ class KBBase(ABC): |
|
|
|
max_revision_num : int |
|
|
|
The upper limit on the number of revisions. |
|
|
|
require_more_revision : int |
|
|
|
If larger than 0, then after having found any candidates consistent with the |
|
|
|
knowledge base, continue to increase the number pseudo labels to revise to |
|
|
|
If larger than 0, then after having found any candidates consistent with the |
|
|
|
knowledge base, continue to increase the number pseudo labels to revise to |
|
|
|
get more possible consistent candidates. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
List[List[Any]] |
|
|
|
A list of candidates, i.e. revised pseudo label that are consistent with the |
|
|
|
A list of candidates, i.e. revised pseudo label that are consistent with the |
|
|
|
knowledge base. |
|
|
|
""" |
|
|
|
""" |
|
|
|
candidates = [] |
|
|
|
for revision_num in range(len(pred_pseudo_label) + 1): |
|
|
|
if revision_num == 0 and self._check_equal(self.logic_forward(pred_pseudo_label), y): |
|
|
@@ -173,20 +175,22 @@ class KBBase(ABC): |
|
|
|
if revision_num >= max_revision_num: |
|
|
|
return [] |
|
|
|
|
|
|
|
for revision_num in range(min_revision_num + 1, min_revision_num + require_more_revision + 1): |
|
|
|
for revision_num in range( |
|
|
|
min_revision_num + 1, min_revision_num + require_more_revision + 1 |
|
|
|
): |
|
|
|
if revision_num > max_revision_num: |
|
|
|
return candidates |
|
|
|
candidates.extend(self._revision(revision_num, pred_pseudo_label, y)) |
|
|
|
return candidates |
|
|
|
|
|
|
|
@abl_cache(max_size=4096) |
|
|
|
def _abduce_by_search_cache(self, pred_pseudo_label, y, max_revision_num, require_more_revision): |
|
|
|
""" |
|
|
|
`_abduce_by_search` with cache. |
|
|
|
""" |
|
|
|
pred_pseudo_label = restore_from_hashable(pred_pseudo_label) |
|
|
|
y = restore_from_hashable(y) |
|
|
|
return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision) |
|
|
|
|
|
|
|
# @abl_cache(max_size=4096) |
|
|
|
# def _abduce_by_search_cache(self, pred_pseudo_label, y, max_revision_num, require_more_revision): |
|
|
|
# """ |
|
|
|
# `_abduce_by_search` with cache. |
|
|
|
# """ |
|
|
|
# pred_pseudo_label = restore_from_hashable(pred_pseudo_label) |
|
|
|
# y = restore_from_hashable(y) |
|
|
|
# return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision) |
|
|
|
|
|
|
|
def __repr__(self): |
|
|
|
return ( |
|
|
@@ -195,13 +199,13 @@ class KBBase(ABC): |
|
|
|
f"max_err={self.max_err!r}, " |
|
|
|
f"use_cache={self.use_cache!r}." |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GroundKB(KBBase): |
|
|
|
""" |
|
|
|
Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon |
|
|
|
class initialization, storing all potential candidates along with their respective |
|
|
|
logical result. Ground KB can accelerate abductive reasoning in `abduce_candidates`. |
|
|
|
Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon |
|
|
|
class initialization, storing all potential candidates along with their respective |
|
|
|
logical result. Ground KB can accelerate abductive reasoning in `abduce_candidates`. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
@@ -211,15 +215,16 @@ class GroundKB(KBBase): |
|
|
|
List of possible lengths of pseudo label. |
|
|
|
max_err : float, optional |
|
|
|
Refer to class `KBBase`. |
|
|
|
|
|
|
|
|
|
|
|
Notes |
|
|
|
----- |
|
|
|
Users can also inherit from this class to build their own knowledge base. Similar |
|
|
|
to `KBBase`, users are only required to provide the `pseudo_label_list` and override |
|
|
|
Users can also inherit from this class to build their own knowledge base. Similar |
|
|
|
to `KBBase`, users are only required to provide the `pseudo_label_list` and override |
|
|
|
the `logic_forward` function. Additionally, users should provide the `GKB_len_list`. |
|
|
|
After that, other operations (e.g. auto-construction of GKB, and how to perform |
|
|
|
After that, other operations (e.g. auto-construction of GKB, and how to perform |
|
|
|
abductive reasoning) will be automatically set up. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, pseudo_label_list, GKB_len_list, max_err=1e-10): |
|
|
|
super().__init__(pseudo_label_list, max_err) |
|
|
|
if not isinstance(GKB_len_list, list): |
|
|
@@ -229,7 +234,6 @@ class GroundKB(KBBase): |
|
|
|
X, Y = self._get_GKB() |
|
|
|
for x, y in zip(X, Y): |
|
|
|
self.GKB.setdefault(len(x), defaultdict(list))[y].append(x) |
|
|
|
|
|
|
|
|
|
|
|
def _get_XY_list(self, args): |
|
|
|
pre_x, post_x_it = args[0], args[1] |
|
|
@@ -259,21 +263,21 @@ class GroundKB(KBBase): |
|
|
|
part_X, part_Y = zip(*XY_list) |
|
|
|
X.extend(part_X) |
|
|
|
Y.extend(part_Y) |
|
|
|
if Y and isinstance(Y[0], (int, float)): |
|
|
|
if Y and isinstance(Y[0], (int, float)): |
|
|
|
X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) |
|
|
|
return X, Y |
|
|
|
|
|
|
|
|
|
|
|
def abduce_candidates(self, pred_pseudo_label, y, max_revision_num, require_more_revision=0): |
|
|
|
""" |
|
|
|
Perform abductive reasoning by directly retrieving consistent candidates from |
|
|
|
the prebuilt GKB. In this way, the time-consuming exhaustive search can be |
|
|
|
Perform abductive reasoning by directly retrieving consistent candidates from |
|
|
|
the prebuilt GKB. In this way, the time-consuming exhaustive search can be |
|
|
|
avoided. |
|
|
|
This is an overridden function. For more information about the parameters and |
|
|
|
This is an overridden function. For more information about the parameters and |
|
|
|
returns, refer to the function of the same name in class `KBBase`. |
|
|
|
""" |
|
|
|
if self.GKB == {} or len(pred_pseudo_label) not in self.GKB_len_list: |
|
|
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
all_candidates = self._find_candidate_GKB(pred_pseudo_label, y) |
|
|
|
if len(all_candidates) == 0: |
|
|
|
return [] |
|
|
@@ -284,29 +288,30 @@ class GroundKB(KBBase): |
|
|
|
idxs = np.where(cost_list <= revision_num)[0] |
|
|
|
candidates = [all_candidates[idx] for idx in idxs] |
|
|
|
return candidates |
|
|
|
|
|
|
|
|
|
|
|
def _find_candidate_GKB(self, pred_pseudo_label, y): |
|
|
|
""" |
|
|
|
Retrieve consistent candidates from the prebuilt GKB. For numerical logical results, |
|
|
|
return all candidates whose logical results fall within the |
|
|
|
Retrieve consistent candidates from the prebuilt GKB. For numerical logical results, |
|
|
|
return all candidates whose logical results fall within the |
|
|
|
[y - max_err, y + max_err] range. |
|
|
|
""" |
|
|
|
if isinstance(y, (int, float)): |
|
|
|
potential_candidates = self.GKB[len(pred_pseudo_label)] |
|
|
|
key_list = list(potential_candidates.keys()) |
|
|
|
|
|
|
|
|
|
|
|
low_key = bisect.bisect_left(key_list, y - self.max_err) |
|
|
|
high_key = bisect.bisect_right(key_list, y + self.max_err) |
|
|
|
|
|
|
|
all_candidates = [candidate |
|
|
|
for key in key_list[low_key:high_key] |
|
|
|
for candidate in potential_candidates[key]] |
|
|
|
all_candidates = [ |
|
|
|
candidate |
|
|
|
for key in key_list[low_key:high_key] |
|
|
|
for candidate in potential_candidates[key] |
|
|
|
] |
|
|
|
return all_candidates |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
return self.GKB[len(pred_pseudo_label)][y] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self): |
|
|
|
return ( |
|
|
|
f"{self.__class__.__name__} is a KB with " |
|
|
@@ -321,78 +326,80 @@ class GroundKB(KBBase): |
|
|
|
class PrologKB(KBBase): |
|
|
|
""" |
|
|
|
Knowledge base provided by a Prolog (.pl) file. |
|
|
|
|
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
pseudo_label_list : list |
|
|
|
Refer to class `KBBase`. |
|
|
|
pl_file : |
|
|
|
Prolog file containing the KB. |
|
|
|
pl_file : |
|
|
|
Prolog file containing the KB. |
|
|
|
max_err : float, optional |
|
|
|
Refer to class `KBBase`. |
|
|
|
|
|
|
|
|
|
|
|
Notes |
|
|
|
----- |
|
|
|
Users can instantiate this class to build their own knowledge base. During the |
|
|
|
Users can instantiate this class to build their own knowledge base. During the |
|
|
|
instantiation, users are only required to provide the `pseudo_label_list` and `pl_file`. |
|
|
|
To use the default logic forward and abductive reasoning methods in this class, in the |
|
|
|
Prolog (.pl) file, there needs to be a rule which is strictly formatted as |
|
|
|
To use the default logic forward and abductive reasoning methods in this class, in the |
|
|
|
Prolog (.pl) file, there needs to be a rule which is strictly formatted as |
|
|
|
`logic_forward(Pseudo_labels, Res).`, e.g., `logic_forward([A,B], C) :- C is A+B`. |
|
|
|
For specifics, refer to the `logic_forward` and `get_query_string` functions in this |
|
|
|
For specifics, refer to the `logic_forward` and `get_query_string` functions in this |
|
|
|
class. Users are also welcome to override related functions for more flexible support. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, pseudo_label_list, pl_file): |
|
|
|
super().__init__(pseudo_label_list) |
|
|
|
self.pl_file = pl_file |
|
|
|
self.prolog = pyswip.Prolog() |
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(self.pl_file): |
|
|
|
raise FileNotFoundError(f"The Prolog file {self.pl_file} does not exist.") |
|
|
|
self.prolog.consult(self.pl_file) |
|
|
|
|
|
|
|
def logic_forward(self, pseudo_labels): |
|
|
|
""" |
|
|
|
Consult prolog with the query `logic_forward(pseudo_labels, Res).`, and set the |
|
|
|
returned `Res` as the logical results. To use this default function, there must be |
|
|
|
a Prolog `log_forward` method in the pl file to perform logical. reasoning. |
|
|
|
Consult prolog with the query `logic_forward(pseudo_labels, Res).`, and set the |
|
|
|
returned `Res` as the logical results. To use this default function, there must be |
|
|
|
a Prolog `log_forward` method in the pl file to perform logical. reasoning. |
|
|
|
Otherwise, users would override this function. |
|
|
|
""" |
|
|
|
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]['Res'] |
|
|
|
if result == 'true': |
|
|
|
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]["Res"] |
|
|
|
if result == "true": |
|
|
|
return True |
|
|
|
elif result == 'false': |
|
|
|
elif result == "false": |
|
|
|
return False |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
def _revision_pred_pseudo_label(self, pred_pseudo_label, revision_idx): |
|
|
|
import re |
|
|
|
|
|
|
|
revision_pred_pseudo_label = pred_pseudo_label.copy() |
|
|
|
revision_pred_pseudo_label = flatten(revision_pred_pseudo_label) |
|
|
|
|
|
|
|
|
|
|
|
for idx in revision_idx: |
|
|
|
revision_pred_pseudo_label[idx] = 'P' + str(idx) |
|
|
|
revision_pred_pseudo_label[idx] = "P" + str(idx) |
|
|
|
revision_pred_pseudo_label = reform_idx(revision_pred_pseudo_label, pred_pseudo_label) |
|
|
|
|
|
|
|
|
|
|
|
regex = r"'P\d+'" |
|
|
|
return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pred_pseudo_label)) |
|
|
|
|
|
|
|
|
|
|
|
def get_query_string(self, pred_pseudo_label, y, revision_idx): |
|
|
|
""" |
|
|
|
Consult prolog with `logic_forward([kept_labels, Revise_labels], Res).`, and set |
|
|
|
the returned `Revise_labels` together with the kept labels as the candidates. This is |
|
|
|
Consult prolog with `logic_forward([kept_labels, Revise_labels], Res).`, and set |
|
|
|
the returned `Revise_labels` together with the kept labels as the candidates. This is |
|
|
|
a default fuction for demo, users would override this function to adapt to their own |
|
|
|
Prolog file. |
|
|
|
Prolog file. |
|
|
|
""" |
|
|
|
query_string = "logic_forward(" |
|
|
|
query_string += self._revision_pred_pseudo_label(pred_pseudo_label, revision_idx) |
|
|
|
key_is_none_flag = y is None or (type(y) == list and y[0] is None) |
|
|
|
query_string += ",%s)." % y if not key_is_none_flag else ")." |
|
|
|
return query_string |
|
|
|
|
|
|
|
|
|
|
|
def revise_at_idx(self, pred_pseudo_label, y, revision_idx): |
|
|
|
""" |
|
|
|
Revise the predicted pseudo label at specified index positions by querying Prolog. |
|
|
|
This is an overridden function. For more information about the parameters, refer to |
|
|
|
This is an overridden function. For more information about the parameters, refer to |
|
|
|
the function of the same name in class `KBBase`. |
|
|
|
""" |
|
|
|
candidates = [] |
|
|
@@ -414,4 +421,4 @@ class PrologKB(KBBase): |
|
|
|
f"pseudo_label_list={self.pseudo_label_list!r}, " |
|
|
|
f"defined by " |
|
|
|
f"Prolog file {self.pl_file!r}." |
|
|
|
) |
|
|
|
) |