You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

reasoner.py 9.7 kB

2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. from typing import Any, List, Mapping, Optional
  2. import numpy as np
  3. from ..structures import ListData
  4. from ..utils import Cache, calculate_revision_num, confidence_dist, hamming_dist
  5. from .base_kb import BaseKB
  6. from .search_engine import BFS, BaseSearchEngine
  7. class ReasonerBase:
  8. def __init__(
  9. self,
  10. kb: BaseKB,
  11. dist_func: str = "confidence",
  12. mapping: Optional[Mapping] = None,
  13. search_engine: Optional[BaseSearchEngine] = None,
  14. use_cache: bool = False,
  15. cache_file: Optional[str] = None,
  16. cache_size: Optional[int] = 4096,
  17. ):
  18. """
  19. Base class for all reasoner in the ABL system.
  20. Parameters
  21. ----------
  22. kb : BaseKB
  23. The knowledge base to be used for reasoning.
  24. dist_func : str, optional
  25. The distance function to be used. Can be "hamming" or "confidence". Default is "confidence".
  26. mapping : dict, optional
  27. A mapping of indices to labels. If None, a default mapping is generated.
  28. use_zoopt : bool, optional
  29. Whether to use the Zoopt library for optimization. Default is False.
  30. Raises
  31. ------
  32. NotImplementedError
  33. If the specified distance function is neither "hamming" nor "confidence".
  34. """
  35. if not isinstance(kb, BaseKB):
  36. raise ValueError("The kb should be of type BaseKB.")
  37. self.kb = kb
  38. if dist_func not in ["hamming", "confidence"]:
  39. raise NotImplementedError(f"The distance function '{dist_func}' is not implemented.")
  40. self.dist_func = dist_func
  41. if mapping is None:
  42. self.mapping = {index: label for index, label in enumerate(self.kb.pseudo_label_list)}
  43. else:
  44. if not isinstance(mapping, dict):
  45. raise ValueError("mapping must be of type dict")
  46. for key, value in mapping.items():
  47. if not isinstance(key, int):
  48. raise ValueError("All keys in the mapping must be integers")
  49. if value not in self.kb.pseudo_label_list:
  50. raise ValueError("All values in the mapping must be in the pseudo_label_list")
  51. self.mapping = mapping
  52. self.remapping = dict(zip(self.mapping.values(), self.mapping.keys()))
  53. if search_engine is None:
  54. self.search_engine = BFS()
  55. else:
  56. if not isinstance(search_engine, BaseSearchEngine):
  57. raise ValueError("The search_engine should be of type BaseSearchEngine.")
  58. else:
  59. self.search_engine = search_engine
  60. self.use_cache = use_cache
  61. self.cache_file = cache_file
  62. if self.use_cache:
  63. if not hasattr(self, "get_key"):
  64. raise NotImplementedError("If use_cache is True, get_key should be implemented.")
  65. key_func = self.get_key
  66. else:
  67. key_func = lambda x: x
  68. self.cache = Cache[ListData, List[List[Any]]](
  69. func=self.abduce_candidates,
  70. cache=self.use_cache,
  71. cache_file=self.cache_file,
  72. key_func=key_func,
  73. max_size=cache_size,
  74. )
  75. def abduce(
  76. self,
  77. data_sample: ListData,
  78. max_revision: int = -1,
  79. require_more_revision: int = 0,
  80. ):
  81. """
  82. Perform revision by abduction on the given data.
  83. Parameters
  84. ----------
  85. pred_prob : list
  86. List of probabilities for predicted results.
  87. pred_pseudo_label : list
  88. List of predicted pseudo labels.
  89. y : any
  90. Ground truth for the predicted results.
  91. max_revision : int or float, optional
  92. Maximum number of revisions to use. If float, represents the fraction of total revisions to use.
  93. If -1, any revisions are allowed. Defaults to -1.
  94. require_more_revision : int, optional
  95. Number of additional revisions to require. Defaults to 0.
  96. Returns
  97. -------
  98. list
  99. The abduced revisions.
  100. """
  101. symbol_num = data_sample.elements_num("pred_pseudo_label")
  102. max_revision_num = calculate_revision_num(max_revision, symbol_num)
  103. data_sample.set_metainfo(dict(symbol_num=symbol_num))
  104. candidates = self.cache.get(data_sample, max_revision_num, require_more_revision)
  105. candidate = self.select_one_candidate(data_sample, candidates)
  106. return candidate
  107. def abduce_candidates(
  108. self,
  109. data_sample: ListData,
  110. max_revision_num: int = -1,
  111. require_more_revision: int = 0,
  112. ):
  113. """
  114. Perform revision by abduction on the given data.
  115. Parameters
  116. ----------
  117. pred_prob : list
  118. List of probabilities for predicted results.
  119. pred_pseudo_label : list
  120. List of predicted pseudo labels.
  121. y : any
  122. Ground truth for the predicted results.
  123. max_revision : int or float, optional
  124. Maximum number of revisions to use. If float, represents the fraction of total revisions to use.
  125. If -1, any revisions are allowed. Defaults to -1.
  126. require_more_revision : int, optional
  127. Number of additional revisions to require. Defaults to 0.
  128. Returns
  129. -------
  130. list
  131. The abduced revisions.
  132. """
  133. if hasattr(self.kb, "abduce_candidates"):
  134. candidates = self.kb.abduce_candidates(
  135. data_sample, max_revision_num, require_more_revision
  136. )
  137. elif hasattr(self.kb, "revise_at_idx"):
  138. candidates = []
  139. gen = self.search_engine.generator(
  140. data_sample,
  141. max_revision_num=max_revision_num,
  142. require_more_revision=require_more_revision,
  143. )
  144. send_signal = True
  145. for revision_idx in gen:
  146. candidates.extend(self.kb.revise_at_idx(data_sample, revision_idx))
  147. if len(candidates) > 0 and send_signal:
  148. try:
  149. revision_idx = gen.send("success")
  150. candidates.extend(self.kb.revise_at_idx(data_sample, revision_idx))
  151. send_signal = False
  152. except StopIteration:
  153. break
  154. else:
  155. raise NotImplementedError(
  156. "The kb should either implement abduce_candidates or revise_at_idx."
  157. )
  158. return candidates
  159. def select_one_candidate(self, data_sample: ListData, candidates: List[List[Any]]):
  160. """
  161. Get one candidate. If multiple candidates exist, return the one with minimum cost.
  162. Parameters
  163. ----------
  164. pred_pseudo_label : list
  165. The pseudo label to be used for selecting a candidate.
  166. pred_prob : list
  167. Probabilities of the predictions.
  168. candidates : list
  169. List of candidate abduction result.
  170. Returns
  171. -------
  172. list
  173. The chosen candidate based on minimum cost.
  174. If no candidates, an empty list is returned.
  175. """
  176. if len(candidates) == 0:
  177. return []
  178. elif len(candidates) == 1:
  179. return candidates[0]
  180. else:
  181. cost_array = self._get_dist_list(data_sample, candidates)
  182. candidate = candidates[np.argmin(cost_array)]
  183. return candidate
  184. def _get_dist_list(self, data_sample: ListData, candidates: List[List[Any]]):
  185. """
  186. Get the list of costs between each pseudo label and candidate.
  187. Parameters
  188. ----------
  189. pred_pseudo_label : list
  190. The pseudo label to be used for computing costs of candidates.
  191. pred_prob : list
  192. Probabilities of the predictions. Used when distance function is "confidence".
  193. candidates : list
  194. List of candidate abduction result.
  195. Returns
  196. -------
  197. numpy.ndarray
  198. Array of computed costs for each candidate.
  199. """
  200. if self.dist_func == "hamming":
  201. return hamming_dist(data_sample["pred_pseudo_label"][0], candidates)
  202. elif self.dist_func == "confidence":
  203. candidates = [[self.remapping[x] for x in c] for c in candidates]
  204. return confidence_dist(data_sample["pred_prob"][0], candidates)
  205. def batch_abduce(
  206. self,
  207. data_samples: ListData,
  208. max_revision: int = -1,
  209. require_more_revision: int = 0,
  210. ):
  211. """
  212. Perform abduction on the given data in batches.
  213. Parameters
  214. ----------
  215. pred_prob : list
  216. List of probabilities for predicted results.
  217. pred_pseudo_label : list
  218. List of predicted pseudo labels.
  219. Y : list
  220. List of ground truths for the predicted results.
  221. max_revision : int or float, optional
  222. Maximum number of revisions to use. If float, represents the fraction of total revisions to use.
  223. If -1, use all revisions. Defaults to -1.
  224. require_more_revision : int, optional
  225. Number of additional revisions to require. Defaults to 0.
  226. Returns
  227. -------
  228. list
  229. The abduced revisions in batches.
  230. """
  231. abduced_pseudo_label = [
  232. self.abduce(
  233. data_sample,
  234. max_revision=max_revision,
  235. require_more_revision=require_more_revision,
  236. )
  237. for data_sample in data_samples
  238. ]
  239. data_samples.abduced_pseudo_label = abduced_pseudo_label
  240. return abduced_pseudo_label

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.