You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kb.py 19 kB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. from abc import ABC, abstractmethod
  2. import bisect
  3. import os
  4. from collections import defaultdict
  5. from itertools import product, combinations
  6. from multiprocessing import Pool
  7. from functools import lru_cache
  8. import numpy as np
  9. import pyswip
  10. from ..utils.utils import flatten, reform_list, hamming_dist, to_hashable
  11. from ..utils.cache import abl_cache
  12. class KBBase(ABC):
  13. """
  14. Base class for knowledge base.
  15. Parameters
  16. ----------
  17. pseudo_label_list : list
  18. List of possible pseudo labels. It's recommended to arrange the pseudo labels in this
  19. list so that each aligns with its corresponding index in the base model: the first with
  20. the 0th index, the second with the 1st, and so forth.
  21. max_err : float, optional
  22. The upper tolerance limit when comparing the similarity between a pseudo label sample's reasoning
  23. result and the ground truth. This is only applicable when the reasoning result is of a numerical type.
  24. This is particularly relevant for regression problems where exact matches might not be
  25. feasible. Defaults to 1e-10.
  26. use_cache : bool, optional
  27. Whether to use abl_cache for previously abduced candidates to speed up subsequent
  28. operations. Defaults to True.
  29. key_func : func, optional
  30. A function employed for hashing in abl_cache. This is only operational when use_cache
  31. is set to True. Defaults to to_hashable.
  32. cache_size: int, optional
  33. The cache size in abl_cache. This is only operational when use_cache is set to
  34. True. Defaults to 4096.
  35. Notes
  36. -----
  37. Users should inherit from this base class to build their own knowledge base. For the
  38. user-build KB (an inherited subclass), it's only required for the user to provide the
  39. `pseudo_label_list` and override the `logic_forward` function (specifying how to
  40. perform logical reasoning). After that, other operations (e.g. how to perform abductive
  41. reasoning) will be automatically set up.
  42. """
  43. def __init__(
  44. self,
  45. pseudo_label_list,
  46. max_err=1e-10,
  47. use_cache=True,
  48. key_func=to_hashable,
  49. cache_size=4096,
  50. ):
  51. if not isinstance(pseudo_label_list, list):
  52. raise TypeError("pseudo_label_list should be list")
  53. self.pseudo_label_list = pseudo_label_list
  54. self.max_err = max_err
  55. self.use_cache = use_cache
  56. self.key_func = key_func
  57. self.cache_size = cache_size
  58. @abstractmethod
  59. def logic_forward(self, pseudo_label):
  60. """
  61. How to perform (deductive) logical reasoning, i.e. matching each pseudo label sample to
  62. their reasoning result. Users are required to provide this.
  63. Parameters
  64. ----------
  65. pseudo_label : List[Any]
  66. Pseudo label sample.
  67. """
  68. pass
  69. def abduce_candidates(self, pseudo_label, y, max_revision_num, require_more_revision):
  70. """
  71. Perform abductive reasoning to get a candidate compatible with the knowledge base.
  72. Parameters
  73. ----------
  74. pseudo_label : List[Any]
  75. Pseudo label sample (to be revised by abductive reasoning).
  76. y : any
  77. Ground truth of the reasoning result for the sample.
  78. max_revision_num : int
  79. The upper limit on the number of revised labels for each sample.
  80. require_more_revision : int
  81. Specifies additional number of revisions permitted beyond the minimum required.
  82. Returns
  83. -------
  84. List[List[Any]]
  85. A list of candidates, i.e. revised pseudo label samples that are compatible with the
  86. knowledge base.
  87. """
  88. return self._abduce_by_search(pseudo_label, y, max_revision_num, require_more_revision)
  89. def _check_equal(self, logic_result, y):
  90. """
  91. Check whether the reasoning result of a pseduo label sample is equal to the ground truth
  92. (or, within the maximum error allowed for numerical results).
  93. Returns
  94. -------
  95. bool
  96. The result of the check.
  97. """
  98. if logic_result is None:
  99. return False
  100. if isinstance(logic_result, (int, float)) and isinstance(y, (int, float)):
  101. return abs(logic_result - y) <= self.max_err
  102. else:
  103. return logic_result == y
  104. def revise_at_idx(self, pseudo_label, y, revision_idx):
  105. """
  106. Revise the pseudo label sample at specified index positions.
  107. Parameters
  108. ----------
  109. pseudo_label : List[Any]
  110. Pseudo label sample (to be revised).
  111. y : Any
  112. Ground truth of the reasoning result for the sample.
  113. revision_idx : array-like
  114. Indices of where revisions should be made to the pseudo label sample.
  115. Returns
  116. -------
  117. List[List[Any]]
  118. A list of candidates, i.e. revised pseudo label samples that are compatible with the
  119. knowledge base.
  120. """
  121. candidates = []
  122. abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx))
  123. for c in abduce_c:
  124. candidate = pseudo_label.copy()
  125. for i, idx in enumerate(revision_idx):
  126. candidate[idx] = c[i]
  127. if self._check_equal(self.logic_forward(candidate), y):
  128. candidates.append(candidate)
  129. return candidates
  130. def _revision(self, revision_num, pseudo_label, y):
  131. """
  132. For a specified number of labels in a pseudo label sample to revise, iterate through all possible
  133. indices to find any candidates that are compatible with the knowledge base.
  134. """
  135. new_candidates = []
  136. revision_idx_list = combinations(range(len(pseudo_label)), revision_num)
  137. for revision_idx in revision_idx_list:
  138. candidates = self.revise_at_idx(pseudo_label, y, revision_idx)
  139. new_candidates.extend(candidates)
  140. return new_candidates
  141. @abl_cache()
  142. def _abduce_by_search(self, pseudo_label, y, max_revision_num, require_more_revision):
  143. """
  144. Perform abductive reasoning by exhastive search. Specifically, begin with 0 and
  145. continuously increase the number of labels in a pseudo label sample to revise, until candidates
  146. that are compatible with the knowledge base are found.
  147. Parameters
  148. ----------
  149. pseudo_label : List[Any]
  150. Pseudo label sample (to be revised).
  151. y : Any
  152. Ground truth of the reasoning result for the sample.
  153. max_revision_num : int
  154. The upper limit on the number of revisions.
  155. require_more_revision : int
  156. If larger than 0, then after having found any candidates compatible with the
  157. knowledge base, continue to increase the number of labels in a pseudo label sample to revise to
  158. get more possible compatible candidates.
  159. Returns
  160. -------
  161. List[List[Any]]
  162. A list of candidates, i.e. revised pseudo label samples that are compatible with the
  163. knowledge base.
  164. """
  165. candidates = []
  166. for revision_num in range(len(pseudo_label) + 1):
  167. if revision_num == 0 and self._check_equal(self.logic_forward(pseudo_label), y):
  168. candidates.append(pseudo_label)
  169. elif revision_num > 0:
  170. candidates.extend(self._revision(revision_num, pseudo_label, y))
  171. if len(candidates) > 0:
  172. min_revision_num = revision_num
  173. break
  174. if revision_num >= max_revision_num:
  175. return []
  176. for revision_num in range(
  177. min_revision_num + 1, min_revision_num + require_more_revision + 1
  178. ):
  179. if revision_num > max_revision_num:
  180. return candidates
  181. candidates.extend(self._revision(revision_num, pseudo_label, y))
  182. return candidates
  183. def __repr__(self):
  184. return (
  185. f"{self.__class__.__name__} is a KB with "
  186. f"pseudo_label_list={self.pseudo_label_list!r}, "
  187. f"max_err={self.max_err!r}, "
  188. f"use_cache={self.use_cache!r}."
  189. )
  190. class GroundKB(KBBase):
  191. """
  192. Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon
  193. class initialization, storing all potential candidates along with their respective
  194. reasoning result. Ground KB can accelerate abductive reasoning in `abduce_candidates`.
  195. Parameters
  196. ----------
  197. pseudo_label_list : list
  198. Refer to class `KBBase`.
  199. GKB_len_list : list
  200. List of possible lengths for a pseudo label sample.
  201. max_err : float, optional
  202. Refer to class `KBBase`.
  203. Notes
  204. -----
  205. Users can also inherit from this class to build their own knowledge base. Similar
  206. to `KBBase`, users are only required to provide the `pseudo_label_list` and override
  207. the `logic_forward` function. Additionally, users should provide the `GKB_len_list`.
  208. After that, other operations (e.g. auto-construction of GKB, and how to perform
  209. abductive reasoning) will be automatically set up.
  210. """
  211. def __init__(self, pseudo_label_list, GKB_len_list, max_err=1e-10):
  212. super().__init__(pseudo_label_list, max_err)
  213. if not isinstance(GKB_len_list, list):
  214. raise TypeError("GKB_len_list should be list")
  215. self.GKB_len_list = GKB_len_list
  216. self.GKB = {}
  217. X, Y = self._get_GKB()
  218. for x, y in zip(X, Y):
  219. self.GKB.setdefault(len(x), defaultdict(list))[y].append(x)
  220. def _get_XY_list(self, args):
  221. pre_x, post_x_it = args[0], args[1]
  222. XY_list = []
  223. for post_x in post_x_it:
  224. x = (pre_x,) + post_x
  225. y = self.logic_forward(x)
  226. if y is not None:
  227. XY_list.append((x, y))
  228. return XY_list
  229. def _get_GKB(self):
  230. """
  231. Prebuild the GKB according to `pseudo_label_list` and `GKB_len_list`.
  232. """
  233. X, Y = [], []
  234. for length in self.GKB_len_list:
  235. arg_list = []
  236. for pre_x in self.pseudo_label_list:
  237. post_x_it = product(self.pseudo_label_list, repeat=length - 1)
  238. arg_list.append((pre_x, post_x_it))
  239. with Pool(processes=len(arg_list)) as pool:
  240. ret_list = pool.map(self._get_XY_list, arg_list)
  241. for XY_list in ret_list:
  242. if len(XY_list) == 0:
  243. continue
  244. part_X, part_Y = zip(*XY_list)
  245. X.extend(part_X)
  246. Y.extend(part_Y)
  247. if Y and isinstance(Y[0], (int, float)):
  248. X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1]))
  249. return X, Y
  250. def abduce_candidates(self, pseudo_label, y, max_revision_num, require_more_revision):
  251. """
  252. Perform abductive reasoning by directly retrieving compatible candidates from
  253. the prebuilt GKB. In this way, the time-consuming exhaustive search can be
  254. avoided.
  255. Parameters
  256. ----------
  257. pseudo_label : List[Any]
  258. Pseudo label sample (to be revised by abductive reasoning).
  259. y : any
  260. Ground truth of the reasoning result for the sample.
  261. max_revision_num : int
  262. The upper limit on the number of revised labels for each sample.
  263. require_more_revision : int, optional
  264. Specifies additional number of revisions permitted beyond the minimum required.
  265. Returns
  266. -------
  267. List[List[Any]]
  268. A list of candidates, i.e. revised pseudo label samples that are compatible with the
  269. knowledge base.
  270. """
  271. if self.GKB == {} or len(pseudo_label) not in self.GKB_len_list:
  272. return []
  273. all_candidates = self._find_candidate_GKB(pseudo_label, y)
  274. if len(all_candidates) == 0:
  275. return []
  276. cost_list = hamming_dist(pseudo_label, all_candidates)
  277. min_revision_num = np.min(cost_list)
  278. revision_num = min(max_revision_num, min_revision_num + require_more_revision)
  279. idxs = np.where(cost_list <= revision_num)[0]
  280. candidates = [all_candidates[idx] for idx in idxs]
  281. return candidates
  282. def _find_candidate_GKB(self, pseudo_label, y):
  283. """
  284. Retrieve compatible candidates from the prebuilt GKB. For numerical reasoning results,
  285. return all candidates whose reasoning results fall within the
  286. [y - max_err, y + max_err] range.
  287. """
  288. if isinstance(y, (int, float)):
  289. potential_candidates = self.GKB[len(pseudo_label)]
  290. key_list = list(potential_candidates.keys())
  291. low_key = bisect.bisect_left(key_list, y - self.max_err)
  292. high_key = bisect.bisect_right(key_list, y + self.max_err)
  293. all_candidates = [
  294. candidate
  295. for key in key_list[low_key:high_key]
  296. for candidate in potential_candidates[key]
  297. ]
  298. return all_candidates
  299. else:
  300. return self.GKB[len(pseudo_label)][y]
  301. def __repr__(self):
  302. GKB_info_parts = []
  303. for i in self.GKB_len_list:
  304. num_candidates = len(self.GKB[i]) if i in self.GKB else 0
  305. GKB_info_parts.append(f"{num_candidates} candidates of length {i}")
  306. GKB_info = ", ".join(GKB_info_parts)
  307. return (
  308. f"{self.__class__.__name__} is a KB with "
  309. f"pseudo_label_list={self.pseudo_label_list!r}, "
  310. f"max_err={self.max_err!r}, "
  311. f"use_cache={self.use_cache!r}. "
  312. f"It has a prebuilt GKB with "
  313. f"GKB_len_list={self.GKB_len_list!r}, "
  314. f"and there are "
  315. f"{GKB_info}"
  316. f" in the GKB."
  317. )
  318. class PrologKB(KBBase):
  319. """
  320. Knowledge base provided by a Prolog (.pl) file.
  321. Parameters
  322. ----------
  323. pseudo_label_list : list
  324. Refer to class `KBBase`.
  325. pl_file :
  326. Prolog file containing the KB.
  327. max_err : float, optional
  328. Refer to class `KBBase`.
  329. Notes
  330. -----
  331. Users can instantiate this class to build their own knowledge base. During the
  332. instantiation, users are only required to provide the `pseudo_label_list` and `pl_file`.
  333. To use the default logic forward and abductive reasoning methods in this class, in the
  334. Prolog (.pl) file, there needs to be a rule which is strictly formatted as
  335. `logic_forward(Pseudo_labels, Res).`, e.g., `logic_forward([A,B], C) :- C is A+B`.
  336. For specifics, refer to the `logic_forward` and `get_query_string` functions in this
  337. class. Users are also welcome to override related functions for more flexible support.
  338. """
  339. def __init__(self, pseudo_label_list, pl_file):
  340. super().__init__(pseudo_label_list)
  341. self.pl_file = pl_file
  342. self.prolog = pyswip.Prolog()
  343. if not os.path.exists(self.pl_file):
  344. raise FileNotFoundError(f"The Prolog file {self.pl_file} does not exist.")
  345. self.prolog.consult(self.pl_file)
  346. def logic_forward(self, pseudo_labels):
  347. """
  348. Consult prolog with the query `logic_forward(pseudo_labels, Res).`, and set the
  349. returned `Res` as the reasoning results. To use this default function, there must be
  350. a `logic_forward` method in the pl file to perform reasoning.
  351. Otherwise, users would override this function.
  352. Parameters
  353. ----------
  354. pseudo_label : List[Any]
  355. Pseudo label sample.
  356. """
  357. result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]["Res"]
  358. if result == "true":
  359. return True
  360. elif result == "false":
  361. return False
  362. return result
  363. def _revision_pseudo_label(self, pseudo_label, revision_idx):
  364. import re
  365. revision_pseudo_label = pseudo_label.copy()
  366. revision_pseudo_label = flatten(revision_pseudo_label)
  367. for idx in revision_idx:
  368. revision_pseudo_label[idx] = "P" + str(idx)
  369. revision_pseudo_label = reform_list(revision_pseudo_label, pseudo_label)
  370. regex = r"'P\d+'"
  371. return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pseudo_label))
  372. def get_query_string(self, pseudo_label, y, revision_idx):
  373. """
  374. Get the query to be used for consulting Prolog.
  375. This is a default function for demo, users would override this function to adapt to their own
  376. Prolog file. In this demo function, return query `logic_forward([kept_labels, Revise_labels], Res).`.
  377. Parameters
  378. ----------
  379. pseudo_label : List[Any]
  380. Pseudo label sample (to be revised by abductive reasoning).
  381. y : any
  382. Ground truth of the reasoning result for the sample.
  383. revision_idx : array-like
  384. Indices of where revisions should be made to the pseudo label sample.
  385. Returns
  386. -------
  387. str
  388. A string of the query.
  389. """
  390. query_string = "logic_forward("
  391. query_string += self._revision_pseudo_label(pseudo_label, revision_idx)
  392. key_is_none_flag = y is None or (type(y) == list and y[0] is None)
  393. query_string += ",%s)." % y if not key_is_none_flag else ")."
  394. return query_string
  395. def revise_at_idx(self, pseudo_label, y, revision_idx):
  396. """
  397. Revise the pseudo label sample at specified index positions by querying Prolog.
  398. Parameters
  399. ----------
  400. pseudo_label : List[Any]
  401. Pseudo label sample (to be revised).
  402. y : Any
  403. Ground truth of the reasoning result for the sample.
  404. revision_idx : array-like
  405. Indices of where revisions should be made to the pseudo label sample.
  406. Returns
  407. -------
  408. List[List[Any]]
  409. A list of candidates, i.e. revised pseudo label samples that are compatible with the
  410. knowledge base.
  411. """
  412. candidates = []
  413. query_string = self.get_query_string(pseudo_label, y, revision_idx)
  414. save_pseudo_label = pseudo_label
  415. pseudo_label = flatten(pseudo_label)
  416. abduce_c = [list(z.values()) for z in self.prolog.query(query_string)]
  417. for c in abduce_c:
  418. candidate = pseudo_label.copy()
  419. for i, idx in enumerate(revision_idx):
  420. candidate[idx] = c[i]
  421. candidate = reform_list(candidate, save_pseudo_label)
  422. candidates.append(candidate)
  423. return candidates
  424. def __repr__(self):
  425. return (
  426. f"{self.__class__.__name__} is a KB with "
  427. f"pseudo_label_list={self.pseudo_label_list!r}, "
  428. f"defined by "
  429. f"Prolog file {self.pl_file!r}."
  430. )

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.