You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

search_based_kb.py 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from abc import ABC, abstractmethod
  2. from itertools import combinations, product
  3. from typing import Any, Callable, Generator, List, Optional, Tuple, Union
  4. import numpy
  5. from abl.structures import ListData
  6. from ..structures import ListData
  7. from ..utils import Cache
  8. from .base_kb import BaseKB
  9. def incremental_search_strategy(
  10. data_sample: ListData, max_revision_num: int, require_more_revision: int = 0
  11. ):
  12. symbol_num = data_sample["symbol_num"]
  13. max_revision_num = min(max_revision_num, symbol_num)
  14. real_end = max_revision_num
  15. for revision_num in range(max_revision_num + 1):
  16. if revision_num > real_end:
  17. break
  18. revision_idx_tuple = combinations(range(symbol_num), revision_num)
  19. for revision_idx in revision_idx_tuple:
  20. received = yield revision_idx
  21. if received == "success":
  22. real_end = min(symbol_num, revision_num + require_more_revision)
  23. class SearchBasedKB(BaseKB, ABC):
  24. def __init__(
  25. self,
  26. pseudo_label_list: List,
  27. search_strategy: Callable[[ListData, int, int], Generator] = incremental_search_strategy,
  28. use_cache: bool = True,
  29. cache_file: Optional[str] = None,
  30. cache_size: int = 4096
  31. ) -> None:
  32. super().__init__(pseudo_label_list)
  33. self.search_strategy = search_strategy
  34. self.use_cache = use_cache
  35. self.cache_file = cache_file
  36. if self.use_cache:
  37. if not hasattr(self, "get_key"):
  38. raise NotImplementedError("If use_cache is True, get_key should be implemented.")
  39. key_func = self.get_key
  40. else:
  41. key_func = lambda x: x
  42. self.cache = Cache[ListData, List[List[Any]]](
  43. func=self._abduce_by_search,
  44. cache=self.use_cache,
  45. cache_file=self.cache_file,
  46. key_func=key_func,
  47. max_size=cache_size,
  48. )
  49. @abstractmethod
  50. def check_equal(self, data_sample: ListData, y: Any):
  51. """Placeholder for check_equal."""
  52. pass
  53. def abduce_candidates(
  54. self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0
  55. ):
  56. return self.cache.get(data_sample, max_revision_num, require_more_revision)
  57. def revise_at_idx(
  58. self,
  59. data_sample: ListData,
  60. revision_idx: Union[List, Tuple, numpy.ndarray],
  61. ):
  62. candidates = []
  63. abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx))
  64. for c in abduce_c:
  65. new_data_sample = data_sample.clone()
  66. candidate = new_data_sample["pred_pseudo_label"][0].copy()
  67. for i, idx in enumerate(revision_idx):
  68. candidate[idx] = c[i]
  69. new_data_sample["pred_pseudo_label"][0] = candidate
  70. if self.check_equal(new_data_sample, new_data_sample["Y"][0]):
  71. candidates.append(candidate)
  72. return candidates
  73. def _abduce_by_search(
  74. self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0
  75. ):
  76. candidates = []
  77. gen = self.search_strategy(
  78. data_sample,
  79. max_revision_num=max_revision_num,
  80. require_more_revision=require_more_revision,
  81. )
  82. send_signal = True
  83. for revision_idx in gen:
  84. candidates.extend(self.revise_at_idx(data_sample, revision_idx))
  85. if len(candidates) > 0 and send_signal:
  86. try:
  87. revision_idx = gen.send("success")
  88. candidates.extend(self.revise_at_idx(data_sample, revision_idx))
  89. send_signal = False
  90. except StopIteration:
  91. break
  92. return candidates
  93. # TODO: When the output is excessively long, use ellipses as a substitute.
  94. def __repr__(self):
  95. return (
  96. f"<{self.__class__.__name__}(\n"
  97. f" pseudo_label_list: {self.pseudo_label_list!r}\n"
  98. f" search_strategy: {self.search_strategy!r}\n"
  99. f" use_cache: {self.use_cache!r}\n"
  100. f" cache_root: {self.cache_root!r}\n"
  101. f") at {hex(id(self))}>"
  102. )

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.