from abc import ABC, abstractmethod import bisect import os from collections import defaultdict from itertools import product, combinations from multiprocessing import Pool from functools import lru_cache import numpy as np import pyswip from ..utils.utils import flatten, reform_list, hamming_dist, to_hashable from ..utils.cache import abl_cache class KBBase(ABC): """ Base class for knowledge base. Parameters ---------- pseudo_label_list : list List of possible pseudo labels. It's recommended to arrange the pseudo labels in this list so that each aligns with its corresponding index in the base model: the first with the 0th index, the second with the 1st, and so forth. max_err : float, optional The upper tolerance limit when comparing the similarity between a pseudo label sample's reasoning result and the ground truth. This is only applicable when the reasoning result is of a numerical type. This is particularly relevant for regression problems where exact matches might not be feasible. Defaults to 1e-10. use_cache : bool, optional Whether to use abl_cache for previously abduced candidates to speed up subsequent operations. Defaults to True. key_func : func, optional A function employed for hashing in abl_cache. This is only operational when use_cache is set to True. Defaults to to_hashable. cache_size: int, optional The cache size in abl_cache. This is only operational when use_cache is set to True. Defaults to 4096. Notes ----- Users should inherit from this base class to build their own knowledge base. For the user-build KB (an inherited subclass), it's only required for the user to provide the `pseudo_label_list` and override the `logic_forward` function (specifying how to perform logical reasoning). After that, other operations (e.g. how to perform abductive reasoning) will be automatically set up. """ def __init__( self, pseudo_label_list, max_err=1e-10, use_cache=True, key_func=to_hashable, cache_size=4096, ): if not isinstance(pseudo_label_list, list): raise TypeError("pseudo_label_list should be list") self.pseudo_label_list = pseudo_label_list self.max_err = max_err self.use_cache = use_cache self.key_func = key_func self.cache_size = cache_size @abstractmethod def logic_forward(self, pseudo_label): """ How to perform (deductive) logical reasoning, i.e. matching each pseudo label sample to their reasoning result. Users are required to provide this. Parameters ---------- pseudo_label : List[Any] Pseudo label sample. """ pass def abduce_candidates(self, pseudo_label, y, max_revision_num, require_more_revision): """ Perform abductive reasoning to get a candidate compatible with the knowledge base. Parameters ---------- pseudo_label : List[Any] Pseudo label sample (to be revised by abductive reasoning). y : any Ground truth of the reasoning result for the sample. max_revision_num : int The upper limit on the number of revised labels for each sample. require_more_revision : int Specifies additional number of revisions permitted beyond the minimum required. Returns ------- List[List[Any]] A list of candidates, i.e. revised pseudo label samples that are compatible with the knowledge base. """ return self._abduce_by_search(pseudo_label, y, max_revision_num, require_more_revision) def _check_equal(self, logic_result, y): """ Check whether the reasoning result of a pseduo label sample is equal to the ground truth (or, within the maximum error allowed for numerical results). Returns ------- bool The result of the check. """ if logic_result is None: return False if isinstance(logic_result, (int, float)) and isinstance(y, (int, float)): return abs(logic_result - y) <= self.max_err else: return logic_result == y def revise_at_idx(self, pseudo_label, y, revision_idx): """ Revise the pseudo label sample at specified index positions. Parameters ---------- pseudo_label : List[Any] Pseudo label sample (to be revised). y : Any Ground truth of the reasoning result for the sample. revision_idx : array-like Indices of where revisions should be made to the pseudo label sample. Returns ------- List[List[Any]] A list of candidates, i.e. revised pseudo label samples that are compatible with the knowledge base. """ candidates = [] abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx)) for c in abduce_c: candidate = pseudo_label.copy() for i, idx in enumerate(revision_idx): candidate[idx] = c[i] if self._check_equal(self.logic_forward(candidate), y): candidates.append(candidate) return candidates def _revision(self, revision_num, pseudo_label, y): """ For a specified number of labels in a pseudo label sample to revise, iterate through all possible indices to find any candidates that are compatible with the knowledge base. """ new_candidates = [] revision_idx_list = combinations(range(len(pseudo_label)), revision_num) for revision_idx in revision_idx_list: candidates = self.revise_at_idx(pseudo_label, y, revision_idx) new_candidates.extend(candidates) return new_candidates @abl_cache() def _abduce_by_search(self, pseudo_label, y, max_revision_num, require_more_revision): """ Perform abductive reasoning by exhastive search. Specifically, begin with 0 and continuously increase the number of labels in a pseudo label sample to revise, until candidates that are compatible with the knowledge base are found. Parameters ---------- pseudo_label : List[Any] Pseudo label sample (to be revised). y : Any Ground truth of the reasoning result for the sample. max_revision_num : int The upper limit on the number of revisions. require_more_revision : int If larger than 0, then after having found any candidates compatible with the knowledge base, continue to increase the number of labels in a pseudo label sample to revise to get more possible compatible candidates. Returns ------- List[List[Any]] A list of candidates, i.e. revised pseudo label samples that are compatible with the knowledge base. """ candidates = [] for revision_num in range(len(pseudo_label) + 1): if revision_num == 0 and self._check_equal(self.logic_forward(pseudo_label), y): candidates.append(pseudo_label) elif revision_num > 0: candidates.extend(self._revision(revision_num, pseudo_label, y)) if len(candidates) > 0: min_revision_num = revision_num break if revision_num >= max_revision_num: return [] for revision_num in range( min_revision_num + 1, min_revision_num + require_more_revision + 1 ): if revision_num > max_revision_num: return candidates candidates.extend(self._revision(revision_num, pseudo_label, y)) return candidates def __repr__(self): return ( f"{self.__class__.__name__} is a KB with " f"pseudo_label_list={self.pseudo_label_list!r}, " f"max_err={self.max_err!r}, " f"use_cache={self.use_cache!r}." ) class GroundKB(KBBase): """ Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon class initialization, storing all potential candidates along with their respective reasoning result. Ground KB can accelerate abductive reasoning in `abduce_candidates`. Parameters ---------- pseudo_label_list : list Refer to class `KBBase`. GKB_len_list : list List of possible lengths for a pseudo label sample. max_err : float, optional Refer to class `KBBase`. Notes ----- Users can also inherit from this class to build their own knowledge base. Similar to `KBBase`, users are only required to provide the `pseudo_label_list` and override the `logic_forward` function. Additionally, users should provide the `GKB_len_list`. After that, other operations (e.g. auto-construction of GKB, and how to perform abductive reasoning) will be automatically set up. """ def __init__(self, pseudo_label_list, GKB_len_list, max_err=1e-10): super().__init__(pseudo_label_list, max_err) if not isinstance(GKB_len_list, list): raise TypeError("GKB_len_list should be list") self.GKB_len_list = GKB_len_list self.GKB = {} X, Y = self._get_GKB() for x, y in zip(X, Y): self.GKB.setdefault(len(x), defaultdict(list))[y].append(x) def _get_XY_list(self, args): pre_x, post_x_it = args[0], args[1] XY_list = [] for post_x in post_x_it: x = (pre_x,) + post_x y = self.logic_forward(x) if y is not None: XY_list.append((x, y)) return XY_list def _get_GKB(self): """ Prebuild the GKB according to `pseudo_label_list` and `GKB_len_list`. """ X, Y = [], [] for length in self.GKB_len_list: arg_list = [] for pre_x in self.pseudo_label_list: post_x_it = product(self.pseudo_label_list, repeat=length - 1) arg_list.append((pre_x, post_x_it)) with Pool(processes=len(arg_list)) as pool: ret_list = pool.map(self._get_XY_list, arg_list) for XY_list in ret_list: if len(XY_list) == 0: continue part_X, part_Y = zip(*XY_list) X.extend(part_X) Y.extend(part_Y) if Y and isinstance(Y[0], (int, float)): X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) return X, Y def abduce_candidates(self, pseudo_label, y, max_revision_num, require_more_revision): """ Perform abductive reasoning by directly retrieving compatible candidates from the prebuilt GKB. In this way, the time-consuming exhaustive search can be avoided. Parameters ---------- pseudo_label : List[Any] Pseudo label sample (to be revised by abductive reasoning). y : any Ground truth of the reasoning result for the sample. max_revision_num : int The upper limit on the number of revised labels for each sample. require_more_revision : int, optional Specifies additional number of revisions permitted beyond the minimum required. Returns ------- List[List[Any]] A list of candidates, i.e. revised pseudo label samples that are compatible with the knowledge base. """ if self.GKB == {} or len(pseudo_label) not in self.GKB_len_list: return [] all_candidates = self._find_candidate_GKB(pseudo_label, y) if len(all_candidates) == 0: return [] cost_list = hamming_dist(pseudo_label, all_candidates) min_revision_num = np.min(cost_list) revision_num = min(max_revision_num, min_revision_num + require_more_revision) idxs = np.where(cost_list <= revision_num)[0] candidates = [all_candidates[idx] for idx in idxs] return candidates def _find_candidate_GKB(self, pseudo_label, y): """ Retrieve compatible candidates from the prebuilt GKB. For numerical reasoning results, return all candidates whose reasoning results fall within the [y - max_err, y + max_err] range. """ if isinstance(y, (int, float)): potential_candidates = self.GKB[len(pseudo_label)] key_list = list(potential_candidates.keys()) low_key = bisect.bisect_left(key_list, y - self.max_err) high_key = bisect.bisect_right(key_list, y + self.max_err) all_candidates = [ candidate for key in key_list[low_key:high_key] for candidate in potential_candidates[key] ] return all_candidates else: return self.GKB[len(pseudo_label)][y] def __repr__(self): GKB_info_parts = [] for i in self.GKB_len_list: num_candidates = len(self.GKB[i]) if i in self.GKB else 0 GKB_info_parts.append(f"{num_candidates} candidates of length {i}") GKB_info = ", ".join(GKB_info_parts) return ( f"{self.__class__.__name__} is a KB with " f"pseudo_label_list={self.pseudo_label_list!r}, " f"max_err={self.max_err!r}, " f"use_cache={self.use_cache!r}. " f"It has a prebuilt GKB with " f"GKB_len_list={self.GKB_len_list!r}, " f"and there are " f"{GKB_info}" f" in the GKB." ) class PrologKB(KBBase): """ Knowledge base provided by a Prolog (.pl) file. Parameters ---------- pseudo_label_list : list Refer to class `KBBase`. pl_file : Prolog file containing the KB. max_err : float, optional Refer to class `KBBase`. Notes ----- Users can instantiate this class to build their own knowledge base. During the instantiation, users are only required to provide the `pseudo_label_list` and `pl_file`. To use the default logic forward and abductive reasoning methods in this class, in the Prolog (.pl) file, there needs to be a rule which is strictly formatted as `logic_forward(Pseudo_labels, Res).`, e.g., `logic_forward([A,B], C) :- C is A+B`. For specifics, refer to the `logic_forward` and `get_query_string` functions in this class. Users are also welcome to override related functions for more flexible support. """ def __init__(self, pseudo_label_list, pl_file): super().__init__(pseudo_label_list) self.pl_file = pl_file self.prolog = pyswip.Prolog() if not os.path.exists(self.pl_file): raise FileNotFoundError(f"The Prolog file {self.pl_file} does not exist.") self.prolog.consult(self.pl_file) def logic_forward(self, pseudo_labels): """ Consult prolog with the query `logic_forward(pseudo_labels, Res).`, and set the returned `Res` as the reasoning results. To use this default function, there must be a `logic_forward` method in the pl file to perform reasoning. Otherwise, users would override this function. Parameters ---------- pseudo_label : List[Any] Pseudo label sample. """ result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]["Res"] if result == "true": return True elif result == "false": return False return result def _revision_pseudo_label(self, pseudo_label, revision_idx): import re revision_pseudo_label = pseudo_label.copy() revision_pseudo_label = flatten(revision_pseudo_label) for idx in revision_idx: revision_pseudo_label[idx] = "P" + str(idx) revision_pseudo_label = reform_list(revision_pseudo_label, pseudo_label) regex = r"'P\d+'" return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pseudo_label)) def get_query_string(self, pseudo_label, y, revision_idx): """ Get the query to be used for consulting Prolog. This is a default function for demo, users would override this function to adapt to their own Prolog file. In this demo function, return query `logic_forward([kept_labels, Revise_labels], Res).`. Parameters ---------- pseudo_label : List[Any] Pseudo label sample (to be revised by abductive reasoning). y : any Ground truth of the reasoning result for the sample. revision_idx : array-like Indices of where revisions should be made to the pseudo label sample. Returns ------- str A string of the query. """ query_string = "logic_forward(" query_string += self._revision_pseudo_label(pseudo_label, revision_idx) key_is_none_flag = y is None or (type(y) == list and y[0] is None) query_string += ",%s)." % y if not key_is_none_flag else ")." return query_string def revise_at_idx(self, pseudo_label, y, revision_idx): """ Revise the pseudo label sample at specified index positions by querying Prolog. Parameters ---------- pseudo_label : List[Any] Pseudo label sample (to be revised). y : Any Ground truth of the reasoning result for the sample. revision_idx : array-like Indices of where revisions should be made to the pseudo label sample. Returns ------- List[List[Any]] A list of candidates, i.e. revised pseudo label samples that are compatible with the knowledge base. """ candidates = [] query_string = self.get_query_string(pseudo_label, y, revision_idx) save_pseudo_label = pseudo_label pseudo_label = flatten(pseudo_label) abduce_c = [list(z.values()) for z in self.prolog.query(query_string)] for c in abduce_c: candidate = pseudo_label.copy() for i, idx in enumerate(revision_idx): candidate[idx] = c[i] candidate = reform_list(candidate, save_pseudo_label) candidates.append(candidate) return candidates def __repr__(self): return ( f"{self.__class__.__name__} is a KB with " f"pseudo_label_list={self.pseudo_label_list!r}, " f"defined by " f"Prolog file {self.pl_file!r}." )