|
|
@@ -44,12 +44,14 @@ class KBBase(ABC): |
|
|
|
def __len__(self): |
|
|
|
pass |
|
|
|
|
|
|
|
class add_KB(KBBase): |
|
|
|
def __init__(self, kb_max_len = -1): |
|
|
|
|
|
|
|
class ClsKB(KBBase): |
|
|
|
def __init__(self, pseudo_label_list, kb_max_len = -1): |
|
|
|
super().__init__() |
|
|
|
self.pseudo_label_list = list(range(10)) |
|
|
|
self.pseudo_label_list = pseudo_label_list |
|
|
|
self.base = {} |
|
|
|
self.kb_max_len = kb_max_len |
|
|
|
|
|
|
|
if(self.kb_max_len > 0): |
|
|
|
X = self.get_X(self.pseudo_label_list, self.kb_max_len) |
|
|
|
Y = self.get_Y(X, self.logic_forward) |
|
|
@@ -57,9 +59,6 @@ class add_KB(KBBase): |
|
|
|
for x, y in zip(X, Y): |
|
|
|
self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) |
|
|
|
|
|
|
|
def logic_forward(self, nums): |
|
|
|
return sum(nums) |
|
|
|
|
|
|
|
def get_X(self, pseudo_label_list, max_len): |
|
|
|
res = [] |
|
|
|
assert(max_len >= 2) |
|
|
@@ -69,6 +68,9 @@ class add_KB(KBBase): |
|
|
|
|
|
|
|
def get_Y(self, X, logic_forward): |
|
|
|
return [logic_forward(nums) for nums in X] |
|
|
|
|
|
|
|
def logic_forward(self): |
|
|
|
return None |
|
|
|
|
|
|
|
def get_candidates(self, key, length = None): |
|
|
|
if(self.base == {}): |
|
|
@@ -76,7 +78,7 @@ class add_KB(KBBase): |
|
|
|
|
|
|
|
if key is None: |
|
|
|
return self.get_all_candidates() |
|
|
|
|
|
|
|
|
|
|
|
length = self._length(length) |
|
|
|
if(self.kb_max_len < min(length)): |
|
|
|
return [] |
|
|
@@ -84,163 +86,72 @@ class add_KB(KBBase): |
|
|
|
|
|
|
|
def get_all_candidates(self): |
|
|
|
return sum([sum(v.values(), []) for v in self.base.values()], []) |
|
|
|
|
|
|
|
|
|
|
|
def _dict_len(self, dic): |
|
|
|
return sum(len(c) for c in dic.values()) |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return sum(self._dict_len(v) for v in self.base.values()) |
|
|
|
|
|
|
|
class hwf_KB(KBBase): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class add_KB(ClsKB): |
|
|
|
def __init__(self, kb_max_len = -1): |
|
|
|
super().__init__() |
|
|
|
self.pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/'] |
|
|
|
self.base = {} |
|
|
|
self.kb_max_len = kb_max_len |
|
|
|
if(self.kb_max_len > 0): |
|
|
|
X = self.get_X(self.pseudo_label_list, self.kb_max_len) |
|
|
|
Y = self.get_Y(X, self.logic_forward) |
|
|
|
self.pseudo_label_list = list(range(10)) |
|
|
|
super().__init__(self.pseudo_label_list, kb_max_len) |
|
|
|
|
|
|
|
def logic_forward(self, nums): |
|
|
|
return sum(nums) |
|
|
|
|
|
|
|
for x, y in zip(X, Y): |
|
|
|
self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) |
|
|
|
def get_candidates(self, key, length = None): |
|
|
|
return super().get_candidates(key, length) |
|
|
|
|
|
|
|
def calculate(self, formula): |
|
|
|
stack = [] |
|
|
|
postfix = [] |
|
|
|
priority = {'+': 0, '-': 0, |
|
|
|
'*': 1, '/': 1} |
|
|
|
skip_flag = 0 |
|
|
|
for i in range(len(formula)): |
|
|
|
if formula[i] == '-': |
|
|
|
if i == 0: |
|
|
|
formula.insert(0, 0) |
|
|
|
for i in range(len(formula)): |
|
|
|
if skip_flag: |
|
|
|
skip_flag -= 1 |
|
|
|
continue |
|
|
|
char = formula[i] |
|
|
|
if char in priority.keys(): |
|
|
|
while stack and (priority[char] <= priority[stack[-1]]): |
|
|
|
postfix.append(stack.pop()) |
|
|
|
stack.append(char) |
|
|
|
else: |
|
|
|
num = int(char) |
|
|
|
while (i + 1) < len(formula): |
|
|
|
if formula[i + 1] not in priority.keys(): |
|
|
|
skip_flag += 1 |
|
|
|
num = num * 10 + int(formula[i + 1]) |
|
|
|
i += 1 |
|
|
|
else: |
|
|
|
break |
|
|
|
postfix.append(num) |
|
|
|
while stack: |
|
|
|
postfix.append(stack.pop()) |
|
|
|
def get_all_candidates(self): |
|
|
|
return super().get_all_candidates() |
|
|
|
|
|
|
|
def _dict_len(self, dic): |
|
|
|
return super()._dict_len(dic) |
|
|
|
|
|
|
|
for i in postfix: |
|
|
|
if i in priority.keys(): |
|
|
|
num2 = stack.pop() |
|
|
|
num1 = stack.pop() |
|
|
|
if i == '+': |
|
|
|
res = num1 + num2 |
|
|
|
elif i == '-': |
|
|
|
res = num1 - num2 |
|
|
|
elif i == '*': |
|
|
|
res = num1 * num2 |
|
|
|
elif i == '/': |
|
|
|
if(num2 == 0): |
|
|
|
return np.inf |
|
|
|
res = num1 / num2 |
|
|
|
stack.append(res) |
|
|
|
else: |
|
|
|
stack.append(i) |
|
|
|
return round(stack[0], 2) |
|
|
|
def __len__(self): |
|
|
|
return super().__len__() |
|
|
|
|
|
|
|
class hwf_KB(ClsKB): |
|
|
|
def __init__(self, kb_max_len = -1): |
|
|
|
self.pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/'] |
|
|
|
super().__init__(self.pseudo_label_list, kb_max_len) |
|
|
|
|
|
|
|
def valid_formula(self, formula): |
|
|
|
symbol_idx_list = [] |
|
|
|
for idx, c in enumerate(formula): |
|
|
|
if(idx == 0 and c == '-'): |
|
|
|
if(len(formula) == 1 or formula[1] in ['+', '-', '*', '/']): |
|
|
|
return False |
|
|
|
continue |
|
|
|
if(c in ['+', '-', '*', '/']): |
|
|
|
if(idx - 1 in symbol_idx_list): |
|
|
|
return False |
|
|
|
symbol_idx_list.append(idx) |
|
|
|
if(0 in symbol_idx_list or len(formula) - 1 in symbol_idx_list): |
|
|
|
if(len(formula) % 2 == 0): |
|
|
|
return False |
|
|
|
for i in range(len(formula)): |
|
|
|
if(i % 2 == 0 and formula[i] not in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']): |
|
|
|
return False |
|
|
|
if(i % 2 != 0 and formula[i] not in ['+', '-', '*', '/']): |
|
|
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
def logic_forward(self, formula): |
|
|
|
if(self.valid_formula(formula) == False): |
|
|
|
return np.inf |
|
|
|
return self.calculate(list(formula)) |
|
|
|
try: |
|
|
|
return eval(''.join(formula)) |
|
|
|
except ZeroDivisionError: |
|
|
|
return np.inf |
|
|
|
|
|
|
|
def get_X(self, pseudo_label_list, max_len): |
|
|
|
res = [] |
|
|
|
assert(max_len >= 2) |
|
|
|
for len in range(2, max_len + 1): |
|
|
|
res += list(product(pseudo_label_list, repeat = len)) |
|
|
|
return res |
|
|
|
|
|
|
|
def get_Y(self, X, logic_forward): |
|
|
|
return [logic_forward(formula) for formula in X] |
|
|
|
|
|
|
|
def get_candidates(self, key, length = None): |
|
|
|
if(self.base == {}): |
|
|
|
return [] |
|
|
|
|
|
|
|
if key is None: |
|
|
|
return self.get_all_candidates() |
|
|
|
|
|
|
|
length = self._length(length) |
|
|
|
if(self.kb_max_len < min(length)): |
|
|
|
return [] |
|
|
|
return sum([self.base[l][key] for l in length], []) |
|
|
|
return super().get_candidates(key, length) |
|
|
|
|
|
|
|
def get_all_candidates(self): |
|
|
|
return sum([sum(v.values(), []) for v in self.base.values()], []) |
|
|
|
|
|
|
|
def _dict_len(self, dic): |
|
|
|
return sum(len(c) for c in dic.values()) |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return sum(self._dict_len(v) for v in self.base.values()) |
|
|
|
|
|
|
|
class cls_KB(KBBase): |
|
|
|
def __init__(self, X, Y = None): |
|
|
|
super().__init__() |
|
|
|
self.base = {} |
|
|
|
|
|
|
|
if X is None: |
|
|
|
return |
|
|
|
|
|
|
|
if Y is None: |
|
|
|
Y = [None] * len(X) |
|
|
|
|
|
|
|
for x, y in zip(X, Y): |
|
|
|
self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) |
|
|
|
return super().get_all_candidates() |
|
|
|
|
|
|
|
def logic_forward(self): |
|
|
|
return None |
|
|
|
|
|
|
|
def get_candidates(self, key, length = None): |
|
|
|
if key is None: |
|
|
|
return self.get_all_candidates() |
|
|
|
|
|
|
|
length = self._length(length) |
|
|
|
|
|
|
|
return sum([self.base[l][key] for l in length], []) |
|
|
|
|
|
|
|
def get_all_candidates(self): |
|
|
|
return sum([sum(v.values(), []) for v in self.base.values()], []) |
|
|
|
|
|
|
|
def _dict_len(self, dic): |
|
|
|
return sum(len(c) for c in dic.values()) |
|
|
|
return super()._dict_len(dic) |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return sum(self._dict_len(v) for v in self.base.values()) |
|
|
|
return super().__len__() |
|
|
|
|
|
|
|
|
|
|
|
class reg_KB(KBBase): |
|
|
|
class RegKB(KBBase): |
|
|
|
def __init__(self, X, Y = None): |
|
|
|
super().__init__() |
|
|
|
tmp_dict = {} |
|
|
@@ -323,26 +234,26 @@ if __name__ == "__main__": |
|
|
|
print() |
|
|
|
|
|
|
|
|
|
|
|
X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] |
|
|
|
Y = [2, 1, 1, 2, 2] |
|
|
|
kb = cls_KB(X, Y) |
|
|
|
print('len(kb):', len(kb)) |
|
|
|
res = kb.get_candidates(2, 5) |
|
|
|
print(res) |
|
|
|
res = kb.get_candidates(2, 3) |
|
|
|
print(res) |
|
|
|
res = kb.get_candidates(None) |
|
|
|
print(res) |
|
|
|
print() |
|
|
|
# X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] |
|
|
|
# Y = [2, 1, 1, 2, 2] |
|
|
|
# kb = ClsKB(X, Y) |
|
|
|
# print('len(kb):', len(kb)) |
|
|
|
# res = kb.get_candidates(2, 5) |
|
|
|
# print(res) |
|
|
|
# res = kb.get_candidates(2, 3) |
|
|
|
# print(res) |
|
|
|
# res = kb.get_candidates(None) |
|
|
|
# print(res) |
|
|
|
# print() |
|
|
|
|
|
|
|
X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"] |
|
|
|
Y = [2, 1, 1, 2, 1.5, 1.5] |
|
|
|
kb = reg_KB(X, Y) |
|
|
|
print('len(kb):', len(kb)) |
|
|
|
res = kb.get_candidates(1.6) |
|
|
|
print(res) |
|
|
|
res = kb.get_candidates(1.6, length = 9) |
|
|
|
print(res) |
|
|
|
res = kb.get_candidates(None) |
|
|
|
print(res) |
|
|
|
# X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"] |
|
|
|
# Y = [2, 1, 1, 2, 1.5, 1.5] |
|
|
|
# kb = RegKB(X, Y) |
|
|
|
# print('len(kb):', len(kb)) |
|
|
|
# res = kb.get_candidates(1.6) |
|
|
|
# print(res) |
|
|
|
# res = kb.get_candidates(1.6, length = 9) |
|
|
|
# print(res) |
|
|
|
# res = kb.get_candidates(None) |
|
|
|
# print(res) |
|
|
|
|