diff --git a/abl/reasoning/search_based_kb.py b/abl/reasoning/search_based_kb.py index 0489519..cd5db40 100644 --- a/abl/reasoning/search_based_kb.py +++ b/abl/reasoning/search_based_kb.py @@ -34,12 +34,13 @@ class SearchBasedKB(BaseKB, ABC): pseudo_label_list: List, search_strategy: Callable[[ListData, int, int], Generator] = incremental_search_strategy, use_cache: bool = True, - cache_root: Optional[str] = None, + cache_file: Optional[str] = None, + cache_size: int = 4096 ) -> None: super().__init__(pseudo_label_list) self.search_strategy = search_strategy self.use_cache = use_cache - self.cache_root = cache_root + self.cache_file = cache_file if self.use_cache: if not hasattr(self, "get_key"): raise NotImplementedError("If use_cache is True, get_key should be implemented.") @@ -48,9 +49,10 @@ class SearchBasedKB(BaseKB, ABC): key_func = lambda x: x self.cache = Cache[ListData, List[List[Any]]]( func=self._abduce_by_search, - cache=use_cache, - cache_root=cache_root, + cache=self.use_cache, + cache_file=self.cache_file, key_func=key_func, + max_size=cache_size, ) @abstractmethod diff --git a/abl/utils/cache.py b/abl/utils/cache.py index 363e6d7..f4b3b0c 100644 --- a/abl/utils/cache.py +++ b/abl/utils/cache.py @@ -7,15 +7,17 @@ from .logger import print_log K = TypeVar("K") T = TypeVar("T") +PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields + -# TODO: add lru class Cache(Generic[K, T]): def __init__( self, func: Callable[[K], T], cache: bool, - cache_root: Union[None, str, PathLike], + cache_file: Union[None, str, PathLike], key_func: Callable[[K], Hashable] = lambda x: x, + max_size: int = 4096, ): """Create cache @@ -27,16 +29,12 @@ class Cache(Generic[K, T]): self.func = func self.key_func = key_func self.cache = cache - if cache is True: + if cache is True or cache_file is not None: print_log("Caching is activated", logger="current") - self.cache_file = cache_root is not None - self.cache_dict = dict() - self.first = func - if self.cache_file: - self.cache_root = Path(cache_root) - self.first = self.get_from_file - if self.cache: + self._init_cache(cache_file, max_size) self.first = self.get_from_dict + else: + self.first = self.func def __getitem__(self, item: K, *args) -> T: return self.first(item, *args) @@ -48,37 +46,67 @@ class Cache(Generic[K, T]): for p in self.cache_root.iterdir(): p.unlink() + def _init_cache(self, cache_file, max_size): + self.cache = True + self.cache_dict = dict() + + self.hits, self.misses, self.maxsize = 0, 0, max_size + self.full = False + self.root = [] # root of the circular doubly linked list + self.root[:] = [self.root, self.root, None, None] + + if cache_file is not None: + with open(cache_file, "rb") as f: + cache_dict_from_file = pickle.load(f) + self.maxsize += len(cache_dict_from_file) + print_log( + f"Max size of the cache has been enlarged to {self.maxsize}.", logger="current" + ) + for cache_key, result in cache_dict_from_file.items(): + last = self.root[PREV] + link = [last, self.root, cache_key, result] + last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link + def get(self, item: K, *args) -> T: return self.first(item, *args) def get_from_dict(self, item: K, *args) -> T: """Implements dict based cache.""" cache_key = (self.key_func(item), *args) - result = self.cache_dict.get(cache_key) - if result is None: - if self.cache_file: - result = self.get_from_file(item, *args, cache_key=cache_key) - else: - result = self.func(item, *args) - self.cache_dict[cache_key] = result - return result + link = self.cache_dict.get(cache_key) + if link is not None: + # Move the link to the front of the circular queue + link_prev, link_next, _key, result = link + link_prev[NEXT] = link_next + link_next[PREV] = link_prev + last = self.root[PREV] + last[NEXT] = self.root[PREV] = link + link[PREV] = last + link[NEXT] = self.root + self.hits += 1 + return result + self.misses += 1 + + result = self.func(item, *args) - def get_from_file(self, item: K, *args, cache_key=None) -> T: - """Implements file based cache.""" - if cache_key is None: - cache_key = (self.key_func(item), *args) - filepath = self.cache_root / str(hash(cache_key)) - result = None - if filepath.exists(): - with open(filepath, "rb") as f: - (key, result) = pickle.load(f) - if key == cache_key: - return result - else: - # Hash collision! Handle this by overwriting the cache with the new query - result = None - if result is None: - result = self.func(item, *args) - with open(filepath, "wb") as f: - pickle.dump((cache_key, result), f) + if self.full: + # Use the old root to store the new key and result. + oldroot = self.root + oldroot[KEY] = cache_key + oldroot[RESULT] = result + # Empty the oldest link and make it the new root. + self.root = oldroot[NEXT] + oldkey = self.root[KEY] + oldresult = self.root[RESULT] + self.root[KEY] = self.root[RESULT] = None + # Now update the cache dictionary. + del self.cache_dict[oldkey] + self.cache_dict[cache_key] = oldroot + else: + # Put result in a new link at the front of the queue. + last = self.root[PREV] + link = [last, self.root, cache_key, result] + last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link + if isinstance(self.maxsize, int): + self.full = len(self.cache_dict) >= self.maxsize return result diff --git a/examples/mnist_add/mnist_add_kb.py b/examples/mnist_add/mnist_add_kb.py index 258b92e..8a29927 100644 --- a/examples/mnist_add/mnist_add_kb.py +++ b/examples/mnist_add/mnist_add_kb.py @@ -5,12 +5,10 @@ from abl.structures import ListData class AddKB(SearchBasedKB): - def __init__( - self, - pseudo_label_list=list(range(10)), - use_cache=True, - ): - super().__init__(pseudo_label_list=pseudo_label_list, use_cache=use_cache) + def __init__(self, pseudo_label_list=list(range(10)), use_cache=True, cache_size=4096): + super().__init__( + pseudo_label_list=pseudo_label_list, use_cache=use_cache, cache_size=cache_size + ) def get_key(self, data_sample: ListData): return (data_sample.to_tuple("pred_pseudo_label"), data_sample["Y"][0])