From 6ffdf44c351c31ea0692688f0efdb0eafb9b92f1 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Wed, 16 Nov 2022 10:16:13 +0800 Subject: [PATCH] Update kb.py --- abducer/kb.py | 158 ++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 115 insertions(+), 43 deletions(-) diff --git a/abducer/kb.py b/abducer/kb.py index 95e9692..6007a5b 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -91,52 +91,120 @@ class add_KB(KBBase): def __len__(self): return sum(self._dict_len(v) for v in self.base.values()) -# class hwf_KB(KBBase): -# def __init__(self, pseudo_label_list, kb_max_len = -1): -# super().__init__() -# self.pseudo_label_list = pseudo_label_list -# self.base = {} -# self.kb_max_len = kb_max_len -# if(self.kb_max_len > 0): -# X = self.get_X(self.pseudo_label_list, self.kb_max_len) -# Y = self.get_Y(X, self.logic_forward) - -# for x, y in zip(X, Y): -# self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) +class hwf_KB(KBBase): + def __init__(self, pseudo_label_list, kb_max_len = -1): + super().__init__() + self.pseudo_label_list = pseudo_label_list + self.base = {} + self.kb_max_len = kb_max_len + if(self.kb_max_len > 0): + X = self.get_X(self.pseudo_label_list, self.kb_max_len) + Y = self.get_Y(X, self.logic_forward) + + for x, y in zip(X, Y): + self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) -# def logic_forward(self, nums): -# return sum(nums) + def calculate(self, formula): + stack = [] + postfix = [] + priority = {'+': 0, '-': 0, + '*': 1, '/': 1} + skip_flag = 0 + for i in range(len(formula)): + if formula[i] == '-': + if i == 0: + formula.insert(0, 0) + for i in range(len(formula)): + if skip_flag: + skip_flag -= 1 + continue + char = formula[i] + if char in priority.keys(): + while stack and (priority[char] <= priority[stack[-1]]): + postfix.append(stack.pop()) + stack.append(char) + else: + num = char + while (i + 1) < len(formula): + if formula[i + 1] not in priority.keys(): + skip_flag += 1 + num = num * 10 + formula[i + 1] + i += 1 + else: + break + postfix.append(num) + while stack: + postfix.append(stack.pop()) + + for i in postfix: + if i in priority.keys(): + num2 = stack.pop() + num1 = stack.pop() + if i == '+': + res = num1 + num2 + elif i == '-': + res = num1 - num2 + elif i == '*': + res = num1 * num2 + elif i == '/': + if(num2 == 0): + return np.inf + res = num1 / num2 + stack.append(res) + else: + stack.append(i) + return round(stack[0], 2) -# def get_X(self, pseudo_label_list, max_len): -# res = [] -# assert(max_len >= 2) -# for len in range(2, max_len + 1): -# res += list(product(pseudo_label_list, repeat = len)) -# return res - -# def get_Y(self, X, logic_forward): -# return [logic_forward(nums) for nums in X] - -# def get_candidates(self, key, length = None): -# if(self.base == {}): -# return [] + def valid_formula(self, formula): + symbol_idx_list = [] + first_minus_flag = 0 + for idx, c in enumerate(formula): + if(idx == 0 and c == '-'): + first_minus_flag = 1 + continue + if(c in ['+', '-', '*', '/']): + if(idx - 1 in symbol_idx_list or (idx == 1 and first_minus_flag == 1)): + return False + symbol_idx_list.append(idx) + if(0 in symbol_idx_list or len(formula) - 1 in symbol_idx_list): + return False + return True + + def logic_forward(self, formula): + if(self.valid_formula(formula) == False): + return np.inf + return self.calculate(list(formula)) -# if key is None: -# return self.get_all_candidates() + def get_X(self, pseudo_label_list, max_len): + res = [] + assert(max_len >= 2) + for len in range(2, max_len + 1): + res += list(product(pseudo_label_list, repeat = len)) + return res -# length = self._length(length) -# if(self.kb_max_len < min(length)): -# return [] -# return sum([self.base[l][key] for l in length], []) + def get_Y(self, X, logic_forward): + return [logic_forward(formula) for formula in X] + + def get_candidates(self, key, length = None): + if(self.base == {}): + return [] + + if key is None: + return self.get_all_candidates() + + length = self._length(length) + if(self.kb_max_len < min(length)): + return [] + return sum([self.base[l][key] for l in length], []) -# def get_all_candidates(self): -# return sum([sum(v.values(), []) for v in self.base.values()], []) + def get_all_candidates(self): + return sum([sum(v.values(), []) for v in self.base.values()], []) -# def _dict_len(self, dic): -# return sum(len(c) for c in dic.values()) + def _dict_len(self, dic): + return sum(len(c) for c in dic.values()) -# def __len__(self): -# return sum(self._dict_len(v) for v in self.base.values()) + def __len__(self): + return sum(self._dict_len(v) for v in self.base.values()) class cls_KB(KBBase): def __init__(self, X, Y = None): @@ -248,10 +316,14 @@ if __name__ == "__main__": print(res) print() - # pseudo_label_list = list(range(10)) + ['+', '-', '*', '/'] - # kb = hwf_KB(pseudo_label_list, max_len = 5) - # print('len(kb):', len(kb)) - # print() + pseudo_label_list = list(range(10)) + ['+', '-', '*', '/'] + kb = hwf_KB(pseudo_label_list, kb_max_len = 5) + print('len(kb):', len(kb)) + res = kb.get_candidates(1, length = 3) + print(res) + res = kb.get_candidates(3.67, length = 5) + print(res) + print() X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"]