Browse Source

Update abducer_base.py

pull/3/head
troyyyyy GitHub 2 years ago
parent
commit
41be49c68e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 15 deletions
  1. +15
    -15
      abducer/abducer_base.py

+ 15
- 15
abducer/abducer_base.py View File

@@ -47,15 +47,15 @@ class AbducerBase(abc.ABC):
def get_cost_list(self, pred_res, pred_res_prob, candidates):
if(self.dist_func == 'hamming'):
if self.dist_func == 'hamming':
return self.hamming_dist(pred_res, candidates)
elif(self.dist_func == 'confidence'):
elif self.dist_func == 'confidence':
return self.confidence_dist(pred_res_prob, candidates)

def get_min_cost_candidate(self, pred_res, pred_res_prob, candidates):
if(len(candidates) == 0):
if len(candidates) == 0:
return []
elif(len(candidates) == 1):
elif len(candidates) == 1:
return candidates[0]
else:
cost_list = self.get_cost_list(pred_res, pred_res_prob, candidates)
@@ -68,9 +68,9 @@ class AbducerBase(abc.ABC):
if max_address_num == -1:
max_address_num = len(pred_res)

if(self.cache and (tuple(pred_res), ans) in self.cache_min_address_num):
if self.cache and (tuple(pred_res), ans) in self.cache_min_address_num:
address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), ans)] + require_more_address)
if((tuple(pred_res), ans, address_num) in self.cache_candidates):
if (tuple(pred_res), ans, address_num) in self.cache_candidates:
# print('cached')
candidates = self.cache_candidates[(tuple(pred_res), ans, address_num)]
candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates)
@@ -78,7 +78,7 @@ class AbducerBase(abc.ABC):
if self.kb.GKB_flag:
all_candidates = self.kb.get_candidates(ans, len(pred_res))
if(len(all_candidates) == 0):
if len(all_candidates) == 0:
return []
else:
cost_list = self.hamming_dist(pred_res, all_candidates)
@@ -90,7 +90,7 @@ class AbducerBase(abc.ABC):
else:
candidates, min_address_num, address_num = self.get_abduce_candidates(pred_res, ans, max_address_num, require_more_address)
if(self.cache):
if self.cache:
self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num
self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates

@@ -104,9 +104,9 @@ class AbducerBase(abc.ABC):
for address_idx in address_idx_list:
for c in all_address_candidate:
pred_res_array = np.array(pred_res)
if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num):
if np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num:
pred_res_array[np.array(address_idx)] = c
if(self.kb.logic_forward(pred_res_array) == key):
if self.kb.logic_forward(pred_res_array) == key:
new_candidates.append(pred_res_array)
return new_candidates, address_num
@@ -114,22 +114,22 @@ class AbducerBase(abc.ABC):
candidates = []
for address_num in range(len(pred_res) + 1):
if(address_num > max_address_num):
if address_num > max_address_num:
return [], None, None
if(address_num == 0):
if(abs(self.kb.logic_forward(pred_res) - key) <= 1e-3):
if address_num == 0:
if abs(self.kb.logic_forward(pred_res) - key) <= 1e-3:
candidates.append(pred_res)
else:
new_candidates, address_num = self.address(address_num, pred_res, key)
candidates += new_candidates
if(len(candidates) > 0):
if len(candidates) > 0:
min_address_num = address_num
break
for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if(address_num > max_address_num):
if address_num > max_address_num:
return candidates, min_address_num, address_num - 1
new_candidates, address_num = self.address(address_num, pred_res, key)
candidates += new_candidates


Loading…
Cancel
Save