Browse Source

Fix bugs for non-zoopt option

pull/3/head
troyyyyy GitHub 2 years ago
parent
commit
01ac0b8dd1
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 3 deletions
  1. +6
    -3
      abducer/kb.py

+ 6
- 3
abducer/kb.py View File

@@ -36,7 +36,10 @@ class KBBase(ABC):

def address(self, address_num, pred_res, key, multiple_predictions = False):
new_candidates = []
address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
if not multiple_predictions:
address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
else:
address_idx_list = list(combinations(list(range(len(self.flatten(pred_res)))), address_num))
for address_idx in address_idx_list:
candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions)
@@ -44,10 +47,10 @@ class KBBase(ABC):
return new_candidates
def correct_result(self, pred_res, key):
if type(key) == int:
if type(key) != bool:
return abs(self.logic_forward(pred_res) - key) <= 1e-3
else:
return self.logic_forward(pred_res) == key
return self.logic_forward(pred_res)
def abduction(self, pred_res, key, max_address_num, require_more_address, multiple_predictions = False):
candidates = []


Loading…
Cancel
Save