|
|
@@ -119,7 +119,7 @@ class add_KB(ClsKB): |
|
|
|
|
|
|
|
class hwf_KB(ClsKB): |
|
|
|
def __init__(self, GKB_flag = False, \ |
|
|
|
pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \ |
|
|
|
pseudo_label_list = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \ |
|
|
|
len_list = [1, 3, 5, 7]): |
|
|
|
super().__init__(GKB_flag, pseudo_label_list, len_list) |
|
|
|
|
|
|
@@ -127,7 +127,7 @@ class hwf_KB(ClsKB): |
|
|
|
if len(formula) % 2 == 0: |
|
|
|
return False |
|
|
|
for i in range(len(formula)): |
|
|
|
if i % 2 == 0 and formula[i] not in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']: |
|
|
|
if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']: |
|
|
|
return False |
|
|
|
if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: |
|
|
|
return False |
|
|
@@ -136,12 +136,9 @@ class hwf_KB(ClsKB): |
|
|
|
def logic_forward(self, formula): |
|
|
|
if not self.valid_candidate(formula): |
|
|
|
return np.inf |
|
|
|
try: |
|
|
|
mapping = {'0':'0', '1':'1', '2':'2', '3':'3', '4':'4', '5':'5', '6':'6', '7':'7', '8':'8', '9':'9', '+':'+', '-':'-', 'times':'*', 'div':'/'} |
|
|
|
formula = [mapping[f] for f in formula] |
|
|
|
return round(eval(''.join(formula)), 2) |
|
|
|
except ZeroDivisionError: |
|
|
|
return np.inf |
|
|
|
mapping = {'1':'1', '2':'2', '3':'3', '4':'4', '5':'5', '6':'6', '7':'7', '8':'8', '9':'9', '+':'+', '-':'-', 'times':'*', 'div':'/'} |
|
|
|
formula = [mapping[f] for f in formula] |
|
|
|
return round(eval(''.join(formula)), 2) |
|
|
|
|
|
|
|
|
|
|
|
class RegKB(KBBase): |
|
|
|