Browse Source

[MNT] Change 'address' to 'revise'

pull/3/head
Gao Enhao 2 years ago
parent
commit
8653e6c0be
2 changed files with 5 additions and 5 deletions
  1. +1
    -1
      examples/hed/framework_hed.py
  2. +4
    -4
      examples/hed/hed_example.ipynb

+ 1
- 1
examples/hed/framework_hed.py View File

@@ -81,7 +81,7 @@ def abduce_and_train(model, abducer, mapping, train_X_true, select_num):
for idx in range(len(pred_res)):
address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0]
candidate = abducer.address_by_idx([pred_res[idx]], None, address_idx)
candidate = abducer.revise_by_idx([pred_res[idx]], None, address_idx)
if len(candidate) > 0:
consistent_idx_tmp.append(idx)
consistent_pred_res_tmp.append(candidate[0][0])


+ 4
- 4
examples/hed/hed_example.ipynb View File

@@ -71,7 +71,7 @@
" def __init__(self, kb, dist_func='hamming'):\n",
" super().__init__(kb, dist_func, zoopt=True)\n",
" \n",
" def _address_by_idxs(self, pred_res, key, all_address_flag, idxs):\n",
" def _revise_by_idxs(self, pred_res, key, all_address_flag, idxs):\n",
" pred = []\n",
" k = []\n",
" address_flag = []\n",
@@ -80,10 +80,10 @@
" k.append(key[idx])\n",
" address_flag += list(all_address_flag[idx])\n",
" address_idx = np.where(np.array(address_flag) != 0)[0] \n",
" candidate = self.address_by_idx(pred, k, address_idx)\n",
" candidate = self.revise_by_idx(pred, k, address_idx)\n",
" return candidate\n",
" \n",
" def zoopt_address_score(self, pred_res, pred_res_prob, key, sol): \n",
" def zoopt_revision_score(self, pred_res, pred_res_prob, key, sol): \n",
" all_address_flag = reform_idx(sol.get_x(), pred_res)\n",
" lefted_idxs = [i for i in range(len(pred_res))]\n",
" candidate_size = [] \n",
@@ -95,7 +95,7 @@
" for idx in range(-1, len(pred_res)):\n",
" if (not idx in idxs) and (idx >= 0):\n",
" idxs.append(idx)\n",
" candidate = self._address_by_idxs(pred_res, key, all_address_flag, idxs)\n",
" candidate = self._revise_by_idxs(pred_res, key, all_address_flag, idxs)\n",
" if len(candidate) == 0:\n",
" if len(idxs) > 1:\n",
" idxs.pop()\n",


Loading…
Cancel
Save