|
|
@@ -32,9 +32,9 @@ class Reasoner: |
|
|
|
in this cost list should be a numerical value representing the cost for each |
|
|
|
candidate, and the list should have the same length as candidates. |
|
|
|
Defaults to 'confidence'. |
|
|
|
mapping : Optional[dict], optional |
|
|
|
idx_to_label : Optional[dict], optional |
|
|
|
A mapping from index in the base model to label. If not provided, a default |
|
|
|
order-based mapping is created. Defaults to None. |
|
|
|
order-based index to label mapping is created. Defaults to None. |
|
|
|
max_revision : Union[int, float], optional |
|
|
|
The upper limit on the number of revisions for each data sample when |
|
|
|
performing abductive reasoning. If float, denotes the fraction of the total |
|
|
@@ -51,7 +51,7 @@ class Reasoner: |
|
|
|
self, |
|
|
|
kb: KBBase, |
|
|
|
dist_func: Union[str, Callable] = "confidence", |
|
|
|
mapping: Optional[dict] = None, |
|
|
|
idx_to_label: Optional[dict] = None, |
|
|
|
max_revision: Union[int, float] = -1, |
|
|
|
require_more_revision: int = 0, |
|
|
|
use_zoopt: bool = False, |
|
|
@@ -63,12 +63,12 @@ class Reasoner: |
|
|
|
self.max_revision = max_revision |
|
|
|
self.require_more_revision = require_more_revision |
|
|
|
|
|
|
|
if mapping is None: |
|
|
|
self.mapping = {index: label for index, label in enumerate(self.kb.pseudo_label_list)} |
|
|
|
if idx_to_label is None: |
|
|
|
self.idx_to_label = {index: label for index, label in enumerate(self.kb.pseudo_label_list)} |
|
|
|
else: |
|
|
|
self._check_valid_mapping(mapping) |
|
|
|
self.mapping = mapping |
|
|
|
self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) |
|
|
|
self._check_valid_idx_to_label(idx_to_label) |
|
|
|
self.idx_to_label = idx_to_label |
|
|
|
self.label_to_idx = dict(zip(self.idx_to_label.values(), self.idx_to_label.keys())) |
|
|
|
|
|
|
|
def _check_valid_dist(self, dist_func): |
|
|
|
if isinstance(dist_func, str): |
|
|
@@ -87,15 +87,15 @@ class Reasoner: |
|
|
|
f"dist_func must be a string or a callable function, but got {type(dist_func)}." |
|
|
|
) |
|
|
|
|
|
|
|
def _check_valid_mapping(self, mapping): |
|
|
|
if not isinstance(mapping, dict): |
|
|
|
raise TypeError(f"mapping should be dict, but got {type(mapping)}.") |
|
|
|
for key, value in mapping.items(): |
|
|
|
def _check_valid_idx_to_label(self, idx_to_label): |
|
|
|
if not isinstance(idx_to_label, dict): |
|
|
|
raise TypeError(f"idx_to_label should be dict, but got {type(idx_to_label)}.") |
|
|
|
for key, value in idx_to_label.items(): |
|
|
|
if not isinstance(key, int): |
|
|
|
raise ValueError(f"All keys in the mapping must be integers, but got {key}.") |
|
|
|
raise ValueError(f"All keys in the idx_to_label must be integers, but got {key}.") |
|
|
|
if value not in self.kb.pseudo_label_list: |
|
|
|
raise ValueError( |
|
|
|
f"All values in the mapping must be in the pseudo_label_list, but got {value}." |
|
|
|
f"All values in the idx_to_label must be in the pseudo_label_list, but got {value}." |
|
|
|
) |
|
|
|
|
|
|
|
def _get_one_candidate( |
|
|
@@ -158,10 +158,10 @@ class Reasoner: |
|
|
|
if self.dist_func == "hamming": |
|
|
|
return hamming_dist(data_sample.pred_pseudo_label, candidates) |
|
|
|
elif self.dist_func == "confidence": |
|
|
|
candidates = [[self.remapping[x] for x in c] for c in candidates] |
|
|
|
candidates = [[self.label_to_idx[x] for x in c] for c in candidates] |
|
|
|
return confidence_dist(data_sample.pred_prob, candidates) |
|
|
|
else: |
|
|
|
candidate_idxs = [[self.remapping[x] for x in c] for c in candidates] |
|
|
|
candidate_idxs = [[self.label_to_idx[x] for x in c] for c in candidates] |
|
|
|
cost_list = self.dist_func(data_sample, candidates, candidate_idxs, reasoning_results) |
|
|
|
if len(cost_list) != len(candidates): |
|
|
|
raise ValueError( |
|
|
|