Browse Source

[MNT] resolve some comments in reasoning

pull/3/head
troyyyyy 1 year ago
parent
commit
aa9c949446
2 changed files with 47 additions and 48 deletions
  1. +10
    -11
      abl/reasoning/kb.py
  2. +37
    -37
      abl/reasoning/reasoner.py

+ 10
- 11
abl/reasoning/kb.py View File

@@ -1,17 +1,16 @@
from abc import ABC, abstractmethod
import bisect
import numpy as np

from collections import defaultdict
from itertools import product, combinations

from ..utils.utils import flatten, reform_idx, hamming_dist, to_hashable, hashable_to_list

from multiprocessing import Pool

from functools import lru_cache

import numpy as np
import pyswip

from ..utils.utils import flatten, reform_idx, hamming_dist, to_hashable, hashable_to_list


class KBBase(ABC):
"""
Base class for knowledge base.
@@ -23,7 +22,7 @@ class KBBase(ABC):
max_err : float, optional
The upper tolerance limit when comparing the similarity between a candidate's logical
result and the ground truth. Especially relevant for regression problems where exact
matches might not be feasible. Default to 0.
matches might not be feasible. Defaults to 1e-10.
use_cache : bool, optional
Whether to use a cache for previously abduced candidates to speed up subsequent
operations. Defaults to True.
@@ -36,7 +35,7 @@ class KBBase(ABC):
perform logical reasoning). After that, other operations (e.g. how to perform abductive
reasoning) will be automatically set up.
"""
def __init__(self, pseudo_label_list, max_err=0, use_cache=True):
def __init__(self, pseudo_label_list, max_err=1e-10, use_cache=True):
if not isinstance(pseudo_label_list, list):
raise TypeError("pseudo_label_list should be list")
self.pseudo_label_list = pseudo_label_list
@@ -70,7 +69,7 @@ class KBBase(ABC):
Returns
-------
List[List[Any]]
A list of candidates, i.e. revised pseudo label that are consistent with the
A list of candidates, i.e. revised pseudo labels that are consistent with the
knowledge base.
"""
if not self.use_cache:
@@ -191,7 +190,7 @@ class KBBase(ABC):
)
class ground_KB(KBBase):
class GroundKB(KBBase):
"""
Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon
class initialization, stroing all potential candidates along with their respective
@@ -310,7 +309,7 @@ class ground_KB(KBBase):
)


class prolog_KB(KBBase):
class PrologKB(KBBase):
"""
Knowledge base given by a prolog (pl) file.



+ 37
- 37
abl/reasoning/reasoner.py View File

@@ -25,7 +25,7 @@ class ReasonerBase:
A mapping from index to label. If not provided, a default order-based mapping is
created.
use_zoopt : bool, optional
Whether to use the Zoopt library during abductive reasoning. Default to False.
Whether to use the Zoopt library during abductive reasoning. Defaults to False.
"""
def __init__(self, kb, dist_func="hamming", mapping=None, use_zoopt=False):
@@ -203,7 +203,7 @@ class ReasonerBase:
max_revision : int or float, optional
The upper limit on the number of revisions. If float, denotes the fraction of the
total length that can be revised. A value of -1 implies no restriction on the number
of revisions. Default to -1.
of revisions. Defaults to -1.
require_more_revision : int, optional
Specifies additional number of revisions permitted beyond the minimum required.
Defaults to 0.
@@ -267,7 +267,7 @@ class ReasonerBase:


if __name__ == "__main__":
from kb import KBBase, ground_KB, prolog_KB
from kb import KBBase, GroundKB, PrologKB

prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]]
@@ -275,7 +275,7 @@ if __name__ == "__main__":
prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]]

class add_KB(KBBase):
class AddKB(KBBase):
def __init__(self, pseudo_label_list=list(range(10)),
use_cache=True):
super().__init__(pseudo_label_list, use_cache=use_cache)
@@ -283,7 +283,7 @@ if __name__ == "__main__":
def logic_forward(self, nums):
return sum(nums)
class add_ground_KB(ground_KB):
class AddGroundKB(GroundKB):
def __init__(self, pseudo_label_list=list(range(10)),
GKB_len_list=[2]):
super().__init__(pseudo_label_list, GKB_len_list)
@@ -304,36 +304,36 @@ if __name__ == "__main__":
print(res)
print()

print("add_KB with GKB:")
kb = add_ground_KB()
print("AddKB with GKB:")
kb = AddGroundKB()
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("add_KB without GKB:")
kb = add_KB()
print("AddKB without GKB:")
kb = AddKB()
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("add_KB without GKB, no cache")
kb = add_KB(use_cache=False)
print("AddKB without GKB, no cache")
kb = AddKB(use_cache=False)
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("prolog_KB with add.pl:")
kb = prolog_KB(pseudo_label_list=list(range(10)),
print("PrologKB with add.pl:")
kb = PrologKB(pseudo_label_list=list(range(10)),
pl_file="examples/mnist_add/datasets/add.pl")
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("prolog_KB with add.pl using zoopt:")
kb = prolog_KB(
print("PrologKB with add.pl using zoopt:")
kb = PrologKB(
pseudo_label_list=list(range(10)),
pl_file="examples/mnist_add/datasets/add.pl",
)
reasoner = ReasonerBase(kb, "confidence", use_zoopt=True)
test_add(reasoner)

print("add_KB with multiple inputs at once:")
print("AddKB with multiple inputs at once:")
multiple_prob = [[
[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
@@ -343,7 +343,7 @@ if __name__ == "__main__":
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
]]

kb = add_KB()
kb = AddKB()
reasoner = ReasonerBase(kb, "confidence")
res = reasoner.batch_abduce(
multiple_prob,
@@ -363,7 +363,7 @@ if __name__ == "__main__":
print(res)
print()

class HWF_KB(KBBase):
class HwfKB(KBBase):
def __init__(
self,
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9",
@@ -390,7 +390,7 @@ if __name__ == "__main__":
formula = [mapping[f] for f in formula]
return eval("".join(formula))
class HWF_ground_KB(ground_KB):
class HwfGroundKB(GroundKB):
def __init__(
self,
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9",
@@ -472,35 +472,35 @@ if __name__ == "__main__":
print(res)
print()

print("HWF_KB with GKB, max_err=0.1")
kb = HWF_ground_KB(GKB_len_list=[1, 3, 5], max_err=0.1)
print("HwfKB with GKB, max_err=0.1")
kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HWF_KB without GKB, max_err=0.1")
kb = HWF_KB(max_err=0.1)
print("HwfKB without GKB, max_err=0.1")
kb = HwfKB(max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HWF_KB with GKB, max_err=1")
kb = HWF_ground_KB(GKB_len_list=[1, 3, 5], max_err=1)
print("HwfKB with GKB, max_err=1")
kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HWF_KB without GKB, max_err=1")
kb = HWF_KB(max_err=1)
print("HwfKB without GKB, max_err=1")
kb = HwfKB(max_err=1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HWF_KB with multiple inputs at once:")
kb = HWF_KB(max_err=0.1)
print("HwfKB with multiple inputs at once:")
kb = HwfKB(max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf_multiple(reasoner, max_revisions=[1,3,3])
print("max_revision is float")
test_hwf_multiple(reasoner, max_revisions=[0.5,0.9,0.9])

class HED_prolog_KB(prolog_KB):
class HedKB(PrologKB):
def __init__(self, pseudo_label_list, pl_file):
super().__init__(pseudo_label_list, pl_file)

@@ -518,7 +518,7 @@ if __name__ == "__main__":
rules = [rule.value for rule in prolog_rules]
return rules

class HED_Reasoner(ReasonerBase):
class HedReasoner(ReasonerBase):
def __init__(self, kb, dist_func="hamming"):
super().__init__(kb, dist_func, use_zoopt=True)

@@ -573,11 +573,11 @@ if __name__ == "__main__":
def abduce_rules(self, pred_res):
return self.kb.abduce_rules(pred_res)

kb = HED_prolog_KB(
kb = HedKB(
pseudo_label_list=[1, 0, "+", "="],
pl_file="examples/hed/datasets/learn_add.pl",
)
reasoner = HED_Reasoner(kb)
reasoner = HedReasoner(kb)
consist_exs = [
[1, 1, "+", 0, "=", 1, 1],
[1, "+", 1, "=", 1, 0],
@@ -592,16 +592,16 @@ if __name__ == "__main__":
inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]]
rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"]

print("HED_kb logic forward")
print("HedKB logic forward")
print(kb.logic_forward(consist_exs))
print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2))
print()
print("HED_kb consist rule")
print("HedKB consist rule")
print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules))
print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules))
print()

print("HED_Reasoner abduce")
print("HedReasoner abduce")
res = reasoner.abduce(
[[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs)
)
@@ -616,6 +616,6 @@ if __name__ == "__main__":
print(res)
print()

print("HED_Reasoner abduce rules")
print("HedReasoner abduce rules")
abduced_rules = reasoner.abduce_rules(consist_exs)
print(abduced_rules)

Loading…
Cancel
Save