@@ -11,3 +11,11 @@ class BaseKB(ABC): | |||||
def abduce_candidates(self, data_sample: ListData): | def abduce_candidates(self, data_sample: ListData): | ||||
"""Placeholder for abduction of the knowledge base.""" | """Placeholder for abduction of the knowledge base.""" | ||||
pass | pass | ||||
# TODO: When the output is excessively long, use ellipses as a substitute. | |||||
def __repr__(self): | |||||
return ( | |||||
f"<{self.__class__.__name__}(\n" | |||||
f" pseudo_label_list: {self.pseudo_label_list!r}\n" | |||||
f") at {hex(id(self))}>" | |||||
) |
@@ -9,7 +9,7 @@ from .base_kb import BaseKB | |||||
class GroundKB(BaseKB, ABC): | class GroundKB(BaseKB, ABC): | ||||
def __init__(self, pseudo_label_list: List) -> None: | def __init__(self, pseudo_label_list: List) -> None: | ||||
super().__init__(pseudo_label_list) | super().__init__(pseudo_label_list) | ||||
self.base = self.construct_base() | |||||
self.GKB = self.construct_base() | |||||
@abstractmethod | @abstractmethod | ||||
def construct_base(self) -> dict: | def construct_base(self) -> dict: | ||||
@@ -20,7 +20,7 @@ class GroundKB(BaseKB, ABC): | |||||
pass | pass | ||||
def key2candidates(self, key: Hashable) -> List[List[Any]]: | def key2candidates(self, key: Hashable) -> List[List[Any]]: | ||||
return self.base[key] | |||||
return self.GKB[key] | |||||
def filter_candidates( | def filter_candidates( | ||||
self, | self, | ||||
@@ -50,3 +50,12 @@ class GroundKB(BaseKB, ABC): | |||||
require_more_revision=require_more_revision, | require_more_revision=require_more_revision, | ||||
candidates=candidates, | candidates=candidates, | ||||
) | ) | ||||
# TODO: When the output is excessively long, use ellipses as a substitute. | |||||
def __repr__(self): | |||||
return ( | |||||
f"<{self.__class__.__name__}(\n" | |||||
f" pseudo_label_list: {self.pseudo_label_list!r}\n" | |||||
f" GKB: {self.GKB!r}\n" | |||||
f") at {hex(id(self))}>" | |||||
) |
@@ -39,6 +39,7 @@ class SearchBasedKB(BaseKB, ABC): | |||||
super().__init__(pseudo_label_list) | super().__init__(pseudo_label_list) | ||||
self.search_strategy = search_strategy | self.search_strategy = search_strategy | ||||
self.use_cache = use_cache | self.use_cache = use_cache | ||||
self.cache_root = cache_root | |||||
if self.use_cache: | if self.use_cache: | ||||
if not hasattr(self, "get_key"): | if not hasattr(self, "get_key"): | ||||
raise NotImplementedError("If use_cache is True, get_key should be implemented.") | raise NotImplementedError("If use_cache is True, get_key should be implemented.") | ||||
@@ -100,3 +101,27 @@ class SearchBasedKB(BaseKB, ABC): | |||||
break | break | ||||
return candidates | return candidates | ||||
# TODO: When the output is excessively long, use ellipses as a substitute. | |||||
def __repr__(self): | |||||
return ( | |||||
f"<{self.__class__.__name__}(\n" | |||||
f" pseudo_label_list: {self.pseudo_label_list!r}\n" | |||||
f" search_strategy: {self.search_strategy!r}\n" | |||||
f" use_cache: {self.use_cache!r}\n" | |||||
f" cache_root: {self.cache_root!r}\n" | |||||
f") at {hex(id(self))}>" | |||||
) | |||||
class Test(SearchBasedKB): | |||||
def __init__( | |||||
self, | |||||
pseudo_label_list: List, | |||||
search_strategy: Callable[[ListData, int, int], Generator] = incremental_search_strategy, | |||||
use_cache: bool = True, | |||||
cache_root: Optional[str] = None, | |||||
) -> None: | |||||
super().__init__(pseudo_label_list, search_strategy, use_cache, cache_root) | |||||
def check_equal(self, data_sample: ListData, y: Any): | |||||
return data_sample["pred_pseudo_label"][0] == y |
@@ -56,10 +56,10 @@ class HWF_KB(GroundKB): | |||||
Y.extend(part_Y) | Y.extend(part_Y) | ||||
if Y and isinstance(Y[0], (int, float)): | if Y and isinstance(Y[0], (int, float)): | ||||
X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) | X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) | ||||
base = {} | |||||
GKB = {} | |||||
for x, y in zip(X, Y): | for x, y in zip(X, Y): | ||||
base.setdefault(len(x), defaultdict(list))[y].append(x) | |||||
return base | |||||
GKB.setdefault(len(x), defaultdict(list))[y].append(x) | |||||
return GKB | |||||
@staticmethod | @staticmethod | ||||
def get_key(data_sample: ListData) -> Hashable: | def get_key(data_sample: ListData) -> Hashable: | ||||
@@ -68,9 +68,9 @@ class HWF_KB(GroundKB): | |||||
def key2candidates(self, key: Hashable) -> List[List[Any]]: | def key2candidates(self, key: Hashable) -> List[List[Any]]: | ||||
equation_len, y = key | equation_len, y = key | ||||
if self.max_err == 0: | if self.max_err == 0: | ||||
return self.base[equation_len][y] | |||||
return self.GKB[equation_len][y] | |||||
else: | else: | ||||
potential_candidates = self.base[equation_len] | |||||
potential_candidates = self.GKB[equation_len] | |||||
key_list = list(potential_candidates.keys()) | key_list = list(potential_candidates.keys()) | ||||
key_idx = bisect.bisect_left(key_list, y) | key_idx = bisect.bisect_left(key_list, y) | ||||