|
|
@@ -47,15 +47,15 @@ class AbducerBase(abc.ABC): |
|
|
|
|
|
|
|
|
|
|
|
def get_cost_list(self, pred_res, pred_res_prob, candidates): |
|
|
|
if(self.dist_func == 'hamming'): |
|
|
|
if self.dist_func == 'hamming': |
|
|
|
return self.hamming_dist(pred_res, candidates) |
|
|
|
elif(self.dist_func == 'confidence'): |
|
|
|
elif self.dist_func == 'confidence': |
|
|
|
return self.confidence_dist(pred_res_prob, candidates) |
|
|
|
|
|
|
|
def get_min_cost_candidate(self, pred_res, pred_res_prob, candidates): |
|
|
|
if(len(candidates) == 0): |
|
|
|
if len(candidates) == 0: |
|
|
|
return [] |
|
|
|
elif(len(candidates) == 1): |
|
|
|
elif len(candidates) == 1: |
|
|
|
return candidates[0] |
|
|
|
else: |
|
|
|
cost_list = self.get_cost_list(pred_res, pred_res_prob, candidates) |
|
|
@@ -68,9 +68,9 @@ class AbducerBase(abc.ABC): |
|
|
|
if max_address_num == -1: |
|
|
|
max_address_num = len(pred_res) |
|
|
|
|
|
|
|
if(self.cache and (tuple(pred_res), ans) in self.cache_min_address_num): |
|
|
|
if self.cache and (tuple(pred_res), ans) in self.cache_min_address_num: |
|
|
|
address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), ans)] + require_more_address) |
|
|
|
if((tuple(pred_res), ans, address_num) in self.cache_candidates): |
|
|
|
if (tuple(pred_res), ans, address_num) in self.cache_candidates: |
|
|
|
# print('cached') |
|
|
|
candidates = self.cache_candidates[(tuple(pred_res), ans, address_num)] |
|
|
|
candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) |
|
|
@@ -78,7 +78,7 @@ class AbducerBase(abc.ABC): |
|
|
|
|
|
|
|
if self.kb.GKB_flag: |
|
|
|
all_candidates = self.kb.get_candidates(ans, len(pred_res)) |
|
|
|
if(len(all_candidates) == 0): |
|
|
|
if len(all_candidates) == 0: |
|
|
|
return [] |
|
|
|
else: |
|
|
|
cost_list = self.hamming_dist(pred_res, all_candidates) |
|
|
@@ -90,7 +90,7 @@ class AbducerBase(abc.ABC): |
|
|
|
else: |
|
|
|
candidates, min_address_num, address_num = self.get_abduce_candidates(pred_res, ans, max_address_num, require_more_address) |
|
|
|
|
|
|
|
if(self.cache): |
|
|
|
if self.cache: |
|
|
|
self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num |
|
|
|
self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates |
|
|
|
|
|
|
@@ -104,9 +104,9 @@ class AbducerBase(abc.ABC): |
|
|
|
for address_idx in address_idx_list: |
|
|
|
for c in all_address_candidate: |
|
|
|
pred_res_array = np.array(pred_res) |
|
|
|
if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num): |
|
|
|
if np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num: |
|
|
|
pred_res_array[np.array(address_idx)] = c |
|
|
|
if(self.kb.logic_forward(pred_res_array) == key): |
|
|
|
if self.kb.logic_forward(pred_res_array) == key: |
|
|
|
new_candidates.append(pred_res_array) |
|
|
|
return new_candidates, address_num |
|
|
|
|
|
|
@@ -114,22 +114,22 @@ class AbducerBase(abc.ABC): |
|
|
|
|
|
|
|
candidates = [] |
|
|
|
for address_num in range(len(pred_res) + 1): |
|
|
|
if(address_num > max_address_num): |
|
|
|
if address_num > max_address_num: |
|
|
|
return [], None, None |
|
|
|
|
|
|
|
if(address_num == 0): |
|
|
|
if(abs(self.kb.logic_forward(pred_res) - key) <= 1e-3): |
|
|
|
if address_num == 0: |
|
|
|
if abs(self.kb.logic_forward(pred_res) - key) <= 1e-3: |
|
|
|
candidates.append(pred_res) |
|
|
|
else: |
|
|
|
new_candidates, address_num = self.address(address_num, pred_res, key) |
|
|
|
candidates += new_candidates |
|
|
|
|
|
|
|
if(len(candidates) > 0): |
|
|
|
if len(candidates) > 0: |
|
|
|
min_address_num = address_num |
|
|
|
break |
|
|
|
|
|
|
|
for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): |
|
|
|
if(address_num > max_address_num): |
|
|
|
if address_num > max_address_num: |
|
|
|
return candidates, min_address_num, address_num - 1 |
|
|
|
new_candidates, address_num = self.address(address_num, pred_res, key) |
|
|
|
candidates += new_candidates |
|
|
|