Browse Source

[ENH] change parameters passing in reasoning

pull/1/head
troyyyyy 1 year ago
parent
commit
bf04dd9c95
1 changed files with 47 additions and 53 deletions
  1. +47
    -53
      abl/reasoning/reasoner.py

+ 47
- 53
abl/reasoning/reasoner.py View File

@@ -1,6 +1,6 @@
import numpy as np import numpy as np
from zoopt import Dimension, Objective, Parameter, Opt from zoopt import Dimension, Objective, Parameter, Opt
from ..utils.utils import (
from abl.utils.utils import (
confidence_dist, confidence_dist,
flatten, flatten,
reform_list, reform_list,
@@ -50,7 +50,7 @@ class ReasonerBase:
self.mapping = mapping self.mapping = mapping
self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) self.remapping = dict(zip(self.mapping.values(), self.mapping.keys()))


def _get_one_candidate(self, pred_pseudo_label, pred_prob, candidates):
def _get_one_candidate(self, data_sample, candidates):
""" """
Due to the nondeterminism of abductive reasoning, there could be multiple candidates Due to the nondeterminism of abductive reasoning, there could be multiple candidates
satisfying the knowledge base. When this happens, return one candidate that has the satisfying the knowledge base. When this happens, return one candidate that has the
@@ -71,11 +71,11 @@ class ReasonerBase:
elif len(candidates) == 1: elif len(candidates) == 1:
return candidates[0] return candidates[0]
else: else:
cost_array = self._get_cost_list(pred_pseudo_label, pred_prob, candidates)
cost_array = self._get_cost_list(data_sample, candidates)
candidate = candidates[np.argmin(cost_array)] candidate = candidates[np.argmin(cost_array)]
return candidate return candidate
def _get_cost_list(self, pred_pseudo_label, pred_prob, candidates):
def _get_cost_list(self, data_sample, candidates):
""" """
Get the list of costs between each candidate and the given prediction. The list is Get the list of costs between each candidate and the given prediction. The list is
calculated based on one of the following distance functions: calculated based on one of the following distance functions:
@@ -95,15 +95,15 @@ class ReasonerBase:
Multiple consistent candidates. Multiple consistent candidates.
""" """
if self.dist_func == "hamming": if self.dist_func == "hamming":
return hamming_dist(pred_pseudo_label, candidates)
return hamming_dist(data_sample.pred_pseudo_label, candidates)


elif self.dist_func == "confidence": elif self.dist_func == "confidence":
candidates = [[self.remapping[x] for x in c] for c in candidates] candidates = [[self.remapping[x] for x in c] for c in candidates]
return confidence_dist(pred_prob, candidates)
return confidence_dist(data_sample.pred_prob, candidates)




def zoopt_get_solution( def zoopt_get_solution(
self, symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num
self, symbol_num, data_sample, max_revision_num
): ):
""" """
Get the optimal solution using the Zoopt library. The solution is a list of Get the optimal solution using the Zoopt library. The solution is a list of
@@ -113,13 +113,8 @@ class ReasonerBase:
---------- ----------
symbol_num : int symbol_num : int
Number of total symbols. Number of total symbols.
pred_pseudo_label : List[Any]
Predicted pseudo label.
pred_prob : List[List[Any]]
Predicted probabilities of the prediction (Each sublist contains the probability
distribution over all pseudo labels).
y : Any
Ground truth for the logical result.
data_sample : ListData
max_revision_num : int max_revision_num : int
Specifies the maximum number of revisions allowed. Specifies the maximum number of revisions allowed.
""" """
@@ -128,7 +123,7 @@ class ReasonerBase:
) )
objective = Objective( objective = Objective(
lambda sol: self.zoopt_revision_score( lambda sol: self.zoopt_revision_score(
symbol_num, pred_pseudo_label, pred_prob, y, sol
symbol_num, data_sample, sol
), ),
dim=dimension, dim=dimension,
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),
@@ -137,15 +132,15 @@ class ReasonerBase:
solution = Opt.min(objective, parameter).get_x() solution = Opt.min(objective, parameter).get_x()
return solution return solution
def zoopt_revision_score(self, symbol_num, pred_pseudo_label, pred_prob, y, sol):
def zoopt_revision_score(self, symbol_num, data_sample, sol):
""" """
Get the revision score for a solution. A lower score suggests that the Zoopt library Get the revision score for a solution. A lower score suggests that the Zoopt library
has a higher preference for this solution. has a higher preference for this solution.
""" """
revision_idx = np.where(sol.get_x() != 0)[0] revision_idx = np.where(sol.get_x() != 0)[0]
candidates = self.revise_at_idx(pred_pseudo_label, y, revision_idx)
candidates = self.revise_at_idx(data_sample, revision_idx)
if len(candidates) > 0: if len(candidates) > 0:
return np.min(self._get_cost_list(pred_pseudo_label, pred_prob, candidates))
return np.min(self._get_cost_list(data_sample, candidates))
else: else:
return symbol_num return symbol_num


@@ -157,7 +152,7 @@ class ReasonerBase:
x = solution.get_x() x = solution.get_x()
return max_revision_num - x.sum() return max_revision_num - x.sum()
def revise_at_idx(self, pred_pseudo_label, y, revision_idx):
def revise_at_idx(self, data_sample, revision_idx):
""" """
Revise the predicted pseudo label at specified index positions. Revise the predicted pseudo label at specified index positions.


@@ -170,7 +165,15 @@ class ReasonerBase:
revision_idx : array-like revision_idx : array-like
Indices of where revisions should be made to the predicted pseudo label. Indices of where revisions should be made to the predicted pseudo label.
""" """
return self.kb.revise_at_idx(pred_pseudo_label, y, revision_idx)
return self.kb.revise_at_idx(data_sample.pred_pseudo_label,
data_sample.Y,
revision_idx)
def abduce_candidates(self, data_sample, max_revision_num, require_more_revision):
return self.kb.abduce_candidates(data_sample.pred_pseudo_label,
data_sample.Y,
max_revision_num,
require_more_revision)


def _get_max_revision_num(self, max_revision, symbol_num): def _get_max_revision_num(self, max_revision, symbol_num):
""" """
@@ -222,22 +225,18 @@ class ReasonerBase:
symbol_num = data_sample.elements_num("pred_pseudo_label") symbol_num = data_sample.elements_num("pred_pseudo_label")
max_revision_num = self._get_max_revision_num(max_revision, symbol_num) max_revision_num = self._get_max_revision_num(max_revision, symbol_num)
pred_pseudo_label = data_sample.pred_pseudo_label
pred_prob = data_sample.pred_prob
y = data_sample.Y
if self.use_zoopt: if self.use_zoopt:
solution = self.zoopt_get_solution( solution = self.zoopt_get_solution(
symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num
symbol_num, data_sample, max_revision_num
) )
revision_idx = np.where(solution != 0)[0] revision_idx = np.where(solution != 0)[0]
candidates = self.revise_at_idx(pred_pseudo_label, y, revision_idx)
candidates = self.revise_at_idx(data_sample, revision_idx)
else: else:
candidates = self.kb.abduce_candidates(
pred_pseudo_label, y, max_revision_num, require_more_revision
candidates = self.abduce_candidates(
data_sample, max_revision_num, require_more_revision
) )


candidate = self._get_one_candidate(pred_pseudo_label, pred_prob, candidates)
candidate = self._get_one_candidate(data_sample, candidates)
return candidate return candidate


def batch_abduce( def batch_abduce(
@@ -254,15 +253,6 @@ class ReasonerBase:
data_samples.abduced_pseudo_label = abduced_pseudo_label data_samples.abduced_pseudo_label = abduced_pseudo_label
return abduced_pseudo_label return abduced_pseudo_label


# def _batch_abduce_helper(self, args):
# z, prob, y, max_revision, require_more_revision = args
# return self.abduce((z, prob, y), max_revision, require_more_revision)

# def batch_abduce(self, Z, Y, max_revision=-1, require_more_revision=0):
# with Pool(processes=os.cpu_count()) as pool:
# results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)])
# return results

def __call__( def __call__(
self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0 self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0
): ):
@@ -486,8 +476,8 @@ if __name__ == "__main__":
pl_query = "eval_inst_feature(%s, %s)." % (exs, rules) pl_query = "eval_inst_feature(%s, %s)." % (exs, rules)
return len(list(self.prolog.query(pl_query))) != 0 return len(list(self.prolog.query(pl_query))) != 0


def abduce_rules(self, pred_res):
pl_query = "consistent_inst_feature(%s, X)." % pred_res
def abduce_rules(self, pseudo_labels):
pl_query = "consistent_inst_feature(%s, X)." % pseudo_labels
prolog_result = list(self.prolog.query(pl_query)) prolog_result = list(self.prolog.query(pl_query))
if len(prolog_result) == 0: if len(prolog_result) == 0:
return None return None
@@ -499,32 +489,36 @@ if __name__ == "__main__":
def __init__(self, kb, dist_func="hamming"): def __init__(self, kb, dist_func="hamming"):
super().__init__(kb, dist_func, use_zoopt=True) super().__init__(kb, dist_func, use_zoopt=True)


def _revise_at_idxs(self, pred_res, y, all_revision_flag, idxs):
pred = []
k = []
def _revise_at_idxs(self, pseudo_labels, ys, all_revision_flag, idxs):
data_sample = ListData()
data_sample.pred_pseudo_label = []
data_sample.Y = []
revision_flag = [] revision_flag = []
for idx in idxs: for idx in idxs:
pred.append(pred_res[idx])
k.append(y[idx])
data_sample.pred_pseudo_label.append(pseudo_labels[idx])
data_sample.Y.append(ys[idx])
revision_flag += list(all_revision_flag[idx]) revision_flag += list(all_revision_flag[idx])
revision_idx = np.where(np.array(revision_flag) != 0)[0] revision_idx = np.where(np.array(revision_flag) != 0)[0]
candidate = self.revise_at_idx(pred, k, revision_idx)
candidate = self.revise_at_idx(data_sample, revision_idx)
return candidate return candidate


def zoopt_revision_score(self, symbol_num, pred_res, pred_prob, y, sol):
all_revision_flag = reform_list(sol.get_x(), pred_res)
lefted_idxs = [i for i in range(len(pred_res))]
def zoopt_revision_score(self, symbol_num, data_sample, sol):
pseudo_labels = data_sample.pred_pseudo_label
ys = data_sample.Y
all_revision_flag = reform_list(sol.get_x(), pseudo_labels)
lefted_idxs = [i for i in range(len(pseudo_labels))]
candidate_size = [] candidate_size = []
while lefted_idxs: while lefted_idxs:
idxs = [] idxs = []
idxs.append(lefted_idxs.pop(0)) idxs.append(lefted_idxs.pop(0))
max_candidate_idxs = [] max_candidate_idxs = []
found = False found = False
for idx in range(-1, len(pred_res)):
for idx in range(-1, len(pseudo_labels)):
if (not idx in idxs) and (idx >= 0): if (not idx in idxs) and (idx >= 0):
idxs.append(idx) idxs.append(idx)
candidate = self._revise_at_idxs( candidate = self._revise_at_idxs(
pred_res, y, all_revision_flag, idxs
pseudo_labels, ys, all_revision_flag, idxs
) )
if len(candidate) == 0: if len(candidate) == 0:
if len(idxs) > 1: if len(idxs) > 1:
@@ -547,8 +541,8 @@ if __name__ == "__main__":
score -= math.exp(-i) * candidate_size[i] score -= math.exp(-i) * candidate_size[i]
return score return score


def abduce_rules(self, pred_res):
return self.kb.abduce_rules(pred_res)
def abduce_rules(self, pseudo_labels):
return self.kb.abduce_rules(pseudo_labels)


kb = HedKB( kb = HedKB(
pseudo_label_list=[1, 0, "+", "="], pseudo_label_list=[1, 0, "+", "="],


Loading…
Cancel
Save