Browse Source

[ENH] add hed example

pull/1/head
troyyyyy 1 year ago
parent
commit
3154d85952
31 changed files with 1286 additions and 1101 deletions
  1. +296
    -3
      docs/Examples/HED.rst
  2. +1
    -1
      docs/Examples/HWF.rst
  3. +1
    -1
      docs/Examples/MNISTAdd.rst
  4. +2
    -0
      docs/Examples/ZOO.rst
  5. BIN
      docs/img/hed_dataset1.png
  6. BIN
      docs/img/hed_dataset2.png
  7. BIN
      docs/img/hed_dataset3.png
  8. BIN
      docs/img/hed_dataset4.png
  9. +1
    -0
      docs/index.rst
  10. +2
    -2
      examples/hed/bridge.py
  11. +0
    -4
      examples/hed/datasets/README.md
  12. +2
    -2
      examples/hed/datasets/__init__.py
  13. +173
    -0
      examples/hed/datasets/equation_generator.py
  14. +46
    -28
      examples/hed/datasets/get_dataset.py
  15. +245
    -731
      examples/hed/hed.ipynb
  16. +0
    -1
      examples/hed/reasoning/reasoning.py
  17. +2
    -1
      examples/hed/requirements.txt
  18. +2
    -4
      examples/hwf/README.md
  19. +2
    -2
      examples/hwf/datasets/get_dataset.py
  20. +2
    -2
      examples/hwf/hwf.ipynb
  21. +2
    -4
      examples/hwf/main.py
  22. +5
    -6
      examples/mnist_add/README.md
  23. +5
    -5
      examples/mnist_add/main.py
  24. +12
    -10
      examples/mnist_add/mnist_add.ipynb
  25. +29
    -0
      examples/zoo/get_dataset.py
  26. +80
    -0
      examples/zoo/kb.py
  27. +4
    -0
      examples/zoo/requirements.txt
  28. +370
    -0
      examples/zoo/zoo.ipynb
  29. +0
    -292
      examples/zoo/zoo_example.ipynb
  30. +1
    -1
      tests/conftest.py
  31. +1
    -1
      tests/test_reasoning.py

+ 296
- 3
docs/Examples/HED.rst View File

@@ -1,5 +1,298 @@
Handwritten Equation Deciphering (HED)
======================================
Handwritten Equation Decipherment (HED)
=======================================

.. contents:: Table of Contents
Below shows an implementation of `Handwritten Equation
Decipherment <https://proceedings.neurips.cc/paper_files/paper/2019/file/9c19a2aa1d84e04b0bd4bc888792bd1e-Paper.pdf>`__.
In this task, the handwritten equations are given, which consist of
sequential pictures of characters. The equations are generated with
unknown operation rules from images of symbols (‘0’, ‘1’, ‘+’ and ‘=’),
and each equation is associated with a label indicating whether the
equation is correct (i.e., positive) or not (i.e., negative). Also, we
are given a knowledge base which involves the structure of the equations
and a recursive definition of bit-wise operations. The task is to learn
from a training set of above mentioned equations and then to predict
labels of unseen equations.

Intuitively, we first use a machine learning model (learning part) to
obtain the pseudo-labels (‘0’, ‘1’, ‘+’ and ‘=’) for the observed
pictures. We then use the knowledge base (reasoning part) to perform
abductive reasoning so as to yield ground hypotheses as possible
explanations to the observed facts, suggesting some pseudo-labels to be
revised. This process enables us to further update the machine learning
model.

.. code:: ipython3

# Import necessary libraries and modules
import os.path as osp
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from examples.hed.datasets import get_dataset, split_equation
from examples.models.nn import SymbolNet
from abl.learning import ABLModel, BasicNN
from examples.hed.reasoning import HedKB, HedReasoner
from abl.evaluation import ReasoningMetric, SymbolMetric
from abl.utils import ABLLogger, print_log
from examples.hed.bridge import HedBridge

Working with Data
-----------------

First, we get the datasets of handwritten equations:

.. code:: ipython3

total_train_data = get_dataset(train=True)
train_data, val_data = split_equation(total_train_data, 3, 1)
test_data = get_dataset(train=False)

The dataset are shown below:

.. code:: ipython3

true_train_equation = train_data[1]
false_train_equation = train_data[0]
print(f"Equations in the dataset is organized by equation length, " +
f"from {min(train_data[0].keys())} to {max(train_data[0].keys())}")
print()
true_train_equation_with_length_5 = true_train_equation[5]
false_train_equation_with_length_5 = false_train_equation[5]
print(f"For each euqation length, there are {len(true_train_equation_with_length_5)} " +
f"true equation and {len(false_train_equation_with_length_5)} false equation " +
f"in the training set")
true_val_equation = val_data[1]
false_val_equation = val_data[0]
true_val_equation_with_length_5 = true_val_equation[5]
false_val_equation_with_length_5 = false_val_equation[5]
print(f"For each euqation length, there are {len(true_val_equation_with_length_5)} " +
f"true equation and {len(false_val_equation_with_length_5)} false equation " +
f"in the validation set")
true_test_equation = test_data[1]
false_test_equation = test_data[0]
true_test_equation_with_length_5 = true_test_equation[5]
false_test_equation_with_length_5 = false_test_equation[5]
print(f"For each euqation length, there are {len(true_test_equation_with_length_5)} " +
f"true equation and {len(false_test_equation_with_length_5)} false equation " +
f"in the test set")


Out:
.. code:: none
:class: code-out

Equations in the dataset is organized by equation length, from 5 to 26
For each euqation length, there are 225 true equation and 225 false equation in the training set
For each euqation length, there are 75 true equation and 75 false equation in the validation set
For each euqation length, there are 300 true equation and 300 false equation in the test set

As illustrations, we show four equations in the training dataset:

.. code:: ipython3

true_train_equation_with_length_5 = true_train_equation[5]
true_train_equation_with_length_8 = true_train_equation[8]
print(f"First true equation with length 5 in the training dataset:")
for i, x in enumerate(true_train_equation_with_length_5[0]):
plt.subplot(1, 5, i+1)
plt.axis('off')
plt.imshow(x.transpose(1, 2, 0))
plt.show()
print(f"First true equation with length 8 in the training dataset:")
for i, x in enumerate(true_train_equation_with_length_8[0]):
plt.subplot(1, 8, i+1)
plt.axis('off')
plt.imshow(x.transpose(1, 2, 0))
plt.show()
false_train_equation_with_length_5 = false_train_equation[5]
false_train_equation_with_length_8 = false_train_equation[8]
print(f"First false equation with length 5 in the training dataset:")
for i, x in enumerate(false_train_equation_with_length_5[0]):
plt.subplot(1, 5, i+1)
plt.axis('off')
plt.imshow(x.transpose(1, 2, 0))
plt.show()
print(f"First false equation with length 8 in the training dataset:")
for i, x in enumerate(false_train_equation_with_length_8[0]):
plt.subplot(1, 8, i+1)
plt.axis('off')
plt.imshow(x.transpose(1, 2, 0))
plt.show()


Out:
.. code:: none
:class: code-out

First true equation with length 5 in the training dataset:
.. image:: ../img/hed_dataset1.png
:width: 300px


Out:
.. code:: none
:class: code-out

First true equation with length 8 in the training dataset:
.. image:: ../img/hed_dataset2.png
:width: 480px


Out:
.. code:: none
:class: code-out

First false equation with length 5 in the training dataset:
.. image:: ../img/hed_dataset3.png
:width: 300px


Out:
.. code:: none
:class: code-out

First false equation with length 8 in the training dataset:
.. image:: ../img/hed_dataset4.png
:width: 480px


Building the Learning Part
--------------------------

To build the learning part, we need to first build a machine learning
base model. We use SymbolNet, and encapsulate it within a ``BasicNN``
object to create the base model. ``BasicNN`` is a class that
encapsulates a PyTorch model, transforming it into a base model with an
sklearn-style interface.

.. code:: ipython3

# class of symbol may be one of ['0', '1', '+', '='], total of 4 classes
cls = SymbolNet(num_classes=4)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model = BasicNN(
cls,
loss_fn,
optimizer,
device,
batch_size=32,
num_epochs=1,
stop_loss=None,
)

However, the base model built above deals with instance-level data
(i.e., individual images), and can not directly deal with example-level
data (i.e., a list of images comprising the equation). Therefore, we
wrap the base model into ``ABLModel``, which enables the learning part
to train, test, and predict on example-level data.

.. code:: ipython3

model = ABLModel(base_model)

Building the Reasoning Part
---------------------------

In the reasoning part, we first build a knowledge base. As mentioned
before, the knowledge base in this task involves the structure of the
equations and a recursive definition of bit-wise operations. The
knowledge base is already defined in ``HedKB``, which is derived from
``PrologKB``, and is built upon Prolog file ``reasoning/BK.pl`` and
``reasoning/learn_add.pl``.

Specifically, the knowledge about the structure of equations (in
``reasoning/BK.pl``) is a set of DCG (definite clause grammar) rules
recursively define that a digit is a sequence of ‘0’ and ‘1’, and
equations share the structure of X+Y=Z, though the length of X, Y and Z
can be varied. The knowledge about bit-wise operations (in
``reasoning/learn_add.pl``) is a recursive logic program, which
reversely calculates X+Y, i.e., it operates on X and Y digit-by-digit
and from the last digit to the first.

Note: Please notice that, the specific rules for calculating the
operations are undefined in the knowledge base, i.e., results of ‘0+0’,
‘0+1’ and ‘1+1’ could be ‘0’, ‘1’, ‘00’, ‘01’ or even ‘10’. The missing
calculation rules are required to be learned from the data. Therefore,
``HedKB`` incorporates methods for abducing rules from data. Users
interested can refer to the specific implementation of ``HedKB`` in
``reasoning/reasoning.py``

.. code:: ipython3

kb = HedKB()

Then, we create a reasoner. Due to the indeterminism of abductive
reasoning, there could be multiple candidates compatible to the
knowledge base. When this happens, reasoner can minimize inconsistencies
between the knowledge base and pseudo-labels predicted by the learning
part, and then return only one candidate that has the highest
consistency.

In this task, we create the reasoner by instantiating the class
``HedReasoner``, which is a reasoner derived from ``Reasoner`` and
tailored specifically for this task. ``HedReasoner`` leverages `ZOOpt
library <https://github.com/polixir/ZOOpt>`__ for acceleration, and has
designed a specific strategy to better harness ZOOpt’s capabilities.
Additionally, methods for abducing rules from data have been
incorporated. Users interested can refer to the specific implementation
of ``HedReasoner`` in ``reasoning/reasoning.py``.

.. code:: ipython3

reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=10)

Building Evaluation Metrics
---------------------------

Next, we set up evaluation metrics. These metrics will be used to
evaluate the model performance during training and testing.
Specifically, we use ``SymbolMetric`` and ``ReasoningMetric``, which are
used to evaluate the accuracy of the machine learning model’s
predictions and the accuracy of the final reasoning results,
respectively.

.. code:: ipython3

# Set up metrics
metric_list = [SymbolMetric(prefix="hed"), ReasoningMetric(kb=kb, prefix="hed")]

Bridge Learning and Reasoning
-----------------------------

Now, the last step is to bridge the learning and reasoning part. We
proceed this step by creating an instance of ``HedBridge``, which is
derived from ``SimpleBridge`` and tailored specific for this task.

.. code:: ipython3

bridge = HedBridge(model, reasoner, metric_list)

Perform training and testing.

**[TODO]** give a detailed introduction about training in HedBridge.

.. code:: ipython3

# Build logger
print_log("Abductive Learning on the HED example.", logger="current")
# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")
bridge.pretrain("./weights")
bridge.train(train_data, val_data)
bridge.test(test_data)

+ 1
- 1
docs/Examples/HWF.rst View File

@@ -2,7 +2,7 @@ Handwritten Formula (HWF)
=========================

Below shows an implementation of `Handwritten
Formula <https://arxiv.org/abs/2006.06649>`__. In this task. In this
Formula <https://arxiv.org/abs/2006.06649>`__. In this
task, handwritten images of decimal formulas and their computed results
are given, alongwith a domain knowledge base containing information on
how to compute the decimal formula. The task is to recognize the symbols


+ 1
- 1
docs/Examples/MNISTAdd.rst View File

@@ -140,7 +140,7 @@ model with an sklearn-style interface.

cls = LeNet5(num_classes=10)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cls.parameters(), lr=0.001)
optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, alpha=0.9)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
base_model = BasicNN(


+ 2
- 0
docs/Examples/ZOO.rst View File

@@ -0,0 +1,2 @@
ZOO
===

BIN
docs/img/hed_dataset1.png View File

Before After
Width: 516  |  Height: 105  |  Size: 3.5 kB

BIN
docs/img/hed_dataset2.png View File

Before After
Width: 516  |  Height: 72  |  Size: 12 kB

BIN
docs/img/hed_dataset3.png View File

Before After
Width: 516  |  Height: 105  |  Size: 4.2 kB

BIN
docs/img/hed_dataset4.png View File

Before After
Width: 516  |  Height: 72  |  Size: 9.8 kB

+ 1
- 0
docs/index.rst View File

@@ -26,6 +26,7 @@
Examples/MNISTAdd
Examples/HWF
Examples/HED
Examples/ZOO

.. toctree::
:maxdepth: 1


+ 2
- 2
examples/hed/bridge.py View File

@@ -10,12 +10,12 @@ from abl.learning import ABLModel, BasicNN
from abl.reasoning import Reasoner
from abl.structures import ListData
from abl.utils import print_log
from examples.hed.datasets.get_dataset import get_pretrain_data
from examples.hed.datasets import get_pretrain_data
from examples.hed.utils import InfiniteSampler, gen_mappings
from examples.models.nn import SymbolNetAutoencoder


class HEDBridge(SimpleBridge):
class HedBridge(SimpleBridge):
def __init__(
self,
model: ABLModel,


+ 0
- 4
examples/hed/datasets/README.md View File

@@ -1,4 +0,0 @@
Download the Handwritten Equation Decipherment dataset from [NJU Box](https://box.nju.edu.cn/f/391c2d48c32b436cb833/) to this folder and unzip it:
```
unzip HED.zip
```

+ 2
- 2
examples/hed/datasets/__init__.py View File

@@ -1,4 +1,4 @@
from .get_dataset import get_dataset, split_equation
from .get_dataset import get_dataset, get_pretrain_data, split_equation


__all__ = ["get_dataset", "split_equation"]
__all__ = ["get_dataset", "get_pretrain_data", "split_equation"]

+ 173
- 0
examples/hed/datasets/equation_generator.py View File

@@ -0,0 +1,173 @@
import os
import itertools
import random
import numpy as np
from PIL import Image
import pickle

def get_sign_path_list(data_dir, sign_names):
sign_num = len(sign_names)
index_dict = dict(zip(sign_names, list(range(sign_num))))
ret = [[] for _ in range(sign_num)]
for path in os.listdir(data_dir):
if (path in sign_names):
index = index_dict[path]
sign_path = os.path.join(data_dir, path)
for p in os.listdir(sign_path):
ret[index].append(os.path.join(sign_path, p))
return ret

def split_pool_by_rate(pools, rate, seed = None):
if seed is not None:
random.seed(seed)
ret1 = []
ret2 = []
for pool in pools:
random.shuffle(pool)
num = int(len(pool) * rate)
ret1.append(pool[:num])
ret2.append(pool[num:])
return ret1, ret2

def int_to_system_form(num, system_num):
if num == 0:
return "0"
ret = ""
while (num > 0):
ret += str(num % system_num)
num //= system_num
return ret[::-1]

def generator_equations(left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type):
expr_len = left_opt_len + right_opt_len
num_list = "".join([str(i) for i in range(system_num)])
ret = []
if generate_type == "all":
candidates = itertools.product(num_list, repeat = expr_len)
else:
candidates = [''.join(random.sample(['0', '1'] * expr_len, expr_len))]
random.shuffle(candidates)
for nums in candidates:
left_num = "".join(nums[:left_opt_len])
right_num = "".join(nums[left_opt_len:])
left_value = int(left_num, system_num)
right_value = int(right_num, system_num)
result_value = left_value + right_value
if (label == 'negative'):
result_value += random.randint(-result_value, result_value)
if (left_value + right_value == result_value):
continue
result_num = int_to_system_form(result_value, system_num)
#leading zeros
if (res_opt_len != len(result_num)):
continue
if ((left_opt_len > 1 and left_num[0] == '0') or (right_opt_len > 1 and right_num[0] == '0')):
continue

#add leading zeros
if (res_opt_len < len(result_num)):
continue
while (len(result_num) < res_opt_len):
result_num = '0' + result_num
#continue
ret.append(left_num + '+' + right_num + '=' + result_num) # current only consider '+' and '='
#print(ret[-1])
return ret

def generator_equation_by_len(equation_len, system_num = 2, label = 0, require_num = 1):
generate_type = "one"
ret = []
equation_sign_num = 2 # '+' and '='
while len(ret) < require_num:
left_opt_len = random.randint(1, equation_len - 1 - equation_sign_num)
right_opt_len = random.randint(1, equation_len - left_opt_len - equation_sign_num)
res_opt_len = equation_len - left_opt_len - right_opt_len - equation_sign_num
ret.extend(generator_equations(left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type))
return ret

def generator_equations_by_len(equation_len, system_num = 2, label = 0, repeat_times = 1, keep = 1, generate_type = "all"):
ret = []
equation_sign_num = 2 # '+' and '='
for left_opt_len in range(1, equation_len - (2 + equation_sign_num) + 1):
for right_opt_len in range(1, equation_len - left_opt_len - (1 + equation_sign_num) + 1):
res_opt_len = equation_len - left_opt_len - right_opt_len - equation_sign_num
for i in range(repeat_times): #generate more equations
if random.random() > keep ** (equation_len):
continue
ret.extend(generator_equations(left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type))
return ret

def generator_equations_by_max_len(max_equation_len, system_num = 2, label = 0, repeat_times = 1, keep = 1, generate_type = "all", num_per_len = None):
ret = []
equation_sign_num = 2 # '+' and '='
for equation_len in range(3 + equation_sign_num, max_equation_len + 1):
if (num_per_len is None):
ret.extend(generator_equations_by_len(equation_len, system_num, label, repeat_times, keep, generate_type))
else:
ret.extend(generator_equation_by_len(equation_len, system_num, label, require_num = num_per_len))
return ret

def generator_equation_images(image_pools, equations, signs, shape, seed, is_color):
if (seed is not None):
random.seed(seed)
ret = []
sign_num = len(signs)
sign_index_dict = dict(zip(signs, list(range(sign_num))))
for equation in equations:
data = []
for sign in equation:
index = sign_index_dict[sign]
pick = random.randint(0, len(image_pools[index]) - 1)
if is_color:
image = Image.open(image_pools[index][pick]).convert('RGB').resize(shape)
else:
image = Image.open(image_pools[index][pick]).convert('I').resize(shape)
image_array = np.array(image)
image_array = (image_array-127)*(1./128)
data.append(image_array)
ret.append(np.array(data))
return ret

def get_equation_std_data(data_dir, sign_dir_lists, sign_output_lists, shape = (28, 28), train_max_equation_len = 10, test_max_equation_len = 10, system_num = 2, tmp_file_prev =
None, seed = None, train_num_per_len = 10, test_num_per_len = 10, is_color = False):
tmp_file = ""
if (tmp_file_prev is not None):
tmp_file = "%s_train_len_%d_test_len_%d_sys_%d_.pk" % (tmp_file_prev, train_max_equation_len, test_max_equation_len, system_num)
if (os.path.exists(tmp_file)):
return pickle.load(open(tmp_file, "rb"))

image_pools = get_sign_path_list(data_dir, sign_dir_lists)
train_pool, test_pool = split_pool_by_rate(image_pools, 0.8, seed)

ret = {}
for label in ["positive", "negative"]:
print("Generating equations.")
train_equations = generator_equations_by_max_len(train_max_equation_len, system_num, label, num_per_len = train_num_per_len)
test_equations = generator_equations_by_max_len(test_max_equation_len, system_num, label, num_per_len = test_num_per_len)
print(train_equations)
print(test_equations)
print("Generated equations.")
print("Generating equation image data.")
ret["train:%s" % (label)] = generator_equation_images(train_pool, train_equations, sign_output_lists, shape, seed, is_color)
ret["test:%s" % (label)] = generator_equation_images(test_pool, test_equations, sign_output_lists, shape, seed, is_color)
print("Generated equation image data.")

if (tmp_file_prev is not None):
pickle.dump(ret, open(tmp_file, "wb"))
return ret

if __name__ == "__main__":
data_dirs = ["./dataset/hed/mnist_images", "./dataset/hed/random_images"] #, "../dataset/cifar10_images"]
tmp_file_prevs = ["mnist_equation_data", "random_equation_data"] #, "cifar10_equation_data"]
for data_dir, tmp_file_prev in zip(data_dirs, tmp_file_prevs):
data = get_equation_std_data(data_dir = data_dir,\
sign_dir_lists = ['0', '1', '10', '11'],\
sign_output_lists = ['0', '1', '+', '='],\
shape = (28, 28),\
train_max_equation_len = 26, \
test_max_equation_len = 26, \
system_num = 2, \
tmp_file_prev = tmp_file_prev, \
train_num_per_len = 300, \
test_num_per_len = 300, \
is_color = False)

+ 46
- 28
examples/hed/datasets/get_dataset.py View File

@@ -2,6 +2,8 @@ import os
import os.path as osp
import pickle
import random
import gdown
import zipfile
from collections import defaultdict

import cv2
@@ -10,29 +12,16 @@ from torchvision.transforms import transforms

CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))


def get_data(img_dataset, train):
X, Y = [], []
if train:
positive = img_dataset["train:positive"]
negative = img_dataset["train:negative"]
else:
positive = img_dataset["test:positive"]
negative = img_dataset["test:negative"]

for equation in positive:
equation = equation.astype(np.float32)
img_list = np.vsplit(equation, equation.shape[0])
X.append(img_list)
Y.append(1)

for equation in negative:
equation = equation.astype(np.float32)
img_list = np.vsplit(equation, equation.shape[0])
X.append(img_list)
Y.append(0)

return X, None, Y
def download_and_unzip(url, zip_file_name):
try:
gdown.download(url, zip_file_name)
with zipfile.ZipFile(zip_file_name, 'r') as zip_ref:
zip_ref.extractall(CURRENT_DIR)
os.remove(zip_file_name)
except Exception as e:
if os.path.exists(zip_file_name):
os.remove(zip_file_name)
raise Exception(f"An error occurred during download or unzip: {e}. Instead, you can download the dataset from {url} and unzip it in 'examples/hed/datasets' folder")


def get_pretrain_data(labels, image_size=(28, 28, 1)):
@@ -82,6 +71,19 @@ def split_equation(equations_by_len, prop_train, prop_val):


def get_dataset(dataset="mnist", train=True):
data_dir = CURRENT_DIR + '/mnist_images'
if not os.path.exists(data_dir):
print("Dataset not exist, downloading it...")
url = 'https://drive.google.com/u/0/uc?id=1XoJDjO3cNUdytqVgXUKOBe9dOcUBobom&export=download'
download_and_unzip(url, os.path.join(CURRENT_DIR, "HED.zip"))
print("Download and extraction complete.")
if train:
file = os.path.join(data_dir, "expr_train.json")
else:
file = os.path.join(data_dir, "expr_test.json")
if dataset == "mnist":
file = osp.join(CURRENT_DIR, "mnist_equation_data_train_len_26_test_len_26_sys_2_.pk")
elif dataset == "random":
@@ -91,11 +93,27 @@ def get_dataset(dataset="mnist", train=True):

with open(file, "rb") as f:
img_dataset = pickle.load(f)
X, _, Y = get_data(img_dataset, train)
equations_by_len = divide_equations_by_len(X, Y)
X, Y = [], []
if train:
positive = img_dataset["train:positive"]
negative = img_dataset["train:negative"]
else:
positive = img_dataset["test:positive"]
negative = img_dataset["test:negative"]

return equations_by_len
for equation in positive:
equation = equation.astype(np.float32)
img_list = np.vsplit(equation, equation.shape[0])
X.append(img_list)
Y.append(1)

for equation in negative:
equation = equation.astype(np.float32)
img_list = np.vsplit(equation, equation.shape[0])
X.append(img_list)
Y.append(0)

equations_by_len = divide_equations_by_len(X, Y)
return equations_by_len

if __name__ == "__main__":
get_hed()

+ 245
- 731
examples/hed/hed.ipynb
File diff suppressed because it is too large
View File


+ 0
- 1
examples/hed/reasoning/reasoning.py View File

@@ -1,7 +1,6 @@
import os
import numpy as np
import math
from zoopt import Dimension, Objective, Opt, Parameter
from abl.reasoning import PrologKB, Reasoner
from abl.utils import reform_list



+ 2
- 1
examples/hed/requirements.txt View File

@@ -1 +1,2 @@
abl
abl
gdown

+ 2
- 4
examples/hwf/README.md View File

@@ -26,11 +26,9 @@ optional arguments:
--no-cuda disables CUDA training
--epochs EPOCHS number of epochs in each learning loop iteration
(default : 1)
--lr LR base learning rate (default : 0.001)
--weight-decay WEIGHT_DECAY
weight decay value (default : 0.03)
--lr LR base model learning rate (default : 0.001)
--batch-size BATCH_SIZE
batch size (default : 32)
base model batch size (default : 32)
--loops LOOPS number of loop iterations (default : 5)
--segment_size SEGMENT_SIZE
segment size (default : 1/3)


+ 2
- 2
examples/hwf/datasets/get_dataset.py View File

@@ -13,13 +13,13 @@ img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(
def download_and_unzip(url, zip_file_name):
try:
gdown.download(url, zip_file_name)
with zipfile.pseudo_labelipFile(zip_file_name, 'r') as zip_ref:
with zipfile.ZipFile(zip_file_name, 'r') as zip_ref:
zip_ref.extractall(CURRENT_DIR)
os.remove(zip_file_name)
except Exception as e:
if os.path.exists(zip_file_name):
os.remove(zip_file_name)
raise Exception(f"An error occurred during download or unzip: {e}. Instead, you can download the dataset from {url} and unzip it in './datasets' folder")
raise Exception(f"An error occurred during download or unzip: {e}. Instead, you can download the dataset from {url} and unzip it in 'examples/hwf/datasets' folder")

def get_dataset(train=True, get_pseudo_label=False):
data_dir = CURRENT_DIR + '/data'


+ 2
- 2
examples/hwf/hwf.ipynb View File

@@ -6,7 +6,7 @@
"source": [
"# Handwritten Formula (HWF)\n",
"\n",
"This notebook shows an implementation of [Handwritten Formula](https://arxiv.org/abs/2006.06649). In this task. In this task, handwritten images of decimal formulas and their computed results are given, alongwith a domain knowledge base containing information on how to compute the decimal formula. The task is to recognize the symbols (which can be digits or operators '+', '-', '×', '÷') of handwritten images and accurately determine their results.\n",
"This notebook shows an implementation of [Handwritten Formula](https://arxiv.org/abs/2006.06649). In this task, handwritten images of decimal formulas and their computed results are given, alongwith a domain knowledge base containing information on how to compute the decimal formula. The task is to recognize the symbols (which can be digits or operators '+', '-', '×', '÷') of handwritten images and accurately determine their results.\n",
"\n",
"Intuitively, we first use a machine learning model (learning part) to convert the input images to symbols (we call them pseudo-labels), and then use the knowledge base (reasoning part) to calculate the results of these symbols. Since we do not have ground-truth of the symbols, in Abductive Learning, the reasoning part will leverage domain knowledge and revise the initial symbols yielded by the learning part through abductive reasoning. This process enables us to further update the machine learning model."
]
@@ -214,7 +214,7 @@
"# class of symbol may be one of ['0', '1', ..., '9', '+', '-', '*', '/'], total of 14 classes\n",
"cls = SymbolNet(num_classes=14, image_size=(45, 45, 1))\n",
"loss_fn = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))\n",
"optimizer = torch.optim.Adam(cls.parameters(), lr=0.001)\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"base_model = BasicNN(\n",


+ 2
- 4
examples/hwf/main.py View File

@@ -68,11 +68,9 @@ def main():
parser.add_argument('--epochs', type=int, default=3,
help='number of epochs in each learning loop iteration (default : 3)')
parser.add_argument('--lr', type=float, default=1e-3,
help='base learning rate (default : 0.001)')
parser.add_argument('--weight-decay', type=int, default=3e-2,
help='weight decay value (default : 0.03)')
help='base model learning rate (default : 0.001)')
parser.add_argument('--batch-size', type=int, default=128,
help='batch size (default : 128)')
help='base model batch size (default : 128)')
parser.add_argument('--loops', type=int, default=5,
help='number of loop iterations (default : 5)')
parser.add_argument('--segment_size', type=int or float, default=1000,


+ 5
- 6
examples/mnist_add/README.md View File

@@ -12,8 +12,8 @@ python main.py
## Usage

```bash
usage: main.py [-h] [--no-cuda] [--epochs EPOCHS] [--lr LR]
[--weight-decay WEIGHT_DECAY] [--batch-size BATCH_SIZE]
usage: main.py [-h] [--no-cuda] [--epochs EPOCHS] [--lr LR]
[--alpha ALPHA] [--batch-size BATCH_SIZE]
[--loops LOOPS] [--segment_size SEGMENT_SIZE]
[--save_interval SAVE_INTERVAL] [--max-revision MAX_REVISION]
[--require-more-revision REQUIRE_MORE_REVISION]
@@ -26,11 +26,10 @@ optional arguments:
--no-cuda disables CUDA training
--epochs EPOCHS number of epochs in each learning loop iteration
(default : 1)
--lr LR base learning rate (default : 0.001)
--weight-decay WEIGHT_DECAY
weight decay value (default : 0.03)
--lr LR base model learning rate (default : 0.001)
--alpha ALPHA alpha in RMSprop (default : 0.9)
--batch-size BATCH_SIZE
batch size (default : 32)
base model batch size (default : 32)
--loops LOOPS number of loop iterations (default : 5)
--segment_size SEGMENT_SIZE
segment size (default : 1/3)


+ 5
- 5
examples/mnist_add/main.py View File

@@ -34,11 +34,11 @@ def main():
parser.add_argument('--epochs', type=int, default=1,
help='number of epochs in each learning loop iteration (default : 1)')
parser.add_argument('--lr', type=float, default=1e-3,
help='base learning rate (default : 0.001)')
parser.add_argument('--weight-decay', type=int, default=3e-2,
help='weight decay value (default : 0.03)')
help='base model learning rate (default : 0.001)')
parser.add_argument('--alpha', type=float, default=0.9,
help='alpha in RMSprop (default : 0.9)')
parser.add_argument('--batch-size', type=int, default=32,
help='batch size (default : 32)')
help='base model batch size (default : 32)')
parser.add_argument('--loops', type=int, default=5,
help='number of loop iterations (default : 5)')
parser.add_argument('--segment_size', type=int or float, default=1/3,
@@ -65,7 +65,7 @@ def main():
# Build necessary components for BasicNN
cls = LeNet5(num_classes=10)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cls.parameters(), lr=args.lr)
optimizer = torch.optim.RMSprop(cls.parameters(), lr=args.lr, alpha=args.alpha)
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")



+ 12
- 10
examples/mnist_add/mnist_add.ipynb View File

@@ -80,11 +80,6 @@
}
],
"source": [
"def describe_structure(lst):\n",
" if not isinstance(lst, list):\n",
" return type(lst).__name__ \n",
" return [describe_structure(item) for item in lst]\n",
"\n",
"print(f\"Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y\")\n",
"print()\n",
"train_X, train_gt_pseudo_label, train_Y = train_data\n",
@@ -357,7 +352,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@@ -390,7 +385,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -402,14 +397,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Bridge Learning and Reasoning\n",
"## Bridging Learning and Reasoning\n",
"\n",
"Now, the last step is to bridge the learning and reasoning part. We proceed this step by creating an instance of `SimpleBridge`."
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
@@ -437,6 +432,13 @@
"bridge.train(train_data, loops=5, segment_size=1/3, save_interval=1, save_dir=weights_dir)\n",
"bridge.test(test_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@@ -455,7 +457,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
"version": "3.8.13"
},
"orig_nbformat": 4,
"vscode": {


+ 29
- 0
examples/zoo/get_dataset.py View File

@@ -0,0 +1,29 @@
import numpy as np
import openml

# Function to load and preprocess the dataset
def load_and_preprocess_dataset(dataset_id):
dataset = openml.datasets.get_dataset(dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False)
X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)
# Convert data types
for col in X.select_dtypes(include='bool').columns:
X[col] = X[col].astype(int)
y = y.cat.codes.astype(int)
X, y = X.to_numpy(), y.to_numpy()
return X, y

# Function to split data (one shot)
def split_dataset(X, y, test_size = 0.3):
# For every class: 1 : (1-test_size)*(len-1) : test_size*(len-1)
label_indices, unlabel_indices, test_indices = [], [], []
for class_label in np.unique(y):
idxs = np.where(y == class_label)[0]
np.random.shuffle(idxs)
n_train_unlabel = int((1-test_size)*(len(idxs)-1))
label_indices.append(idxs[0])
unlabel_indices.extend(idxs[1:1+n_train_unlabel])
test_indices.extend(idxs[1+n_train_unlabel:])
X_label, y_label = X[label_indices], y[label_indices]
X_unlabel, y_unlabel = X[unlabel_indices], y[unlabel_indices]
X_test, y_test = X[test_indices], y[test_indices]
return X_label, y_label, X_unlabel, y_unlabel, X_test, y_test

+ 80
- 0
examples/zoo/kb.py View File

@@ -0,0 +1,80 @@
from z3 import Solver, Int, If, Not, Implies, Sum, sat
import openml
from abl.reasoning import KBBase

class ZooKB(KBBase):
def __init__(self):
super().__init__(pseudo_label_list=list(range(7)), use_cache=False)

self.solver = Solver()

# Load information of Zoo dataset
dataset = openml.datasets.get_dataset(dataset_id = 62, download_data=False, download_qualities=False, download_features_meta_data=False)
X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)
self.attribute_names = attribute_names
self.target_names = y.cat.categories.tolist()
print("Attribute names are: ", self.attribute_names)
print("Target names are: ", self.target_names)
# self.attribute_names = ["hair", "feathers", "eggs", "milk", "airborne", "aquatic", "predator", "toothed", "backbone", "breathes", "venomous", "fins", "legs", "tail", "domestic", "catsize"]
# self.target_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "invertebrate"]

# Define variables
for name in self.attribute_names+self.target_names:
exec(f"globals()['{name}'] = Int('{name}')") ## or use dict to create var and modify rules
# Define rules
rules = [
Implies(milk == 1, mammal == 1),
Implies(mammal == 1, milk == 1),
Implies(mammal == 1, backbone == 1),
Implies(mammal == 1, breathes == 1),
Implies(feathers == 1, bird == 1),
Implies(bird == 1, feathers == 1),
Implies(bird == 1, eggs == 1),
Implies(bird == 1, backbone == 1),
Implies(bird == 1, breathes == 1),
Implies(bird == 1, legs == 2),
Implies(bird == 1, tail == 1),
Implies(reptile == 1, backbone == 1),
Implies(reptile == 1, breathes == 1),
Implies(reptile == 1, tail == 1),
Implies(fish == 1, aquatic == 1),
Implies(fish == 1, toothed == 1),
Implies(fish == 1, backbone == 1),
Implies(fish == 1, Not(breathes == 1)),
Implies(fish == 1, fins == 1),
Implies(fish == 1, legs == 0),
Implies(fish == 1, tail == 1),
Implies(amphibian == 1, eggs == 1),
Implies(amphibian == 1, aquatic == 1),
Implies(amphibian == 1, backbone == 1),
Implies(amphibian == 1, breathes == 1),
Implies(amphibian == 1, legs == 4),
Implies(insect == 1, eggs == 1),
Implies(insect == 1, Not(backbone == 1)),
Implies(insect == 1, legs == 6),
Implies(invertebrate == 1, Not(backbone == 1))
]
# Define weights and sum of violated weights
self.weights = {rule: 1 for rule in rules}
self.total_violation_weight = Sum([If(Not(rule), self.weights[rule], 0) for rule in self.weights])

def logic_forward(self, pseudo_label, data_point):
attribute_names, target_names = self.attribute_names, self.target_names
solver = self.solver
total_violation_weight = self.total_violation_weight
pseudo_label, data_point = pseudo_label[0], data_point[0]

self.solver.reset()
for name, value in zip(attribute_names, data_point):
solver.add(eval(f"{name} == {value}"))
for cate, name in zip(self.pseudo_label_list,target_names):
value = 1 if (cate == pseudo_label) else 0
solver.add(eval(f"{name} == {value}"))

if solver.check() == sat:
model = solver.model()
total_weight = model.evaluate(total_violation_weight)
return total_weight.as_long()
else:
# No solution found
return 1e10

+ 4
- 0
examples/zoo/requirements.txt View File

@@ -0,0 +1,4 @@
abl
z3-solver
openml
scikit-learn

+ 370
- 0
examples/zoo/zoo.ipynb View File

@@ -0,0 +1,370 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ZOO\n",
"\n",
"This notebook shows an implementation of [MNIST Addition](https://arxiv.org/abs/1805.10872). In this task, pairs of MNIST handwritten images and their sums are given, alongwith a domain knowledge base containing information on how to perform addition operations. The task is to recognize the digits of handwritten images and accurately determine their sum.\n",
"\n",
"Intuitively, we first use a machine learning model (learning part) to convert the input images to digits (we call them pseudo-labels), and then use the knowledge base (reasoning part) to calculate the sum of these digits. Since we do not have ground-truth of the digits, in Abductive Learning, the reasoning part will leverage domain knowledge and revise the initial digits yielded by the learning part through abductive reasoning. This process enables us to further update the machine learning model."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Import necessary libraries and modules\n",
"import os.path as osp\n",
"import numpy as np\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from examples.zoo.get_dataset import load_and_preprocess_dataset, split_dataset\n",
"from abl.learning import ABLModel\n",
"from examples.zoo.kb import ZooKB\n",
"from abl.reasoning import Reasoner\n",
"from abl.evaluation import ReasoningMetric, SymbolMetric\n",
"from abl.utils import ABLLogger, print_log, confidence_dist\n",
"from abl.bridge import SimpleBridge"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Working with Data\n",
"\n",
"First, we get the training and testing datasets:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Load and preprocess the Zoo dataset\n",
"X, y = load_and_preprocess_dataset(dataset_id=62)\n",
"\n",
"# Split data into labeled/unlabeled/test data\n",
"X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`train_data` and `test_data` share identical structures: tuples with three components: X (list where each element is a list of two images), gt_pseudo_label (list where each element is a list of two digits, i.e., pseudo-labels) and Y (list where each element is the sum of the two digits). The length and structures of datasets are illustrated as follows.\n",
"\n",
"Note: ``gt_pseudo_label`` is only used to evaluate the performance of the learning part but not to train the model."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Shape of X and y: (101, 16) (101,)\n",
"First five elements of X:\n",
"[[True False False True False False True True True True False False 4\n",
" False False True]\n",
" [True False False True False False False True True True False False 4\n",
" True False True]\n",
" [False False True False False True True True True False False True 0\n",
" True False False]\n",
" [True False False True False False True True True True False False 4\n",
" False False True]\n",
" [True False False True False False True True True True False False 4\n",
" True False True]]\n",
"First five elements of y:\n",
"[0 0 3 0 0]\n"
]
}
],
"source": [
"print(\"Shape of X and y:\", X.shape, y.shape)\n",
"print(\"First five elements of X:\")\n",
"print(X[:5])\n",
"print(\"First five elements of y:\")\n",
"print(y[:5])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Transform tabluar data to the format required by ABL-Package, which is a tuple of (X, gt_pseudo_label, Y)\n",
"\n",
"For tabular data in abl, each example contains a single instance (a row from the dataset).\n",
"\n",
"For these tabular data samples, the reasoning results are expected to be 0, indicating no rules are violated."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def transform_tab_data(X, y):\n",
" return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))\n",
"label_data = transform_tab_data(X_label, y_label)\n",
"test_data = transform_tab_data(X_test, y_test)\n",
"train_data = transform_tab_data(X_unlabel, y_unlabel)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Building the Learning Part"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To build the learning part, we need to first build a machine learning base model. We use a [Random Forest](https://en.wikipedia.org/wiki/Random_forest) as the base model"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style>#sk-container-id-1 {color: black;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>RandomForestClassifier()</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">RandomForestClassifier</label><div class=\"sk-toggleable__content\"><pre>RandomForestClassifier()</pre></div></div></div></div></div>"
],
"text/plain": [
"RandomForestClassifier()"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"base_model = RandomForestClassifier()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"However, the base model built above deals with instance-level data, and can not directly deal with example-level data. Therefore, we wrap the base model into `ABLModel`, which enables the learning part to train, test, and predict on example-level data."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"model = ABLModel(base_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Building the Reasoning Part"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the reasoning part, we first build a knowledge base which contain information on how to perform addition operations. We build it by creating a subclass of `KBBase`. In the derived subclass, we initialize the `pseudo_label_list` parameter specifying list of possible pseudo-labels, and override the `logic_forward` function defining how to perform (deductive) reasoning."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Attribute names are: ['hair', 'feathers', 'eggs', 'milk', 'airborne', 'aquatic', 'predator', 'toothed', 'backbone', 'breathes', 'venomous', 'fins', 'legs', 'tail', 'domestic', 'catsize']\n",
"Target names are: ['mammal', 'bird', 'reptile', 'fish', 'amphibian', 'insect', 'invertebrate']\n"
]
}
],
"source": [
"kb = ZooKB()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The knowledge base can perform logical reasoning (both deductive reasoning and abductive reasoning). Below is an example of performing (deductive) reasoning, and users can refer to [Documentation]() for details of abductive reasoning."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Reasoning result of pseudo-label example [1, 2] is 3.\n"
]
}
],
"source": [
"pseudo_label = [0]\n",
"data_point = [np.array([1,0,0,1,0,0,1,1,1,1,0,0,4,0,0,1,1])]\n",
"print(kb.logic_forward(pseudo_label, data_point))\n",
"for x, y_item in zip(X, y):\n",
" print(x,y_item)\n",
" print(kb.logic_forward([y_item], [x]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note: In addition to building a knowledge base based on `KBBase`, we can also establish a knowledge base with a ground KB using `GroundKB`, or a knowledge base implemented based on Prolog files using `PrologKB`. The corresponding code for these implementations can be found in the `main.py` file. Those interested are encouraged to examine it for further insights."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then, we create a reasoner by instantiating the class ``Reasoner``. Due to the indeterminism of abductive reasoning, there could be multiple candidates compatible to the knowledge base. When this happens, reasoner can minimize inconsistencies between the knowledge base and pseudo-labels predicted by the learning part, and then return only one candidate that has the highest consistency."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def consitency(data_example, candidates, candidate_idxs, reasoning_results):\n",
" pred_prob = data_example.pred_prob\n",
" model_scores = confidence_dist(pred_prob, candidate_idxs)\n",
" rule_scores = np.array(reasoning_results)\n",
" scores = model_scores + rule_scores\n",
" return scores\n",
"\n",
"reasoner = Reasoner(kb, dist_func=consitency)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Building Evaluation Metrics"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we set up evaluation metrics. These metrics will be used to evaluate the model performance during training and testing. Specifically, we use `SymbolMetric` and `ReasoningMetric`, which are used to evaluate the accuracy of the machine learning model’s predictions and the accuracy of the final reasoning results, respectively."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"metric_list = [SymbolMetric(prefix=\"zoo\"), ReasoningMetric(kb=kb, prefix=\"zoo\")]"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Bridging Learning and Reasoning\n",
"\n",
"Now, the last step is to bridge the learning and reasoning part. We proceed this step by creating an instance of `SimpleBridge`."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"bridge = SimpleBridge(model, reasoner, metric_list)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Perform training and testing by invoking the `train` and `test` methods of `SimpleBridge`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Build logger\n",
"print_log(\"Abductive Learning on the ZOO example.\", logger=\"current\")\n",
"log_dir = ABLLogger.get_current_instance().log_dir\n",
"weights_dir = osp.join(log_dir, \"weights\")\n",
"\n",
"# Pre-train the machine learning model\n",
"base_model.fit(X_label, y_label)\n",
"\n",
"# Test the initial model\n",
"print(\"------- Test the initial model -----------\")\n",
"bridge.test(test_data)\n",
"print(\"------- Use ABL to train the model -----------\")\n",
"# Use ABL to train the model\n",
"bridge.train(train_data=train_data, label_data=label_data, loops=3, segment_size=len(X_unlabel), save_dir=weights_dir)\n",
"print(\"------- Test the final model -----------\")\n",
"# Test the final model\n",
"bridge.test(test_data)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "abl",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "9c8d454494e49869a4ee4046edcac9a39ff683f7d38abf0769f648402670238e"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

+ 0
- 292
examples/zoo/zoo_example.ipynb View File

@@ -1,292 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os.path as osp\n",
"\n",
"import numpy as np\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.metrics import accuracy_score\n",
"from z3 import Solver, Int, If, Not, Implies, Sum, sat\n",
"import openml\n",
"\n",
"from abl.learning import ABLModel\n",
"from abl.reasoning import KBBase, Reasoner\n",
"from abl.evaluation import ReasoningMetric, SymbolMetric\n",
"from abl.bridge import SimpleBridge\n",
"from abl.utils.utils import confidence_dist\n",
"from abl.utils import ABLLogger, print_log"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Build logger\n",
"print_log(\"Abductive Learning on the Zoo example.\", logger=\"current\")\n",
"\n",
"# Retrieve the directory of the Log file and define the directory for saving the model weights.\n",
"log_dir = ABLLogger.get_current_instance().log_dir\n",
"weights_dir = osp.join(log_dir, \"weights\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Learning Part"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rf = RandomForestClassifier()\n",
"model = ABLModel(rf)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Logic Part"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class ZooKB(KBBase):\n",
" def __init__(self):\n",
" super().__init__(pseudo_label_list=list(range(7)), use_cache=False)\n",
" \n",
" # Use z3 solver \n",
" self.solver = Solver()\n",
"\n",
" # Load information of Zoo dataset\n",
" dataset = openml.datasets.get_dataset(dataset_id = 62, download_data=False, download_qualities=False, download_features_meta_data=False)\n",
" X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n",
" self.attribute_names = attribute_names\n",
" self.target_names = y.cat.categories.tolist()\n",
" \n",
" # Define variables\n",
" for name in self.attribute_names+self.target_names:\n",
" exec(f\"globals()['{name}'] = Int('{name}')\") ## or use dict to create var and modify rules\n",
" # Define rules\n",
" rules = [\n",
" Implies(milk == 1, mammal == 1),\n",
" Implies(mammal == 1, milk == 1),\n",
" Implies(mammal == 1, backbone == 1),\n",
" Implies(mammal == 1, breathes == 1),\n",
" Implies(feathers == 1, bird == 1),\n",
" Implies(bird == 1, feathers == 1),\n",
" Implies(bird == 1, eggs == 1),\n",
" Implies(bird == 1, backbone == 1),\n",
" Implies(bird == 1, breathes == 1),\n",
" Implies(bird == 1, legs == 2),\n",
" Implies(bird == 1, tail == 1),\n",
" Implies(reptile == 1, backbone == 1),\n",
" Implies(reptile == 1, breathes == 1),\n",
" Implies(reptile == 1, tail == 1),\n",
" Implies(fish == 1, aquatic == 1),\n",
" Implies(fish == 1, toothed == 1),\n",
" Implies(fish == 1, backbone == 1),\n",
" Implies(fish == 1, Not(breathes == 1)),\n",
" Implies(fish == 1, fins == 1),\n",
" Implies(fish == 1, legs == 0),\n",
" Implies(fish == 1, tail == 1),\n",
" Implies(amphibian == 1, eggs == 1),\n",
" Implies(amphibian == 1, aquatic == 1),\n",
" Implies(amphibian == 1, backbone == 1),\n",
" Implies(amphibian == 1, breathes == 1),\n",
" Implies(amphibian == 1, legs == 4),\n",
" Implies(insect == 1, eggs == 1),\n",
" Implies(insect == 1, Not(backbone == 1)),\n",
" Implies(insect == 1, legs == 6),\n",
" Implies(invertebrate == 1, Not(backbone == 1))\n",
" ]\n",
" # Define weights and sum of violated weights\n",
" self.weights = {rule: 1 for rule in rules}\n",
" self.total_violation_weight = Sum([If(Not(rule), self.weights[rule], 0) for rule in self.weights])\n",
" \n",
" def logic_forward(self, pseudo_label, data_point):\n",
" attribute_names, target_names = self.attribute_names, self.target_names\n",
" solver = self.solver\n",
" total_violation_weight = self.total_violation_weight\n",
" pseudo_label, data_point = pseudo_label[0], data_point[0]\n",
" \n",
" self.solver.reset()\n",
" for name, value in zip(attribute_names, data_point):\n",
" solver.add(eval(f\"{name} == {value}\"))\n",
" for cate, name in zip(self.pseudo_label_list,target_names):\n",
" value = 1 if (cate == pseudo_label) else 0\n",
" solver.add(eval(f\"{name} == {value}\"))\n",
" \n",
" if solver.check() == sat:\n",
" model = solver.model()\n",
" total_weight = model.evaluate(total_violation_weight)\n",
" return total_weight.as_long()\n",
" else:\n",
" # No solution found\n",
" return 1e10\n",
" \n",
"def consitency(data_example, candidates, candidate_idxs, reasoning_results):\n",
" pred_prob = data_example.pred_prob\n",
" model_scores = confidence_dist(pred_prob, candidate_idxs)\n",
" rule_scores = np.array(reasoning_results)\n",
" scores = model_scores + rule_scores\n",
" return scores\n",
"\n",
"kb = ZooKB()\n",
"reasoner = Reasoner(kb, dist_func=consitency)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Datasets and Evaluation Metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Function to load and preprocess the dataset\n",
"def load_and_preprocess_dataset(dataset_id):\n",
" dataset = openml.datasets.get_dataset(dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False)\n",
" X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n",
" # Convert data types\n",
" for col in X.select_dtypes(include='bool').columns:\n",
" X[col] = X[col].astype(int)\n",
" y = y.cat.codes.astype(int)\n",
" X, y = X.to_numpy(), y.to_numpy()\n",
" return X, y\n",
"\n",
"# Function to split data (one shot)\n",
"def split_dataset(X, y, test_size = 0.3):\n",
" # For every class: 1 : (1-test_size)*(len-1) : test_size*(len-1)\n",
" label_indices, unlabel_indices, test_indices = [], [], []\n",
" for class_label in np.unique(y):\n",
" idxs = np.where(y == class_label)[0]\n",
" np.random.shuffle(idxs)\n",
" n_train_unlabel = int((1-test_size)*(len(idxs)-1))\n",
" label_indices.append(idxs[0])\n",
" unlabel_indices.extend(idxs[1:1+n_train_unlabel])\n",
" test_indices.extend(idxs[1+n_train_unlabel:])\n",
" X_label, y_label = X[label_indices], y[label_indices]\n",
" X_unlabel, y_unlabel = X[unlabel_indices], y[unlabel_indices]\n",
" X_test, y_test = X[test_indices], y[test_indices]\n",
" return X_label, y_label, X_unlabel, y_unlabel, X_test, y_test\n",
"\n",
"# Load and preprocess the Zoo dataset\n",
"X, y = load_and_preprocess_dataset(dataset_id=62)\n",
"\n",
"# Split data into labeled/unlabeled/test data\n",
"X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3)\n",
"\n",
"# Transform tabluar data to the format required by ABL, which is a tuple of (X, ground truth of X, reasoning results)\n",
"# For tabular data in abl, each example contains a single instance (a row from the dataset).\n",
"# For these tabular data examples, the reasoning results are expected to be 0, indicating no rules are violated.\n",
"def transform_tab_data(X, y):\n",
" return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))\n",
"label_data = transform_tab_data(X_label, y_label)\n",
"test_data = transform_tab_data(X_test, y_test)\n",
"train_data = transform_tab_data(X_unlabel, y_unlabel)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Set up metrics\n",
"metric_list = [SymbolMetric(prefix=\"zoo\"), ReasoningMetric(kb=kb, prefix=\"zoo\")]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Bridge Machine Learning and Logic Reasoning"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bridge = SimpleBridge(model, reasoner, metric_list)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train and Test"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Pre-train the machine learning model\n",
"rf.fit(X_label, y_label)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test the initial model\n",
"print(\"------- Test the initial model -----------\")\n",
"bridge.test(test_data)\n",
"print(\"------- Use ABL to train the model -----------\")\n",
"# Use ABL to train the model\n",
"bridge.train(train_data=train_data, label_data=label_data, loops=3, segment_size=len(X_unlabel), save_dir=weights_dir)\n",
"print(\"------- Test the final model -----------\")\n",
"# Test the final model\n",
"bridge.test(test_data)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "abl",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

+ 1
- 1
tests/conftest.py View File

@@ -215,7 +215,7 @@ def kb_hwf2():
def kb_hed():
kb = HedKB(
pseudo_label_list=[1, 0, "+", "="],
pl_file="examples/hed/datasets/learn_add.pl",
pl_file="examples/hed/reasoning/learn_add.pl",
)
return kb



+ 1
- 1
tests/test_reasoning.py View File

@@ -57,7 +57,7 @@ class TestPrologKB(object):

def test_init_pl2(self, kb_hed):
assert kb_hed.pseudo_label_list == [1, 0, "+", "="]
assert kb_hed.pl_file == "examples/hed/datasets/learn_add.pl"
assert kb_hed.pl_file == "examples/hed/reasoning/learn_add.pl"

def test_prolog_file_not_exist(self):
pseudo_label_list = [1, 2]


Loading…
Cancel
Save