Browse Source

[ENH] run hed successfully

pull/1/head
Gao Enhao 1 year ago
parent
commit
011f2f312f
3 changed files with 40 additions and 30 deletions
  1. +3
    -0
      abl/utils/utils.py
  2. +6
    -8
      examples/hed/hed_bridge.py
  3. +31
    -22
      examples/hed/hed_example.ipynb

+ 3
- 0
abl/utils/utils.py View File

@@ -25,6 +25,9 @@ def flatten(nested_list):
# if not isinstance(nested_list, list):
# raise TypeError("Input must be of type list.")

if isinstance(nested_list, list) and len(nested_list) == 0:
return nested_list

if not isinstance(nested_list, list) or not isinstance(nested_list[0], (list, tuple)):
return nested_list



+ 6
- 8
examples/hed/hed_bridge.py View File

@@ -53,8 +53,7 @@ class HEDBridge(SimpleBridge):
pretrain_data, batch_size=64, shuffle=True
)

min_loss = pretrain_model.fit(pretrain_data_loader)
print_log(f"min loss is {min_loss}", logger="current")
pretrain_model.fit(pretrain_data_loader)
save_parma_dic = {
"model": cls_autoencoder.base_model.state_dict(),
}
@@ -81,6 +80,7 @@ class HEDBridge(SimpleBridge):
self.reasoner.remapping = dict(
zip(self.reasoner.mapping.values(), self.reasoner.mapping.keys())
)
self.idx_to_pseudo_label(data_samples)
data_samples.abduced_pseudo_label = abduced_pseudo_label_list[return_idx]

return data_samples.abduced_pseudo_label
@@ -202,6 +202,10 @@ class HEDBridge(SimpleBridge):
data_samples = self.data_preprocess(train_data[1], equation_len)
sampler = InfiniteSampler(len(data_samples), batch_size=segment_size)
for seg_idx, select_idx in enumerate(sampler):
print_log(
f"Equation Len(train) [{equation_len}] Segment Index [{seg_idx + 1}]",
logger="current",
)
sub_data_samples = data_samples[select_idx]
self.predict(sub_data_samples)
if equation_len == min_len:
@@ -213,12 +217,6 @@ class HEDBridge(SimpleBridge):
self.pseudo_label_to_idx(filtered_sub_data_samples)
loss = self.model.train(filtered_sub_data_samples)

print_log(
f"Equation Len(train) [{equation_len}] Segment Index [{seg_idx + 1}] \
model loss is {loss:.5f}",
logger="current",
)

if self.check_training_impact(filtered_sub_data_samples, sub_data_samples):
condition_num += 1
else:


+ 31
- 22
examples/hed/hed_example.ipynb View File

@@ -11,6 +11,7 @@
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"from zoopt import Dimension, Objective, Opt, Parameter\n",
"\n",
"from abl.evaluation import ReasoningMetric, SymbolMetric\n",
"from abl.learning import ABLModel, BasicNN\n",
@@ -71,7 +72,7 @@
" def revise_at_idx(self, data_sample):\n",
" revision_idx = np.where(np.array(data_sample.flatten(\"revision_flag\")) != 0)[0]\n",
" candidate = self.kb.revise_at_idx(\n",
" data_sample.pred_pseudo_label, data_sample.Y, revision_idx\n",
" data_sample.pred_pseudo_label, data_sample.Y, data_sample.X, revision_idx\n",
" )\n",
" return candidate\n",
"\n",
@@ -83,6 +84,7 @@
"\n",
" lefted_idxs = [i for i in range(len(data_sample.pred_idx))]\n",
" candidate_size = []\n",
" max_consistent_idxs = []\n",
" while lefted_idxs:\n",
" idxs = []\n",
" idxs.append(lefted_idxs.pop(0))\n",
@@ -91,8 +93,8 @@
" for idx in range(-1, len(data_sample.pred_idx)):\n",
" if (not idx in idxs) and (idx >= 0):\n",
" idxs.append(idx)\n",
" candidate = self.revise_at_idx(data_sample[idxs])\n",
" if len(candidate) == 0:\n",
" candidates, _ = self.revise_at_idx(data_sample[idxs])\n",
" if len(candidates) == 0:\n",
" if len(idxs) > 1:\n",
" idxs.pop()\n",
" else:\n",
@@ -101,7 +103,9 @@
" max_candidate_idxs = idxs.copy()\n",
" removed = [i for i in lefted_idxs if i in max_candidate_idxs]\n",
" if found:\n",
" candidate_size.append(len(removed) + 1)\n",
" removed.insert(0, idxs[0])\n",
" candidate_size.append(len(removed))\n",
" max_consistent_idxs = max_candidate_idxs.copy()\n",
" lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]\n",
" candidate_size.sort()\n",
" score = 0\n",
@@ -109,27 +113,32 @@
"\n",
" for i in range(0, len(candidate_size)):\n",
" score -= math.exp(-i) * candidate_size[i]\n",
" return score\n",
" return score, max_consistent_idxs\n",
" \n",
" def _zoopt_get_solution(self, symbol_num, data_sample, max_revision_num):\n",
" dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num)\n",
" objective = Objective(\n",
" lambda sol: self.zoopt_revision_score(symbol_num, data_sample, sol)[0],\n",
" dim=dimension,\n",
" constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),\n",
" )\n",
" parameter = Parameter(budget=100, intermediate_result=False, autoset=True)\n",
" solution = Opt.min(objective, parameter)\n",
" return solution\n",
"\n",
" def abduce(self, data_sample):\n",
" symbol_num = data_sample.elements_num(\"pred_pseudo_label\")\n",
" max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)\n",
"\n",
" solution = self.zoopt_get_solution(symbol_num, data_sample, max_revision_num)\n",
"\n",
" data_sample.revision_flag = reform_list(\n",
" solution.astype(np.int32), data_sample.pred_pseudo_label\n",
" )\n",
" solution = self._zoopt_get_solution(symbol_num, data_sample, max_revision_num)\n",
" _, max_candidate_idxs = self.zoopt_revision_score(symbol_num, data_sample, solution)\n",
"\n",
" abduced_pseudo_label = []\n",
" abduced_pseudo_label = [[] for _ in range(len(data_sample))]\n",
"\n",
" for single_instance in data_sample:\n",
" single_instance.pred_pseudo_label = [single_instance.pred_pseudo_label]\n",
" candidates = self.revise_at_idx(single_instance)\n",
" if len(candidates) == 0:\n",
" abduced_pseudo_label.append([])\n",
" else:\n",
" abduced_pseudo_label.append(candidates[0][0])\n",
" if len(max_candidate_idxs) > 0:\n",
" candidates, _ = self.revise_at_idx(data_sample[max_candidate_idxs])\n",
" for i, idx in enumerate(max_candidate_idxs):\n",
" abduced_pseudo_label[idx] = candidates[0][i]\n",
" data_sample.abduced_pseudo_label = abduced_pseudo_label\n",
" return abduced_pseudo_label\n",
"\n",
@@ -138,7 +147,7 @@
"\n",
"\n",
"kb = HedKB(pseudo_label_list=[1, 0, \"+\", \"=\"], pl_file=\"./datasets/learn_add.pl\")\n",
"reasoner = HedReasoner(kb, dist_func=\"hamming\", use_zoopt=True, max_revision=20)"
"reasoner = HedReasoner(kb, dist_func=\"hamming\", use_zoopt=True, max_revision=10)"
]
},
{
@@ -158,8 +167,7 @@
"# Build necessary components for BasicNN\n",
"cls = SymbolNet(num_classes=4)\n",
"loss_fn = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6)\n",
"# optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))\n",
"optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-4)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
@@ -179,6 +187,7 @@
" batch_size=32,\n",
" num_epochs=1,\n",
" save_interval=1,\n",
" stop_loss=None,\n",
" save_dir=weights_dir,\n",
")"
]
@@ -210,7 +219,7 @@
"outputs": [],
"source": [
"# Set up metrics\n",
"metric_list = [SymbolMetric(prefix=\"hed\"), ReasoningMetric(prefix=\"hed\")]"
"metric_list = [SymbolMetric(prefix=\"hed\"), ReasoningMetric(kb=kb, prefix=\"hed\")]"
]
},
{


Loading…
Cancel
Save