You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cache.py 3.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. from typing import Callable, Generic, TypeVar
  2. K = TypeVar("K")
  3. T = TypeVar("T")
  4. PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields
  5. class Cache(Generic[K, T]):
  6. def __init__(self, func: Callable[[K], T]):
  7. """Create cache
  8. :param func: Function this cache evaluates
  9. :param cache: If true, do in memory caching.
  10. :param cache_root: If not None, cache to files at the provided path.
  11. :param key_func: Convert the key into a hashable object if needed
  12. """
  13. self.func = func
  14. self.has_init = False
  15. def __getitem__(self, obj, *args) -> T:
  16. return self.get_from_dict(obj, *args)
  17. def clear_cache(self):
  18. """Invalidate entire cache."""
  19. self.cache_dict.clear()
  20. def _init_cache(self, obj):
  21. if self.has_init:
  22. return
  23. self.cache = True
  24. self.cache_dict = dict()
  25. self.key_func = obj.key_func
  26. self.max_size = obj.cache_size
  27. self.hits, self.misses = 0, 0
  28. self.full = False
  29. self.root = [] # root of the circular doubly linked list
  30. self.root[:] = [self.root, self.root, None, None]
  31. self.has_init = True
  32. def get_from_dict(self, obj, *args) -> T:
  33. """Implements dict based cache."""
  34. pred_pseudo_label, y, *res_args = args
  35. cache_key = (self.key_func(pred_pseudo_label), self.key_func(y), *res_args)
  36. link = self.cache_dict.get(cache_key)
  37. if link is not None:
  38. # Move the link to the front of the circular queue
  39. link_prev, link_next, _key, result = link
  40. link_prev[NEXT] = link_next
  41. link_next[PREV] = link_prev
  42. last = self.root[PREV]
  43. last[NEXT] = self.root[PREV] = link
  44. link[PREV] = last
  45. link[NEXT] = self.root
  46. self.hits += 1
  47. return result
  48. self.misses += 1
  49. result = self.func(obj, *args)
  50. if self.full:
  51. # Use the old root to store the new key and result.
  52. oldroot = self.root
  53. oldroot[KEY] = cache_key
  54. oldroot[RESULT] = result
  55. # Empty the oldest link and make it the new root.
  56. self.root = oldroot[NEXT]
  57. oldkey = self.root[KEY]
  58. self.root[KEY] = self.root[RESULT] = None
  59. # Now update the cache dictionary.
  60. del self.cache_dict[oldkey]
  61. self.cache_dict[cache_key] = oldroot
  62. else:
  63. # Put result in a new link at the front of the queue.
  64. last = self.root[PREV]
  65. link = [last, self.root, cache_key, result]
  66. last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link
  67. if isinstance(self.max_size, int):
  68. self.full = len(self.cache_dict) >= self.max_size
  69. return result
  70. def abl_cache():
  71. def decorator(func):
  72. cache_instance = Cache(func)
  73. def wrapper(obj, *args):
  74. if obj.use_cache:
  75. cache_instance._init_cache(obj)
  76. return cache_instance.get_from_dict(obj, *args)
  77. else:
  78. return func(obj, *args)
  79. return wrapper
  80. return decorator

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.