You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

reasoner.py 13 kB

2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. import inspect
  2. from typing import Callable, Any, List, Optional, Union
  3. import numpy as np
  4. from zoopt import Dimension, Objective, Opt, Parameter, Solution
  5. from ..reasoning import KBBase
  6. from ..data.structures import ListData
  7. from ..utils.utils import confidence_dist, hamming_dist
  8. class Reasoner:
  9. """
  10. Reasoner for minimizing the inconsistency between the knowledge base and learning models.
  11. Parameters
  12. ----------
  13. kb : class KBBase
  14. The knowledge base to be used for reasoning.
  15. dist_func : Union[str, Callable], optional
  16. The distance function used to determine the cost list between each
  17. candidate and the given prediction. The cost is also referred to as a consistency
  18. measure, wherein the candidate with lowest cost is selected as the final
  19. abduced label. It can be either a string representing a predefined distance
  20. function or a callable function. The available predefined distance functions:
  21. 'hamming' | 'confidence'. 'hamming': directly calculates the Hamming
  22. distance between the predicted pseudo-label in the data example and each
  23. candidate, 'confidence': calculates the distance between the prediction
  24. and each candidate based on confidence derived from the predicted probability
  25. in the data example. The callable function should have the signature
  26. dist_func(data_example, candidates, candidate_idxs, reasoning_results) and must return a cost list. Each element
  27. in this cost list should be a numerical value representing the cost for each
  28. candidate, and the list should have the same length as candidates.
  29. Defaults to 'confidence'.
  30. idx_to_label : Optional[dict], optional
  31. A mapping from index in the base model to label. If not provided, a default
  32. order-based index to label mapping is created. Defaults to None.
  33. max_revision : Union[int, float], optional
  34. The upper limit on the number of revisions for each data example when
  35. performing abductive reasoning. If float, denotes the fraction of the total
  36. length that can be revised. A value of -1 implies no restriction on the
  37. number of revisions. Defaults to -1.
  38. require_more_revision : int, optional
  39. Specifies additional number of revisions permitted beyond the minimum required
  40. when performing abductive reasoning. Defaults to 0.
  41. use_zoopt : bool, optional
  42. Whether to use ZOOpt library during abductive reasoning. Defaults to False.
  43. """
  44. def __init__(
  45. self,
  46. kb: KBBase,
  47. dist_func: Union[str, Callable] = "confidence",
  48. idx_to_label: Optional[dict] = None,
  49. max_revision: Union[int, float] = -1,
  50. require_more_revision: int = 0,
  51. use_zoopt: bool = False,
  52. ):
  53. self.kb = kb
  54. self._check_valid_dist(dist_func)
  55. self.dist_func = dist_func
  56. self.use_zoopt = use_zoopt
  57. self.max_revision = max_revision
  58. self.require_more_revision = require_more_revision
  59. if idx_to_label is None:
  60. self.idx_to_label = {
  61. index: label for index, label in enumerate(self.kb.pseudo_label_list)
  62. }
  63. else:
  64. self._check_valid_idx_to_label(idx_to_label)
  65. self.idx_to_label = idx_to_label
  66. self.label_to_idx = dict(zip(self.idx_to_label.values(), self.idx_to_label.keys()))
  67. def _check_valid_dist(self, dist_func):
  68. if isinstance(dist_func, str):
  69. if dist_func not in ["hamming", "confidence"]:
  70. raise NotImplementedError(
  71. f'Valid options for predefined dist_func include "hamming" and "confidence", but got {dist_func}.'
  72. )
  73. return
  74. elif callable(dist_func):
  75. params = inspect.signature(dist_func).parameters.values()
  76. if len(params) != 4:
  77. raise ValueError(
  78. f"User-defined dist_func must have exactly four parameters, but got {len(params)}."
  79. )
  80. return
  81. else:
  82. raise TypeError(
  83. f"dist_func must be a string or a callable function, but got {type(dist_func)}."
  84. )
  85. def _check_valid_idx_to_label(self, idx_to_label):
  86. if not isinstance(idx_to_label, dict):
  87. raise TypeError(f"idx_to_label should be dict, but got {type(idx_to_label)}.")
  88. for key, value in idx_to_label.items():
  89. if not isinstance(key, int):
  90. raise ValueError(f"All keys in the idx_to_label must be integers, but got {key}.")
  91. if value not in self.kb.pseudo_label_list:
  92. raise ValueError(
  93. f"All values in the idx_to_label must be in the pseudo_label_list, but got {value}."
  94. )
  95. def _get_one_candidate(
  96. self,
  97. data_example: ListData,
  98. candidates: List[List[Any]],
  99. reasoning_results: List[Any],
  100. ) -> List[Any]:
  101. """
  102. Due to the nondeterminism of abductive reasoning, there could be multiple candidates
  103. satisfying the knowledge base. When this happens, return one candidate that has the
  104. minimum cost. If no candidates are provided, an empty list is returned.
  105. Parameters
  106. ----------
  107. data_example : ListData
  108. Data example.
  109. candidates : List[List[Any]]
  110. Multiple compatible candidates.
  111. reasoning_results : List[Any]
  112. Corresponding reasoning results of the candidates.
  113. Returns
  114. -------
  115. List[Any]
  116. A selected candidate.
  117. """
  118. if len(candidates) == 0:
  119. return []
  120. elif len(candidates) == 1:
  121. return candidates[0]
  122. else:
  123. cost_array = self._get_cost_list(data_example, candidates, reasoning_results)
  124. candidate = candidates[np.argmin(cost_array)]
  125. return candidate
  126. def _get_cost_list(
  127. self,
  128. data_example: ListData,
  129. candidates: List[List[Any]],
  130. reasoning_results: List[Any],
  131. ) -> Union[List[Union[int, float]], np.ndarray]:
  132. """
  133. Get the list of costs between each candidate and the given data example.
  134. Parameters
  135. ----------
  136. data_example : ListData
  137. Data example.
  138. candidates : List[List[Any]]
  139. Multiple compatible candidates.
  140. reasoning_results : List[Any]
  141. Corresponding reasoning results of the candidates.
  142. Returns
  143. -------
  144. Union[List[Union[int, float]], np.ndarray]
  145. The list of costs.
  146. """
  147. if self.dist_func == "hamming":
  148. return hamming_dist(data_example.pred_pseudo_label, candidates)
  149. elif self.dist_func == "confidence":
  150. candidates = [[self.label_to_idx[x] for x in c] for c in candidates]
  151. return confidence_dist(data_example.pred_prob, candidates)
  152. else:
  153. candidate_idxs = [[self.label_to_idx[x] for x in c] for c in candidates]
  154. cost_list = self.dist_func(data_example, candidates, candidate_idxs, reasoning_results)
  155. if len(cost_list) != len(candidates):
  156. raise ValueError(
  157. f"The length of the array returned by dist_func must be equal to the number of candidates. "
  158. f"Expected length {len(candidates)}, but got {len(cost_list)}."
  159. )
  160. return cost_list
  161. def _zoopt_get_solution(
  162. self,
  163. symbol_num: int,
  164. data_example: ListData,
  165. max_revision_num: int,
  166. ) -> Solution:
  167. """
  168. Get the optimal solution using ZOOpt library. From the solution, we can get a list of
  169. boolean values, where '1' (True) indicates the indices chosen to be revised.
  170. Parameters
  171. ----------
  172. symbol_num : int
  173. Number of total symbols.
  174. data_example : ListData
  175. Data example.
  176. max_revision_num : int
  177. Specifies the maximum number of revisions allowed.
  178. Returns
  179. -------
  180. Solution
  181. The solution for ZOOpt library.
  182. """
  183. dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num)
  184. objective = Objective(
  185. lambda sol: self.zoopt_revision_score(symbol_num, data_example, sol),
  186. dim=dimension,
  187. constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),
  188. )
  189. parameter = Parameter(budget=200, intermediate_result=False, autoset=True)
  190. solution = Opt.min(objective, parameter)
  191. return solution
  192. def zoopt_revision_score(
  193. self,
  194. symbol_num: int,
  195. data_example: ListData,
  196. sol: Solution,
  197. ) -> int:
  198. """
  199. Get the revision score for a solution. A lower score suggests that ZOOpt library
  200. has a higher preference for this solution.
  201. Parameters
  202. ----------
  203. symbol_num : int
  204. Number of total symbols.
  205. data_example : ListData
  206. Data example.
  207. sol: Solution
  208. The solution for ZOOpt library.
  209. Returns
  210. -------
  211. int
  212. The revision score for the solution.
  213. """
  214. revision_idx = np.where(sol.get_x() != 0)[0]
  215. candidates, reasoning_results = self.kb.revise_at_idx(
  216. data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx
  217. )
  218. if len(candidates) > 0:
  219. return np.min(self._get_cost_list(data_example, candidates, reasoning_results))
  220. else:
  221. return symbol_num
  222. def _constrain_revision_num(self, solution: Solution, max_revision_num: int) -> int:
  223. """
  224. Constrain that the total number of revisions chosen by the solution does not exceed
  225. maximum number of revisions allowed.
  226. """
  227. x = solution.get_x()
  228. return max_revision_num - x.sum()
  229. def _get_max_revision_num(self, max_revision: Union[int, float], symbol_num: int) -> int:
  230. """
  231. Get the maximum revision number according to input ``max_revision``.
  232. """
  233. if not isinstance(max_revision, (int, float)):
  234. raise TypeError(f"Parameter must be of type int or float, but got {type(max_revision)}")
  235. if max_revision == -1:
  236. return symbol_num
  237. elif isinstance(max_revision, float):
  238. if not (0 <= max_revision <= 1):
  239. raise ValueError(
  240. f"If max_revision is a float, it must be between 0 and 1, but got {max_revision}"
  241. )
  242. return round(symbol_num * max_revision)
  243. else:
  244. if max_revision < 0:
  245. raise ValueError(
  246. f"If max_revision is an int, it must be non-negative, but got {max_revision}"
  247. )
  248. return max_revision
  249. def abduce(self, data_example: ListData) -> List[Any]:
  250. """
  251. Perform abductive reasoning on the given data example.
  252. Parameters
  253. ----------
  254. data_example : ListData
  255. Data example.
  256. Returns
  257. -------
  258. List[Any]
  259. A revised pseudo-label example through abductive reasoning, which is compatible
  260. with the knowledge base.
  261. """
  262. symbol_num = data_example.elements_num("pred_pseudo_label")
  263. max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)
  264. if self.use_zoopt:
  265. solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num)
  266. revision_idx = np.where(solution.get_x() != 0)[0]
  267. candidates, reasoning_results = self.kb.revise_at_idx(
  268. pseudo_label=data_example.pred_pseudo_label,
  269. y=data_example.Y,
  270. x=data_example.X,
  271. revision_idx=revision_idx,
  272. )
  273. else:
  274. candidates, reasoning_results = self.kb.abduce_candidates(
  275. pseudo_label=data_example.pred_pseudo_label,
  276. y=data_example.Y,
  277. x=data_example.X,
  278. max_revision_num=max_revision_num,
  279. require_more_revision=self.require_more_revision,
  280. )
  281. candidate = self._get_one_candidate(data_example, candidates, reasoning_results)
  282. return candidate
  283. def batch_abduce(self, data_examples: ListData) -> List[List[Any]]:
  284. """
  285. Perform abductive reasoning on the given prediction data examples.
  286. For detailed information, refer to ``abduce``.
  287. """
  288. abduced_pseudo_label = [self.abduce(data_example) for data_example in data_examples]
  289. data_examples.abduced_pseudo_label = abduced_pseudo_label
  290. return abduced_pseudo_label
  291. def __call__(self, data_examples: ListData) -> List[List[Any]]:
  292. return self.batch_abduce(data_examples)

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.