diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index 10aa559..8ef6a25 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -44,7 +44,6 @@ class KBBase(ABC): pseudo_label_list, max_err=1e-10, use_cache=True, - cache_file=None, key_func=to_hashable, max_cache_size=4096, ): @@ -54,7 +53,6 @@ class KBBase(ABC): self.max_err = max_err self.use_cache = use_cache - self.cache_file = cache_file self.key_func = key_func self.max_cache_size = max_cache_size diff --git a/abl/utils/cache.py b/abl/utils/cache.py index bbd15cf..a927e1f 100644 --- a/abl/utils/cache.py +++ b/abl/utils/cache.py @@ -36,7 +36,6 @@ class Cache(Generic[K, T]): self.cache = True self.cache_dict = dict() self.key_func = obj.key_func - self.cache_file = obj.cache_file self.max_size = obj.max_cache_size self.hits, self.misses = 0, 0 @@ -44,18 +43,6 @@ class Cache(Generic[K, T]): self.root = [] # root of the circular doubly linked list self.root[:] = [self.root, self.root, None, None] - if self.cache_file is not None: - with open(self.cache_file, "rb") as f: - cache_dict_from_file = pickle.load(f) - self.max_size += len(cache_dict_from_file) - print_log( - f"Max size of the cache has been enlarged to {self.max_size}.", logger="current" - ) - for cache_key, result in cache_dict_from_file.items(): - last = self.root[PREV] - link = [last, self.root, cache_key, result] - last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link - self.has_init = True def get_from_dict(self, obj, *args) -> T: @@ -98,14 +85,6 @@ class Cache(Generic[K, T]): last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link if isinstance(self.max_size, int): self.full = len(self.cache_dict) >= self.max_size - if self.full: - log_dir = ABLLogger.get_current_instance().log_dir - cache_dir = osp.join(log_dir, "cache") - os.makedirs(cache_dir, exist_ok=True) - cache_path = osp.join(cache_dir, "abduce_by_search_cache_res.pth") - with open(cache_path, "wb") as file: - pickle.dump(self.cache_dict, file, protocol=pickle.HIGHEST_PROTOCOL) - print_log(f"Cache will be saved to {cache_path}", logger="current") return result