From 2d834333885f01181e663cb995ddb419cf25cf3a Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Tue, 31 Oct 2023 14:03:05 +0800 Subject: [PATCH] [FIX] fix typo --- abl/reasoning/reasoner.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 278bcc8..e221792 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -20,14 +20,13 @@ class ReasonerBase: dist_func : str, optional The distance function to be used when determining the cost list between each candidate and the given prediction. Valid options include: `"hamming"` (default) - | `"confidence"`. Any other options will raise a `NotImplementedError`. - For detailed explanations of these options, refer to `_get_cost_list`. - mapping : dictt, optional - A mapping from label to index. If not provided, a default - order-based mapping is created. + | `"confidence"`. Any other options will raise a `NotImplementedError`. For + detailed explanations of these options, refer to `_get_cost_list`. + mapping : dict, optional + A mapping from label to index. If not provided, a default order-based mapping is + created. use_zoopt : bool, optional - Whether to use the Zoopt library during abductive reasoning. - Default is False. + Whether to use the Zoopt library during abductive reasoning. Default to False. """ def __init__(self, kb, dist_func="hamming", mapping=None, use_zoopt=False): @@ -42,12 +41,14 @@ class ReasonerBase: label: index for index, label in enumerate(self.kb.pseudo_label_list) } else: + if not isinstance(mapping, dict): + raise TypeError("mapping should be dict") self.mapping = mapping def _get_one_candidate(self, pred_pseudo_label, pred_prob, candidates): """ Due to the nondeterminism of abductive reasoning, there could be multiple candidates - satisfying the knowledge base. If this happens, return one candidate that has the + satisfying the knowledge base. When this happens, return one candidate that has the minimum cost. If no candidates are provided, an empty list is returned. Parameters @@ -58,7 +59,7 @@ class ReasonerBase: Predicted probabilities of the prediction (Each sublist contains the probability values of all pseudo labels). candidates : List[List[Any]] - Several candidate abduction results. + Multiple candidate abduction results. """ if len(candidates) == 0: return [] @@ -86,7 +87,7 @@ class ReasonerBase: Predicted probabilities of the prediction (Each sublist contains the probability values of all pseudo labels). Used when distance function is "confidence". candidates : List[List[Any]] - Several candidate abduction results. + Multiple candidate abduction results. """ if self.dist_func == "hamming": return hamming_dist(pred_pseudo_label, candidates)