@@ -14,22 +14,48 @@ import abc | |||
from kb import add_KB | |||
import numpy as np | |||
from itertools import product, combinations | |||
def hamming_dist(A, B): | |||
return np.sum(np.array(A) != np.array(B)) | |||
B = np.array(B) | |||
A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B))) | |||
return np.sum(A != B, axis = 1) | |||
def confidence_dist(A, B): | |||
B = np.array(B) | |||
#print(A) | |||
A = np.clip(A, 1e-9, 1) | |||
A = np.expand_dims(A, axis=0) | |||
A = A.repeat(axis=0, repeats=(len(B))) | |||
rows = np.array(range(len(B))) | |||
rows = np.expand_dims(rows, axis = 1).repeat(axis = 1, repeats = len(B[0])) | |||
cols = np.array(range(len(B[0]))) | |||
cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B)) | |||
return 1 - np.prod(A[rows, cols, B], axis = 1) | |||
class AbducerBase(abc.ABC): | |||
def __init__(self, kb, dist_func = "hamming", pred_res_parse = None, cache = True): | |||
self.kb = kb | |||
if dist_func == "hamming": | |||
self.dist_func = hamming_dist | |||
dist_func = hamming_dist | |||
elif dist_func == "confidence": | |||
dist_func = confidence_dist | |||
self.dist_func = dist_func | |||
if pred_res_parse is None: | |||
pred_res_parse = lambda x : x["cls"] | |||
if(dist_func == "hamming"): | |||
pred_res_parse = lambda x : x["cls"] | |||
elif dist_func == "confidence": | |||
pred_res_parse = lambda x : x[" "] | |||
self.pred_res_parse = pred_res_parse | |||
self.cache = cache | |||
self.cache_min_address_num = {} | |||
self.cache_candidates = {} | |||
def abduce(self, data, max_address_num = 3, require_more_address = 0, length = -1): | |||
pred_res, ans = data | |||
@@ -42,30 +68,72 @@ class AbducerBase(abc.ABC): | |||
print('cached') | |||
return self.cache_candidates[(tuple(pred_res), ans, address_num)] | |||
candidates, min_address_num, address_num = self.kb.get_abduce_candidates(pred_res, ans, length, self.dist_func, max_address_num, require_more_address) | |||
if(self.kb.base != {}): | |||
all_candidates = self.kb.get_candidates(ans, length) | |||
cost_list = self.dist_func(pred_res, all_candidates) | |||
min_address_num = np.min(cost_list) | |||
address_num = min(max_address_num, min_address_num + require_more_address) | |||
idxs = np.where(cost_list <= address_num)[0] | |||
candidates = [all_candidates[idx] for idx in idxs] | |||
else: | |||
candidates, min_address_num, address_num = self.get_abduce_candidates(pred_res, ans, max_address_num, require_more_address) | |||
cost_list = self.dist_func(pred_res, candidates) | |||
if(self.cache): | |||
self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num | |||
self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates | |||
return candidates | |||
# candidates = self.kb.get_candidates(ans, length) | |||
cost_list = self.dist_func(pred_res, candidates) | |||
min_address_num = np.min(cost_list) | |||
idxs = np.where(cost_list == min_address_num)[0] | |||
candidates = [candidates[idx] for idx in idxs] | |||
return candidates[0] | |||
def address(self, address_num, pred_res, key): | |||
new_candidates = [] | |||
all_address_candidate = list(product(self.kb.pseudo_label_list, repeat = address_num)) | |||
address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | |||
for address_idx in address_idx_list: | |||
for c in all_address_candidate: | |||
pred_res_array = np.array(pred_res) | |||
if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num): | |||
pred_res_array[np.array(address_idx)] = c | |||
if(abs(self.kb.logic_forward(pred_res_array) - key) <= 1e-3): | |||
new_candidates.append(pred_res_array) | |||
return new_candidates, address_num | |||
def get_abduce_candidates(self, pred_res, key, max_address_num, require_more_address): | |||
candidates = [] | |||
for address_num in range(len(pred_res) + 1): | |||
if(address_num > max_address_num): | |||
print('No candidates found') | |||
return None, None, None | |||
if(address_num == 0): | |||
if(abs(self.kb.logic_forward(pred_res) - key) <= 1e-3): | |||
candidates.append(pred_res) | |||
else: | |||
new_candidates, address_num = self.address(address_num, pred_res, key) | |||
candidates += new_candidates | |||
if(len(candidates) > 0): | |||
min_address_num = address_num | |||
break | |||
# cost_list = self.dist_func(pred_res, candidates) | |||
# address_num = np.min(cost_list) | |||
# # threshold = min(address_num + require_more_address, max_address_num) | |||
# idxs = np.where(cost_list <= address_num + require_more_address)[0] | |||
for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | |||
if(address_num > max_address_num): | |||
return candidates, min_address_num, address_num - 1 | |||
new_candidates, address_num = self.address(address_num, pred_res, key) | |||
candidates += new_candidates | |||
# return [candidates[idx] for idx in idxs], address_num | |||
return candidates, min_address_num, address_num | |||
# if len(idxs) > 1: | |||
# return None | |||
# return [candidates[idx] for idx in idxs] | |||
def batch_abduce(self, Y, C, max_address_num = 3, require_more_address = 0): | |||
return [ | |||
self.abduce((y, c), max_address_num, require_more_address)\ | |||
@@ -83,17 +151,19 @@ if __name__ == "__main__": | |||
abd = AbducerBase(kb) | |||
res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 0) | |||
print(res) | |||
print() | |||
res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 1) | |||
print(res) | |||
print() | |||
res = abd.abduce(([1, 1, 1], 4), max_address_num = 1, require_more_address = 1) | |||
print(res) | |||
print() | |||
print('Test cache') | |||
res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 0) | |||
print(res) | |||
res = abd.abduce(([1, 1, 1], 4), max_address_num = 20, require_more_address = 1) | |||
print() | |||
res = abd.abduce(([1, 1, 1], 5), max_address_num = 2, require_more_address = 1) | |||
print(res) | |||
# res = abd.abduce(([0, 2, 0], 0.99), 1, 0) | |||
# print(res) | |||
@@ -16,7 +16,6 @@ import copy | |||
import numpy as np | |||
from collections import defaultdict | |||
from itertools import product | |||
class KBBase(ABC): | |||
@@ -46,16 +45,17 @@ class KBBase(ABC): | |||
pass | |||
class add_KB(KBBase): | |||
def __init__(self, pseudo_label_list, max_len = 5): | |||
def __init__(self, pseudo_label_list, kb_max_len = -1): | |||
super().__init__() | |||
self.pseudo_label_list = pseudo_label_list | |||
self.base = {} | |||
X = self.get_X(self.pseudo_label_list, max_len) | |||
Y = self.get_Y(X, self.logic_forward) | |||
self.kb_max_len = kb_max_len | |||
if(self.kb_max_len > 0): | |||
X = self.get_X(self.pseudo_label_list, self.kb_max_len) | |||
Y = self.get_Y(X, self.logic_forward) | |||
for x, y in zip(X, Y): | |||
self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) | |||
for x, y in zip(X, Y): | |||
self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) | |||
def logic_forward(self, nums): | |||
return sum(nums) | |||
@@ -71,50 +71,73 @@ class add_KB(KBBase): | |||
return [logic_forward(nums) for nums in X] | |||
def get_candidates(self, key, length = None): | |||
if(self.base == {}): | |||
return [] | |||
if key is None: | |||
return self.get_all_candidates() | |||
length = self._length(length) | |||
if(self.kb_max_len < min(length)): | |||
return [] | |||
return sum([self.base[l][key] for l in length], []) | |||
def get_all_candidates(self): | |||
return sum([sum(v.values(), []) for v in self.base.values()], []) | |||
def get_abduce_candidates(self, pred_res, key, length, dist_func, max_address_num, require_more_address): | |||
if key is None: | |||
return self.get_all_candidates() | |||
candidates = [] | |||
all_candidates = list(product(self.pseudo_label_list, repeat = len(pred_res))) | |||
for address_num in range(length + 1): | |||
if(address_num > max_address_num): | |||
print('No candidates found') | |||
return None, None, None | |||
for c in all_candidates: | |||
if(dist_func(c, pred_res) == address_num): | |||
if(self.logic_forward(c) == key): | |||
candidates.append(c) | |||
if(len(candidates) > 0): | |||
min_address_num = address_num | |||
break | |||
for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | |||
if(address_num > max_address_num): | |||
return candidates, min_address_num, address_num - 1 | |||
for c in all_candidates: | |||
if(dist_func(c, pred_res) == address_num): | |||
if(self.logic_forward(c) == key): | |||
candidates.append(c) | |||
return candidates, min_address_num, address_num | |||
def _dict_len(self, dic): | |||
return sum(len(c) for c in dic.values()) | |||
def __len__(self): | |||
return sum(self._dict_len(v) for v in self.base.values()) | |||
# class hwf_KB(KBBase): | |||
# def __init__(self, pseudo_label_list, kb_max_len = -1): | |||
# super().__init__() | |||
# self.pseudo_label_list = pseudo_label_list | |||
# self.base = {} | |||
# self.kb_max_len = kb_max_len | |||
# if(self.kb_max_len > 0): | |||
# X = self.get_X(self.pseudo_label_list, self.kb_max_len) | |||
# Y = self.get_Y(X, self.logic_forward) | |||
# for x, y in zip(X, Y): | |||
# self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) | |||
# def logic_forward(self, nums): | |||
# return sum(nums) | |||
# def get_X(self, pseudo_label_list, max_len): | |||
# res = [] | |||
# assert(max_len >= 2) | |||
# for len in range(2, max_len + 1): | |||
# res += list(product(pseudo_label_list, repeat = len)) | |||
# return res | |||
# def get_Y(self, X, logic_forward): | |||
# return [logic_forward(nums) for nums in X] | |||
# def get_candidates(self, key, length = None): | |||
# if(self.base == {}): | |||
# return [] | |||
# if key is None: | |||
# return self.get_all_candidates() | |||
# length = self._length(length) | |||
# if(self.kb_max_len < min(length)): | |||
# return [] | |||
# return sum([self.base[l][key] for l in length], []) | |||
# def get_all_candidates(self): | |||
# return sum([sum(v.values(), []) for v in self.base.values()], []) | |||
# def _dict_len(self, dic): | |||
# return sum(len(c) for c in dic.values()) | |||
# def __len__(self): | |||
# return sum(self._dict_len(v) for v in self.base.values()) | |||
class cls_KB(KBBase): | |||
def __init__(self, X, Y = None): | |||
super().__init__() | |||
@@ -197,24 +220,44 @@ class reg_KB(KBBase): | |||
return sum([sum(len(x) for x in D[0]) for D in self.base.values()]) | |||
if __name__ == "__main__": | |||
# With ground KB | |||
pseudo_label_list = list(range(10)) | |||
kb = add_KB(pseudo_label_list, max_len = 5) | |||
kb = add_KB(pseudo_label_list, kb_max_len = 5) | |||
print('len(kb):', len(kb)) | |||
print() | |||
res = kb.get_candidates(0) | |||
print(res) | |||
print() | |||
res = kb.get_candidates(18, length = 2) | |||
print(res) | |||
res = kb.get_candidates(18, length = 8) | |||
print(res) | |||
res = kb.get_candidates(7, length = 3) | |||
print(res) | |||
print() | |||
# Without ground KB | |||
pseudo_label_list = list(range(10)) | |||
kb = add_KB(pseudo_label_list) | |||
print('len(kb):', len(kb)) | |||
res = kb.get_candidates(0) | |||
print(res) | |||
res = kb.get_candidates(18, length = 2) | |||
print(res) | |||
res = kb.get_candidates(18, length = 8) | |||
print(res) | |||
res = kb.get_candidates(7, length = 3) | |||
print(res) | |||
print() | |||
# pseudo_label_list = list(range(10)) + ['+', '-', '*', '/'] | |||
# kb = hwf_KB(pseudo_label_list, max_len = 5) | |||
# print('len(kb):', len(kb)) | |||
# print() | |||
X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] | |||
Y = [2, 1, 1, 2, 2] | |||
kb = cls_KB(X, Y) | |||
print(len(kb)) | |||
print('len(kb):', len(kb)) | |||
res = kb.get_candidates(2, 5) | |||
print(res) | |||
res = kb.get_candidates(2, 3) | |||
@@ -226,7 +269,7 @@ if __name__ == "__main__": | |||
X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"] | |||
Y = [2, 1, 1, 2, 1.5, 1.5] | |||
kb = reg_KB(X, Y) | |||
print(len(kb)) | |||
print('len(kb):', len(kb)) | |||
res = kb.get_candidates(1.6) | |||
print(res) | |||
res = kb.get_candidates(1.6, length = 9) | |||