Browse Source

Update kb.py

pull/3/head
troyyyyy GitHub 2 years ago
parent
commit
523d741c8d
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 5 deletions
  1. +11
    -5
      abducer/kb.py

+ 11
- 5
abducer/kb.py View File

@@ -115,7 +115,7 @@ class add_KB(KBBase):
def __len__(self):
return sum(self._dict_len(v) for v in self.base.values())

class ClsKB(KBBase):
class cls_KB(KBBase):
def __init__(self, X, Y = None):
super().__init__()
self.base = {}
@@ -128,6 +128,9 @@ class ClsKB(KBBase):

for x, y in zip(X, Y):
self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x))
def logic_forward(self):
return None

def get_candidates(self, key, length = None):
if key is None:
@@ -146,7 +149,7 @@ class ClsKB(KBBase):
def __len__(self):
return sum(self._dict_len(v) for v in self.base.values())

class RegKB(KBBase):
class reg_KB(KBBase):
def __init__(self, X, Y = None):
super().__init__()
tmp_dict = {}
@@ -159,7 +162,10 @@ class RegKB(KBBase):
X = [x for y, x in data]
Y = [y for y, x in data]
self.base[l] = (X, Y)

def logic_forward(self):
return None
def get_candidates(self, key, length = None):
if key is None:
return self.get_all_candidates()
@@ -207,7 +213,7 @@ if __name__ == "__main__":
X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"]
Y = [2, 1, 1, 2, 2]
kb = ClsKB(X, Y)
kb = cls_KB(X, Y)
print(len(kb))
res = kb.get_candidates(2, 5)
print(res)
@@ -219,7 +225,7 @@ if __name__ == "__main__":
X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"]
Y = [2, 1, 1, 2, 1.5, 1.5]
kb = RegKB(X, Y)
kb = reg_KB(X, Y)
print(len(kb))
res = kb.get_candidates(1.6)
print(res)


Loading…
Cancel
Save