Browse Source

[FIX] change return parameter type of get_cost_list

pull/1/head
troyyyyy 1 year ago
parent
commit
e8b3d06517
1 changed files with 6 additions and 6 deletions
  1. +6
    -6
      abl/reasoning/reasoner.py

+ 6
- 6
abl/reasoning/reasoner.py View File

@@ -17,7 +17,7 @@ class Reasoner:
---------- ----------
kb : class KBBase kb : class KBBase
The knowledge base to be used for reasoning. The knowledge base to be used for reasoning.
dist_func : str or Callable, optional
dist_func : Union[str, Callable], optional
The distance function used to determine the cost list between each The distance function used to determine the cost list between each
candidate and the given prediction. The cost is also referred to as a consistency candidate and the given prediction. The cost is also referred to as a consistency
measure, wherein the candidate with lowest cost is selected as the final measure, wherein the candidate with lowest cost is selected as the final
@@ -35,7 +35,7 @@ class Reasoner:
mapping : Optional[dict], optional mapping : Optional[dict], optional
A mapping from index in the base model to label. If not provided, a default A mapping from index in the base model to label. If not provided, a default
order-based mapping is created. Defaults to None. order-based mapping is created. Defaults to None.
max_revision : int or float, optional
max_revision : Union[int, float], optional
The upper limit on the number of revisions for each data sample when The upper limit on the number of revisions for each data sample when
performing abductive reasoning. If float, denotes the fraction of the total performing abductive reasoning. If float, denotes the fraction of the total
length that can be revised. A value of -1 implies no restriction on the length that can be revised. A value of -1 implies no restriction on the
@@ -137,7 +137,7 @@ class Reasoner:
data_sample: ListData, data_sample: ListData,
candidates: List[List[Any]], candidates: List[List[Any]],
reasoning_results: List[Any], reasoning_results: List[Any],
) -> np.ndarray:
) -> Union[List[Union[int, float]], np.ndarray]:
""" """
Get the list of costs between each candidate and the given data sample. Get the list of costs between each candidate and the given data sample.


@@ -152,8 +152,8 @@ class Reasoner:


Returns Returns
------- -------
np.ndarray
A Numpy array representing list of costs.
Union[List[Union[int, float]], np.ndarray]
The list of costs.
""" """
if self.dist_func == "hamming": if self.dist_func == "hamming":
return hamming_dist(data_sample.pred_pseudo_label, candidates) return hamming_dist(data_sample.pred_pseudo_label, candidates)
@@ -244,7 +244,7 @@ class Reasoner:
x = solution.get_x() x = solution.get_x()
return max_revision_num - x.sum() return max_revision_num - x.sum()


def _get_max_revision_num(self, max_revision: int or float, symbol_num: int) -> int:
def _get_max_revision_num(self, max_revision: Union[int, float], symbol_num: int) -> int:
""" """
Get the maximum revision number according to input `max_revision`. Get the maximum revision number according to input `max_revision`.
""" """


Loading…
Cancel
Save