|
|
@@ -21,9 +21,9 @@ class KBBase(ABC): |
|
|
|
pseudo_label_list : list |
|
|
|
List of possible pseudo labels. |
|
|
|
max_err : float, optional |
|
|
|
The upper tolerance limit when comparing the similarity between a candidate result |
|
|
|
and the ground truth. Especially relevant for regression problems where exact matches |
|
|
|
might not be feasible. Default to 0. |
|
|
|
The upper tolerance limit when comparing the similarity between a candidate's logical |
|
|
|
result and the ground truth. Especially relevant for regression problems where exact |
|
|
|
matches might not be feasible. Default to 0. |
|
|
|
use_cache : bool, optional |
|
|
|
Whether to use a cache for previously abduced candidates to speed up subsequent |
|
|
|
operations. Defaults to True. |
|
|
@@ -46,7 +46,8 @@ class KBBase(ABC): |
|
|
|
@abstractmethod |
|
|
|
def logic_forward(self, pseudo_labels): |
|
|
|
""" |
|
|
|
How to perform logical reasoning. Users are required to provide this. |
|
|
|
How to perform (deductive) logical reasoning, i.e. matching each pseudo label to |
|
|
|
their logical result. Users are required to provide this. |
|
|
|
""" |
|
|
|
pass |
|
|
|
|
|
|
@@ -59,7 +60,7 @@ class KBBase(ABC): |
|
|
|
pred_pseudo_label : List[Any] |
|
|
|
Predicted pseudo label. |
|
|
|
y : any |
|
|
|
Ground truth for the result (after passing through the logic part). |
|
|
|
Ground truth for the logical result. |
|
|
|
max_revision_num : int |
|
|
|
The upper limit on the number of revisions. |
|
|
|
require_more_revision : int, optional |
|
|
@@ -89,7 +90,7 @@ class KBBase(ABC): |
|
|
|
pred_pseudo_label : List[Any] |
|
|
|
Predicted pseudo label. |
|
|
|
y : Any |
|
|
|
Ground truth for the result (after passing through the logic part). |
|
|
|
Ground truth for the logical result. |
|
|
|
revision_idx : array-like |
|
|
|
Indices of where revisions should be made to the predicted pseudo label. |
|
|
|
""" |
|
|
@@ -127,7 +128,7 @@ class KBBase(ABC): |
|
|
|
pred_pseudo_label : List[Any] |
|
|
|
Predicted pseudo label. |
|
|
|
y : Any |
|
|
|
Ground truth for the result (after passing through the logic part). |
|
|
|
Ground truth for the logical result. |
|
|
|
max_revision_num : int |
|
|
|
The upper limit on the number of revisions. |
|
|
|
require_more_revision : int |
|
|
@@ -173,11 +174,9 @@ class KBBase(ABC): |
|
|
|
|
|
|
|
class ground_KB(KBBase): |
|
|
|
""" |
|
|
|
Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt |
|
|
|
upon class initialization, stroing all potential candidates along with |
|
|
|
their respective results after passing through the logic part. Ground KB can |
|
|
|
enhance the speed of abductive reasoning. For more on this, refer to the |
|
|
|
`abduce_candidates` method in this class. |
|
|
|
Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon |
|
|
|
class initialization, stroing all potential candidates along with their respective |
|
|
|
logical result. Ground KB can accelerate abductive reasoning in `abduce_candidates`. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
@@ -190,11 +189,11 @@ class ground_KB(KBBase): |
|
|
|
|
|
|
|
Notes |
|
|
|
----- |
|
|
|
Users can also inherit from this class to build their own knowledge base. |
|
|
|
Similar to `KBBase`, users are only required to provide the `pseudo_label_list` |
|
|
|
and override the `logic_forward` function. Additionally, users should provide |
|
|
|
the `GKB_len_list`. After that, other operations (e.g. auto-construction of |
|
|
|
GKB, and how to perform abductive reasoning) will be automatically set up. |
|
|
|
Users can also inherit from this class to build their own knowledge base. Similar |
|
|
|
to `KBBase`, users are only required to provide the `pseudo_label_list` and override |
|
|
|
the `logic_forward` function. Additionally, users should provide the `GKB_len_list`. |
|
|
|
After that, other operations (e.g. auto-construction of GKB, and how to perform |
|
|
|
abductive reasoning) will be automatically set up. |
|
|
|
""" |
|
|
|
def __init__(self, pseudo_label_list, GKB_len_list, max_err=0): |
|
|
|
super().__init__(pseudo_label_list, max_err) |
|
|
@@ -272,32 +271,46 @@ class ground_KB(KBBase): |
|
|
|
else: |
|
|
|
potential_candidates = self.GKB[len(pred_pseudo_label)] |
|
|
|
key_list = list(potential_candidates.keys()) |
|
|
|
key_idx = bisect.bisect_left(key_list, y) |
|
|
|
|
|
|
|
all_candidates = [] |
|
|
|
for idx in range(key_idx - 1, 0, -1): |
|
|
|
k = key_list[idx] |
|
|
|
if abs(k - y) <= self.max_err: |
|
|
|
all_candidates.extend(potential_candidates[k]) |
|
|
|
else: |
|
|
|
break |
|
|
|
|
|
|
|
for idx in range(key_idx, len(key_list)): |
|
|
|
k = key_list[idx] |
|
|
|
if abs(k - y) <= self.max_err: |
|
|
|
all_candidates.extend(potential_candidates[k]) |
|
|
|
else: |
|
|
|
break |
|
|
|
low_key = bisect.bisect_left(key_list, y - self.max_err) |
|
|
|
high_key = bisect.bisect_right(key_list, y + self.max_err) |
|
|
|
|
|
|
|
all_candidates = [candidate |
|
|
|
for key in key_list[low_key:high_key] |
|
|
|
for candidate in potential_candidates[key]] |
|
|
|
return all_candidates |
|
|
|
|
|
|
|
|
|
|
|
class prolog_KB(KBBase): |
|
|
|
def __init__(self, pseudo_label_list, pl_file, max_err=0): |
|
|
|
super().__init__(pseudo_label_list, max_err) |
|
|
|
""" |
|
|
|
Knowledge base given by a prolog (pl) file. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
pseudo_label_list : list |
|
|
|
Refer to class `KBBase`. |
|
|
|
pl_file : |
|
|
|
Prolog file containing the KB. |
|
|
|
max_err : float, optional |
|
|
|
Refer to class `KBBase`. |
|
|
|
|
|
|
|
Notes |
|
|
|
----- |
|
|
|
Users can also inherit from this class to build their own knowledge base. When using |
|
|
|
this class, users are only required to provide the `pl_file`. |
|
|
|
""" |
|
|
|
def __init__(self, pseudo_label_list, pl_file): |
|
|
|
super().__init__(pseudo_label_list) |
|
|
|
self.prolog = pyswip.Prolog() |
|
|
|
self.prolog.consult(pl_file) |
|
|
|
|
|
|
|
def logic_forward(self, pseudo_labels): |
|
|
|
""" |
|
|
|
Consult prolog with the query `logic_forward(pseudo_labels, Res).`, and set the |
|
|
|
returned `Res` as the logical results. To use this default function, there must be |
|
|
|
a Prolog `log_forward` method in the pl file to perform logical. reasoning. Otherwise, |
|
|
|
users would override this function. |
|
|
|
""" |
|
|
|
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]['Res'] |
|
|
|
if result == 'true': |
|
|
|
return True |
|
|
@@ -314,11 +327,16 @@ class prolog_KB(KBBase): |
|
|
|
revision_pred_pseudo_label[idx] = 'P' + str(idx) |
|
|
|
revision_pred_pseudo_label = reform_idx(revision_pred_pseudo_label, pred_pseudo_label) |
|
|
|
|
|
|
|
# TODO:不知道有没有更简洁的方法 |
|
|
|
regex = r"'P\d+'" |
|
|
|
return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pred_pseudo_label)) |
|
|
|
|
|
|
|
def get_query_string(self, pred_pseudo_label, y, revision_idx): |
|
|
|
""" |
|
|
|
Consult prolog with `logic_forward([kept_labels, Revise_labels], Res).`, and set |
|
|
|
the returned `Revise_labels` together with the kept labels as the candidates. This is |
|
|
|
a default fuction for demo, users would override this function to adapt to their own |
|
|
|
Prolog file. |
|
|
|
""" |
|
|
|
query_string = "logic_forward(" |
|
|
|
query_string += self._revision_pred_pseudo_label(pred_pseudo_label, revision_idx) |
|
|
|
key_is_none_flag = y is None or (type(y) == list and y[0] is None) |
|
|
@@ -326,6 +344,11 @@ class prolog_KB(KBBase): |
|
|
|
return query_string |
|
|
|
|
|
|
|
def revise_at_idx(self, pred_pseudo_label, y, revision_idx): |
|
|
|
""" |
|
|
|
Revise the predicted pseudo label at specified index positions by querying Prolog. |
|
|
|
This is an overridden function. For more information about the parameters, refer to |
|
|
|
the function of the same name in class `KBBase`. |
|
|
|
""" |
|
|
|
candidates = [] |
|
|
|
query_string = self.get_query_string(pred_pseudo_label, y, revision_idx) |
|
|
|
save_pred_pseudo_label = pred_pseudo_label |
|
|
|