|
|
@@ -25,7 +25,7 @@ class ReasonerBase: |
|
|
|
A mapping from index to label. If not provided, a default order-based mapping is |
|
|
|
created. |
|
|
|
use_zoopt : bool, optional |
|
|
|
Whether to use the Zoopt library during abductive reasoning. Default to False. |
|
|
|
Whether to use the Zoopt library during abductive reasoning. Defaults to False. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, kb, dist_func="hamming", mapping=None, use_zoopt=False): |
|
|
@@ -203,7 +203,7 @@ class ReasonerBase: |
|
|
|
max_revision : int or float, optional |
|
|
|
The upper limit on the number of revisions. If float, denotes the fraction of the |
|
|
|
total length that can be revised. A value of -1 implies no restriction on the number |
|
|
|
of revisions. Default to -1. |
|
|
|
of revisions. Defaults to -1. |
|
|
|
require_more_revision : int, optional |
|
|
|
Specifies additional number of revisions permitted beyond the minimum required. |
|
|
|
Defaults to 0. |
|
|
@@ -267,7 +267,7 @@ class ReasonerBase: |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
from kb import KBBase, ground_KB, prolog_KB |
|
|
|
from kb import KBBase, GroundKB, PrologKB |
|
|
|
|
|
|
|
prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], |
|
|
|
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] |
|
|
@@ -275,7 +275,7 @@ if __name__ == "__main__": |
|
|
|
prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], |
|
|
|
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] |
|
|
|
|
|
|
|
class add_KB(KBBase): |
|
|
|
class AddKB(KBBase): |
|
|
|
def __init__(self, pseudo_label_list=list(range(10)), |
|
|
|
use_cache=True): |
|
|
|
super().__init__(pseudo_label_list, use_cache=use_cache) |
|
|
@@ -283,7 +283,7 @@ if __name__ == "__main__": |
|
|
|
def logic_forward(self, nums): |
|
|
|
return sum(nums) |
|
|
|
|
|
|
|
class add_ground_KB(ground_KB): |
|
|
|
class AddGroundKB(GroundKB): |
|
|
|
def __init__(self, pseudo_label_list=list(range(10)), |
|
|
|
GKB_len_list=[2]): |
|
|
|
super().__init__(pseudo_label_list, GKB_len_list) |
|
|
@@ -304,36 +304,36 @@ if __name__ == "__main__": |
|
|
|
print(res) |
|
|
|
print() |
|
|
|
|
|
|
|
print("add_KB with GKB:") |
|
|
|
kb = add_ground_KB() |
|
|
|
print("AddKB with GKB:") |
|
|
|
kb = AddGroundKB() |
|
|
|
reasoner = ReasonerBase(kb, "confidence") |
|
|
|
test_add(reasoner) |
|
|
|
|
|
|
|
print("add_KB without GKB:") |
|
|
|
kb = add_KB() |
|
|
|
print("AddKB without GKB:") |
|
|
|
kb = AddKB() |
|
|
|
reasoner = ReasonerBase(kb, "confidence") |
|
|
|
test_add(reasoner) |
|
|
|
|
|
|
|
print("add_KB without GKB, no cache") |
|
|
|
kb = add_KB(use_cache=False) |
|
|
|
print("AddKB without GKB, no cache") |
|
|
|
kb = AddKB(use_cache=False) |
|
|
|
reasoner = ReasonerBase(kb, "confidence") |
|
|
|
test_add(reasoner) |
|
|
|
|
|
|
|
print("prolog_KB with add.pl:") |
|
|
|
kb = prolog_KB(pseudo_label_list=list(range(10)), |
|
|
|
print("PrologKB with add.pl:") |
|
|
|
kb = PrologKB(pseudo_label_list=list(range(10)), |
|
|
|
pl_file="examples/mnist_add/datasets/add.pl") |
|
|
|
reasoner = ReasonerBase(kb, "confidence") |
|
|
|
test_add(reasoner) |
|
|
|
|
|
|
|
print("prolog_KB with add.pl using zoopt:") |
|
|
|
kb = prolog_KB( |
|
|
|
print("PrologKB with add.pl using zoopt:") |
|
|
|
kb = PrologKB( |
|
|
|
pseudo_label_list=list(range(10)), |
|
|
|
pl_file="examples/mnist_add/datasets/add.pl", |
|
|
|
) |
|
|
|
reasoner = ReasonerBase(kb, "confidence", use_zoopt=True) |
|
|
|
test_add(reasoner) |
|
|
|
|
|
|
|
print("add_KB with multiple inputs at once:") |
|
|
|
print("AddKB with multiple inputs at once:") |
|
|
|
multiple_prob = [[ |
|
|
|
[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], |
|
|
|
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], |
|
|
@@ -343,7 +343,7 @@ if __name__ == "__main__": |
|
|
|
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], |
|
|
|
]] |
|
|
|
|
|
|
|
kb = add_KB() |
|
|
|
kb = AddKB() |
|
|
|
reasoner = ReasonerBase(kb, "confidence") |
|
|
|
res = reasoner.batch_abduce( |
|
|
|
multiple_prob, |
|
|
@@ -363,7 +363,7 @@ if __name__ == "__main__": |
|
|
|
print(res) |
|
|
|
print() |
|
|
|
|
|
|
|
class HWF_KB(KBBase): |
|
|
|
class HwfKB(KBBase): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", |
|
|
@@ -390,7 +390,7 @@ if __name__ == "__main__": |
|
|
|
formula = [mapping[f] for f in formula] |
|
|
|
return eval("".join(formula)) |
|
|
|
|
|
|
|
class HWF_ground_KB(ground_KB): |
|
|
|
class HwfGroundKB(GroundKB): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", |
|
|
@@ -472,35 +472,35 @@ if __name__ == "__main__": |
|
|
|
print(res) |
|
|
|
print() |
|
|
|
|
|
|
|
print("HWF_KB with GKB, max_err=0.1") |
|
|
|
kb = HWF_ground_KB(GKB_len_list=[1, 3, 5], max_err=0.1) |
|
|
|
print("HwfKB with GKB, max_err=0.1") |
|
|
|
kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=0.1) |
|
|
|
reasoner = ReasonerBase(kb, "hamming") |
|
|
|
test_hwf(reasoner) |
|
|
|
|
|
|
|
print("HWF_KB without GKB, max_err=0.1") |
|
|
|
kb = HWF_KB(max_err=0.1) |
|
|
|
print("HwfKB without GKB, max_err=0.1") |
|
|
|
kb = HwfKB(max_err=0.1) |
|
|
|
reasoner = ReasonerBase(kb, "hamming") |
|
|
|
test_hwf(reasoner) |
|
|
|
|
|
|
|
print("HWF_KB with GKB, max_err=1") |
|
|
|
kb = HWF_ground_KB(GKB_len_list=[1, 3, 5], max_err=1) |
|
|
|
print("HwfKB with GKB, max_err=1") |
|
|
|
kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=1) |
|
|
|
reasoner = ReasonerBase(kb, "hamming") |
|
|
|
test_hwf(reasoner) |
|
|
|
|
|
|
|
print("HWF_KB without GKB, max_err=1") |
|
|
|
kb = HWF_KB(max_err=1) |
|
|
|
print("HwfKB without GKB, max_err=1") |
|
|
|
kb = HwfKB(max_err=1) |
|
|
|
reasoner = ReasonerBase(kb, "hamming") |
|
|
|
test_hwf(reasoner) |
|
|
|
|
|
|
|
print("HWF_KB with multiple inputs at once:") |
|
|
|
kb = HWF_KB(max_err=0.1) |
|
|
|
print("HwfKB with multiple inputs at once:") |
|
|
|
kb = HwfKB(max_err=0.1) |
|
|
|
reasoner = ReasonerBase(kb, "hamming") |
|
|
|
test_hwf_multiple(reasoner, max_revisions=[1,3,3]) |
|
|
|
|
|
|
|
print("max_revision is float") |
|
|
|
test_hwf_multiple(reasoner, max_revisions=[0.5,0.9,0.9]) |
|
|
|
|
|
|
|
class HED_prolog_KB(prolog_KB): |
|
|
|
class HedKB(PrologKB): |
|
|
|
def __init__(self, pseudo_label_list, pl_file): |
|
|
|
super().__init__(pseudo_label_list, pl_file) |
|
|
|
|
|
|
@@ -518,7 +518,7 @@ if __name__ == "__main__": |
|
|
|
rules = [rule.value for rule in prolog_rules] |
|
|
|
return rules |
|
|
|
|
|
|
|
class HED_Reasoner(ReasonerBase): |
|
|
|
class HedReasoner(ReasonerBase): |
|
|
|
def __init__(self, kb, dist_func="hamming"): |
|
|
|
super().__init__(kb, dist_func, use_zoopt=True) |
|
|
|
|
|
|
@@ -573,11 +573,11 @@ if __name__ == "__main__": |
|
|
|
def abduce_rules(self, pred_res): |
|
|
|
return self.kb.abduce_rules(pred_res) |
|
|
|
|
|
|
|
kb = HED_prolog_KB( |
|
|
|
kb = HedKB( |
|
|
|
pseudo_label_list=[1, 0, "+", "="], |
|
|
|
pl_file="examples/hed/datasets/learn_add.pl", |
|
|
|
) |
|
|
|
reasoner = HED_Reasoner(kb) |
|
|
|
reasoner = HedReasoner(kb) |
|
|
|
consist_exs = [ |
|
|
|
[1, 1, "+", 0, "=", 1, 1], |
|
|
|
[1, "+", 1, "=", 1, 0], |
|
|
@@ -592,16 +592,16 @@ if __name__ == "__main__": |
|
|
|
inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] |
|
|
|
rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"] |
|
|
|
|
|
|
|
print("HED_kb logic forward") |
|
|
|
print("HedKB logic forward") |
|
|
|
print(kb.logic_forward(consist_exs)) |
|
|
|
print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2)) |
|
|
|
print() |
|
|
|
print("HED_kb consist rule") |
|
|
|
print("HedKB consist rule") |
|
|
|
print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules)) |
|
|
|
print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules)) |
|
|
|
print() |
|
|
|
|
|
|
|
print("HED_Reasoner abduce") |
|
|
|
print("HedReasoner abduce") |
|
|
|
res = reasoner.abduce( |
|
|
|
[[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs) |
|
|
|
) |
|
|
@@ -616,6 +616,6 @@ if __name__ == "__main__": |
|
|
|
print(res) |
|
|
|
print() |
|
|
|
|
|
|
|
print("HED_Reasoner abduce rules") |
|
|
|
print("HedReasoner abduce rules") |
|
|
|
abduced_rules = reasoner.abduce_rules(consist_exs) |
|
|
|
print(abduced_rules) |