Browse Source

Update kb.py

pull/3/head
troyyyyy GitHub 2 years ago
parent
commit
abb7fb4a1e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 310 additions and 93 deletions
  1. +310
    -93
      abducer/kb.py

+ 310
- 93
abducer/kb.py View File

@@ -1,8 +1,8 @@
# coding: utf-8
#================================================================#
# Copyright (C) 2021 Freecss All rights reserved.
# Copyright (C) 2021 LAMDA All rights reserved.
#
# File Name :abducer_base.py
# File Name :kb.py
# Author :freecss
# Email :karlfreecss@gmail.com
# Created Date :2021/06/03
@@ -10,63 +10,27 @@
#
#================================================================#

import sys
sys.path.append("..")

import abc
from abducer.kb import add_KB, hwf_KB, add_prolog_KB
from abc import ABC, abstractmethod
import bisect
import copy
import numpy as np

from collections import defaultdict
from itertools import product, combinations
import time

class AbducerBase(abc.ABC):
def __init__(self, kb, dist_func = 'confidence', cache = True):
self.kb = kb
assert(dist_func == 'hamming' or dist_func == 'confidence')
self.dist_func = dist_func
self.cache = cache
if self.cache:
self.cache_min_address_num = {}
self.cache_candidates = {}
import pyswip

def hamming_dist(self, A, B):
B = np.array(B)
A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B)))
return np.sum(A != B, axis = 1)
class KBBase(ABC):
def __init__(self):
pass

def confidence_dist(self, A, B):
mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list)))))
B = [list(map(lambda x : mapping[x], b)) for b in B]
@abstractmethod
def logic_forward(self):
pass
B = np.array(B)
A = np.clip(A, 1e-9, 1)
A = np.expand_dims(A, axis=0)
A = A.repeat(axis=0, repeats=(len(B)))
rows = np.array(range(len(B)))
rows = np.expand_dims(rows, axis = 1).repeat(axis = 1, repeats = len(B[0]))
cols = np.array(range(len(B[0])))
cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B))
return 1 - np.prod(A[rows, cols, B], axis = 1)
def get_cost_list(self, pred_res, pred_res_prob, candidates):
if self.dist_func == 'hamming':
return self.hamming_dist(pred_res, candidates)
elif self.dist_func == 'confidence':
return self.confidence_dist(pred_res_prob, candidates)

def get_min_cost_candidate(self, pred_res, pred_res_prob, candidates):
if len(candidates) == 0:
return []
elif len(candidates) == 1:
return candidates[0]
else:
cost_list = self.get_cost_list(pred_res, pred_res_prob, candidates)
min_address_num = np.min(cost_list)
idxs = np.where(cost_list == min_address_num)[0]
return [candidates[idx] for idx in idxs][0]
@abstractmethod
def abduce_candidates(self):
pass
def filter_all_candidates(self, pred_res, all_candidates, max_address_num, require_more_address):
if len(all_candidates) == 0:
@@ -80,68 +44,321 @@ class AbducerBase(abc.ABC):
idxs = np.where(cost_list <= address_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
return candidates, min_address_num, address_num
def hamming_dist(self, A, B):
B = np.array(B)
A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B)))
return np.sum(A != B, axis = 1)
def __len__(self):
pass

class ClsKB(KBBase):
def __init__(self, GKB_flag = False, pseudo_label_list = None, len_list = None):
super().__init__()
self.GKB_flag = GKB_flag
self.pseudo_label_list = pseudo_label_list
self.len_list = len_list
self.prolog_flag = False
if GKB_flag:
# self.base = np.load('abducer/hwf.npy', allow_pickle=True).item()
self.base = {}
X, Y = self.get_GKB(self.pseudo_label_list, self.len_list)
for x, y in zip(X, Y):
self.base.setdefault(len(x), defaultdict(list))[y].append(x)
else:
self.all_address_candidate_dict = {}
for address_num in range(1, max(self.len_list) + 1):
self.all_address_candidate_dict[address_num] = list(product(self.pseudo_label_list, repeat = address_num))
def get_GKB(self, pseudo_label_list, len_list):
all_X = []
for len in len_list:
all_X += list(product(pseudo_label_list, repeat = len))
X = []
Y = []
for x in all_X:
y = self.logic_forward(x)
if y != np.inf:
X.append(x)
Y.append(y)
return X, Y
def logic_forward(self):
pass
def abduce_candidates(self, pred_res, key, max_address_num = -1, require_more_address = 0):
if max_address_num == -1:
max_address_num = len(pred_res)
if self.GKB_flag:
all_candidates = self.get_candidates_GKB(key, len(pred_res))
return self.filter_all_candidates(pred_res, all_candidates, max_address_num, require_more_address)
else:
return self.abduction(pred_res, key, max_address_num, require_more_address)

def abduce(self, data, max_address_num = -1, require_more_address = 0):
pred_res, pred_res_prob, ans = data

if self.cache and (tuple(pred_res), ans) in self.cache_min_address_num:
address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), ans)] + require_more_address)
if (tuple(pred_res), ans, address_num) in self.cache_candidates:
candidates = self.cache_candidates[(tuple(pred_res), ans, address_num)]
candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates)
return candidate

def get_candidates_GKB(self, key, length = None):
if self.base == {}:
return []
candidates, min_address_num, address_num = self.kb.abduce_candidates(pred_res, ans, max_address_num, require_more_address)
if key is None:
return self.get_all_candidates()
if length is None:
length = list(self.base.keys())
elif type(length) is int and length not in self.len_list:
return []
else:
length = [length]
return sum([self.base[l][key] for l in length], [])
def get_all_candidates(self):
if self.base == {}:
return []
else:
return sum([sum(v.values(), []) for v in self.base.values()], [])



def abduction(self, pred_res, key, max_address_num, require_more_address):
candidates = []
for address_num in range(len(pred_res) + 1):
if address_num == 0:
if abs(self.logic_forward(pred_res) - key) <= 1e-3:
candidates.append(pred_res)
else:
new_candidates = self.address(address_num, pred_res, key)
candidates += new_candidates
if len(candidates) > 0:
min_address_num = address_num
break
if address_num >= max_address_num:
return [], 0, 0
if self.cache:
self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num
self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates
for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if address_num > max_address_num:
return candidates, min_address_num, address_num - 1
new_candidates = self.address(address_num, pred_res, key)
candidates += new_candidates

candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates)
return candidate
return candidates, min_address_num, address_num
def address(self, address_num, pred_res, key):
new_candidates = []
all_address_candidate = self.all_address_candidate_dict[address_num]
address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
for address_idx in address_idx_list:
for c in all_address_candidate:
address_list = [pred_res[i] for i in address_idx]
if(sum([address_list[i] == c[i] for i in range(address_num)]) == 0):
candidate = pred_res.copy()
for i, idx in enumerate(address_idx):
candidate[idx] = c[i]
if self.logic_forward(candidate) == key:
new_candidates.append(candidate)
return new_candidates
def batch_abduce(self, Z, Y, max_address_num = -1, require_more_address = 0):
return [
self.abduce((z, prob, y), max_address_num, require_more_address)\
for z, prob, y in zip(Z['cls'], Z['prob'], Y)
]

def __call__(self, Z, Y, max_address_num = -1, require_more_address = 0):
return self.batch_abduce(Z, Y, max_address_num, require_more_address)
def _dict_len(self, dic):
if not self.GKB_flag:
return 0
else:
return sum(len(c) for c in dic.values())

def __len__(self):
if not self.GKB_flag:
return 0
else:
return sum(self._dict_len(v) for v in self.base.values())


class add_KB(ClsKB):
def __init__(self, GKB_flag = False, \
pseudo_label_list = list(range(10)), \
len_list = [2]):
super().__init__(GKB_flag, pseudo_label_list, len_list)
def logic_forward(self, nums):
return sum(nums)
class hwf_KB(ClsKB):
def __init__(self, GKB_flag = False, \
pseudo_label_list = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \
len_list = [1, 3, 5, 7]):
super().__init__(GKB_flag, pseudo_label_list, len_list)
def valid_candidate(self, formula):
if len(formula) % 2 == 0:
return False
for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']:
return False
if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']:
return False
return True
def logic_forward(self, formula):
if not self.valid_candidate(formula):
return np.inf
mapping = {'1':'1', '2':'2', '3':'3', '4':'4', '5':'5', '6':'6', '7':'7', '8':'8', '9':'9', '+':'+', '-':'-', 'times':'*', 'div':'/'}
formula = [mapping[f] for f in formula]
return round(eval(''.join(formula)), 2)





class prolog_KB(KBBase):
def __init__(self, pseudo_label_list):
super().__init__()
self.pseudo_label_list = pseudo_label_list
self.prolog = pyswip.Prolog()
for i in self.pseudo_label_list:
self.prolog.assertz("pseudo_label(%s)" % i)
self.prolog_flag = True
def logic_forward(self):
pass
def abduce_candidates(self, pred_res, key, max_address_num = -1, require_more_address = 0):
if max_address_num == -1:
max_address_num = len(pred_res)
all_candidates = self.get_candidates_prolog(key)
return self.filter_all_candidates(pred_res, all_candidates, max_address_num, require_more_address)

class add_prolog_KB(prolog_KB):
def __init__(self, pseudo_label_list = list(range(10))):
super().__init__(pseudo_label_list)
self.prolog.assertz("addition(Z1, Z2, Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2")
def logic_forward(self, nums):
return list(self.prolog.query("addition(%s, %s, Res)." %(nums[0], nums[1])))[0]['Res']
def get_candidates_prolog(self, key):
return [(z['Z1'], z['Z2']) for z in list(self.prolog.query("addition(Z1, Z2, %s)." % key))]


class RegKB(KBBase):
def __init__(self, GKB_flag = False, X = None, Y = None):
super().__init__()
tmp_dict = {}
for x, y in zip(X, Y):
tmp_dict.setdefault(len(x), defaultdict(list))[y].append(np.array(x))

self.base = {}
for l in tmp_dict.keys():
data = sorted(list(zip(tmp_dict[l].keys(), tmp_dict[l].values())))
X = [x for y, x in data]
Y = [y for y, x in data]
self.base[l] = (X, Y)

def valid_candidate(self):
pass
def logic_forward(self):
pass
def abduce_candidates(self, key, length = None):
if key is None:
return self.get_all_candidates()

length = self._length(length)

min_err = 999999
candidates = []
for l in length:
X, Y = self.base[l]

if __name__ == '__main__':
idx = bisect.bisect_left(Y, key)
begin = max(0, idx - 1)
end = min(idx + 2, len(X))

for idx in range(begin, end):
err = abs(Y[idx] - key)
if abs(err - min_err) < 1e-9:
candidates.extend(X[idx])
elif err < min_err:
candidates = copy.deepcopy(X[idx])
min_err = err
return candidates

def get_all_candidates(self):
return sum([sum(D[0], []) for D in self.base.values()], [])

def __len__(self):
return sum([sum(len(x) for x in D[0]) for D in self.base.values()])

import time
if __name__ == "__main__":
# With ground KB
kb = add_KB(GKB_flag = True)
abd = AbducerBase(kb, 'hamming')
res = abd.abduce(([1, 1], None, 17), max_address_num = 2, require_more_address = 0)
print('len(kb):', len(kb))
res = kb.get_candidates_GKB(0)
print(res)
res = abd.abduce(([1, 1], None, 17), max_address_num = 1, require_more_address = 0)
res = kb.get_candidates_GKB(18)
print(res)
res = abd.abduce(([1, 1], None, 20), max_address_num = 2, require_more_address = 0)
res = kb.get_candidates_GKB(18)
print(res)
res = kb.get_candidates_GKB(16)
print(res)
print()
kb = add_prolog_KB()
abd = AbducerBase(kb, 'hamming')
res = abd.abduce(([1, 1], None, 17), max_address_num = 2, require_more_address = 0)
print(res)
res = abd.abduce(([1, 1], None, 17), max_address_num = 1, require_more_address = 0)
print(res)
res = abd.abduce(([1, 1], None, 20), max_address_num = 2, require_more_address = 0)
print(res)
# Without ground KB
kb = add_KB()
print('len(kb):', len(kb))
print()
kb = hwf_KB(len_list = [1, 3, 5])
abd = AbducerBase(kb, 'hamming')
res = abd.abduce((['5', '+', '2'], None, 3), max_address_num = 2, require_more_address = 0)
# Prolog
kb = add_prolog_KB()
print(kb.logic_forward([3, 4]))
res = kb.get_candidates_prolog(16)
print(res)
res = abd.abduce((['5', '+', '2'], None, 64), max_address_num = 3, require_more_address = 0)
start = time.time()
kb = hwf_KB(GKB_flag = True, len_list = [1, 3, 5])
print(time.time() - start)
print('len(kb):', len(kb))
res = kb.get_candidates_GKB(2, length = 1)
print(res)
res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num = 3, require_more_address = 0)
res = kb.get_candidates_GKB(1, length = 3)
print(res)
res = abd.abduce((['5', '8', '8', '8', '8'], None, 3.17), max_address_num = 5, require_more_address = 3)
res = kb.get_candidates_GKB(3.67, length = 5)
print(res)
print()
# X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"]
# Y = [2, 1, 1, 2, 2]
# kb = ClsKB(X, Y)
# print('len(kb):', len(kb))
# res = kb.get_candidates(2, 5)
# print(res)
# res = kb.get_candidates(2, 3)
# print(res)
# res = kb.get_candidates(None)
# print(res)
# print()
# X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"]
# Y = [2, 1, 1, 2, 1.5, 1.5]
# kb = RegKB(X, Y)
# print('len(kb):', len(kb))
# res = kb.get_candidates(1.6)
# print(res)
# res = kb.get_candidates(1.6, length = 9)
# print(res)
# res = kb.get_candidates(None)
# print(res)


Loading…
Cancel
Save