Browse Source

[MNT] add docstring for class prolog_KB

pull/3/head
troyyyyy 1 year ago
parent
commit
01f00d225e
2 changed files with 65 additions and 42 deletions
  1. +58
    -35
      abl/reasoning/kb.py
  2. +7
    -7
      abl/reasoning/reasoner.py

+ 58
- 35
abl/reasoning/kb.py View File

@@ -21,9 +21,9 @@ class KBBase(ABC):
pseudo_label_list : list
List of possible pseudo labels.
max_err : float, optional
The upper tolerance limit when comparing the similarity between a candidate result
and the ground truth. Especially relevant for regression problems where exact matches
might not be feasible. Default to 0.
The upper tolerance limit when comparing the similarity between a candidate's logical
result and the ground truth. Especially relevant for regression problems where exact
matches might not be feasible. Default to 0.
use_cache : bool, optional
Whether to use a cache for previously abduced candidates to speed up subsequent
operations. Defaults to True.
@@ -46,7 +46,8 @@ class KBBase(ABC):
@abstractmethod
def logic_forward(self, pseudo_labels):
"""
How to perform logical reasoning. Users are required to provide this.
How to perform (deductive) logical reasoning, i.e. matching each pseudo label to
their logical result. Users are required to provide this.
"""
pass

@@ -59,7 +60,7 @@ class KBBase(ABC):
pred_pseudo_label : List[Any]
Predicted pseudo label.
y : any
Ground truth for the result (after passing through the logic part).
Ground truth for the logical result.
max_revision_num : int
The upper limit on the number of revisions.
require_more_revision : int, optional
@@ -89,7 +90,7 @@ class KBBase(ABC):
pred_pseudo_label : List[Any]
Predicted pseudo label.
y : Any
Ground truth for the result (after passing through the logic part).
Ground truth for the logical result.
revision_idx : array-like
Indices of where revisions should be made to the predicted pseudo label.
"""
@@ -127,7 +128,7 @@ class KBBase(ABC):
pred_pseudo_label : List[Any]
Predicted pseudo label.
y : Any
Ground truth for the result (after passing through the logic part).
Ground truth for the logical result.
max_revision_num : int
The upper limit on the number of revisions.
require_more_revision : int
@@ -173,11 +174,9 @@ class KBBase(ABC):
class ground_KB(KBBase):
"""
Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt
upon class initialization, stroing all potential candidates along with
their respective results after passing through the logic part. Ground KB can
enhance the speed of abductive reasoning. For more on this, refer to the
`abduce_candidates` method in this class.
Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon
class initialization, stroing all potential candidates along with their respective
logical result. Ground KB can accelerate abductive reasoning in `abduce_candidates`.

Parameters
----------
@@ -190,11 +189,11 @@ class ground_KB(KBBase):
Notes
-----
Users can also inherit from this class to build their own knowledge base.
Similar to `KBBase`, users are only required to provide the `pseudo_label_list`
and override the `logic_forward` function. Additionally, users should provide
the `GKB_len_list`. After that, other operations (e.g. auto-construction of
GKB, and how to perform abductive reasoning) will be automatically set up.
Users can also inherit from this class to build their own knowledge base. Similar
to `KBBase`, users are only required to provide the `pseudo_label_list` and override
the `logic_forward` function. Additionally, users should provide the `GKB_len_list`.
After that, other operations (e.g. auto-construction of GKB, and how to perform
abductive reasoning) will be automatically set up.
"""
def __init__(self, pseudo_label_list, GKB_len_list, max_err=0):
super().__init__(pseudo_label_list, max_err)
@@ -272,32 +271,46 @@ class ground_KB(KBBase):
else:
potential_candidates = self.GKB[len(pred_pseudo_label)]
key_list = list(potential_candidates.keys())
key_idx = bisect.bisect_left(key_list, y)
all_candidates = []
for idx in range(key_idx - 1, 0, -1):
k = key_list[idx]
if abs(k - y) <= self.max_err:
all_candidates.extend(potential_candidates[k])
else:
break
for idx in range(key_idx, len(key_list)):
k = key_list[idx]
if abs(k - y) <= self.max_err:
all_candidates.extend(potential_candidates[k])
else:
break
low_key = bisect.bisect_left(key_list, y - self.max_err)
high_key = bisect.bisect_right(key_list, y + self.max_err)

all_candidates = [candidate
for key in key_list[low_key:high_key]
for candidate in potential_candidates[key]]
return all_candidates


class prolog_KB(KBBase):
def __init__(self, pseudo_label_list, pl_file, max_err=0):
super().__init__(pseudo_label_list, max_err)
"""
Knowledge base given by a prolog (pl) file.

Parameters
----------
pseudo_label_list : list
Refer to class `KBBase`.
pl_file :
Prolog file containing the KB.
max_err : float, optional
Refer to class `KBBase`.
Notes
-----
Users can also inherit from this class to build their own knowledge base. When using
this class, users are only required to provide the `pl_file`.
"""
def __init__(self, pseudo_label_list, pl_file):
super().__init__(pseudo_label_list)
self.prolog = pyswip.Prolog()
self.prolog.consult(pl_file)

def logic_forward(self, pseudo_labels):
"""
Consult prolog with the query `logic_forward(pseudo_labels, Res).`, and set the
returned `Res` as the logical results. To use this default function, there must be
a Prolog `log_forward` method in the pl file to perform logical. reasoning. Otherwise,
users would override this function.
"""
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]['Res']
if result == 'true':
return True
@@ -314,11 +327,16 @@ class prolog_KB(KBBase):
revision_pred_pseudo_label[idx] = 'P' + str(idx)
revision_pred_pseudo_label = reform_idx(revision_pred_pseudo_label, pred_pseudo_label)
# TODO:不知道有没有更简洁的方法
regex = r"'P\d+'"
return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pred_pseudo_label))
def get_query_string(self, pred_pseudo_label, y, revision_idx):
"""
Consult prolog with `logic_forward([kept_labels, Revise_labels], Res).`, and set
the returned `Revise_labels` together with the kept labels as the candidates. This is
a default fuction for demo, users would override this function to adapt to their own
Prolog file.
"""
query_string = "logic_forward("
query_string += self._revision_pred_pseudo_label(pred_pseudo_label, revision_idx)
key_is_none_flag = y is None or (type(y) == list and y[0] is None)
@@ -326,6 +344,11 @@ class prolog_KB(KBBase):
return query_string
def revise_at_idx(self, pred_pseudo_label, y, revision_idx):
"""
Revise the predicted pseudo label at specified index positions by querying Prolog.
This is an overridden function. For more information about the parameters, refer to
the function of the same name in class `KBBase`.
"""
candidates = []
query_string = self.get_query_string(pred_pseudo_label, y, revision_idx)
save_pred_pseudo_label = pred_pseudo_label


+ 7
- 7
abl/reasoning/reasoner.py View File

@@ -1,6 +1,6 @@
import numpy as np
from zoopt import Dimension, Objective, Parameter, Opt
from ..utils.utils import (
from abl.utils.utils import (
confidence_dist,
flatten,
reform_idx,
@@ -60,7 +60,7 @@ class ReasonerBase:
Predicted probabilities of the prediction (Each sublist contains the probability
distribution over all pseudo labels).
candidates : List[List[Any]]
Multiple candidate abduction results.
Multiple consistent candidates.
"""
if len(candidates) == 0:
return []
@@ -88,7 +88,7 @@ class ReasonerBase:
Predicted probabilities of the prediction (Each sublist contains the probability
distribution over all pseudo labels). Used when distance function is "confidence".
candidates : List[List[Any]]
Multiple candidate abduction results.
Multiple consistent candidates.
"""
if self.dist_func == "hamming":
return hamming_dist(pred_pseudo_label, candidates)
@@ -115,7 +115,7 @@ class ReasonerBase:
Predicted probabilities of the prediction (Each sublist contains the probability
distribution over all pseudo labels).
y : Any
Ground truth for the result (after passing through the logic part).
Ground truth for the logical result.
max_revision_num : int
Specifies the maximum number of revisions allowed.
"""
@@ -162,7 +162,7 @@ class ReasonerBase:
pred_pseudo_label : List[Any]
Predicted pseudo label.
y : Any
Ground truth for the result (after passing through the logic part).
Ground truth for the logical result.
revision_idx : array-like
Indices of where revisions should be made to the predicted pseudo label.
"""
@@ -182,7 +182,7 @@ class ReasonerBase:
pred_pseudo_label : List[Any]
Predicted pseudo label.
y : Any
Ground truth for the result (after passing through the logic part).
Ground truth for the logical result.
max_revision : int or float, optional
The upper limit on the number of revisions. If float, denotes the fraction of the
total length that can be revised. A value of -1 implies no restriction on the number
@@ -456,7 +456,7 @@ if __name__ == "__main__":
print()

print("HWF_KB with GKB, max_err=0.1")
kb = HWF_ground_KB(GKB_len_list=[1, 3, 5, 7], max_err=0.1)
kb = HWF_ground_KB(GKB_len_list=[1, 3, 5], max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)



Loading…
Cancel
Save