Browse Source

[MNT]add user-defined dist func, add detailed parameters

pull/1/head
troyyyyy 1 year ago
parent
commit
d9d0ad9f6e
2 changed files with 185 additions and 76 deletions
  1. +98
    -37
      abl/reasoning/kb.py
  2. +87
    -39
      abl/reasoning/reasoner.py

+ 98
- 37
abl/reasoning/kb.py View File

@@ -1,11 +1,12 @@
import bisect
import os
import inspect
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from itertools import combinations, product
from multiprocessing import Pool
import inspect
import logging
from typing import Callable, Any, List, Optional

import numpy as np
import pyswip
@@ -33,7 +34,7 @@ class KBBase(ABC):
use_cache : bool, optional
Whether to use abl_cache for previously abduced candidates to speed up subsequent
operations. Defaults to True.
key_func : func, optional
key_func : Callable, optional
A function employed for hashing in abl_cache. This is only operational when use_cache
is set to True. Defaults to to_hashable.
cache_size: int, optional
@@ -51,11 +52,11 @@ class KBBase(ABC):

def __init__(
self,
pseudo_label_list,
max_err=1e-10,
use_cache=True,
key_func=to_hashable,
cache_size=4096,
pseudo_label_list: list,
max_err: float = 1e-10,
use_cache: bool = True,
key_func: Callable = to_hashable,
cache_size: int = 4096,
):
if not isinstance(pseudo_label_list, list):
raise TypeError(f"pseudo_label_list should be list, got {type(pseudo_label_list)}")
@@ -79,7 +80,7 @@ class KBBase(ABC):
# TODO 添加consistency measure+max_err容忍错误

@abstractmethod
def logic_forward(self, pseudo_label, x = None):
def logic_forward(self, pseudo_label: List[Any], x: Optional[List[Any]] = None) -> Any:
"""
How to perform (deductive) logical reasoning, i.e. matching each pseudo label sample to
their reasoning result. Users are required to provide this.
@@ -88,9 +89,25 @@ class KBBase(ABC):
----------
pseudo_label : List[Any]
Pseudo label sample.
x : Optional[List[Any]]
The corresponding input sample. If deductive logical reasoning does not require any
information from the input, the overridden function provided by the user can omit
this parameter.
Returns
-------
Any
The reasoning result.
"""

def abduce_candidates(self, pseudo_label, y, x, max_revision_num, require_more_revision):
def abduce_candidates(
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
max_revision_num: int,
require_more_revision: int,
) -> List[List[Any]]:
"""
Perform abductive reasoning to get a candidate compatible with the knowledge base.

@@ -98,7 +115,7 @@ class KBBase(ABC):
----------
pseudo_label : List[Any]
Pseudo label sample (to be revised by abductive reasoning).
y : any
y : Any
Ground truth of the reasoning result for the sample.
x : List[Any]
The corresponding input sample.
@@ -115,7 +132,7 @@ class KBBase(ABC):
"""
return self._abduce_by_search(pseudo_label, y, x, max_revision_num, require_more_revision)

def _check_equal(self, logic_result, y):
def _check_equal(self, reasoning_result: Any, y: Any) -> bool:
"""
Check whether the reasoning result of a pseduo label sample is equal to the ground truth
(or, within the maximum error allowed for numerical results).
@@ -125,15 +142,21 @@ class KBBase(ABC):
bool
The result of the check.
"""
if logic_result is None:
if reasoning_result is None:
return False

if isinstance(logic_result, (int, float)) and isinstance(y, (int, float)):
return abs(logic_result - y) <= self.max_err
if isinstance(reasoning_result, (int, float)) and isinstance(y, (int, float)):
return abs(reasoning_result - y) <= self.max_err
else:
return logic_result == y

def revise_at_idx(self, pseudo_label, y, x, revision_idx):
return reasoning_result == y

def revise_at_idx(
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
revision_idx: List[int],
) -> List[List[Any]]:
"""
Revise the pseudo label sample at specified index positions.

@@ -145,8 +168,8 @@ class KBBase(ABC):
Ground truth of the reasoning result for the sample.
x : List[Any]
The corresponding input sample.
revision_idx : array-like
Indices of where revisions should be made to the pseudo label sample.
revision_idx : List[int]
A list specifying indices of where revisions should be made to the pseudo label sample.

Returns
-------
@@ -164,7 +187,13 @@ class KBBase(ABC):
candidates.append(candidate)
return candidates

def _revision(self, revision_num, pseudo_label, y, x):
def _revision(
self,
revision_num: int,
pseudo_label: List[Any],
y: Any,
x: List[Any],
) -> List[List[Any]]:
"""
For a specified number of labels in a pseudo label sample to revise, iterate through
all possible indices to find any candidates that are compatible with the knowledge base.
@@ -178,7 +207,14 @@ class KBBase(ABC):
return new_candidates

@abl_cache()
def _abduce_by_search(self, pseudo_label, y, x, max_revision_num, require_more_revision):
def _abduce_by_search(
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
max_revision_num: int,
require_more_revision: int,
) -> List[List[Any]]:
"""
Perform abductive reasoning by exhastive search. Specifically, begin with 0 and
continuously increase the number of labels in a pseudo label sample to revise, until
@@ -302,7 +338,14 @@ class GroundKB(KBBase):
X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1]))
return X, Y

def abduce_candidates(self, pseudo_label, y, x, max_revision_num, require_more_revision):
def abduce_candidates(
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
max_revision_num: int,
require_more_revision: int,
) -> List[List[Any]]:
"""
Perform abductive reasoning by directly retrieving compatible candidates from
the prebuilt GKB. In this way, the time-consuming exhaustive search can be
@@ -312,13 +355,13 @@ class GroundKB(KBBase):
----------
pseudo_label : List[Any]
Pseudo label sample (to be revised by abductive reasoning).
y : any
y : Any
Ground truth of the reasoning result for the sample.
x : List[Any]
The corresponding input sample (unused in GroundKB).
max_revision_num : int
The upper limit on the number of revised labels for each sample.
require_more_revision : int, optional
require_more_revision : int
Specifies additional number of revisions permitted beyond the minimum required.

Returns
@@ -341,7 +384,7 @@ class GroundKB(KBBase):
candidates = [all_candidates[idx] for idx in idxs]
return candidates

def _find_candidate_GKB(self, pseudo_label, y):
def _find_candidate_GKB(self, pseudo_label: List[Any], y: Any) -> List[List[Any]]:
"""
Retrieve compatible candidates from the prebuilt GKB. For numerical reasoning results,
return all candidates whose reasoning results fall within the
@@ -408,7 +451,7 @@ class PrologKB(KBBase):
class. Users are also welcome to override related functions for more flexible support.
"""

def __init__(self, pseudo_label_list, pl_file):
def __init__(self, pseudo_label_list: List[Any], pl_file: str):
super().__init__(pseudo_label_list)
self.pl_file = pl_file
self.prolog = pyswip.Prolog()
@@ -417,7 +460,7 @@ class PrologKB(KBBase):
raise FileNotFoundError(f"The Prolog file {self.pl_file} does not exist.")
self.prolog.consult(self.pl_file)

def logic_forward(self, pseudo_labels):
def logic_forward(self, pseudo_label: List[Any]) -> Any:
"""
Consult prolog with the query `logic_forward(pseudo_labels, Res).`, and set the
returned `Res` as the reasoning results. To use this default function, there must be
@@ -429,14 +472,18 @@ class PrologKB(KBBase):
pseudo_label : List[Any]
Pseudo label sample.
"""
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]["Res"]
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_label))[0]["Res"]
if result == "true":
return True
elif result == "false":
return False
return result

def _revision_pseudo_label(self, pseudo_label, revision_idx):
def _revision_pseudo_label(
self,
pseudo_label: List[Any],
revision_idx: List[int],
) -> List[Any]:
import re

revision_pseudo_label = pseudo_label.copy()
@@ -449,7 +496,13 @@ class PrologKB(KBBase):
regex = r"'P\d+'"
return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pseudo_label))

def get_query_string(self, pseudo_label, y, revision_idx):
def get_query_string(
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
revision_idx: List[int],
) -> str:
"""
Get the query to be used for consulting Prolog.
This is a default function for demo, users would override this function to adapt to
@@ -460,10 +513,12 @@ class PrologKB(KBBase):
----------
pseudo_label : List[Any]
Pseudo label sample (to be revised by abductive reasoning).
y : any
y : Any
Ground truth of the reasoning result for the sample.
revision_idx : array-like
Indices of where revisions should be made to the pseudo label sample.
x : List[Any]
The corresponding input sample.
revision_idx : List[int]
A list specifying indices of where revisions should be made to the pseudo label sample.

Returns
-------
@@ -476,7 +531,13 @@ class PrologKB(KBBase):
query_string += ",%s)." % y if not key_is_none_flag else ")."
return query_string

def revise_at_idx(self, pseudo_label, y, x, revision_idx):
def revise_at_idx(
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
revision_idx: List[int],
) -> List[List[Any]]:
"""
Revise the pseudo label sample at specified index positions by querying Prolog.

@@ -488,8 +549,8 @@ class PrologKB(KBBase):
Ground truth of the reasoning result for the sample.
x : List[Any]
The corresponding input sample.
revision_idx : array-like
Indices of where revisions should be made to the pseudo label sample.
revision_idx : List[int]
A list specifying indices of where revisions should be made to the pseudo label sample.

Returns
-------


+ 87
- 39
abl/reasoning/reasoner.py View File

@@ -1,6 +1,9 @@
import numpy as np
from zoopt import Dimension, Objective, Opt, Parameter
from typing import Callable, Any, List, Optional

from kb import KBBase
from ..structures import ListData
from ..utils.utils import confidence_dist, hamming_dist


@@ -12,16 +15,12 @@ class Reasoner:
----------
kb : class KBBase
The knowledge base to be used for reasoning.
dist_func : str, optional
dist_func : str or Callable, optional
The distance function to be used when determining the cost list between each
candidate and the given prediction. Valid options include: "confidence" (default) |
"hamming". For "confidence", it calculates the distance between the prediction
and the candidate based on confidence derived from the predicted probabilities in the
data sample. For "hamming", it directly calculates the Hamming distance between
the predicted pseudo label sample and the candidate.
mapping : dict, optional
candidate and the given prediction. Defaults to "confidence".
mapping : Optional[dict], optional
A mapping from index in the base model to label. If not provided, a default
order-based mapping is created.
order-based mapping is created. Defaults to None.
max_revision : int or float, optional
The upper limit on the number of revisions for each data sample when
performing abductive reasoning. If float, denotes the fraction of the total
@@ -36,18 +35,13 @@ class Reasoner:

def __init__(
self,
kb,
dist_func="confidence",
mapping=None,
max_revision=-1,
require_more_revision=0,
use_zoopt=False,
kb: KBBase,
dist_func: str or Callable = "confidence",
mapping: Optional[dict] = None,
max_revision: int or float = -1,
require_more_revision: int = 0,
use_zoopt: bool = False,
):
if dist_func not in ["hamming", "confidence"]:
raise NotImplementedError(
'Valid options for dist_func include "hamming" and "confidence"'
)

self.kb = kb
self.dist_func = dist_func
self.use_zoopt = use_zoopt
@@ -57,17 +51,24 @@ class Reasoner:
if mapping is None:
self.mapping = {index: label for index, label in enumerate(self.kb.pseudo_label_list)}
else:
if not isinstance(mapping, dict):
raise TypeError(f"mapping should be dict, got {type(mapping)}")
for key, value in mapping.items():
if not isinstance(key, int):
raise ValueError(f"All keys in the mapping must be integers, got {key}")
if value not in self.kb.pseudo_label_list:
raise ValueError(f"All values in the mapping must be in the pseudo_label_list, got {value}")
self._check_valid_mapping(mapping)
self.mapping = mapping
self.remapping = dict(zip(self.mapping.values(), self.mapping.keys()))

def _get_one_candidate(self, data_sample, candidates):
def _check_valid_mapping(self, mapping):
if not isinstance(mapping, dict):
raise TypeError(f"mapping should be dict, got {type(mapping)}")
for key, value in mapping.items():
if not isinstance(key, int):
raise ValueError(f"All keys in the mapping must be integers, got {key}")
if value not in self.kb.pseudo_label_list:
raise ValueError(f"All values in the mapping must be in the pseudo_label_list, got {value}")
def _get_one_candidate(
self,
data_sample: ListData,
candidates: List[List[Any]],
) -> List[Any]:
"""
Due to the nondeterminism of abductive reasoning, there could be multiple candidates
satisfying the knowledge base. When this happens, return one candidate that has the
@@ -90,13 +91,19 @@ class Reasoner:
elif len(candidates) == 1:
return candidates[0]
else:
cost_array = self._get_cost_list(data_sample, candidates)
cost_array = self.get_cost_list(data_sample, candidates)
candidate = candidates[np.argmin(cost_array)]
return candidate

def _get_cost_list(self, data_sample, candidates):
def get_cost_list(
self,
data_sample: ListData,
candidates: List[List[Any]],
) -> np.ndarray:
"""
Get the list of costs between each candidate and the given data sample. The list is
Get the list of costs between each candidate and the given data sample.
The list is
calculated based on one of the following distance functions:
- "hamming": Directly calculates the Hamming distance between the predicted pseudo
label in the data sample and candidate.
@@ -110,6 +117,11 @@ class Reasoner:
Data sample.
candidates : List[List[Any]]
Multiple compatible candidates.
Returns
-------
np.ndarray
A Numpy array representing list of costs.
"""
if self.dist_func == "hamming":
return hamming_dist(data_sample.pred_pseudo_label, candidates)
@@ -117,8 +129,20 @@ class Reasoner:
elif self.dist_func == "confidence":
candidates = [[self.remapping[x] for x in c] for c in candidates]
return confidence_dist(data_sample.pred_prob, candidates)
elif callable(self.dist_func):
return self.dist_func(data_sample, candidates)

def zoopt_get_solution(self, symbol_num, data_sample, max_revision_num):
else:
raise ValueError("dist_func must be either a string or a callable function")


def _zoopt_get_solution(
self,
symbol_num: int,
data_sample: ListData,
max_revision_num: int,
) -> List[bool]:
"""
Get the optimal solution using ZOOpt library. The solution is a list of
boolean values, where '1' (True) indicates the indices chosen to be revised.
@@ -131,6 +155,11 @@ class Reasoner:
Data sample.
max_revision_num : int
Specifies the maximum number of revisions allowed.
Returns
-------
List[bool]
The solution for ZOOpt library.
"""
dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num)
objective = Objective(
@@ -142,21 +171,40 @@ class Reasoner:
solution = Opt.min(objective, parameter).get_x()
return solution

def zoopt_revision_score(self, symbol_num, data_sample, sol):
def zoopt_revision_score(
self,
symbol_num: int,
data_sample: ListData,
sol: List[bool],
) -> int:
"""
Get the revision score for a solution. A lower score suggests that ZOOpt library
has a higher preference for this solution.
Parameters
----------
symbol_num : int
Number of total symbols.
data_sample : ListData
Data sample.
sol: List[bool]
The solution for ZOOpt library.
Returns
-------
int
The revision score for the solution.
"""
revision_idx = np.where(sol.get_x() != 0)[0]
candidates = self.kb.revise_at_idx(
data_sample.pred_pseudo_label, data_sample.Y, data_sample.X, revision_idx
)
if len(candidates) > 0:
return np.min(self._get_cost_list(data_sample, candidates))
return np.min(self.get_cost_list(data_sample, candidates))
else:
return symbol_num

def _constrain_revision_num(self, solution, max_revision_num):
def _constrain_revision_num(self, solution: List[bool], max_revision_num: int) -> int:
"""
Constrain that the total number of revisions chosen by the solution does not exceed
maximum number of revisions allowed.
@@ -164,7 +212,7 @@ class Reasoner:
x = solution.get_x()
return max_revision_num - x.sum()

def _get_max_revision_num(self, max_revision, symbol_num):
def _get_max_revision_num(self, max_revision: int or float, symbol_num: int) -> int:
"""
Get the maximum revision number according to input `max_revision`.
"""
@@ -182,7 +230,7 @@ class Reasoner:
raise ValueError(f"If max_revision is an int, it must be non-negative, but got {max_revision}")
return max_revision

def abduce(self, data_sample):
def abduce(self, data_sample: ListData) -> List[Any]:
"""
Perform abductive reasoning on the given data sample.

@@ -201,7 +249,7 @@ class Reasoner:
max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)

if self.use_zoopt:
solution = self.zoopt_get_solution(symbol_num, data_sample, max_revision_num)
solution = self._zoopt_get_solution(symbol_num, data_sample, max_revision_num)
revision_idx = np.where(solution != 0)[0]
candidates = self.kb.revise_at_idx(
data_sample.pred_pseudo_label, data_sample.Y, data_sample.X, revision_idx
@@ -218,7 +266,7 @@ class Reasoner:
candidate = self._get_one_candidate(data_sample, candidates)
return candidate

def batch_abduce(self, data_samples):
def batch_abduce(self, data_samples: ListData) -> List[List[Any]]:
"""
Perform abductive reasoning on the given prediction data samples.
For detailed information, refer to `abduce`.
@@ -227,5 +275,5 @@ class Reasoner:
data_samples.abduced_pseudo_label = abduced_pseudo_label
return abduced_pseudo_label

def __call__(self, data_samples):
def __call__(self, data_samples: ListData) -> List[List[Any]]:
return self.batch_abduce(data_samples)

Loading…
Cancel
Save