Browse Source

[MNT] resolve some comments in reasoning

pull/3/head
troyyyyy 1 year ago
parent
commit
aa9c949446
2 changed files with 47 additions and 48 deletions
  1. +10
    -11
      abl/reasoning/kb.py
  2. +37
    -37
      abl/reasoning/reasoner.py

+ 10
- 11
abl/reasoning/kb.py View File

@@ -1,17 +1,16 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import bisect import bisect
import numpy as np

from collections import defaultdict from collections import defaultdict
from itertools import product, combinations from itertools import product, combinations

from ..utils.utils import flatten, reform_idx, hamming_dist, to_hashable, hashable_to_list

from multiprocessing import Pool from multiprocessing import Pool

from functools import lru_cache from functools import lru_cache

import numpy as np
import pyswip import pyswip


from ..utils.utils import flatten, reform_idx, hamming_dist, to_hashable, hashable_to_list


class KBBase(ABC): class KBBase(ABC):
""" """
Base class for knowledge base. Base class for knowledge base.
@@ -23,7 +22,7 @@ class KBBase(ABC):
max_err : float, optional max_err : float, optional
The upper tolerance limit when comparing the similarity between a candidate's logical The upper tolerance limit when comparing the similarity between a candidate's logical
result and the ground truth. Especially relevant for regression problems where exact result and the ground truth. Especially relevant for regression problems where exact
matches might not be feasible. Default to 0.
matches might not be feasible. Defaults to 1e-10.
use_cache : bool, optional use_cache : bool, optional
Whether to use a cache for previously abduced candidates to speed up subsequent Whether to use a cache for previously abduced candidates to speed up subsequent
operations. Defaults to True. operations. Defaults to True.
@@ -36,7 +35,7 @@ class KBBase(ABC):
perform logical reasoning). After that, other operations (e.g. how to perform abductive perform logical reasoning). After that, other operations (e.g. how to perform abductive
reasoning) will be automatically set up. reasoning) will be automatically set up.
""" """
def __init__(self, pseudo_label_list, max_err=0, use_cache=True):
def __init__(self, pseudo_label_list, max_err=1e-10, use_cache=True):
if not isinstance(pseudo_label_list, list): if not isinstance(pseudo_label_list, list):
raise TypeError("pseudo_label_list should be list") raise TypeError("pseudo_label_list should be list")
self.pseudo_label_list = pseudo_label_list self.pseudo_label_list = pseudo_label_list
@@ -70,7 +69,7 @@ class KBBase(ABC):
Returns Returns
------- -------
List[List[Any]] List[List[Any]]
A list of candidates, i.e. revised pseudo label that are consistent with the
A list of candidates, i.e. revised pseudo labels that are consistent with the
knowledge base. knowledge base.
""" """
if not self.use_cache: if not self.use_cache:
@@ -191,7 +190,7 @@ class KBBase(ABC):
) )
class ground_KB(KBBase):
class GroundKB(KBBase):
""" """
Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon
class initialization, stroing all potential candidates along with their respective class initialization, stroing all potential candidates along with their respective
@@ -310,7 +309,7 @@ class ground_KB(KBBase):
) )




class prolog_KB(KBBase):
class PrologKB(KBBase):
""" """
Knowledge base given by a prolog (pl) file. Knowledge base given by a prolog (pl) file.




+ 37
- 37
abl/reasoning/reasoner.py View File

@@ -25,7 +25,7 @@ class ReasonerBase:
A mapping from index to label. If not provided, a default order-based mapping is A mapping from index to label. If not provided, a default order-based mapping is
created. created.
use_zoopt : bool, optional use_zoopt : bool, optional
Whether to use the Zoopt library during abductive reasoning. Default to False.
Whether to use the Zoopt library during abductive reasoning. Defaults to False.
""" """
def __init__(self, kb, dist_func="hamming", mapping=None, use_zoopt=False): def __init__(self, kb, dist_func="hamming", mapping=None, use_zoopt=False):
@@ -203,7 +203,7 @@ class ReasonerBase:
max_revision : int or float, optional max_revision : int or float, optional
The upper limit on the number of revisions. If float, denotes the fraction of the The upper limit on the number of revisions. If float, denotes the fraction of the
total length that can be revised. A value of -1 implies no restriction on the number total length that can be revised. A value of -1 implies no restriction on the number
of revisions. Default to -1.
of revisions. Defaults to -1.
require_more_revision : int, optional require_more_revision : int, optional
Specifies additional number of revisions permitted beyond the minimum required. Specifies additional number of revisions permitted beyond the minimum required.
Defaults to 0. Defaults to 0.
@@ -267,7 +267,7 @@ class ReasonerBase:




if __name__ == "__main__": if __name__ == "__main__":
from kb import KBBase, ground_KB, prolog_KB
from kb import KBBase, GroundKB, PrologKB


prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]]
@@ -275,7 +275,7 @@ if __name__ == "__main__":
prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]]


class add_KB(KBBase):
class AddKB(KBBase):
def __init__(self, pseudo_label_list=list(range(10)), def __init__(self, pseudo_label_list=list(range(10)),
use_cache=True): use_cache=True):
super().__init__(pseudo_label_list, use_cache=use_cache) super().__init__(pseudo_label_list, use_cache=use_cache)
@@ -283,7 +283,7 @@ if __name__ == "__main__":
def logic_forward(self, nums): def logic_forward(self, nums):
return sum(nums) return sum(nums)
class add_ground_KB(ground_KB):
class AddGroundKB(GroundKB):
def __init__(self, pseudo_label_list=list(range(10)), def __init__(self, pseudo_label_list=list(range(10)),
GKB_len_list=[2]): GKB_len_list=[2]):
super().__init__(pseudo_label_list, GKB_len_list) super().__init__(pseudo_label_list, GKB_len_list)
@@ -304,36 +304,36 @@ if __name__ == "__main__":
print(res) print(res)
print() print()


print("add_KB with GKB:")
kb = add_ground_KB()
print("AddKB with GKB:")
kb = AddGroundKB()
reasoner = ReasonerBase(kb, "confidence") reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner) test_add(reasoner)


print("add_KB without GKB:")
kb = add_KB()
print("AddKB without GKB:")
kb = AddKB()
reasoner = ReasonerBase(kb, "confidence") reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner) test_add(reasoner)


print("add_KB without GKB, no cache")
kb = add_KB(use_cache=False)
print("AddKB without GKB, no cache")
kb = AddKB(use_cache=False)
reasoner = ReasonerBase(kb, "confidence") reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner) test_add(reasoner)


print("prolog_KB with add.pl:")
kb = prolog_KB(pseudo_label_list=list(range(10)),
print("PrologKB with add.pl:")
kb = PrologKB(pseudo_label_list=list(range(10)),
pl_file="examples/mnist_add/datasets/add.pl") pl_file="examples/mnist_add/datasets/add.pl")
reasoner = ReasonerBase(kb, "confidence") reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner) test_add(reasoner)


print("prolog_KB with add.pl using zoopt:")
kb = prolog_KB(
print("PrologKB with add.pl using zoopt:")
kb = PrologKB(
pseudo_label_list=list(range(10)), pseudo_label_list=list(range(10)),
pl_file="examples/mnist_add/datasets/add.pl", pl_file="examples/mnist_add/datasets/add.pl",
) )
reasoner = ReasonerBase(kb, "confidence", use_zoopt=True) reasoner = ReasonerBase(kb, "confidence", use_zoopt=True)
test_add(reasoner) test_add(reasoner)


print("add_KB with multiple inputs at once:")
print("AddKB with multiple inputs at once:")
multiple_prob = [[ multiple_prob = [[
[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
@@ -343,7 +343,7 @@ if __name__ == "__main__":
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
]] ]]


kb = add_KB()
kb = AddKB()
reasoner = ReasonerBase(kb, "confidence") reasoner = ReasonerBase(kb, "confidence")
res = reasoner.batch_abduce( res = reasoner.batch_abduce(
multiple_prob, multiple_prob,
@@ -363,7 +363,7 @@ if __name__ == "__main__":
print(res) print(res)
print() print()


class HWF_KB(KBBase):
class HwfKB(KBBase):
def __init__( def __init__(
self, self,
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9",
@@ -390,7 +390,7 @@ if __name__ == "__main__":
formula = [mapping[f] for f in formula] formula = [mapping[f] for f in formula]
return eval("".join(formula)) return eval("".join(formula))
class HWF_ground_KB(ground_KB):
class HwfGroundKB(GroundKB):
def __init__( def __init__(
self, self,
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9",
@@ -472,35 +472,35 @@ if __name__ == "__main__":
print(res) print(res)
print() print()


print("HWF_KB with GKB, max_err=0.1")
kb = HWF_ground_KB(GKB_len_list=[1, 3, 5], max_err=0.1)
print("HwfKB with GKB, max_err=0.1")
kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=0.1)
reasoner = ReasonerBase(kb, "hamming") reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner) test_hwf(reasoner)


print("HWF_KB without GKB, max_err=0.1")
kb = HWF_KB(max_err=0.1)
print("HwfKB without GKB, max_err=0.1")
kb = HwfKB(max_err=0.1)
reasoner = ReasonerBase(kb, "hamming") reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner) test_hwf(reasoner)


print("HWF_KB with GKB, max_err=1")
kb = HWF_ground_KB(GKB_len_list=[1, 3, 5], max_err=1)
print("HwfKB with GKB, max_err=1")
kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=1)
reasoner = ReasonerBase(kb, "hamming") reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner) test_hwf(reasoner)


print("HWF_KB without GKB, max_err=1")
kb = HWF_KB(max_err=1)
print("HwfKB without GKB, max_err=1")
kb = HwfKB(max_err=1)
reasoner = ReasonerBase(kb, "hamming") reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner) test_hwf(reasoner)


print("HWF_KB with multiple inputs at once:")
kb = HWF_KB(max_err=0.1)
print("HwfKB with multiple inputs at once:")
kb = HwfKB(max_err=0.1)
reasoner = ReasonerBase(kb, "hamming") reasoner = ReasonerBase(kb, "hamming")
test_hwf_multiple(reasoner, max_revisions=[1,3,3]) test_hwf_multiple(reasoner, max_revisions=[1,3,3])
print("max_revision is float") print("max_revision is float")
test_hwf_multiple(reasoner, max_revisions=[0.5,0.9,0.9]) test_hwf_multiple(reasoner, max_revisions=[0.5,0.9,0.9])


class HED_prolog_KB(prolog_KB):
class HedKB(PrologKB):
def __init__(self, pseudo_label_list, pl_file): def __init__(self, pseudo_label_list, pl_file):
super().__init__(pseudo_label_list, pl_file) super().__init__(pseudo_label_list, pl_file)


@@ -518,7 +518,7 @@ if __name__ == "__main__":
rules = [rule.value for rule in prolog_rules] rules = [rule.value for rule in prolog_rules]
return rules return rules


class HED_Reasoner(ReasonerBase):
class HedReasoner(ReasonerBase):
def __init__(self, kb, dist_func="hamming"): def __init__(self, kb, dist_func="hamming"):
super().__init__(kb, dist_func, use_zoopt=True) super().__init__(kb, dist_func, use_zoopt=True)


@@ -573,11 +573,11 @@ if __name__ == "__main__":
def abduce_rules(self, pred_res): def abduce_rules(self, pred_res):
return self.kb.abduce_rules(pred_res) return self.kb.abduce_rules(pred_res)


kb = HED_prolog_KB(
kb = HedKB(
pseudo_label_list=[1, 0, "+", "="], pseudo_label_list=[1, 0, "+", "="],
pl_file="examples/hed/datasets/learn_add.pl", pl_file="examples/hed/datasets/learn_add.pl",
) )
reasoner = HED_Reasoner(kb)
reasoner = HedReasoner(kb)
consist_exs = [ consist_exs = [
[1, 1, "+", 0, "=", 1, 1], [1, 1, "+", 0, "=", 1, 1],
[1, "+", 1, "=", 1, 0], [1, "+", 1, "=", 1, 0],
@@ -592,16 +592,16 @@ if __name__ == "__main__":
inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]]
rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"] rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"]


print("HED_kb logic forward")
print("HedKB logic forward")
print(kb.logic_forward(consist_exs)) print(kb.logic_forward(consist_exs))
print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2)) print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2))
print() print()
print("HED_kb consist rule")
print("HedKB consist rule")
print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules)) print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules))
print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules)) print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules))
print() print()


print("HED_Reasoner abduce")
print("HedReasoner abduce")
res = reasoner.abduce( res = reasoner.abduce(
[[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs) [[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs)
) )
@@ -616,6 +616,6 @@ if __name__ == "__main__":
print(res) print(res)
print() print()


print("HED_Reasoner abduce rules")
print("HedReasoner abduce rules")
abduced_rules = reasoner.abduce_rules(consist_exs) abduced_rules = reasoner.abduce_rules(consist_exs)
print(abduced_rules) print(abduced_rules)

Loading…
Cancel
Save