Browse Source

[FIX] add reasoner init parameters

pull/1/head
troyyyyy 1 year ago
parent
commit
cd5c577a50
2 changed files with 46 additions and 78 deletions
  1. +7
    -21
      abl/bridge/simple_bridge.py
  2. +39
    -57
      abl/reasoning/reasoner.py

+ 7
- 21
abl/bridge/simple_bridge.py View File

@@ -21,39 +21,25 @@ class SimpleBridge(BaseBridge):
super().__init__(model, reasoner)
self.metric_list = metric_list

# TODO: add reasoner.mapping to the property of SimpleBridge

def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]:
self.model.predict(data_samples)
return data_samples.pred_idx, data_samples.pred_prob

def abduce_pseudo_label(
self,
data_samples: ListData,
max_revision: int = -1,
require_more_revision: int = 0,
) -> List[List[Any]]:
self.reasoner.batch_abduce(data_samples, max_revision, require_more_revision)
def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
self.reasoner.batch_abduce(data_samples)
return data_samples.abduced_pseudo_label

def idx_to_pseudo_label(
self, data_samples: ListData, mapping: Optional[Dict] = None
) -> List[List[Any]]:
if mapping is None:
mapping = self.reasoner.mapping
def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
pred_idx = data_samples.pred_idx
data_samples.pred_pseudo_label = [
[mapping[_idx] for _idx in sub_list] for sub_list in pred_idx
[self.reasoner.mapping[_idx] for _idx in sub_list]
for sub_list in pred_idx
]
return data_samples.pred_pseudo_label

def pseudo_label_to_idx(
self, data_samples: ListData, mapping: Optional[Dict] = None
) -> List[List[Any]]:
if mapping is None:
mapping = self.reasoner.remapping
def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]:
abduced_idx = [
[mapping[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list]
[self.reasoner.remapping[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list]
for sub_list in data_samples.abduced_pseudo_label
]
data_samples.abduced_idx = abduced_idx


+ 39
- 57
abl/reasoning/reasoner.py View File

@@ -24,17 +24,35 @@ class ReasonerBase:
mapping : dict, optional
A mapping from index in the base model to label. If not provided, a default
order-based mapping is created.
max_revision : int or float, optional
The upper limit on the number of revisions for each data sample when
performing abductive reasoning. If float, denotes the fraction of the total
length that can be revised. A value of -1 implies no restriction on the
number of revisions. Defaults to -1.
require_more_revision : int, optional
Specifies additional number of revisions permitted beyond the minimum required
when performing abductive reasoning. Defaults to 0.
use_zoopt : bool, optional
Whether to use the Zoopt library during abductive reasoning. Defaults to False.
"""
def __init__(self, kb, dist_func="confidence", mapping=None, use_zoopt=False):
def __init__(self,
kb,
dist_func="confidence",
mapping=None,
max_revision=-1,
require_more_revision=0,
use_zoopt=False,
):
if dist_func not in ["hamming", "confidence"]:
raise NotImplementedError("Valid options for dist_func include \"hamming\" and \"confidence\"")

self.kb = kb
self.dist_func = dist_func
self.use_zoopt = use_zoopt
self.max_revision = max_revision
self.require_more_revision = require_more_revision
if mapping is None:
self.mapping = {
index: label for index, label in enumerate(self.kb.pseudo_label_list)
@@ -117,9 +135,7 @@ class ReasonerBase:
size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num
)
objective = Objective(
lambda sol: self.zoopt_revision_score(
symbol_num, data_sample, sol
),
lambda sol: self.zoopt_revision_score(symbol_num, data_sample, sol),
dim=dimension,
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),
)
@@ -133,7 +149,9 @@ class ReasonerBase:
has a higher preference for this solution.
"""
revision_idx = np.where(sol.get_x() != 0)[0]
candidates = self.revise_at_idx(data_sample, revision_idx)
candidates = self.kb.revise_at_idx(data_sample.pred_pseudo_label,
data_sample.Y,
revision_idx)
if len(candidates) > 0:
return np.min(self._get_cost_list(data_sample, candidates))
else:
@@ -146,27 +164,6 @@ class ReasonerBase:
"""
x = solution.get_x()
return max_revision_num - x.sum()
def revise_at_idx(self, data_sample, revision_idx):
"""
Revise the pseudo label in the data sample at specified index positions.

Parameters
----------
data_sample : ListData
Data sample.
revision_idx : array-like
Indices of where revisions should be made to the predicted pseudo label.
"""
return self.kb.revise_at_idx(data_sample.pred_pseudo_label,
data_sample.Y,
revision_idx)
def abduce_candidates(self, data_sample, max_revision_num, require_more_revision):
return self.kb.abduce_candidates(data_sample.pred_pseudo_label,
data_sample.Y,
max_revision_num,
require_more_revision)

def _get_max_revision_num(self, max_revision, symbol_num):
"""
@@ -186,9 +183,7 @@ class ReasonerBase:
raise ValueError("If max_revision is an int, it must be non-negative.")
return max_revision
def abduce(
self, data_sample, max_revision=-1, require_more_revision=0
):
def abduce(self, data_sample):
"""
Perform abductive reasoning on the given data sample.

@@ -196,14 +191,7 @@ class ReasonerBase:
----------
data_sample : ListData
Data sample.
max_revision : int or float, optional
The upper limit on the number of revisions. If float, denotes the fraction of the
total length that can be revised. A value of -1 implies no restriction on the number
of revisions. Defaults to -1.
require_more_revision : int, optional
Specifies additional number of revisions permitted beyond the minimum required.
Defaults to 0.

Returns
-------
List[Any]
@@ -211,39 +199,33 @@ class ReasonerBase:
knowledge base.
"""
symbol_num = data_sample.elements_num("pred_pseudo_label")
max_revision_num = self._get_max_revision_num(max_revision, symbol_num)
max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)
if self.use_zoopt:
solution = self.zoopt_get_solution(
symbol_num, data_sample, max_revision_num
)
solution = self.zoopt_get_solution(symbol_num, data_sample, max_revision_num)
revision_idx = np.where(solution != 0)[0]
candidates = self.revise_at_idx(data_sample, revision_idx)
candidates = self.self.kb.revise_at_idx(data_sample.pred_pseudo_label,
data_sample.Y,
revision_idx)
else:
candidates = self.abduce_candidates(
data_sample, max_revision_num, require_more_revision
)

candidates = self.kb.abduce_candidates(data_sample.pred_pseudo_label,
data_sample.Y,
max_revision_num,
self.require_more_revision)
candidate = self._get_one_candidate(data_sample, candidates)
return candidate

def batch_abduce(
self, data_samples, max_revision=-1, require_more_revision=0
):
def batch_abduce(self, data_samples):
"""
Perform abductive reasoning on the given prediction data samples.
For detailed information, refer to `abduce`.
"""
abduced_pseudo_label = [
self.abduce(data_sample, max_revision, require_more_revision)
for data_sample in data_samples
self.abduce(data_sample) for data_sample in data_samples
]
data_samples.abduced_pseudo_label = abduced_pseudo_label
return abduced_pseudo_label

def __call__(
self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0
):
return self.batch_abduce(
pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision
)
def __call__(self, data_samples):
return self.batch_abduce(data_samples)

Loading…
Cancel
Save