From abb7fb4a1eb0927a2e96abcf8d6baa91e08a9ac1 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Thu, 24 Nov 2022 17:42:27 +0800 Subject: [PATCH] Update kb.py --- abducer/kb.py | 403 ++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 310 insertions(+), 93 deletions(-) diff --git a/abducer/kb.py b/abducer/kb.py index f26b363..2600f8b 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -1,8 +1,8 @@ # coding: utf-8 #================================================================# -# Copyright (C) 2021 Freecss All rights reserved. +# Copyright (C) 2021 LAMDA All rights reserved. # -# File Name :abducer_base.py +# File Name :kb.py # Author :freecss # Email :karlfreecss@gmail.com # Created Date :2021/06/03 @@ -10,63 +10,27 @@ # #================================================================# -import sys -sys.path.append("..") - -import abc -from abducer.kb import add_KB, hwf_KB, add_prolog_KB +from abc import ABC, abstractmethod +import bisect +import copy import numpy as np +from collections import defaultdict from itertools import product, combinations -import time -class AbducerBase(abc.ABC): - def __init__(self, kb, dist_func = 'confidence', cache = True): - self.kb = kb - assert(dist_func == 'hamming' or dist_func == 'confidence') - self.dist_func = dist_func - self.cache = cache - - if self.cache: - self.cache_min_address_num = {} - self.cache_candidates = {} +import pyswip - def hamming_dist(self, A, B): - B = np.array(B) - A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B))) - return np.sum(A != B, axis = 1) +class KBBase(ABC): + def __init__(self): + pass - def confidence_dist(self, A, B): - mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list))))) - B = [list(map(lambda x : mapping[x], b)) for b in B] + @abstractmethod + def logic_forward(self): + pass - B = np.array(B) - A = np.clip(A, 1e-9, 1) - A = np.expand_dims(A, axis=0) - A = A.repeat(axis=0, repeats=(len(B))) - rows = np.array(range(len(B))) - rows = np.expand_dims(rows, axis = 1).repeat(axis = 1, repeats = len(B[0])) - cols = np.array(range(len(B[0]))) - cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B)) - return 1 - np.prod(A[rows, cols, B], axis = 1) - - - def get_cost_list(self, pred_res, pred_res_prob, candidates): - if self.dist_func == 'hamming': - return self.hamming_dist(pred_res, candidates) - elif self.dist_func == 'confidence': - return self.confidence_dist(pred_res_prob, candidates) - - def get_min_cost_candidate(self, pred_res, pred_res_prob, candidates): - if len(candidates) == 0: - return [] - elif len(candidates) == 1: - return candidates[0] - else: - cost_list = self.get_cost_list(pred_res, pred_res_prob, candidates) - min_address_num = np.min(cost_list) - idxs = np.where(cost_list == min_address_num)[0] - return [candidates[idx] for idx in idxs][0] + @abstractmethod + def abduce_candidates(self): + pass def filter_all_candidates(self, pred_res, all_candidates, max_address_num, require_more_address): if len(all_candidates) == 0: @@ -80,68 +44,321 @@ class AbducerBase(abc.ABC): idxs = np.where(cost_list <= address_num)[0] candidates = [all_candidates[idx] for idx in idxs] return candidates, min_address_num, address_num + + def hamming_dist(self, A, B): + B = np.array(B) + A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B))) + return np.sum(A != B, axis = 1) + + def __len__(self): + pass + +class ClsKB(KBBase): + def __init__(self, GKB_flag = False, pseudo_label_list = None, len_list = None): + super().__init__() + self.GKB_flag = GKB_flag + self.pseudo_label_list = pseudo_label_list + self.len_list = len_list + self.prolog_flag = False + + if GKB_flag: + # self.base = np.load('abducer/hwf.npy', allow_pickle=True).item() + self.base = {} + X, Y = self.get_GKB(self.pseudo_label_list, self.len_list) + for x, y in zip(X, Y): + self.base.setdefault(len(x), defaultdict(list))[y].append(x) + else: + self.all_address_candidate_dict = {} + for address_num in range(1, max(self.len_list) + 1): + self.all_address_candidate_dict[address_num] = list(product(self.pseudo_label_list, repeat = address_num)) + + def get_GKB(self, pseudo_label_list, len_list): + all_X = [] + for len in len_list: + all_X += list(product(pseudo_label_list, repeat = len)) + + X = [] + Y = [] + for x in all_X: + y = self.logic_forward(x) + if y != np.inf: + X.append(x) + Y.append(y) + return X, Y + + def logic_forward(self): + pass + + def abduce_candidates(self, pred_res, key, max_address_num = -1, require_more_address = 0): + if max_address_num == -1: + max_address_num = len(pred_res) + if self.GKB_flag: + all_candidates = self.get_candidates_GKB(key, len(pred_res)) + return self.filter_all_candidates(pred_res, all_candidates, max_address_num, require_more_address) + else: + return self.abduction(pred_res, key, max_address_num, require_more_address) - def abduce(self, data, max_address_num = -1, require_more_address = 0): - pred_res, pred_res_prob, ans = data - if self.cache and (tuple(pred_res), ans) in self.cache_min_address_num: - address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), ans)] + require_more_address) - if (tuple(pred_res), ans, address_num) in self.cache_candidates: - candidates = self.cache_candidates[(tuple(pred_res), ans, address_num)] - candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) - return candidate + + def get_candidates_GKB(self, key, length = None): + if self.base == {}: + return [] - candidates, min_address_num, address_num = self.kb.abduce_candidates(pred_res, ans, max_address_num, require_more_address) + if key is None: + return self.get_all_candidates() + + if length is None: + length = list(self.base.keys()) + elif type(length) is int and length not in self.len_list: + return [] + else: + length = [length] + + return sum([self.base[l][key] for l in length], []) + + def get_all_candidates(self): + if self.base == {}: + return [] + else: + return sum([sum(v.values(), []) for v in self.base.values()], []) + + + + + + def abduction(self, pred_res, key, max_address_num, require_more_address): + candidates = [] + for address_num in range(len(pred_res) + 1): + if address_num == 0: + if abs(self.logic_forward(pred_res) - key) <= 1e-3: + candidates.append(pred_res) + else: + new_candidates = self.address(address_num, pred_res, key) + candidates += new_candidates + + if len(candidates) > 0: + min_address_num = address_num + break + + if address_num >= max_address_num: + return [], 0, 0 - if self.cache: - self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num - self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates + for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): + if address_num > max_address_num: + return candidates, min_address_num, address_num - 1 + new_candidates = self.address(address_num, pred_res, key) + candidates += new_candidates - candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) - return candidate + return candidates, min_address_num, address_num + + def address(self, address_num, pred_res, key): + new_candidates = [] + all_address_candidate = self.all_address_candidate_dict[address_num] + address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) + for address_idx in address_idx_list: + for c in all_address_candidate: + address_list = [pred_res[i] for i in address_idx] + if(sum([address_list[i] == c[i] for i in range(address_num)]) == 0): + candidate = pred_res.copy() + for i, idx in enumerate(address_idx): + candidate[idx] = c[i] + if self.logic_forward(candidate) == key: + new_candidates.append(candidate) + return new_candidates + - - def batch_abduce(self, Z, Y, max_address_num = -1, require_more_address = 0): - return [ - self.abduce((z, prob, y), max_address_num, require_more_address)\ - for z, prob, y in zip(Z['cls'], Z['prob'], Y) - ] - def __call__(self, Z, Y, max_address_num = -1, require_more_address = 0): - return self.batch_abduce(Z, Y, max_address_num, require_more_address) + def _dict_len(self, dic): + if not self.GKB_flag: + return 0 + else: + return sum(len(c) for c in dic.values()) + + def __len__(self): + if not self.GKB_flag: + return 0 + else: + return sum(self._dict_len(v) for v in self.base.values()) + + +class add_KB(ClsKB): + def __init__(self, GKB_flag = False, \ + pseudo_label_list = list(range(10)), \ + len_list = [2]): + super().__init__(GKB_flag, pseudo_label_list, len_list) + + def logic_forward(self, nums): + return sum(nums) + +class hwf_KB(ClsKB): + def __init__(self, GKB_flag = False, \ + pseudo_label_list = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \ + len_list = [1, 3, 5, 7]): + super().__init__(GKB_flag, pseudo_label_list, len_list) + + def valid_candidate(self, formula): + if len(formula) % 2 == 0: + return False + for i in range(len(formula)): + if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']: + return False + if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: + return False + return True + + def logic_forward(self, formula): + if not self.valid_candidate(formula): + return np.inf + mapping = {'1':'1', '2':'2', '3':'3', '4':'4', '5':'5', '6':'6', '7':'7', '8':'8', '9':'9', '+':'+', '-':'-', 'times':'*', 'div':'/'} + formula = [mapping[f] for f in formula] + return round(eval(''.join(formula)), 2) + + + + + +class prolog_KB(KBBase): + def __init__(self, pseudo_label_list): + super().__init__() + self.pseudo_label_list = pseudo_label_list + self.prolog = pyswip.Prolog() + for i in self.pseudo_label_list: + self.prolog.assertz("pseudo_label(%s)" % i) + self.prolog_flag = True + + def logic_forward(self): + pass + + def abduce_candidates(self, pred_res, key, max_address_num = -1, require_more_address = 0): + if max_address_num == -1: + max_address_num = len(pred_res) + all_candidates = self.get_candidates_prolog(key) + return self.filter_all_candidates(pred_res, all_candidates, max_address_num, require_more_address) + + +class add_prolog_KB(prolog_KB): + def __init__(self, pseudo_label_list = list(range(10))): + super().__init__(pseudo_label_list) + self.prolog.assertz("addition(Z1, Z2, Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2") + + def logic_forward(self, nums): + return list(self.prolog.query("addition(%s, %s, Res)." %(nums[0], nums[1])))[0]['Res'] + + def get_candidates_prolog(self, key): + return [(z['Z1'], z['Z2']) for z in list(self.prolog.query("addition(Z1, Z2, %s)." % key))] + + +class RegKB(KBBase): + def __init__(self, GKB_flag = False, X = None, Y = None): + super().__init__() + tmp_dict = {} + for x, y in zip(X, Y): + tmp_dict.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) + + self.base = {} + for l in tmp_dict.keys(): + data = sorted(list(zip(tmp_dict[l].keys(), tmp_dict[l].values()))) + X = [x for y, x in data] + Y = [y for y, x in data] + self.base[l] = (X, Y) + + def valid_candidate(self): + pass + + def logic_forward(self): + pass + + def abduce_candidates(self, key, length = None): + if key is None: + return self.get_all_candidates() + + length = self._length(length) + + min_err = 999999 + candidates = [] + for l in length: + X, Y = self.base[l] -if __name__ == '__main__': + idx = bisect.bisect_left(Y, key) + begin = max(0, idx - 1) + end = min(idx + 2, len(X)) + + for idx in range(begin, end): + err = abs(Y[idx] - key) + if abs(err - min_err) < 1e-9: + candidates.extend(X[idx]) + elif err < min_err: + candidates = copy.deepcopy(X[idx]) + min_err = err + return candidates + + def get_all_candidates(self): + return sum([sum(D[0], []) for D in self.base.values()], []) + + def __len__(self): + return sum([sum(len(x) for x in D[0]) for D in self.base.values()]) + +import time +if __name__ == "__main__": + # With ground KB kb = add_KB(GKB_flag = True) - abd = AbducerBase(kb, 'hamming') - res = abd.abduce(([1, 1], None, 17), max_address_num = 2, require_more_address = 0) + print('len(kb):', len(kb)) + res = kb.get_candidates_GKB(0) print(res) - res = abd.abduce(([1, 1], None, 17), max_address_num = 1, require_more_address = 0) + res = kb.get_candidates_GKB(18) print(res) - res = abd.abduce(([1, 1], None, 20), max_address_num = 2, require_more_address = 0) + res = kb.get_candidates_GKB(18) + print(res) + res = kb.get_candidates_GKB(16) print(res) print() - kb = add_prolog_KB() - abd = AbducerBase(kb, 'hamming') - res = abd.abduce(([1, 1], None, 17), max_address_num = 2, require_more_address = 0) - print(res) - res = abd.abduce(([1, 1], None, 17), max_address_num = 1, require_more_address = 0) - print(res) - res = abd.abduce(([1, 1], None, 20), max_address_num = 2, require_more_address = 0) - print(res) + # Without ground KB + kb = add_KB() + print('len(kb):', len(kb)) print() - kb = hwf_KB(len_list = [1, 3, 5]) - abd = AbducerBase(kb, 'hamming') - res = abd.abduce((['5', '+', '2'], None, 3), max_address_num = 2, require_more_address = 0) + # Prolog + kb = add_prolog_KB() + print(kb.logic_forward([3, 4])) + res = kb.get_candidates_prolog(16) print(res) - res = abd.abduce((['5', '+', '2'], None, 64), max_address_num = 3, require_more_address = 0) + + start = time.time() + kb = hwf_KB(GKB_flag = True, len_list = [1, 3, 5]) + print(time.time() - start) + print('len(kb):', len(kb)) + res = kb.get_candidates_GKB(2, length = 1) print(res) - res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num = 3, require_more_address = 0) + res = kb.get_candidates_GKB(1, length = 3) print(res) - res = abd.abduce((['5', '8', '8', '8', '8'], None, 3.17), max_address_num = 5, require_more_address = 3) + res = kb.get_candidates_GKB(3.67, length = 5) print(res) print() + + # X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] + # Y = [2, 1, 1, 2, 2] + # kb = ClsKB(X, Y) + # print('len(kb):', len(kb)) + # res = kb.get_candidates(2, 5) + # print(res) + # res = kb.get_candidates(2, 3) + # print(res) + # res = kb.get_candidates(None) + # print(res) + # print() + + # X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"] + # Y = [2, 1, 1, 2, 1.5, 1.5] + # kb = RegKB(X, Y) + # print('len(kb):', len(kb)) + # res = kb.get_candidates(1.6) + # print(res) + # res = kb.get_candidates(1.6, length = 9) + # print(res) + # res = kb.get_candidates(None) + # print(res) +