diff --git a/abl/reasoning/base_kb.py b/abl/reasoning/base_kb.py index 4822e45..848d641 100644 --- a/abl/reasoning/base_kb.py +++ b/abl/reasoning/base_kb.py @@ -11,3 +11,11 @@ class BaseKB(ABC): def abduce_candidates(self, data_sample: ListData): """Placeholder for abduction of the knowledge base.""" pass + + # TODO: When the output is excessively long, use ellipses as a substitute. + def __repr__(self): + return ( + f"<{self.__class__.__name__}(\n" + f" pseudo_label_list: {self.pseudo_label_list!r}\n" + f") at {hex(id(self))}>" + ) diff --git a/abl/reasoning/ground_kb.py b/abl/reasoning/ground_kb.py index ed32efd..9f9428d 100644 --- a/abl/reasoning/ground_kb.py +++ b/abl/reasoning/ground_kb.py @@ -9,7 +9,7 @@ from .base_kb import BaseKB class GroundKB(BaseKB, ABC): def __init__(self, pseudo_label_list: List) -> None: super().__init__(pseudo_label_list) - self.base = self.construct_base() + self.GKB = self.construct_base() @abstractmethod def construct_base(self) -> dict: @@ -20,7 +20,7 @@ class GroundKB(BaseKB, ABC): pass def key2candidates(self, key: Hashable) -> List[List[Any]]: - return self.base[key] + return self.GKB[key] def filter_candidates( self, @@ -50,3 +50,12 @@ class GroundKB(BaseKB, ABC): require_more_revision=require_more_revision, candidates=candidates, ) + + # TODO: When the output is excessively long, use ellipses as a substitute. + def __repr__(self): + return ( + f"<{self.__class__.__name__}(\n" + f" pseudo_label_list: {self.pseudo_label_list!r}\n" + f" GKB: {self.GKB!r}\n" + f") at {hex(id(self))}>" + ) diff --git a/abl/reasoning/search_based_kb.py b/abl/reasoning/search_based_kb.py index fb32e7f..e1540fd 100644 --- a/abl/reasoning/search_based_kb.py +++ b/abl/reasoning/search_based_kb.py @@ -39,6 +39,7 @@ class SearchBasedKB(BaseKB, ABC): super().__init__(pseudo_label_list) self.search_strategy = search_strategy self.use_cache = use_cache + self.cache_root = cache_root if self.use_cache: if not hasattr(self, "get_key"): raise NotImplementedError("If use_cache is True, get_key should be implemented.") @@ -100,3 +101,27 @@ class SearchBasedKB(BaseKB, ABC): break return candidates + + # TODO: When the output is excessively long, use ellipses as a substitute. + def __repr__(self): + return ( + f"<{self.__class__.__name__}(\n" + f" pseudo_label_list: {self.pseudo_label_list!r}\n" + f" search_strategy: {self.search_strategy!r}\n" + f" use_cache: {self.use_cache!r}\n" + f" cache_root: {self.cache_root!r}\n" + f") at {hex(id(self))}>" + ) + +class Test(SearchBasedKB): + def __init__( + self, + pseudo_label_list: List, + search_strategy: Callable[[ListData, int, int], Generator] = incremental_search_strategy, + use_cache: bool = True, + cache_root: Optional[str] = None, + ) -> None: + super().__init__(pseudo_label_list, search_strategy, use_cache, cache_root) + + def check_equal(self, data_sample: ListData, y: Any): + return data_sample["pred_pseudo_label"][0] == y diff --git a/examples/hwf/hwf_kb.py b/examples/hwf/hwf_kb.py index 9f91c8c..15b6312 100644 --- a/examples/hwf/hwf_kb.py +++ b/examples/hwf/hwf_kb.py @@ -56,10 +56,10 @@ class HWF_KB(GroundKB): Y.extend(part_Y) if Y and isinstance(Y[0], (int, float)): X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) - base = {} + GKB = {} for x, y in zip(X, Y): - base.setdefault(len(x), defaultdict(list))[y].append(x) - return base + GKB.setdefault(len(x), defaultdict(list))[y].append(x) + return GKB @staticmethod def get_key(data_sample: ListData) -> Hashable: @@ -68,9 +68,9 @@ class HWF_KB(GroundKB): def key2candidates(self, key: Hashable) -> List[List[Any]]: equation_len, y = key if self.max_err == 0: - return self.base[equation_len][y] + return self.GKB[equation_len][y] else: - potential_candidates = self.base[equation_len] + potential_candidates = self.GKB[equation_len] key_list = list(potential_candidates.keys()) key_idx = bisect.bisect_left(key_list, y)