Browse Source

[ENH] Add notebook for tabular dataset

pull/1/head
troyyyyy 1 year ago
parent
commit
8f63f3797c
3 changed files with 302 additions and 198 deletions
  1. +0
    -197
      examples/zoo/zoo.py
  2. +299
    -0
      examples/zoo/zoo_example.ipynb
  3. +3
    -1
      requirements.txt

+ 0
- 197
examples/zoo/zoo.py View File

@@ -1,197 +0,0 @@
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from z3 import Solver, Int, If, Not, Implies, Sum, sat
import openml
from abl.learning import ABLModel
from abl.reasoning import KBBase, Reasoner
from abl.evaluation import ReasoningMetric, SymbolMetric
from abl.bridge import SimpleBridge
from abl.utils.utils import confidence_dist

class ZooKB(KBBase):
def __init__(self):
super().__init__(pseudo_label_list=list(range(7)), use_cache=False)
# Use z3 solver
self.solver = Solver()

# Load information of Zoo dataset
dataset = openml.datasets.get_dataset(dataset_id = 62, download_data=False, download_qualities=False, download_features_meta_data=False)
X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)
self.attribute_names = attribute_names
self.target_names = y.cat.categories.tolist()
print("Attribute names are: ", self.attribute_names)
print("Target names are: ", self.target_names)
# self.attribute_names = ["hair", "feathers", "eggs", "milk", "airborne", "aquatic", "predator", "toothed", "backbone", "breathes", "venomous", "fins", "legs", "tail", "domestic", "catsize"]
# self.target_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "invertebrate"]

# Define variables
for name in self.attribute_names+self.target_names:
exec(f"globals()['{name}'] = Int('{name}')") ## or use dict to create var and modify rules
# Define rules
rules = [
Implies(milk == 1, mammal == 1),
Implies(mammal == 1, milk == 1),
Implies(mammal == 1, backbone == 1),
Implies(mammal == 1, breathes == 1),
Implies(feathers == 1, bird == 1),
Implies(bird == 1, feathers == 1),
Implies(bird == 1, eggs == 1),
Implies(bird == 1, backbone == 1),
Implies(bird == 1, breathes == 1),
Implies(bird == 1, legs == 2),
Implies(bird == 1, tail == 1),
Implies(reptile == 1, backbone == 1),
Implies(reptile == 1, breathes == 1),
Implies(reptile == 1, tail == 1),
Implies(fish == 1, aquatic == 1),
Implies(fish == 1, toothed == 1),
Implies(fish == 1, backbone == 1),
Implies(fish == 1, Not(breathes == 1)),
Implies(fish == 1, fins == 1),
Implies(fish == 1, legs == 0),
Implies(fish == 1, tail == 1),
Implies(amphibian == 1, eggs == 1),
Implies(amphibian == 1, aquatic == 1),
Implies(amphibian == 1, backbone == 1),
Implies(amphibian == 1, breathes == 1),
Implies(amphibian == 1, legs == 4),
Implies(insect == 1, eggs == 1),
Implies(insect == 1, Not(backbone == 1)),
Implies(insect == 1, legs == 6),
Implies(invertebrate == 1, Not(backbone == 1))
]
# Define weights and sum of violated weights
self.weights = {rule: 1 for rule in rules}
self.total_violation_weight = Sum([If(Not(rule), self.weights[rule], 0) for rule in self.weights])
def logic_forward(self, pseudo_label, data_point):
attribute_names, target_names = self.attribute_names, self.target_names
solver = self.solver
total_violation_weight = self.total_violation_weight
pseudo_label, data_point = pseudo_label[0], data_point[0]
self.solver.reset()
for name, value in zip(attribute_names, data_point):
solver.add(eval(f"{name} == {value}"))
for cate, name in zip(self.pseudo_label_list,target_names):
value = 1 if (cate == pseudo_label) else 0
solver.add(eval(f"{name} == {value}"))
if solver.check() == sat:
model = solver.model()
total_weight = model.evaluate(total_violation_weight)
# violated_rules = [str(rule) for rule in self.weights if model.evaluate(Not(rule))]
# print("Total violation weight for the given data point:", total_weight)
# print("Violated rules:", violated_rules)
return total_weight.as_long()
else:
# No solution found
return 1e10

def consitency(data_sample, candidates, candidate_idxs, reasoning_results):
pred_prob = data_sample.pred_prob
model_scores = confidence_dist(pred_prob, candidate_idxs)
rule_scores = np.array(reasoning_results)
scores = model_scores + rule_scores
return scores

# Function to load and preprocess the dataset
def load_and_preprocess_dataset(dataset_id):
dataset = openml.datasets.get_dataset(dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False)
X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)
# Convert data types
for col in X.select_dtypes(include='bool').columns:
X[col] = X[col].astype(int)
y = y.cat.codes.astype(int)
X, y = X.to_numpy(), y.to_numpy()
return X, y

# Function to split data (one shot)
def split_dataset(X, y, test_size = 0.3):
# For every class: 1 : (1-test_size)*(len-1) : test_size*(len-1)
label_indices, unlabel_indices, test_indices = [], [], []
for class_label in np.unique(y):
idxs = np.where(y == class_label)[0]
np.random.shuffle(idxs)
n_train_unlabel = int((1-test_size)*(len(idxs)-1))
label_indices.append(idxs[0])
unlabel_indices.extend(idxs[1:1+n_train_unlabel])
test_indices.extend(idxs[1+n_train_unlabel:])
X_label, y_label = X[label_indices], y[label_indices]
X_unlabel, y_unlabel = X[unlabel_indices], y[unlabel_indices]
X_test, y_test = X[test_indices], y[test_indices]
return X_label, y_label, X_unlabel, y_unlabel, X_test, y_test

if __name__ == "__main__":
'''
Working with data
'''
# Load and preprocess the Zoo dataset
X, y = load_and_preprocess_dataset(dataset_id=62)
print("Shape of X and y:", X.shape, y.shape)
print("First five elements of X:")
print(X[:5])
print("First five elements of y:")
print(y[:5])
# Split data into labeled/unlabeled/test data
X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3)
# Transform tabluar data to the format required by ABL, which is a tuple of (X, ground truth of X, reasoning results)
# For tabular data in abl, each sample contains a single instance (a row from the dataset).
# For these tabular data samples, the reasoning results are expected to be 0, indicating no rules are violated.
def transform_tab_data(X, y):
return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))
label_data = transform_tab_data(X_label, y_label)
test_data = transform_tab_data(X_test, y_test)
train_data = transform_tab_data(X_unlabel, y_unlabel)
'''
Building the learning part
'''
rf = RandomForestClassifier()
# Pre-train the machine learning model
rf.fit(X_label, y_label)
# # Test the initial model
# y_test_pred = rf.predict(X_test)
# labeled_test_acc = accuracy_score(y_test, y_test_pred)
# print(labeled_test_acc)
model = ABLModel(rf)
'''
Building the reasoning part
'''
# Create the knowledge base for Zoo
kb = ZooKB()
# # Test ZooKB
# pseudo_label = [0]
# data_point = [np.array([1,0,0,1,0,0,1,1,1,1,0,0,4,0,0,1,1])]
# print(kb.logic_forward(pseudo_label, data_point))
# for x,y_item in zip(X, y):
# print(x,y_item)
# print(kb.logic_forward([y_item], [x]))
reasoner = Reasoner(kb, dist_func=consitency)
'''
Building evaluation metrics
'''
metric_list = [SymbolMetric(prefix="zoo"), ReasoningMetric(kb=kb, prefix="zoo")]
'''
Bridging Learning and Reasoning
'''
bridge = SimpleBridge(model, reasoner, metric_list)
# Test the initial model
print("------- Test the initial model -----------")
bridge.test(test_data)
print("------- Use ABL to train the model -----------")
# Use ABL to train the model
bridge.train(train_data=train_data, label_data=label_data, loops=3, segment_size=len(X_unlabel))
print("------- Test the final model -----------")
# Test the final model
bridge.test(test_data)

+ 299
- 0
examples/zoo/zoo_example.ipynb View File

@@ -0,0 +1,299 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os.path as osp\n",
"\n",
"import numpy as np\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.metrics import accuracy_score\n",
"from z3 import Solver, Int, If, Not, Implies, Sum, sat\n",
"import openml\n",
"\n",
"from abl.learning import ABLModel\n",
"from abl.reasoning import KBBase, Reasoner\n",
"from abl.evaluation import ReasoningMetric, SymbolMetric\n",
"from abl.bridge import SimpleBridge\n",
"from abl.utils.utils import confidence_dist\n",
"from abl.utils import ABLLogger, print_log"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Build logger\n",
"print_log(\"Abductive Learning on the Zoo example.\", logger=\"current\")\n",
"\n",
"# Retrieve the directory of the Log file and define the directory for saving the model weights.\n",
"log_dir = ABLLogger.get_current_instance().log_dir\n",
"weights_dir = osp.join(log_dir, \"weights\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Learning Part"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rf = RandomForestClassifier()\n",
"model = ABLModel(rf)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Logic Part"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class ZooKB(KBBase):\n",
" def __init__(self):\n",
" super().__init__(pseudo_label_list=list(range(7)), use_cache=False)\n",
" \n",
" # Use z3 solver \n",
" self.solver = Solver()\n",
"\n",
" # Load information of Zoo dataset\n",
" dataset = openml.datasets.get_dataset(dataset_id = 62, download_data=False, download_qualities=False, download_features_meta_data=False)\n",
" X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n",
" self.attribute_names = attribute_names\n",
" self.target_names = y.cat.categories.tolist()\n",
" print(\"Attribute names are: \", self.attribute_names)\n",
" print(\"Target names are: \", self.target_names)\n",
" # self.attribute_names = [\"hair\", \"feathers\", \"eggs\", \"milk\", \"airborne\", \"aquatic\", \"predator\", \"toothed\", \"backbone\", \"breathes\", \"venomous\", \"fins\", \"legs\", \"tail\", \"domestic\", \"catsize\"]\n",
" # self.target_names = [\"mammal\", \"bird\", \"reptile\", \"fish\", \"amphibian\", \"insect\", \"invertebrate\"]\n",
"\n",
" # Define variables\n",
" for name in self.attribute_names+self.target_names:\n",
" exec(f\"globals()['{name}'] = Int('{name}')\") ## or use dict to create var and modify rules\n",
" # Define rules\n",
" rules = [\n",
" Implies(milk == 1, mammal == 1),\n",
" Implies(mammal == 1, milk == 1),\n",
" Implies(mammal == 1, backbone == 1),\n",
" Implies(mammal == 1, breathes == 1),\n",
" Implies(feathers == 1, bird == 1),\n",
" Implies(bird == 1, feathers == 1),\n",
" Implies(bird == 1, eggs == 1),\n",
" Implies(bird == 1, backbone == 1),\n",
" Implies(bird == 1, breathes == 1),\n",
" Implies(bird == 1, legs == 2),\n",
" Implies(bird == 1, tail == 1),\n",
" Implies(reptile == 1, backbone == 1),\n",
" Implies(reptile == 1, breathes == 1),\n",
" Implies(reptile == 1, tail == 1),\n",
" Implies(fish == 1, aquatic == 1),\n",
" Implies(fish == 1, toothed == 1),\n",
" Implies(fish == 1, backbone == 1),\n",
" Implies(fish == 1, Not(breathes == 1)),\n",
" Implies(fish == 1, fins == 1),\n",
" Implies(fish == 1, legs == 0),\n",
" Implies(fish == 1, tail == 1),\n",
" Implies(amphibian == 1, eggs == 1),\n",
" Implies(amphibian == 1, aquatic == 1),\n",
" Implies(amphibian == 1, backbone == 1),\n",
" Implies(amphibian == 1, breathes == 1),\n",
" Implies(amphibian == 1, legs == 4),\n",
" Implies(insect == 1, eggs == 1),\n",
" Implies(insect == 1, Not(backbone == 1)),\n",
" Implies(insect == 1, legs == 6),\n",
" Implies(invertebrate == 1, Not(backbone == 1))\n",
" ]\n",
" # Define weights and sum of violated weights\n",
" self.weights = {rule: 1 for rule in rules}\n",
" self.total_violation_weight = Sum([If(Not(rule), self.weights[rule], 0) for rule in self.weights])\n",
" \n",
" def logic_forward(self, pseudo_label, data_point):\n",
" attribute_names, target_names = self.attribute_names, self.target_names\n",
" solver = self.solver\n",
" total_violation_weight = self.total_violation_weight\n",
" pseudo_label, data_point = pseudo_label[0], data_point[0]\n",
" \n",
" self.solver.reset()\n",
" for name, value in zip(attribute_names, data_point):\n",
" solver.add(eval(f\"{name} == {value}\"))\n",
" for cate, name in zip(self.pseudo_label_list,target_names):\n",
" value = 1 if (cate == pseudo_label) else 0\n",
" solver.add(eval(f\"{name} == {value}\"))\n",
" \n",
" if solver.check() == sat:\n",
" model = solver.model()\n",
" total_weight = model.evaluate(total_violation_weight)\n",
" # violated_rules = [str(rule) for rule in self.weights if model.evaluate(Not(rule))]\n",
" # print(\"Total violation weight for the given data point:\", total_weight)\n",
" # print(\"Violated rules:\", violated_rules)\n",
" return total_weight.as_long()\n",
" else:\n",
" # No solution found\n",
" return 1e10\n",
" \n",
"def consitency(data_sample, candidates, candidate_idxs, reasoning_results):\n",
" pred_prob = data_sample.pred_prob\n",
" model_scores = confidence_dist(pred_prob, candidate_idxs)\n",
" rule_scores = np.array(reasoning_results)\n",
" scores = model_scores + rule_scores\n",
" return scores\n",
"\n",
"kb = ZooKB()\n",
"reasoner = Reasoner(kb, dist_func=consitency)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Datasets and Evaluation Metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Function to load and preprocess the dataset\n",
"def load_and_preprocess_dataset(dataset_id):\n",
" dataset = openml.datasets.get_dataset(dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False)\n",
" X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n",
" # Convert data types\n",
" for col in X.select_dtypes(include='bool').columns:\n",
" X[col] = X[col].astype(int)\n",
" y = y.cat.codes.astype(int)\n",
" X, y = X.to_numpy(), y.to_numpy()\n",
" return X, y\n",
"\n",
"# Function to split data (one shot)\n",
"def split_dataset(X, y, test_size = 0.3):\n",
" # For every class: 1 : (1-test_size)*(len-1) : test_size*(len-1)\n",
" label_indices, unlabel_indices, test_indices = [], [], []\n",
" for class_label in np.unique(y):\n",
" idxs = np.where(y == class_label)[0]\n",
" np.random.shuffle(idxs)\n",
" n_train_unlabel = int((1-test_size)*(len(idxs)-1))\n",
" label_indices.append(idxs[0])\n",
" unlabel_indices.extend(idxs[1:1+n_train_unlabel])\n",
" test_indices.extend(idxs[1+n_train_unlabel:])\n",
" X_label, y_label = X[label_indices], y[label_indices]\n",
" X_unlabel, y_unlabel = X[unlabel_indices], y[unlabel_indices]\n",
" X_test, y_test = X[test_indices], y[test_indices]\n",
" return X_label, y_label, X_unlabel, y_unlabel, X_test, y_test\n",
"\n",
"# Load and preprocess the Zoo dataset\n",
"X, y = load_and_preprocess_dataset(dataset_id=62)\n",
"\n",
"# Split data into labeled/unlabeled/test data\n",
"X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3)\n",
"\n",
"# Transform tabluar data to the format required by ABL, which is a tuple of (X, ground truth of X, reasoning results)\n",
"# For tabular data in abl, each sample contains a single instance (a row from the dataset).\n",
"# For these tabular data samples, the reasoning results are expected to be 0, indicating no rules are violated.\n",
"def transform_tab_data(X, y):\n",
" return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))\n",
"label_data = transform_tab_data(X_label, y_label)\n",
"test_data = transform_tab_data(X_test, y_test)\n",
"train_data = transform_tab_data(X_unlabel, y_unlabel)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Set up metrics\n",
"metric_list = [SymbolMetric(prefix=\"zoo\"), ReasoningMetric(kb=kb, prefix=\"zoo\")]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Bridge Machine Learning and Logic Reasoning"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bridge = SimpleBridge(model, reasoner, metric_list)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train and Test"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Pre-train the machine learning model\n",
"rf.fit(X_label, y_label)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test the initial model\n",
"print(\"------- Test the initial model -----------\")\n",
"bridge.test(test_data)\n",
"print(\"------- Use ABL to train the model -----------\")\n",
"# Use ABL to train the model\n",
"bridge.train(train_data=train_data, label_data=label_data, loops=3, segment_size=len(X_unlabel), save_dir=weights_dir)\n",
"print(\"------- Test the final model -----------\")\n",
"# Test the final model\n",
"bridge.test(test_data)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "abl",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

+ 3
- 1
requirements.txt View File

@@ -4,4 +4,6 @@ torch
torchvision
torchaudio
zoopt
termcolor
termcolor
openml
z3-solver

Loading…
Cancel
Save