From 24dcf02b33a822b0bd3a8bf3379ec7c22a5c1382 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Sat, 19 Nov 2022 09:51:37 +0800 Subject: [PATCH] Update kb.py --- abducer/kb.py | 229 +++++++++++++++----------------------------------- 1 file changed, 70 insertions(+), 159 deletions(-) diff --git a/abducer/kb.py b/abducer/kb.py index 12772e5..a448a83 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -44,12 +44,14 @@ class KBBase(ABC): def __len__(self): pass -class add_KB(KBBase): - def __init__(self, kb_max_len = -1): + +class ClsKB(KBBase): + def __init__(self, pseudo_label_list, kb_max_len = -1): super().__init__() - self.pseudo_label_list = list(range(10)) + self.pseudo_label_list = pseudo_label_list self.base = {} self.kb_max_len = kb_max_len + if(self.kb_max_len > 0): X = self.get_X(self.pseudo_label_list, self.kb_max_len) Y = self.get_Y(X, self.logic_forward) @@ -57,9 +59,6 @@ class add_KB(KBBase): for x, y in zip(X, Y): self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) - def logic_forward(self, nums): - return sum(nums) - def get_X(self, pseudo_label_list, max_len): res = [] assert(max_len >= 2) @@ -69,6 +68,9 @@ class add_KB(KBBase): def get_Y(self, X, logic_forward): return [logic_forward(nums) for nums in X] + + def logic_forward(self): + return None def get_candidates(self, key, length = None): if(self.base == {}): @@ -76,7 +78,7 @@ class add_KB(KBBase): if key is None: return self.get_all_candidates() - + length = self._length(length) if(self.kb_max_len < min(length)): return [] @@ -84,163 +86,72 @@ class add_KB(KBBase): def get_all_candidates(self): return sum([sum(v.values(), []) for v in self.base.values()], []) - + def _dict_len(self, dic): return sum(len(c) for c in dic.values()) def __len__(self): return sum(self._dict_len(v) for v in self.base.values()) - -class hwf_KB(KBBase): + + + +class add_KB(ClsKB): def __init__(self, kb_max_len = -1): - super().__init__() - self.pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/'] - self.base = {} - self.kb_max_len = kb_max_len - if(self.kb_max_len > 0): - X = self.get_X(self.pseudo_label_list, self.kb_max_len) - Y = self.get_Y(X, self.logic_forward) + self.pseudo_label_list = list(range(10)) + super().__init__(self.pseudo_label_list, kb_max_len) + + def logic_forward(self, nums): + return sum(nums) - for x, y in zip(X, Y): - self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) + def get_candidates(self, key, length = None): + return super().get_candidates(key, length) - def calculate(self, formula): - stack = [] - postfix = [] - priority = {'+': 0, '-': 0, - '*': 1, '/': 1} - skip_flag = 0 - for i in range(len(formula)): - if formula[i] == '-': - if i == 0: - formula.insert(0, 0) - for i in range(len(formula)): - if skip_flag: - skip_flag -= 1 - continue - char = formula[i] - if char in priority.keys(): - while stack and (priority[char] <= priority[stack[-1]]): - postfix.append(stack.pop()) - stack.append(char) - else: - num = int(char) - while (i + 1) < len(formula): - if formula[i + 1] not in priority.keys(): - skip_flag += 1 - num = num * 10 + int(formula[i + 1]) - i += 1 - else: - break - postfix.append(num) - while stack: - postfix.append(stack.pop()) + def get_all_candidates(self): + return super().get_all_candidates() + + def _dict_len(self, dic): + return super()._dict_len(dic) - for i in postfix: - if i in priority.keys(): - num2 = stack.pop() - num1 = stack.pop() - if i == '+': - res = num1 + num2 - elif i == '-': - res = num1 - num2 - elif i == '*': - res = num1 * num2 - elif i == '/': - if(num2 == 0): - return np.inf - res = num1 / num2 - stack.append(res) - else: - stack.append(i) - return round(stack[0], 2) + def __len__(self): + return super().__len__() +class hwf_KB(ClsKB): + def __init__(self, kb_max_len = -1): + self.pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/'] + super().__init__(self.pseudo_label_list, kb_max_len) + def valid_formula(self, formula): - symbol_idx_list = [] - for idx, c in enumerate(formula): - if(idx == 0 and c == '-'): - if(len(formula) == 1 or formula[1] in ['+', '-', '*', '/']): - return False - continue - if(c in ['+', '-', '*', '/']): - if(idx - 1 in symbol_idx_list): - return False - symbol_idx_list.append(idx) - if(0 in symbol_idx_list or len(formula) - 1 in symbol_idx_list): + if(len(formula) % 2 == 0): return False + for i in range(len(formula)): + if(i % 2 == 0 and formula[i] not in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']): + return False + if(i % 2 != 0 and formula[i] not in ['+', '-', '*', '/']): + return False return True def logic_forward(self, formula): if(self.valid_formula(formula) == False): return np.inf - return self.calculate(list(formula)) + try: + return eval(''.join(formula)) + except ZeroDivisionError: + return np.inf - def get_X(self, pseudo_label_list, max_len): - res = [] - assert(max_len >= 2) - for len in range(2, max_len + 1): - res += list(product(pseudo_label_list, repeat = len)) - return res - - def get_Y(self, X, logic_forward): - return [logic_forward(formula) for formula in X] - def get_candidates(self, key, length = None): - if(self.base == {}): - return [] - - if key is None: - return self.get_all_candidates() - - length = self._length(length) - if(self.kb_max_len < min(length)): - return [] - return sum([self.base[l][key] for l in length], []) + return super().get_candidates(key, length) def get_all_candidates(self): - return sum([sum(v.values(), []) for v in self.base.values()], []) - - def _dict_len(self, dic): - return sum(len(c) for c in dic.values()) - - def __len__(self): - return sum(self._dict_len(v) for v in self.base.values()) - -class cls_KB(KBBase): - def __init__(self, X, Y = None): - super().__init__() - self.base = {} - - if X is None: - return - - if Y is None: - Y = [None] * len(X) - - for x, y in zip(X, Y): - self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) + return super().get_all_candidates() - def logic_forward(self): - return None - - def get_candidates(self, key, length = None): - if key is None: - return self.get_all_candidates() - - length = self._length(length) - - return sum([self.base[l][key] for l in length], []) - - def get_all_candidates(self): - return sum([sum(v.values(), []) for v in self.base.values()], []) - def _dict_len(self, dic): - return sum(len(c) for c in dic.values()) + return super()._dict_len(dic) def __len__(self): - return sum(self._dict_len(v) for v in self.base.values()) + return super().__len__() + -class reg_KB(KBBase): +class RegKB(KBBase): def __init__(self, X, Y = None): super().__init__() tmp_dict = {} @@ -323,26 +234,26 @@ if __name__ == "__main__": print() - X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] - Y = [2, 1, 1, 2, 2] - kb = cls_KB(X, Y) - print('len(kb):', len(kb)) - res = kb.get_candidates(2, 5) - print(res) - res = kb.get_candidates(2, 3) - print(res) - res = kb.get_candidates(None) - print(res) - print() + # X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] + # Y = [2, 1, 1, 2, 2] + # kb = ClsKB(X, Y) + # print('len(kb):', len(kb)) + # res = kb.get_candidates(2, 5) + # print(res) + # res = kb.get_candidates(2, 3) + # print(res) + # res = kb.get_candidates(None) + # print(res) + # print() - X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"] - Y = [2, 1, 1, 2, 1.5, 1.5] - kb = reg_KB(X, Y) - print('len(kb):', len(kb)) - res = kb.get_candidates(1.6) - print(res) - res = kb.get_candidates(1.6, length = 9) - print(res) - res = kb.get_candidates(None) - print(res) + # X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"] + # Y = [2, 1, 1, 2, 1.5, 1.5] + # kb = RegKB(X, Y) + # print('len(kb):', len(kb)) + # res = kb.get_candidates(1.6) + # print(res) + # res = kb.get_candidates(1.6, length = 9) + # print(res) + # res = kb.get_candidates(None) + # print(res)