diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 724a8e4..7bd4944 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -4,8 +4,7 @@ import numpy as np from zoopt import Dimension, Objective, Opt, Parameter, Solution from ..structures import ListData -from ..utils.utils import (calculate_revision_num, confidence_dist, - hamming_dist, reform_idx) +from ..utils.utils import calculate_revision_num, confidence_dist, hamming_dist, reform_idx from .base_kb import BaseKB @@ -13,7 +12,7 @@ class ReasonerBase: def __init__( self, kb: BaseKB, - dist_func: str = "hamming", + dist_func: str = "confidence", mapping: Mapping = None, use_zoopt: bool = False, ): @@ -25,7 +24,7 @@ class ReasonerBase: kb : BaseKB The knowledge base to be used for reasoning. dist_func : str, optional - The distance function to be used. Can be "hamming" or "confidence". Default is "hamming". + The distance function to be used. Can be "hamming" or "confidence". Default is "confidence". mapping : dict, optional A mapping of indices to labels. If None, a default mapping is generated. use_zoopt : bool, optional @@ -37,8 +36,8 @@ class ReasonerBase: If the specified distance function is neither "hamming" nor "confidence". """ - if not (dist_func == "hamming" or dist_func == "confidence"): - raise NotImplementedError # Only hamming or confidence distance is available. + if dist_func not in ["hamming", "confidence"]: + raise NotImplementedError(f"The distance function '{dist_func}' is not implemented.") self.kb = kb self.dist_func = dist_func @@ -46,7 +45,18 @@ class ReasonerBase: if mapping is None: self.mapping = {index: label for index, label in enumerate(self.kb.pseudo_label_list)} else: + if not isinstance(mapping, dict): + raise ValueError("mapping must be of type dict") + + for key, value in mapping.items(): + if not isinstance(key, int): + raise ValueError("All keys in the mapping must be integers") + + if value not in self.kb.pseudo_label_list: + raise ValueError("All values in the mapping must be in the pseudo_label_list") + self.mapping = mapping + self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) def _get_cost_list(self, data_sample: ListData, candidates: List[List[Any]]):