You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

reasoner.py 13 kB

2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. import inspect
  2. from typing import Callable, Any, List, Optional, Union
  3. import numpy as np
  4. from zoopt import Dimension, Objective, Opt, Parameter, Solution
  5. from ..reasoning import KBBase
  6. from ..structures import ListData
  7. from ..utils.utils import confidence_dist, hamming_dist
  8. class Reasoner:
  9. """
  10. Reasoner for minimizing the inconsistency between the knowledge base and learning models.
  11. Parameters
  12. ----------
  13. kb : class KBBase
  14. The knowledge base to be used for reasoning.
  15. dist_func : Union[str, Callable], optional
  16. The distance function used to determine the cost list between each
  17. candidate and the given prediction. The cost is also referred to as a consistency
  18. measure, wherein the candidate with lowest cost is selected as the final
  19. abduced label. It can be either a string representing a predefined distance
  20. function or a callable function. The available predefined distance functions:
  21. 'hamming' | 'confidence'. 'hamming': directly calculates the Hamming
  22. distance between the predicted pseudo-label in the data example and each
  23. candidate, 'confidence': calculates the distance between the prediction
  24. and each candidate based on confidence derived from the predicted probability
  25. in the data example. The callable function should have the signature
  26. dist_func(data_example, candidates, candidate_idxs, reasoning_results) and must return a cost list. Each element
  27. in this cost list should be a numerical value representing the cost for each
  28. candidate, and the list should have the same length as candidates.
  29. Defaults to 'confidence'.
  30. idx_to_label : Optional[dict], optional
  31. A mapping from index in the base model to label. If not provided, a default
  32. order-based index to label mapping is created. Defaults to None.
  33. max_revision : Union[int, float], optional
  34. The upper limit on the number of revisions for each data example when
  35. performing abductive reasoning. If float, denotes the fraction of the total
  36. length that can be revised. A value of -1 implies no restriction on the
  37. number of revisions. Defaults to -1.
  38. require_more_revision : int, optional
  39. Specifies additional number of revisions permitted beyond the minimum required
  40. when performing abductive reasoning. Defaults to 0.
  41. use_zoopt : bool, optional
  42. Whether to use ZOOpt library during abductive reasoning. Defaults to False.
  43. """
  44. def __init__(
  45. self,
  46. kb: KBBase,
  47. dist_func: Union[str, Callable] = "confidence",
  48. idx_to_label: Optional[dict] = None,
  49. max_revision: Union[int, float] = -1,
  50. require_more_revision: int = 0,
  51. use_zoopt: bool = False,
  52. ):
  53. self.kb = kb
  54. self._check_valid_dist(dist_func)
  55. self.dist_func = dist_func
  56. self.use_zoopt = use_zoopt
  57. self.max_revision = max_revision
  58. self.require_more_revision = require_more_revision
  59. if idx_to_label is None:
  60. self.idx_to_label = {index: label for index, label in enumerate(self.kb.pseudo_label_list)}
  61. else:
  62. self._check_valid_idx_to_label(idx_to_label)
  63. self.idx_to_label = idx_to_label
  64. self.label_to_idx = dict(zip(self.idx_to_label.values(), self.idx_to_label.keys()))
  65. def _check_valid_dist(self, dist_func):
  66. if isinstance(dist_func, str):
  67. if dist_func not in ["hamming", "confidence"]:
  68. raise NotImplementedError(
  69. f'Valid options for predefined dist_func include "hamming" and "confidence", but got {dist_func}.'
  70. )
  71. return
  72. elif callable(dist_func):
  73. params = inspect.signature(dist_func).parameters.values()
  74. if len(params) != 4:
  75. raise ValueError(f"User-defined dist_func must have exactly four parameters, but got {len(params)}.")
  76. return
  77. else:
  78. raise TypeError(
  79. f"dist_func must be a string or a callable function, but got {type(dist_func)}."
  80. )
  81. def _check_valid_idx_to_label(self, idx_to_label):
  82. if not isinstance(idx_to_label, dict):
  83. raise TypeError(f"idx_to_label should be dict, but got {type(idx_to_label)}.")
  84. for key, value in idx_to_label.items():
  85. if not isinstance(key, int):
  86. raise ValueError(f"All keys in the idx_to_label must be integers, but got {key}.")
  87. if value not in self.kb.pseudo_label_list:
  88. raise ValueError(
  89. f"All values in the idx_to_label must be in the pseudo_label_list, but got {value}."
  90. )
  91. def _get_one_candidate(
  92. self,
  93. data_example: ListData,
  94. candidates: List[List[Any]],
  95. reasoning_results: List[Any],
  96. ) -> List[Any]:
  97. """
  98. Due to the nondeterminism of abductive reasoning, there could be multiple candidates
  99. satisfying the knowledge base. When this happens, return one candidate that has the
  100. minimum cost. If no candidates are provided, an empty list is returned.
  101. Parameters
  102. ----------
  103. data_example : ListData
  104. Data example.
  105. candidates : List[List[Any]]
  106. Multiple compatible candidates.
  107. reasoning_results : List[Any]
  108. Corresponding reasoning results of the candidates.
  109. Returns
  110. -------
  111. List[Any]
  112. A selected candidate.
  113. """
  114. if len(candidates) == 0:
  115. return []
  116. elif len(candidates) == 1:
  117. return candidates[0]
  118. else:
  119. cost_array = self._get_cost_list(data_example, candidates, reasoning_results)
  120. candidate = candidates[np.argmin(cost_array)]
  121. return candidate
  122. def _get_cost_list(
  123. self,
  124. data_example: ListData,
  125. candidates: List[List[Any]],
  126. reasoning_results: List[Any],
  127. ) -> Union[List[Union[int, float]], np.ndarray]:
  128. """
  129. Get the list of costs between each candidate and the given data example.
  130. Parameters
  131. ----------
  132. data_example : ListData
  133. Data example.
  134. candidates : List[List[Any]]
  135. Multiple compatible candidates.
  136. reasoning_results : List[Any]
  137. Corresponding reasoning results of the candidates.
  138. Returns
  139. -------
  140. Union[List[Union[int, float]], np.ndarray]
  141. The list of costs.
  142. """
  143. if self.dist_func == "hamming":
  144. return hamming_dist(data_example.pred_pseudo_label, candidates)
  145. elif self.dist_func == "confidence":
  146. candidates = [[self.label_to_idx[x] for x in c] for c in candidates]
  147. return confidence_dist(data_example.pred_prob, candidates)
  148. else:
  149. candidate_idxs = [[self.label_to_idx[x] for x in c] for c in candidates]
  150. cost_list = self.dist_func(data_example, candidates, candidate_idxs, reasoning_results)
  151. if len(cost_list) != len(candidates):
  152. raise ValueError(
  153. f"The length of the array returned by dist_func must be equal to the number of candidates. "
  154. f"Expected length {len(candidates)}, but got {len(cost_list)}."
  155. )
  156. return cost_list
  157. def _zoopt_get_solution(
  158. self,
  159. symbol_num: int,
  160. data_example: ListData,
  161. max_revision_num: int,
  162. ) -> Solution:
  163. """
  164. Get the optimal solution using ZOOpt library. From the solution, we can get a list of
  165. boolean values, where '1' (True) indicates the indices chosen to be revised.
  166. Parameters
  167. ----------
  168. symbol_num : int
  169. Number of total symbols.
  170. data_example : ListData
  171. Data example.
  172. max_revision_num : int
  173. Specifies the maximum number of revisions allowed.
  174. Returns
  175. -------
  176. Solution
  177. The solution for ZOOpt library.
  178. """
  179. dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num)
  180. objective = Objective(
  181. lambda sol: self.zoopt_revision_score(symbol_num, data_example, sol),
  182. dim=dimension,
  183. constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),
  184. )
  185. parameter = Parameter(budget=100, intermediate_result=False, autoset=True)
  186. solution = Opt.min(objective, parameter)
  187. return solution
  188. def zoopt_revision_score(
  189. self,
  190. symbol_num: int,
  191. data_example: ListData,
  192. sol: Solution,
  193. ) -> int:
  194. """
  195. Get the revision score for a solution. A lower score suggests that ZOOpt library
  196. has a higher preference for this solution.
  197. Parameters
  198. ----------
  199. symbol_num : int
  200. Number of total symbols.
  201. data_example : ListData
  202. Data example.
  203. sol: Solution
  204. The solution for ZOOpt library.
  205. Returns
  206. -------
  207. int
  208. The revision score for the solution.
  209. """
  210. revision_idx = np.where(sol.get_x() != 0)[0]
  211. candidates, reasoning_results = self.kb.revise_at_idx(
  212. data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx
  213. )
  214. if len(candidates) > 0:
  215. return np.min(self._get_cost_list(data_example, candidates, reasoning_results))
  216. else:
  217. return symbol_num
  218. def _constrain_revision_num(self, solution: Solution, max_revision_num: int) -> int:
  219. """
  220. Constrain that the total number of revisions chosen by the solution does not exceed
  221. maximum number of revisions allowed.
  222. """
  223. x = solution.get_x()
  224. return max_revision_num - x.sum()
  225. def _get_max_revision_num(self, max_revision: Union[int, float], symbol_num: int) -> int:
  226. """
  227. Get the maximum revision number according to input `max_revision`.
  228. """
  229. if not isinstance(max_revision, (int, float)):
  230. raise TypeError(f"Parameter must be of type int or float, but got {type(max_revision)}")
  231. if max_revision == -1:
  232. return symbol_num
  233. elif isinstance(max_revision, float):
  234. if not (0 <= max_revision <= 1):
  235. raise ValueError(
  236. f"If max_revision is a float, it must be between 0 and 1, but got {max_revision}"
  237. )
  238. return round(symbol_num * max_revision)
  239. else:
  240. if max_revision < 0:
  241. raise ValueError(
  242. f"If max_revision is an int, it must be non-negative, but got {max_revision}"
  243. )
  244. return max_revision
  245. def abduce(self, data_example: ListData) -> List[Any]:
  246. """
  247. Perform abductive reasoning on the given data example.
  248. Parameters
  249. ----------
  250. data_example : ListData
  251. Data example.
  252. Returns
  253. -------
  254. List[Any]
  255. A revised pseudo-label example through abductive reasoning, which is compatible
  256. with the knowledge base.
  257. """
  258. symbol_num = data_example.elements_num("pred_pseudo_label")
  259. max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)
  260. if self.use_zoopt:
  261. solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num)
  262. revision_idx = np.where(solution.get_x() != 0)[0]
  263. candidates, reasoning_results = self.kb.revise_at_idx(
  264. pseudo_label=data_example.pred_pseudo_label,
  265. y=data_example.Y,
  266. x=data_example.X,
  267. revision_idx=revision_idx
  268. )
  269. else:
  270. candidates, reasoning_results = self.kb.abduce_candidates(
  271. pseudo_label=data_example.pred_pseudo_label,
  272. y=data_example.Y,
  273. x=data_example.X,
  274. max_revision_num=max_revision_num,
  275. require_more_revision=self.require_more_revision
  276. )
  277. candidate = self._get_one_candidate(data_example, candidates, reasoning_results)
  278. return candidate
  279. def batch_abduce(self, data_examples: ListData) -> List[List[Any]]:
  280. """
  281. Perform abductive reasoning on the given prediction data examples.
  282. For detailed information, refer to `abduce`.
  283. """
  284. abduced_pseudo_label = [self.abduce(data_example) for data_example in data_examples]
  285. data_examples.abduced_pseudo_label = abduced_pseudo_label
  286. return abduced_pseudo_label
  287. def __call__(self, data_examples: ListData) -> List[List[Any]]:
  288. return self.batch_abduce(data_examples)

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.