You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

reasoner.py 14 kB

2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. """
  2. This module contains the class Reasoner, which is used for minimizing the inconsistency
  3. between the knowledge base and learning models.
  4. Copyright (c) 2024 LAMDA. All rights reserved.
  5. """
  6. import inspect
  7. from typing import Any, Callable, List, Optional, Union
  8. import numpy as np
  9. from zoopt import Dimension, Objective, Opt, Parameter, Solution
  10. from ..data.structures import ListData
  11. from ..reasoning import KBBase
  12. from ..utils.utils import hamming_dist, confidence_dist, avg_confidence_dist
  13. class Reasoner:
  14. """
  15. Reasoner for minimizing the inconsistency between the knowledge base and learning models.
  16. Parameters
  17. ----------
  18. kb : class KBBase
  19. The knowledge base to be used for reasoning.
  20. dist_func : Union[str, Callable], optional
  21. The distance function used to determine the cost list between each
  22. candidate and the given prediction. The cost is also referred to as a consistency
  23. measure, wherein the candidate with lowest cost is selected as the final
  24. abduced label. It can be either a string representing a predefined distance
  25. function or a callable function. The available predefined distance functions:
  26. 'hamming' | 'confidence' | 'avg_confidence'. 'hamming' directly calculates the
  27. Hamming distance between the predicted pseudo-label in the data example and each
  28. candidate. 'confidence' and 'avg_confidence' calculates the confidence distance
  29. between the predicted probabilities in the data example and each candidate, where
  30. the confidence distance is defined as 1 - the product of prediction probabilities
  31. in 'confidence' and 1 - the average of prediction probabilities in 'avg_confidence'.
  32. Alternatively, the callable function should have the signature
  33. ``dist_func(data_example, candidates, candidate_idxs, reasoning_results)`` and must
  34. return a cost list. Each element in this cost list should be a numerical value
  35. representing the cost for each candidate, and the list should have the same length
  36. as candidates. Defaults to 'confidence'.
  37. idx_to_label : dict, optional
  38. A mapping from index in the base model to label. If not provided, a default
  39. order-based index to label mapping is created. Defaults to None.
  40. max_revision : Union[int, float], optional
  41. The upper limit on the number of revisions for each data example when
  42. performing abductive reasoning. If float, denotes the fraction of the total
  43. length that can be revised. A value of -1 implies no restriction on the
  44. number of revisions. Defaults to -1.
  45. require_more_revision : int, optional
  46. Specifies additional number of revisions permitted beyond the minimum required
  47. when performing abductive reasoning. Defaults to 0.
  48. use_zoopt : bool, optional
  49. Whether to use ZOOpt library during abductive reasoning. Defaults to False.
  50. """
  51. def __init__(
  52. self,
  53. kb: KBBase,
  54. dist_func: Union[str, Callable] = "confidence",
  55. idx_to_label: Optional[dict] = None,
  56. max_revision: Union[int, float] = -1,
  57. require_more_revision: int = 0,
  58. use_zoopt: bool = False,
  59. ):
  60. self.kb = kb
  61. self._check_valid_dist(dist_func)
  62. self.dist_func = dist_func
  63. self.use_zoopt = use_zoopt
  64. self.max_revision = max_revision
  65. self.require_more_revision = require_more_revision
  66. if idx_to_label is None:
  67. self.idx_to_label = {
  68. index: label for index, label in enumerate(self.kb.pseudo_label_list)
  69. }
  70. else:
  71. self._check_valid_idx_to_label(idx_to_label)
  72. self.idx_to_label = idx_to_label
  73. self.label_to_idx = dict(zip(self.idx_to_label.values(), self.idx_to_label.keys()))
  74. def _check_valid_dist(self, dist_func):
  75. if isinstance(dist_func, str):
  76. if dist_func not in ["hamming", "confidence", "avg_confidence"]:
  77. raise NotImplementedError(
  78. 'Valid options for predefined dist_func include "hamming", '
  79. + f'"confidence" and "avg_confidence", but got {dist_func}.'
  80. )
  81. return
  82. elif callable(dist_func):
  83. params = inspect.signature(dist_func).parameters.values()
  84. if len(params) != 4:
  85. raise ValueError(
  86. "User-defined dist_func must have exactly four parameters, "
  87. + f"but got {len(params)}."
  88. )
  89. return
  90. else:
  91. raise TypeError(
  92. f"dist_func must be a string or a callable function, but got {type(dist_func)}."
  93. )
  94. def _check_valid_idx_to_label(self, idx_to_label):
  95. if not isinstance(idx_to_label, dict):
  96. raise TypeError(f"idx_to_label should be dict, but got {type(idx_to_label)}.")
  97. for key, value in idx_to_label.items():
  98. if not isinstance(key, int):
  99. raise ValueError(f"All keys in the idx_to_label must be integers, but got {key}.")
  100. if value not in self.kb.pseudo_label_list:
  101. raise ValueError(
  102. "All values in the idx_to_label must be in the pseudo_label_list, "
  103. + f"but got {value}."
  104. )
  105. def _get_one_candidate(
  106. self,
  107. data_example: ListData,
  108. candidates: List[List[Any]],
  109. reasoning_results: List[Any],
  110. ) -> List[Any]:
  111. """
  112. Due to the nondeterminism of abductive reasoning, there could be multiple candidates
  113. satisfying the knowledge base. When this happens, return one candidate that has the
  114. minimum cost. If no candidates are provided, an empty list is returned.
  115. Parameters
  116. ----------
  117. data_example : ListData
  118. Data example.
  119. candidates : List[List[Any]]
  120. Multiple possible candidates.
  121. reasoning_results : List[Any]
  122. Corresponding reasoning results of the candidates.
  123. Returns
  124. -------
  125. List[Any]
  126. A selected candidate.
  127. """
  128. if len(candidates) == 0:
  129. return []
  130. elif len(candidates) == 1:
  131. return candidates[0]
  132. else:
  133. cost_array = self._get_cost_list(data_example, candidates, reasoning_results)
  134. candidate = candidates[np.argmin(cost_array)]
  135. return candidate
  136. def _get_cost_list(
  137. self,
  138. data_example: ListData,
  139. candidates: List[List[Any]],
  140. reasoning_results: List[Any],
  141. ) -> Union[List[Union[int, float]], np.ndarray]:
  142. """
  143. Get the list of costs between each candidate and the given data example.
  144. Parameters
  145. ----------
  146. data_example : ListData
  147. Data example.
  148. candidates : List[List[Any]]
  149. Multiple possible candidates.
  150. reasoning_results : List[Any]
  151. Corresponding reasoning results of the candidates.
  152. Returns
  153. -------
  154. Union[List[Union[int, float]], np.ndarray]
  155. The list of costs.
  156. """
  157. if self.dist_func == "hamming":
  158. return hamming_dist(data_example.pred_pseudo_label, candidates)
  159. elif self.dist_func == "confidence":
  160. candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates]
  161. return confidence_dist(data_example.pred_prob, candidates_idxs)
  162. elif self.dist_func == "avg_confidence":
  163. candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates]
  164. return avg_confidence_dist(data_example.pred_prob, candidates_idxs)
  165. else:
  166. candidate_idxs = [[self.label_to_idx[x] for x in c] for c in candidates]
  167. cost_list = self.dist_func(data_example, candidates, candidate_idxs, reasoning_results)
  168. if len(cost_list) != len(candidates):
  169. raise ValueError(
  170. "The length of the array returned by dist_func must be equal to the number "
  171. + f"of candidates. Expected length {len(candidates)}, but got {len(cost_list)}."
  172. )
  173. return cost_list
  174. def _zoopt_get_solution(
  175. self,
  176. symbol_num: int,
  177. data_example: ListData,
  178. max_revision_num: int,
  179. ) -> Solution:
  180. """
  181. Get the optimal solution using ZOOpt library. From the solution, we can get a list of
  182. boolean values, where '1' (True) indicates the indices chosen to be revised.
  183. Parameters
  184. ----------
  185. symbol_num : int
  186. Number of total symbols.
  187. data_example : ListData
  188. Data example.
  189. max_revision_num : int
  190. Specifies the maximum number of revisions allowed.
  191. Returns
  192. -------
  193. Solution
  194. The solution for ZOOpt library.
  195. """
  196. dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num)
  197. objective = Objective(
  198. lambda sol: self.zoopt_score(symbol_num, data_example, sol),
  199. dim=dimension,
  200. constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),
  201. )
  202. parameter = Parameter(
  203. budget=self.zoopt_budget(symbol_num), intermediate_result=False, autoset=True
  204. )
  205. solution = Opt.min(objective, parameter)
  206. return solution
  207. def zoopt_score(
  208. self,
  209. symbol_num: int,
  210. data_example: ListData,
  211. sol: Solution,
  212. ) -> int:
  213. """
  214. Set the score for a solution. A lower score suggests that ZOOpt library
  215. has a higher preference for this solution.
  216. Parameters
  217. ----------
  218. symbol_num : int
  219. Number of total symbols.
  220. data_example : ListData
  221. Data example.
  222. sol: Solution
  223. The solution for ZOOpt library.
  224. Returns
  225. -------
  226. int
  227. The score for the solution.
  228. """
  229. revision_idx = np.where(sol.get_x() != 0)[0]
  230. candidates, reasoning_results = self.kb.revise_at_idx(
  231. data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx
  232. )
  233. if len(candidates) > 0:
  234. return np.min(self._get_cost_list(data_example, candidates, reasoning_results))
  235. else:
  236. return symbol_num
  237. def zoopt_budget(self, symbol_num: int) -> int:
  238. """
  239. Set the budget for ZOOpt optimization. The budget can be dynamic relying on
  240. the number of symbols considered, e.g., the default implementation shown below.
  241. Alternatively, it can be a fixed value, such as simply setting it to 100.
  242. Parameters
  243. ----------
  244. symbol_num : int
  245. The number of symbols to be considered in the ZOOpt optimization process.
  246. Returns
  247. -------
  248. int
  249. The budget for ZOOpt optimization.
  250. """
  251. return 10 * symbol_num
  252. def _constrain_revision_num(self, solution: Solution, max_revision_num: int) -> int:
  253. """
  254. Constrain that the total number of revisions chosen by the solution does not exceed
  255. maximum number of revisions allowed.
  256. """
  257. x = solution.get_x()
  258. return max_revision_num - x.sum()
  259. def _get_max_revision_num(self, max_revision: Union[int, float], symbol_num: int) -> int:
  260. """
  261. Get the maximum revision number according to input ``max_revision``.
  262. """
  263. if not isinstance(max_revision, (int, float)):
  264. raise TypeError(f"Parameter must be of type int or float, but got {type(max_revision)}")
  265. if max_revision == -1:
  266. return symbol_num
  267. if isinstance(max_revision, float):
  268. if not 0 <= max_revision <= 1:
  269. raise ValueError(
  270. "If max_revision is a float, it must be between 0 and 1, "
  271. + f"but got {max_revision}"
  272. )
  273. return round(symbol_num * max_revision)
  274. if max_revision < 0:
  275. raise ValueError(
  276. f"If max_revision is an int, it must be non-negative, but got {max_revision}"
  277. )
  278. return max_revision
  279. def abduce(self, data_example: ListData) -> List[Any]:
  280. """
  281. Perform abductive reasoning on the given data example.
  282. Parameters
  283. ----------
  284. data_example : ListData
  285. Data example.
  286. Returns
  287. -------
  288. List[Any]
  289. A revised pseudo-labels of the example through abductive reasoning, which is compatible
  290. with the knowledge base.
  291. """
  292. symbol_num = data_example.elements_num("pred_pseudo_label")
  293. max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)
  294. if self.use_zoopt:
  295. solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num)
  296. revision_idx = np.where(solution.get_x() != 0)[0]
  297. candidates, reasoning_results = self.kb.revise_at_idx(
  298. pseudo_label=data_example.pred_pseudo_label,
  299. y=data_example.Y,
  300. x=data_example.X,
  301. revision_idx=revision_idx,
  302. )
  303. else:
  304. candidates, reasoning_results = self.kb.abduce_candidates(
  305. pseudo_label=data_example.pred_pseudo_label,
  306. y=data_example.Y,
  307. x=data_example.X,
  308. max_revision_num=max_revision_num,
  309. require_more_revision=self.require_more_revision,
  310. )
  311. candidate = self._get_one_candidate(data_example, candidates, reasoning_results)
  312. return candidate
  313. def batch_abduce(self, data_examples: ListData) -> List[List[Any]]:
  314. """
  315. Perform abductive reasoning on the given prediction data examples.
  316. For detailed information, refer to ``abduce``.
  317. """
  318. abduced_pseudo_label = [self.abduce(data_example) for data_example in data_examples]
  319. data_examples.abduced_pseudo_label = abduced_pseudo_label
  320. return abduced_pseudo_label
  321. def __call__(self, data_examples: ListData) -> List[List[Any]]:
  322. return self.batch_abduce(data_examples)

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.