Browse Source

[ENH]run hed_example.ipynb after reformat examples

pull/3/head
Gao Enhao 2 years ago
parent
commit
6bd0ff66d6
3 changed files with 17 additions and 100 deletions
  1. +3
    -3
      examples/hed/datasets/get_hed.py
  2. +6
    -81
      examples/hed/framework_hed.py
  3. +8
    -16
      examples/hed/hed_example.ipynb

+ 3
- 3
examples/hed/datasets/get_hed.py View File

@@ -41,7 +41,7 @@ def get_pretrain_data(labels, image_size=(28, 28, 1)):
X = []
for label in labels:
label_path = os.path.join(
"./datasets/hed/mnist_images", label
"./datasets/mnist_images", label
)
img_path_list = os.listdir(label_path)
for img_path in img_path_list:
@@ -107,13 +107,13 @@ def get_hed(dataset="mnist", train=True):

if dataset == "mnist":
with open(
"./datasets/hed/mnist_equation_data_train_len_26_test_len_26_sys_2_.pk",
"./datasets/mnist_equation_data_train_len_26_test_len_26_sys_2_.pk",
"rb",
) as f:
img_dataset = pickle.load(f)
elif dataset == "random":
with open(
"./datasets/hed/random_equation_data_train_len_26_test_len_26_sys_2_.pk",
"./datasets/random_equation_data_train_len_26_test_len_26_sys_2_.pk",
"rb",
) as f:
img_dataset = pickle.load(f)


+ 6
- 81
examples/hed/framework_hed.py View File

@@ -10,93 +10,18 @@
#
# ================================================================#

import pickle as pk
import torch
import torch.nn as nn
import numpy as np
import os

from .utils.plog import INFO, DEBUG, clocker
from .utils.utils import flatten, reform_idx, block_sample, gen_mappings, mapping_res, remapping_res
from abl.utils.plog import INFO
from abl.utils.utils import flatten, reform_idx
from abl.models.basic_model import BasicModel, BasicDataset

from .models.nn import SymbolNetAutoencoder
from .models.basic_model import BasicModel, BasicDataset

import sys
sys.path.append("..")
from examples.datasets.hed.get_hed import get_pretrain_data

def result_statistics(pred_Z, Z, Y, logic_forward, char_acc_flag):
result = {}
if char_acc_flag:
char_acc_num = 0
char_num = 0
for pred_z, z in zip(pred_Z, Z):
char_num += len(z)
for zidx in range(len(z)):
if pred_z[zidx] == z[zidx]:
char_acc_num += 1
char_acc = char_acc_num / char_num
result["Character level accuracy"] = char_acc

abl_acc_num = 0
for pred_z, y in zip(pred_Z, Y):
if logic_forward(pred_z) == y:
abl_acc_num += 1
abl_acc = abl_acc_num / len(Y)
result["ABL accuracy"] = abl_acc

return result


def filter_data(X, abduced_Z):
finetune_Z = []
finetune_X = []
for x, abduced_z in zip(X, abduced_Z):
if len(abduced_z) > 0:
finetune_X.append(x)
finetune_Z.append(abduced_z)
return finetune_X, finetune_Z



def train(model, abducer, train_data, test_data, loop_num=50, sample_num=-1, verbose=-1):
train_X, train_Z, train_Y = train_data
test_X, test_Z, test_Y = test_data

# Set default parameters
if sample_num == -1:
sample_num = len(train_X)

if verbose < 1:
verbose = loop_num

char_acc_flag = 1
if train_Z == None:
char_acc_flag = 0
train_Z = [None] * len(train_X)

predict_func = clocker(model.predict)
train_func = clocker(model.train)
abduce_func = clocker(abducer.batch_abduce)

for loop_idx in range(loop_num):
X, Z, Y = block_sample(train_X, train_Z, train_Y, sample_num, loop_idx)
preds_res = predict_func(X)
abduced_Z = abduce_func(preds_res, Y)

if ((loop_idx + 1) % verbose == 0) or (loop_idx == loop_num - 1):
res = result_statistics(preds_res['cls'], Z, Y, abducer.kb.logic_forward, char_acc_flag)
INFO('loop: ', loop_idx + 1, ' ', res)

finetune_X, finetune_Z = filter_data(X, abduced_Z)
if len(finetune_X) > 0:
# model.valid(finetune_X, finetune_Z)
train_func(finetune_X, finetune_Z)
else:
INFO("lack of data, all abduced failed", len(finetune_X))

return res
from utils import gen_mappings, mapping_res, remapping_res
from models.nn import SymbolNetAutoencoder
from datasets.get_hed import get_pretrain_data


def hed_pretrain(kb, cls, recorder):


+ 8
- 16
examples/hed/hed_example.ipynb View File

@@ -2,13 +2,13 @@
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
"sys.path.append(\"../\")\n",
"sys.path.append(\"../../\")\n",
"\n",
"import torch.nn as nn\n",
"import torch\n",
@@ -21,13 +21,13 @@
"from abl.models.wabl_models import WABLBasicModel\n",
"\n",
"from models.nn import SymbolNet\n",
"from datasets.hed.get_hed import get_hed, split_equation\n",
"from abl import framework_hed"
"from datasets.get_hed import get_hed, split_equation\n",
"import framework_hed"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -45,20 +45,12 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"ERROR: /home/gaoeh/ABL-Package/examples/datasets/hed/learn_add.pl:67:9: Syntax error: Operator expected\n"
]
}
],
"outputs": [],
"source": [
"# Initialize knowledge base and abducer\n",
"kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='./datasets/hed/learn_add.pl')\n",
"kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='./datasets/learn_add.pl')\n",
"abducer = HED_Abducer(kb)"
]
},


Loading…
Cancel
Save