Browse Source

Update kb.py

pull/3/head
troyyyyy GitHub 2 years ago
parent
commit
24dcf02b33
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 70 additions and 159 deletions
  1. +70
    -159
      abducer/kb.py

+ 70
- 159
abducer/kb.py View File

@@ -44,12 +44,14 @@ class KBBase(ABC):
def __len__(self): def __len__(self):
pass pass


class add_KB(KBBase):
def __init__(self, kb_max_len = -1):

class ClsKB(KBBase):
def __init__(self, pseudo_label_list, kb_max_len = -1):
super().__init__() super().__init__()
self.pseudo_label_list = list(range(10))
self.pseudo_label_list = pseudo_label_list
self.base = {} self.base = {}
self.kb_max_len = kb_max_len self.kb_max_len = kb_max_len
if(self.kb_max_len > 0): if(self.kb_max_len > 0):
X = self.get_X(self.pseudo_label_list, self.kb_max_len) X = self.get_X(self.pseudo_label_list, self.kb_max_len)
Y = self.get_Y(X, self.logic_forward) Y = self.get_Y(X, self.logic_forward)
@@ -57,9 +59,6 @@ class add_KB(KBBase):
for x, y in zip(X, Y): for x, y in zip(X, Y):
self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x))
def logic_forward(self, nums):
return sum(nums)
def get_X(self, pseudo_label_list, max_len): def get_X(self, pseudo_label_list, max_len):
res = [] res = []
assert(max_len >= 2) assert(max_len >= 2)
@@ -69,6 +68,9 @@ class add_KB(KBBase):


def get_Y(self, X, logic_forward): def get_Y(self, X, logic_forward):
return [logic_forward(nums) for nums in X] return [logic_forward(nums) for nums in X]
def logic_forward(self):
return None


def get_candidates(self, key, length = None): def get_candidates(self, key, length = None):
if(self.base == {}): if(self.base == {}):
@@ -76,7 +78,7 @@ class add_KB(KBBase):
if key is None: if key is None:
return self.get_all_candidates() return self.get_all_candidates()
length = self._length(length) length = self._length(length)
if(self.kb_max_len < min(length)): if(self.kb_max_len < min(length)):
return [] return []
@@ -84,163 +86,72 @@ class add_KB(KBBase):
def get_all_candidates(self): def get_all_candidates(self):
return sum([sum(v.values(), []) for v in self.base.values()], []) return sum([sum(v.values(), []) for v in self.base.values()], [])
def _dict_len(self, dic): def _dict_len(self, dic):
return sum(len(c) for c in dic.values()) return sum(len(c) for c in dic.values())


def __len__(self): def __len__(self):
return sum(self._dict_len(v) for v in self.base.values()) return sum(self._dict_len(v) for v in self.base.values())
class hwf_KB(KBBase):



class add_KB(ClsKB):
def __init__(self, kb_max_len = -1): def __init__(self, kb_max_len = -1):
super().__init__()
self.pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/']
self.base = {}
self.kb_max_len = kb_max_len
if(self.kb_max_len > 0):
X = self.get_X(self.pseudo_label_list, self.kb_max_len)
Y = self.get_Y(X, self.logic_forward)
self.pseudo_label_list = list(range(10))
super().__init__(self.pseudo_label_list, kb_max_len)
def logic_forward(self, nums):
return sum(nums)


for x, y in zip(X, Y):
self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x))
def get_candidates(self, key, length = None):
return super().get_candidates(key, length)
def calculate(self, formula):
stack = []
postfix = []
priority = {'+': 0, '-': 0,
'*': 1, '/': 1}
skip_flag = 0
for i in range(len(formula)):
if formula[i] == '-':
if i == 0:
formula.insert(0, 0)
for i in range(len(formula)):
if skip_flag:
skip_flag -= 1
continue
char = formula[i]
if char in priority.keys():
while stack and (priority[char] <= priority[stack[-1]]):
postfix.append(stack.pop())
stack.append(char)
else:
num = int(char)
while (i + 1) < len(formula):
if formula[i + 1] not in priority.keys():
skip_flag += 1
num = num * 10 + int(formula[i + 1])
i += 1
else:
break
postfix.append(num)
while stack:
postfix.append(stack.pop())
def get_all_candidates(self):
return super().get_all_candidates()
def _dict_len(self, dic):
return super()._dict_len(dic)


for i in postfix:
if i in priority.keys():
num2 = stack.pop()
num1 = stack.pop()
if i == '+':
res = num1 + num2
elif i == '-':
res = num1 - num2
elif i == '*':
res = num1 * num2
elif i == '/':
if(num2 == 0):
return np.inf
res = num1 / num2
stack.append(res)
else:
stack.append(i)
return round(stack[0], 2)
def __len__(self):
return super().__len__()
class hwf_KB(ClsKB):
def __init__(self, kb_max_len = -1):
self.pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/']
super().__init__(self.pseudo_label_list, kb_max_len)
def valid_formula(self, formula): def valid_formula(self, formula):
symbol_idx_list = []
for idx, c in enumerate(formula):
if(idx == 0 and c == '-'):
if(len(formula) == 1 or formula[1] in ['+', '-', '*', '/']):
return False
continue
if(c in ['+', '-', '*', '/']):
if(idx - 1 in symbol_idx_list):
return False
symbol_idx_list.append(idx)
if(0 in symbol_idx_list or len(formula) - 1 in symbol_idx_list):
if(len(formula) % 2 == 0):
return False return False
for i in range(len(formula)):
if(i % 2 == 0 and formula[i] not in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']):
return False
if(i % 2 != 0 and formula[i] not in ['+', '-', '*', '/']):
return False
return True return True
def logic_forward(self, formula): def logic_forward(self, formula):
if(self.valid_formula(formula) == False): if(self.valid_formula(formula) == False):
return np.inf return np.inf
return self.calculate(list(formula))
try:
return eval(''.join(formula))
except ZeroDivisionError:
return np.inf
def get_X(self, pseudo_label_list, max_len):
res = []
assert(max_len >= 2)
for len in range(2, max_len + 1):
res += list(product(pseudo_label_list, repeat = len))
return res

def get_Y(self, X, logic_forward):
return [logic_forward(formula) for formula in X]

def get_candidates(self, key, length = None): def get_candidates(self, key, length = None):
if(self.base == {}):
return []
if key is None:
return self.get_all_candidates()

length = self._length(length)
if(self.kb_max_len < min(length)):
return []
return sum([self.base[l][key] for l in length], [])
return super().get_candidates(key, length)
def get_all_candidates(self): def get_all_candidates(self):
return sum([sum(v.values(), []) for v in self.base.values()], [])

def _dict_len(self, dic):
return sum(len(c) for c in dic.values())

def __len__(self):
return sum(self._dict_len(v) for v in self.base.values())
class cls_KB(KBBase):
def __init__(self, X, Y = None):
super().__init__()
self.base = {}
if X is None:
return

if Y is None:
Y = [None] * len(X)

for x, y in zip(X, Y):
self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x))
return super().get_all_candidates()
def logic_forward(self):
return None

def get_candidates(self, key, length = None):
if key is None:
return self.get_all_candidates()

length = self._length(length)

return sum([self.base[l][key] for l in length], [])
def get_all_candidates(self):
return sum([sum(v.values(), []) for v in self.base.values()], [])

def _dict_len(self, dic): def _dict_len(self, dic):
return sum(len(c) for c in dic.values())
return super()._dict_len(dic)


def __len__(self): def __len__(self):
return sum(self._dict_len(v) for v in self.base.values())
return super().__len__()


class reg_KB(KBBase):
class RegKB(KBBase):
def __init__(self, X, Y = None): def __init__(self, X, Y = None):
super().__init__() super().__init__()
tmp_dict = {} tmp_dict = {}
@@ -323,26 +234,26 @@ if __name__ == "__main__":
print() print()
X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"]
Y = [2, 1, 1, 2, 2]
kb = cls_KB(X, Y)
print('len(kb):', len(kb))
res = kb.get_candidates(2, 5)
print(res)
res = kb.get_candidates(2, 3)
print(res)
res = kb.get_candidates(None)
print(res)
print()
# X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"]
# Y = [2, 1, 1, 2, 2]
# kb = ClsKB(X, Y)
# print('len(kb):', len(kb))
# res = kb.get_candidates(2, 5)
# print(res)
# res = kb.get_candidates(2, 3)
# print(res)
# res = kb.get_candidates(None)
# print(res)
# print()
X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"]
Y = [2, 1, 1, 2, 1.5, 1.5]
kb = reg_KB(X, Y)
print('len(kb):', len(kb))
res = kb.get_candidates(1.6)
print(res)
res = kb.get_candidates(1.6, length = 9)
print(res)
res = kb.get_candidates(None)
print(res)
# X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"]
# Y = [2, 1, 1, 2, 1.5, 1.5]
# kb = RegKB(X, Y)
# print('len(kb):', len(kb))
# res = kb.get_candidates(1.6)
# print(res)
# res = kb.get_candidates(1.6, length = 9)
# print(res)
# res = kb.get_candidates(None)
# print(res)



Loading…
Cancel
Save