|
|
@@ -24,17 +24,35 @@ class ReasonerBase: |
|
|
|
mapping : dict, optional |
|
|
|
A mapping from index in the base model to label. If not provided, a default |
|
|
|
order-based mapping is created. |
|
|
|
max_revision : int or float, optional |
|
|
|
The upper limit on the number of revisions for each data sample when |
|
|
|
performing abductive reasoning. If float, denotes the fraction of the total |
|
|
|
length that can be revised. A value of -1 implies no restriction on the |
|
|
|
number of revisions. Defaults to -1. |
|
|
|
require_more_revision : int, optional |
|
|
|
Specifies additional number of revisions permitted beyond the minimum required |
|
|
|
when performing abductive reasoning. Defaults to 0. |
|
|
|
use_zoopt : bool, optional |
|
|
|
Whether to use the Zoopt library during abductive reasoning. Defaults to False. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, kb, dist_func="confidence", mapping=None, use_zoopt=False): |
|
|
|
def __init__(self, |
|
|
|
kb, |
|
|
|
dist_func="confidence", |
|
|
|
mapping=None, |
|
|
|
max_revision=-1, |
|
|
|
require_more_revision=0, |
|
|
|
use_zoopt=False, |
|
|
|
): |
|
|
|
if dist_func not in ["hamming", "confidence"]: |
|
|
|
raise NotImplementedError("Valid options for dist_func include \"hamming\" and \"confidence\"") |
|
|
|
|
|
|
|
self.kb = kb |
|
|
|
self.dist_func = dist_func |
|
|
|
self.use_zoopt = use_zoopt |
|
|
|
self.max_revision = max_revision |
|
|
|
self.require_more_revision = require_more_revision |
|
|
|
|
|
|
|
if mapping is None: |
|
|
|
self.mapping = { |
|
|
|
index: label for index, label in enumerate(self.kb.pseudo_label_list) |
|
|
@@ -117,9 +135,7 @@ class ReasonerBase: |
|
|
|
size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num |
|
|
|
) |
|
|
|
objective = Objective( |
|
|
|
lambda sol: self.zoopt_revision_score( |
|
|
|
symbol_num, data_sample, sol |
|
|
|
), |
|
|
|
lambda sol: self.zoopt_revision_score(symbol_num, data_sample, sol), |
|
|
|
dim=dimension, |
|
|
|
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), |
|
|
|
) |
|
|
@@ -133,7 +149,9 @@ class ReasonerBase: |
|
|
|
has a higher preference for this solution. |
|
|
|
""" |
|
|
|
revision_idx = np.where(sol.get_x() != 0)[0] |
|
|
|
candidates = self.revise_at_idx(data_sample, revision_idx) |
|
|
|
candidates = self.kb.revise_at_idx(data_sample.pred_pseudo_label, |
|
|
|
data_sample.Y, |
|
|
|
revision_idx) |
|
|
|
if len(candidates) > 0: |
|
|
|
return np.min(self._get_cost_list(data_sample, candidates)) |
|
|
|
else: |
|
|
@@ -146,27 +164,6 @@ class ReasonerBase: |
|
|
|
""" |
|
|
|
x = solution.get_x() |
|
|
|
return max_revision_num - x.sum() |
|
|
|
|
|
|
|
def revise_at_idx(self, data_sample, revision_idx): |
|
|
|
""" |
|
|
|
Revise the pseudo label in the data sample at specified index positions. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
data_sample : ListData |
|
|
|
Data sample. |
|
|
|
revision_idx : array-like |
|
|
|
Indices of where revisions should be made to the predicted pseudo label. |
|
|
|
""" |
|
|
|
return self.kb.revise_at_idx(data_sample.pred_pseudo_label, |
|
|
|
data_sample.Y, |
|
|
|
revision_idx) |
|
|
|
|
|
|
|
def abduce_candidates(self, data_sample, max_revision_num, require_more_revision): |
|
|
|
return self.kb.abduce_candidates(data_sample.pred_pseudo_label, |
|
|
|
data_sample.Y, |
|
|
|
max_revision_num, |
|
|
|
require_more_revision) |
|
|
|
|
|
|
|
def _get_max_revision_num(self, max_revision, symbol_num): |
|
|
|
""" |
|
|
@@ -186,9 +183,7 @@ class ReasonerBase: |
|
|
|
raise ValueError("If max_revision is an int, it must be non-negative.") |
|
|
|
return max_revision |
|
|
|
|
|
|
|
def abduce( |
|
|
|
self, data_sample, max_revision=-1, require_more_revision=0 |
|
|
|
): |
|
|
|
def abduce(self, data_sample): |
|
|
|
""" |
|
|
|
Perform abductive reasoning on the given data sample. |
|
|
|
|
|
|
@@ -196,14 +191,7 @@ class ReasonerBase: |
|
|
|
---------- |
|
|
|
data_sample : ListData |
|
|
|
Data sample. |
|
|
|
max_revision : int or float, optional |
|
|
|
The upper limit on the number of revisions. If float, denotes the fraction of the |
|
|
|
total length that can be revised. A value of -1 implies no restriction on the number |
|
|
|
of revisions. Defaults to -1. |
|
|
|
require_more_revision : int, optional |
|
|
|
Specifies additional number of revisions permitted beyond the minimum required. |
|
|
|
Defaults to 0. |
|
|
|
|
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
List[Any] |
|
|
@@ -211,39 +199,33 @@ class ReasonerBase: |
|
|
|
knowledge base. |
|
|
|
""" |
|
|
|
symbol_num = data_sample.elements_num("pred_pseudo_label") |
|
|
|
max_revision_num = self._get_max_revision_num(max_revision, symbol_num) |
|
|
|
max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num) |
|
|
|
|
|
|
|
if self.use_zoopt: |
|
|
|
solution = self.zoopt_get_solution( |
|
|
|
symbol_num, data_sample, max_revision_num |
|
|
|
) |
|
|
|
solution = self.zoopt_get_solution(symbol_num, data_sample, max_revision_num) |
|
|
|
revision_idx = np.where(solution != 0)[0] |
|
|
|
candidates = self.revise_at_idx(data_sample, revision_idx) |
|
|
|
candidates = self.self.kb.revise_at_idx(data_sample.pred_pseudo_label, |
|
|
|
data_sample.Y, |
|
|
|
revision_idx) |
|
|
|
else: |
|
|
|
candidates = self.abduce_candidates( |
|
|
|
data_sample, max_revision_num, require_more_revision |
|
|
|
) |
|
|
|
|
|
|
|
candidates = self.kb.abduce_candidates(data_sample.pred_pseudo_label, |
|
|
|
data_sample.Y, |
|
|
|
max_revision_num, |
|
|
|
self.require_more_revision) |
|
|
|
|
|
|
|
candidate = self._get_one_candidate(data_sample, candidates) |
|
|
|
return candidate |
|
|
|
|
|
|
|
def batch_abduce( |
|
|
|
self, data_samples, max_revision=-1, require_more_revision=0 |
|
|
|
): |
|
|
|
def batch_abduce(self, data_samples): |
|
|
|
""" |
|
|
|
Perform abductive reasoning on the given prediction data samples. |
|
|
|
For detailed information, refer to `abduce`. |
|
|
|
""" |
|
|
|
abduced_pseudo_label = [ |
|
|
|
self.abduce(data_sample, max_revision, require_more_revision) |
|
|
|
for data_sample in data_samples |
|
|
|
self.abduce(data_sample) for data_sample in data_samples |
|
|
|
] |
|
|
|
data_samples.abduced_pseudo_label = abduced_pseudo_label |
|
|
|
return abduced_pseudo_label |
|
|
|
|
|
|
|
def __call__( |
|
|
|
self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0 |
|
|
|
): |
|
|
|
return self.batch_abduce( |
|
|
|
pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision |
|
|
|
) |
|
|
|
def __call__(self, data_samples): |
|
|
|
return self.batch_abduce(data_samples) |