Browse Source

[ENH] add lru mechanism to Cache

ab_data
Gao Enhao 1 year ago
parent
commit
99c9aa37b1
3 changed files with 74 additions and 46 deletions
  1. +6
    -4
      abl/reasoning/search_based_kb.py
  2. +64
    -36
      abl/utils/cache.py
  3. +4
    -6
      examples/mnist_add/mnist_add_kb.py

+ 6
- 4
abl/reasoning/search_based_kb.py View File

@@ -34,12 +34,13 @@ class SearchBasedKB(BaseKB, ABC):
pseudo_label_list: List,
search_strategy: Callable[[ListData, int, int], Generator] = incremental_search_strategy,
use_cache: bool = True,
cache_root: Optional[str] = None,
cache_file: Optional[str] = None,
cache_size: int = 4096
) -> None:
super().__init__(pseudo_label_list)
self.search_strategy = search_strategy
self.use_cache = use_cache
self.cache_root = cache_root
self.cache_file = cache_file
if self.use_cache:
if not hasattr(self, "get_key"):
raise NotImplementedError("If use_cache is True, get_key should be implemented.")
@@ -48,9 +49,10 @@ class SearchBasedKB(BaseKB, ABC):
key_func = lambda x: x
self.cache = Cache[ListData, List[List[Any]]](
func=self._abduce_by_search,
cache=use_cache,
cache_root=cache_root,
cache=self.use_cache,
cache_file=self.cache_file,
key_func=key_func,
max_size=cache_size,
)

@abstractmethod


+ 64
- 36
abl/utils/cache.py View File

@@ -7,15 +7,17 @@ from .logger import print_log

K = TypeVar("K")
T = TypeVar("T")
PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields


# TODO: add lru
class Cache(Generic[K, T]):
def __init__(
self,
func: Callable[[K], T],
cache: bool,
cache_root: Union[None, str, PathLike],
cache_file: Union[None, str, PathLike],
key_func: Callable[[K], Hashable] = lambda x: x,
max_size: int = 4096,
):
"""Create cache

@@ -27,16 +29,12 @@ class Cache(Generic[K, T]):
self.func = func
self.key_func = key_func
self.cache = cache
if cache is True:
if cache is True or cache_file is not None:
print_log("Caching is activated", logger="current")
self.cache_file = cache_root is not None
self.cache_dict = dict()
self.first = func
if self.cache_file:
self.cache_root = Path(cache_root)
self.first = self.get_from_file
if self.cache:
self._init_cache(cache_file, max_size)
self.first = self.get_from_dict
else:
self.first = self.func

def __getitem__(self, item: K, *args) -> T:
return self.first(item, *args)
@@ -48,37 +46,67 @@ class Cache(Generic[K, T]):
for p in self.cache_root.iterdir():
p.unlink()

def _init_cache(self, cache_file, max_size):
self.cache = True
self.cache_dict = dict()

self.hits, self.misses, self.maxsize = 0, 0, max_size
self.full = False
self.root = [] # root of the circular doubly linked list
self.root[:] = [self.root, self.root, None, None]

if cache_file is not None:
with open(cache_file, "rb") as f:
cache_dict_from_file = pickle.load(f)
self.maxsize += len(cache_dict_from_file)
print_log(
f"Max size of the cache has been enlarged to {self.maxsize}.", logger="current"
)
for cache_key, result in cache_dict_from_file.items():
last = self.root[PREV]
link = [last, self.root, cache_key, result]
last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link

def get(self, item: K, *args) -> T:
return self.first(item, *args)

def get_from_dict(self, item: K, *args) -> T:
"""Implements dict based cache."""
cache_key = (self.key_func(item), *args)
result = self.cache_dict.get(cache_key)
if result is None:
if self.cache_file:
result = self.get_from_file(item, *args, cache_key=cache_key)
else:
result = self.func(item, *args)
self.cache_dict[cache_key] = result
return result
link = self.cache_dict.get(cache_key)
if link is not None:
# Move the link to the front of the circular queue
link_prev, link_next, _key, result = link
link_prev[NEXT] = link_next
link_next[PREV] = link_prev
last = self.root[PREV]
last[NEXT] = self.root[PREV] = link
link[PREV] = last
link[NEXT] = self.root
self.hits += 1
return result
self.misses += 1

result = self.func(item, *args)

def get_from_file(self, item: K, *args, cache_key=None) -> T:
"""Implements file based cache."""
if cache_key is None:
cache_key = (self.key_func(item), *args)
filepath = self.cache_root / str(hash(cache_key))
result = None
if filepath.exists():
with open(filepath, "rb") as f:
(key, result) = pickle.load(f)
if key == cache_key:
return result
else:
# Hash collision! Handle this by overwriting the cache with the new query
result = None
if result is None:
result = self.func(item, *args)
with open(filepath, "wb") as f:
pickle.dump((cache_key, result), f)
if self.full:
# Use the old root to store the new key and result.
oldroot = self.root
oldroot[KEY] = cache_key
oldroot[RESULT] = result
# Empty the oldest link and make it the new root.
self.root = oldroot[NEXT]
oldkey = self.root[KEY]
oldresult = self.root[RESULT]
self.root[KEY] = self.root[RESULT] = None
# Now update the cache dictionary.
del self.cache_dict[oldkey]
self.cache_dict[cache_key] = oldroot
else:
# Put result in a new link at the front of the queue.
last = self.root[PREV]
link = [last, self.root, cache_key, result]
last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link
if isinstance(self.maxsize, int):
self.full = len(self.cache_dict) >= self.maxsize
return result

+ 4
- 6
examples/mnist_add/mnist_add_kb.py View File

@@ -5,12 +5,10 @@ from abl.structures import ListData


class AddKB(SearchBasedKB):
def __init__(
self,
pseudo_label_list=list(range(10)),
use_cache=True,
):
super().__init__(pseudo_label_list=pseudo_label_list, use_cache=use_cache)
def __init__(self, pseudo_label_list=list(range(10)), use_cache=True, cache_size=4096):
super().__init__(
pseudo_label_list=pseudo_label_list, use_cache=use_cache, cache_size=cache_size
)

def get_key(self, data_sample: ListData):
return (data_sample.to_tuple("pred_pseudo_label"), data_sample["Y"][0])


Loading…
Cancel
Save