|
|
@@ -1,11 +1,12 @@ |
|
|
|
import bisect |
|
|
|
import os |
|
|
|
import inspect |
|
|
|
import logging |
|
|
|
from abc import ABC, abstractmethod |
|
|
|
from collections import defaultdict |
|
|
|
from itertools import combinations, product |
|
|
|
from multiprocessing import Pool |
|
|
|
import inspect |
|
|
|
import logging |
|
|
|
from typing import Callable, Any, List, Optional |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import pyswip |
|
|
@@ -33,7 +34,7 @@ class KBBase(ABC): |
|
|
|
use_cache : bool, optional |
|
|
|
Whether to use abl_cache for previously abduced candidates to speed up subsequent |
|
|
|
operations. Defaults to True. |
|
|
|
key_func : func, optional |
|
|
|
key_func : Callable, optional |
|
|
|
A function employed for hashing in abl_cache. This is only operational when use_cache |
|
|
|
is set to True. Defaults to to_hashable. |
|
|
|
cache_size: int, optional |
|
|
@@ -51,11 +52,11 @@ class KBBase(ABC): |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
pseudo_label_list, |
|
|
|
max_err=1e-10, |
|
|
|
use_cache=True, |
|
|
|
key_func=to_hashable, |
|
|
|
cache_size=4096, |
|
|
|
pseudo_label_list: list, |
|
|
|
max_err: float = 1e-10, |
|
|
|
use_cache: bool = True, |
|
|
|
key_func: Callable = to_hashable, |
|
|
|
cache_size: int = 4096, |
|
|
|
): |
|
|
|
if not isinstance(pseudo_label_list, list): |
|
|
|
raise TypeError(f"pseudo_label_list should be list, got {type(pseudo_label_list)}") |
|
|
@@ -79,7 +80,7 @@ class KBBase(ABC): |
|
|
|
# TODO 添加consistency measure+max_err容忍错误 |
|
|
|
|
|
|
|
@abstractmethod |
|
|
|
def logic_forward(self, pseudo_label, x = None): |
|
|
|
def logic_forward(self, pseudo_label: List[Any], x: Optional[List[Any]] = None) -> Any: |
|
|
|
""" |
|
|
|
How to perform (deductive) logical reasoning, i.e. matching each pseudo label sample to |
|
|
|
their reasoning result. Users are required to provide this. |
|
|
@@ -88,9 +89,25 @@ class KBBase(ABC): |
|
|
|
---------- |
|
|
|
pseudo_label : List[Any] |
|
|
|
Pseudo label sample. |
|
|
|
x : Optional[List[Any]] |
|
|
|
The corresponding input sample. If deductive logical reasoning does not require any |
|
|
|
information from the input, the overridden function provided by the user can omit |
|
|
|
this parameter. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
Any |
|
|
|
The reasoning result. |
|
|
|
""" |
|
|
|
|
|
|
|
def abduce_candidates(self, pseudo_label, y, x, max_revision_num, require_more_revision): |
|
|
|
def abduce_candidates( |
|
|
|
self, |
|
|
|
pseudo_label: List[Any], |
|
|
|
y: Any, |
|
|
|
x: List[Any], |
|
|
|
max_revision_num: int, |
|
|
|
require_more_revision: int, |
|
|
|
) -> List[List[Any]]: |
|
|
|
""" |
|
|
|
Perform abductive reasoning to get a candidate compatible with the knowledge base. |
|
|
|
|
|
|
@@ -98,7 +115,7 @@ class KBBase(ABC): |
|
|
|
---------- |
|
|
|
pseudo_label : List[Any] |
|
|
|
Pseudo label sample (to be revised by abductive reasoning). |
|
|
|
y : any |
|
|
|
y : Any |
|
|
|
Ground truth of the reasoning result for the sample. |
|
|
|
x : List[Any] |
|
|
|
The corresponding input sample. |
|
|
@@ -115,7 +132,7 @@ class KBBase(ABC): |
|
|
|
""" |
|
|
|
return self._abduce_by_search(pseudo_label, y, x, max_revision_num, require_more_revision) |
|
|
|
|
|
|
|
def _check_equal(self, logic_result, y): |
|
|
|
def _check_equal(self, reasoning_result: Any, y: Any) -> bool: |
|
|
|
""" |
|
|
|
Check whether the reasoning result of a pseduo label sample is equal to the ground truth |
|
|
|
(or, within the maximum error allowed for numerical results). |
|
|
@@ -125,15 +142,21 @@ class KBBase(ABC): |
|
|
|
bool |
|
|
|
The result of the check. |
|
|
|
""" |
|
|
|
if logic_result is None: |
|
|
|
if reasoning_result is None: |
|
|
|
return False |
|
|
|
|
|
|
|
if isinstance(logic_result, (int, float)) and isinstance(y, (int, float)): |
|
|
|
return abs(logic_result - y) <= self.max_err |
|
|
|
if isinstance(reasoning_result, (int, float)) and isinstance(y, (int, float)): |
|
|
|
return abs(reasoning_result - y) <= self.max_err |
|
|
|
else: |
|
|
|
return logic_result == y |
|
|
|
|
|
|
|
def revise_at_idx(self, pseudo_label, y, x, revision_idx): |
|
|
|
return reasoning_result == y |
|
|
|
|
|
|
|
def revise_at_idx( |
|
|
|
self, |
|
|
|
pseudo_label: List[Any], |
|
|
|
y: Any, |
|
|
|
x: List[Any], |
|
|
|
revision_idx: List[int], |
|
|
|
) -> List[List[Any]]: |
|
|
|
""" |
|
|
|
Revise the pseudo label sample at specified index positions. |
|
|
|
|
|
|
@@ -145,8 +168,8 @@ class KBBase(ABC): |
|
|
|
Ground truth of the reasoning result for the sample. |
|
|
|
x : List[Any] |
|
|
|
The corresponding input sample. |
|
|
|
revision_idx : array-like |
|
|
|
Indices of where revisions should be made to the pseudo label sample. |
|
|
|
revision_idx : List[int] |
|
|
|
A list specifying indices of where revisions should be made to the pseudo label sample. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
@@ -164,7 +187,13 @@ class KBBase(ABC): |
|
|
|
candidates.append(candidate) |
|
|
|
return candidates |
|
|
|
|
|
|
|
def _revision(self, revision_num, pseudo_label, y, x): |
|
|
|
def _revision( |
|
|
|
self, |
|
|
|
revision_num: int, |
|
|
|
pseudo_label: List[Any], |
|
|
|
y: Any, |
|
|
|
x: List[Any], |
|
|
|
) -> List[List[Any]]: |
|
|
|
""" |
|
|
|
For a specified number of labels in a pseudo label sample to revise, iterate through |
|
|
|
all possible indices to find any candidates that are compatible with the knowledge base. |
|
|
@@ -178,7 +207,14 @@ class KBBase(ABC): |
|
|
|
return new_candidates |
|
|
|
|
|
|
|
@abl_cache() |
|
|
|
def _abduce_by_search(self, pseudo_label, y, x, max_revision_num, require_more_revision): |
|
|
|
def _abduce_by_search( |
|
|
|
self, |
|
|
|
pseudo_label: List[Any], |
|
|
|
y: Any, |
|
|
|
x: List[Any], |
|
|
|
max_revision_num: int, |
|
|
|
require_more_revision: int, |
|
|
|
) -> List[List[Any]]: |
|
|
|
""" |
|
|
|
Perform abductive reasoning by exhastive search. Specifically, begin with 0 and |
|
|
|
continuously increase the number of labels in a pseudo label sample to revise, until |
|
|
@@ -302,7 +338,14 @@ class GroundKB(KBBase): |
|
|
|
X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) |
|
|
|
return X, Y |
|
|
|
|
|
|
|
def abduce_candidates(self, pseudo_label, y, x, max_revision_num, require_more_revision): |
|
|
|
def abduce_candidates( |
|
|
|
self, |
|
|
|
pseudo_label: List[Any], |
|
|
|
y: Any, |
|
|
|
x: List[Any], |
|
|
|
max_revision_num: int, |
|
|
|
require_more_revision: int, |
|
|
|
) -> List[List[Any]]: |
|
|
|
""" |
|
|
|
Perform abductive reasoning by directly retrieving compatible candidates from |
|
|
|
the prebuilt GKB. In this way, the time-consuming exhaustive search can be |
|
|
@@ -312,13 +355,13 @@ class GroundKB(KBBase): |
|
|
|
---------- |
|
|
|
pseudo_label : List[Any] |
|
|
|
Pseudo label sample (to be revised by abductive reasoning). |
|
|
|
y : any |
|
|
|
y : Any |
|
|
|
Ground truth of the reasoning result for the sample. |
|
|
|
x : List[Any] |
|
|
|
The corresponding input sample (unused in GroundKB). |
|
|
|
max_revision_num : int |
|
|
|
The upper limit on the number of revised labels for each sample. |
|
|
|
require_more_revision : int, optional |
|
|
|
require_more_revision : int |
|
|
|
Specifies additional number of revisions permitted beyond the minimum required. |
|
|
|
|
|
|
|
Returns |
|
|
@@ -341,7 +384,7 @@ class GroundKB(KBBase): |
|
|
|
candidates = [all_candidates[idx] for idx in idxs] |
|
|
|
return candidates |
|
|
|
|
|
|
|
def _find_candidate_GKB(self, pseudo_label, y): |
|
|
|
def _find_candidate_GKB(self, pseudo_label: List[Any], y: Any) -> List[List[Any]]: |
|
|
|
""" |
|
|
|
Retrieve compatible candidates from the prebuilt GKB. For numerical reasoning results, |
|
|
|
return all candidates whose reasoning results fall within the |
|
|
@@ -408,7 +451,7 @@ class PrologKB(KBBase): |
|
|
|
class. Users are also welcome to override related functions for more flexible support. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, pseudo_label_list, pl_file): |
|
|
|
def __init__(self, pseudo_label_list: List[Any], pl_file: str): |
|
|
|
super().__init__(pseudo_label_list) |
|
|
|
self.pl_file = pl_file |
|
|
|
self.prolog = pyswip.Prolog() |
|
|
@@ -417,7 +460,7 @@ class PrologKB(KBBase): |
|
|
|
raise FileNotFoundError(f"The Prolog file {self.pl_file} does not exist.") |
|
|
|
self.prolog.consult(self.pl_file) |
|
|
|
|
|
|
|
def logic_forward(self, pseudo_labels): |
|
|
|
def logic_forward(self, pseudo_label: List[Any]) -> Any: |
|
|
|
""" |
|
|
|
Consult prolog with the query `logic_forward(pseudo_labels, Res).`, and set the |
|
|
|
returned `Res` as the reasoning results. To use this default function, there must be |
|
|
@@ -429,14 +472,18 @@ class PrologKB(KBBase): |
|
|
|
pseudo_label : List[Any] |
|
|
|
Pseudo label sample. |
|
|
|
""" |
|
|
|
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]["Res"] |
|
|
|
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_label))[0]["Res"] |
|
|
|
if result == "true": |
|
|
|
return True |
|
|
|
elif result == "false": |
|
|
|
return False |
|
|
|
return result |
|
|
|
|
|
|
|
def _revision_pseudo_label(self, pseudo_label, revision_idx): |
|
|
|
def _revision_pseudo_label( |
|
|
|
self, |
|
|
|
pseudo_label: List[Any], |
|
|
|
revision_idx: List[int], |
|
|
|
) -> List[Any]: |
|
|
|
import re |
|
|
|
|
|
|
|
revision_pseudo_label = pseudo_label.copy() |
|
|
@@ -449,7 +496,13 @@ class PrologKB(KBBase): |
|
|
|
regex = r"'P\d+'" |
|
|
|
return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pseudo_label)) |
|
|
|
|
|
|
|
def get_query_string(self, pseudo_label, y, revision_idx): |
|
|
|
def get_query_string( |
|
|
|
self, |
|
|
|
pseudo_label: List[Any], |
|
|
|
y: Any, |
|
|
|
x: List[Any], |
|
|
|
revision_idx: List[int], |
|
|
|
) -> str: |
|
|
|
""" |
|
|
|
Get the query to be used for consulting Prolog. |
|
|
|
This is a default function for demo, users would override this function to adapt to |
|
|
@@ -460,10 +513,12 @@ class PrologKB(KBBase): |
|
|
|
---------- |
|
|
|
pseudo_label : List[Any] |
|
|
|
Pseudo label sample (to be revised by abductive reasoning). |
|
|
|
y : any |
|
|
|
y : Any |
|
|
|
Ground truth of the reasoning result for the sample. |
|
|
|
revision_idx : array-like |
|
|
|
Indices of where revisions should be made to the pseudo label sample. |
|
|
|
x : List[Any] |
|
|
|
The corresponding input sample. |
|
|
|
revision_idx : List[int] |
|
|
|
A list specifying indices of where revisions should be made to the pseudo label sample. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
@@ -476,7 +531,13 @@ class PrologKB(KBBase): |
|
|
|
query_string += ",%s)." % y if not key_is_none_flag else ")." |
|
|
|
return query_string |
|
|
|
|
|
|
|
def revise_at_idx(self, pseudo_label, y, x, revision_idx): |
|
|
|
def revise_at_idx( |
|
|
|
self, |
|
|
|
pseudo_label: List[Any], |
|
|
|
y: Any, |
|
|
|
x: List[Any], |
|
|
|
revision_idx: List[int], |
|
|
|
) -> List[List[Any]]: |
|
|
|
""" |
|
|
|
Revise the pseudo label sample at specified index positions by querying Prolog. |
|
|
|
|
|
|
@@ -488,8 +549,8 @@ class PrologKB(KBBase): |
|
|
|
Ground truth of the reasoning result for the sample. |
|
|
|
x : List[Any] |
|
|
|
The corresponding input sample. |
|
|
|
revision_idx : array-like |
|
|
|
Indices of where revisions should be made to the pseudo label sample. |
|
|
|
revision_idx : List[int] |
|
|
|
A list specifying indices of where revisions should be made to the pseudo label sample. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|