|
|
@@ -275,12 +275,11 @@ class ReasonerBase: |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
from kb import KBBase, GroundKB, PrologKB |
|
|
|
|
|
|
|
prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], |
|
|
|
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] |
|
|
|
from abl.structures import ListData |
|
|
|
|
|
|
|
prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], |
|
|
|
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] |
|
|
|
################################ |
|
|
|
# Test for MNIST Add reasoning # |
|
|
|
################################ |
|
|
|
|
|
|
|
class AddKB(KBBase): |
|
|
|
def __init__(self, pseudo_label_list=list(range(10)), |
|
|
@@ -290,38 +289,54 @@ if __name__ == "__main__": |
|
|
|
def logic_forward(self, nums): |
|
|
|
return sum(nums) |
|
|
|
|
|
|
|
class AddGroundKB(GroundKB): |
|
|
|
class AddGroundKB(GroundKB, AddKB): |
|
|
|
def __init__(self, pseudo_label_list=list(range(10)), |
|
|
|
GKB_len_list=[2]): |
|
|
|
super().__init__(pseudo_label_list, GKB_len_list) |
|
|
|
|
|
|
|
|
|
|
|
def logic_forward(self, nums): |
|
|
|
return sum(nums) |
|
|
|
|
|
|
|
|
|
|
|
def logic_forward(self, nums): |
|
|
|
return sum(nums) |
|
|
|
|
|
|
|
def test_add(reasoner): |
|
|
|
res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) |
|
|
|
print(res) |
|
|
|
res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) |
|
|
|
print(res) |
|
|
|
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) |
|
|
|
print(res) |
|
|
|
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) |
|
|
|
# favor 1 in first one |
|
|
|
prob1 = [[0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0], |
|
|
|
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] |
|
|
|
|
|
|
|
# favor 7 in first one |
|
|
|
prob2 = [[0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0], |
|
|
|
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] |
|
|
|
|
|
|
|
data_samples_add = ListData() |
|
|
|
data_samples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]] |
|
|
|
data_samples_add.pred_prob = [prob1, prob2, prob1, prob2] |
|
|
|
data_samples_add.Y = [8, 8, 17, 10] |
|
|
|
|
|
|
|
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0) |
|
|
|
print(res) |
|
|
|
res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) |
|
|
|
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1) |
|
|
|
print(res) |
|
|
|
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0) |
|
|
|
print(res) |
|
|
|
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1) |
|
|
|
print(res) # due to more revision allowed, for the 4th, it will favor [7,3] over [1,9] |
|
|
|
print() |
|
|
|
|
|
|
|
print("AddKB with GKB:") |
|
|
|
print("AddGroundKB:") |
|
|
|
kb = AddGroundKB() |
|
|
|
reasoner = ReasonerBase(kb, "confidence") |
|
|
|
test_add(reasoner) |
|
|
|
|
|
|
|
print("AddKB without GKB:") |
|
|
|
print("AddKB:") |
|
|
|
kb = AddKB() |
|
|
|
reasoner = ReasonerBase(kb, "confidence") |
|
|
|
test_add(reasoner) |
|
|
|
|
|
|
|
print("AddKB without GKB, no cache") |
|
|
|
print("AddKB, no cache") |
|
|
|
kb = AddKB(use_cache=False) |
|
|
|
reasoner = ReasonerBase(kb, "confidence") |
|
|
|
test_add(reasoner) |
|
|
@@ -339,45 +354,20 @@ if __name__ == "__main__": |
|
|
|
) |
|
|
|
reasoner = ReasonerBase(kb, "confidence", use_zoopt=True) |
|
|
|
test_add(reasoner) |
|
|
|
|
|
|
|
print("AddKB with multiple inputs at once:") |
|
|
|
multiple_prob = [[ |
|
|
|
[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], |
|
|
|
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], |
|
|
|
], |
|
|
|
[ |
|
|
|
[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], |
|
|
|
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], |
|
|
|
]] |
|
|
|
|
|
|
|
kb = AddKB() |
|
|
|
reasoner = ReasonerBase(kb, "confidence") |
|
|
|
res = reasoner.batch_abduce( |
|
|
|
multiple_prob, |
|
|
|
[[1, 1], [1, 2]], |
|
|
|
[4, 8], |
|
|
|
max_revision=2, |
|
|
|
require_more_revision=0, |
|
|
|
) |
|
|
|
print(res) |
|
|
|
res = reasoner.batch_abduce( |
|
|
|
multiple_prob, |
|
|
|
[[1, 1], [1, 2]], |
|
|
|
[4, 8], |
|
|
|
max_revision=2, |
|
|
|
require_more_revision=1, |
|
|
|
) |
|
|
|
print(res) |
|
|
|
print() |
|
|
|
|
|
|
|
|
|
|
|
################################ |
|
|
|
#### Test for HWF reasoning #### |
|
|
|
################################ |
|
|
|
|
|
|
|
class HwfKB(KBBase): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", |
|
|
|
"+", "-", "times", "div"], |
|
|
|
max_err=1e-3, |
|
|
|
use_cache=False, |
|
|
|
): |
|
|
|
super().__init__(pseudo_label_list, max_err) |
|
|
|
super().__init__(pseudo_label_list, max_err, use_cache) |
|
|
|
|
|
|
|
def _valid_candidate(self, formula): |
|
|
|
if len(formula) % 2 == 0: |
|
|
@@ -397,7 +387,7 @@ if __name__ == "__main__": |
|
|
|
formula = [mapping[f] for f in formula] |
|
|
|
return eval("".join(formula)) |
|
|
|
|
|
|
|
class HwfGroundKB(GroundKB): |
|
|
|
class HwfGroundKB(GroundKB, HwfKB): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", |
|
|
@@ -407,6 +397,17 @@ if __name__ == "__main__": |
|
|
|
): |
|
|
|
super().__init__(pseudo_label_list, GKB_len_list, max_err) |
|
|
|
|
|
|
|
|
|
|
|
def _valid_candidate(self, formula): |
|
|
|
if len(formula) % 2 == 0: |
|
|
|
return False |
|
|
|
for i in range(len(formula)): |
|
|
|
if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: |
|
|
|
return False |
|
|
|
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: |
|
|
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
def _valid_candidate(self, formula): |
|
|
|
if len(formula) % 2 == 0: |
|
|
|
return False |
|
|
@@ -417,6 +418,16 @@ if __name__ == "__main__": |
|
|
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
def logic_forward(self, formula): |
|
|
|
if not self._valid_candidate(formula): |
|
|
|
return None |
|
|
|
mapping = {str(i): str(i) for i in range(1, 10)} |
|
|
|
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) |
|
|
|
formula = [mapping[f] for f in formula] |
|
|
|
return eval("".join(formula)) |
|
|
|
|
|
|
|
|
|
|
|
def logic_forward(self, formula): |
|
|
|
if not self._valid_candidate(formula): |
|
|
|
return None |
|
|
@@ -426,87 +437,46 @@ if __name__ == "__main__": |
|
|
|
return eval("".join(formula)) |
|
|
|
|
|
|
|
def test_hwf(reasoner): |
|
|
|
res = reasoner.batch_abduce( |
|
|
|
[None], |
|
|
|
[["5", "+", "2"]], |
|
|
|
[3], |
|
|
|
max_revision=2, |
|
|
|
require_more_revision=0, |
|
|
|
) |
|
|
|
data_samples_hwf = ListData() |
|
|
|
data_samples_hwf.pred_pseudo_label = [["5", "+", "2"], ["5", "+", "9"], ["5", "+", "9"], ["5", "-", "8", "8", "8"]] |
|
|
|
data_samples_hwf.pred_prob = [None, None, None, None] |
|
|
|
data_samples_hwf.Y = [3, 64, 65, 3.17] |
|
|
|
|
|
|
|
res = reasoner.batch_abduce(data_samples_hwf, max_revision=3, require_more_revision=0) |
|
|
|
print(res) |
|
|
|
res = reasoner.batch_abduce( |
|
|
|
[None], |
|
|
|
[["5", "+", "9"]], |
|
|
|
[65], |
|
|
|
max_revision=3, |
|
|
|
require_more_revision=0, |
|
|
|
) |
|
|
|
res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.5, require_more_revision=3) |
|
|
|
print(res) |
|
|
|
res = reasoner.batch_abduce( |
|
|
|
[None], |
|
|
|
[["5", "8", "8", "8", "8"]], |
|
|
|
[3.17], |
|
|
|
max_revision=5, |
|
|
|
require_more_revision=3, |
|
|
|
) |
|
|
|
res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.9, require_more_revision=0) |
|
|
|
print(res) |
|
|
|
print() |
|
|
|
|
|
|
|
def test_hwf_multiple(reasoner, max_revisions): |
|
|
|
res = reasoner.batch_abduce( |
|
|
|
[None, None], |
|
|
|
[["5", "+", "2"], ["5", "+", "9"]], |
|
|
|
[3, 64], |
|
|
|
max_revision=max_revisions[0], |
|
|
|
require_more_revision=0, |
|
|
|
) |
|
|
|
print(res) |
|
|
|
res = reasoner.batch_abduce( |
|
|
|
[None, None], |
|
|
|
[["5", "+", "2"], ["5", "+", "9"]], |
|
|
|
[3, 64], |
|
|
|
max_revision=max_revisions[1], |
|
|
|
require_more_revision=0, |
|
|
|
) |
|
|
|
print(res) |
|
|
|
res = reasoner.batch_abduce( |
|
|
|
[None, None], |
|
|
|
[["5", "+", "2"], ["5", "+", "9"]], |
|
|
|
[3, 65], |
|
|
|
max_revision=max_revisions[2], |
|
|
|
require_more_revision=0, |
|
|
|
) |
|
|
|
print(res) |
|
|
|
print() |
|
|
|
|
|
|
|
print("HwfKB with GKB, max_err=0.1") |
|
|
|
print("HwfGroundKB, max_err=0.1:") |
|
|
|
kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=0.1) |
|
|
|
reasoner = ReasonerBase(kb, "hamming") |
|
|
|
test_hwf(reasoner) |
|
|
|
|
|
|
|
print("HwfKB without GKB, max_err=0.1") |
|
|
|
print("HwfKB, max_err=0.1:") |
|
|
|
kb = HwfKB(max_err=0.1) |
|
|
|
reasoner = ReasonerBase(kb, "hamming") |
|
|
|
test_hwf(reasoner) |
|
|
|
|
|
|
|
print("HwfKB with GKB, max_err=1") |
|
|
|
print("HwfGroundKB, max_err=1:") |
|
|
|
kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=1) |
|
|
|
reasoner = ReasonerBase(kb, "hamming") |
|
|
|
test_hwf(reasoner) |
|
|
|
|
|
|
|
print("HwfKB without GKB, max_err=1") |
|
|
|
print("HwfKB, max_err=1:") |
|
|
|
kb = HwfKB(max_err=1) |
|
|
|
reasoner = ReasonerBase(kb, "hamming") |
|
|
|
test_hwf(reasoner) |
|
|
|
|
|
|
|
print("HwfKB with multiple inputs at once:") |
|
|
|
kb = HwfKB(max_err=0.1) |
|
|
|
reasoner = ReasonerBase(kb, "hamming") |
|
|
|
test_hwf_multiple(reasoner, max_revisions=[1,3,3]) |
|
|
|
|
|
|
|
print("max_revision is float") |
|
|
|
test_hwf_multiple(reasoner, max_revisions=[0.5,0.9,0.9]) |
|
|
|
|
|
|
|
|
|
|
|
################################ |
|
|
|
#### Test for HED reasoning #### |
|
|
|
################################ |
|
|
|
|
|
|
|
|
|
|
|
class HedKB(PrologKB): |
|
|
|
def __init__(self, pseudo_label_list, pl_file): |
|
|
|
super().__init__(pseudo_label_list, pl_file) |
|
|
@@ -599,28 +569,24 @@ if __name__ == "__main__": |
|
|
|
inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] |
|
|
|
rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"] |
|
|
|
|
|
|
|
print("HedKB logic forward") |
|
|
|
print(kb.logic_forward(consist_exs)) |
|
|
|
print("HedKB logic forward:") |
|
|
|
print(kb.logic_forward(consist_exs), end=" ") |
|
|
|
print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2)) |
|
|
|
print() |
|
|
|
print("HedKB consist rule") |
|
|
|
print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules)) |
|
|
|
print("HedKB consist rule:") |
|
|
|
print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules), end=" ") |
|
|
|
print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules)) |
|
|
|
print() |
|
|
|
|
|
|
|
data_sample_hed = ListData() |
|
|
|
data_sample_hed.pred_pseudo_label = [consist_exs, inconsist_exs1, inconsist_exs2] |
|
|
|
data_sample_hed.pred_prob = [[None] * len(consist_exs), [None] * len(inconsist_exs1), [None] * len(inconsist_exs2)] |
|
|
|
data_sample_hed.Y = [[None] * len(consist_exs), [None] * len(inconsist_exs1), [None] * len(inconsist_exs2)] |
|
|
|
|
|
|
|
print("HedReasoner abduce") |
|
|
|
res = reasoner.abduce( |
|
|
|
[[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs) |
|
|
|
) |
|
|
|
print(res) |
|
|
|
res = reasoner.abduce( |
|
|
|
[[[None]]] * len(inconsist_exs1), inconsist_exs1, [None] * len(inconsist_exs1) |
|
|
|
) |
|
|
|
print(res) |
|
|
|
res = reasoner.abduce( |
|
|
|
[[[None]]] * len(inconsist_exs2), inconsist_exs2, [None] * len(inconsist_exs2) |
|
|
|
) |
|
|
|
print(res) |
|
|
|
res = reasoner.batch_abduce(data_sample_hed) |
|
|
|
for r in res: |
|
|
|
print(r) |
|
|
|
print() |
|
|
|
|
|
|
|
print("HedReasoner abduce rules") |
|
|
|