Browse Source

[ENH] change parameters passing in reasoning

pull/1/head
troyyyyy 1 year ago
parent
commit
bf04dd9c95
1 changed files with 47 additions and 53 deletions
  1. +47
    -53
      abl/reasoning/reasoner.py

+ 47
- 53
abl/reasoning/reasoner.py View File

@@ -1,6 +1,6 @@
import numpy as np
from zoopt import Dimension, Objective, Parameter, Opt
from ..utils.utils import (
from abl.utils.utils import (
confidence_dist,
flatten,
reform_list,
@@ -50,7 +50,7 @@ class ReasonerBase:
self.mapping = mapping
self.remapping = dict(zip(self.mapping.values(), self.mapping.keys()))

def _get_one_candidate(self, pred_pseudo_label, pred_prob, candidates):
def _get_one_candidate(self, data_sample, candidates):
"""
Due to the nondeterminism of abductive reasoning, there could be multiple candidates
satisfying the knowledge base. When this happens, return one candidate that has the
@@ -71,11 +71,11 @@ class ReasonerBase:
elif len(candidates) == 1:
return candidates[0]
else:
cost_array = self._get_cost_list(pred_pseudo_label, pred_prob, candidates)
cost_array = self._get_cost_list(data_sample, candidates)
candidate = candidates[np.argmin(cost_array)]
return candidate
def _get_cost_list(self, pred_pseudo_label, pred_prob, candidates):
def _get_cost_list(self, data_sample, candidates):
"""
Get the list of costs between each candidate and the given prediction. The list is
calculated based on one of the following distance functions:
@@ -95,15 +95,15 @@ class ReasonerBase:
Multiple consistent candidates.
"""
if self.dist_func == "hamming":
return hamming_dist(pred_pseudo_label, candidates)
return hamming_dist(data_sample.pred_pseudo_label, candidates)

elif self.dist_func == "confidence":
candidates = [[self.remapping[x] for x in c] for c in candidates]
return confidence_dist(pred_prob, candidates)
return confidence_dist(data_sample.pred_prob, candidates)


def zoopt_get_solution(
self, symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num
self, symbol_num, data_sample, max_revision_num
):
"""
Get the optimal solution using the Zoopt library. The solution is a list of
@@ -113,13 +113,8 @@ class ReasonerBase:
----------
symbol_num : int
Number of total symbols.
pred_pseudo_label : List[Any]
Predicted pseudo label.
pred_prob : List[List[Any]]
Predicted probabilities of the prediction (Each sublist contains the probability
distribution over all pseudo labels).
y : Any
Ground truth for the logical result.
data_sample : ListData
max_revision_num : int
Specifies the maximum number of revisions allowed.
"""
@@ -128,7 +123,7 @@ class ReasonerBase:
)
objective = Objective(
lambda sol: self.zoopt_revision_score(
symbol_num, pred_pseudo_label, pred_prob, y, sol
symbol_num, data_sample, sol
),
dim=dimension,
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),
@@ -137,15 +132,15 @@ class ReasonerBase:
solution = Opt.min(objective, parameter).get_x()
return solution
def zoopt_revision_score(self, symbol_num, pred_pseudo_label, pred_prob, y, sol):
def zoopt_revision_score(self, symbol_num, data_sample, sol):
"""
Get the revision score for a solution. A lower score suggests that the Zoopt library
has a higher preference for this solution.
"""
revision_idx = np.where(sol.get_x() != 0)[0]
candidates = self.revise_at_idx(pred_pseudo_label, y, revision_idx)
candidates = self.revise_at_idx(data_sample, revision_idx)
if len(candidates) > 0:
return np.min(self._get_cost_list(pred_pseudo_label, pred_prob, candidates))
return np.min(self._get_cost_list(data_sample, candidates))
else:
return symbol_num

@@ -157,7 +152,7 @@ class ReasonerBase:
x = solution.get_x()
return max_revision_num - x.sum()
def revise_at_idx(self, pred_pseudo_label, y, revision_idx):
def revise_at_idx(self, data_sample, revision_idx):
"""
Revise the predicted pseudo label at specified index positions.

@@ -170,7 +165,15 @@ class ReasonerBase:
revision_idx : array-like
Indices of where revisions should be made to the predicted pseudo label.
"""
return self.kb.revise_at_idx(pred_pseudo_label, y, revision_idx)
return self.kb.revise_at_idx(data_sample.pred_pseudo_label,
data_sample.Y,
revision_idx)
def abduce_candidates(self, data_sample, max_revision_num, require_more_revision):
return self.kb.abduce_candidates(data_sample.pred_pseudo_label,
data_sample.Y,
max_revision_num,
require_more_revision)

def _get_max_revision_num(self, max_revision, symbol_num):
"""
@@ -222,22 +225,18 @@ class ReasonerBase:
symbol_num = data_sample.elements_num("pred_pseudo_label")
max_revision_num = self._get_max_revision_num(max_revision, symbol_num)
pred_pseudo_label = data_sample.pred_pseudo_label
pred_prob = data_sample.pred_prob
y = data_sample.Y
if self.use_zoopt:
solution = self.zoopt_get_solution(
symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num
symbol_num, data_sample, max_revision_num
)
revision_idx = np.where(solution != 0)[0]
candidates = self.revise_at_idx(pred_pseudo_label, y, revision_idx)
candidates = self.revise_at_idx(data_sample, revision_idx)
else:
candidates = self.kb.abduce_candidates(
pred_pseudo_label, y, max_revision_num, require_more_revision
candidates = self.abduce_candidates(
data_sample, max_revision_num, require_more_revision
)

candidate = self._get_one_candidate(pred_pseudo_label, pred_prob, candidates)
candidate = self._get_one_candidate(data_sample, candidates)
return candidate

def batch_abduce(
@@ -254,15 +253,6 @@ class ReasonerBase:
data_samples.abduced_pseudo_label = abduced_pseudo_label
return abduced_pseudo_label

# def _batch_abduce_helper(self, args):
# z, prob, y, max_revision, require_more_revision = args
# return self.abduce((z, prob, y), max_revision, require_more_revision)

# def batch_abduce(self, Z, Y, max_revision=-1, require_more_revision=0):
# with Pool(processes=os.cpu_count()) as pool:
# results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)])
# return results

def __call__(
self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0
):
@@ -486,8 +476,8 @@ if __name__ == "__main__":
pl_query = "eval_inst_feature(%s, %s)." % (exs, rules)
return len(list(self.prolog.query(pl_query))) != 0

def abduce_rules(self, pred_res):
pl_query = "consistent_inst_feature(%s, X)." % pred_res
def abduce_rules(self, pseudo_labels):
pl_query = "consistent_inst_feature(%s, X)." % pseudo_labels
prolog_result = list(self.prolog.query(pl_query))
if len(prolog_result) == 0:
return None
@@ -499,32 +489,36 @@ if __name__ == "__main__":
def __init__(self, kb, dist_func="hamming"):
super().__init__(kb, dist_func, use_zoopt=True)

def _revise_at_idxs(self, pred_res, y, all_revision_flag, idxs):
pred = []
k = []
def _revise_at_idxs(self, pseudo_labels, ys, all_revision_flag, idxs):
data_sample = ListData()
data_sample.pred_pseudo_label = []
data_sample.Y = []
revision_flag = []
for idx in idxs:
pred.append(pred_res[idx])
k.append(y[idx])
data_sample.pred_pseudo_label.append(pseudo_labels[idx])
data_sample.Y.append(ys[idx])
revision_flag += list(all_revision_flag[idx])
revision_idx = np.where(np.array(revision_flag) != 0)[0]
candidate = self.revise_at_idx(pred, k, revision_idx)
candidate = self.revise_at_idx(data_sample, revision_idx)
return candidate

def zoopt_revision_score(self, symbol_num, pred_res, pred_prob, y, sol):
all_revision_flag = reform_list(sol.get_x(), pred_res)
lefted_idxs = [i for i in range(len(pred_res))]
def zoopt_revision_score(self, symbol_num, data_sample, sol):
pseudo_labels = data_sample.pred_pseudo_label
ys = data_sample.Y
all_revision_flag = reform_list(sol.get_x(), pseudo_labels)
lefted_idxs = [i for i in range(len(pseudo_labels))]
candidate_size = []
while lefted_idxs:
idxs = []
idxs.append(lefted_idxs.pop(0))
max_candidate_idxs = []
found = False
for idx in range(-1, len(pred_res)):
for idx in range(-1, len(pseudo_labels)):
if (not idx in idxs) and (idx >= 0):
idxs.append(idx)
candidate = self._revise_at_idxs(
pred_res, y, all_revision_flag, idxs
pseudo_labels, ys, all_revision_flag, idxs
)
if len(candidate) == 0:
if len(idxs) > 1:
@@ -547,8 +541,8 @@ if __name__ == "__main__":
score -= math.exp(-i) * candidate_size[i]
return score

def abduce_rules(self, pred_res):
return self.kb.abduce_rules(pred_res)
def abduce_rules(self, pseudo_labels):
return self.kb.abduce_rules(pseudo_labels)

kb = HedKB(
pseudo_label_list=[1, 0, "+", "="],


Loading…
Cancel
Save