From 87c3ba1b71dbc372bf826cb66ada556f8fda8c38 Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Wed, 15 Mar 2023 14:13:31 +0800 Subject: [PATCH] Remove --- abl/abducer/abducer_base.py | 226 +++++++++++++++++------------ abl/abducer/kb.py | 102 +++++-------- examples/datasets/hed/learn_add.pl | 2 +- 3 files changed, 168 insertions(+), 162 deletions(-) diff --git a/abl/abducer/abducer_base.py b/abl/abducer/abducer_base.py index 54dbec1..b7e6526 100644 --- a/abl/abducer/abducer_base.py +++ b/abl/abducer/abducer_base.py @@ -16,27 +16,19 @@ from zoopt import Dimension, Objective, Parameter, Opt from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist class AbducerBase(abc.ABC): - def __init__(self, kb, dist_func='hamming', zoopt=False, multiple_predictions=False): + def __init__(self, kb, dist_func='hamming', zoopt=False): self.kb = kb assert dist_func == 'hamming' or dist_func == 'confidence' self.dist_func = dist_func self.zoopt = zoopt - self.multiple_predictions = multiple_predictions if dist_func == 'confidence': self.mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list))))) def _get_cost_list(self, pred_res, pred_res_prob, candidates): if self.dist_func == 'hamming': - if self.multiple_predictions: - pred_res = flatten(pred_res) - candidates = [flatten(c) for c in candidates] - return hamming_dist(pred_res, candidates) elif self.dist_func == 'confidence': - if self.multiple_predictions: - pred_res_prob = flatten(pred_res_prob) - candidates = [flatten(c) for c in candidates] candidates = [list(map(lambda x: self.mapping[x], c)) for c in candidates] return confidence_dist(pred_res_prob, candidates) @@ -50,26 +42,24 @@ class AbducerBase(abc.ABC): cost_list = self._get_cost_list(pred_res, pred_res_prob, candidates) candidate = candidates[np.argmin(cost_list)] return candidate - - def _zoopt_address_score(self, pred_res, pred_res_prob, key, sol): - if not self.multiple_predictions: - address_idx = np.where(sol.get_x() != 0)[0] - candidates = self.address_by_idx(pred_res, key, address_idx) - if len(candidates) > 0: - return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates)) - else: - return len(pred_res) + + def _zoopt_address_score_single(self, sol_x, pred_res, pred_res_prob, key): + address_idx = np.where(sol_x != 0)[0] + candidates = self.address_by_idx(pred_res, key, address_idx) + if len(candidates) > 0: + return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates)) else: - all_address_flag = reform_idx(sol.get_x(), pred_res) - score = 0 - for idx in range(len(pred_res)): - address_idx = np.where(all_address_flag[idx] != 0)[0] - candidates = self.address_by_idx([pred_res[idx]], key[idx], address_idx) - if len(candidates) > 0: - score += np.min(self._get_cost_list(pred_res[idx], pred_res_prob[idx], candidates)) - else: - score += len(pred_res[idx]) - return score + return len(pred_res) + + def _zoopt_address_score(self, pred_res, pred_res_prob, key, sol): + # if not self.multiple_predictions: + return self._zoopt_address_score_single(sol.get_x(), pred_res, pred_res_prob, key) + # else: + # all_address_flag = reform_idx(sol.get_x(), pred_res) + # score = 0 + # for idx in range(len(pred_res)): + # score += self._zoopt_address_score_single(all_address_flag[idx], pred_res[idx], pred_res_prob[idx], key) + # return score def _constrain_address_num(self, solution, max_address_num): x = solution.get_x() @@ -88,21 +78,29 @@ class AbducerBase(abc.ABC): return solution def address_by_idx(self, pred_res, key, address_idx): - return self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) + return self.kb.address_by_idx(pred_res, key, address_idx) - def abduce(self, data, max_address_num=-1, require_more_address=0): + def abduce(self, data, max_address=-1, require_more_address=0): pred_res, pred_res_prob, key = data - if max_address_num == -1: - max_address_num = len(flatten(pred_res)) + # if max_address_num == -1: + # max_address_num = len(flatten(pred_res)) + + assert(type(max_address) in (int, float)) + if max_address == -1: + max_address_num = len(pred_res) + elif type(max_address) == float: + assert(max_address >= 0 and max_address <= 1) + max_address_num = round(len(pred_res) * max_address) + else: + assert(max_address >= 0) + max_address_num = max_address if self.zoopt: solution = self.zoopt_get_solution(pred_res, pred_res_prob, key, max_address_num) address_idx = np.where(solution != 0)[0] candidates = self.address_by_idx(pred_res, key, address_idx) else: - candidates = self.kb.abduce_candidates( - pred_res, key, max_address_num, require_more_address, self.multiple_predictions - ) + candidates = self.kb.abduce_candidates(pred_res, key, max_address_num, require_more_address) candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates) return candidate @@ -110,134 +108,170 @@ class AbducerBase(abc.ABC): def abduce_rules(self, pred_res): return self.kb.abduce_rules(pred_res) - def batch_abduce(self, Z, Y, max_address_num=-1, require_more_address=0): - # if self.multiple_predictions: - return self.abduce((Z['cls'], Z['prob'], Y), max_address_num, require_more_address) - # else: - # return [self.abduce((z, prob, y), max_address_num, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)] + def batch_abduce(self, data, max_address=-1, require_more_address=0): + Z1, Z2, Y = data + return [self.abduce((z, prob, y), max_address, require_more_address) for z, prob, y in zip(Z1, Z2, Y)] def __call__(self, Z, Y, max_address_num=-1, require_more_address=0): return self.batch_abduce(Z, Y, max_address_num, require_more_address) if __name__ == '__main__': - from kb import add_KB, prolog_KB, HWF_KB + from kb import add_KB, prolog_KB, HWF_KB, HED_prolog_KB - prob1 = [[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] - prob2 = [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] + prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] + prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] + print('add_KB with GKB:') kb = add_KB(GKB_flag=True) abd = AbducerBase(kb, 'confidence') - res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) + res = abd.batch_abduce(([[1, 1]], prob1, [8]), max_address=2, require_more_address=0) print(res) - res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) + res = abd.batch_abduce(([[1, 1]], prob2, [8]), max_address=2, require_more_address=0) print(res) - res = abd.abduce(([1, 1], prob1, 17), max_address_num=2, require_more_address=0) + res = abd.batch_abduce(([[1, 1]], prob1, [17]), max_address=2, require_more_address=0) print(res) - res = abd.abduce(([1, 1], prob1, 17), max_address_num=1, require_more_address=0) + res = abd.batch_abduce(([[1, 1]], prob1, [17]), max_address=1, require_more_address=0) print(res) - res = abd.abduce(([1, 1], prob1, 20), max_address_num=2, require_more_address=0) + res = abd.batch_abduce(([[1, 1]], prob1, [20]), max_address=2, require_more_address=0) print(res) print() - - multiple_prob = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]], - [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] - - + print('add_KB without GKB:') kb = add_KB() - abd = AbducerBase(kb, 'confidence', multiple_predictions=True) - res = abd.abduce(([[1, 1], [1, 2]], multiple_prob, [4, 8]), max_address_num=4, require_more_address=0) + abd = AbducerBase(kb, 'confidence') + res = abd.batch_abduce(([[1, 1]], prob1, [8]), max_address=2, require_more_address=0) + print(res) + res = abd.batch_abduce(([[1, 1]], prob2, [8]), max_address=2, require_more_address=0) print(res) - res = abd.abduce(([[1, 1], [1, 2]], multiple_prob, [4, 8]), max_address_num=4, require_more_address=1) + res = abd.batch_abduce(([[1, 1]], prob1, [17]), max_address=2, require_more_address=0) + print(res) + res = abd.batch_abduce(([[1, 1]], prob1, [17]), max_address=1, require_more_address=0) + print(res) + res = abd.batch_abduce(([[1, 1]], prob1, [20]), max_address=2, require_more_address=0) print(res) print() - + print('prolog_KB with add.pl:') kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/datasets/mnist_add/add.pl') abd = AbducerBase(kb, 'confidence') - res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) + res = abd.batch_abduce(([[1, 1]], prob1, [8]), max_address=2, require_more_address=0) print(res) - res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) + res = abd.batch_abduce(([[1, 1]], prob2, [8]), max_address=2, require_more_address=0) print(res) - res = abd.abduce(([1, 1], prob1, 17), max_address_num=2, require_more_address=0) + res = abd.batch_abduce(([[1, 1]], prob1, [17]), max_address=2, require_more_address=0) print(res) - res = abd.abduce(([1, 1], prob1, 17), max_address_num=1, require_more_address=0) + res = abd.batch_abduce(([[1, 1]], prob1, [17]), max_address=1, require_more_address=0) print(res) - res = abd.abduce(([1, 1], prob1, 20), max_address_num=2, require_more_address=0) + res = abd.batch_abduce(([[1, 1]], prob1, [20]), max_address=2, require_more_address=0) print(res) print() + print('prolog_KB with add.pl using zoopt:') kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/datasets/mnist_add/add.pl') abd = AbducerBase(kb, 'confidence', zoopt=True) - res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) + res = abd.batch_abduce(([[1, 1]], prob1, [8]), max_address=2, require_more_address=0) print(res) - res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) + res = abd.batch_abduce(([[1, 1]], prob2, [8]), max_address=2, require_more_address=0) print(res) - res = abd.abduce(([1, 1], prob1, 17), max_address_num=2, require_more_address=0) + res = abd.batch_abduce(([[1, 1]], prob1, [17]), max_address=2, require_more_address=0) print(res) - res = abd.abduce(([1, 1], prob1, 17), max_address_num=1, require_more_address=0) + res = abd.batch_abduce(([[1, 1]], prob1, [17]), max_address=1, require_more_address=0) print(res) - res = abd.abduce(([1, 1], prob1, 20), max_address_num=2, require_more_address=0) + res = abd.batch_abduce(([[1, 1]], prob1, [20]), max_address=2, require_more_address=0) print(res) print() - - kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 0.1) - abd = AbducerBase(kb, 'hamming') - res = abd.abduce((['5', '+', '2'], None, 3), max_address_num=2, require_more_address=0) + + print('add_KB with multiple inputs at once:') + multiple_prob = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]], + [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] + + kb = add_KB() + abd = AbducerBase(kb, 'confidence') + res = abd.batch_abduce(([[1, 1], [1, 2]], multiple_prob, [4, 8]), max_address=4, require_more_address=0) print(res) - res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0) + res = abd.batch_abduce(([[1, 1], [1, 2]], multiple_prob, [4, 8]), max_address=4, require_more_address=1) print(res) print() - kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 1) + print('HWF_KB with GKB, max_err=0.1') + kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 0.1) abd = AbducerBase(kb, 'hamming') - res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0) + res = abd.batch_abduce(([['5', '+', '2']], [None], [3]), max_address=2, require_more_address=0) print(res) - res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num=3, require_more_address=0) + res = abd.batch_abduce(([['5', '+', '9']], [None], [65]), max_address=3, require_more_address=0) print(res) - res = abd.abduce((['5', '8', '8', '8', '8'], None, 3.17), max_address_num=5, require_more_address=3) + res = abd.batch_abduce(([['5', '8', '8', '8', '8']], [None], [3.17]), max_address=5, require_more_address=3) print(res) print() + print('HWF_KB without GKB, max_err=0.1') kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1) - abd = AbducerBase(kb, 'hamming', multiple_predictions=True) - res = abd.abduce(([['5', '+', '2'], ['5', '+', '9']], None, [3, 64]), max_address_num=6, require_more_address=0) + abd = AbducerBase(kb, 'hamming') + res = abd.batch_abduce(([['5', '+', '2']], [None], [3]), max_address=2, require_more_address=0) + print(res) + res = abd.batch_abduce(([['5', '+', '9']], [None], [65]), max_address=3, require_more_address=0) + print(res) + res = abd.batch_abduce(([['5', '8', '8', '8', '8']], [None], [3.17]), max_address=5, require_more_address=3) print(res) print() - kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1) + print('HWF_KB with GKB, max_err=1') + kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 1) abd = AbducerBase(kb, 'hamming') - res = abd.abduce((['5', '+', '2'], None, 3), max_address_num=2, require_more_address=0) + res = abd.batch_abduce(([['5', '+', '9']], [None], [65]), max_address=3, require_more_address=0) print(res) - res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0) + res = abd.batch_abduce(([['5', '+', '2']], [None], [1.67]), max_address=3, require_more_address=0) print(res) + res = abd.batch_abduce(([['5', '8', '8', '8', '8']], [None], [3.17]), max_address=5, require_more_address=3) + print(res) + print() + print('HWF_KB without GKB, max_err=1') kb = HWF_KB(len_list=[1, 3, 5], max_err = 1) abd = AbducerBase(kb, 'hamming') - res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0) + res = abd.batch_abduce(([['5', '+', '9']], [None], [65]), max_address=3, require_more_address=0) + print(res) + res = abd.batch_abduce(([['5', '+', '2']], [None], [1.67]), max_address=3, require_more_address=0) + print(res) + res = abd.batch_abduce(([['5', '8', '8', '8', '8']], [None], [3.17]), max_address=5, require_more_address=3) + print(res) + print() + + print('HWF_KB with multiple inputs at once:') + kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1) + abd = AbducerBase(kb, 'hamming') + res = abd.batch_abduce(([['5', '+', '2'], ['5', '+', '9']], [None, None], [3, 64]), max_address=1, require_more_address=0) print(res) - res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num=3, require_more_address=0) + res = abd.batch_abduce(([['5', '+', '2'], ['5', '+', '9']], [None, None], [3, 64]), max_address=3, require_more_address=0) print(res) - res = abd.abduce((['5', '8', '8', '8', '8'], None, 3.17), max_address_num=5, require_more_address=3) + res = abd.batch_abduce(([['5', '+', '2'], ['5', '+', '9']], [None, None], [3, 65]), max_address=3, require_more_address=0) + print(res) + print() + print('max_address is float') + res = abd.batch_abduce(([['5', '+', '2'], ['5', '+', '9']], [None, None], [3, 64]), max_address=0.5, require_more_address=0) + print(res) + res = abd.batch_abduce(([['5', '+', '2'], ['5', '+', '9']], [None, None], [3, 64]), max_address=0.9, require_more_address=0) print(res) print() - kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl') - abd = AbducerBase(kb, zoopt=True, multiple_predictions=True) + kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl') + abd = AbducerBase(kb, zoopt=True) consist_exs = [[1, 1, '+', 0, '=', 1, 1], [1, '+', 1, '=', 1, 0], [0, '+', 0, '=', 0]] inconsist_exs = [[1, '+', 0, '=', 0], [1, '=', 1, '=', 0], [0, '=', 0, '=', 1, 1]] - # inconsist_exs = [[1, '+', 0, '=', 0], ['=', '=', '=', '=', 0], ['=', '=', 0, '=', '=', '=']] rules = ['my_op([0], [0], [0])', 'my_op([1], [1], [1, 0])'] - print(kb._logic_forward(consist_exs, True), kb._logic_forward(inconsist_exs, True)) - print(kb.consist_rule([1, '+', 1, '=', 1, 0], rules), kb.consist_rule([1, '+', 1, '=', 1, 1], rules)) + print(kb.logic_forward(consist_exs)) + print(kb.logic_forward(inconsist_exs)) print() - - res = abd.abduce((consist_exs, [None] * len(consist_exs), [None] * len(consist_exs))) - print(res) - res = abd.abduce((inconsist_exs, [None] * len(consist_exs), [None] * len(inconsist_exs))) - print(res) + print(kb.consist_rule([1, '+', 1, '=', 1, 0], rules)) + print(kb.consist_rule([1, '+', 1, '=', 1, 1], rules)) print() - abduced_rules = abd.abduce_rules(consist_exs) - print(abduced_rules) \ No newline at end of file + # res = abd.abduce((consist_exs, [None] * len(consist_exs), [None] * len(consist_exs))) + # print(res) + # res = abd.batch_abduce((inconsist_exs, [None] * len(consist_exs), [None] * len(inconsist_exs))) + # print(res) + # print() + + # abduced_rules = abd.batch_abduce_rules(consist_exs) + # print(abduced_rules) \ No newline at end of file diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index feba8c4..42a5f6c 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -17,7 +17,7 @@ import numpy as np from collections import defaultdict from itertools import product, combinations -from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list +from utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list from multiprocessing import Pool @@ -25,11 +25,16 @@ from functools import lru_cache import pyswip class KBBase(ABC): - def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0): + def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0, cache_size=128): + # TODO:添加一下类型检查,比如 + # if not isinstance(X, (np.ndarray, spmatrix)): + # raise TypeError("X should be numpy array or sparse matrix") + self.pseudo_label_list = pseudo_label_list self.len_list = len_list self.GKB_flag = GKB_flag self.max_err = max_err + self.cache_size = cache_size if GKB_flag: self.base = {} @@ -73,106 +78,73 @@ class KBBase(ABC): @abstractmethod def logic_forward(self, pseudo_labels): pass - - def _logic_forward(self, xs, multiple_predictions=False): - if not multiple_predictions: - return self.logic_forward(xs) - else: - res = [self.logic_forward(x) for x in xs] - return res - def abduce_candidates(self, pred_res, key, max_address_num, require_more_address=0, multiple_predictions=False): + def abduce_candidates(self, pred_res, key, max_address_num, require_more_address=0): if self.GKB_flag: - return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) + return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address) else: - return self._abduce_by_search(to_hashable(pred_res), to_hashable(key), max_address_num, require_more_address, multiple_predictions) + return self._abduce_by_search(to_hashable(pred_res), to_hashable(key), max_address_num, require_more_address) @abstractmethod def _find_candidate_GKB(self, pred_res, key): pass - def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): + def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address): if self.base == {}: return [] - - if not multiple_predictions: - if len(pred_res) not in self.len_list: - return [] - all_candidates = self._find_candidate_GKB(pred_res, key) - if len(all_candidates) == 0: - return [] - else: - cost_list = hamming_dist(pred_res, all_candidates) - min_address_num = np.min(cost_list) - address_num = min(max_address_num, min_address_num + require_more_address) - idxs = np.where(cost_list <= address_num)[0] - candidates = [all_candidates[idx] for idx in idxs] - return candidates - + + if len(pred_res) not in self.len_list: + return [] + all_candidates = self._find_candidate_GKB(pred_res, key) + if len(all_candidates) == 0: + return [] else: - min_address_num = 0 - all_candidates_save = [] - cost_list_save = [] - for p_res, k in zip(pred_res, key): - if len(p_res) not in self.len_list: - return [] - all_candidates = self._find_candidate_GKB(p_res, k) - if len(all_candidates) == 0: - return [] - else: - all_candidates_save.append(all_candidates) - cost_list = hamming_dist(p_res, all_candidates) - min_address_num += np.min(cost_list) - cost_list_save.append(cost_list) - - multiple_all_candidates = [flatten(c) for c in product(*all_candidates_save)] - multiple_cost_list = np.array([sum(cost) for cost in product(*cost_list_save)]) + cost_list = hamming_dist(pred_res, all_candidates) + min_address_num = np.min(cost_list) address_num = min(max_address_num, min_address_num + require_more_address) - idxs = np.where(multiple_cost_list <= address_num)[0] - candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] + idxs = np.where(cost_list <= address_num)[0] + candidates = [all_candidates[idx] for idx in idxs] return candidates - def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): + def address_by_idx(self, pred_res, key, address_idx): candidates = [] abduce_c = product(self.pseudo_label_list, repeat=len(address_idx)) - if multiple_predictions: - save_pred_res = pred_res - pred_res = flatten(pred_res) + # if multiple_predictions: + # save_pred_res = pred_res + # pred_res = flatten(pred_res) for c in abduce_c: candidate = pred_res.copy() for i, idx in enumerate(address_idx): candidate[idx] = c[i] - if multiple_predictions: - candidate = reform_idx(candidate, save_pred_res) - if check_equal(self._logic_forward(candidate, multiple_predictions), key, self.max_err): + # if multiple_predictions: + # candidate = reform_idx(candidate, save_pred_res) + if check_equal(self.logic_forward(candidate), key, self.max_err): candidates.append(candidate) return candidates - def _address(self, address_num, pred_res, key, multiple_predictions): + def _address(self, address_num, pred_res, key): new_candidates = [] - if not multiple_predictions: - address_idx_list = combinations(list(range(len(pred_res))), address_num) - else: - address_idx_list = combinations(list(range(len(flatten(pred_res)))), address_num) + address_idx_list = combinations(list(range(len(pred_res))), address_num) for address_idx in address_idx_list: - candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) + candidates = self.address_by_idx(pred_res, key, address_idx) new_candidates += candidates return new_candidates - @lru_cache(maxsize=100) - def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): + # TODO:在类初始化时应该有一个cache(默认True)的参数,用户可以指定是否用cache(若KB会变,那不能用cache) + @lru_cache(maxsize=None) + def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address): pred_res = hashable_to_list(pred_res) key = hashable_to_list(key) candidates = [] for address_num in range(len(flatten(pred_res)) + 1): if address_num == 0: - if check_equal(self._logic_forward(pred_res, multiple_predictions), key, self.max_err): + if check_equal(self.logic_forward(pred_res), key, self.max_err): candidates.append(pred_res) else: - new_candidates = self._address(address_num, pred_res, key, multiple_predictions) + new_candidates = self._address(address_num, pred_res, key) candidates += new_candidates if len(candidates) > 0: min_address_num = address_num @@ -183,7 +155,7 @@ class KBBase(ABC): for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): if address_num > max_address_num: return candidates - new_candidates = self._address(address_num, pred_res, key, multiple_predictions) + new_candidates = self._address(address_num, pred_res, key) candidates += new_candidates return candidates diff --git a/examples/datasets/hed/learn_add.pl b/examples/datasets/hed/learn_add.pl index fbf698f..35a71c6 100644 --- a/examples/datasets/hed/learn_add.pl +++ b/examples/datasets/hed/learn_add.pl @@ -32,7 +32,7 @@ abduce_consistent_insts(Exs):- % (Experimental) Uncomment to use parallel abduction % abduce_consistent_exs_concurrent(Exs), !. -logic_forward(Exs, X) :- abduce_consistent_insts([Exs]) -> X = true ; X = false. +logic_forward(Exs, X) :- abduce_consistent_insts(Exs) -> X = true ; X = false. logic_forward(Exs) :- abduce_consistent_insts(Exs). %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%