diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 644a24d..54f77d3 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -17,7 +17,7 @@ class Reasoner: ---------- kb : class KBBase The knowledge base to be used for reasoning. - dist_func : str or Callable, optional + dist_func : Union[str, Callable], optional The distance function used to determine the cost list between each candidate and the given prediction. The cost is also referred to as a consistency measure, wherein the candidate with lowest cost is selected as the final @@ -35,7 +35,7 @@ class Reasoner: mapping : Optional[dict], optional A mapping from index in the base model to label. If not provided, a default order-based mapping is created. Defaults to None. - max_revision : int or float, optional + max_revision : Union[int, float], optional The upper limit on the number of revisions for each data sample when performing abductive reasoning. If float, denotes the fraction of the total length that can be revised. A value of -1 implies no restriction on the @@ -137,7 +137,7 @@ class Reasoner: data_sample: ListData, candidates: List[List[Any]], reasoning_results: List[Any], - ) -> np.ndarray: + ) -> Union[List[Union[int, float]], np.ndarray]: """ Get the list of costs between each candidate and the given data sample. @@ -152,8 +152,8 @@ class Reasoner: Returns ------- - np.ndarray - A Numpy array representing list of costs. + Union[List[Union[int, float]], np.ndarray] + The list of costs. """ if self.dist_func == "hamming": return hamming_dist(data_sample.pred_pseudo_label, candidates) @@ -244,7 +244,7 @@ class Reasoner: x = solution.get_x() return max_revision_num - x.sum() - def _get_max_revision_num(self, max_revision: int or float, symbol_num: int) -> int: + def _get_max_revision_num(self, max_revision: Union[int, float], symbol_num: int) -> int: """ Get the maximum revision number according to input `max_revision`. """