From d72fc51bbd4566af60729980ee95734b5ca9f247 Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Thu, 16 Nov 2023 16:04:01 +0800 Subject: [PATCH] [ENH] refine reasoning test --- abl/reasoning/reasoner.py | 222 ++++++++++++++++---------------------- 1 file changed, 94 insertions(+), 128 deletions(-) diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 2e57570..b16a595 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -219,13 +219,13 @@ class ReasonerBase: A revised pseudo label through abductive reasoning, which is consistent with the knowledge base. """ - symbol_num = data_sample.elements_num("pred_pseudo_label") - max_revision_num = self._get_max_revision_num(max_revision, symbol_num) - pred_pseudo_label = data_sample.pred_pseudo_label[0] pred_prob = data_sample.pred_prob[0] y = data_sample.Y[0] + symbol_num = len(flatten(pred_pseudo_label)) + max_revision_num = self._get_max_revision_num(max_revision, symbol_num) + if self.use_zoopt: solution = self.zoopt_get_solution( symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num @@ -275,12 +275,11 @@ class ReasonerBase: if __name__ == "__main__": from kb import KBBase, GroundKB, PrologKB - - prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] + from abl.structures import ListData - prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] + ################################ + # Test for MNIST Add reasoning # + ################################ class AddKB(KBBase): def __init__(self, pseudo_label_list=list(range(10)), @@ -290,38 +289,54 @@ if __name__ == "__main__": def logic_forward(self, nums): return sum(nums) - class AddGroundKB(GroundKB): + class AddGroundKB(GroundKB, AddKB): def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]): super().__init__(pseudo_label_list, GKB_len_list) + + def logic_forward(self, nums): + return sum(nums) + + def logic_forward(self, nums): return sum(nums) def test_add(reasoner): - res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) - print(res) - res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) - print(res) - res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) - print(res) - res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) + # favor 1 in first one + prob1 = [[0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] + + # favor 7 in first one + prob2 = [[0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] + + data_samples_add = ListData() + data_samples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]] + data_samples_add.pred_prob = [prob1, prob2, prob1, prob2] + data_samples_add.Y = [8, 8, 17, 10] + + res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0) print(res) - res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) + res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1) + print(res) + res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0) print(res) + res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1) + print(res) # due to more revision allowed, for the 4th, it will favor [7,3] over [1,9] print() - print("AddKB with GKB:") + print("AddGroundKB:") kb = AddGroundKB() reasoner = ReasonerBase(kb, "confidence") test_add(reasoner) - print("AddKB without GKB:") + print("AddKB:") kb = AddKB() reasoner = ReasonerBase(kb, "confidence") test_add(reasoner) - print("AddKB without GKB, no cache") + print("AddKB, no cache") kb = AddKB(use_cache=False) reasoner = ReasonerBase(kb, "confidence") test_add(reasoner) @@ -339,45 +354,20 @@ if __name__ == "__main__": ) reasoner = ReasonerBase(kb, "confidence", use_zoopt=True) test_add(reasoner) - - print("AddKB with multiple inputs at once:") - multiple_prob = [[ - [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], - ], - [ - [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], - ]] - - kb = AddKB() - reasoner = ReasonerBase(kb, "confidence") - res = reasoner.batch_abduce( - multiple_prob, - [[1, 1], [1, 2]], - [4, 8], - max_revision=2, - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - multiple_prob, - [[1, 1], [1, 2]], - [4, 8], - max_revision=2, - require_more_revision=1, - ) - print(res) - print() - + + ################################ + #### Test for HWF reasoning #### + ################################ + class HwfKB(KBBase): def __init__( self, pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "times", "div"], max_err=1e-3, + use_cache=False, ): - super().__init__(pseudo_label_list, max_err) + super().__init__(pseudo_label_list, max_err, use_cache) def _valid_candidate(self, formula): if len(formula) % 2 == 0: @@ -397,7 +387,7 @@ if __name__ == "__main__": formula = [mapping[f] for f in formula] return eval("".join(formula)) - class HwfGroundKB(GroundKB): + class HwfGroundKB(GroundKB, HwfKB): def __init__( self, pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", @@ -407,6 +397,7 @@ if __name__ == "__main__": ): super().__init__(pseudo_label_list, GKB_len_list, max_err) + def _valid_candidate(self, formula): if len(formula) % 2 == 0: return False @@ -416,6 +407,17 @@ if __name__ == "__main__": if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: return False return True + + def _valid_candidate(self, formula): + if len(formula) % 2 == 0: + return False + for i in range(len(formula)): + if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: + return False + if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: + return False + return True + def logic_forward(self, formula): if not self._valid_candidate(formula): @@ -425,88 +427,56 @@ if __name__ == "__main__": formula = [mapping[f] for f in formula] return eval("".join(formula)) + + def logic_forward(self, formula): + if not self._valid_candidate(formula): + return None + mapping = {str(i): str(i) for i in range(1, 10)} + mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) + formula = [mapping[f] for f in formula] + return eval("".join(formula)) + def test_hwf(reasoner): - res = reasoner.batch_abduce( - [None], - [["5", "+", "2"]], - [3], - max_revision=2, - require_more_revision=0, - ) + data_samples_hwf = ListData() + data_samples_hwf.pred_pseudo_label = [["5", "+", "2"], ["5", "+", "9"], ["5", "+", "9"], ["5", "-", "8", "8", "8"]] + data_samples_hwf.pred_prob = [None, None, None, None] + data_samples_hwf.Y = [3, 64, 65, 3.17] + + res = reasoner.batch_abduce(data_samples_hwf, max_revision=3, require_more_revision=0) print(res) - res = reasoner.batch_abduce( - [None], - [["5", "+", "9"]], - [65], - max_revision=3, - require_more_revision=0, - ) + res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.5, require_more_revision=3) print(res) - res = reasoner.batch_abduce( - [None], - [["5", "8", "8", "8", "8"]], - [3.17], - max_revision=5, - require_more_revision=3, - ) + res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.9, require_more_revision=0) print(res) print() - def test_hwf_multiple(reasoner, max_revisions): - res = reasoner.batch_abduce( - [None, None], - [["5", "+", "2"], ["5", "+", "9"]], - [3, 64], - max_revision=max_revisions[0], - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - [None, None], - [["5", "+", "2"], ["5", "+", "9"]], - [3, 64], - max_revision=max_revisions[1], - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - [None, None], - [["5", "+", "2"], ["5", "+", "9"]], - [3, 65], - max_revision=max_revisions[2], - require_more_revision=0, - ) - print(res) - print() - print("HwfKB with GKB, max_err=0.1") + print("HwfGroundKB, max_err=0.1:") kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=0.1) reasoner = ReasonerBase(kb, "hamming") test_hwf(reasoner) - print("HwfKB without GKB, max_err=0.1") + print("HwfKB, max_err=0.1:") kb = HwfKB(max_err=0.1) reasoner = ReasonerBase(kb, "hamming") test_hwf(reasoner) - print("HwfKB with GKB, max_err=1") + print("HwfGroundKB, max_err=1:") kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=1) reasoner = ReasonerBase(kb, "hamming") test_hwf(reasoner) - print("HwfKB without GKB, max_err=1") + print("HwfKB, max_err=1:") kb = HwfKB(max_err=1) reasoner = ReasonerBase(kb, "hamming") test_hwf(reasoner) - - print("HwfKB with multiple inputs at once:") - kb = HwfKB(max_err=0.1) - reasoner = ReasonerBase(kb, "hamming") - test_hwf_multiple(reasoner, max_revisions=[1,3,3]) - print("max_revision is float") - test_hwf_multiple(reasoner, max_revisions=[0.5,0.9,0.9]) - + + ################################ + #### Test for HED reasoning #### + ################################ + + class HedKB(PrologKB): def __init__(self, pseudo_label_list, pl_file): super().__init__(pseudo_label_list, pl_file) @@ -599,28 +569,24 @@ if __name__ == "__main__": inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"] - print("HedKB logic forward") - print(kb.logic_forward(consist_exs)) + print("HedKB logic forward:") + print(kb.logic_forward(consist_exs), end=" ") print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2)) print() - print("HedKB consist rule") - print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules)) + print("HedKB consist rule:") + print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules), end=" ") print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules)) print() + data_sample_hed = ListData() + data_sample_hed.pred_pseudo_label = [consist_exs, inconsist_exs1, inconsist_exs2] + data_sample_hed.pred_prob = [[None] * len(consist_exs), [None] * len(inconsist_exs1), [None] * len(inconsist_exs2)] + data_sample_hed.Y = [[None] * len(consist_exs), [None] * len(inconsist_exs1), [None] * len(inconsist_exs2)] + print("HedReasoner abduce") - res = reasoner.abduce( - [[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs) - ) - print(res) - res = reasoner.abduce( - [[[None]]] * len(inconsist_exs1), inconsist_exs1, [None] * len(inconsist_exs1) - ) - print(res) - res = reasoner.abduce( - [[[None]]] * len(inconsist_exs2), inconsist_exs2, [None] * len(inconsist_exs2) - ) - print(res) + res = reasoner.batch_abduce(data_sample_hed) + for r in res: + print(r) print() print("HedReasoner abduce rules")