Browse Source

Remove

pull/3/head
troyyyyy 2 years ago
parent
commit
87c3ba1b71
3 changed files with 168 additions and 162 deletions
  1. +130
    -96
      abl/abducer/abducer_base.py
  2. +37
    -65
      abl/abducer/kb.py
  3. +1
    -1
      examples/datasets/hed/learn_add.pl

+ 130
- 96
abl/abducer/abducer_base.py View File

@@ -16,27 +16,19 @@ from zoopt import Dimension, Objective, Parameter, Opt
from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist

class AbducerBase(abc.ABC):
def __init__(self, kb, dist_func='hamming', zoopt=False, multiple_predictions=False):
def __init__(self, kb, dist_func='hamming', zoopt=False):
self.kb = kb
assert dist_func == 'hamming' or dist_func == 'confidence'
self.dist_func = dist_func
self.zoopt = zoopt
self.multiple_predictions = multiple_predictions
if dist_func == 'confidence':
self.mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list)))))

def _get_cost_list(self, pred_res, pred_res_prob, candidates):
if self.dist_func == 'hamming':
if self.multiple_predictions:
pred_res = flatten(pred_res)
candidates = [flatten(c) for c in candidates]
return hamming_dist(pred_res, candidates)
elif self.dist_func == 'confidence':
if self.multiple_predictions:
pred_res_prob = flatten(pred_res_prob)
candidates = [flatten(c) for c in candidates]
candidates = [list(map(lambda x: self.mapping[x], c)) for c in candidates]
return confidence_dist(pred_res_prob, candidates)

@@ -50,26 +42,24 @@ class AbducerBase(abc.ABC):
cost_list = self._get_cost_list(pred_res, pred_res_prob, candidates)
candidate = candidates[np.argmin(cost_list)]
return candidate

def _zoopt_address_score(self, pred_res, pred_res_prob, key, sol):
if not self.multiple_predictions:
address_idx = np.where(sol.get_x() != 0)[0]
candidates = self.address_by_idx(pred_res, key, address_idx)
if len(candidates) > 0:
return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates))
else:
return len(pred_res)
def _zoopt_address_score_single(self, sol_x, pred_res, pred_res_prob, key):
address_idx = np.where(sol_x != 0)[0]
candidates = self.address_by_idx(pred_res, key, address_idx)
if len(candidates) > 0:
return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates))
else:
all_address_flag = reform_idx(sol.get_x(), pred_res)
score = 0
for idx in range(len(pred_res)):
address_idx = np.where(all_address_flag[idx] != 0)[0]
candidates = self.address_by_idx([pred_res[idx]], key[idx], address_idx)
if len(candidates) > 0:
score += np.min(self._get_cost_list(pred_res[idx], pred_res_prob[idx], candidates))
else:
score += len(pred_res[idx])
return score
return len(pred_res)
def _zoopt_address_score(self, pred_res, pred_res_prob, key, sol):
# if not self.multiple_predictions:
return self._zoopt_address_score_single(sol.get_x(), pred_res, pred_res_prob, key)
# else:
# all_address_flag = reform_idx(sol.get_x(), pred_res)
# score = 0
# for idx in range(len(pred_res)):
# score += self._zoopt_address_score_single(all_address_flag[idx], pred_res[idx], pred_res_prob[idx], key)
# return score
def _constrain_address_num(self, solution, max_address_num):
x = solution.get_x()
@@ -88,21 +78,29 @@ class AbducerBase(abc.ABC):
return solution
def address_by_idx(self, pred_res, key, address_idx):
return self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions)
return self.kb.address_by_idx(pred_res, key, address_idx)

def abduce(self, data, max_address_num=-1, require_more_address=0):
def abduce(self, data, max_address=-1, require_more_address=0):
pred_res, pred_res_prob, key = data
if max_address_num == -1:
max_address_num = len(flatten(pred_res))
# if max_address_num == -1:
# max_address_num = len(flatten(pred_res))
assert(type(max_address) in (int, float))
if max_address == -1:
max_address_num = len(pred_res)
elif type(max_address) == float:
assert(max_address >= 0 and max_address <= 1)
max_address_num = round(len(pred_res) * max_address)
else:
assert(max_address >= 0)
max_address_num = max_address

if self.zoopt:
solution = self.zoopt_get_solution(pred_res, pred_res_prob, key, max_address_num)
address_idx = np.where(solution != 0)[0]
candidates = self.address_by_idx(pred_res, key, address_idx)
else:
candidates = self.kb.abduce_candidates(
pred_res, key, max_address_num, require_more_address, self.multiple_predictions
)
candidates = self.kb.abduce_candidates(pred_res, key, max_address_num, require_more_address)

candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates)
return candidate
@@ -110,134 +108,170 @@ class AbducerBase(abc.ABC):
def abduce_rules(self, pred_res):
return self.kb.abduce_rules(pred_res)

def batch_abduce(self, Z, Y, max_address_num=-1, require_more_address=0):
# if self.multiple_predictions:
return self.abduce((Z['cls'], Z['prob'], Y), max_address_num, require_more_address)
# else:
# return [self.abduce((z, prob, y), max_address_num, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]
def batch_abduce(self, data, max_address=-1, require_more_address=0):
Z1, Z2, Y = data
return [self.abduce((z, prob, y), max_address, require_more_address) for z, prob, y in zip(Z1, Z2, Y)]

def __call__(self, Z, Y, max_address_num=-1, require_more_address=0):
return self.batch_abduce(Z, Y, max_address_num, require_more_address)

if __name__ == '__main__':
from kb import add_KB, prolog_KB, HWF_KB
from kb import add_KB, prolog_KB, HWF_KB, HED_prolog_KB
prob1 = [[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]
prob2 = [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]
prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]]
prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]]

print('add_KB with GKB:')
kb = add_KB(GKB_flag=True)
abd = AbducerBase(kb, 'confidence')
res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0)
res = abd.batch_abduce(([[1, 1]], prob1, [8]), max_address=2, require_more_address=0)
print(res)
res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0)
res = abd.batch_abduce(([[1, 1]], prob2, [8]), max_address=2, require_more_address=0)
print(res)
res = abd.abduce(([1, 1], prob1, 17), max_address_num=2, require_more_address=0)
res = abd.batch_abduce(([[1, 1]], prob1, [17]), max_address=2, require_more_address=0)
print(res)
res = abd.abduce(([1, 1], prob1, 17), max_address_num=1, require_more_address=0)
res = abd.batch_abduce(([[1, 1]], prob1, [17]), max_address=1, require_more_address=0)
print(res)
res = abd.abduce(([1, 1], prob1, 20), max_address_num=2, require_more_address=0)
res = abd.batch_abduce(([[1, 1]], prob1, [20]), max_address=2, require_more_address=0)
print(res)
print()
multiple_prob = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]],
[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]]
print('add_KB without GKB:')
kb = add_KB()
abd = AbducerBase(kb, 'confidence', multiple_predictions=True)
res = abd.abduce(([[1, 1], [1, 2]], multiple_prob, [4, 8]), max_address_num=4, require_more_address=0)
abd = AbducerBase(kb, 'confidence')
res = abd.batch_abduce(([[1, 1]], prob1, [8]), max_address=2, require_more_address=0)
print(res)
res = abd.batch_abduce(([[1, 1]], prob2, [8]), max_address=2, require_more_address=0)
print(res)
res = abd.abduce(([[1, 1], [1, 2]], multiple_prob, [4, 8]), max_address_num=4, require_more_address=1)
res = abd.batch_abduce(([[1, 1]], prob1, [17]), max_address=2, require_more_address=0)
print(res)
res = abd.batch_abduce(([[1, 1]], prob1, [17]), max_address=1, require_more_address=0)
print(res)
res = abd.batch_abduce(([[1, 1]], prob1, [20]), max_address=2, require_more_address=0)
print(res)
print()

print('prolog_KB with add.pl:')
kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/datasets/mnist_add/add.pl')
abd = AbducerBase(kb, 'confidence')
res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0)
res = abd.batch_abduce(([[1, 1]], prob1, [8]), max_address=2, require_more_address=0)
print(res)
res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0)
res = abd.batch_abduce(([[1, 1]], prob2, [8]), max_address=2, require_more_address=0)
print(res)
res = abd.abduce(([1, 1], prob1, 17), max_address_num=2, require_more_address=0)
res = abd.batch_abduce(([[1, 1]], prob1, [17]), max_address=2, require_more_address=0)
print(res)
res = abd.abduce(([1, 1], prob1, 17), max_address_num=1, require_more_address=0)
res = abd.batch_abduce(([[1, 1]], prob1, [17]), max_address=1, require_more_address=0)
print(res)
res = abd.abduce(([1, 1], prob1, 20), max_address_num=2, require_more_address=0)
res = abd.batch_abduce(([[1, 1]], prob1, [20]), max_address=2, require_more_address=0)
print(res)
print()

print('prolog_KB with add.pl using zoopt:')
kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/datasets/mnist_add/add.pl')
abd = AbducerBase(kb, 'confidence', zoopt=True)
res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0)
res = abd.batch_abduce(([[1, 1]], prob1, [8]), max_address=2, require_more_address=0)
print(res)
res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0)
res = abd.batch_abduce(([[1, 1]], prob2, [8]), max_address=2, require_more_address=0)
print(res)
res = abd.abduce(([1, 1], prob1, 17), max_address_num=2, require_more_address=0)
res = abd.batch_abduce(([[1, 1]], prob1, [17]), max_address=2, require_more_address=0)
print(res)
res = abd.abduce(([1, 1], prob1, 17), max_address_num=1, require_more_address=0)
res = abd.batch_abduce(([[1, 1]], prob1, [17]), max_address=1, require_more_address=0)
print(res)
res = abd.abduce(([1, 1], prob1, 20), max_address_num=2, require_more_address=0)
res = abd.batch_abduce(([[1, 1]], prob1, [20]), max_address=2, require_more_address=0)
print(res)
print()

kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 0.1)
abd = AbducerBase(kb, 'hamming')
res = abd.abduce((['5', '+', '2'], None, 3), max_address_num=2, require_more_address=0)
print('add_KB with multiple inputs at once:')
multiple_prob = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]],
[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]]
kb = add_KB()
abd = AbducerBase(kb, 'confidence')
res = abd.batch_abduce(([[1, 1], [1, 2]], multiple_prob, [4, 8]), max_address=4, require_more_address=0)
print(res)
res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0)
res = abd.batch_abduce(([[1, 1], [1, 2]], multiple_prob, [4, 8]), max_address=4, require_more_address=1)
print(res)
print()
kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 1)
print('HWF_KB with GKB, max_err=0.1')
kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 0.1)
abd = AbducerBase(kb, 'hamming')
res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0)
res = abd.batch_abduce(([['5', '+', '2']], [None], [3]), max_address=2, require_more_address=0)
print(res)
res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num=3, require_more_address=0)
res = abd.batch_abduce(([['5', '+', '9']], [None], [65]), max_address=3, require_more_address=0)
print(res)
res = abd.abduce((['5', '8', '8', '8', '8'], None, 3.17), max_address_num=5, require_more_address=3)
res = abd.batch_abduce(([['5', '8', '8', '8', '8']], [None], [3.17]), max_address=5, require_more_address=3)
print(res)
print()
print('HWF_KB without GKB, max_err=0.1')
kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1)
abd = AbducerBase(kb, 'hamming', multiple_predictions=True)
res = abd.abduce(([['5', '+', '2'], ['5', '+', '9']], None, [3, 64]), max_address_num=6, require_more_address=0)
abd = AbducerBase(kb, 'hamming')
res = abd.batch_abduce(([['5', '+', '2']], [None], [3]), max_address=2, require_more_address=0)
print(res)
res = abd.batch_abduce(([['5', '+', '9']], [None], [65]), max_address=3, require_more_address=0)
print(res)
res = abd.batch_abduce(([['5', '8', '8', '8', '8']], [None], [3.17]), max_address=5, require_more_address=3)
print(res)
print()
kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1)
print('HWF_KB with GKB, max_err=1')
kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 1)
abd = AbducerBase(kb, 'hamming')
res = abd.abduce((['5', '+', '2'], None, 3), max_address_num=2, require_more_address=0)
res = abd.batch_abduce(([['5', '+', '9']], [None], [65]), max_address=3, require_more_address=0)
print(res)
res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0)
res = abd.batch_abduce(([['5', '+', '2']], [None], [1.67]), max_address=3, require_more_address=0)
print(res)
res = abd.batch_abduce(([['5', '8', '8', '8', '8']], [None], [3.17]), max_address=5, require_more_address=3)
print(res)
print()
print('HWF_KB without GKB, max_err=1')
kb = HWF_KB(len_list=[1, 3, 5], max_err = 1)
abd = AbducerBase(kb, 'hamming')
res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0)
res = abd.batch_abduce(([['5', '+', '9']], [None], [65]), max_address=3, require_more_address=0)
print(res)
res = abd.batch_abduce(([['5', '+', '2']], [None], [1.67]), max_address=3, require_more_address=0)
print(res)
res = abd.batch_abduce(([['5', '8', '8', '8', '8']], [None], [3.17]), max_address=5, require_more_address=3)
print(res)
print()
print('HWF_KB with multiple inputs at once:')
kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1)
abd = AbducerBase(kb, 'hamming')
res = abd.batch_abduce(([['5', '+', '2'], ['5', '+', '9']], [None, None], [3, 64]), max_address=1, require_more_address=0)
print(res)
res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num=3, require_more_address=0)
res = abd.batch_abduce(([['5', '+', '2'], ['5', '+', '9']], [None, None], [3, 64]), max_address=3, require_more_address=0)
print(res)
res = abd.abduce((['5', '8', '8', '8', '8'], None, 3.17), max_address_num=5, require_more_address=3)
res = abd.batch_abduce(([['5', '+', '2'], ['5', '+', '9']], [None, None], [3, 65]), max_address=3, require_more_address=0)
print(res)
print()
print('max_address is float')
res = abd.batch_abduce(([['5', '+', '2'], ['5', '+', '9']], [None, None], [3, 64]), max_address=0.5, require_more_address=0)
print(res)
res = abd.batch_abduce(([['5', '+', '2'], ['5', '+', '9']], [None, None], [3, 64]), max_address=0.9, require_more_address=0)
print(res)
print()

kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl')
abd = AbducerBase(kb, zoopt=True, multiple_predictions=True)
kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl')
abd = AbducerBase(kb, zoopt=True)
consist_exs = [[1, 1, '+', 0, '=', 1, 1], [1, '+', 1, '=', 1, 0], [0, '+', 0, '=', 0]]
inconsist_exs = [[1, '+', 0, '=', 0], [1, '=', 1, '=', 0], [0, '=', 0, '=', 1, 1]]
# inconsist_exs = [[1, '+', 0, '=', 0], ['=', '=', '=', '=', 0], ['=', '=', 0, '=', '=', '=']]
rules = ['my_op([0], [0], [0])', 'my_op([1], [1], [1, 0])']

print(kb._logic_forward(consist_exs, True), kb._logic_forward(inconsist_exs, True))
print(kb.consist_rule([1, '+', 1, '=', 1, 0], rules), kb.consist_rule([1, '+', 1, '=', 1, 1], rules))
print(kb.logic_forward(consist_exs))
print(kb.logic_forward(inconsist_exs))
print()

res = abd.abduce((consist_exs, [None] * len(consist_exs), [None] * len(consist_exs)))
print(res)
res = abd.abduce((inconsist_exs, [None] * len(consist_exs), [None] * len(inconsist_exs)))
print(res)
print(kb.consist_rule([1, '+', 1, '=', 1, 0], rules))
print(kb.consist_rule([1, '+', 1, '=', 1, 1], rules))
print()

abduced_rules = abd.abduce_rules(consist_exs)
print(abduced_rules)
# res = abd.abduce((consist_exs, [None] * len(consist_exs), [None] * len(consist_exs)))
# print(res)
# res = abd.batch_abduce((inconsist_exs, [None] * len(consist_exs), [None] * len(inconsist_exs)))
# print(res)
# print()

# abduced_rules = abd.batch_abduce_rules(consist_exs)
# print(abduced_rules)

+ 37
- 65
abl/abducer/kb.py View File

@@ -17,7 +17,7 @@ import numpy as np

from collections import defaultdict
from itertools import product, combinations
from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list
from utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list

from multiprocessing import Pool

@@ -25,11 +25,16 @@ from functools import lru_cache
import pyswip

class KBBase(ABC):
def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0):
def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0, cache_size=128):
# TODO:添加一下类型检查,比如
# if not isinstance(X, (np.ndarray, spmatrix)):
# raise TypeError("X should be numpy array or sparse matrix")

self.pseudo_label_list = pseudo_label_list
self.len_list = len_list
self.GKB_flag = GKB_flag
self.max_err = max_err
self.cache_size = cache_size

if GKB_flag:
self.base = {}
@@ -73,106 +78,73 @@ class KBBase(ABC):
@abstractmethod
def logic_forward(self, pseudo_labels):
pass
def _logic_forward(self, xs, multiple_predictions=False):
if not multiple_predictions:
return self.logic_forward(xs)
else:
res = [self.logic_forward(x) for x in xs]
return res

def abduce_candidates(self, pred_res, key, max_address_num, require_more_address=0, multiple_predictions=False):
def abduce_candidates(self, pred_res, key, max_address_num, require_more_address=0):
if self.GKB_flag:
return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions)
return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address)
else:
return self._abduce_by_search(to_hashable(pred_res), to_hashable(key), max_address_num, require_more_address, multiple_predictions)
return self._abduce_by_search(to_hashable(pred_res), to_hashable(key), max_address_num, require_more_address)
@abstractmethod
def _find_candidate_GKB(self, pred_res, key):
pass
def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions):
def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address):
if self.base == {}:
return []

if not multiple_predictions:
if len(pred_res) not in self.len_list:
return []
all_candidates = self._find_candidate_GKB(pred_res, key)
if len(all_candidates) == 0:
return []
else:
cost_list = hamming_dist(pred_res, all_candidates)
min_address_num = np.min(cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
idxs = np.where(cost_list <= address_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
return candidates
if len(pred_res) not in self.len_list:
return []
all_candidates = self._find_candidate_GKB(pred_res, key)
if len(all_candidates) == 0:
return []
else:
min_address_num = 0
all_candidates_save = []
cost_list_save = []
for p_res, k in zip(pred_res, key):
if len(p_res) not in self.len_list:
return []
all_candidates = self._find_candidate_GKB(p_res, k)
if len(all_candidates) == 0:
return []
else:
all_candidates_save.append(all_candidates)
cost_list = hamming_dist(p_res, all_candidates)
min_address_num += np.min(cost_list)
cost_list_save.append(cost_list)
multiple_all_candidates = [flatten(c) for c in product(*all_candidates_save)]
multiple_cost_list = np.array([sum(cost) for cost in product(*cost_list_save)])
cost_list = hamming_dist(pred_res, all_candidates)
min_address_num = np.min(cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
idxs = np.where(multiple_cost_list <= address_num)[0]
candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs]
idxs = np.where(cost_list <= address_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
return candidates

def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False):
def address_by_idx(self, pred_res, key, address_idx):
candidates = []
abduce_c = product(self.pseudo_label_list, repeat=len(address_idx))
if multiple_predictions:
save_pred_res = pred_res
pred_res = flatten(pred_res)
# if multiple_predictions:
# save_pred_res = pred_res
# pred_res = flatten(pred_res)

for c in abduce_c:
candidate = pred_res.copy()
for i, idx in enumerate(address_idx):
candidate[idx] = c[i]
if multiple_predictions:
candidate = reform_idx(candidate, save_pred_res)
if check_equal(self._logic_forward(candidate, multiple_predictions), key, self.max_err):
# if multiple_predictions:
# candidate = reform_idx(candidate, save_pred_res)
if check_equal(self.logic_forward(candidate), key, self.max_err):
candidates.append(candidate)
return candidates

def _address(self, address_num, pred_res, key, multiple_predictions):
def _address(self, address_num, pred_res, key):
new_candidates = []
if not multiple_predictions:
address_idx_list = combinations(list(range(len(pred_res))), address_num)
else:
address_idx_list = combinations(list(range(len(flatten(pred_res)))), address_num)
address_idx_list = combinations(list(range(len(pred_res))), address_num)

for address_idx in address_idx_list:
candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions)
candidates = self.address_by_idx(pred_res, key, address_idx)
new_candidates += candidates
return new_candidates

@lru_cache(maxsize=100)
def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions):
# TODO:在类初始化时应该有一个cache(默认True)的参数,用户可以指定是否用cache(若KB会变,那不能用cache)
@lru_cache(maxsize=None)
def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address):
pred_res = hashable_to_list(pred_res)
key = hashable_to_list(key)
candidates = []
for address_num in range(len(flatten(pred_res)) + 1):
if address_num == 0:
if check_equal(self._logic_forward(pred_res, multiple_predictions), key, self.max_err):
if check_equal(self.logic_forward(pred_res), key, self.max_err):
candidates.append(pred_res)
else:
new_candidates = self._address(address_num, pred_res, key, multiple_predictions)
new_candidates = self._address(address_num, pred_res, key)
candidates += new_candidates
if len(candidates) > 0:
min_address_num = address_num
@@ -183,7 +155,7 @@ class KBBase(ABC):
for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if address_num > max_address_num:
return candidates
new_candidates = self._address(address_num, pred_res, key, multiple_predictions)
new_candidates = self._address(address_num, pred_res, key)
candidates += new_candidates
return candidates



+ 1
- 1
examples/datasets/hed/learn_add.pl View File

@@ -32,7 +32,7 @@ abduce_consistent_insts(Exs):-
% (Experimental) Uncomment to use parallel abduction
% abduce_consistent_exs_concurrent(Exs), !.
logic_forward(Exs, X) :- abduce_consistent_insts([Exs]) -> X = true ; X = false.
logic_forward(Exs, X) :- abduce_consistent_insts(Exs) -> X = true ; X = false.
logic_forward(Exs) :- abduce_consistent_insts(Exs).
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


Loading…
Cancel
Save