|
|
@@ -4,8 +4,7 @@ import numpy as np |
|
|
|
from zoopt import Dimension, Objective, Opt, Parameter, Solution |
|
|
|
|
|
|
|
from ..structures import ListData |
|
|
|
from ..utils.utils import (calculate_revision_num, confidence_dist, |
|
|
|
hamming_dist, reform_idx) |
|
|
|
from ..utils.utils import calculate_revision_num, confidence_dist, hamming_dist, reform_idx |
|
|
|
from .base_kb import BaseKB |
|
|
|
|
|
|
|
|
|
|
@@ -13,7 +12,7 @@ class ReasonerBase: |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
kb: BaseKB, |
|
|
|
dist_func: str = "hamming", |
|
|
|
dist_func: str = "confidence", |
|
|
|
mapping: Mapping = None, |
|
|
|
use_zoopt: bool = False, |
|
|
|
): |
|
|
@@ -25,7 +24,7 @@ class ReasonerBase: |
|
|
|
kb : BaseKB |
|
|
|
The knowledge base to be used for reasoning. |
|
|
|
dist_func : str, optional |
|
|
|
The distance function to be used. Can be "hamming" or "confidence". Default is "hamming". |
|
|
|
The distance function to be used. Can be "hamming" or "confidence". Default is "confidence". |
|
|
|
mapping : dict, optional |
|
|
|
A mapping of indices to labels. If None, a default mapping is generated. |
|
|
|
use_zoopt : bool, optional |
|
|
@@ -37,8 +36,8 @@ class ReasonerBase: |
|
|
|
If the specified distance function is neither "hamming" nor "confidence". |
|
|
|
""" |
|
|
|
|
|
|
|
if not (dist_func == "hamming" or dist_func == "confidence"): |
|
|
|
raise NotImplementedError # Only hamming or confidence distance is available. |
|
|
|
if dist_func not in ["hamming", "confidence"]: |
|
|
|
raise NotImplementedError(f"The distance function '{dist_func}' is not implemented.") |
|
|
|
|
|
|
|
self.kb = kb |
|
|
|
self.dist_func = dist_func |
|
|
@@ -46,7 +45,18 @@ class ReasonerBase: |
|
|
|
if mapping is None: |
|
|
|
self.mapping = {index: label for index, label in enumerate(self.kb.pseudo_label_list)} |
|
|
|
else: |
|
|
|
if not isinstance(mapping, dict): |
|
|
|
raise ValueError("mapping must be of type dict") |
|
|
|
|
|
|
|
for key, value in mapping.items(): |
|
|
|
if not isinstance(key, int): |
|
|
|
raise ValueError("All keys in the mapping must be integers") |
|
|
|
|
|
|
|
if value not in self.kb.pseudo_label_list: |
|
|
|
raise ValueError("All values in the mapping must be in the pseudo_label_list") |
|
|
|
|
|
|
|
self.mapping = mapping |
|
|
|
|
|
|
|
self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) |
|
|
|
|
|
|
|
def _get_cost_list(self, data_sample: ListData, candidates: List[List[Any]]): |
|
|
|