|
|
@@ -45,7 +45,10 @@ class AbducerBase(abc.ABC): |
|
|
|
dist_func = confidence_dist |
|
|
|
self.dist_func = dist_func |
|
|
|
if pred_res_parse is None: |
|
|
|
pred_res_parse = lambda x : x["cls"] |
|
|
|
if(dist_func == "hamming"): |
|
|
|
pred_res_parse = lambda x : x["cls"] |
|
|
|
elif dist_func == "confidence": |
|
|
|
pred_res_parse = lambda x : x[" "] |
|
|
|
self.pred_res_parse = pred_res_parse |
|
|
|
|
|
|
|
self.cache = cache |
|
|
@@ -75,20 +78,18 @@ class AbducerBase(abc.ABC): |
|
|
|
|
|
|
|
else: |
|
|
|
candidates, min_address_num, address_num = self.get_abduce_candidates(pred_res, ans, max_address_num, require_more_address) |
|
|
|
cost_list = self.dist_func(pred_res, candidates) |
|
|
|
|
|
|
|
if(self.cache): |
|
|
|
self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num |
|
|
|
self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates |
|
|
|
|
|
|
|
return candidates |
|
|
|
|
|
|
|
|
|
|
|
cost_list = self.dist_func(pred_res, candidates) |
|
|
|
min_address_num = np.min(cost_list) |
|
|
|
idxs = np.where(cost_list == min_address_num)[0] |
|
|
|
candidates = [candidates[idx] for idx in idxs] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# if len(idxs) > 1: |
|
|
|
# return None |
|
|
|
# return [candidates[idx] for idx in idxs] |
|
|
|
return candidates[0] |
|
|
|
|
|
|
|
def address(self, address_num, pred_res, key): |
|
|
|
new_candidates = [] |
|
|
@@ -99,7 +100,7 @@ class AbducerBase(abc.ABC): |
|
|
|
pred_res_array = np.array(pred_res) |
|
|
|
if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num): |
|
|
|
pred_res_array[np.array(address_idx)] = c |
|
|
|
if(self.kb.logic_forward(pred_res_array) == key): |
|
|
|
if(abs(self.kb.logic_forward(pred_res_array) - key) <= 1e-3): |
|
|
|
new_candidates.append(pred_res_array) |
|
|
|
return new_candidates, address_num |
|
|
|
|
|
|
@@ -113,7 +114,7 @@ class AbducerBase(abc.ABC): |
|
|
|
return None, None, None |
|
|
|
|
|
|
|
if(address_num == 0): |
|
|
|
if(self.kb.logic_forward(pred_res) == key): |
|
|
|
if(abs(self.kb.logic_forward(pred_res) - key) <= 1e-3): |
|
|
|
candidates.append(pred_res) |
|
|
|
else: |
|
|
|
new_candidates, address_num = self.address(address_num, pred_res, key) |
|
|
@@ -148,7 +149,7 @@ if __name__ == "__main__": |
|
|
|
pseudo_label_list = list(range(10)) |
|
|
|
kb = add_KB(pseudo_label_list) |
|
|
|
abd = AbducerBase(kb) |
|
|
|
res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 1) |
|
|
|
res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 0) |
|
|
|
print(res) |
|
|
|
print() |
|
|
|
res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 1) |
|
|
|