|
|
@@ -36,7 +36,10 @@ class KBBase(ABC): |
|
|
|
|
|
|
|
def address(self, address_num, pred_res, key, multiple_predictions = False): |
|
|
|
new_candidates = [] |
|
|
|
address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) |
|
|
|
if not multiple_predictions: |
|
|
|
address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) |
|
|
|
else: |
|
|
|
address_idx_list = list(combinations(list(range(len(self.flatten(pred_res)))), address_num)) |
|
|
|
|
|
|
|
for address_idx in address_idx_list: |
|
|
|
candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) |
|
|
@@ -44,10 +47,10 @@ class KBBase(ABC): |
|
|
|
return new_candidates |
|
|
|
|
|
|
|
def correct_result(self, pred_res, key): |
|
|
|
if type(key) == int: |
|
|
|
if type(key) != bool: |
|
|
|
return abs(self.logic_forward(pred_res) - key) <= 1e-3 |
|
|
|
else: |
|
|
|
return self.logic_forward(pred_res) == key |
|
|
|
return self.logic_forward(pred_res) |
|
|
|
|
|
|
|
def abduction(self, pred_res, key, max_address_num, require_more_address, multiple_predictions = False): |
|
|
|
candidates = [] |
|
|
|