|
|
@@ -39,6 +39,7 @@ class SearchBasedKB(BaseKB, ABC): |
|
|
|
super().__init__(pseudo_label_list) |
|
|
|
self.search_strategy = search_strategy |
|
|
|
self.use_cache = use_cache |
|
|
|
self.cache_root = cache_root |
|
|
|
if self.use_cache: |
|
|
|
if not hasattr(self, "get_key"): |
|
|
|
raise NotImplementedError("If use_cache is True, get_key should be implemented.") |
|
|
@@ -100,3 +101,27 @@ class SearchBasedKB(BaseKB, ABC): |
|
|
|
break |
|
|
|
|
|
|
|
return candidates |
|
|
|
|
|
|
|
# TODO: When the output is excessively long, use ellipses as a substitute. |
|
|
|
def __repr__(self): |
|
|
|
return ( |
|
|
|
f"<{self.__class__.__name__}(\n" |
|
|
|
f" pseudo_label_list: {self.pseudo_label_list!r}\n" |
|
|
|
f" search_strategy: {self.search_strategy!r}\n" |
|
|
|
f" use_cache: {self.use_cache!r}\n" |
|
|
|
f" cache_root: {self.cache_root!r}\n" |
|
|
|
f") at {hex(id(self))}>" |
|
|
|
) |
|
|
|
|
|
|
|
class Test(SearchBasedKB): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
pseudo_label_list: List, |
|
|
|
search_strategy: Callable[[ListData, int, int], Generator] = incremental_search_strategy, |
|
|
|
use_cache: bool = True, |
|
|
|
cache_root: Optional[str] = None, |
|
|
|
) -> None: |
|
|
|
super().__init__(pseudo_label_list, search_strategy, use_cache, cache_root) |
|
|
|
|
|
|
|
def check_equal(self, data_sample: ListData, y: Any): |
|
|
|
return data_sample["pred_pseudo_label"][0] == y |