Browse Source

[MNT] improve hed performance

pull/1/head
Gao Enhao 1 year ago
parent
commit
e95e7998b6
3 changed files with 17 additions and 186 deletions
  1. +1
    -1
      examples/hed/hed_bridge.py
  2. +14
    -185
      examples/hed/hed_example.ipynb
  3. +2
    -0
      examples/hed/utils.py

+ 1
- 1
examples/hed/hed_bridge.py View File

@@ -98,7 +98,7 @@ class HEDBridge(SimpleBridge):
)
print_log(log_string, logger="current")

if character_accuracy >= 0.9 and revisible_ratio >= 0.9:
if character_accuracy >= 0.95 and revisible_ratio >= 0.95:
return True
return False



+ 14
- 185
examples/hed/hed_example.ipynb View File

@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -24,17 +24,9 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"12/18 09:01:12 - abl - INFO - Abductive Learning on the HED example.\n"
]
}
],
"outputs": [],
"source": [
"# Build logger\n",
"print_log(\"Abductive Learning on the HED example.\", logger=\"current\")\n",
@@ -54,7 +46,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -130,7 +122,7 @@
" dim=dimension,\n",
" constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),\n",
" )\n",
" parameter = Parameter(budget=100, intermediate_result=False, autoset=True)\n",
" parameter = Parameter(budget=200, intermediate_result=False, autoset=True)\n",
" solution = Opt.min(objective, parameter)\n",
" return solution\n",
"\n",
@@ -168,7 +160,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -181,7 +173,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -202,7 +194,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -222,7 +214,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -240,7 +232,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -257,7 +249,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -276,176 +268,13 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"12/18 09:04:27 - abl - INFO - Pretrain Start\n",
"12/18 09:04:31 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_1.pth\n",
"12/18 09:04:33 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_2.pth\n",
"12/18 09:04:34 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_3.pth\n",
"12/18 09:04:36 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_4.pth\n",
"12/18 09:04:37 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_5.pth\n",
"12/18 09:04:38 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_6.pth\n",
"12/18 09:04:40 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_7.pth\n",
"12/18 09:04:41 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_8.pth\n",
"12/18 09:04:43 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_9.pth\n",
"12/18 09:04:44 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_10.pth\n",
"12/18 09:04:44 - abl - INFO - model loss: 0.78453\n",
"12/18 09:04:44 - abl - INFO - min loss is <abl.learning.basic_nn.BasicNN object at 0x7f6c4f9393d0>\n",
"12/18 09:04:44 - abl - INFO - Loads checkpoint by local backend from path: ./weights/pretrain_weights.pth\n",
"12/18 09:04:44 - abl - INFO - ============== equation_len: 5-6 ================\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [-1.0, 9.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [-1.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [-2.0, 8.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [-1.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [-2.0, 8.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [-1.0, 9.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [-2.0, 6.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [-1.0, 9.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [-1.0, 8.0]\n",
"[zoopt] x: array([0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [-2.0, 8.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [-2.0, 8.0]\n",
"12/18 09:05:16 - abl - INFO - Checkpoints will be saved to results/20231218_09_01_12/weights/model_checkpoint_epoch_1.pth\n",
"12/18 09:05:16 - abl - INFO - model loss: 0.59495\n"
]
},
{
"ename": "TypeError",
"evalue": "unsupported format string passed to BasicNN.__format__",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Input \u001b[0;32mIn [13]\u001b[0m, in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m bridge\u001b[38;5;241m.\u001b[39mpretrain(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m./weights\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m \u001b[43mbridge\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_data\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/ABL-Package/examples/hed/hed_bridge.py:217\u001b[0m, in \u001b[0;36mHEDBridge.train\u001b[0;34m(self, train_data, val_data, segment_size, min_len, max_len)\u001b[0m\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mabduce_pseudo_label(sub_data_examples)\n\u001b[1;32m 216\u001b[0m filtered_sub_data_examples \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfilter_empty(sub_data_examples)\n\u001b[0;32m--> 217\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpseudo_label_to_idx(filtered_sub_data_examples)\n\u001b[1;32m 218\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mtrain(filtered_sub_data_examples)\n\u001b[1;32m 220\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcheck_training_impact(filtered_sub_data_examples, sub_data_examples):\n",
"\u001b[0;31mTypeError\u001b[0m: unsupported format string passed to BasicNN.__format__"
]
}
],
"outputs": [],
"source": [
"bridge.pretrain(\"./weights\")\n",
"bridge.train(train_data, val_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@@ -464,7 +293,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.8.18"
},
"orig_nbformat": 4,
"vscode": {


+ 2
- 0
examples/hed/utils.py View File

@@ -32,6 +32,8 @@ def gen_mappings(chars, symbs):
# returned mappings
perms = permutations(symbs)
for p in perms:
if p.index(1) < p.index(0):
continue
mappings.append(dict(zip(chars, list(p))))
return mappings



Loading…
Cancel
Save