|
|
@@ -17,7 +17,7 @@ class Reasoner: |
|
|
|
---------- |
|
|
|
kb : class KBBase |
|
|
|
The knowledge base to be used for reasoning. |
|
|
|
dist_func : str or Callable, optional |
|
|
|
dist_func : Union[str, Callable], optional |
|
|
|
The distance function used to determine the cost list between each |
|
|
|
candidate and the given prediction. The cost is also referred to as a consistency |
|
|
|
measure, wherein the candidate with lowest cost is selected as the final |
|
|
@@ -35,7 +35,7 @@ class Reasoner: |
|
|
|
mapping : Optional[dict], optional |
|
|
|
A mapping from index in the base model to label. If not provided, a default |
|
|
|
order-based mapping is created. Defaults to None. |
|
|
|
max_revision : int or float, optional |
|
|
|
max_revision : Union[int, float], optional |
|
|
|
The upper limit on the number of revisions for each data sample when |
|
|
|
performing abductive reasoning. If float, denotes the fraction of the total |
|
|
|
length that can be revised. A value of -1 implies no restriction on the |
|
|
@@ -137,7 +137,7 @@ class Reasoner: |
|
|
|
data_sample: ListData, |
|
|
|
candidates: List[List[Any]], |
|
|
|
reasoning_results: List[Any], |
|
|
|
) -> np.ndarray: |
|
|
|
) -> Union[List[Union[int, float]], np.ndarray]: |
|
|
|
""" |
|
|
|
Get the list of costs between each candidate and the given data sample. |
|
|
|
|
|
|
@@ -152,8 +152,8 @@ class Reasoner: |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
np.ndarray |
|
|
|
A Numpy array representing list of costs. |
|
|
|
Union[List[Union[int, float]], np.ndarray] |
|
|
|
The list of costs. |
|
|
|
""" |
|
|
|
if self.dist_func == "hamming": |
|
|
|
return hamming_dist(data_sample.pred_pseudo_label, candidates) |
|
|
@@ -244,7 +244,7 @@ class Reasoner: |
|
|
|
x = solution.get_x() |
|
|
|
return max_revision_num - x.sum() |
|
|
|
|
|
|
|
def _get_max_revision_num(self, max_revision: int or float, symbol_num: int) -> int: |
|
|
|
def _get_max_revision_num(self, max_revision: Union[int, float], symbol_num: int) -> int: |
|
|
|
""" |
|
|
|
Get the maximum revision number according to input `max_revision`. |
|
|
|
""" |
|
|
|