@@ -2,7 +2,7 @@ import inspect | |||
from typing import Callable, Any, List, Optional, Union | |||
import numpy as np | |||
from zoopt import Dimension, Objective, Opt, Parameter | |||
from zoopt import Dimension, Objective, Opt, Parameter, Solution | |||
from ..reasoning import KBBase | |||
from ..structures import ListData | |||
@@ -175,9 +175,9 @@ class Reasoner: | |||
symbol_num: int, | |||
data_example: ListData, | |||
max_revision_num: int, | |||
) -> List[bool]: | |||
) -> Solution: | |||
""" | |||
Get the optimal solution using ZOOpt library. The solution is a list of | |||
Get the optimal solution using ZOOpt library. From the solution, we can get a list of | |||
boolean values, where '1' (True) indicates the indices chosen to be revised. | |||
Parameters | |||
@@ -191,7 +191,7 @@ class Reasoner: | |||
Returns | |||
------- | |||
List[bool] | |||
Solution | |||
The solution for ZOOpt library. | |||
""" | |||
dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num) | |||
@@ -201,14 +201,14 @@ class Reasoner: | |||
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), | |||
) | |||
parameter = Parameter(budget=100, intermediate_result=False, autoset=True) | |||
solution = Opt.min(objective, parameter).get_x() | |||
solution = Opt.min(objective, parameter) | |||
return solution | |||
def zoopt_revision_score( | |||
self, | |||
symbol_num: int, | |||
data_example: ListData, | |||
sol: List[bool], | |||
sol: Solution, | |||
) -> int: | |||
""" | |||
Get the revision score for a solution. A lower score suggests that ZOOpt library | |||
@@ -220,7 +220,7 @@ class Reasoner: | |||
Number of total symbols. | |||
data_example : ListData | |||
Data example. | |||
sol: List[bool] | |||
sol: Solution | |||
The solution for ZOOpt library. | |||
Returns | |||
@@ -237,7 +237,7 @@ class Reasoner: | |||
else: | |||
return symbol_num | |||
def _constrain_revision_num(self, solution: List[bool], max_revision_num: int) -> int: | |||
def _constrain_revision_num(self, solution: Solution, max_revision_num: int) -> int: | |||
""" | |||
Constrain that the total number of revisions chosen by the solution does not exceed | |||
maximum number of revisions allowed. | |||
@@ -287,7 +287,7 @@ class Reasoner: | |||
if self.use_zoopt: | |||
solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num) | |||
revision_idx = np.where(solution != 0)[0] | |||
revision_idx = np.where(solution.get_x() != 0)[0] | |||
candidates, reasoning_results = self.kb.revise_at_idx( | |||
pseudo_label=data_example.pred_pseudo_label, | |||
y=data_example.Y, | |||
@@ -10,7 +10,7 @@ from abl.learning import ABLModel, BasicNN | |||
from abl.reasoning import Reasoner | |||
from abl.structures import ListData | |||
from abl.utils import print_log | |||
from examples.hed.datasets.get_hed import get_pretrain_data | |||
from examples.hed.datasets.get_dataset import get_pretrain_data | |||
from examples.hed.utils import InfiniteSampler, gen_mappings | |||
from examples.models.nn import SymbolNetAutoencoder | |||
@@ -0,0 +1,4 @@ | |||
from .get_dataset import get_dataset, split_equation | |||
__all__ = ["get_dataset", "split_equation"] |
@@ -1,173 +0,0 @@ | |||
import os | |||
import itertools | |||
import random | |||
import numpy as np | |||
from PIL import Image | |||
import pickle | |||
def get_sign_path_list(data_dir, sign_names): | |||
sign_num = len(sign_names) | |||
index_dict = dict(zip(sign_names, list(range(sign_num)))) | |||
ret = [[] for _ in range(sign_num)] | |||
for path in os.listdir(data_dir): | |||
if (path in sign_names): | |||
index = index_dict[path] | |||
sign_path = os.path.join(data_dir, path) | |||
for p in os.listdir(sign_path): | |||
ret[index].append(os.path.join(sign_path, p)) | |||
return ret | |||
def split_pool_by_rate(pools, rate, seed = None): | |||
if seed is not None: | |||
random.seed(seed) | |||
ret1 = [] | |||
ret2 = [] | |||
for pool in pools: | |||
random.shuffle(pool) | |||
num = int(len(pool) * rate) | |||
ret1.append(pool[:num]) | |||
ret2.append(pool[num:]) | |||
return ret1, ret2 | |||
def int_to_system_form(num, system_num): | |||
if num == 0: | |||
return "0" | |||
ret = "" | |||
while (num > 0): | |||
ret += str(num % system_num) | |||
num //= system_num | |||
return ret[::-1] | |||
def generator_equations(left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type): | |||
expr_len = left_opt_len + right_opt_len | |||
num_list = "".join([str(i) for i in range(system_num)]) | |||
ret = [] | |||
if generate_type == "all": | |||
candidates = itertools.product(num_list, repeat = expr_len) | |||
else: | |||
candidates = [''.join(random.sample(['0', '1'] * expr_len, expr_len))] | |||
random.shuffle(candidates) | |||
for nums in candidates: | |||
left_num = "".join(nums[:left_opt_len]) | |||
right_num = "".join(nums[left_opt_len:]) | |||
left_value = int(left_num, system_num) | |||
right_value = int(right_num, system_num) | |||
result_value = left_value + right_value | |||
if (label == 'negative'): | |||
result_value += random.randint(-result_value, result_value) | |||
if (left_value + right_value == result_value): | |||
continue | |||
result_num = int_to_system_form(result_value, system_num) | |||
#leading zeros | |||
if (res_opt_len != len(result_num)): | |||
continue | |||
if ((left_opt_len > 1 and left_num[0] == '0') or (right_opt_len > 1 and right_num[0] == '0')): | |||
continue | |||
#add leading zeros | |||
if (res_opt_len < len(result_num)): | |||
continue | |||
while (len(result_num) < res_opt_len): | |||
result_num = '0' + result_num | |||
#continue | |||
ret.append(left_num + '+' + right_num + '=' + result_num) # current only consider '+' and '=' | |||
#print(ret[-1]) | |||
return ret | |||
def generator_equation_by_len(equation_len, system_num = 2, label = 0, require_num = 1): | |||
generate_type = "one" | |||
ret = [] | |||
equation_sign_num = 2 # '+' and '=' | |||
while len(ret) < require_num: | |||
left_opt_len = random.randint(1, equation_len - 1 - equation_sign_num) | |||
right_opt_len = random.randint(1, equation_len - left_opt_len - equation_sign_num) | |||
res_opt_len = equation_len - left_opt_len - right_opt_len - equation_sign_num | |||
ret.extend(generator_equations(left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type)) | |||
return ret | |||
def generator_equations_by_len(equation_len, system_num = 2, label = 0, repeat_times = 1, keep = 1, generate_type = "all"): | |||
ret = [] | |||
equation_sign_num = 2 # '+' and '=' | |||
for left_opt_len in range(1, equation_len - (2 + equation_sign_num) + 1): | |||
for right_opt_len in range(1, equation_len - left_opt_len - (1 + equation_sign_num) + 1): | |||
res_opt_len = equation_len - left_opt_len - right_opt_len - equation_sign_num | |||
for i in range(repeat_times): #generate more equations | |||
if random.random() > keep ** (equation_len): | |||
continue | |||
ret.extend(generator_equations(left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type)) | |||
return ret | |||
def generator_equations_by_max_len(max_equation_len, system_num = 2, label = 0, repeat_times = 1, keep = 1, generate_type = "all", num_per_len = None): | |||
ret = [] | |||
equation_sign_num = 2 # '+' and '=' | |||
for equation_len in range(3 + equation_sign_num, max_equation_len + 1): | |||
if (num_per_len is None): | |||
ret.extend(generator_equations_by_len(equation_len, system_num, label, repeat_times, keep, generate_type)) | |||
else: | |||
ret.extend(generator_equation_by_len(equation_len, system_num, label, require_num = num_per_len)) | |||
return ret | |||
def generator_equation_images(image_pools, equations, signs, shape, seed, is_color): | |||
if (seed is not None): | |||
random.seed(seed) | |||
ret = [] | |||
sign_num = len(signs) | |||
sign_index_dict = dict(zip(signs, list(range(sign_num)))) | |||
for equation in equations: | |||
data = [] | |||
for sign in equation: | |||
index = sign_index_dict[sign] | |||
pick = random.randint(0, len(image_pools[index]) - 1) | |||
if is_color: | |||
image = Image.open(image_pools[index][pick]).convert('RGB').resize(shape) | |||
else: | |||
image = Image.open(image_pools[index][pick]).convert('I').resize(shape) | |||
image_array = np.array(image) | |||
image_array = (image_array-127)*(1./128) | |||
data.append(image_array) | |||
ret.append(np.array(data)) | |||
return ret | |||
def get_equation_std_data(data_dir, sign_dir_lists, sign_output_lists, shape = (28, 28), train_max_equation_len = 10, test_max_equation_len = 10, system_num = 2, tmp_file_prev = | |||
None, seed = None, train_num_per_len = 10, test_num_per_len = 10, is_color = False): | |||
tmp_file = "" | |||
if (tmp_file_prev is not None): | |||
tmp_file = "%s_train_len_%d_test_len_%d_sys_%d_.pk" % (tmp_file_prev, train_max_equation_len, test_max_equation_len, system_num) | |||
if (os.path.exists(tmp_file)): | |||
return pickle.load(open(tmp_file, "rb")) | |||
image_pools = get_sign_path_list(data_dir, sign_dir_lists) | |||
train_pool, test_pool = split_pool_by_rate(image_pools, 0.8, seed) | |||
ret = {} | |||
for label in ["positive", "negative"]: | |||
print("Generating equations.") | |||
train_equations = generator_equations_by_max_len(train_max_equation_len, system_num, label, num_per_len = train_num_per_len) | |||
test_equations = generator_equations_by_max_len(test_max_equation_len, system_num, label, num_per_len = test_num_per_len) | |||
print(train_equations) | |||
print(test_equations) | |||
print("Generated equations.") | |||
print("Generating equation image data.") | |||
ret["train:%s" % (label)] = generator_equation_images(train_pool, train_equations, sign_output_lists, shape, seed, is_color) | |||
ret["test:%s" % (label)] = generator_equation_images(test_pool, test_equations, sign_output_lists, shape, seed, is_color) | |||
print("Generated equation image data.") | |||
if (tmp_file_prev is not None): | |||
pickle.dump(ret, open(tmp_file, "wb")) | |||
return ret | |||
if __name__ == "__main__": | |||
data_dirs = ["./dataset/hed/mnist_images", "./dataset/hed/random_images"] #, "../dataset/cifar10_images"] | |||
tmp_file_prevs = ["mnist_equation_data", "random_equation_data"] #, "cifar10_equation_data"] | |||
for data_dir, tmp_file_prev in zip(data_dirs, tmp_file_prevs): | |||
data = get_equation_std_data(data_dir = data_dir,\ | |||
sign_dir_lists = ['0', '1', '10', '11'],\ | |||
sign_output_lists = ['0', '1', '+', '='],\ | |||
shape = (28, 28),\ | |||
train_max_equation_len = 26, \ | |||
test_max_equation_len = 26, \ | |||
system_num = 2, \ | |||
tmp_file_prev = tmp_file_prev, \ | |||
train_num_per_len = 300, \ | |||
test_num_per_len = 300, \ | |||
is_color = False) |
@@ -81,7 +81,7 @@ def split_equation(equations_by_len, prop_train, prop_val): | |||
return train_equations_by_len, val_equations_by_len | |||
def get_hed(dataset="mnist", train=True): | |||
def get_dataset(dataset="mnist", train=True): | |||
if dataset == "mnist": | |||
file = osp.join(CURRENT_DIR, "mnist_equation_data_train_len_26_test_len_26_sys_2_.pk") | |||
elif dataset == "random": |
@@ -1,307 +0,0 @@ | |||
{ | |||
"cells": [ | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"import os.path as osp\n", | |||
"\n", | |||
"import numpy as np\n", | |||
"import torch\n", | |||
"import torch.nn as nn\n", | |||
"from zoopt import Dimension, Objective, Opt, Parameter\n", | |||
"\n", | |||
"from abl.evaluation import ReasoningMetric, SymbolMetric\n", | |||
"from abl.learning import ABLModel, BasicNN\n", | |||
"from abl.reasoning import PrologKB, Reasoner\n", | |||
"from abl.utils import ABLLogger, print_log, reform_list\n", | |||
"from examples.hed.datasets.get_hed import get_hed, split_equation\n", | |||
"from examples.hed.hed_bridge import HEDBridge\n", | |||
"from examples.models.nn import SymbolNet" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"# Build logger\n", | |||
"print_log(\"Abductive Learning on the HED example.\", logger=\"current\")\n", | |||
"\n", | |||
"# Retrieve the directory of the Log file and define the directory for saving the model weights.\n", | |||
"log_dir = ABLLogger.get_current_instance().log_dir\n", | |||
"weights_dir = osp.join(log_dir, \"weights\")" | |||
] | |||
}, | |||
{ | |||
"attachments": {}, | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"### Logic Part" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"# Initialize knowledge base and abducer\n", | |||
"class HedKB(PrologKB):\n", | |||
" def __init__(self, pseudo_label_list, pl_file):\n", | |||
" super().__init__(pseudo_label_list, pl_file)\n", | |||
"\n", | |||
" def consist_rule(self, exs, rules):\n", | |||
" rules = str(rules).replace(\"'\", \"\")\n", | |||
" return len(list(self.prolog.query(\"eval_inst_feature(%s, %s).\" % (exs, rules)))) != 0\n", | |||
"\n", | |||
" def abduce_rules(self, pred_res):\n", | |||
" prolog_result = list(self.prolog.query(\"consistent_inst_feature(%s, X).\" % pred_res))\n", | |||
" if len(prolog_result) == 0:\n", | |||
" return None\n", | |||
" prolog_rules = prolog_result[0][\"X\"]\n", | |||
" rules = [rule.value for rule in prolog_rules]\n", | |||
" return rules\n", | |||
"\n", | |||
"\n", | |||
"class HedReasoner(Reasoner):\n", | |||
" def revise_at_idx(self, data_example):\n", | |||
" revision_idx = np.where(np.array(data_example.flatten(\"revision_flag\")) != 0)[0]\n", | |||
" candidate = self.kb.revise_at_idx(\n", | |||
" data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx\n", | |||
" )\n", | |||
" return candidate\n", | |||
"\n", | |||
" def zoopt_revision_score(self, symbol_num, data_example, sol):\n", | |||
" revision_flag = reform_list(\n", | |||
" list(sol.get_x().astype(np.int32)), data_example.pred_pseudo_label\n", | |||
" )\n", | |||
" data_example.revision_flag = revision_flag\n", | |||
"\n", | |||
" lefted_idxs = [i for i in range(len(data_example.pred_idx))]\n", | |||
" candidate_size = []\n", | |||
" max_consistent_idxs = []\n", | |||
" while lefted_idxs:\n", | |||
" idxs = []\n", | |||
" idxs.append(lefted_idxs.pop(0))\n", | |||
" max_candidate_idxs = []\n", | |||
" found = False\n", | |||
" for idx in range(-1, len(data_example.pred_idx)):\n", | |||
" if (not idx in idxs) and (idx >= 0):\n", | |||
" idxs.append(idx)\n", | |||
" candidates, _ = self.revise_at_idx(data_example[idxs])\n", | |||
" if len(candidates) == 0:\n", | |||
" if len(idxs) > 1:\n", | |||
" idxs.pop()\n", | |||
" else:\n", | |||
" if len(idxs) > len(max_candidate_idxs):\n", | |||
" found = True\n", | |||
" max_candidate_idxs = idxs.copy()\n", | |||
" removed = [i for i in lefted_idxs if i in max_candidate_idxs]\n", | |||
" if found:\n", | |||
" removed.insert(0, idxs[0])\n", | |||
" candidate_size.append(len(removed))\n", | |||
" max_consistent_idxs = max_candidate_idxs.copy()\n", | |||
" lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]\n", | |||
" candidate_size.sort()\n", | |||
" score = 0\n", | |||
" import math\n", | |||
"\n", | |||
" for i in range(0, len(candidate_size)):\n", | |||
" score -= math.exp(-i) * candidate_size[i]\n", | |||
" return score, max_consistent_idxs\n", | |||
" \n", | |||
" def _zoopt_get_solution(self, symbol_num, data_example, max_revision_num):\n", | |||
" dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num)\n", | |||
" objective = Objective(\n", | |||
" lambda sol: self.zoopt_revision_score(symbol_num, data_example, sol)[0],\n", | |||
" dim=dimension,\n", | |||
" constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),\n", | |||
" )\n", | |||
" parameter = Parameter(budget=200, intermediate_result=False, autoset=True)\n", | |||
" solution = Opt.min(objective, parameter)\n", | |||
" return solution\n", | |||
"\n", | |||
" def abduce(self, data_example):\n", | |||
" symbol_num = data_example.elements_num(\"pred_pseudo_label\")\n", | |||
" max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)\n", | |||
"\n", | |||
" solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num)\n", | |||
" _, max_candidate_idxs = self.zoopt_revision_score(symbol_num, data_example, solution)\n", | |||
"\n", | |||
" abduced_pseudo_label = [[] for _ in range(len(data_example))]\n", | |||
"\n", | |||
" if len(max_candidate_idxs) > 0:\n", | |||
" candidates, _ = self.revise_at_idx(data_example[max_candidate_idxs])\n", | |||
" for i, idx in enumerate(max_candidate_idxs):\n", | |||
" abduced_pseudo_label[idx] = candidates[0][i]\n", | |||
" data_example.abduced_pseudo_label = abduced_pseudo_label\n", | |||
" return abduced_pseudo_label\n", | |||
"\n", | |||
" def abduce_rules(self, pred_res):\n", | |||
" return self.kb.abduce_rules(pred_res)\n", | |||
"\n", | |||
"\n", | |||
"kb = HedKB(pseudo_label_list=[1, 0, \"+\", \"=\"], pl_file=\"./datasets/learn_add.pl\")\n", | |||
"reasoner = HedReasoner(kb, dist_func=\"hamming\", use_zoopt=True, max_revision=10)" | |||
] | |||
}, | |||
{ | |||
"attachments": {}, | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"### Machine Learning Part" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"# Build necessary components for BasicNN\n", | |||
"cls = SymbolNet(num_classes=4)\n", | |||
"loss_fn = nn.CrossEntropyLoss()\n", | |||
"optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-4)\n", | |||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"# Build BasicNN\n", | |||
"# The function of BasicNN is to wrap NN models into the form of an sklearn estimator\n", | |||
"base_model = BasicNN(\n", | |||
" cls,\n", | |||
" loss_fn,\n", | |||
" optimizer,\n", | |||
" device,\n", | |||
" batch_size=32,\n", | |||
" num_epochs=1,\n", | |||
" save_interval=1,\n", | |||
" stop_loss=None,\n", | |||
" save_dir=weights_dir,\n", | |||
")" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"# Build ABLModel\n", | |||
"# The main function of the ABL model is to serialize data and\n", | |||
"# provide a unified interface for different machine learning models\n", | |||
"model = ABLModel(base_model)" | |||
] | |||
}, | |||
{ | |||
"attachments": {}, | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"### Metric" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"# Set up metrics\n", | |||
"metric_list = [SymbolMetric(prefix=\"hed\"), ReasoningMetric(kb=kb, prefix=\"hed\")]" | |||
] | |||
}, | |||
{ | |||
"attachments": {}, | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"### Bridge Machine Learning and Logic Reasoning" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"bridge = HEDBridge(model, reasoner, metric_list)" | |||
] | |||
}, | |||
{ | |||
"attachments": {}, | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"### Dataset" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"total_train_data = get_hed(train=True)\n", | |||
"train_data, val_data = split_equation(total_train_data, 3, 1)\n", | |||
"test_data = get_hed(train=False)" | |||
] | |||
}, | |||
{ | |||
"attachments": {}, | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"### Train and Test" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"bridge.pretrain(\"./weights\")\n", | |||
"bridge.train(train_data, val_data)" | |||
] | |||
} | |||
], | |||
"metadata": { | |||
"kernelspec": { | |||
"display_name": "ABL", | |||
"language": "python", | |||
"name": "python3" | |||
}, | |||
"language_info": { | |||
"codemirror_mode": { | |||
"name": "ipython", | |||
"version": 3 | |||
}, | |||
"file_extension": ".py", | |||
"mimetype": "text/x-python", | |||
"name": "python", | |||
"nbconvert_exporter": "python", | |||
"pygments_lexer": "ipython3", | |||
"version": "3.8.18" | |||
}, | |||
"orig_nbformat": 4, | |||
"vscode": { | |||
"interpreter": { | |||
"hash": "fb6f4ceeabb9a733f366948eb80109f83aedf798cc984df1e68fb411adb27d58" | |||
} | |||
} | |||
}, | |||
"nbformat": 4, | |||
"nbformat_minor": 2 | |||
} |
@@ -0,0 +1,3 @@ | |||
from .reasoning import HedKB, HedReasoner | |||
__all__ = ["HedKB", "HedReasoner"] |
@@ -0,0 +1,93 @@ | |||
import os | |||
import numpy as np | |||
import math | |||
from zoopt import Dimension, Objective, Opt, Parameter | |||
from abl.reasoning import PrologKB, Reasoner | |||
from abl.utils import reform_list | |||
CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | |||
class HedKB(PrologKB): | |||
def __init__(self, pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "learn_add.pl")): | |||
super().__init__(pseudo_label_list, pl_file) | |||
def consist_rule(self, exs, rules): | |||
rules = str(rules).replace("'", "") | |||
return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0 | |||
def abduce_rules(self, pred_res): | |||
prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res)) | |||
if len(prolog_result) == 0: | |||
return None | |||
prolog_rules = prolog_result[0]["X"] | |||
rules = [rule.value for rule in prolog_rules] | |||
return rules | |||
class HedReasoner(Reasoner): | |||
def revise_at_idx(self, data_example): | |||
revision_idx = np.where(np.array(data_example.flatten("revision_flag")) != 0)[0] | |||
candidate = self.kb.revise_at_idx( | |||
data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx | |||
) | |||
return candidate | |||
def zoopt_revision_score(self, symbol_num, data_example, sol, get_score=True): | |||
revision_flag = reform_list( | |||
list(sol.get_x().astype(np.int32)), data_example.pred_pseudo_label | |||
) | |||
data_example.revision_flag = revision_flag | |||
lefted_idxs = [i for i in range(len(data_example.pred_idx))] | |||
candidate_size = [] | |||
max_consistent_idxs = [] | |||
while lefted_idxs: | |||
idxs = [] | |||
idxs.append(lefted_idxs.pop(0)) | |||
max_candidate_idxs = [] | |||
found = False | |||
for idx in range(-1, len(data_example.pred_idx)): | |||
if (not idx in idxs) and (idx >= 0): | |||
idxs.append(idx) | |||
candidates, _ = self.revise_at_idx(data_example[idxs]) | |||
if len(candidates) == 0: | |||
if len(idxs) > 1: | |||
idxs.pop() | |||
else: | |||
if len(idxs) > len(max_candidate_idxs): | |||
found = True | |||
max_candidate_idxs = idxs.copy() | |||
removed = [i for i in lefted_idxs if i in max_candidate_idxs] | |||
if found: | |||
removed.insert(0, idxs[0]) | |||
candidate_size.append(len(removed)) | |||
max_consistent_idxs = max_candidate_idxs.copy() | |||
lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] | |||
candidate_size.sort() | |||
score = 0 | |||
for i in range(0, len(candidate_size)): | |||
score -= math.exp(-i) * candidate_size[i] | |||
if get_score: | |||
return score | |||
else: | |||
return max_consistent_idxs | |||
def abduce(self, data_example): | |||
symbol_num = data_example.elements_num("pred_pseudo_label") | |||
max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num) | |||
solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num) | |||
max_candidate_idxs = self.zoopt_revision_score(symbol_num, data_example, solution, get_score=False) | |||
abduced_pseudo_label = [[] for _ in range(len(data_example))] | |||
if len(max_candidate_idxs) > 0: | |||
candidates, _ = self.revise_at_idx(data_example[max_candidate_idxs]) | |||
for i, idx in enumerate(max_candidate_idxs): | |||
abduced_pseudo_label[idx] = candidates[0][i] | |||
data_example.abduced_pseudo_label = abduced_pseudo_label | |||
return abduced_pseudo_label | |||
def abduce_rules(self, pred_res): | |||
return self.kb.abduce_rules(pred_res) |
@@ -0,0 +1 @@ | |||
abl |