You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kb.py 14 kB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. # coding: utf-8
  2. #================================================================#
  3. # Copyright (C) 2021 LAMDA All rights reserved.
  4. #
  5. # File Name :kb.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2021/06/03
  9. # Description :
  10. #
  11. #================================================================#
  12. from abc import ABC, abstractmethod
  13. import bisect
  14. import copy
  15. import numpy as np
  16. from collections import defaultdict
  17. from itertools import product, combinations
  18. import pyswip
  19. class KBBase(ABC):
  20. def __init__(self, pseudo_label_list = None):
  21. pass
  22. @abstractmethod
  23. def logic_forward(self):
  24. pass
  25. @abstractmethod
  26. def abduce_candidates(self):
  27. pass
  28. def address(self, address_num, pred_res, key, multiple_predictions = False):
  29. new_candidates = []
  30. if not multiple_predictions:
  31. address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
  32. else:
  33. address_idx_list = list(combinations(list(range(len(self.flatten(pred_res)))), address_num))
  34. for address_idx in address_idx_list:
  35. candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions)
  36. new_candidates += candidates
  37. return new_candidates
  38. def correct_result(self, pred_res, key):
  39. if type(key) != bool:
  40. return abs(self.logic_forward(pred_res) - key) <= 1e-3
  41. else:
  42. return self.logic_forward(pred_res)
  43. def abduction(self, pred_res, key, max_address_num, require_more_address, multiple_predictions = False):
  44. candidates = []
  45. for address_num in range(len(pred_res) + 1):
  46. if address_num == 0:
  47. if self.correct_result(pred_res, key):
  48. candidates.append(pred_res)
  49. else:
  50. new_candidates = self.address(address_num, pred_res, key, multiple_predictions)
  51. candidates += new_candidates
  52. if len(candidates) > 0:
  53. min_address_num = address_num
  54. break
  55. if address_num >= max_address_num:
  56. return [], 0, 0
  57. for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
  58. if address_num > max_address_num:
  59. return candidates, min_address_num, address_num - 1
  60. new_candidates = self.address(address_num, pred_res, key, multiple_predictions)
  61. candidates += new_candidates
  62. return candidates, min_address_num, address_num
  63. # for multiple predictions, modify from `learn_add.py`
  64. def flatten(self, l):
  65. return [item for sublist in l for item in sublist]
  66. # for multiple predictions, modify from `learn_add.py`
  67. def reform_ids(self, flatten_pred_res, save_pred_res):
  68. re = []
  69. i = 0
  70. for e in save_pred_res:
  71. j = 0
  72. ids = []
  73. while j < len(e):
  74. ids.append(flatten_pred_res[i + j])
  75. j += 1
  76. re.append(ids)
  77. i = i + j
  78. return re
  79. def __len__(self):
  80. pass
  81. class ClsKB(KBBase):
  82. def __init__(self, GKB_flag = False, pseudo_label_list = None, len_list = None):
  83. super().__init__()
  84. self.GKB_flag = GKB_flag
  85. self.pseudo_label_list = pseudo_label_list
  86. self.len_list = len_list
  87. self.prolog_flag = False
  88. if GKB_flag:
  89. self.base = {}
  90. X, Y = self.get_GKB(self.pseudo_label_list, self.len_list)
  91. for x, y in zip(X, Y):
  92. self.base.setdefault(len(x), defaultdict(list))[y].append(x)
  93. else:
  94. self.all_address_candidate_dict = {}
  95. for address_num in range(max(self.len_list) + 1):
  96. self.all_address_candidate_dict[address_num] = list(product(self.pseudo_label_list, repeat = address_num))
  97. def get_GKB(self, pseudo_label_list, len_list):
  98. all_X = []
  99. for len in len_list:
  100. all_X += list(product(pseudo_label_list, repeat = len))
  101. X = []
  102. Y = []
  103. for x in all_X:
  104. y = self.logic_forward(x)
  105. if y != np.inf:
  106. X.append(x)
  107. Y.append(y)
  108. return X, Y
  109. def logic_forward(self):
  110. pass
  111. def abduce_candidates(self, pred_res, key, max_address_num = -1, require_more_address = 0, multiple_predictions = False):
  112. if self.GKB_flag:
  113. return self.abduce_from_GKB(pred_res, key, max_address_num, require_more_address)
  114. else:
  115. return self.abduction(pred_res, key, max_address_num, require_more_address, multiple_predictions)
  116. def hamming_dist(self, A, B):
  117. B = np.array(B)
  118. A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B)))
  119. return np.sum(A != B, axis = 1)
  120. def abduce_from_GKB(self, pred_res, key, max_address_num, require_more_address):
  121. if self.base == {} or len(pred_res) not in self.len_list:
  122. return []
  123. all_candidates = self.base[len(pred_res)][key]
  124. if len(all_candidates) == 0:
  125. candidates = []
  126. min_address_num = 0
  127. address_num = 0
  128. else:
  129. cost_list = self.hamming_dist(pred_res, all_candidates)
  130. min_address_num = np.min(cost_list)
  131. address_num = min(max_address_num, min_address_num + require_more_address)
  132. idxs = np.where(cost_list <= address_num)[0]
  133. candidates = [all_candidates[idx] for idx in idxs]
  134. return candidates, min_address_num, address_num
  135. def address_by_idx(self, pred_res, key, address_idx, multiple_predictions = False):
  136. candidates = []
  137. abduce_c = self.all_address_candidate_dict[len(address_idx)]
  138. if multiple_predictions:
  139. save_pred_res = pred_res
  140. pred_res = self.flatten(pred_res)
  141. for c in abduce_c:
  142. candidate = pred_res.copy()
  143. for i, idx in enumerate(address_idx):
  144. candidate[idx] = c[i]
  145. if multiple_predictions:
  146. candidate = self.reform_ids(candidate, save_pred_res)
  147. if self.logic_forward(candidate) == key:
  148. candidates.append(candidate)
  149. return candidates
  150. def _dict_len(self, dic):
  151. if not self.GKB_flag:
  152. return 0
  153. else:
  154. return sum(len(c) for c in dic.values())
  155. def __len__(self):
  156. if not self.GKB_flag:
  157. return 0
  158. else:
  159. return sum(self._dict_len(v) for v in self.base.values())
  160. class add_KB(ClsKB):
  161. def __init__(self, GKB_flag = False, \
  162. pseudo_label_list = list(range(10)), \
  163. len_list = [2]):
  164. super().__init__(GKB_flag, pseudo_label_list, len_list)
  165. def logic_forward(self, nums):
  166. return sum(nums)
  167. class HWF_KB(ClsKB):
  168. def __init__(self, GKB_flag = False, \
  169. pseudo_label_list = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \
  170. len_list = [1, 3, 5, 7]):
  171. super().__init__(GKB_flag, pseudo_label_list, len_list)
  172. def valid_candidate(self, formula):
  173. if len(formula) % 2 == 0:
  174. return False
  175. for i in range(len(formula)):
  176. if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']:
  177. return False
  178. if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']:
  179. return False
  180. return True
  181. def logic_forward(self, formula):
  182. if not self.valid_candidate(formula):
  183. return np.inf
  184. mapping = {'1':'1', '2':'2', '3':'3', '4':'4', '5':'5', '6':'6', '7':'7', '8':'8', '9':'9', '+':'+', '-':'-', 'times':'*', 'div':'/'}
  185. formula = [mapping[f] for f in formula]
  186. return round(eval(''.join(formula)), 2)
  187. class prolog_KB(KBBase):
  188. def __init__(self, pseudo_label_list):
  189. super().__init__()
  190. self.pseudo_label_list = pseudo_label_list
  191. self.prolog = pyswip.Prolog()
  192. def logic_forward(self):
  193. pass
  194. def abduce_candidates(self, pred_res, key, max_address_num, require_more_address, multiple_predictions):
  195. return self.abduction(pred_res, key, max_address_num, require_more_address, multiple_predictions)
  196. def address_by_idx(self, pred_res, key, address_idx, multiple_predictions = False):
  197. candidates = []
  198. if not multiple_predictions:
  199. query_string = self.get_query_string(pred_res, key, address_idx)
  200. else:
  201. query_string = self.get_query_string_need_flatten(pred_res, key, address_idx)
  202. if multiple_predictions:
  203. save_pred_res = pred_res
  204. pred_res = self.flatten(pred_res)
  205. abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string))]
  206. for c in abduce_c:
  207. candidate = pred_res.copy()
  208. for i, idx in enumerate(address_idx):
  209. candidate[idx] = c[i]
  210. if multiple_predictions:
  211. candidate = self.reform_ids(candidate, save_pred_res)
  212. candidates.append(candidate)
  213. return candidates
  214. class add_prolog_KB(prolog_KB):
  215. def __init__(self, pseudo_label_list = list(range(10))):
  216. super().__init__(pseudo_label_list)
  217. for i in self.pseudo_label_list:
  218. self.prolog.assertz("pseudo_label(%s)" % i)
  219. self.prolog.assertz("addition(Z1, Z2, Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2")
  220. def logic_forward(self, nums):
  221. return list(self.prolog.query("addition(%s, %s, Res)." %(nums[0], nums[1])))[0]['Res']
  222. def get_query_string(self, pred_res, key, address_idx):
  223. query_string = "addition("
  224. for idx, i in enumerate(pred_res):
  225. tmp = 'Z' + str(idx) + ',' if idx in address_idx else str(i) + ','
  226. query_string += tmp
  227. query_string += "%s)." % key
  228. return query_string
  229. class HED_prolog_KB(prolog_KB):
  230. def __init__(self, pseudo_label_list = [0, 1, '+', '=']):
  231. super().__init__(pseudo_label_list)
  232. self.prolog.consult('../datasets/hed/learn_add.pl')
  233. # corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py`
  234. def logic_forward(self, exs):
  235. return len(list(self.prolog.query("abduce_consistent_insts(%s)." % exs))) != 0
  236. def get_query_string_need_flatten(self, pred_res, key, address_idx):
  237. # flatten
  238. flatten_pred_res = self.flatten(pred_res)
  239. # add variables for prolog
  240. for idx in range(len(flatten_pred_res)):
  241. if idx in address_idx:
  242. flatten_pred_res[idx] = 'X' + str(idx)
  243. # unflatten
  244. new_pred_res = self.reform_ids(flatten_pred_res, pred_res)
  245. query_string = "abduce_consistent_insts(%s)." % new_pred_res
  246. return query_string.replace("'", "").replace("+", "'+'").replace("=", "'='")
  247. def consist_rule(self, exs, rules):
  248. rule_str = "%s" % rules
  249. rule_str = rule_str.replace("'", "")
  250. return len(list(self.prolog.query("consistent_inst_feature(%s, %s)." %(exs, rule_str)))) != 0
  251. def abduce_rules(self, pred_res):
  252. prolog_rules = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res))[0]['X']
  253. rules = []
  254. for rule in prolog_rules:
  255. rules.append(rule.value)
  256. return rules
  257. # def consist_rules(self, pred_res, rules):
  258. class RegKB(KBBase):
  259. def __init__(self, GKB_flag = False, X = None, Y = None):
  260. super().__init__()
  261. tmp_dict = {}
  262. for x, y in zip(X, Y):
  263. tmp_dict.setdefault(len(x), defaultdict(list))[y].append(np.array(x))
  264. self.base = {}
  265. for l in tmp_dict.keys():
  266. data = sorted(list(zip(tmp_dict[l].keys(), tmp_dict[l].values())))
  267. X = [x for y, x in data]
  268. Y = [y for y, x in data]
  269. self.base[l] = (X, Y)
  270. def valid_candidate(self):
  271. pass
  272. def logic_forward(self):
  273. pass
  274. def abduce_candidates(self, key, length = None):
  275. if key is None:
  276. return self.get_all_candidates()
  277. length = self._length(length)
  278. min_err = 999999
  279. candidates = []
  280. for l in length:
  281. X, Y = self.base[l]
  282. idx = bisect.bisect_left(Y, key)
  283. begin = max(0, idx - 1)
  284. end = min(idx + 2, len(X))
  285. for idx in range(begin, end):
  286. err = abs(Y[idx] - key)
  287. if abs(err - min_err) < 1e-9:
  288. candidates.extend(X[idx])
  289. elif err < min_err:
  290. candidates = copy.deepcopy(X[idx])
  291. min_err = err
  292. return candidates
  293. def get_all_candidates(self):
  294. return sum([sum(D[0], []) for D in self.base.values()], [])
  295. def __len__(self):
  296. return sum([sum(len(x) for x in D[0]) for D in self.base.values()])
  297. import time
  298. if __name__ == "__main__":
  299. pass
  300. # X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"]
  301. # Y = [2, 1, 1, 2, 2]
  302. # kb = ClsKB(X, Y)
  303. # print('len(kb):', len(kb))
  304. # res = kb.get_candidates(2, 5)
  305. # print(res)
  306. # res = kb.get_candidates(2, 3)
  307. # print(res)
  308. # res = kb.get_candidates(None)
  309. # print(res)
  310. # print()
  311. # X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"]
  312. # Y = [2, 1, 1, 2, 1.5, 1.5]
  313. # kb = RegKB(X, Y)
  314. # print('len(kb):', len(kb))
  315. # res = kb.get_candidates(1.6)
  316. # print(res)
  317. # res = kb.get_candidates(1.6, length = 9)
  318. # print(res)
  319. # res = kb.get_candidates(None)
  320. # print(res)

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.