Browse Source

[MNT] resolve comments in reasoner.py

ab_data
Gao Enhao 1 year ago
parent
commit
41d52ef6c4
1 changed files with 16 additions and 6 deletions
  1. +16
    -6
      abl/reasoning/reasoner.py

+ 16
- 6
abl/reasoning/reasoner.py View File

@@ -4,8 +4,7 @@ import numpy as np
from zoopt import Dimension, Objective, Opt, Parameter, Solution from zoopt import Dimension, Objective, Opt, Parameter, Solution


from ..structures import ListData from ..structures import ListData
from ..utils.utils import (calculate_revision_num, confidence_dist,
hamming_dist, reform_idx)
from ..utils.utils import calculate_revision_num, confidence_dist, hamming_dist, reform_idx
from .base_kb import BaseKB from .base_kb import BaseKB




@@ -13,7 +12,7 @@ class ReasonerBase:
def __init__( def __init__(
self, self,
kb: BaseKB, kb: BaseKB,
dist_func: str = "hamming",
dist_func: str = "confidence",
mapping: Mapping = None, mapping: Mapping = None,
use_zoopt: bool = False, use_zoopt: bool = False,
): ):
@@ -25,7 +24,7 @@ class ReasonerBase:
kb : BaseKB kb : BaseKB
The knowledge base to be used for reasoning. The knowledge base to be used for reasoning.
dist_func : str, optional dist_func : str, optional
The distance function to be used. Can be "hamming" or "confidence". Default is "hamming".
The distance function to be used. Can be "hamming" or "confidence". Default is "confidence".
mapping : dict, optional mapping : dict, optional
A mapping of indices to labels. If None, a default mapping is generated. A mapping of indices to labels. If None, a default mapping is generated.
use_zoopt : bool, optional use_zoopt : bool, optional
@@ -37,8 +36,8 @@ class ReasonerBase:
If the specified distance function is neither "hamming" nor "confidence". If the specified distance function is neither "hamming" nor "confidence".
""" """


if not (dist_func == "hamming" or dist_func == "confidence"):
raise NotImplementedError # Only hamming or confidence distance is available.
if dist_func not in ["hamming", "confidence"]:
raise NotImplementedError(f"The distance function '{dist_func}' is not implemented.")


self.kb = kb self.kb = kb
self.dist_func = dist_func self.dist_func = dist_func
@@ -46,7 +45,18 @@ class ReasonerBase:
if mapping is None: if mapping is None:
self.mapping = {index: label for index, label in enumerate(self.kb.pseudo_label_list)} self.mapping = {index: label for index, label in enumerate(self.kb.pseudo_label_list)}
else: else:
if not isinstance(mapping, dict):
raise ValueError("mapping must be of type dict")

for key, value in mapping.items():
if not isinstance(key, int):
raise ValueError("All keys in the mapping must be integers")

if value not in self.kb.pseudo_label_list:
raise ValueError("All values in the mapping must be in the pseudo_label_list")

self.mapping = mapping self.mapping = mapping

self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) self.remapping = dict(zip(self.mapping.values(), self.mapping.keys()))


def _get_cost_list(self, data_sample: ListData, candidates: List[List[Any]]): def _get_cost_list(self, data_sample: ListData, candidates: List[List[Any]]):


Loading…
Cancel
Save