Browse Source

Merge branch 'Dev' of https://github.com/AbductiveLearning/ABL-Package into Dev

pull/3/head
Gao Enhao 2 years ago
parent
commit
6b3762acc7
2 changed files with 174 additions and 61 deletions
  1. +90
    -20
      abducer/abducer_base.py
  2. +84
    -41
      abducer/kb.py

+ 90
- 20
abducer/abducer_base.py View File

@@ -14,22 +14,48 @@ import abc
from kb import add_KB
import numpy as np

from itertools import product, combinations

def hamming_dist(A, B):
return np.sum(np.array(A) != np.array(B))
B = np.array(B)
A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B)))
return np.sum(A != B, axis = 1)

def confidence_dist(A, B):
B = np.array(B)

#print(A)
A = np.clip(A, 1e-9, 1)
A = np.expand_dims(A, axis=0)
A = A.repeat(axis=0, repeats=(len(B)))
rows = np.array(range(len(B)))
rows = np.expand_dims(rows, axis = 1).repeat(axis = 1, repeats = len(B[0]))
cols = np.array(range(len(B[0])))
cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B))
return 1 - np.prod(A[rows, cols, B], axis = 1)



class AbducerBase(abc.ABC):
def __init__(self, kb, dist_func = "hamming", pred_res_parse = None, cache = True):
self.kb = kb
if dist_func == "hamming":
self.dist_func = hamming_dist
dist_func = hamming_dist
elif dist_func == "confidence":
dist_func = confidence_dist
self.dist_func = dist_func
if pred_res_parse is None:
pred_res_parse = lambda x : x["cls"]
if(dist_func == "hamming"):
pred_res_parse = lambda x : x["cls"]
elif dist_func == "confidence":
pred_res_parse = lambda x : x[" "]
self.pred_res_parse = pred_res_parse
self.cache = cache
self.cache_min_address_num = {}
self.cache_candidates = {}


def abduce(self, data, max_address_num = 3, require_more_address = 0, length = -1):
pred_res, ans = data

@@ -42,30 +68,72 @@ class AbducerBase(abc.ABC):
print('cached')
return self.cache_candidates[(tuple(pred_res), ans, address_num)]
candidates, min_address_num, address_num = self.kb.get_abduce_candidates(pred_res, ans, length, self.dist_func, max_address_num, require_more_address)
if(self.kb.base != {}):
all_candidates = self.kb.get_candidates(ans, length)
cost_list = self.dist_func(pred_res, all_candidates)
min_address_num = np.min(cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
idxs = np.where(cost_list <= address_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
else:
candidates, min_address_num, address_num = self.get_abduce_candidates(pred_res, ans, max_address_num, require_more_address)
cost_list = self.dist_func(pred_res, candidates)
if(self.cache):
self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num
self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates

return candidates
# candidates = self.kb.get_candidates(ans, length)
cost_list = self.dist_func(pred_res, candidates)
min_address_num = np.min(cost_list)
idxs = np.where(cost_list == min_address_num)[0]
candidates = [candidates[idx] for idx in idxs]
return candidates[0]
def address(self, address_num, pred_res, key):
new_candidates = []
all_address_candidate = list(product(self.kb.pseudo_label_list, repeat = address_num))
address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
for address_idx in address_idx_list:
for c in all_address_candidate:
pred_res_array = np.array(pred_res)
if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num):
pred_res_array[np.array(address_idx)] = c
if(abs(self.kb.logic_forward(pred_res_array) - key) <= 1e-3):
new_candidates.append(pred_res_array)
return new_candidates, address_num
def get_abduce_candidates(self, pred_res, key, max_address_num, require_more_address):
candidates = []

for address_num in range(len(pred_res) + 1):
if(address_num > max_address_num):
print('No candidates found')
return None, None, None
if(address_num == 0):
if(abs(self.kb.logic_forward(pred_res) - key) <= 1e-3):
candidates.append(pred_res)
else:
new_candidates, address_num = self.address(address_num, pred_res, key)
candidates += new_candidates
if(len(candidates) > 0):
min_address_num = address_num
break
# cost_list = self.dist_func(pred_res, candidates)
# address_num = np.min(cost_list)
# # threshold = min(address_num + require_more_address, max_address_num)
# idxs = np.where(cost_list <= address_num + require_more_address)[0]
for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if(address_num > max_address_num):
return candidates, min_address_num, address_num - 1
new_candidates, address_num = self.address(address_num, pred_res, key)
candidates += new_candidates

# return [candidates[idx] for idx in idxs], address_num
return candidates, min_address_num, address_num
# if len(idxs) > 1:
# return None
# return [candidates[idx] for idx in idxs]

def batch_abduce(self, Y, C, max_address_num = 3, require_more_address = 0):
return [
self.abduce((y, c), max_address_num, require_more_address)\
@@ -83,17 +151,19 @@ if __name__ == "__main__":
abd = AbducerBase(kb)
res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 0)
print(res)
print()
res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 1)
print(res)
print()
res = abd.abduce(([1, 1, 1], 4), max_address_num = 1, require_more_address = 1)
print(res)
print()
print('Test cache')
res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 0)
print(res)
res = abd.abduce(([1, 1, 1], 4), max_address_num = 20, require_more_address = 1)
print()
res = abd.abduce(([1, 1, 1], 5), max_address_num = 2, require_more_address = 1)
print(res)
# res = abd.abduce(([0, 2, 0], 0.99), 1, 0)
# print(res)


+ 84
- 41
abducer/kb.py View File

@@ -16,7 +16,6 @@ import copy
import numpy as np

from collections import defaultdict

from itertools import product

class KBBase(ABC):
@@ -46,16 +45,17 @@ class KBBase(ABC):
pass

class add_KB(KBBase):
def __init__(self, pseudo_label_list, max_len = 5):
def __init__(self, pseudo_label_list, kb_max_len = -1):
super().__init__()
self.pseudo_label_list = pseudo_label_list
self.base = {}
X = self.get_X(self.pseudo_label_list, max_len)
Y = self.get_Y(X, self.logic_forward)
self.kb_max_len = kb_max_len
if(self.kb_max_len > 0):
X = self.get_X(self.pseudo_label_list, self.kb_max_len)
Y = self.get_Y(X, self.logic_forward)

for x, y in zip(X, Y):
self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x))
for x, y in zip(X, Y):
self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x))
def logic_forward(self, nums):
return sum(nums)
@@ -71,50 +71,73 @@ class add_KB(KBBase):
return [logic_forward(nums) for nums in X]

def get_candidates(self, key, length = None):
if(self.base == {}):
return []
if key is None:
return self.get_all_candidates()

length = self._length(length)
if(self.kb_max_len < min(length)):
return []
return sum([self.base[l][key] for l in length], [])
def get_all_candidates(self):
return sum([sum(v.values(), []) for v in self.base.values()], [])
def get_abduce_candidates(self, pred_res, key, length, dist_func, max_address_num, require_more_address):
if key is None:
return self.get_all_candidates()
candidates = []
all_candidates = list(product(self.pseudo_label_list, repeat = len(pred_res)))
for address_num in range(length + 1):
if(address_num > max_address_num):
print('No candidates found')
return None, None, None
for c in all_candidates:
if(dist_func(c, pred_res) == address_num):
if(self.logic_forward(c) == key):
candidates.append(c)
if(len(candidates) > 0):
min_address_num = address_num
break
for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if(address_num > max_address_num):
return candidates, min_address_num, address_num - 1
for c in all_candidates:
if(dist_func(c, pred_res) == address_num):
if(self.logic_forward(c) == key):
candidates.append(c)

return candidates, min_address_num, address_num

def _dict_len(self, dic):
return sum(len(c) for c in dic.values())

def __len__(self):
return sum(self._dict_len(v) for v in self.base.values())
# class hwf_KB(KBBase):
# def __init__(self, pseudo_label_list, kb_max_len = -1):
# super().__init__()
# self.pseudo_label_list = pseudo_label_list
# self.base = {}
# self.kb_max_len = kb_max_len
# if(self.kb_max_len > 0):
# X = self.get_X(self.pseudo_label_list, self.kb_max_len)
# Y = self.get_Y(X, self.logic_forward)

# for x, y in zip(X, Y):
# self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x))
# def logic_forward(self, nums):
# return sum(nums)
# def get_X(self, pseudo_label_list, max_len):
# res = []
# assert(max_len >= 2)
# for len in range(2, max_len + 1):
# res += list(product(pseudo_label_list, repeat = len))
# return res

# def get_Y(self, X, logic_forward):
# return [logic_forward(nums) for nums in X]

# def get_candidates(self, key, length = None):
# if(self.base == {}):
# return []
# if key is None:
# return self.get_all_candidates()

# length = self._length(length)
# if(self.kb_max_len < min(length)):
# return []
# return sum([self.base[l][key] for l in length], [])
# def get_all_candidates(self):
# return sum([sum(v.values(), []) for v in self.base.values()], [])

# def _dict_len(self, dic):
# return sum(len(c) for c in dic.values())

# def __len__(self):
# return sum(self._dict_len(v) for v in self.base.values())
class cls_KB(KBBase):
def __init__(self, X, Y = None):
super().__init__()
@@ -197,24 +220,44 @@ class reg_KB(KBBase):
return sum([sum(len(x) for x in D[0]) for D in self.base.values()])

if __name__ == "__main__":
# With ground KB
pseudo_label_list = list(range(10))
kb = add_KB(pseudo_label_list, max_len = 5)
kb = add_KB(pseudo_label_list, kb_max_len = 5)
print('len(kb):', len(kb))
print()
res = kb.get_candidates(0)
print(res)
print()
res = kb.get_candidates(18, length = 2)
print(res)
res = kb.get_candidates(18, length = 8)
print(res)
res = kb.get_candidates(7, length = 3)
print(res)
print()
# Without ground KB
pseudo_label_list = list(range(10))
kb = add_KB(pseudo_label_list)
print('len(kb):', len(kb))
res = kb.get_candidates(0)
print(res)
res = kb.get_candidates(18, length = 2)
print(res)
res = kb.get_candidates(18, length = 8)
print(res)
res = kb.get_candidates(7, length = 3)
print(res)
print()
# pseudo_label_list = list(range(10)) + ['+', '-', '*', '/']
# kb = hwf_KB(pseudo_label_list, max_len = 5)
# print('len(kb):', len(kb))
# print()
X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"]
Y = [2, 1, 1, 2, 2]
kb = cls_KB(X, Y)
print(len(kb))
print('len(kb):', len(kb))
res = kb.get_candidates(2, 5)
print(res)
res = kb.get_candidates(2, 3)
@@ -226,7 +269,7 @@ if __name__ == "__main__":
X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"]
Y = [2, 1, 1, 2, 1.5, 1.5]
kb = reg_KB(X, Y)
print(len(kb))
print('len(kb):', len(kb))
res = kb.get_candidates(1.6)
print(res)
res = kb.get_candidates(1.6, length = 9)


Loading…
Cancel
Save