|
|
@@ -71,7 +71,7 @@ |
|
|
|
" def __init__(self, kb, dist_func='hamming'):\n", |
|
|
|
" super().__init__(kb, dist_func, zoopt=True)\n", |
|
|
|
" \n", |
|
|
|
" def _address_by_idxs(self, pred_res, key, all_address_flag, idxs):\n", |
|
|
|
" def _revise_by_idxs(self, pred_res, key, all_address_flag, idxs):\n", |
|
|
|
" pred = []\n", |
|
|
|
" k = []\n", |
|
|
|
" address_flag = []\n", |
|
|
@@ -80,10 +80,10 @@ |
|
|
|
" k.append(key[idx])\n", |
|
|
|
" address_flag += list(all_address_flag[idx])\n", |
|
|
|
" address_idx = np.where(np.array(address_flag) != 0)[0] \n", |
|
|
|
" candidate = self.address_by_idx(pred, k, address_idx)\n", |
|
|
|
" candidate = self.revise_by_idx(pred, k, address_idx)\n", |
|
|
|
" return candidate\n", |
|
|
|
" \n", |
|
|
|
" def zoopt_address_score(self, pred_res, pred_res_prob, key, sol): \n", |
|
|
|
" def zoopt_revision_score(self, pred_res, pred_res_prob, key, sol): \n", |
|
|
|
" all_address_flag = reform_idx(sol.get_x(), pred_res)\n", |
|
|
|
" lefted_idxs = [i for i in range(len(pred_res))]\n", |
|
|
|
" candidate_size = [] \n", |
|
|
@@ -95,7 +95,7 @@ |
|
|
|
" for idx in range(-1, len(pred_res)):\n", |
|
|
|
" if (not idx in idxs) and (idx >= 0):\n", |
|
|
|
" idxs.append(idx)\n", |
|
|
|
" candidate = self._address_by_idxs(pred_res, key, all_address_flag, idxs)\n", |
|
|
|
" candidate = self._revise_by_idxs(pred_res, key, all_address_flag, idxs)\n", |
|
|
|
" if len(candidate) == 0:\n", |
|
|
|
" if len(idxs) > 1:\n", |
|
|
|
" idxs.pop()\n", |
|
|
|