Browse Source

Update abducer_base.py

pull/3/head
troyyyyy GitHub 2 years ago
parent
commit
0883e27fc1
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 12 deletions
  1. +13
    -12
      abducer/abducer_base.py

+ 13
- 12
abducer/abducer_base.py View File

@@ -45,7 +45,10 @@ class AbducerBase(abc.ABC):
dist_func = confidence_dist
self.dist_func = dist_func
if pred_res_parse is None:
pred_res_parse = lambda x : x["cls"]
if(dist_func == "hamming"):
pred_res_parse = lambda x : x["cls"]
elif dist_func == "confidence":
pred_res_parse = lambda x : x[" "]
self.pred_res_parse = pred_res_parse
self.cache = cache
@@ -75,20 +78,18 @@ class AbducerBase(abc.ABC):
else:
candidates, min_address_num, address_num = self.get_abduce_candidates(pred_res, ans, max_address_num, require_more_address)
cost_list = self.dist_func(pred_res, candidates)
if(self.cache):
self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num
self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates

return candidates
cost_list = self.dist_func(pred_res, candidates)
min_address_num = np.min(cost_list)
idxs = np.where(cost_list == min_address_num)[0]
candidates = [candidates[idx] for idx in idxs]
# if len(idxs) > 1:
# return None
# return [candidates[idx] for idx in idxs]
return candidates[0]
def address(self, address_num, pred_res, key):
new_candidates = []
@@ -99,7 +100,7 @@ class AbducerBase(abc.ABC):
pred_res_array = np.array(pred_res)
if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num):
pred_res_array[np.array(address_idx)] = c
if(self.kb.logic_forward(pred_res_array) == key):
if(abs(self.kb.logic_forward(pred_res_array) - key) <= 1e-3):
new_candidates.append(pred_res_array)
return new_candidates, address_num
@@ -113,7 +114,7 @@ class AbducerBase(abc.ABC):
return None, None, None
if(address_num == 0):
if(self.kb.logic_forward(pred_res) == key):
if(abs(self.kb.logic_forward(pred_res) - key) <= 1e-3):
candidates.append(pred_res)
else:
new_candidates, address_num = self.address(address_num, pred_res, key)
@@ -148,7 +149,7 @@ if __name__ == "__main__":
pseudo_label_list = list(range(10))
kb = add_KB(pseudo_label_list)
abd = AbducerBase(kb)
res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 1)
res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 0)
print(res)
print()
res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 1)


Loading…
Cancel
Save