Browse Source

[MNT] add repr to kbs and change base to GKB

ab_data
Gao Enhao 1 year ago
parent
commit
6022f702d9
4 changed files with 49 additions and 7 deletions
  1. +8
    -0
      abl/reasoning/base_kb.py
  2. +11
    -2
      abl/reasoning/ground_kb.py
  3. +25
    -0
      abl/reasoning/search_based_kb.py
  4. +5
    -5
      examples/hwf/hwf_kb.py

+ 8
- 0
abl/reasoning/base_kb.py View File

@@ -11,3 +11,11 @@ class BaseKB(ABC):
def abduce_candidates(self, data_sample: ListData): def abduce_candidates(self, data_sample: ListData):
"""Placeholder for abduction of the knowledge base.""" """Placeholder for abduction of the knowledge base."""
pass pass

# TODO: When the output is excessively long, use ellipses as a substitute.
def __repr__(self):
return (
f"<{self.__class__.__name__}(\n"
f" pseudo_label_list: {self.pseudo_label_list!r}\n"
f") at {hex(id(self))}>"
)

+ 11
- 2
abl/reasoning/ground_kb.py View File

@@ -9,7 +9,7 @@ from .base_kb import BaseKB
class GroundKB(BaseKB, ABC): class GroundKB(BaseKB, ABC):
def __init__(self, pseudo_label_list: List) -> None: def __init__(self, pseudo_label_list: List) -> None:
super().__init__(pseudo_label_list) super().__init__(pseudo_label_list)
self.base = self.construct_base()
self.GKB = self.construct_base()


@abstractmethod @abstractmethod
def construct_base(self) -> dict: def construct_base(self) -> dict:
@@ -20,7 +20,7 @@ class GroundKB(BaseKB, ABC):
pass pass


def key2candidates(self, key: Hashable) -> List[List[Any]]: def key2candidates(self, key: Hashable) -> List[List[Any]]:
return self.base[key]
return self.GKB[key]


def filter_candidates( def filter_candidates(
self, self,
@@ -50,3 +50,12 @@ class GroundKB(BaseKB, ABC):
require_more_revision=require_more_revision, require_more_revision=require_more_revision,
candidates=candidates, candidates=candidates,
) )

# TODO: When the output is excessively long, use ellipses as a substitute.
def __repr__(self):
return (
f"<{self.__class__.__name__}(\n"
f" pseudo_label_list: {self.pseudo_label_list!r}\n"
f" GKB: {self.GKB!r}\n"
f") at {hex(id(self))}>"
)

+ 25
- 0
abl/reasoning/search_based_kb.py View File

@@ -39,6 +39,7 @@ class SearchBasedKB(BaseKB, ABC):
super().__init__(pseudo_label_list) super().__init__(pseudo_label_list)
self.search_strategy = search_strategy self.search_strategy = search_strategy
self.use_cache = use_cache self.use_cache = use_cache
self.cache_root = cache_root
if self.use_cache: if self.use_cache:
if not hasattr(self, "get_key"): if not hasattr(self, "get_key"):
raise NotImplementedError("If use_cache is True, get_key should be implemented.") raise NotImplementedError("If use_cache is True, get_key should be implemented.")
@@ -100,3 +101,27 @@ class SearchBasedKB(BaseKB, ABC):
break break


return candidates return candidates

# TODO: When the output is excessively long, use ellipses as a substitute.
def __repr__(self):
return (
f"<{self.__class__.__name__}(\n"
f" pseudo_label_list: {self.pseudo_label_list!r}\n"
f" search_strategy: {self.search_strategy!r}\n"
f" use_cache: {self.use_cache!r}\n"
f" cache_root: {self.cache_root!r}\n"
f") at {hex(id(self))}>"
)

class Test(SearchBasedKB):
def __init__(
self,
pseudo_label_list: List,
search_strategy: Callable[[ListData, int, int], Generator] = incremental_search_strategy,
use_cache: bool = True,
cache_root: Optional[str] = None,
) -> None:
super().__init__(pseudo_label_list, search_strategy, use_cache, cache_root)

def check_equal(self, data_sample: ListData, y: Any):
return data_sample["pred_pseudo_label"][0] == y

+ 5
- 5
examples/hwf/hwf_kb.py View File

@@ -56,10 +56,10 @@ class HWF_KB(GroundKB):
Y.extend(part_Y) Y.extend(part_Y)
if Y and isinstance(Y[0], (int, float)): if Y and isinstance(Y[0], (int, float)):
X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1]))
base = {}
GKB = {}
for x, y in zip(X, Y): for x, y in zip(X, Y):
base.setdefault(len(x), defaultdict(list))[y].append(x)
return base
GKB.setdefault(len(x), defaultdict(list))[y].append(x)
return GKB


@staticmethod @staticmethod
def get_key(data_sample: ListData) -> Hashable: def get_key(data_sample: ListData) -> Hashable:
@@ -68,9 +68,9 @@ class HWF_KB(GroundKB):
def key2candidates(self, key: Hashable) -> List[List[Any]]: def key2candidates(self, key: Hashable) -> List[List[Any]]:
equation_len, y = key equation_len, y = key
if self.max_err == 0: if self.max_err == 0:
return self.base[equation_len][y]
return self.GKB[equation_len][y]
else: else:
potential_candidates = self.base[equation_len]
potential_candidates = self.GKB[equation_len]
key_list = list(potential_candidates.keys()) key_list = list(potential_candidates.keys())
key_idx = bisect.bisect_left(key_list, y) key_idx = bisect.bisect_left(key_list, y)




Loading…
Cancel
Save