@@ -2,7 +2,7 @@ import inspect | |||||
from typing import Callable, Any, List, Optional, Union | from typing import Callable, Any, List, Optional, Union | ||||
import numpy as np | import numpy as np | ||||
from zoopt import Dimension, Objective, Opt, Parameter | |||||
from zoopt import Dimension, Objective, Opt, Parameter, Solution | |||||
from ..reasoning import KBBase | from ..reasoning import KBBase | ||||
from ..structures import ListData | from ..structures import ListData | ||||
@@ -175,9 +175,9 @@ class Reasoner: | |||||
symbol_num: int, | symbol_num: int, | ||||
data_example: ListData, | data_example: ListData, | ||||
max_revision_num: int, | max_revision_num: int, | ||||
) -> List[bool]: | |||||
) -> Solution: | |||||
""" | """ | ||||
Get the optimal solution using ZOOpt library. The solution is a list of | |||||
Get the optimal solution using ZOOpt library. From the solution, we can get a list of | |||||
boolean values, where '1' (True) indicates the indices chosen to be revised. | boolean values, where '1' (True) indicates the indices chosen to be revised. | ||||
Parameters | Parameters | ||||
@@ -191,7 +191,7 @@ class Reasoner: | |||||
Returns | Returns | ||||
------- | ------- | ||||
List[bool] | |||||
Solution | |||||
The solution for ZOOpt library. | The solution for ZOOpt library. | ||||
""" | """ | ||||
dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num) | dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num) | ||||
@@ -201,14 +201,14 @@ class Reasoner: | |||||
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), | constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), | ||||
) | ) | ||||
parameter = Parameter(budget=100, intermediate_result=False, autoset=True) | parameter = Parameter(budget=100, intermediate_result=False, autoset=True) | ||||
solution = Opt.min(objective, parameter).get_x() | |||||
solution = Opt.min(objective, parameter) | |||||
return solution | return solution | ||||
def zoopt_revision_score( | def zoopt_revision_score( | ||||
self, | self, | ||||
symbol_num: int, | symbol_num: int, | ||||
data_example: ListData, | data_example: ListData, | ||||
sol: List[bool], | |||||
sol: Solution, | |||||
) -> int: | ) -> int: | ||||
""" | """ | ||||
Get the revision score for a solution. A lower score suggests that ZOOpt library | Get the revision score for a solution. A lower score suggests that ZOOpt library | ||||
@@ -220,7 +220,7 @@ class Reasoner: | |||||
Number of total symbols. | Number of total symbols. | ||||
data_example : ListData | data_example : ListData | ||||
Data example. | Data example. | ||||
sol: List[bool] | |||||
sol: Solution | |||||
The solution for ZOOpt library. | The solution for ZOOpt library. | ||||
Returns | Returns | ||||
@@ -237,7 +237,7 @@ class Reasoner: | |||||
else: | else: | ||||
return symbol_num | return symbol_num | ||||
def _constrain_revision_num(self, solution: List[bool], max_revision_num: int) -> int: | |||||
def _constrain_revision_num(self, solution: Solution, max_revision_num: int) -> int: | |||||
""" | """ | ||||
Constrain that the total number of revisions chosen by the solution does not exceed | Constrain that the total number of revisions chosen by the solution does not exceed | ||||
maximum number of revisions allowed. | maximum number of revisions allowed. | ||||
@@ -287,7 +287,7 @@ class Reasoner: | |||||
if self.use_zoopt: | if self.use_zoopt: | ||||
solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num) | solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num) | ||||
revision_idx = np.where(solution != 0)[0] | |||||
revision_idx = np.where(solution.get_x() != 0)[0] | |||||
candidates, reasoning_results = self.kb.revise_at_idx( | candidates, reasoning_results = self.kb.revise_at_idx( | ||||
pseudo_label=data_example.pred_pseudo_label, | pseudo_label=data_example.pred_pseudo_label, | ||||
y=data_example.Y, | y=data_example.Y, | ||||
@@ -10,7 +10,7 @@ from abl.learning import ABLModel, BasicNN | |||||
from abl.reasoning import Reasoner | from abl.reasoning import Reasoner | ||||
from abl.structures import ListData | from abl.structures import ListData | ||||
from abl.utils import print_log | from abl.utils import print_log | ||||
from examples.hed.datasets.get_hed import get_pretrain_data | |||||
from examples.hed.datasets.get_dataset import get_pretrain_data | |||||
from examples.hed.utils import InfiniteSampler, gen_mappings | from examples.hed.utils import InfiniteSampler, gen_mappings | ||||
from examples.models.nn import SymbolNetAutoencoder | from examples.models.nn import SymbolNetAutoencoder | ||||
@@ -0,0 +1,4 @@ | |||||
from .get_dataset import get_dataset, split_equation | |||||
__all__ = ["get_dataset", "split_equation"] |
@@ -1,173 +0,0 @@ | |||||
import os | |||||
import itertools | |||||
import random | |||||
import numpy as np | |||||
from PIL import Image | |||||
import pickle | |||||
def get_sign_path_list(data_dir, sign_names): | |||||
sign_num = len(sign_names) | |||||
index_dict = dict(zip(sign_names, list(range(sign_num)))) | |||||
ret = [[] for _ in range(sign_num)] | |||||
for path in os.listdir(data_dir): | |||||
if (path in sign_names): | |||||
index = index_dict[path] | |||||
sign_path = os.path.join(data_dir, path) | |||||
for p in os.listdir(sign_path): | |||||
ret[index].append(os.path.join(sign_path, p)) | |||||
return ret | |||||
def split_pool_by_rate(pools, rate, seed = None): | |||||
if seed is not None: | |||||
random.seed(seed) | |||||
ret1 = [] | |||||
ret2 = [] | |||||
for pool in pools: | |||||
random.shuffle(pool) | |||||
num = int(len(pool) * rate) | |||||
ret1.append(pool[:num]) | |||||
ret2.append(pool[num:]) | |||||
return ret1, ret2 | |||||
def int_to_system_form(num, system_num): | |||||
if num == 0: | |||||
return "0" | |||||
ret = "" | |||||
while (num > 0): | |||||
ret += str(num % system_num) | |||||
num //= system_num | |||||
return ret[::-1] | |||||
def generator_equations(left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type): | |||||
expr_len = left_opt_len + right_opt_len | |||||
num_list = "".join([str(i) for i in range(system_num)]) | |||||
ret = [] | |||||
if generate_type == "all": | |||||
candidates = itertools.product(num_list, repeat = expr_len) | |||||
else: | |||||
candidates = [''.join(random.sample(['0', '1'] * expr_len, expr_len))] | |||||
random.shuffle(candidates) | |||||
for nums in candidates: | |||||
left_num = "".join(nums[:left_opt_len]) | |||||
right_num = "".join(nums[left_opt_len:]) | |||||
left_value = int(left_num, system_num) | |||||
right_value = int(right_num, system_num) | |||||
result_value = left_value + right_value | |||||
if (label == 'negative'): | |||||
result_value += random.randint(-result_value, result_value) | |||||
if (left_value + right_value == result_value): | |||||
continue | |||||
result_num = int_to_system_form(result_value, system_num) | |||||
#leading zeros | |||||
if (res_opt_len != len(result_num)): | |||||
continue | |||||
if ((left_opt_len > 1 and left_num[0] == '0') or (right_opt_len > 1 and right_num[0] == '0')): | |||||
continue | |||||
#add leading zeros | |||||
if (res_opt_len < len(result_num)): | |||||
continue | |||||
while (len(result_num) < res_opt_len): | |||||
result_num = '0' + result_num | |||||
#continue | |||||
ret.append(left_num + '+' + right_num + '=' + result_num) # current only consider '+' and '=' | |||||
#print(ret[-1]) | |||||
return ret | |||||
def generator_equation_by_len(equation_len, system_num = 2, label = 0, require_num = 1): | |||||
generate_type = "one" | |||||
ret = [] | |||||
equation_sign_num = 2 # '+' and '=' | |||||
while len(ret) < require_num: | |||||
left_opt_len = random.randint(1, equation_len - 1 - equation_sign_num) | |||||
right_opt_len = random.randint(1, equation_len - left_opt_len - equation_sign_num) | |||||
res_opt_len = equation_len - left_opt_len - right_opt_len - equation_sign_num | |||||
ret.extend(generator_equations(left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type)) | |||||
return ret | |||||
def generator_equations_by_len(equation_len, system_num = 2, label = 0, repeat_times = 1, keep = 1, generate_type = "all"): | |||||
ret = [] | |||||
equation_sign_num = 2 # '+' and '=' | |||||
for left_opt_len in range(1, equation_len - (2 + equation_sign_num) + 1): | |||||
for right_opt_len in range(1, equation_len - left_opt_len - (1 + equation_sign_num) + 1): | |||||
res_opt_len = equation_len - left_opt_len - right_opt_len - equation_sign_num | |||||
for i in range(repeat_times): #generate more equations | |||||
if random.random() > keep ** (equation_len): | |||||
continue | |||||
ret.extend(generator_equations(left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type)) | |||||
return ret | |||||
def generator_equations_by_max_len(max_equation_len, system_num = 2, label = 0, repeat_times = 1, keep = 1, generate_type = "all", num_per_len = None): | |||||
ret = [] | |||||
equation_sign_num = 2 # '+' and '=' | |||||
for equation_len in range(3 + equation_sign_num, max_equation_len + 1): | |||||
if (num_per_len is None): | |||||
ret.extend(generator_equations_by_len(equation_len, system_num, label, repeat_times, keep, generate_type)) | |||||
else: | |||||
ret.extend(generator_equation_by_len(equation_len, system_num, label, require_num = num_per_len)) | |||||
return ret | |||||
def generator_equation_images(image_pools, equations, signs, shape, seed, is_color): | |||||
if (seed is not None): | |||||
random.seed(seed) | |||||
ret = [] | |||||
sign_num = len(signs) | |||||
sign_index_dict = dict(zip(signs, list(range(sign_num)))) | |||||
for equation in equations: | |||||
data = [] | |||||
for sign in equation: | |||||
index = sign_index_dict[sign] | |||||
pick = random.randint(0, len(image_pools[index]) - 1) | |||||
if is_color: | |||||
image = Image.open(image_pools[index][pick]).convert('RGB').resize(shape) | |||||
else: | |||||
image = Image.open(image_pools[index][pick]).convert('I').resize(shape) | |||||
image_array = np.array(image) | |||||
image_array = (image_array-127)*(1./128) | |||||
data.append(image_array) | |||||
ret.append(np.array(data)) | |||||
return ret | |||||
def get_equation_std_data(data_dir, sign_dir_lists, sign_output_lists, shape = (28, 28), train_max_equation_len = 10, test_max_equation_len = 10, system_num = 2, tmp_file_prev = | |||||
None, seed = None, train_num_per_len = 10, test_num_per_len = 10, is_color = False): | |||||
tmp_file = "" | |||||
if (tmp_file_prev is not None): | |||||
tmp_file = "%s_train_len_%d_test_len_%d_sys_%d_.pk" % (tmp_file_prev, train_max_equation_len, test_max_equation_len, system_num) | |||||
if (os.path.exists(tmp_file)): | |||||
return pickle.load(open(tmp_file, "rb")) | |||||
image_pools = get_sign_path_list(data_dir, sign_dir_lists) | |||||
train_pool, test_pool = split_pool_by_rate(image_pools, 0.8, seed) | |||||
ret = {} | |||||
for label in ["positive", "negative"]: | |||||
print("Generating equations.") | |||||
train_equations = generator_equations_by_max_len(train_max_equation_len, system_num, label, num_per_len = train_num_per_len) | |||||
test_equations = generator_equations_by_max_len(test_max_equation_len, system_num, label, num_per_len = test_num_per_len) | |||||
print(train_equations) | |||||
print(test_equations) | |||||
print("Generated equations.") | |||||
print("Generating equation image data.") | |||||
ret["train:%s" % (label)] = generator_equation_images(train_pool, train_equations, sign_output_lists, shape, seed, is_color) | |||||
ret["test:%s" % (label)] = generator_equation_images(test_pool, test_equations, sign_output_lists, shape, seed, is_color) | |||||
print("Generated equation image data.") | |||||
if (tmp_file_prev is not None): | |||||
pickle.dump(ret, open(tmp_file, "wb")) | |||||
return ret | |||||
if __name__ == "__main__": | |||||
data_dirs = ["./dataset/hed/mnist_images", "./dataset/hed/random_images"] #, "../dataset/cifar10_images"] | |||||
tmp_file_prevs = ["mnist_equation_data", "random_equation_data"] #, "cifar10_equation_data"] | |||||
for data_dir, tmp_file_prev in zip(data_dirs, tmp_file_prevs): | |||||
data = get_equation_std_data(data_dir = data_dir,\ | |||||
sign_dir_lists = ['0', '1', '10', '11'],\ | |||||
sign_output_lists = ['0', '1', '+', '='],\ | |||||
shape = (28, 28),\ | |||||
train_max_equation_len = 26, \ | |||||
test_max_equation_len = 26, \ | |||||
system_num = 2, \ | |||||
tmp_file_prev = tmp_file_prev, \ | |||||
train_num_per_len = 300, \ | |||||
test_num_per_len = 300, \ | |||||
is_color = False) |
@@ -81,7 +81,7 @@ def split_equation(equations_by_len, prop_train, prop_val): | |||||
return train_equations_by_len, val_equations_by_len | return train_equations_by_len, val_equations_by_len | ||||
def get_hed(dataset="mnist", train=True): | |||||
def get_dataset(dataset="mnist", train=True): | |||||
if dataset == "mnist": | if dataset == "mnist": | ||||
file = osp.join(CURRENT_DIR, "mnist_equation_data_train_len_26_test_len_26_sys_2_.pk") | file = osp.join(CURRENT_DIR, "mnist_equation_data_train_len_26_test_len_26_sys_2_.pk") | ||||
elif dataset == "random": | elif dataset == "random": |
@@ -1,307 +0,0 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"import os.path as osp\n", | |||||
"\n", | |||||
"import numpy as np\n", | |||||
"import torch\n", | |||||
"import torch.nn as nn\n", | |||||
"from zoopt import Dimension, Objective, Opt, Parameter\n", | |||||
"\n", | |||||
"from abl.evaluation import ReasoningMetric, SymbolMetric\n", | |||||
"from abl.learning import ABLModel, BasicNN\n", | |||||
"from abl.reasoning import PrologKB, Reasoner\n", | |||||
"from abl.utils import ABLLogger, print_log, reform_list\n", | |||||
"from examples.hed.datasets.get_hed import get_hed, split_equation\n", | |||||
"from examples.hed.hed_bridge import HEDBridge\n", | |||||
"from examples.models.nn import SymbolNet" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"# Build logger\n", | |||||
"print_log(\"Abductive Learning on the HED example.\", logger=\"current\")\n", | |||||
"\n", | |||||
"# Retrieve the directory of the Log file and define the directory for saving the model weights.\n", | |||||
"log_dir = ABLLogger.get_current_instance().log_dir\n", | |||||
"weights_dir = osp.join(log_dir, \"weights\")" | |||||
] | |||||
}, | |||||
{ | |||||
"attachments": {}, | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### Logic Part" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"# Initialize knowledge base and abducer\n", | |||||
"class HedKB(PrologKB):\n", | |||||
" def __init__(self, pseudo_label_list, pl_file):\n", | |||||
" super().__init__(pseudo_label_list, pl_file)\n", | |||||
"\n", | |||||
" def consist_rule(self, exs, rules):\n", | |||||
" rules = str(rules).replace(\"'\", \"\")\n", | |||||
" return len(list(self.prolog.query(\"eval_inst_feature(%s, %s).\" % (exs, rules)))) != 0\n", | |||||
"\n", | |||||
" def abduce_rules(self, pred_res):\n", | |||||
" prolog_result = list(self.prolog.query(\"consistent_inst_feature(%s, X).\" % pred_res))\n", | |||||
" if len(prolog_result) == 0:\n", | |||||
" return None\n", | |||||
" prolog_rules = prolog_result[0][\"X\"]\n", | |||||
" rules = [rule.value for rule in prolog_rules]\n", | |||||
" return rules\n", | |||||
"\n", | |||||
"\n", | |||||
"class HedReasoner(Reasoner):\n", | |||||
" def revise_at_idx(self, data_example):\n", | |||||
" revision_idx = np.where(np.array(data_example.flatten(\"revision_flag\")) != 0)[0]\n", | |||||
" candidate = self.kb.revise_at_idx(\n", | |||||
" data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx\n", | |||||
" )\n", | |||||
" return candidate\n", | |||||
"\n", | |||||
" def zoopt_revision_score(self, symbol_num, data_example, sol):\n", | |||||
" revision_flag = reform_list(\n", | |||||
" list(sol.get_x().astype(np.int32)), data_example.pred_pseudo_label\n", | |||||
" )\n", | |||||
" data_example.revision_flag = revision_flag\n", | |||||
"\n", | |||||
" lefted_idxs = [i for i in range(len(data_example.pred_idx))]\n", | |||||
" candidate_size = []\n", | |||||
" max_consistent_idxs = []\n", | |||||
" while lefted_idxs:\n", | |||||
" idxs = []\n", | |||||
" idxs.append(lefted_idxs.pop(0))\n", | |||||
" max_candidate_idxs = []\n", | |||||
" found = False\n", | |||||
" for idx in range(-1, len(data_example.pred_idx)):\n", | |||||
" if (not idx in idxs) and (idx >= 0):\n", | |||||
" idxs.append(idx)\n", | |||||
" candidates, _ = self.revise_at_idx(data_example[idxs])\n", | |||||
" if len(candidates) == 0:\n", | |||||
" if len(idxs) > 1:\n", | |||||
" idxs.pop()\n", | |||||
" else:\n", | |||||
" if len(idxs) > len(max_candidate_idxs):\n", | |||||
" found = True\n", | |||||
" max_candidate_idxs = idxs.copy()\n", | |||||
" removed = [i for i in lefted_idxs if i in max_candidate_idxs]\n", | |||||
" if found:\n", | |||||
" removed.insert(0, idxs[0])\n", | |||||
" candidate_size.append(len(removed))\n", | |||||
" max_consistent_idxs = max_candidate_idxs.copy()\n", | |||||
" lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]\n", | |||||
" candidate_size.sort()\n", | |||||
" score = 0\n", | |||||
" import math\n", | |||||
"\n", | |||||
" for i in range(0, len(candidate_size)):\n", | |||||
" score -= math.exp(-i) * candidate_size[i]\n", | |||||
" return score, max_consistent_idxs\n", | |||||
" \n", | |||||
" def _zoopt_get_solution(self, symbol_num, data_example, max_revision_num):\n", | |||||
" dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num)\n", | |||||
" objective = Objective(\n", | |||||
" lambda sol: self.zoopt_revision_score(symbol_num, data_example, sol)[0],\n", | |||||
" dim=dimension,\n", | |||||
" constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),\n", | |||||
" )\n", | |||||
" parameter = Parameter(budget=200, intermediate_result=False, autoset=True)\n", | |||||
" solution = Opt.min(objective, parameter)\n", | |||||
" return solution\n", | |||||
"\n", | |||||
" def abduce(self, data_example):\n", | |||||
" symbol_num = data_example.elements_num(\"pred_pseudo_label\")\n", | |||||
" max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)\n", | |||||
"\n", | |||||
" solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num)\n", | |||||
" _, max_candidate_idxs = self.zoopt_revision_score(symbol_num, data_example, solution)\n", | |||||
"\n", | |||||
" abduced_pseudo_label = [[] for _ in range(len(data_example))]\n", | |||||
"\n", | |||||
" if len(max_candidate_idxs) > 0:\n", | |||||
" candidates, _ = self.revise_at_idx(data_example[max_candidate_idxs])\n", | |||||
" for i, idx in enumerate(max_candidate_idxs):\n", | |||||
" abduced_pseudo_label[idx] = candidates[0][i]\n", | |||||
" data_example.abduced_pseudo_label = abduced_pseudo_label\n", | |||||
" return abduced_pseudo_label\n", | |||||
"\n", | |||||
" def abduce_rules(self, pred_res):\n", | |||||
" return self.kb.abduce_rules(pred_res)\n", | |||||
"\n", | |||||
"\n", | |||||
"kb = HedKB(pseudo_label_list=[1, 0, \"+\", \"=\"], pl_file=\"./datasets/learn_add.pl\")\n", | |||||
"reasoner = HedReasoner(kb, dist_func=\"hamming\", use_zoopt=True, max_revision=10)" | |||||
] | |||||
}, | |||||
{ | |||||
"attachments": {}, | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### Machine Learning Part" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"# Build necessary components for BasicNN\n", | |||||
"cls = SymbolNet(num_classes=4)\n", | |||||
"loss_fn = nn.CrossEntropyLoss()\n", | |||||
"optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-4)\n", | |||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"# Build BasicNN\n", | |||||
"# The function of BasicNN is to wrap NN models into the form of an sklearn estimator\n", | |||||
"base_model = BasicNN(\n", | |||||
" cls,\n", | |||||
" loss_fn,\n", | |||||
" optimizer,\n", | |||||
" device,\n", | |||||
" batch_size=32,\n", | |||||
" num_epochs=1,\n", | |||||
" save_interval=1,\n", | |||||
" stop_loss=None,\n", | |||||
" save_dir=weights_dir,\n", | |||||
")" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"# Build ABLModel\n", | |||||
"# The main function of the ABL model is to serialize data and\n", | |||||
"# provide a unified interface for different machine learning models\n", | |||||
"model = ABLModel(base_model)" | |||||
] | |||||
}, | |||||
{ | |||||
"attachments": {}, | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### Metric" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"# Set up metrics\n", | |||||
"metric_list = [SymbolMetric(prefix=\"hed\"), ReasoningMetric(kb=kb, prefix=\"hed\")]" | |||||
] | |||||
}, | |||||
{ | |||||
"attachments": {}, | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### Bridge Machine Learning and Logic Reasoning" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"bridge = HEDBridge(model, reasoner, metric_list)" | |||||
] | |||||
}, | |||||
{ | |||||
"attachments": {}, | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### Dataset" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"total_train_data = get_hed(train=True)\n", | |||||
"train_data, val_data = split_equation(total_train_data, 3, 1)\n", | |||||
"test_data = get_hed(train=False)" | |||||
] | |||||
}, | |||||
{ | |||||
"attachments": {}, | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### Train and Test" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"bridge.pretrain(\"./weights\")\n", | |||||
"bridge.train(train_data, val_data)" | |||||
] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "ABL", | |||||
"language": "python", | |||||
"name": "python3" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 3 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython3", | |||||
"version": "3.8.18" | |||||
}, | |||||
"orig_nbformat": 4, | |||||
"vscode": { | |||||
"interpreter": { | |||||
"hash": "fb6f4ceeabb9a733f366948eb80109f83aedf798cc984df1e68fb411adb27d58" | |||||
} | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 2 | |||||
} |
@@ -0,0 +1,3 @@ | |||||
from .reasoning import HedKB, HedReasoner | |||||
__all__ = ["HedKB", "HedReasoner"] |
@@ -0,0 +1,93 @@ | |||||
import os | |||||
import numpy as np | |||||
import math | |||||
from zoopt import Dimension, Objective, Opt, Parameter | |||||
from abl.reasoning import PrologKB, Reasoner | |||||
from abl.utils import reform_list | |||||
CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | |||||
class HedKB(PrologKB): | |||||
def __init__(self, pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "learn_add.pl")): | |||||
super().__init__(pseudo_label_list, pl_file) | |||||
def consist_rule(self, exs, rules): | |||||
rules = str(rules).replace("'", "") | |||||
return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0 | |||||
def abduce_rules(self, pred_res): | |||||
prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res)) | |||||
if len(prolog_result) == 0: | |||||
return None | |||||
prolog_rules = prolog_result[0]["X"] | |||||
rules = [rule.value for rule in prolog_rules] | |||||
return rules | |||||
class HedReasoner(Reasoner): | |||||
def revise_at_idx(self, data_example): | |||||
revision_idx = np.where(np.array(data_example.flatten("revision_flag")) != 0)[0] | |||||
candidate = self.kb.revise_at_idx( | |||||
data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx | |||||
) | |||||
return candidate | |||||
def zoopt_revision_score(self, symbol_num, data_example, sol, get_score=True): | |||||
revision_flag = reform_list( | |||||
list(sol.get_x().astype(np.int32)), data_example.pred_pseudo_label | |||||
) | |||||
data_example.revision_flag = revision_flag | |||||
lefted_idxs = [i for i in range(len(data_example.pred_idx))] | |||||
candidate_size = [] | |||||
max_consistent_idxs = [] | |||||
while lefted_idxs: | |||||
idxs = [] | |||||
idxs.append(lefted_idxs.pop(0)) | |||||
max_candidate_idxs = [] | |||||
found = False | |||||
for idx in range(-1, len(data_example.pred_idx)): | |||||
if (not idx in idxs) and (idx >= 0): | |||||
idxs.append(idx) | |||||
candidates, _ = self.revise_at_idx(data_example[idxs]) | |||||
if len(candidates) == 0: | |||||
if len(idxs) > 1: | |||||
idxs.pop() | |||||
else: | |||||
if len(idxs) > len(max_candidate_idxs): | |||||
found = True | |||||
max_candidate_idxs = idxs.copy() | |||||
removed = [i for i in lefted_idxs if i in max_candidate_idxs] | |||||
if found: | |||||
removed.insert(0, idxs[0]) | |||||
candidate_size.append(len(removed)) | |||||
max_consistent_idxs = max_candidate_idxs.copy() | |||||
lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] | |||||
candidate_size.sort() | |||||
score = 0 | |||||
for i in range(0, len(candidate_size)): | |||||
score -= math.exp(-i) * candidate_size[i] | |||||
if get_score: | |||||
return score | |||||
else: | |||||
return max_consistent_idxs | |||||
def abduce(self, data_example): | |||||
symbol_num = data_example.elements_num("pred_pseudo_label") | |||||
max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num) | |||||
solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num) | |||||
max_candidate_idxs = self.zoopt_revision_score(symbol_num, data_example, solution, get_score=False) | |||||
abduced_pseudo_label = [[] for _ in range(len(data_example))] | |||||
if len(max_candidate_idxs) > 0: | |||||
candidates, _ = self.revise_at_idx(data_example[max_candidate_idxs]) | |||||
for i, idx in enumerate(max_candidate_idxs): | |||||
abduced_pseudo_label[idx] = candidates[0][i] | |||||
data_example.abduced_pseudo_label = abduced_pseudo_label | |||||
return abduced_pseudo_label | |||||
def abduce_rules(self, pred_res): | |||||
return self.kb.abduce_rules(pred_res) |
@@ -0,0 +1 @@ | |||||
abl |