You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hwf_kb.py 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import bisect
  2. from collections import defaultdict
  3. from itertools import product
  4. from multiprocessing import Pool
  5. from typing import Any, Hashable, List
  6. import numpy as np
  7. from abl.reasoning import GroundKB
  8. from abl.structures import ListData
  9. from abl.utils import hamming_dist
  10. class HWF_KB(GroundKB):
  11. def __init__(
  12. self,
  13. pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "times", "div"],
  14. GKB_len_list=[1, 3, 5, 7],
  15. max_err=1e-10,
  16. ):
  17. self.GKB_len_list = GKB_len_list
  18. self.max_err = max_err
  19. self.label2evaluable = {str(i): str(i) for i in range(1, 10)}
  20. self.label2evaluable.update({"+": "+", "-": "-", "times": "*", "div": "/"})
  21. super().__init__(pseudo_label_list)
  22. def logic_forward(self, data_sample: ListData):
  23. if not self._valid_candidate(data_sample):
  24. return None
  25. formula = data_sample["pred_pseudo_label"][0]
  26. formula = [self.label2evaluable[f] for f in formula]
  27. data_sample["Y"] = [eval("".join(formula))]
  28. return data_sample["Y"][0]
  29. def check_equal(self, data_sample: ListData, y: Any):
  30. if not self._valid_candidate(data_sample):
  31. return False
  32. formula = data_sample["pred_pseudo_label"][0]
  33. formula = [self.label2evaluable[f] for f in formula]
  34. return abs(eval("".join(formula)) - y) < self.max_err
  35. def construct_base(self) -> dict:
  36. X, Y = [], []
  37. for length in self.GKB_len_list:
  38. arg_list = []
  39. for pre_x in self.pseudo_label_list:
  40. post_x_it = product(self.pseudo_label_list, repeat=length - 1)
  41. arg_list.append((pre_x, post_x_it))
  42. with Pool(processes=len(arg_list)) as pool:
  43. ret_list = pool.map(self._get_XY_list, arg_list)
  44. for XY_list in ret_list:
  45. if len(XY_list) == 0:
  46. continue
  47. part_X, part_Y = zip(*XY_list)
  48. X.extend(part_X)
  49. Y.extend(part_Y)
  50. if Y and isinstance(Y[0], (int, float)):
  51. X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1]))
  52. GKB = {}
  53. for x, y in zip(X, Y):
  54. GKB.setdefault(len(x), defaultdict(list))[y].append(x)
  55. return GKB
  56. @staticmethod
  57. def get_key(data_sample: ListData) -> Hashable:
  58. return (data_sample["symbol_num"], data_sample["Y"][0])
  59. def key2candidates(self, key: Hashable) -> List[List[Any]]:
  60. equation_len, y = key
  61. if self.max_err == 0:
  62. return self.GKB[equation_len][y]
  63. else:
  64. potential_candidates = self.GKB[equation_len]
  65. key_list = list(potential_candidates.keys())
  66. key_idx = bisect.bisect_left(key_list, y)
  67. all_candidates = []
  68. for idx in range(key_idx - 1, -1, -1):
  69. k = key_list[idx]
  70. if abs(k - y) <= self.max_err:
  71. all_candidates.extend(potential_candidates[k])
  72. else:
  73. break
  74. for idx in range(key_idx, len(key_list)):
  75. k = key_list[idx]
  76. if abs(k - y) <= self.max_err:
  77. all_candidates.extend(potential_candidates[k])
  78. else:
  79. break
  80. return all_candidates
  81. def filter_candidates(
  82. self,
  83. data_sample: ListData,
  84. candidates: List[List[Any]],
  85. max_revision_num: int,
  86. require_more_revision: int = 0,
  87. ) -> List[List[Any]]:
  88. cost_list = hamming_dist(data_sample["pred_pseudo_label"][0], candidates)
  89. min_revision_num = np.min(cost_list)
  90. revision_num = min(max_revision_num, min_revision_num + require_more_revision)
  91. idxs = np.where(cost_list <= revision_num)[0]
  92. filtered_candidates = [candidates[idx] for idx in idxs]
  93. return filtered_candidates
  94. # TODO: change return value to List[ListData]
  95. def _get_XY_list(self, args):
  96. pre_x, post_x_it = args[0], args[1]
  97. XY_list = []
  98. for post_x in post_x_it:
  99. x = (pre_x,) + post_x
  100. data_sample = ListData(pred_pseudo_label=[x])
  101. y = self.logic_forward(data_sample)
  102. if y is not None:
  103. XY_list.append((x, y))
  104. return XY_list
  105. @staticmethod
  106. def _valid_candidate(data_sample):
  107. formula = data_sample["pred_pseudo_label"][0]
  108. if len(formula) % 2 == 0:
  109. return False
  110. for i in range(len(formula)):
  111. if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]:
  112. return False
  113. if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]:
  114. return False
  115. return True

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.