@@ -1,5 +1,298 @@ | |||||
Handwritten Equation Deciphering (HED) | |||||
====================================== | |||||
Handwritten Equation Decipherment (HED) | |||||
======================================= | |||||
.. contents:: Table of Contents | |||||
Below shows an implementation of `Handwritten Equation | |||||
Decipherment <https://proceedings.neurips.cc/paper_files/paper/2019/file/9c19a2aa1d84e04b0bd4bc888792bd1e-Paper.pdf>`__. | |||||
In this task, the handwritten equations are given, which consist of | |||||
sequential pictures of characters. The equations are generated with | |||||
unknown operation rules from images of symbols (‘0’, ‘1’, ‘+’ and ‘=’), | |||||
and each equation is associated with a label indicating whether the | |||||
equation is correct (i.e., positive) or not (i.e., negative). Also, we | |||||
are given a knowledge base which involves the structure of the equations | |||||
and a recursive definition of bit-wise operations. The task is to learn | |||||
from a training set of above mentioned equations and then to predict | |||||
labels of unseen equations. | |||||
Intuitively, we first use a machine learning model (learning part) to | |||||
obtain the pseudo-labels (‘0’, ‘1’, ‘+’ and ‘=’) for the observed | |||||
pictures. We then use the knowledge base (reasoning part) to perform | |||||
abductive reasoning so as to yield ground hypotheses as possible | |||||
explanations to the observed facts, suggesting some pseudo-labels to be | |||||
revised. This process enables us to further update the machine learning | |||||
model. | |||||
.. code:: ipython3 | |||||
# Import necessary libraries and modules | |||||
import os.path as osp | |||||
import torch | |||||
import torch.nn as nn | |||||
import matplotlib.pyplot as plt | |||||
from examples.hed.datasets import get_dataset, split_equation | |||||
from examples.models.nn import SymbolNet | |||||
from abl.learning import ABLModel, BasicNN | |||||
from examples.hed.reasoning import HedKB, HedReasoner | |||||
from abl.evaluation import ReasoningMetric, SymbolMetric | |||||
from abl.utils import ABLLogger, print_log | |||||
from examples.hed.bridge import HedBridge | |||||
Working with Data | |||||
----------------- | |||||
First, we get the datasets of handwritten equations: | |||||
.. code:: ipython3 | |||||
total_train_data = get_dataset(train=True) | |||||
train_data, val_data = split_equation(total_train_data, 3, 1) | |||||
test_data = get_dataset(train=False) | |||||
The dataset are shown below: | |||||
.. code:: ipython3 | |||||
true_train_equation = train_data[1] | |||||
false_train_equation = train_data[0] | |||||
print(f"Equations in the dataset is organized by equation length, " + | |||||
f"from {min(train_data[0].keys())} to {max(train_data[0].keys())}") | |||||
print() | |||||
true_train_equation_with_length_5 = true_train_equation[5] | |||||
false_train_equation_with_length_5 = false_train_equation[5] | |||||
print(f"For each euqation length, there are {len(true_train_equation_with_length_5)} " + | |||||
f"true equation and {len(false_train_equation_with_length_5)} false equation " + | |||||
f"in the training set") | |||||
true_val_equation = val_data[1] | |||||
false_val_equation = val_data[0] | |||||
true_val_equation_with_length_5 = true_val_equation[5] | |||||
false_val_equation_with_length_5 = false_val_equation[5] | |||||
print(f"For each euqation length, there are {len(true_val_equation_with_length_5)} " + | |||||
f"true equation and {len(false_val_equation_with_length_5)} false equation " + | |||||
f"in the validation set") | |||||
true_test_equation = test_data[1] | |||||
false_test_equation = test_data[0] | |||||
true_test_equation_with_length_5 = true_test_equation[5] | |||||
false_test_equation_with_length_5 = false_test_equation[5] | |||||
print(f"For each euqation length, there are {len(true_test_equation_with_length_5)} " + | |||||
f"true equation and {len(false_test_equation_with_length_5)} false equation " + | |||||
f"in the test set") | |||||
Out: | |||||
.. code:: none | |||||
:class: code-out | |||||
Equations in the dataset is organized by equation length, from 5 to 26 | |||||
For each euqation length, there are 225 true equation and 225 false equation in the training set | |||||
For each euqation length, there are 75 true equation and 75 false equation in the validation set | |||||
For each euqation length, there are 300 true equation and 300 false equation in the test set | |||||
As illustrations, we show four equations in the training dataset: | |||||
.. code:: ipython3 | |||||
true_train_equation_with_length_5 = true_train_equation[5] | |||||
true_train_equation_with_length_8 = true_train_equation[8] | |||||
print(f"First true equation with length 5 in the training dataset:") | |||||
for i, x in enumerate(true_train_equation_with_length_5[0]): | |||||
plt.subplot(1, 5, i+1) | |||||
plt.axis('off') | |||||
plt.imshow(x.transpose(1, 2, 0)) | |||||
plt.show() | |||||
print(f"First true equation with length 8 in the training dataset:") | |||||
for i, x in enumerate(true_train_equation_with_length_8[0]): | |||||
plt.subplot(1, 8, i+1) | |||||
plt.axis('off') | |||||
plt.imshow(x.transpose(1, 2, 0)) | |||||
plt.show() | |||||
false_train_equation_with_length_5 = false_train_equation[5] | |||||
false_train_equation_with_length_8 = false_train_equation[8] | |||||
print(f"First false equation with length 5 in the training dataset:") | |||||
for i, x in enumerate(false_train_equation_with_length_5[0]): | |||||
plt.subplot(1, 5, i+1) | |||||
plt.axis('off') | |||||
plt.imshow(x.transpose(1, 2, 0)) | |||||
plt.show() | |||||
print(f"First false equation with length 8 in the training dataset:") | |||||
for i, x in enumerate(false_train_equation_with_length_8[0]): | |||||
plt.subplot(1, 8, i+1) | |||||
plt.axis('off') | |||||
plt.imshow(x.transpose(1, 2, 0)) | |||||
plt.show() | |||||
Out: | |||||
.. code:: none | |||||
:class: code-out | |||||
First true equation with length 5 in the training dataset: | |||||
.. image:: ../img/hed_dataset1.png | |||||
:width: 300px | |||||
Out: | |||||
.. code:: none | |||||
:class: code-out | |||||
First true equation with length 8 in the training dataset: | |||||
.. image:: ../img/hed_dataset2.png | |||||
:width: 480px | |||||
Out: | |||||
.. code:: none | |||||
:class: code-out | |||||
First false equation with length 5 in the training dataset: | |||||
.. image:: ../img/hed_dataset3.png | |||||
:width: 300px | |||||
Out: | |||||
.. code:: none | |||||
:class: code-out | |||||
First false equation with length 8 in the training dataset: | |||||
.. image:: ../img/hed_dataset4.png | |||||
:width: 480px | |||||
Building the Learning Part | |||||
-------------------------- | |||||
To build the learning part, we need to first build a machine learning | |||||
base model. We use SymbolNet, and encapsulate it within a ``BasicNN`` | |||||
object to create the base model. ``BasicNN`` is a class that | |||||
encapsulates a PyTorch model, transforming it into a base model with an | |||||
sklearn-style interface. | |||||
.. code:: ipython3 | |||||
# class of symbol may be one of ['0', '1', '+', '='], total of 4 classes | |||||
cls = SymbolNet(num_classes=4) | |||||
loss_fn = nn.CrossEntropyLoss() | |||||
optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-4) | |||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||||
base_model = BasicNN( | |||||
cls, | |||||
loss_fn, | |||||
optimizer, | |||||
device, | |||||
batch_size=32, | |||||
num_epochs=1, | |||||
stop_loss=None, | |||||
) | |||||
However, the base model built above deals with instance-level data | |||||
(i.e., individual images), and can not directly deal with example-level | |||||
data (i.e., a list of images comprising the equation). Therefore, we | |||||
wrap the base model into ``ABLModel``, which enables the learning part | |||||
to train, test, and predict on example-level data. | |||||
.. code:: ipython3 | |||||
model = ABLModel(base_model) | |||||
Building the Reasoning Part | |||||
--------------------------- | |||||
In the reasoning part, we first build a knowledge base. As mentioned | |||||
before, the knowledge base in this task involves the structure of the | |||||
equations and a recursive definition of bit-wise operations. The | |||||
knowledge base is already defined in ``HedKB``, which is derived from | |||||
``PrologKB``, and is built upon Prolog file ``reasoning/BK.pl`` and | |||||
``reasoning/learn_add.pl``. | |||||
Specifically, the knowledge about the structure of equations (in | |||||
``reasoning/BK.pl``) is a set of DCG (definite clause grammar) rules | |||||
recursively define that a digit is a sequence of ‘0’ and ‘1’, and | |||||
equations share the structure of X+Y=Z, though the length of X, Y and Z | |||||
can be varied. The knowledge about bit-wise operations (in | |||||
``reasoning/learn_add.pl``) is a recursive logic program, which | |||||
reversely calculates X+Y, i.e., it operates on X and Y digit-by-digit | |||||
and from the last digit to the first. | |||||
Note: Please notice that, the specific rules for calculating the | |||||
operations are undefined in the knowledge base, i.e., results of ‘0+0’, | |||||
‘0+1’ and ‘1+1’ could be ‘0’, ‘1’, ‘00’, ‘01’ or even ‘10’. The missing | |||||
calculation rules are required to be learned from the data. Therefore, | |||||
``HedKB`` incorporates methods for abducing rules from data. Users | |||||
interested can refer to the specific implementation of ``HedKB`` in | |||||
``reasoning/reasoning.py`` | |||||
.. code:: ipython3 | |||||
kb = HedKB() | |||||
Then, we create a reasoner. Due to the indeterminism of abductive | |||||
reasoning, there could be multiple candidates compatible to the | |||||
knowledge base. When this happens, reasoner can minimize inconsistencies | |||||
between the knowledge base and pseudo-labels predicted by the learning | |||||
part, and then return only one candidate that has the highest | |||||
consistency. | |||||
In this task, we create the reasoner by instantiating the class | |||||
``HedReasoner``, which is a reasoner derived from ``Reasoner`` and | |||||
tailored specifically for this task. ``HedReasoner`` leverages `ZOOpt | |||||
library <https://github.com/polixir/ZOOpt>`__ for acceleration, and has | |||||
designed a specific strategy to better harness ZOOpt’s capabilities. | |||||
Additionally, methods for abducing rules from data have been | |||||
incorporated. Users interested can refer to the specific implementation | |||||
of ``HedReasoner`` in ``reasoning/reasoning.py``. | |||||
.. code:: ipython3 | |||||
reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=10) | |||||
Building Evaluation Metrics | |||||
--------------------------- | |||||
Next, we set up evaluation metrics. These metrics will be used to | |||||
evaluate the model performance during training and testing. | |||||
Specifically, we use ``SymbolMetric`` and ``ReasoningMetric``, which are | |||||
used to evaluate the accuracy of the machine learning model’s | |||||
predictions and the accuracy of the final reasoning results, | |||||
respectively. | |||||
.. code:: ipython3 | |||||
# Set up metrics | |||||
metric_list = [SymbolMetric(prefix="hed"), ReasoningMetric(kb=kb, prefix="hed")] | |||||
Bridge Learning and Reasoning | |||||
----------------------------- | |||||
Now, the last step is to bridge the learning and reasoning part. We | |||||
proceed this step by creating an instance of ``HedBridge``, which is | |||||
derived from ``SimpleBridge`` and tailored specific for this task. | |||||
.. code:: ipython3 | |||||
bridge = HedBridge(model, reasoner, metric_list) | |||||
Perform training and testing. | |||||
**[TODO]** give a detailed introduction about training in HedBridge. | |||||
.. code:: ipython3 | |||||
# Build logger | |||||
print_log("Abductive Learning on the HED example.", logger="current") | |||||
# Retrieve the directory of the Log file and define the directory for saving the model weights. | |||||
log_dir = ABLLogger.get_current_instance().log_dir | |||||
weights_dir = osp.join(log_dir, "weights") | |||||
bridge.pretrain("./weights") | |||||
bridge.train(train_data, val_data) | |||||
bridge.test(test_data) |
@@ -2,7 +2,7 @@ Handwritten Formula (HWF) | |||||
========================= | ========================= | ||||
Below shows an implementation of `Handwritten | Below shows an implementation of `Handwritten | ||||
Formula <https://arxiv.org/abs/2006.06649>`__. In this task. In this | |||||
Formula <https://arxiv.org/abs/2006.06649>`__. In this | |||||
task, handwritten images of decimal formulas and their computed results | task, handwritten images of decimal formulas and their computed results | ||||
are given, alongwith a domain knowledge base containing information on | are given, alongwith a domain knowledge base containing information on | ||||
how to compute the decimal formula. The task is to recognize the symbols | how to compute the decimal formula. The task is to recognize the symbols | ||||
@@ -140,7 +140,7 @@ model with an sklearn-style interface. | |||||
cls = LeNet5(num_classes=10) | cls = LeNet5(num_classes=10) | ||||
loss_fn = nn.CrossEntropyLoss() | loss_fn = nn.CrossEntropyLoss() | ||||
optimizer = torch.optim.Adam(cls.parameters(), lr=0.001) | |||||
optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, alpha=0.9) | |||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||||
base_model = BasicNN( | base_model = BasicNN( | ||||
@@ -0,0 +1,2 @@ | |||||
ZOO | |||||
=== |
@@ -26,6 +26,7 @@ | |||||
Examples/MNISTAdd | Examples/MNISTAdd | ||||
Examples/HWF | Examples/HWF | ||||
Examples/HED | Examples/HED | ||||
Examples/ZOO | |||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 1 | :maxdepth: 1 | ||||
@@ -10,12 +10,12 @@ from abl.learning import ABLModel, BasicNN | |||||
from abl.reasoning import Reasoner | from abl.reasoning import Reasoner | ||||
from abl.structures import ListData | from abl.structures import ListData | ||||
from abl.utils import print_log | from abl.utils import print_log | ||||
from examples.hed.datasets.get_dataset import get_pretrain_data | |||||
from examples.hed.datasets import get_pretrain_data | |||||
from examples.hed.utils import InfiniteSampler, gen_mappings | from examples.hed.utils import InfiniteSampler, gen_mappings | ||||
from examples.models.nn import SymbolNetAutoencoder | from examples.models.nn import SymbolNetAutoencoder | ||||
class HEDBridge(SimpleBridge): | |||||
class HedBridge(SimpleBridge): | |||||
def __init__( | def __init__( | ||||
self, | self, | ||||
model: ABLModel, | model: ABLModel, | ||||
@@ -1,4 +0,0 @@ | |||||
Download the Handwritten Equation Decipherment dataset from [NJU Box](https://box.nju.edu.cn/f/391c2d48c32b436cb833/) to this folder and unzip it: | |||||
``` | |||||
unzip HED.zip | |||||
``` |
@@ -1,4 +1,4 @@ | |||||
from .get_dataset import get_dataset, split_equation | |||||
from .get_dataset import get_dataset, get_pretrain_data, split_equation | |||||
__all__ = ["get_dataset", "split_equation"] | |||||
__all__ = ["get_dataset", "get_pretrain_data", "split_equation"] |
@@ -0,0 +1,173 @@ | |||||
import os | |||||
import itertools | |||||
import random | |||||
import numpy as np | |||||
from PIL import Image | |||||
import pickle | |||||
def get_sign_path_list(data_dir, sign_names): | |||||
sign_num = len(sign_names) | |||||
index_dict = dict(zip(sign_names, list(range(sign_num)))) | |||||
ret = [[] for _ in range(sign_num)] | |||||
for path in os.listdir(data_dir): | |||||
if (path in sign_names): | |||||
index = index_dict[path] | |||||
sign_path = os.path.join(data_dir, path) | |||||
for p in os.listdir(sign_path): | |||||
ret[index].append(os.path.join(sign_path, p)) | |||||
return ret | |||||
def split_pool_by_rate(pools, rate, seed = None): | |||||
if seed is not None: | |||||
random.seed(seed) | |||||
ret1 = [] | |||||
ret2 = [] | |||||
for pool in pools: | |||||
random.shuffle(pool) | |||||
num = int(len(pool) * rate) | |||||
ret1.append(pool[:num]) | |||||
ret2.append(pool[num:]) | |||||
return ret1, ret2 | |||||
def int_to_system_form(num, system_num): | |||||
if num == 0: | |||||
return "0" | |||||
ret = "" | |||||
while (num > 0): | |||||
ret += str(num % system_num) | |||||
num //= system_num | |||||
return ret[::-1] | |||||
def generator_equations(left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type): | |||||
expr_len = left_opt_len + right_opt_len | |||||
num_list = "".join([str(i) for i in range(system_num)]) | |||||
ret = [] | |||||
if generate_type == "all": | |||||
candidates = itertools.product(num_list, repeat = expr_len) | |||||
else: | |||||
candidates = [''.join(random.sample(['0', '1'] * expr_len, expr_len))] | |||||
random.shuffle(candidates) | |||||
for nums in candidates: | |||||
left_num = "".join(nums[:left_opt_len]) | |||||
right_num = "".join(nums[left_opt_len:]) | |||||
left_value = int(left_num, system_num) | |||||
right_value = int(right_num, system_num) | |||||
result_value = left_value + right_value | |||||
if (label == 'negative'): | |||||
result_value += random.randint(-result_value, result_value) | |||||
if (left_value + right_value == result_value): | |||||
continue | |||||
result_num = int_to_system_form(result_value, system_num) | |||||
#leading zeros | |||||
if (res_opt_len != len(result_num)): | |||||
continue | |||||
if ((left_opt_len > 1 and left_num[0] == '0') or (right_opt_len > 1 and right_num[0] == '0')): | |||||
continue | |||||
#add leading zeros | |||||
if (res_opt_len < len(result_num)): | |||||
continue | |||||
while (len(result_num) < res_opt_len): | |||||
result_num = '0' + result_num | |||||
#continue | |||||
ret.append(left_num + '+' + right_num + '=' + result_num) # current only consider '+' and '=' | |||||
#print(ret[-1]) | |||||
return ret | |||||
def generator_equation_by_len(equation_len, system_num = 2, label = 0, require_num = 1): | |||||
generate_type = "one" | |||||
ret = [] | |||||
equation_sign_num = 2 # '+' and '=' | |||||
while len(ret) < require_num: | |||||
left_opt_len = random.randint(1, equation_len - 1 - equation_sign_num) | |||||
right_opt_len = random.randint(1, equation_len - left_opt_len - equation_sign_num) | |||||
res_opt_len = equation_len - left_opt_len - right_opt_len - equation_sign_num | |||||
ret.extend(generator_equations(left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type)) | |||||
return ret | |||||
def generator_equations_by_len(equation_len, system_num = 2, label = 0, repeat_times = 1, keep = 1, generate_type = "all"): | |||||
ret = [] | |||||
equation_sign_num = 2 # '+' and '=' | |||||
for left_opt_len in range(1, equation_len - (2 + equation_sign_num) + 1): | |||||
for right_opt_len in range(1, equation_len - left_opt_len - (1 + equation_sign_num) + 1): | |||||
res_opt_len = equation_len - left_opt_len - right_opt_len - equation_sign_num | |||||
for i in range(repeat_times): #generate more equations | |||||
if random.random() > keep ** (equation_len): | |||||
continue | |||||
ret.extend(generator_equations(left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type)) | |||||
return ret | |||||
def generator_equations_by_max_len(max_equation_len, system_num = 2, label = 0, repeat_times = 1, keep = 1, generate_type = "all", num_per_len = None): | |||||
ret = [] | |||||
equation_sign_num = 2 # '+' and '=' | |||||
for equation_len in range(3 + equation_sign_num, max_equation_len + 1): | |||||
if (num_per_len is None): | |||||
ret.extend(generator_equations_by_len(equation_len, system_num, label, repeat_times, keep, generate_type)) | |||||
else: | |||||
ret.extend(generator_equation_by_len(equation_len, system_num, label, require_num = num_per_len)) | |||||
return ret | |||||
def generator_equation_images(image_pools, equations, signs, shape, seed, is_color): | |||||
if (seed is not None): | |||||
random.seed(seed) | |||||
ret = [] | |||||
sign_num = len(signs) | |||||
sign_index_dict = dict(zip(signs, list(range(sign_num)))) | |||||
for equation in equations: | |||||
data = [] | |||||
for sign in equation: | |||||
index = sign_index_dict[sign] | |||||
pick = random.randint(0, len(image_pools[index]) - 1) | |||||
if is_color: | |||||
image = Image.open(image_pools[index][pick]).convert('RGB').resize(shape) | |||||
else: | |||||
image = Image.open(image_pools[index][pick]).convert('I').resize(shape) | |||||
image_array = np.array(image) | |||||
image_array = (image_array-127)*(1./128) | |||||
data.append(image_array) | |||||
ret.append(np.array(data)) | |||||
return ret | |||||
def get_equation_std_data(data_dir, sign_dir_lists, sign_output_lists, shape = (28, 28), train_max_equation_len = 10, test_max_equation_len = 10, system_num = 2, tmp_file_prev = | |||||
None, seed = None, train_num_per_len = 10, test_num_per_len = 10, is_color = False): | |||||
tmp_file = "" | |||||
if (tmp_file_prev is not None): | |||||
tmp_file = "%s_train_len_%d_test_len_%d_sys_%d_.pk" % (tmp_file_prev, train_max_equation_len, test_max_equation_len, system_num) | |||||
if (os.path.exists(tmp_file)): | |||||
return pickle.load(open(tmp_file, "rb")) | |||||
image_pools = get_sign_path_list(data_dir, sign_dir_lists) | |||||
train_pool, test_pool = split_pool_by_rate(image_pools, 0.8, seed) | |||||
ret = {} | |||||
for label in ["positive", "negative"]: | |||||
print("Generating equations.") | |||||
train_equations = generator_equations_by_max_len(train_max_equation_len, system_num, label, num_per_len = train_num_per_len) | |||||
test_equations = generator_equations_by_max_len(test_max_equation_len, system_num, label, num_per_len = test_num_per_len) | |||||
print(train_equations) | |||||
print(test_equations) | |||||
print("Generated equations.") | |||||
print("Generating equation image data.") | |||||
ret["train:%s" % (label)] = generator_equation_images(train_pool, train_equations, sign_output_lists, shape, seed, is_color) | |||||
ret["test:%s" % (label)] = generator_equation_images(test_pool, test_equations, sign_output_lists, shape, seed, is_color) | |||||
print("Generated equation image data.") | |||||
if (tmp_file_prev is not None): | |||||
pickle.dump(ret, open(tmp_file, "wb")) | |||||
return ret | |||||
if __name__ == "__main__": | |||||
data_dirs = ["./dataset/hed/mnist_images", "./dataset/hed/random_images"] #, "../dataset/cifar10_images"] | |||||
tmp_file_prevs = ["mnist_equation_data", "random_equation_data"] #, "cifar10_equation_data"] | |||||
for data_dir, tmp_file_prev in zip(data_dirs, tmp_file_prevs): | |||||
data = get_equation_std_data(data_dir = data_dir,\ | |||||
sign_dir_lists = ['0', '1', '10', '11'],\ | |||||
sign_output_lists = ['0', '1', '+', '='],\ | |||||
shape = (28, 28),\ | |||||
train_max_equation_len = 26, \ | |||||
test_max_equation_len = 26, \ | |||||
system_num = 2, \ | |||||
tmp_file_prev = tmp_file_prev, \ | |||||
train_num_per_len = 300, \ | |||||
test_num_per_len = 300, \ | |||||
is_color = False) |
@@ -2,6 +2,8 @@ import os | |||||
import os.path as osp | import os.path as osp | ||||
import pickle | import pickle | ||||
import random | import random | ||||
import gdown | |||||
import zipfile | |||||
from collections import defaultdict | from collections import defaultdict | ||||
import cv2 | import cv2 | ||||
@@ -10,29 +12,16 @@ from torchvision.transforms import transforms | |||||
CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | ||||
def get_data(img_dataset, train): | |||||
X, Y = [], [] | |||||
if train: | |||||
positive = img_dataset["train:positive"] | |||||
negative = img_dataset["train:negative"] | |||||
else: | |||||
positive = img_dataset["test:positive"] | |||||
negative = img_dataset["test:negative"] | |||||
for equation in positive: | |||||
equation = equation.astype(np.float32) | |||||
img_list = np.vsplit(equation, equation.shape[0]) | |||||
X.append(img_list) | |||||
Y.append(1) | |||||
for equation in negative: | |||||
equation = equation.astype(np.float32) | |||||
img_list = np.vsplit(equation, equation.shape[0]) | |||||
X.append(img_list) | |||||
Y.append(0) | |||||
return X, None, Y | |||||
def download_and_unzip(url, zip_file_name): | |||||
try: | |||||
gdown.download(url, zip_file_name) | |||||
with zipfile.ZipFile(zip_file_name, 'r') as zip_ref: | |||||
zip_ref.extractall(CURRENT_DIR) | |||||
os.remove(zip_file_name) | |||||
except Exception as e: | |||||
if os.path.exists(zip_file_name): | |||||
os.remove(zip_file_name) | |||||
raise Exception(f"An error occurred during download or unzip: {e}. Instead, you can download the dataset from {url} and unzip it in 'examples/hed/datasets' folder") | |||||
def get_pretrain_data(labels, image_size=(28, 28, 1)): | def get_pretrain_data(labels, image_size=(28, 28, 1)): | ||||
@@ -82,6 +71,19 @@ def split_equation(equations_by_len, prop_train, prop_val): | |||||
def get_dataset(dataset="mnist", train=True): | def get_dataset(dataset="mnist", train=True): | ||||
data_dir = CURRENT_DIR + '/mnist_images' | |||||
if not os.path.exists(data_dir): | |||||
print("Dataset not exist, downloading it...") | |||||
url = 'https://drive.google.com/u/0/uc?id=1XoJDjO3cNUdytqVgXUKOBe9dOcUBobom&export=download' | |||||
download_and_unzip(url, os.path.join(CURRENT_DIR, "HED.zip")) | |||||
print("Download and extraction complete.") | |||||
if train: | |||||
file = os.path.join(data_dir, "expr_train.json") | |||||
else: | |||||
file = os.path.join(data_dir, "expr_test.json") | |||||
if dataset == "mnist": | if dataset == "mnist": | ||||
file = osp.join(CURRENT_DIR, "mnist_equation_data_train_len_26_test_len_26_sys_2_.pk") | file = osp.join(CURRENT_DIR, "mnist_equation_data_train_len_26_test_len_26_sys_2_.pk") | ||||
elif dataset == "random": | elif dataset == "random": | ||||
@@ -91,11 +93,27 @@ def get_dataset(dataset="mnist", train=True): | |||||
with open(file, "rb") as f: | with open(file, "rb") as f: | ||||
img_dataset = pickle.load(f) | img_dataset = pickle.load(f) | ||||
X, _, Y = get_data(img_dataset, train) | |||||
equations_by_len = divide_equations_by_len(X, Y) | |||||
X, Y = [], [] | |||||
if train: | |||||
positive = img_dataset["train:positive"] | |||||
negative = img_dataset["train:negative"] | |||||
else: | |||||
positive = img_dataset["test:positive"] | |||||
negative = img_dataset["test:negative"] | |||||
return equations_by_len | |||||
for equation in positive: | |||||
equation = equation.astype(np.float32) | |||||
img_list = np.vsplit(equation, equation.shape[0]) | |||||
X.append(img_list) | |||||
Y.append(1) | |||||
for equation in negative: | |||||
equation = equation.astype(np.float32) | |||||
img_list = np.vsplit(equation, equation.shape[0]) | |||||
X.append(img_list) | |||||
Y.append(0) | |||||
equations_by_len = divide_equations_by_len(X, Y) | |||||
return equations_by_len | |||||
if __name__ == "__main__": | |||||
get_hed() |
@@ -1,7 +1,6 @@ | |||||
import os | import os | ||||
import numpy as np | import numpy as np | ||||
import math | import math | ||||
from zoopt import Dimension, Objective, Opt, Parameter | |||||
from abl.reasoning import PrologKB, Reasoner | from abl.reasoning import PrologKB, Reasoner | ||||
from abl.utils import reform_list | from abl.utils import reform_list | ||||
@@ -1 +1,2 @@ | |||||
abl | |||||
abl | |||||
gdown |
@@ -26,11 +26,9 @@ optional arguments: | |||||
--no-cuda disables CUDA training | --no-cuda disables CUDA training | ||||
--epochs EPOCHS number of epochs in each learning loop iteration | --epochs EPOCHS number of epochs in each learning loop iteration | ||||
(default : 1) | (default : 1) | ||||
--lr LR base learning rate (default : 0.001) | |||||
--weight-decay WEIGHT_DECAY | |||||
weight decay value (default : 0.03) | |||||
--lr LR base model learning rate (default : 0.001) | |||||
--batch-size BATCH_SIZE | --batch-size BATCH_SIZE | ||||
batch size (default : 32) | |||||
base model batch size (default : 32) | |||||
--loops LOOPS number of loop iterations (default : 5) | --loops LOOPS number of loop iterations (default : 5) | ||||
--segment_size SEGMENT_SIZE | --segment_size SEGMENT_SIZE | ||||
segment size (default : 1/3) | segment size (default : 1/3) | ||||
@@ -13,13 +13,13 @@ img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize( | |||||
def download_and_unzip(url, zip_file_name): | def download_and_unzip(url, zip_file_name): | ||||
try: | try: | ||||
gdown.download(url, zip_file_name) | gdown.download(url, zip_file_name) | ||||
with zipfile.pseudo_labelipFile(zip_file_name, 'r') as zip_ref: | |||||
with zipfile.ZipFile(zip_file_name, 'r') as zip_ref: | |||||
zip_ref.extractall(CURRENT_DIR) | zip_ref.extractall(CURRENT_DIR) | ||||
os.remove(zip_file_name) | os.remove(zip_file_name) | ||||
except Exception as e: | except Exception as e: | ||||
if os.path.exists(zip_file_name): | if os.path.exists(zip_file_name): | ||||
os.remove(zip_file_name) | os.remove(zip_file_name) | ||||
raise Exception(f"An error occurred during download or unzip: {e}. Instead, you can download the dataset from {url} and unzip it in './datasets' folder") | |||||
raise Exception(f"An error occurred during download or unzip: {e}. Instead, you can download the dataset from {url} and unzip it in 'examples/hwf/datasets' folder") | |||||
def get_dataset(train=True, get_pseudo_label=False): | def get_dataset(train=True, get_pseudo_label=False): | ||||
data_dir = CURRENT_DIR + '/data' | data_dir = CURRENT_DIR + '/data' | ||||
@@ -6,7 +6,7 @@ | |||||
"source": [ | "source": [ | ||||
"# Handwritten Formula (HWF)\n", | "# Handwritten Formula (HWF)\n", | ||||
"\n", | "\n", | ||||
"This notebook shows an implementation of [Handwritten Formula](https://arxiv.org/abs/2006.06649). In this task. In this task, handwritten images of decimal formulas and their computed results are given, alongwith a domain knowledge base containing information on how to compute the decimal formula. The task is to recognize the symbols (which can be digits or operators '+', '-', '×', '÷') of handwritten images and accurately determine their results.\n", | |||||
"This notebook shows an implementation of [Handwritten Formula](https://arxiv.org/abs/2006.06649). In this task, handwritten images of decimal formulas and their computed results are given, alongwith a domain knowledge base containing information on how to compute the decimal formula. The task is to recognize the symbols (which can be digits or operators '+', '-', '×', '÷') of handwritten images and accurately determine their results.\n", | |||||
"\n", | "\n", | ||||
"Intuitively, we first use a machine learning model (learning part) to convert the input images to symbols (we call them pseudo-labels), and then use the knowledge base (reasoning part) to calculate the results of these symbols. Since we do not have ground-truth of the symbols, in Abductive Learning, the reasoning part will leverage domain knowledge and revise the initial symbols yielded by the learning part through abductive reasoning. This process enables us to further update the machine learning model." | "Intuitively, we first use a machine learning model (learning part) to convert the input images to symbols (we call them pseudo-labels), and then use the knowledge base (reasoning part) to calculate the results of these symbols. Since we do not have ground-truth of the symbols, in Abductive Learning, the reasoning part will leverage domain knowledge and revise the initial symbols yielded by the learning part through abductive reasoning. This process enables us to further update the machine learning model." | ||||
] | ] | ||||
@@ -214,7 +214,7 @@ | |||||
"# class of symbol may be one of ['0', '1', ..., '9', '+', '-', '*', '/'], total of 14 classes\n", | "# class of symbol may be one of ['0', '1', ..., '9', '+', '-', '*', '/'], total of 14 classes\n", | ||||
"cls = SymbolNet(num_classes=14, image_size=(45, 45, 1))\n", | "cls = SymbolNet(num_classes=14, image_size=(45, 45, 1))\n", | ||||
"loss_fn = nn.CrossEntropyLoss()\n", | "loss_fn = nn.CrossEntropyLoss()\n", | ||||
"optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))\n", | |||||
"optimizer = torch.optim.Adam(cls.parameters(), lr=0.001)\n", | |||||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", | ||||
"\n", | "\n", | ||||
"base_model = BasicNN(\n", | "base_model = BasicNN(\n", | ||||
@@ -68,11 +68,9 @@ def main(): | |||||
parser.add_argument('--epochs', type=int, default=3, | parser.add_argument('--epochs', type=int, default=3, | ||||
help='number of epochs in each learning loop iteration (default : 3)') | help='number of epochs in each learning loop iteration (default : 3)') | ||||
parser.add_argument('--lr', type=float, default=1e-3, | parser.add_argument('--lr', type=float, default=1e-3, | ||||
help='base learning rate (default : 0.001)') | |||||
parser.add_argument('--weight-decay', type=int, default=3e-2, | |||||
help='weight decay value (default : 0.03)') | |||||
help='base model learning rate (default : 0.001)') | |||||
parser.add_argument('--batch-size', type=int, default=128, | parser.add_argument('--batch-size', type=int, default=128, | ||||
help='batch size (default : 128)') | |||||
help='base model batch size (default : 128)') | |||||
parser.add_argument('--loops', type=int, default=5, | parser.add_argument('--loops', type=int, default=5, | ||||
help='number of loop iterations (default : 5)') | help='number of loop iterations (default : 5)') | ||||
parser.add_argument('--segment_size', type=int or float, default=1000, | parser.add_argument('--segment_size', type=int or float, default=1000, | ||||
@@ -12,8 +12,8 @@ python main.py | |||||
## Usage | ## Usage | ||||
```bash | ```bash | ||||
usage: main.py [-h] [--no-cuda] [--epochs EPOCHS] [--lr LR] | |||||
[--weight-decay WEIGHT_DECAY] [--batch-size BATCH_SIZE] | |||||
usage: main.py [-h] [--no-cuda] [--epochs EPOCHS] [--lr LR] | |||||
[--alpha ALPHA] [--batch-size BATCH_SIZE] | |||||
[--loops LOOPS] [--segment_size SEGMENT_SIZE] | [--loops LOOPS] [--segment_size SEGMENT_SIZE] | ||||
[--save_interval SAVE_INTERVAL] [--max-revision MAX_REVISION] | [--save_interval SAVE_INTERVAL] [--max-revision MAX_REVISION] | ||||
[--require-more-revision REQUIRE_MORE_REVISION] | [--require-more-revision REQUIRE_MORE_REVISION] | ||||
@@ -26,11 +26,10 @@ optional arguments: | |||||
--no-cuda disables CUDA training | --no-cuda disables CUDA training | ||||
--epochs EPOCHS number of epochs in each learning loop iteration | --epochs EPOCHS number of epochs in each learning loop iteration | ||||
(default : 1) | (default : 1) | ||||
--lr LR base learning rate (default : 0.001) | |||||
--weight-decay WEIGHT_DECAY | |||||
weight decay value (default : 0.03) | |||||
--lr LR base model learning rate (default : 0.001) | |||||
--alpha ALPHA alpha in RMSprop (default : 0.9) | |||||
--batch-size BATCH_SIZE | --batch-size BATCH_SIZE | ||||
batch size (default : 32) | |||||
base model batch size (default : 32) | |||||
--loops LOOPS number of loop iterations (default : 5) | --loops LOOPS number of loop iterations (default : 5) | ||||
--segment_size SEGMENT_SIZE | --segment_size SEGMENT_SIZE | ||||
segment size (default : 1/3) | segment size (default : 1/3) | ||||
@@ -34,11 +34,11 @@ def main(): | |||||
parser.add_argument('--epochs', type=int, default=1, | parser.add_argument('--epochs', type=int, default=1, | ||||
help='number of epochs in each learning loop iteration (default : 1)') | help='number of epochs in each learning loop iteration (default : 1)') | ||||
parser.add_argument('--lr', type=float, default=1e-3, | parser.add_argument('--lr', type=float, default=1e-3, | ||||
help='base learning rate (default : 0.001)') | |||||
parser.add_argument('--weight-decay', type=int, default=3e-2, | |||||
help='weight decay value (default : 0.03)') | |||||
help='base model learning rate (default : 0.001)') | |||||
parser.add_argument('--alpha', type=float, default=0.9, | |||||
help='alpha in RMSprop (default : 0.9)') | |||||
parser.add_argument('--batch-size', type=int, default=32, | parser.add_argument('--batch-size', type=int, default=32, | ||||
help='batch size (default : 32)') | |||||
help='base model batch size (default : 32)') | |||||
parser.add_argument('--loops', type=int, default=5, | parser.add_argument('--loops', type=int, default=5, | ||||
help='number of loop iterations (default : 5)') | help='number of loop iterations (default : 5)') | ||||
parser.add_argument('--segment_size', type=int or float, default=1/3, | parser.add_argument('--segment_size', type=int or float, default=1/3, | ||||
@@ -65,7 +65,7 @@ def main(): | |||||
# Build necessary components for BasicNN | # Build necessary components for BasicNN | ||||
cls = LeNet5(num_classes=10) | cls = LeNet5(num_classes=10) | ||||
loss_fn = nn.CrossEntropyLoss() | loss_fn = nn.CrossEntropyLoss() | ||||
optimizer = torch.optim.Adam(cls.parameters(), lr=args.lr) | |||||
optimizer = torch.optim.RMSprop(cls.parameters(), lr=args.lr, alpha=args.alpha) | |||||
use_cuda = not args.no_cuda and torch.cuda.is_available() | use_cuda = not args.no_cuda and torch.cuda.is_available() | ||||
device = torch.device("cuda" if use_cuda else "cpu") | device = torch.device("cuda" if use_cuda else "cpu") | ||||
@@ -80,11 +80,6 @@ | |||||
} | } | ||||
], | ], | ||||
"source": [ | "source": [ | ||||
"def describe_structure(lst):\n", | |||||
" if not isinstance(lst, list):\n", | |||||
" return type(lst).__name__ \n", | |||||
" return [describe_structure(item) for item in lst]\n", | |||||
"\n", | |||||
"print(f\"Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y\")\n", | "print(f\"Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y\")\n", | ||||
"print()\n", | "print()\n", | ||||
"train_X, train_gt_pseudo_label, train_Y = train_data\n", | "train_X, train_gt_pseudo_label, train_Y = train_data\n", | ||||
@@ -357,7 +352,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": null, | |||||
"execution_count": 11, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -390,7 +385,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": null, | |||||
"execution_count": 12, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -402,14 +397,14 @@ | |||||
"cell_type": "markdown", | "cell_type": "markdown", | ||||
"metadata": {}, | "metadata": {}, | ||||
"source": [ | "source": [ | ||||
"## Bridge Learning and Reasoning\n", | |||||
"## Bridging Learning and Reasoning\n", | |||||
"\n", | "\n", | ||||
"Now, the last step is to bridge the learning and reasoning part. We proceed this step by creating an instance of `SimpleBridge`." | "Now, the last step is to bridge the learning and reasoning part. We proceed this step by creating an instance of `SimpleBridge`." | ||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": null, | |||||
"execution_count": 13, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -437,6 +432,13 @@ | |||||
"bridge.train(train_data, loops=5, segment_size=1/3, save_interval=1, save_dir=weights_dir)\n", | "bridge.train(train_data, loops=5, segment_size=1/3, save_interval=1, save_dir=weights_dir)\n", | ||||
"bridge.test(test_data)" | "bridge.test(test_data)" | ||||
] | ] | ||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
} | } | ||||
], | ], | ||||
"metadata": { | "metadata": { | ||||
@@ -455,7 +457,7 @@ | |||||
"name": "python", | "name": "python", | ||||
"nbconvert_exporter": "python", | "nbconvert_exporter": "python", | ||||
"pygments_lexer": "ipython3", | "pygments_lexer": "ipython3", | ||||
"version": "3.8.18" | |||||
"version": "3.8.13" | |||||
}, | }, | ||||
"orig_nbformat": 4, | "orig_nbformat": 4, | ||||
"vscode": { | "vscode": { | ||||
@@ -0,0 +1,29 @@ | |||||
import numpy as np | |||||
import openml | |||||
# Function to load and preprocess the dataset | |||||
def load_and_preprocess_dataset(dataset_id): | |||||
dataset = openml.datasets.get_dataset(dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False) | |||||
X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute) | |||||
# Convert data types | |||||
for col in X.select_dtypes(include='bool').columns: | |||||
X[col] = X[col].astype(int) | |||||
y = y.cat.codes.astype(int) | |||||
X, y = X.to_numpy(), y.to_numpy() | |||||
return X, y | |||||
# Function to split data (one shot) | |||||
def split_dataset(X, y, test_size = 0.3): | |||||
# For every class: 1 : (1-test_size)*(len-1) : test_size*(len-1) | |||||
label_indices, unlabel_indices, test_indices = [], [], [] | |||||
for class_label in np.unique(y): | |||||
idxs = np.where(y == class_label)[0] | |||||
np.random.shuffle(idxs) | |||||
n_train_unlabel = int((1-test_size)*(len(idxs)-1)) | |||||
label_indices.append(idxs[0]) | |||||
unlabel_indices.extend(idxs[1:1+n_train_unlabel]) | |||||
test_indices.extend(idxs[1+n_train_unlabel:]) | |||||
X_label, y_label = X[label_indices], y[label_indices] | |||||
X_unlabel, y_unlabel = X[unlabel_indices], y[unlabel_indices] | |||||
X_test, y_test = X[test_indices], y[test_indices] | |||||
return X_label, y_label, X_unlabel, y_unlabel, X_test, y_test |
@@ -0,0 +1,80 @@ | |||||
from z3 import Solver, Int, If, Not, Implies, Sum, sat | |||||
import openml | |||||
from abl.reasoning import KBBase | |||||
class ZooKB(KBBase): | |||||
def __init__(self): | |||||
super().__init__(pseudo_label_list=list(range(7)), use_cache=False) | |||||
self.solver = Solver() | |||||
# Load information of Zoo dataset | |||||
dataset = openml.datasets.get_dataset(dataset_id = 62, download_data=False, download_qualities=False, download_features_meta_data=False) | |||||
X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute) | |||||
self.attribute_names = attribute_names | |||||
self.target_names = y.cat.categories.tolist() | |||||
print("Attribute names are: ", self.attribute_names) | |||||
print("Target names are: ", self.target_names) | |||||
# self.attribute_names = ["hair", "feathers", "eggs", "milk", "airborne", "aquatic", "predator", "toothed", "backbone", "breathes", "venomous", "fins", "legs", "tail", "domestic", "catsize"] | |||||
# self.target_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "invertebrate"] | |||||
# Define variables | |||||
for name in self.attribute_names+self.target_names: | |||||
exec(f"globals()['{name}'] = Int('{name}')") ## or use dict to create var and modify rules | |||||
# Define rules | |||||
rules = [ | |||||
Implies(milk == 1, mammal == 1), | |||||
Implies(mammal == 1, milk == 1), | |||||
Implies(mammal == 1, backbone == 1), | |||||
Implies(mammal == 1, breathes == 1), | |||||
Implies(feathers == 1, bird == 1), | |||||
Implies(bird == 1, feathers == 1), | |||||
Implies(bird == 1, eggs == 1), | |||||
Implies(bird == 1, backbone == 1), | |||||
Implies(bird == 1, breathes == 1), | |||||
Implies(bird == 1, legs == 2), | |||||
Implies(bird == 1, tail == 1), | |||||
Implies(reptile == 1, backbone == 1), | |||||
Implies(reptile == 1, breathes == 1), | |||||
Implies(reptile == 1, tail == 1), | |||||
Implies(fish == 1, aquatic == 1), | |||||
Implies(fish == 1, toothed == 1), | |||||
Implies(fish == 1, backbone == 1), | |||||
Implies(fish == 1, Not(breathes == 1)), | |||||
Implies(fish == 1, fins == 1), | |||||
Implies(fish == 1, legs == 0), | |||||
Implies(fish == 1, tail == 1), | |||||
Implies(amphibian == 1, eggs == 1), | |||||
Implies(amphibian == 1, aquatic == 1), | |||||
Implies(amphibian == 1, backbone == 1), | |||||
Implies(amphibian == 1, breathes == 1), | |||||
Implies(amphibian == 1, legs == 4), | |||||
Implies(insect == 1, eggs == 1), | |||||
Implies(insect == 1, Not(backbone == 1)), | |||||
Implies(insect == 1, legs == 6), | |||||
Implies(invertebrate == 1, Not(backbone == 1)) | |||||
] | |||||
# Define weights and sum of violated weights | |||||
self.weights = {rule: 1 for rule in rules} | |||||
self.total_violation_weight = Sum([If(Not(rule), self.weights[rule], 0) for rule in self.weights]) | |||||
def logic_forward(self, pseudo_label, data_point): | |||||
attribute_names, target_names = self.attribute_names, self.target_names | |||||
solver = self.solver | |||||
total_violation_weight = self.total_violation_weight | |||||
pseudo_label, data_point = pseudo_label[0], data_point[0] | |||||
self.solver.reset() | |||||
for name, value in zip(attribute_names, data_point): | |||||
solver.add(eval(f"{name} == {value}")) | |||||
for cate, name in zip(self.pseudo_label_list,target_names): | |||||
value = 1 if (cate == pseudo_label) else 0 | |||||
solver.add(eval(f"{name} == {value}")) | |||||
if solver.check() == sat: | |||||
model = solver.model() | |||||
total_weight = model.evaluate(total_violation_weight) | |||||
return total_weight.as_long() | |||||
else: | |||||
# No solution found | |||||
return 1e10 |
@@ -0,0 +1,4 @@ | |||||
abl | |||||
z3-solver | |||||
openml | |||||
scikit-learn |
@@ -0,0 +1,370 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"# ZOO\n", | |||||
"\n", | |||||
"This notebook shows an implementation of [MNIST Addition](https://arxiv.org/abs/1805.10872). In this task, pairs of MNIST handwritten images and their sums are given, alongwith a domain knowledge base containing information on how to perform addition operations. The task is to recognize the digits of handwritten images and accurately determine their sum.\n", | |||||
"\n", | |||||
"Intuitively, we first use a machine learning model (learning part) to convert the input images to digits (we call them pseudo-labels), and then use the knowledge base (reasoning part) to calculate the sum of these digits. Since we do not have ground-truth of the digits, in Abductive Learning, the reasoning part will leverage domain knowledge and revise the initial digits yielded by the learning part through abductive reasoning. This process enables us to further update the machine learning model." | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 1, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"# Import necessary libraries and modules\n", | |||||
"import os.path as osp\n", | |||||
"import numpy as np\n", | |||||
"from sklearn.ensemble import RandomForestClassifier\n", | |||||
"from examples.zoo.get_dataset import load_and_preprocess_dataset, split_dataset\n", | |||||
"from abl.learning import ABLModel\n", | |||||
"from examples.zoo.kb import ZooKB\n", | |||||
"from abl.reasoning import Reasoner\n", | |||||
"from abl.evaluation import ReasoningMetric, SymbolMetric\n", | |||||
"from abl.utils import ABLLogger, print_log, confidence_dist\n", | |||||
"from abl.bridge import SimpleBridge" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## Working with Data\n", | |||||
"\n", | |||||
"First, we get the training and testing datasets:" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 2, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"# Load and preprocess the Zoo dataset\n", | |||||
"X, y = load_and_preprocess_dataset(dataset_id=62)\n", | |||||
"\n", | |||||
"# Split data into labeled/unlabeled/test data\n", | |||||
"X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"`train_data` and `test_data` share identical structures: tuples with three components: X (list where each element is a list of two images), gt_pseudo_label (list where each element is a list of two digits, i.e., pseudo-labels) and Y (list where each element is the sum of the two digits). The length and structures of datasets are illustrated as follows.\n", | |||||
"\n", | |||||
"Note: ``gt_pseudo_label`` is only used to evaluate the performance of the learning part but not to train the model." | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 3, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Shape of X and y: (101, 16) (101,)\n", | |||||
"First five elements of X:\n", | |||||
"[[True False False True False False True True True True False False 4\n", | |||||
" False False True]\n", | |||||
" [True False False True False False False True True True False False 4\n", | |||||
" True False True]\n", | |||||
" [False False True False False True True True True False False True 0\n", | |||||
" True False False]\n", | |||||
" [True False False True False False True True True True False False 4\n", | |||||
" False False True]\n", | |||||
" [True False False True False False True True True True False False 4\n", | |||||
" True False True]]\n", | |||||
"First five elements of y:\n", | |||||
"[0 0 3 0 0]\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"print(\"Shape of X and y:\", X.shape, y.shape)\n", | |||||
"print(\"First five elements of X:\")\n", | |||||
"print(X[:5])\n", | |||||
"print(\"First five elements of y:\")\n", | |||||
"print(y[:5])" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"Transform tabluar data to the format required by ABL-Package, which is a tuple of (X, gt_pseudo_label, Y)\n", | |||||
"\n", | |||||
"For tabular data in abl, each example contains a single instance (a row from the dataset).\n", | |||||
"\n", | |||||
"For these tabular data samples, the reasoning results are expected to be 0, indicating no rules are violated." | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 4, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"def transform_tab_data(X, y):\n", | |||||
" return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))\n", | |||||
"label_data = transform_tab_data(X_label, y_label)\n", | |||||
"test_data = transform_tab_data(X_test, y_test)\n", | |||||
"train_data = transform_tab_data(X_unlabel, y_unlabel)" | |||||
] | |||||
}, | |||||
{ | |||||
"attachments": {}, | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## Building the Learning Part" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"To build the learning part, we need to first build a machine learning base model. We use a [Random Forest](https://en.wikipedia.org/wiki/Random_forest) as the base model" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 5, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<style>#sk-container-id-1 {color: black;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>RandomForestClassifier()</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">RandomForestClassifier</label><div class=\"sk-toggleable__content\"><pre>RandomForestClassifier()</pre></div></div></div></div></div>" | |||||
], | |||||
"text/plain": [ | |||||
"RandomForestClassifier()" | |||||
] | |||||
}, | |||||
"execution_count": 5, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"base_model = RandomForestClassifier()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"However, the base model built above deals with instance-level data, and can not directly deal with example-level data. Therefore, we wrap the base model into `ABLModel`, which enables the learning part to train, test, and predict on example-level data." | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 6, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"model = ABLModel(base_model)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## Building the Reasoning Part" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"In the reasoning part, we first build a knowledge base which contain information on how to perform addition operations. We build it by creating a subclass of `KBBase`. In the derived subclass, we initialize the `pseudo_label_list` parameter specifying list of possible pseudo-labels, and override the `logic_forward` function defining how to perform (deductive) reasoning." | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 7, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Attribute names are: ['hair', 'feathers', 'eggs', 'milk', 'airborne', 'aquatic', 'predator', 'toothed', 'backbone', 'breathes', 'venomous', 'fins', 'legs', 'tail', 'domestic', 'catsize']\n", | |||||
"Target names are: ['mammal', 'bird', 'reptile', 'fish', 'amphibian', 'insect', 'invertebrate']\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"kb = ZooKB()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"The knowledge base can perform logical reasoning (both deductive reasoning and abductive reasoning). Below is an example of performing (deductive) reasoning, and users can refer to [Documentation]() for details of abductive reasoning." | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 10, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Reasoning result of pseudo-label example [1, 2] is 3.\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"pseudo_label = [0]\n", | |||||
"data_point = [np.array([1,0,0,1,0,0,1,1,1,1,0,0,4,0,0,1,1])]\n", | |||||
"print(kb.logic_forward(pseudo_label, data_point))\n", | |||||
"for x, y_item in zip(X, y):\n", | |||||
" print(x,y_item)\n", | |||||
" print(kb.logic_forward([y_item], [x]))" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"Note: In addition to building a knowledge base based on `KBBase`, we can also establish a knowledge base with a ground KB using `GroundKB`, or a knowledge base implemented based on Prolog files using `PrologKB`. The corresponding code for these implementations can be found in the `main.py` file. Those interested are encouraged to examine it for further insights." | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"Then, we create a reasoner by instantiating the class ``Reasoner``. Due to the indeterminism of abductive reasoning, there could be multiple candidates compatible to the knowledge base. When this happens, reasoner can minimize inconsistencies between the knowledge base and pseudo-labels predicted by the learning part, and then return only one candidate that has the highest consistency." | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 11, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"def consitency(data_example, candidates, candidate_idxs, reasoning_results):\n", | |||||
" pred_prob = data_example.pred_prob\n", | |||||
" model_scores = confidence_dist(pred_prob, candidate_idxs)\n", | |||||
" rule_scores = np.array(reasoning_results)\n", | |||||
" scores = model_scores + rule_scores\n", | |||||
" return scores\n", | |||||
"\n", | |||||
"reasoner = Reasoner(kb, dist_func=consitency)" | |||||
] | |||||
}, | |||||
{ | |||||
"attachments": {}, | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## Building Evaluation Metrics" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"Next, we set up evaluation metrics. These metrics will be used to evaluate the model performance during training and testing. Specifically, we use `SymbolMetric` and `ReasoningMetric`, which are used to evaluate the accuracy of the machine learning model’s predictions and the accuracy of the final reasoning results, respectively." | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 12, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"metric_list = [SymbolMetric(prefix=\"zoo\"), ReasoningMetric(kb=kb, prefix=\"zoo\")]" | |||||
] | |||||
}, | |||||
{ | |||||
"attachments": {}, | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## Bridging Learning and Reasoning\n", | |||||
"\n", | |||||
"Now, the last step is to bridge the learning and reasoning part. We proceed this step by creating an instance of `SimpleBridge`." | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 13, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"bridge = SimpleBridge(model, reasoner, metric_list)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"Perform training and testing by invoking the `train` and `test` methods of `SimpleBridge`." | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"# Build logger\n", | |||||
"print_log(\"Abductive Learning on the ZOO example.\", logger=\"current\")\n", | |||||
"log_dir = ABLLogger.get_current_instance().log_dir\n", | |||||
"weights_dir = osp.join(log_dir, \"weights\")\n", | |||||
"\n", | |||||
"# Pre-train the machine learning model\n", | |||||
"base_model.fit(X_label, y_label)\n", | |||||
"\n", | |||||
"# Test the initial model\n", | |||||
"print(\"------- Test the initial model -----------\")\n", | |||||
"bridge.test(test_data)\n", | |||||
"print(\"------- Use ABL to train the model -----------\")\n", | |||||
"# Use ABL to train the model\n", | |||||
"bridge.train(train_data=train_data, label_data=label_data, loops=3, segment_size=len(X_unlabel), save_dir=weights_dir)\n", | |||||
"print(\"------- Test the final model -----------\")\n", | |||||
"# Test the final model\n", | |||||
"bridge.test(test_data)" | |||||
] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "abl", | |||||
"language": "python", | |||||
"name": "python3" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 3 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython3", | |||||
"version": "3.8.13" | |||||
}, | |||||
"orig_nbformat": 4, | |||||
"vscode": { | |||||
"interpreter": { | |||||
"hash": "9c8d454494e49869a4ee4046edcac9a39ff683f7d38abf0769f648402670238e" | |||||
} | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 2 | |||||
} |
@@ -1,292 +0,0 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"import os.path as osp\n", | |||||
"\n", | |||||
"import numpy as np\n", | |||||
"from sklearn.ensemble import RandomForestClassifier\n", | |||||
"from sklearn.metrics import accuracy_score\n", | |||||
"from z3 import Solver, Int, If, Not, Implies, Sum, sat\n", | |||||
"import openml\n", | |||||
"\n", | |||||
"from abl.learning import ABLModel\n", | |||||
"from abl.reasoning import KBBase, Reasoner\n", | |||||
"from abl.evaluation import ReasoningMetric, SymbolMetric\n", | |||||
"from abl.bridge import SimpleBridge\n", | |||||
"from abl.utils.utils import confidence_dist\n", | |||||
"from abl.utils import ABLLogger, print_log" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"# Build logger\n", | |||||
"print_log(\"Abductive Learning on the Zoo example.\", logger=\"current\")\n", | |||||
"\n", | |||||
"# Retrieve the directory of the Log file and define the directory for saving the model weights.\n", | |||||
"log_dir = ABLLogger.get_current_instance().log_dir\n", | |||||
"weights_dir = osp.join(log_dir, \"weights\")" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### Learning Part" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"rf = RandomForestClassifier()\n", | |||||
"model = ABLModel(rf)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### Logic Part" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"class ZooKB(KBBase):\n", | |||||
" def __init__(self):\n", | |||||
" super().__init__(pseudo_label_list=list(range(7)), use_cache=False)\n", | |||||
" \n", | |||||
" # Use z3 solver \n", | |||||
" self.solver = Solver()\n", | |||||
"\n", | |||||
" # Load information of Zoo dataset\n", | |||||
" dataset = openml.datasets.get_dataset(dataset_id = 62, download_data=False, download_qualities=False, download_features_meta_data=False)\n", | |||||
" X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n", | |||||
" self.attribute_names = attribute_names\n", | |||||
" self.target_names = y.cat.categories.tolist()\n", | |||||
" \n", | |||||
" # Define variables\n", | |||||
" for name in self.attribute_names+self.target_names:\n", | |||||
" exec(f\"globals()['{name}'] = Int('{name}')\") ## or use dict to create var and modify rules\n", | |||||
" # Define rules\n", | |||||
" rules = [\n", | |||||
" Implies(milk == 1, mammal == 1),\n", | |||||
" Implies(mammal == 1, milk == 1),\n", | |||||
" Implies(mammal == 1, backbone == 1),\n", | |||||
" Implies(mammal == 1, breathes == 1),\n", | |||||
" Implies(feathers == 1, bird == 1),\n", | |||||
" Implies(bird == 1, feathers == 1),\n", | |||||
" Implies(bird == 1, eggs == 1),\n", | |||||
" Implies(bird == 1, backbone == 1),\n", | |||||
" Implies(bird == 1, breathes == 1),\n", | |||||
" Implies(bird == 1, legs == 2),\n", | |||||
" Implies(bird == 1, tail == 1),\n", | |||||
" Implies(reptile == 1, backbone == 1),\n", | |||||
" Implies(reptile == 1, breathes == 1),\n", | |||||
" Implies(reptile == 1, tail == 1),\n", | |||||
" Implies(fish == 1, aquatic == 1),\n", | |||||
" Implies(fish == 1, toothed == 1),\n", | |||||
" Implies(fish == 1, backbone == 1),\n", | |||||
" Implies(fish == 1, Not(breathes == 1)),\n", | |||||
" Implies(fish == 1, fins == 1),\n", | |||||
" Implies(fish == 1, legs == 0),\n", | |||||
" Implies(fish == 1, tail == 1),\n", | |||||
" Implies(amphibian == 1, eggs == 1),\n", | |||||
" Implies(amphibian == 1, aquatic == 1),\n", | |||||
" Implies(amphibian == 1, backbone == 1),\n", | |||||
" Implies(amphibian == 1, breathes == 1),\n", | |||||
" Implies(amphibian == 1, legs == 4),\n", | |||||
" Implies(insect == 1, eggs == 1),\n", | |||||
" Implies(insect == 1, Not(backbone == 1)),\n", | |||||
" Implies(insect == 1, legs == 6),\n", | |||||
" Implies(invertebrate == 1, Not(backbone == 1))\n", | |||||
" ]\n", | |||||
" # Define weights and sum of violated weights\n", | |||||
" self.weights = {rule: 1 for rule in rules}\n", | |||||
" self.total_violation_weight = Sum([If(Not(rule), self.weights[rule], 0) for rule in self.weights])\n", | |||||
" \n", | |||||
" def logic_forward(self, pseudo_label, data_point):\n", | |||||
" attribute_names, target_names = self.attribute_names, self.target_names\n", | |||||
" solver = self.solver\n", | |||||
" total_violation_weight = self.total_violation_weight\n", | |||||
" pseudo_label, data_point = pseudo_label[0], data_point[0]\n", | |||||
" \n", | |||||
" self.solver.reset()\n", | |||||
" for name, value in zip(attribute_names, data_point):\n", | |||||
" solver.add(eval(f\"{name} == {value}\"))\n", | |||||
" for cate, name in zip(self.pseudo_label_list,target_names):\n", | |||||
" value = 1 if (cate == pseudo_label) else 0\n", | |||||
" solver.add(eval(f\"{name} == {value}\"))\n", | |||||
" \n", | |||||
" if solver.check() == sat:\n", | |||||
" model = solver.model()\n", | |||||
" total_weight = model.evaluate(total_violation_weight)\n", | |||||
" return total_weight.as_long()\n", | |||||
" else:\n", | |||||
" # No solution found\n", | |||||
" return 1e10\n", | |||||
" \n", | |||||
"def consitency(data_example, candidates, candidate_idxs, reasoning_results):\n", | |||||
" pred_prob = data_example.pred_prob\n", | |||||
" model_scores = confidence_dist(pred_prob, candidate_idxs)\n", | |||||
" rule_scores = np.array(reasoning_results)\n", | |||||
" scores = model_scores + rule_scores\n", | |||||
" return scores\n", | |||||
"\n", | |||||
"kb = ZooKB()\n", | |||||
"reasoner = Reasoner(kb, dist_func=consitency)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### Datasets and Evaluation Metrics" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"# Function to load and preprocess the dataset\n", | |||||
"def load_and_preprocess_dataset(dataset_id):\n", | |||||
" dataset = openml.datasets.get_dataset(dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False)\n", | |||||
" X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n", | |||||
" # Convert data types\n", | |||||
" for col in X.select_dtypes(include='bool').columns:\n", | |||||
" X[col] = X[col].astype(int)\n", | |||||
" y = y.cat.codes.astype(int)\n", | |||||
" X, y = X.to_numpy(), y.to_numpy()\n", | |||||
" return X, y\n", | |||||
"\n", | |||||
"# Function to split data (one shot)\n", | |||||
"def split_dataset(X, y, test_size = 0.3):\n", | |||||
" # For every class: 1 : (1-test_size)*(len-1) : test_size*(len-1)\n", | |||||
" label_indices, unlabel_indices, test_indices = [], [], []\n", | |||||
" for class_label in np.unique(y):\n", | |||||
" idxs = np.where(y == class_label)[0]\n", | |||||
" np.random.shuffle(idxs)\n", | |||||
" n_train_unlabel = int((1-test_size)*(len(idxs)-1))\n", | |||||
" label_indices.append(idxs[0])\n", | |||||
" unlabel_indices.extend(idxs[1:1+n_train_unlabel])\n", | |||||
" test_indices.extend(idxs[1+n_train_unlabel:])\n", | |||||
" X_label, y_label = X[label_indices], y[label_indices]\n", | |||||
" X_unlabel, y_unlabel = X[unlabel_indices], y[unlabel_indices]\n", | |||||
" X_test, y_test = X[test_indices], y[test_indices]\n", | |||||
" return X_label, y_label, X_unlabel, y_unlabel, X_test, y_test\n", | |||||
"\n", | |||||
"# Load and preprocess the Zoo dataset\n", | |||||
"X, y = load_and_preprocess_dataset(dataset_id=62)\n", | |||||
"\n", | |||||
"# Split data into labeled/unlabeled/test data\n", | |||||
"X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3)\n", | |||||
"\n", | |||||
"# Transform tabluar data to the format required by ABL, which is a tuple of (X, ground truth of X, reasoning results)\n", | |||||
"# For tabular data in abl, each example contains a single instance (a row from the dataset).\n", | |||||
"# For these tabular data examples, the reasoning results are expected to be 0, indicating no rules are violated.\n", | |||||
"def transform_tab_data(X, y):\n", | |||||
" return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))\n", | |||||
"label_data = transform_tab_data(X_label, y_label)\n", | |||||
"test_data = transform_tab_data(X_test, y_test)\n", | |||||
"train_data = transform_tab_data(X_unlabel, y_unlabel)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"# Set up metrics\n", | |||||
"metric_list = [SymbolMetric(prefix=\"zoo\"), ReasoningMetric(kb=kb, prefix=\"zoo\")]" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### Bridge Machine Learning and Logic Reasoning" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"bridge = SimpleBridge(model, reasoner, metric_list)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### Train and Test" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"# Pre-train the machine learning model\n", | |||||
"rf.fit(X_label, y_label)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"# Test the initial model\n", | |||||
"print(\"------- Test the initial model -----------\")\n", | |||||
"bridge.test(test_data)\n", | |||||
"print(\"------- Use ABL to train the model -----------\")\n", | |||||
"# Use ABL to train the model\n", | |||||
"bridge.train(train_data=train_data, label_data=label_data, loops=3, segment_size=len(X_unlabel), save_dir=weights_dir)\n", | |||||
"print(\"------- Test the final model -----------\")\n", | |||||
"# Test the final model\n", | |||||
"bridge.test(test_data)" | |||||
] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "abl", | |||||
"language": "python", | |||||
"name": "python3" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 3 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython3", | |||||
"version": "3.8.13" | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 2 | |||||
} |
@@ -215,7 +215,7 @@ def kb_hwf2(): | |||||
def kb_hed(): | def kb_hed(): | ||||
kb = HedKB( | kb = HedKB( | ||||
pseudo_label_list=[1, 0, "+", "="], | pseudo_label_list=[1, 0, "+", "="], | ||||
pl_file="examples/hed/datasets/learn_add.pl", | |||||
pl_file="examples/hed/reasoning/learn_add.pl", | |||||
) | ) | ||||
return kb | return kb | ||||
@@ -57,7 +57,7 @@ class TestPrologKB(object): | |||||
def test_init_pl2(self, kb_hed): | def test_init_pl2(self, kb_hed): | ||||
assert kb_hed.pseudo_label_list == [1, 0, "+", "="] | assert kb_hed.pseudo_label_list == [1, 0, "+", "="] | ||||
assert kb_hed.pl_file == "examples/hed/datasets/learn_add.pl" | |||||
assert kb_hed.pl_file == "examples/hed/reasoning/learn_add.pl" | |||||
def test_prolog_file_not_exist(self): | def test_prolog_file_not_exist(self): | ||||
pseudo_label_list = [1, 2] | pseudo_label_list = [1, 2] | ||||