Browse Source

[ENH] add abstract data interface to reasoner

ab_data
Gao Enhao 1 year ago
parent
commit
3a3b0ee575
1 changed files with 75 additions and 90 deletions
  1. +75
    -90
      abl/reasoning/reasoner.py

+ 75
- 90
abl/reasoning/reasoner.py View File

@@ -1,18 +1,28 @@
from typing import Any, List, Mapping, Tuple, Union

import numpy as np
from zoopt import Dimension, Objective, Opt, Parameter
from zoopt import Dimension, Objective, Opt, Parameter, Solution

from ..utils.utils import (calculate_revision_num, confidence_dist, flatten,
from ..structures import ListData
from ..utils.utils import (calculate_revision_num, confidence_dist,
hamming_dist, reform_idx)
from .base_kb import BaseKB


class ReasonerBase:
def __init__(self, kb, dist_func="hamming", mapping=None, use_zoopt=False):
def __init__(
self,
kb: BaseKB,
dist_func: str = "hamming",
mapping: Mapping = None,
use_zoopt: bool = False,
):
"""
Base class for all reasoner in the ABL system.

Parameters
----------
kb : KBBase
kb : BaseKB
The knowledge base to be used for reasoning.
dist_func : str, optional
The distance function to be used. Can be "hamming" or "confidence". Default is "hamming".
@@ -34,14 +44,12 @@ class ReasonerBase:
self.dist_func = dist_func
self.use_zoopt = use_zoopt
if mapping is None:
self.mapping = {
index: label for index, label in enumerate(self.kb.pseudo_label_list)
}
self.mapping = {index: label for index, label in enumerate(self.kb.pseudo_label_list)}
else:
self.mapping = mapping
self.remapping = dict(zip(self.mapping.values(), self.mapping.keys()))

def _get_cost_list(self, pred_pseudo_label, pred_prob, candidates):
def _get_cost_list(self, data_sample: ListData, candidates: List[List[Any]]):
"""
Get the list of costs between each pseudo label and candidate.

@@ -60,13 +68,13 @@ class ReasonerBase:
Array of computed costs for each candidate.
"""
if self.dist_func == "hamming":
return hamming_dist(pred_pseudo_label, candidates)
return hamming_dist(data_sample["pred_pseudo_label"][0], candidates)

elif self.dist_func == "confidence":
candidates = [[self.remapping[x] for x in c] for c in candidates]
return confidence_dist(pred_prob, candidates)
return confidence_dist(data_sample["pred_prob"][0], candidates)

def _get_one_candidate(self, data_sample, candidates):
def _get_one_candidate(self, data_sample: ListData, candidates: List[List[Any]]):
"""
Get one candidate. If multiple candidates exist, return the one with minimum cost.

@@ -90,27 +98,23 @@ class ReasonerBase:
elif len(candidates) == 1:
return candidates[0]
else:
cost_array = self._get_cost_list(
data_sample["pred_pseudo_label"][0], data_sample["pred_prob"][0], candidates
)
cost_array = self._get_cost_list(data_sample, candidates)
candidate = candidates[np.argmin(cost_array)]
return candidate

def zoopt_revision_score(self, symbol_num, pred_pseudo_label, pred_prob, y, sol):
def zoopt_revision_score(self, data_sample: ListData, solution: Solution):
"""
Get the revision score for a single solution.

Parameters
----------
symbol_num : int
Number of total symbols.
pred_pseudo_label : list
List of predicted pseudo labels.
pred_prob : list
List of probabilities for predicted results.
y : any
Ground truth for the predicted results.
sol : array-like
solution : array-like
Solution to evaluate.

Returns
@@ -118,26 +122,22 @@ class ReasonerBase:
float
The revision score for the given solution.
"""
revision_idx = np.where(sol.get_x() != 0)[0]
candidates = self.revise_by_idx(pred_pseudo_label, y, revision_idx)
revision_idx = np.where(solution.get_x() != 0)[0]
candidates = self.revise_at_idx(data_sample, revision_idx)
if len(candidates) > 0:
return np.min(self._get_cost_list(pred_pseudo_label, pred_prob, candidates))
return np.min(self._get_cost_list(data_sample, candidates))
else:
return symbol_num
return data_sample["symbol_num"]

def _constrain_revision_num(self, solution, max_revision_num):
def _constrain_revision_num(self, solution: Solution, max_revision_num: int):
x = solution.get_x()
return max_revision_num - x.sum()

def zoopt_get_solution(
self, symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num
):
def zoopt_get_solution(self, data_sample: ListData, max_revision_num: int):
"""Get the optimal solution using the Zoopt library.

Parameters
----------
symbol_num : int
Number of total symbols.
pred_pseudo_label : list
List of predicted pseudo labels.
pred_prob : list
@@ -152,21 +152,18 @@ class ReasonerBase:
array-like
The optimal solution, i.e., where to revise predict pseudo label.
"""
dimension = Dimension(
size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num
)
symbol_num = data_sample["symbol_num"]
dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num)
objective = Objective(
lambda sol: self.zoopt_revision_score(
symbol_num, pred_pseudo_label, pred_prob, y, sol
),
lambda solution: self.zoopt_revision_score(data_sample, solution),
dim=dimension,
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),
constraint=lambda solution: self._constrain_revision_num(solution, max_revision_num),
)
parameter = Parameter(budget=100, intermediate_result=False, autoset=True)
solution = Opt.min(objective, parameter).get_x()
return solution

def revise_by_idx(self, pred_pseudo_label, y, revision_idx):
def revise_at_idx(self, data_sample: ListData, revision_idx: Union[List, Tuple, np.ndarray]):
"""
Revise the pseudo label according to the given indices.

@@ -184,9 +181,14 @@ class ReasonerBase:
list
The revisions according to the given indices.
"""
return self.kb.revise_by_idx(pred_pseudo_label, y, revision_idx)
return self.kb.revise_at_idx(data_sample, revision_idx)

def abduce(self, data_sample, max_revision=-1, require_more_revision=0):
def abduce(
self,
data_sample: ListData,
max_revision: int = -1,
require_more_revision: int = 0,
):
"""
Perform revision by abduction on the given data.

@@ -209,15 +211,15 @@ class ReasonerBase:
list
The abduced revisions.
"""
symbol_num = len(flatten(data_sample.pred_pseudo_label))
symbol_num = data_sample.elements_num("pred_pseudo_label")
max_revision_num = calculate_revision_num(max_revision, symbol_num)

data_sample.set_metainfo(dict(symbol_num=symbol_num))

if self.use_zoopt:
solution = self.zoopt_get_solution(
symbol_num, data_sample, max_revision_num
)
solution = self.zoopt_get_solution(data_sample, max_revision_num)
revision_idx = np.where(solution != 0)[0]
candidates = self.revise_by_idx(data_sample, revision_idx)
candidates = self.revise_at_idx(data_sample, revision_idx)
else:
candidates = self.kb.abduce_candidates(
data_sample, max_revision_num, require_more_revision
@@ -226,7 +228,12 @@ class ReasonerBase:
candidate = self._get_one_candidate(data_sample, candidates)
return candidate

def batch_abduce(self, data_samples, max_revision=-1, require_more_revision=0):
def batch_abduce(
self,
data_samples: ListData,
max_revision: int = -1,
require_more_revision: int = 0,
):
"""
Perform abduction on the given data in batches.

@@ -258,6 +265,7 @@ class ReasonerBase:
for data_sample in data_samples
]
data_samples.abduced_pseudo_label = abduced_pseudo_label
return abduced_pseudo_label

# def _batch_abduce_helper(self, args):
# z, prob, y, max_revision, require_more_revision = args
@@ -268,16 +276,9 @@ class ReasonerBase:
# results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)])
# return results

def __call__(
self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0
):
return self.batch_abduce(
pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision
)


if __name__ == "__main__":
from kb import KBBase, ground_KB, prolog_KB
from abl.reasoning.base_kb import BaseKB, GroundKB, PrologBasedKB

prob1 = [
[
@@ -293,14 +294,14 @@ if __name__ == "__main__":
]
]

class add_KB(KBBase):
class add_KB(BaseKB):
def __init__(self, pseudo_label_list=list(range(10)), use_cache=True):
super().__init__(pseudo_label_list, use_cache=use_cache)

def logic_forward(self, nums):
return sum(nums)

class add_ground_KB(ground_KB):
class add_GroundKB(GroundKB):
def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]):
super().__init__(pseudo_label_list, GKB_len_list)

@@ -308,30 +309,20 @@ if __name__ == "__main__":
return sum(nums)

def test_add(reasoner):
res = reasoner.batch_abduce(
prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0
)
res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(
prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0
)
res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(
prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0
)
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(
prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0
)
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(
prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0
)
res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0)
print(res)
print()

print("add_KB with GKB:")
kb = add_ground_KB()
kb = add_GroundKB()
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

@@ -345,15 +336,15 @@ if __name__ == "__main__":
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("prolog_KB with add.pl:")
kb = prolog_KB(
print("PrologBasedKB with add.pl:")
kb = PrologBasedKB(
pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/datasets/add.pl"
)
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("prolog_KB with add.pl using zoopt:")
kb = prolog_KB(
print("PrologBasedKB with add.pl using zoopt:")
kb = PrologBasedKB(
pseudo_label_list=list(range(10)),
pl_file="examples/mnist_add/datasets/add.pl",
)
@@ -392,7 +383,7 @@ if __name__ == "__main__":
print(res)
print()

class HWF_KB(KBBase):
class HWF_KB(BaseKB):
def __init__(
self,
pseudo_label_list=[
@@ -442,7 +433,7 @@ if __name__ == "__main__":
formula = [mapping[f] for f in formula]
return eval("".join(formula))

class HWF_ground_KB(ground_KB):
class HWF_GroundKB(GroundKB):
def __init__(
self,
pseudo_label_list=[
@@ -548,7 +539,7 @@ if __name__ == "__main__":
print()

print("HWF_KB with GKB, max_err=0.1")
kb = HWF_ground_KB(GKB_len_list=[1, 3, 5], max_err=0.1)
kb = HWF_GroundKB(GKB_len_list=[1, 3, 5], max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

@@ -558,7 +549,7 @@ if __name__ == "__main__":
test_hwf(reasoner)

print("HWF_KB with GKB, max_err=1")
kb = HWF_ground_KB(GKB_len_list=[1, 3, 5], max_err=1)
kb = HWF_GroundKB(GKB_len_list=[1, 3, 5], max_err=1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

@@ -575,7 +566,7 @@ if __name__ == "__main__":
print("max_revision is float")
test_hwf_multiple(reasoner, max_revisions=[0.5, 0.9, 0.9])

class HED_prolog_KB(prolog_KB):
class HED_prolog_KB(PrologBasedKB):
def __init__(self, pseudo_label_list, pl_file):
super().__init__(pseudo_label_list, pl_file)

@@ -597,7 +588,7 @@ if __name__ == "__main__":
def __init__(self, kb, dist_func="hamming"):
super().__init__(kb, dist_func, use_zoopt=True)

def _revise_by_idxs(self, pred_res, y, all_revision_flag, idxs):
def _revise_at_idxs(self, pred_res, y, all_revision_flag, idxs):
pred = []
k = []
revision_flag = []
@@ -606,7 +597,7 @@ if __name__ == "__main__":
k.append(y[idx])
revision_flag += list(all_revision_flag[idx])
revision_idx = np.where(np.array(revision_flag) != 0)[0]
candidate = self.revise_by_idx(pred, k, revision_idx)
candidate = self.revise_at_idx(pred, k, revision_idx)
return candidate

def zoopt_revision_score(self, symbol_num, pred_res, pred_prob, y, sol):
@@ -621,9 +612,7 @@ if __name__ == "__main__":
for idx in range(-1, len(pred_res)):
if (not idx in idxs) and (idx >= 0):
idxs.append(idx)
candidate = self._revise_by_idxs(
pred_res, y, all_revision_flag, idxs
)
candidate = self._revise_at_idxs(pred_res, y, all_revision_flag, idxs)
if len(candidate) == 0:
if len(idxs) > 1:
idxs.pop()
@@ -634,9 +623,7 @@ if __name__ == "__main__":
removed = [i for i in lefted_idxs if i in max_candidate_idxs]
if found:
candidate_size.append(len(removed) + 1)
lefted_idxs = [
i for i in lefted_idxs if i not in max_candidate_idxs
]
lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]
candidate_size.sort()
score = 0
import math
@@ -677,9 +664,7 @@ if __name__ == "__main__":
print()

print("HED_Reasoner abduce")
res = reasoner.abduce(
[[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs)
)
res = reasoner.abduce([[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs))
print(res)
res = reasoner.abduce(
[[[None]]] * len(inconsist_exs1), inconsist_exs1, [None] * len(inconsist_exs1)


Loading…
Cancel
Save