|
|
@@ -112,16 +112,3 @@ class SearchBasedKB(BaseKB, ABC): |
|
|
|
f" cache_root: {self.cache_root!r}\n" |
|
|
|
f") at {hex(id(self))}>" |
|
|
|
) |
|
|
|
|
|
|
|
class Test(SearchBasedKB): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
pseudo_label_list: List, |
|
|
|
search_strategy: Callable[[ListData, int, int], Generator] = incremental_search_strategy, |
|
|
|
use_cache: bool = True, |
|
|
|
cache_root: Optional[str] = None, |
|
|
|
) -> None: |
|
|
|
super().__init__(pseudo_label_list, search_strategy, use_cache, cache_root) |
|
|
|
|
|
|
|
def check_equal(self, data_sample: ListData, y: Any): |
|
|
|
return data_sample["pred_pseudo_label"][0] == y |