Browse Source

fix method name bug in abducer_base

pull/3/head
Tony-HYX 2 years ago
parent
commit
176f4bdd20
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      abl/abducer/abducer_base.py

+ 3
- 3
abl/abducer/abducer_base.py View File

@@ -51,7 +51,7 @@ class AbducerBase(abc.ABC):
candidate = candidates[np.argmin(cost_list)]
return candidate
def _get_zoopt_score(self, sol_x, pred_res, pred_res_prob, key):
def _zoopt_address_score_single(self, sol_x, pred_res, pred_res_prob, key):
address_idx = np.where(sol_x != 0)[0]
candidates = self.address_by_idx(pred_res, key, address_idx)
if len(candidates) > 0:
@@ -61,12 +61,12 @@ class AbducerBase(abc.ABC):
def _zoopt_address_score(self, pred_res, pred_res_prob, key, sol):
if not self.multiple_predictions:
return self._get_address_score(sol.get_x(), pred_res, pred_res_prob, key)
return self._zoopt_address_score_single(sol.get_x(), pred_res, pred_res_prob, key)
else:
all_address_flag = reform_idx(sol.get_x(), pred_res)
score = 0
for idx in range(len(pred_res)):
score += self._get_address_score(all_address_flag[idx], pred_res[idx], pred_res_prob[idx], key)
score += self._zoopt_address_score_single(all_address_flag[idx], pred_res[idx], pred_res_prob[idx], key)
return score
def _constrain_address_num(self, solution, max_address_num):


Loading…
Cancel
Save