Browse Source

Update kb.py

pull/3/head
troyyyyy GitHub 2 years ago
parent
commit
6ffdf44c35
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 115 additions and 43 deletions
  1. +115
    -43
      abducer/kb.py

+ 115
- 43
abducer/kb.py View File

@@ -91,52 +91,120 @@ class add_KB(KBBase):
def __len__(self): def __len__(self):
return sum(self._dict_len(v) for v in self.base.values()) return sum(self._dict_len(v) for v in self.base.values())
# class hwf_KB(KBBase):
# def __init__(self, pseudo_label_list, kb_max_len = -1):
# super().__init__()
# self.pseudo_label_list = pseudo_label_list
# self.base = {}
# self.kb_max_len = kb_max_len
# if(self.kb_max_len > 0):
# X = self.get_X(self.pseudo_label_list, self.kb_max_len)
# Y = self.get_Y(X, self.logic_forward)
# for x, y in zip(X, Y):
# self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x))
class hwf_KB(KBBase):
def __init__(self, pseudo_label_list, kb_max_len = -1):
super().__init__()
self.pseudo_label_list = pseudo_label_list
self.base = {}
self.kb_max_len = kb_max_len
if(self.kb_max_len > 0):
X = self.get_X(self.pseudo_label_list, self.kb_max_len)
Y = self.get_Y(X, self.logic_forward)
for x, y in zip(X, Y):
self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x))
# def logic_forward(self, nums):
# return sum(nums)
def calculate(self, formula):
stack = []
postfix = []
priority = {'+': 0, '-': 0,
'*': 1, '/': 1}
skip_flag = 0
for i in range(len(formula)):
if formula[i] == '-':
if i == 0:
formula.insert(0, 0)
for i in range(len(formula)):
if skip_flag:
skip_flag -= 1
continue
char = formula[i]
if char in priority.keys():
while stack and (priority[char] <= priority[stack[-1]]):
postfix.append(stack.pop())
stack.append(char)
else:
num = char
while (i + 1) < len(formula):
if formula[i + 1] not in priority.keys():
skip_flag += 1
num = num * 10 + formula[i + 1]
i += 1
else:
break
postfix.append(num)
while stack:
postfix.append(stack.pop())

for i in postfix:
if i in priority.keys():
num2 = stack.pop()
num1 = stack.pop()
if i == '+':
res = num1 + num2
elif i == '-':
res = num1 - num2
elif i == '*':
res = num1 * num2
elif i == '/':
if(num2 == 0):
return np.inf
res = num1 / num2
stack.append(res)
else:
stack.append(i)
return round(stack[0], 2)
# def get_X(self, pseudo_label_list, max_len):
# res = []
# assert(max_len >= 2)
# for len in range(2, max_len + 1):
# res += list(product(pseudo_label_list, repeat = len))
# return res

# def get_Y(self, X, logic_forward):
# return [logic_forward(nums) for nums in X]

# def get_candidates(self, key, length = None):
# if(self.base == {}):
# return []
def valid_formula(self, formula):
symbol_idx_list = []
first_minus_flag = 0
for idx, c in enumerate(formula):
if(idx == 0 and c == '-'):
first_minus_flag = 1
continue
if(c in ['+', '-', '*', '/']):
if(idx - 1 in symbol_idx_list or (idx == 1 and first_minus_flag == 1)):
return False
symbol_idx_list.append(idx)
if(0 in symbol_idx_list or len(formula) - 1 in symbol_idx_list):
return False
return True
def logic_forward(self, formula):
if(self.valid_formula(formula) == False):
return np.inf
return self.calculate(list(formula))
# if key is None:
# return self.get_all_candidates()
def get_X(self, pseudo_label_list, max_len):
res = []
assert(max_len >= 2)
for len in range(2, max_len + 1):
res += list(product(pseudo_label_list, repeat = len))
return res


# length = self._length(length)
# if(self.kb_max_len < min(length)):
# return []
# return sum([self.base[l][key] for l in length], [])
def get_Y(self, X, logic_forward):
return [logic_forward(formula) for formula in X]

def get_candidates(self, key, length = None):
if(self.base == {}):
return []
if key is None:
return self.get_all_candidates()

length = self._length(length)
if(self.kb_max_len < min(length)):
return []
return sum([self.base[l][key] for l in length], [])
# def get_all_candidates(self):
# return sum([sum(v.values(), []) for v in self.base.values()], [])
def get_all_candidates(self):
return sum([sum(v.values(), []) for v in self.base.values()], [])


# def _dict_len(self, dic):
# return sum(len(c) for c in dic.values())
def _dict_len(self, dic):
return sum(len(c) for c in dic.values())


# def __len__(self):
# return sum(self._dict_len(v) for v in self.base.values())
def __len__(self):
return sum(self._dict_len(v) for v in self.base.values())
class cls_KB(KBBase): class cls_KB(KBBase):
def __init__(self, X, Y = None): def __init__(self, X, Y = None):
@@ -248,10 +316,14 @@ if __name__ == "__main__":
print(res) print(res)
print() print()
# pseudo_label_list = list(range(10)) + ['+', '-', '*', '/']
# kb = hwf_KB(pseudo_label_list, max_len = 5)
# print('len(kb):', len(kb))
# print()
pseudo_label_list = list(range(10)) + ['+', '-', '*', '/']
kb = hwf_KB(pseudo_label_list, kb_max_len = 5)
print('len(kb):', len(kb))
res = kb.get_candidates(1, length = 3)
print(res)
res = kb.get_candidates(3.67, length = 5)
print(res)
print()
X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"]


Loading…
Cancel
Save