Browse Source

[FIX] change return parameter type of get_cost_list

pull/1/head
troyyyyy 1 year ago
parent
commit
e8b3d06517
1 changed files with 6 additions and 6 deletions
  1. +6
    -6
      abl/reasoning/reasoner.py

+ 6
- 6
abl/reasoning/reasoner.py View File

@@ -17,7 +17,7 @@ class Reasoner:
----------
kb : class KBBase
The knowledge base to be used for reasoning.
dist_func : str or Callable, optional
dist_func : Union[str, Callable], optional
The distance function used to determine the cost list between each
candidate and the given prediction. The cost is also referred to as a consistency
measure, wherein the candidate with lowest cost is selected as the final
@@ -35,7 +35,7 @@ class Reasoner:
mapping : Optional[dict], optional
A mapping from index in the base model to label. If not provided, a default
order-based mapping is created. Defaults to None.
max_revision : int or float, optional
max_revision : Union[int, float], optional
The upper limit on the number of revisions for each data sample when
performing abductive reasoning. If float, denotes the fraction of the total
length that can be revised. A value of -1 implies no restriction on the
@@ -137,7 +137,7 @@ class Reasoner:
data_sample: ListData,
candidates: List[List[Any]],
reasoning_results: List[Any],
) -> np.ndarray:
) -> Union[List[Union[int, float]], np.ndarray]:
"""
Get the list of costs between each candidate and the given data sample.

@@ -152,8 +152,8 @@ class Reasoner:

Returns
-------
np.ndarray
A Numpy array representing list of costs.
Union[List[Union[int, float]], np.ndarray]
The list of costs.
"""
if self.dist_func == "hamming":
return hamming_dist(data_sample.pred_pseudo_label, candidates)
@@ -244,7 +244,7 @@ class Reasoner:
x = solution.get_x()
return max_revision_num - x.sum()

def _get_max_revision_num(self, max_revision: int or float, symbol_num: int) -> int:
def _get_max_revision_num(self, max_revision: Union[int, float], symbol_num: int) -> int:
"""
Get the maximum revision number according to input `max_revision`.
"""


Loading…
Cancel
Save