Browse Source

[MNT] resolve comments in reasoner.py

ab_data
Gao Enhao 1 year ago
parent
commit
41d52ef6c4
1 changed files with 16 additions and 6 deletions
  1. +16
    -6
      abl/reasoning/reasoner.py

+ 16
- 6
abl/reasoning/reasoner.py View File

@@ -4,8 +4,7 @@ import numpy as np
from zoopt import Dimension, Objective, Opt, Parameter, Solution

from ..structures import ListData
from ..utils.utils import (calculate_revision_num, confidence_dist,
hamming_dist, reform_idx)
from ..utils.utils import calculate_revision_num, confidence_dist, hamming_dist, reform_idx
from .base_kb import BaseKB


@@ -13,7 +12,7 @@ class ReasonerBase:
def __init__(
self,
kb: BaseKB,
dist_func: str = "hamming",
dist_func: str = "confidence",
mapping: Mapping = None,
use_zoopt: bool = False,
):
@@ -25,7 +24,7 @@ class ReasonerBase:
kb : BaseKB
The knowledge base to be used for reasoning.
dist_func : str, optional
The distance function to be used. Can be "hamming" or "confidence". Default is "hamming".
The distance function to be used. Can be "hamming" or "confidence". Default is "confidence".
mapping : dict, optional
A mapping of indices to labels. If None, a default mapping is generated.
use_zoopt : bool, optional
@@ -37,8 +36,8 @@ class ReasonerBase:
If the specified distance function is neither "hamming" nor "confidence".
"""

if not (dist_func == "hamming" or dist_func == "confidence"):
raise NotImplementedError # Only hamming or confidence distance is available.
if dist_func not in ["hamming", "confidence"]:
raise NotImplementedError(f"The distance function '{dist_func}' is not implemented.")

self.kb = kb
self.dist_func = dist_func
@@ -46,7 +45,18 @@ class ReasonerBase:
if mapping is None:
self.mapping = {index: label for index, label in enumerate(self.kb.pseudo_label_list)}
else:
if not isinstance(mapping, dict):
raise ValueError("mapping must be of type dict")

for key, value in mapping.items():
if not isinstance(key, int):
raise ValueError("All keys in the mapping must be integers")

if value not in self.kb.pseudo_label_list:
raise ValueError("All values in the mapping must be in the pseudo_label_list")

self.mapping = mapping

self.remapping = dict(zip(self.mapping.values(), self.mapping.keys()))

def _get_cost_list(self, data_sample: ListData, candidates: List[List[Any]]):


Loading…
Cancel
Save