|
|
@@ -3,8 +3,7 @@ from typing import Any, List, Mapping, Optional |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from ..structures import ListData |
|
|
|
from ..utils import (Cache, calculate_revision_num, confidence_dist, |
|
|
|
hamming_dist) |
|
|
|
from ..utils import Cache, calculate_revision_num, confidence_dist, hamming_dist |
|
|
|
from .base_kb import BaseKB |
|
|
|
from .search_engine import BFS, BaseSearchEngine |
|
|
|
|
|
|
@@ -81,70 +80,53 @@ class ReasonerBase: |
|
|
|
else: |
|
|
|
key_func = lambda x: x |
|
|
|
self.cache = Cache[ListData, List[List[Any]]]( |
|
|
|
func=self.abduce, |
|
|
|
func=self.abduce_candidates, |
|
|
|
cache=self.use_cache, |
|
|
|
cache_file=self.cache_file, |
|
|
|
key_func=key_func, |
|
|
|
max_size=cache_size, |
|
|
|
) |
|
|
|
|
|
|
|
def _get_dist_list(self, data_sample: ListData, candidates: List[List[Any]]): |
|
|
|
def abduce( |
|
|
|
self, |
|
|
|
data_sample: ListData, |
|
|
|
max_revision: int = -1, |
|
|
|
require_more_revision: int = 0, |
|
|
|
): |
|
|
|
""" |
|
|
|
Get the list of costs between each pseudo label and candidate. |
|
|
|
Perform revision by abduction on the given data. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
pred_pseudo_label : list |
|
|
|
The pseudo label to be used for computing costs of candidates. |
|
|
|
pred_prob : list |
|
|
|
Probabilities of the predictions. Used when distance function is "confidence". |
|
|
|
candidates : list |
|
|
|
List of candidate abduction result. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
numpy.ndarray |
|
|
|
Array of computed costs for each candidate. |
|
|
|
""" |
|
|
|
if self.dist_func == "hamming": |
|
|
|
return hamming_dist(data_sample["pred_pseudo_label"][0], candidates) |
|
|
|
|
|
|
|
elif self.dist_func == "confidence": |
|
|
|
candidates = [[self.remapping[x] for x in c] for c in candidates] |
|
|
|
return confidence_dist(data_sample["pred_prob"][0], candidates) |
|
|
|
|
|
|
|
def select(self, data_sample: ListData, candidates: List[List[Any]]): |
|
|
|
""" |
|
|
|
Get one candidate. If multiple candidates exist, return the one with minimum cost. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
List of probabilities for predicted results. |
|
|
|
pred_pseudo_label : list |
|
|
|
The pseudo label to be used for selecting a candidate. |
|
|
|
pred_prob : list |
|
|
|
Probabilities of the predictions. |
|
|
|
candidates : list |
|
|
|
List of candidate abduction result. |
|
|
|
List of predicted pseudo labels. |
|
|
|
y : any |
|
|
|
Ground truth for the predicted results. |
|
|
|
max_revision : int or float, optional |
|
|
|
Maximum number of revisions to use. If float, represents the fraction of total revisions to use. |
|
|
|
If -1, any revisions are allowed. Defaults to -1. |
|
|
|
require_more_revision : int, optional |
|
|
|
Number of additional revisions to require. Defaults to 0. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
list |
|
|
|
The chosen candidate based on minimum cost. |
|
|
|
If no candidates, an empty list is returned. |
|
|
|
The abduced revisions. |
|
|
|
""" |
|
|
|
if len(candidates) == 0: |
|
|
|
return [] |
|
|
|
elif len(candidates) == 1: |
|
|
|
return candidates[0] |
|
|
|
else: |
|
|
|
cost_array = self._get_dist_list(data_sample, candidates) |
|
|
|
candidate = candidates[np.argmin(cost_array)] |
|
|
|
return candidate |
|
|
|
symbol_num = data_sample.elements_num("pred_pseudo_label") |
|
|
|
max_revision_num = calculate_revision_num(max_revision, symbol_num) |
|
|
|
data_sample.set_metainfo(dict(symbol_num=symbol_num)) |
|
|
|
|
|
|
|
def abduce( |
|
|
|
candidates = self.cache.get(data_sample, max_revision_num, require_more_revision) |
|
|
|
candidate = self.select_one_candidate(data_sample, candidates) |
|
|
|
return candidate |
|
|
|
|
|
|
|
def abduce_candidates( |
|
|
|
self, |
|
|
|
data_sample: ListData, |
|
|
|
max_revision: int = -1, |
|
|
|
max_revision_num: int = -1, |
|
|
|
require_more_revision: int = 0, |
|
|
|
): |
|
|
|
""" |
|
|
@@ -169,9 +151,6 @@ class ReasonerBase: |
|
|
|
list |
|
|
|
The abduced revisions. |
|
|
|
""" |
|
|
|
symbol_num = data_sample.elements_num("pred_pseudo_label") |
|
|
|
max_revision_num = calculate_revision_num(max_revision, symbol_num) |
|
|
|
data_sample.set_metainfo(dict(symbol_num=symbol_num)) |
|
|
|
|
|
|
|
if hasattr(self.kb, "abduce_candidates"): |
|
|
|
candidates = self.kb.abduce_candidates( |
|
|
@@ -198,9 +177,60 @@ class ReasonerBase: |
|
|
|
raise NotImplementedError( |
|
|
|
"The kb should either implement abduce_candidates or revise_at_idx." |
|
|
|
) |
|
|
|
return candidates |
|
|
|
|
|
|
|
candidate = self.select(data_sample, candidates) |
|
|
|
return candidate |
|
|
|
def select_one_candidate(self, data_sample: ListData, candidates: List[List[Any]]): |
|
|
|
""" |
|
|
|
Get one candidate. If multiple candidates exist, return the one with minimum cost. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
pred_pseudo_label : list |
|
|
|
The pseudo label to be used for selecting a candidate. |
|
|
|
pred_prob : list |
|
|
|
Probabilities of the predictions. |
|
|
|
candidates : list |
|
|
|
List of candidate abduction result. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
list |
|
|
|
The chosen candidate based on minimum cost. |
|
|
|
If no candidates, an empty list is returned. |
|
|
|
""" |
|
|
|
if len(candidates) == 0: |
|
|
|
return [] |
|
|
|
elif len(candidates) == 1: |
|
|
|
return candidates[0] |
|
|
|
else: |
|
|
|
cost_array = self._get_dist_list(data_sample, candidates) |
|
|
|
candidate = candidates[np.argmin(cost_array)] |
|
|
|
return candidate |
|
|
|
|
|
|
|
def _get_dist_list(self, data_sample: ListData, candidates: List[List[Any]]): |
|
|
|
""" |
|
|
|
Get the list of costs between each pseudo label and candidate. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
pred_pseudo_label : list |
|
|
|
The pseudo label to be used for computing costs of candidates. |
|
|
|
pred_prob : list |
|
|
|
Probabilities of the predictions. Used when distance function is "confidence". |
|
|
|
candidates : list |
|
|
|
List of candidate abduction result. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
numpy.ndarray |
|
|
|
Array of computed costs for each candidate. |
|
|
|
""" |
|
|
|
if self.dist_func == "hamming": |
|
|
|
return hamming_dist(data_sample["pred_pseudo_label"][0], candidates) |
|
|
|
|
|
|
|
elif self.dist_func == "confidence": |
|
|
|
candidates = [[self.remapping[x] for x in c] for c in candidates] |
|
|
|
return confidence_dist(data_sample["pred_prob"][0], candidates) |
|
|
|
|
|
|
|
def batch_abduce( |
|
|
|
self, |
|
|
@@ -231,7 +261,7 @@ class ReasonerBase: |
|
|
|
The abduced revisions in batches. |
|
|
|
""" |
|
|
|
abduced_pseudo_label = [ |
|
|
|
self.cache.get( |
|
|
|
self.abduce( |
|
|
|
data_sample, |
|
|
|
max_revision=max_revision, |
|
|
|
require_more_revision=require_more_revision, |
|
|
|