You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kb.py 12 kB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago

  1. # coding: utf-8
  2. #================================================================#
  3. # Copyright (C) 2021 LAMDA All rights reserved.
  4. #
  5. # File Name :kb.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2021/06/03
  9. # Description :
  10. #
  11. #================================================================#
  12. from abc import ABC, abstractmethod
  13. import bisect
  14. import copy
  15. import numpy as np
  16. from collections import defaultdict
  17. from itertools import product, combinations
  18. class KBBase(ABC):
  19. def __init__(self):
  20. pass
  21. @abstractmethod
  22. def get_candidates(self):
  23. pass
  24. @abstractmethod
  25. def get_all_candidates(self):
  26. pass
  27. @abstractmethod
  28. def logic_forward(self, X):
  29. pass
  30. def _length(self, length):
  31. if length is None:
  32. length = list(self.base.keys())
  33. if type(length) is int:
  34. length = [length]
  35. return length
  36. def __len__(self):
  37. pass
  38. class add_KB(KBBase):
  39. def __init__(self, pseudo_label_list, max_len = 5):
  40. super().__init__()
  41. self.pseudo_label_list = pseudo_label_list
  42. self.base = {}
  43. X = self.get_X(self.pseudo_label_list, max_len)
  44. Y = self.get_Y(X, self.logic_forward)
  45. for x, y in zip(X, Y):
  46. self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x))
  47. def logic_forward(self, nums):
  48. return sum(nums)
  49. def get_X(self, pseudo_label_list, max_len):
  50. res = []
  51. assert(max_len >= 2)
  52. for len in range(2, max_len + 1):
  53. res += list(product(pseudo_label_list, repeat = len))
  54. return res
  55. def get_Y(self, X, logic_forward):
  56. return [logic_forward(nums) for nums in X]
  57. def get_candidates(self, key, length = None):
  58. if key is None:
  59. return self.get_all_candidates()
  60. length = self._length(length)
  61. return sum([self.base[l][key] for l in length], [])
  62. def get_all_candidates(self):
  63. return sum([sum(v.values(), []) for v in self.base.values()], [])
  64. def get_abduce_candidates(self, pred_res, key, max_address_num, require_more_address):
  65. if key is None:
  66. return self.get_all_candidates()
  67. candidates = []
  68. for address_num in range(len(pred_res) + 1):
  69. if(address_num > max_address_num):
  70. print('No candidates found')
  71. return None, None, None
  72. if(address_num == 0):
  73. if(self.logic_forward(pred_res) == key):
  74. candidates.append(pred_res)
  75. else:
  76. all_address_candidate = list(product(self.pseudo_label_list, repeat = address_num))
  77. address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
  78. for address_idx in address_idx_list:
  79. for c in all_address_candidate:
  80. pred_res_array = np.array(pred_res)
  81. if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num):
  82. pred_res_array[np.array(address_idx)] = c
  83. if(self.logic_forward(pred_res_array) == key):
  84. candidates.append(pred_res_array)
  85. if(len(candidates) > 0):
  86. min_address_num = address_num
  87. break
  88. for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
  89. if(address_num > max_address_num):
  90. return candidates, min_address_num, address_num - 1
  91. all_candidate = list(product(self.pseudo_label_list, repeat = address_num))
  92. address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
  93. for address_idx in address_idx_list:
  94. for c in all_candidate:
  95. pred_res_array = np.array(pred_res)
  96. if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num):
  97. pred_res_array[np.array(address_idx)] = c
  98. if(self.logic_forward(pred_res_array) == key):
  99. candidates.append(pred_res_array)
  100. return candidates, min_address_num, address_num
  101. def _dict_len(self, dic):
  102. return sum(len(c) for c in dic.values())
  103. def __len__(self):
  104. return sum(self._dict_len(v) for v in self.base.values())
  105. # class hwf_KB(KBBase):
  106. # def __init__(self, pseudo_label_list, max_len = 5):
  107. # super().__init__()
  108. # self.pseudo_label_list = pseudo_label_list
  109. # self.base = {}
  110. # X = self.get_X(self.pseudo_label_list, max_len)
  111. # Y = self.get_Y(X, self.logic_forward)
  112. # for x, y in zip(X, Y):
  113. # self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x))
  114. # def logic_forward(self, nums):
  115. # return sum(nums)
  116. # def get_X(self, pseudo_label_list, max_len):
  117. # res = []
  118. # assert(max_len >= 2)
  119. # for len in range(2, max_len + 1):
  120. # res += list(product(pseudo_label_list, repeat = len))
  121. # return res
  122. # def get_Y(self, X, logic_forward):
  123. # return [logic_forward(nums) for nums in X]
  124. # def get_candidates(self, key, length = None):
  125. # if key is None:
  126. # return self.get_all_candidates()
  127. # length = self._length(length)
  128. # return sum([self.base[l][key] for l in length], [])
  129. # def get_all_candidates(self):
  130. # return sum([sum(v.values(), []) for v in self.base.values()], [])
  131. # def get_abduce_candidates(self, pred_res, key, length, dist_func, max_address_num, require_more_address):
  132. # if key is None:
  133. # return self.get_all_candidates()
  134. # candidates = []
  135. # # all_candidates = list(product(self.pseudo_label_list, repeat = len(pred_res)))
  136. # for address_num in range(length + 1):
  137. # if(address_num > max_address_num):
  138. # print('No candidates found')
  139. # return None, None, None
  140. # if(address_num == 0):
  141. # if(self.logic_forward(pred_res) == key):
  142. # candidates.append(pred_res)
  143. # else:
  144. # all_address_candidate = list(product(self.pseudo_label_list, repeat = address_num))
  145. # address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
  146. # for address_idx in address_idx_list:
  147. # for c in all_address_candidate:
  148. # pred_res_array = np.array(pred_res)
  149. # pred_res_array[np.array(address_idx)] = c
  150. # if(np.count_nonzero(np.array(c) != np.array(pred_res)[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key):
  151. # candidates.append(pred_res_array)
  152. # if(len(candidates) > 0):
  153. # min_address_num = address_num
  154. # break
  155. # for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
  156. # if(address_num > max_address_num):
  157. # return candidates, min_address_num, address_num - 1
  158. # all_candidate = list(product(self.pseudo_label_list, repeat = address_num))
  159. # address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
  160. # for address_idx in address_idx_list:
  161. # for c in all_candidate:
  162. # pred_res_array = np.array(pred_res)
  163. # pred_res_array[np.array(address_idx)] = c
  164. # if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key):
  165. # candidates.append(pred_res_array)
  166. # return candidates, min_address_num, address_num
  167. # def _dict_len(self, dic):
  168. # return sum(len(c) for c in dic.values())
  169. # def __len__(self):
  170. # return sum(self._dict_len(v) for v in self.base.values())
  171. class cls_KB(KBBase):
  172. def __init__(self, X, Y = None):
  173. super().__init__()
  174. self.base = {}
  175. if X is None:
  176. return
  177. if Y is None:
  178. Y = [None] * len(X)
  179. for x, y in zip(X, Y):
  180. self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x))
  181. def logic_forward(self):
  182. return None
  183. def get_candidates(self, key, length = None):
  184. if key is None:
  185. return self.get_all_candidates()
  186. length = self._length(length)
  187. return sum([self.base[l][key] for l in length], [])
  188. def get_all_candidates(self):
  189. return sum([sum(v.values(), []) for v in self.base.values()], [])
  190. def _dict_len(self, dic):
  191. return sum(len(c) for c in dic.values())
  192. def __len__(self):
  193. return sum(self._dict_len(v) for v in self.base.values())
  194. class reg_KB(KBBase):
  195. def __init__(self, X, Y = None):
  196. super().__init__()
  197. tmp_dict = {}
  198. for x, y in zip(X, Y):
  199. tmp_dict.setdefault(len(x), defaultdict(list))[y].append(np.array(x))
  200. self.base = {}
  201. for l in tmp_dict.keys():
  202. data = sorted(list(zip(tmp_dict[l].keys(), tmp_dict[l].values())))
  203. X = [x for y, x in data]
  204. Y = [y for y, x in data]
  205. self.base[l] = (X, Y)
  206. def logic_forward(self):
  207. return None
  208. def get_candidates(self, key, length = None):
  209. if key is None:
  210. return self.get_all_candidates()
  211. length = self._length(length)
  212. min_err = 999999
  213. candidates = []
  214. for l in length:
  215. X, Y = self.base[l]
  216. idx = bisect.bisect_left(Y, key)
  217. begin = max(0, idx - 1)
  218. end = min(idx + 2, len(X))
  219. for idx in range(begin, end):
  220. err = abs(Y[idx] - key)
  221. if abs(err - min_err) < 1e-9:
  222. candidates.extend(X[idx])
  223. elif err < min_err:
  224. candidates = copy.deepcopy(X[idx])
  225. min_err = err
  226. return candidates
  227. def get_all_candidates(self):
  228. return sum([sum(D[0], []) for D in self.base.values()], [])
  229. def __len__(self):
  230. return sum([sum(len(x) for x in D[0]) for D in self.base.values()])
  231. if __name__ == "__main__":
  232. pseudo_label_list = list(range(10))
  233. kb = add_KB(pseudo_label_list, max_len = 5)
  234. print('len(kb):', len(kb))
  235. print()
  236. res = kb.get_candidates(0)
  237. print(res)
  238. print()
  239. res = kb.get_candidates(18, length = 2)
  240. print(res)
  241. print()
  242. res = kb.get_candidates(7, length = 3)
  243. print(res)
  244. print()
  245. pseudo_label_list = list(range(10)) + ['+', '-', '*', '/']
  246. kb = hwf_KB(pseudo_label_list, max_len = 5)
  247. print('len(kb):', len(kb))
  248. print()
  249. X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"]
  250. Y = [2, 1, 1, 2, 2]
  251. kb = cls_KB(X, Y)
  252. print('len(kb):', len(kb))
  253. res = kb.get_candidates(2, 5)
  254. print(res)
  255. res = kb.get_candidates(2, 3)
  256. print(res)
  257. res = kb.get_candidates(None)
  258. print(res)
  259. print()
  260. X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"]
  261. Y = [2, 1, 1, 2, 1.5, 1.5]
  262. kb = reg_KB(X, Y)
  263. print('len(kb):', len(kb))
  264. res = kb.get_candidates(1.6)
  265. print(res)
  266. res = kb.get_candidates(1.6, length = 9)
  267. print(res)
  268. res = kb.get_candidates(None)
  269. print(res)

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.