|
|
@@ -1,6 +1,6 @@ |
|
|
|
import numpy as np |
|
|
|
from zoopt import Dimension, Objective, Parameter, Opt |
|
|
|
from ..utils.utils import ( |
|
|
|
from abl.utils.utils import ( |
|
|
|
confidence_dist, |
|
|
|
flatten, |
|
|
|
reform_list, |
|
|
@@ -50,7 +50,7 @@ class ReasonerBase: |
|
|
|
self.mapping = mapping |
|
|
|
self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) |
|
|
|
|
|
|
|
def _get_one_candidate(self, pred_pseudo_label, pred_prob, candidates): |
|
|
|
def _get_one_candidate(self, data_sample, candidates): |
|
|
|
""" |
|
|
|
Due to the nondeterminism of abductive reasoning, there could be multiple candidates |
|
|
|
satisfying the knowledge base. When this happens, return one candidate that has the |
|
|
@@ -71,11 +71,11 @@ class ReasonerBase: |
|
|
|
elif len(candidates) == 1: |
|
|
|
return candidates[0] |
|
|
|
else: |
|
|
|
cost_array = self._get_cost_list(pred_pseudo_label, pred_prob, candidates) |
|
|
|
cost_array = self._get_cost_list(data_sample, candidates) |
|
|
|
candidate = candidates[np.argmin(cost_array)] |
|
|
|
return candidate |
|
|
|
|
|
|
|
def _get_cost_list(self, pred_pseudo_label, pred_prob, candidates): |
|
|
|
def _get_cost_list(self, data_sample, candidates): |
|
|
|
""" |
|
|
|
Get the list of costs between each candidate and the given prediction. The list is |
|
|
|
calculated based on one of the following distance functions: |
|
|
@@ -95,15 +95,15 @@ class ReasonerBase: |
|
|
|
Multiple consistent candidates. |
|
|
|
""" |
|
|
|
if self.dist_func == "hamming": |
|
|
|
return hamming_dist(pred_pseudo_label, candidates) |
|
|
|
return hamming_dist(data_sample.pred_pseudo_label, candidates) |
|
|
|
|
|
|
|
elif self.dist_func == "confidence": |
|
|
|
candidates = [[self.remapping[x] for x in c] for c in candidates] |
|
|
|
return confidence_dist(pred_prob, candidates) |
|
|
|
return confidence_dist(data_sample.pred_prob, candidates) |
|
|
|
|
|
|
|
|
|
|
|
def zoopt_get_solution( |
|
|
|
self, symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num |
|
|
|
self, symbol_num, data_sample, max_revision_num |
|
|
|
): |
|
|
|
""" |
|
|
|
Get the optimal solution using the Zoopt library. The solution is a list of |
|
|
@@ -113,13 +113,8 @@ class ReasonerBase: |
|
|
|
---------- |
|
|
|
symbol_num : int |
|
|
|
Number of total symbols. |
|
|
|
pred_pseudo_label : List[Any] |
|
|
|
Predicted pseudo label. |
|
|
|
pred_prob : List[List[Any]] |
|
|
|
Predicted probabilities of the prediction (Each sublist contains the probability |
|
|
|
distribution over all pseudo labels). |
|
|
|
y : Any |
|
|
|
Ground truth for the logical result. |
|
|
|
data_sample : ListData |
|
|
|
|
|
|
|
max_revision_num : int |
|
|
|
Specifies the maximum number of revisions allowed. |
|
|
|
""" |
|
|
@@ -128,7 +123,7 @@ class ReasonerBase: |
|
|
|
) |
|
|
|
objective = Objective( |
|
|
|
lambda sol: self.zoopt_revision_score( |
|
|
|
symbol_num, pred_pseudo_label, pred_prob, y, sol |
|
|
|
symbol_num, data_sample, sol |
|
|
|
), |
|
|
|
dim=dimension, |
|
|
|
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), |
|
|
@@ -137,15 +132,15 @@ class ReasonerBase: |
|
|
|
solution = Opt.min(objective, parameter).get_x() |
|
|
|
return solution |
|
|
|
|
|
|
|
def zoopt_revision_score(self, symbol_num, pred_pseudo_label, pred_prob, y, sol): |
|
|
|
def zoopt_revision_score(self, symbol_num, data_sample, sol): |
|
|
|
""" |
|
|
|
Get the revision score for a solution. A lower score suggests that the Zoopt library |
|
|
|
has a higher preference for this solution. |
|
|
|
""" |
|
|
|
revision_idx = np.where(sol.get_x() != 0)[0] |
|
|
|
candidates = self.revise_at_idx(pred_pseudo_label, y, revision_idx) |
|
|
|
candidates = self.revise_at_idx(data_sample, revision_idx) |
|
|
|
if len(candidates) > 0: |
|
|
|
return np.min(self._get_cost_list(pred_pseudo_label, pred_prob, candidates)) |
|
|
|
return np.min(self._get_cost_list(data_sample, candidates)) |
|
|
|
else: |
|
|
|
return symbol_num |
|
|
|
|
|
|
@@ -157,7 +152,7 @@ class ReasonerBase: |
|
|
|
x = solution.get_x() |
|
|
|
return max_revision_num - x.sum() |
|
|
|
|
|
|
|
def revise_at_idx(self, pred_pseudo_label, y, revision_idx): |
|
|
|
def revise_at_idx(self, data_sample, revision_idx): |
|
|
|
""" |
|
|
|
Revise the predicted pseudo label at specified index positions. |
|
|
|
|
|
|
@@ -170,7 +165,15 @@ class ReasonerBase: |
|
|
|
revision_idx : array-like |
|
|
|
Indices of where revisions should be made to the predicted pseudo label. |
|
|
|
""" |
|
|
|
return self.kb.revise_at_idx(pred_pseudo_label, y, revision_idx) |
|
|
|
return self.kb.revise_at_idx(data_sample.pred_pseudo_label, |
|
|
|
data_sample.Y, |
|
|
|
revision_idx) |
|
|
|
|
|
|
|
def abduce_candidates(self, data_sample, max_revision_num, require_more_revision): |
|
|
|
return self.kb.abduce_candidates(data_sample.pred_pseudo_label, |
|
|
|
data_sample.Y, |
|
|
|
max_revision_num, |
|
|
|
require_more_revision) |
|
|
|
|
|
|
|
def _get_max_revision_num(self, max_revision, symbol_num): |
|
|
|
""" |
|
|
@@ -222,22 +225,18 @@ class ReasonerBase: |
|
|
|
symbol_num = data_sample.elements_num("pred_pseudo_label") |
|
|
|
max_revision_num = self._get_max_revision_num(max_revision, symbol_num) |
|
|
|
|
|
|
|
pred_pseudo_label = data_sample.pred_pseudo_label |
|
|
|
pred_prob = data_sample.pred_prob |
|
|
|
y = data_sample.Y |
|
|
|
|
|
|
|
if self.use_zoopt: |
|
|
|
solution = self.zoopt_get_solution( |
|
|
|
symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num |
|
|
|
symbol_num, data_sample, max_revision_num |
|
|
|
) |
|
|
|
revision_idx = np.where(solution != 0)[0] |
|
|
|
candidates = self.revise_at_idx(pred_pseudo_label, y, revision_idx) |
|
|
|
candidates = self.revise_at_idx(data_sample, revision_idx) |
|
|
|
else: |
|
|
|
candidates = self.kb.abduce_candidates( |
|
|
|
pred_pseudo_label, y, max_revision_num, require_more_revision |
|
|
|
candidates = self.abduce_candidates( |
|
|
|
data_sample, max_revision_num, require_more_revision |
|
|
|
) |
|
|
|
|
|
|
|
candidate = self._get_one_candidate(pred_pseudo_label, pred_prob, candidates) |
|
|
|
candidate = self._get_one_candidate(data_sample, candidates) |
|
|
|
return candidate |
|
|
|
|
|
|
|
def batch_abduce( |
|
|
@@ -254,15 +253,6 @@ class ReasonerBase: |
|
|
|
data_samples.abduced_pseudo_label = abduced_pseudo_label |
|
|
|
return abduced_pseudo_label |
|
|
|
|
|
|
|
# def _batch_abduce_helper(self, args): |
|
|
|
# z, prob, y, max_revision, require_more_revision = args |
|
|
|
# return self.abduce((z, prob, y), max_revision, require_more_revision) |
|
|
|
|
|
|
|
# def batch_abduce(self, Z, Y, max_revision=-1, require_more_revision=0): |
|
|
|
# with Pool(processes=os.cpu_count()) as pool: |
|
|
|
# results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]) |
|
|
|
# return results |
|
|
|
|
|
|
|
def __call__( |
|
|
|
self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0 |
|
|
|
): |
|
|
@@ -486,8 +476,8 @@ if __name__ == "__main__": |
|
|
|
pl_query = "eval_inst_feature(%s, %s)." % (exs, rules) |
|
|
|
return len(list(self.prolog.query(pl_query))) != 0 |
|
|
|
|
|
|
|
def abduce_rules(self, pred_res): |
|
|
|
pl_query = "consistent_inst_feature(%s, X)." % pred_res |
|
|
|
def abduce_rules(self, pseudo_labels): |
|
|
|
pl_query = "consistent_inst_feature(%s, X)." % pseudo_labels |
|
|
|
prolog_result = list(self.prolog.query(pl_query)) |
|
|
|
if len(prolog_result) == 0: |
|
|
|
return None |
|
|
@@ -499,32 +489,36 @@ if __name__ == "__main__": |
|
|
|
def __init__(self, kb, dist_func="hamming"): |
|
|
|
super().__init__(kb, dist_func, use_zoopt=True) |
|
|
|
|
|
|
|
def _revise_at_idxs(self, pred_res, y, all_revision_flag, idxs): |
|
|
|
pred = [] |
|
|
|
k = [] |
|
|
|
def _revise_at_idxs(self, pseudo_labels, ys, all_revision_flag, idxs): |
|
|
|
data_sample = ListData() |
|
|
|
data_sample.pred_pseudo_label = [] |
|
|
|
data_sample.Y = [] |
|
|
|
revision_flag = [] |
|
|
|
for idx in idxs: |
|
|
|
pred.append(pred_res[idx]) |
|
|
|
k.append(y[idx]) |
|
|
|
data_sample.pred_pseudo_label.append(pseudo_labels[idx]) |
|
|
|
data_sample.Y.append(ys[idx]) |
|
|
|
revision_flag += list(all_revision_flag[idx]) |
|
|
|
revision_idx = np.where(np.array(revision_flag) != 0)[0] |
|
|
|
candidate = self.revise_at_idx(pred, k, revision_idx) |
|
|
|
candidate = self.revise_at_idx(data_sample, revision_idx) |
|
|
|
return candidate |
|
|
|
|
|
|
|
def zoopt_revision_score(self, symbol_num, pred_res, pred_prob, y, sol): |
|
|
|
all_revision_flag = reform_list(sol.get_x(), pred_res) |
|
|
|
lefted_idxs = [i for i in range(len(pred_res))] |
|
|
|
def zoopt_revision_score(self, symbol_num, data_sample, sol): |
|
|
|
pseudo_labels = data_sample.pred_pseudo_label |
|
|
|
ys = data_sample.Y |
|
|
|
|
|
|
|
all_revision_flag = reform_list(sol.get_x(), pseudo_labels) |
|
|
|
lefted_idxs = [i for i in range(len(pseudo_labels))] |
|
|
|
candidate_size = [] |
|
|
|
while lefted_idxs: |
|
|
|
idxs = [] |
|
|
|
idxs.append(lefted_idxs.pop(0)) |
|
|
|
max_candidate_idxs = [] |
|
|
|
found = False |
|
|
|
for idx in range(-1, len(pred_res)): |
|
|
|
for idx in range(-1, len(pseudo_labels)): |
|
|
|
if (not idx in idxs) and (idx >= 0): |
|
|
|
idxs.append(idx) |
|
|
|
candidate = self._revise_at_idxs( |
|
|
|
pred_res, y, all_revision_flag, idxs |
|
|
|
pseudo_labels, ys, all_revision_flag, idxs |
|
|
|
) |
|
|
|
if len(candidate) == 0: |
|
|
|
if len(idxs) > 1: |
|
|
@@ -547,8 +541,8 @@ if __name__ == "__main__": |
|
|
|
score -= math.exp(-i) * candidate_size[i] |
|
|
|
return score |
|
|
|
|
|
|
|
def abduce_rules(self, pred_res): |
|
|
|
return self.kb.abduce_rules(pred_res) |
|
|
|
def abduce_rules(self, pseudo_labels): |
|
|
|
return self.kb.abduce_rules(pseudo_labels) |
|
|
|
|
|
|
|
kb = HedKB( |
|
|
|
pseudo_label_list=[1, 0, "+", "="], |
|
|
|