|
|
@@ -11,6 +11,7 @@ |
|
|
|
"import numpy as np\n", |
|
|
|
"import torch\n", |
|
|
|
"import torch.nn as nn\n", |
|
|
|
"from zoopt import Dimension, Objective, Opt, Parameter\n", |
|
|
|
"\n", |
|
|
|
"from abl.evaluation import ReasoningMetric, SymbolMetric\n", |
|
|
|
"from abl.learning import ABLModel, BasicNN\n", |
|
|
@@ -71,7 +72,7 @@ |
|
|
|
" def revise_at_idx(self, data_sample):\n", |
|
|
|
" revision_idx = np.where(np.array(data_sample.flatten(\"revision_flag\")) != 0)[0]\n", |
|
|
|
" candidate = self.kb.revise_at_idx(\n", |
|
|
|
" data_sample.pred_pseudo_label, data_sample.Y, revision_idx\n", |
|
|
|
" data_sample.pred_pseudo_label, data_sample.Y, data_sample.X, revision_idx\n", |
|
|
|
" )\n", |
|
|
|
" return candidate\n", |
|
|
|
"\n", |
|
|
@@ -83,6 +84,7 @@ |
|
|
|
"\n", |
|
|
|
" lefted_idxs = [i for i in range(len(data_sample.pred_idx))]\n", |
|
|
|
" candidate_size = []\n", |
|
|
|
" max_consistent_idxs = []\n", |
|
|
|
" while lefted_idxs:\n", |
|
|
|
" idxs = []\n", |
|
|
|
" idxs.append(lefted_idxs.pop(0))\n", |
|
|
@@ -91,8 +93,8 @@ |
|
|
|
" for idx in range(-1, len(data_sample.pred_idx)):\n", |
|
|
|
" if (not idx in idxs) and (idx >= 0):\n", |
|
|
|
" idxs.append(idx)\n", |
|
|
|
" candidate = self.revise_at_idx(data_sample[idxs])\n", |
|
|
|
" if len(candidate) == 0:\n", |
|
|
|
" candidates, _ = self.revise_at_idx(data_sample[idxs])\n", |
|
|
|
" if len(candidates) == 0:\n", |
|
|
|
" if len(idxs) > 1:\n", |
|
|
|
" idxs.pop()\n", |
|
|
|
" else:\n", |
|
|
@@ -101,7 +103,9 @@ |
|
|
|
" max_candidate_idxs = idxs.copy()\n", |
|
|
|
" removed = [i for i in lefted_idxs if i in max_candidate_idxs]\n", |
|
|
|
" if found:\n", |
|
|
|
" candidate_size.append(len(removed) + 1)\n", |
|
|
|
" removed.insert(0, idxs[0])\n", |
|
|
|
" candidate_size.append(len(removed))\n", |
|
|
|
" max_consistent_idxs = max_candidate_idxs.copy()\n", |
|
|
|
" lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]\n", |
|
|
|
" candidate_size.sort()\n", |
|
|
|
" score = 0\n", |
|
|
@@ -109,27 +113,32 @@ |
|
|
|
"\n", |
|
|
|
" for i in range(0, len(candidate_size)):\n", |
|
|
|
" score -= math.exp(-i) * candidate_size[i]\n", |
|
|
|
" return score\n", |
|
|
|
" return score, max_consistent_idxs\n", |
|
|
|
" \n", |
|
|
|
" def _zoopt_get_solution(self, symbol_num, data_sample, max_revision_num):\n", |
|
|
|
" dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num)\n", |
|
|
|
" objective = Objective(\n", |
|
|
|
" lambda sol: self.zoopt_revision_score(symbol_num, data_sample, sol)[0],\n", |
|
|
|
" dim=dimension,\n", |
|
|
|
" constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),\n", |
|
|
|
" )\n", |
|
|
|
" parameter = Parameter(budget=100, intermediate_result=False, autoset=True)\n", |
|
|
|
" solution = Opt.min(objective, parameter)\n", |
|
|
|
" return solution\n", |
|
|
|
"\n", |
|
|
|
" def abduce(self, data_sample):\n", |
|
|
|
" symbol_num = data_sample.elements_num(\"pred_pseudo_label\")\n", |
|
|
|
" max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)\n", |
|
|
|
"\n", |
|
|
|
" solution = self.zoopt_get_solution(symbol_num, data_sample, max_revision_num)\n", |
|
|
|
"\n", |
|
|
|
" data_sample.revision_flag = reform_list(\n", |
|
|
|
" solution.astype(np.int32), data_sample.pred_pseudo_label\n", |
|
|
|
" )\n", |
|
|
|
" solution = self._zoopt_get_solution(symbol_num, data_sample, max_revision_num)\n", |
|
|
|
" _, max_candidate_idxs = self.zoopt_revision_score(symbol_num, data_sample, solution)\n", |
|
|
|
"\n", |
|
|
|
" abduced_pseudo_label = []\n", |
|
|
|
" abduced_pseudo_label = [[] for _ in range(len(data_sample))]\n", |
|
|
|
"\n", |
|
|
|
" for single_instance in data_sample:\n", |
|
|
|
" single_instance.pred_pseudo_label = [single_instance.pred_pseudo_label]\n", |
|
|
|
" candidates = self.revise_at_idx(single_instance)\n", |
|
|
|
" if len(candidates) == 0:\n", |
|
|
|
" abduced_pseudo_label.append([])\n", |
|
|
|
" else:\n", |
|
|
|
" abduced_pseudo_label.append(candidates[0][0])\n", |
|
|
|
" if len(max_candidate_idxs) > 0:\n", |
|
|
|
" candidates, _ = self.revise_at_idx(data_sample[max_candidate_idxs])\n", |
|
|
|
" for i, idx in enumerate(max_candidate_idxs):\n", |
|
|
|
" abduced_pseudo_label[idx] = candidates[0][i]\n", |
|
|
|
" data_sample.abduced_pseudo_label = abduced_pseudo_label\n", |
|
|
|
" return abduced_pseudo_label\n", |
|
|
|
"\n", |
|
|
@@ -138,7 +147,7 @@ |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"kb = HedKB(pseudo_label_list=[1, 0, \"+\", \"=\"], pl_file=\"./datasets/learn_add.pl\")\n", |
|
|
|
"reasoner = HedReasoner(kb, dist_func=\"hamming\", use_zoopt=True, max_revision=20)" |
|
|
|
"reasoner = HedReasoner(kb, dist_func=\"hamming\", use_zoopt=True, max_revision=10)" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
@@ -158,8 +167,7 @@ |
|
|
|
"# Build necessary components for BasicNN\n", |
|
|
|
"cls = SymbolNet(num_classes=4)\n", |
|
|
|
"loss_fn = nn.CrossEntropyLoss()\n", |
|
|
|
"optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6)\n", |
|
|
|
"# optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))\n", |
|
|
|
"optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-4)\n", |
|
|
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" |
|
|
|
] |
|
|
|
}, |
|
|
@@ -179,6 +187,7 @@ |
|
|
|
" batch_size=32,\n", |
|
|
|
" num_epochs=1,\n", |
|
|
|
" save_interval=1,\n", |
|
|
|
" stop_loss=None,\n", |
|
|
|
" save_dir=weights_dir,\n", |
|
|
|
")" |
|
|
|
] |
|
|
@@ -210,7 +219,7 @@ |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"# Set up metrics\n", |
|
|
|
"metric_list = [SymbolMetric(prefix=\"hed\"), ReasoningMetric(prefix=\"hed\")]" |
|
|
|
"metric_list = [SymbolMetric(prefix=\"hed\"), ReasoningMetric(kb=kb, prefix=\"hed\")]" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|