Browse Source

[ENH] refine reasoning test

pull/4/head
troyyyyy 1 year ago
parent
commit
d72fc51bbd
1 changed files with 94 additions and 128 deletions
  1. +94
    -128
      abl/reasoning/reasoner.py

+ 94
- 128
abl/reasoning/reasoner.py View File

@@ -219,13 +219,13 @@ class ReasonerBase:
A revised pseudo label through abductive reasoning, which is consistent with the
knowledge base.
"""
symbol_num = data_sample.elements_num("pred_pseudo_label")
max_revision_num = self._get_max_revision_num(max_revision, symbol_num)
pred_pseudo_label = data_sample.pred_pseudo_label[0]
pred_prob = data_sample.pred_prob[0]
y = data_sample.Y[0]
symbol_num = len(flatten(pred_pseudo_label))
max_revision_num = self._get_max_revision_num(max_revision, symbol_num)
if self.use_zoopt:
solution = self.zoopt_get_solution(
symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num
@@ -275,12 +275,11 @@ class ReasonerBase:

if __name__ == "__main__":
from kb import KBBase, GroundKB, PrologKB

prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]]
from abl.structures import ListData
prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]]
################################
# Test for MNIST Add reasoning #
################################

class AddKB(KBBase):
def __init__(self, pseudo_label_list=list(range(10)),
@@ -290,38 +289,54 @@ if __name__ == "__main__":
def logic_forward(self, nums):
return sum(nums)
class AddGroundKB(GroundKB):
class AddGroundKB(GroundKB, AddKB):
def __init__(self, pseudo_label_list=list(range(10)),
GKB_len_list=[2]):
super().__init__(pseudo_label_list, GKB_len_list)


def logic_forward(self, nums):
return sum(nums)
def logic_forward(self, nums):
return sum(nums)
def test_add(reasoner):
res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0)
# favor 1 in first one
prob1 = [[0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]
# favor 7 in first one
prob2 = [[0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]
data_samples_add = ListData()
data_samples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]]
data_samples_add.pred_prob = [prob1, prob2, prob1, prob2]
data_samples_add.Y = [8, 8, 17, 10]
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1)
print(res)
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1)
print(res) # due to more revision allowed, for the 4th, it will favor [7,3] over [1,9]
print()

print("AddKB with GKB:")
print("AddGroundKB:")
kb = AddGroundKB()
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("AddKB without GKB:")
print("AddKB:")
kb = AddKB()
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("AddKB without GKB, no cache")
print("AddKB, no cache")
kb = AddKB(use_cache=False)
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)
@@ -339,45 +354,20 @@ if __name__ == "__main__":
)
reasoner = ReasonerBase(kb, "confidence", use_zoopt=True)
test_add(reasoner)

print("AddKB with multiple inputs at once:")
multiple_prob = [[
[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
],
[
[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
]]

kb = AddKB()
reasoner = ReasonerBase(kb, "confidence")
res = reasoner.batch_abduce(
multiple_prob,
[[1, 1], [1, 2]],
[4, 8],
max_revision=2,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
multiple_prob,
[[1, 1], [1, 2]],
[4, 8],
max_revision=2,
require_more_revision=1,
)
print(res)
print()

################################
#### Test for HWF reasoning ####
################################
class HwfKB(KBBase):
def __init__(
self,
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9",
"+", "-", "times", "div"],
max_err=1e-3,
use_cache=False,
):
super().__init__(pseudo_label_list, max_err)
super().__init__(pseudo_label_list, max_err, use_cache)

def _valid_candidate(self, formula):
if len(formula) % 2 == 0:
@@ -397,7 +387,7 @@ if __name__ == "__main__":
formula = [mapping[f] for f in formula]
return eval("".join(formula))
class HwfGroundKB(GroundKB):
class HwfGroundKB(GroundKB, HwfKB):
def __init__(
self,
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9",
@@ -407,6 +397,7 @@ if __name__ == "__main__":
):
super().__init__(pseudo_label_list, GKB_len_list, max_err)


def _valid_candidate(self, formula):
if len(formula) % 2 == 0:
return False
@@ -416,6 +407,17 @@ if __name__ == "__main__":
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]:
return False
return True
def _valid_candidate(self, formula):
if len(formula) % 2 == 0:
return False
for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]:
return False
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]:
return False
return True


def logic_forward(self, formula):
if not self._valid_candidate(formula):
@@ -425,88 +427,56 @@ if __name__ == "__main__":
formula = [mapping[f] for f in formula]
return eval("".join(formula))
def logic_forward(self, formula):
if not self._valid_candidate(formula):
return None
mapping = {str(i): str(i) for i in range(1, 10)}
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"})
formula = [mapping[f] for f in formula]
return eval("".join(formula))
def test_hwf(reasoner):
res = reasoner.batch_abduce(
[None],
[["5", "+", "2"]],
[3],
max_revision=2,
require_more_revision=0,
)
data_samples_hwf = ListData()
data_samples_hwf.pred_pseudo_label = [["5", "+", "2"], ["5", "+", "9"], ["5", "+", "9"], ["5", "-", "8", "8", "8"]]
data_samples_hwf.pred_prob = [None, None, None, None]
data_samples_hwf.Y = [3, 64, 65, 3.17]
res = reasoner.batch_abduce(data_samples_hwf, max_revision=3, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(
[None],
[["5", "+", "9"]],
[65],
max_revision=3,
require_more_revision=0,
)
res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.5, require_more_revision=3)
print(res)
res = reasoner.batch_abduce(
[None],
[["5", "8", "8", "8", "8"]],
[3.17],
max_revision=5,
require_more_revision=3,
)
res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.9, require_more_revision=0)
print(res)
print()
def test_hwf_multiple(reasoner, max_revisions):
res = reasoner.batch_abduce(
[None, None],
[["5", "+", "2"], ["5", "+", "9"]],
[3, 64],
max_revision=max_revisions[0],
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[None, None],
[["5", "+", "2"], ["5", "+", "9"]],
[3, 64],
max_revision=max_revisions[1],
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[None, None],
[["5", "+", "2"], ["5", "+", "9"]],
[3, 65],
max_revision=max_revisions[2],
require_more_revision=0,
)
print(res)
print()

print("HwfKB with GKB, max_err=0.1")
print("HwfGroundKB, max_err=0.1:")
kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HwfKB without GKB, max_err=0.1")
print("HwfKB, max_err=0.1:")
kb = HwfKB(max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HwfKB with GKB, max_err=1")
print("HwfGroundKB, max_err=1:")
kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HwfKB without GKB, max_err=1")
print("HwfKB, max_err=1:")
kb = HwfKB(max_err=1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HwfKB with multiple inputs at once:")
kb = HwfKB(max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf_multiple(reasoner, max_revisions=[1,3,3])
print("max_revision is float")
test_hwf_multiple(reasoner, max_revisions=[0.5,0.9,0.9])

################################
#### Test for HED reasoning ####
################################
class HedKB(PrologKB):
def __init__(self, pseudo_label_list, pl_file):
super().__init__(pseudo_label_list, pl_file)
@@ -599,28 +569,24 @@ if __name__ == "__main__":
inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]]
rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"]

print("HedKB logic forward")
print(kb.logic_forward(consist_exs))
print("HedKB logic forward:")
print(kb.logic_forward(consist_exs), end=" ")
print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2))
print()
print("HedKB consist rule")
print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules))
print("HedKB consist rule:")
print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules), end=" ")
print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules))
print()

data_sample_hed = ListData()
data_sample_hed.pred_pseudo_label = [consist_exs, inconsist_exs1, inconsist_exs2]
data_sample_hed.pred_prob = [[None] * len(consist_exs), [None] * len(inconsist_exs1), [None] * len(inconsist_exs2)]
data_sample_hed.Y = [[None] * len(consist_exs), [None] * len(inconsist_exs1), [None] * len(inconsist_exs2)]

print("HedReasoner abduce")
res = reasoner.abduce(
[[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs)
)
print(res)
res = reasoner.abduce(
[[[None]]] * len(inconsist_exs1), inconsist_exs1, [None] * len(inconsist_exs1)
)
print(res)
res = reasoner.abduce(
[[[None]]] * len(inconsist_exs2), inconsist_exs2, [None] * len(inconsist_exs2)
)
print(res)
res = reasoner.batch_abduce(data_sample_hed)
for r in res:
print(r)
print()

print("HedReasoner abduce rules")


Loading…
Cancel
Save