|
|
@@ -0,0 +1,299 @@ |
|
|
|
{ |
|
|
|
"cells": [ |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"import os.path as osp\n", |
|
|
|
"\n", |
|
|
|
"import numpy as np\n", |
|
|
|
"from sklearn.ensemble import RandomForestClassifier\n", |
|
|
|
"from sklearn.metrics import accuracy_score\n", |
|
|
|
"from z3 import Solver, Int, If, Not, Implies, Sum, sat\n", |
|
|
|
"import openml\n", |
|
|
|
"\n", |
|
|
|
"from abl.learning import ABLModel\n", |
|
|
|
"from abl.reasoning import KBBase, Reasoner\n", |
|
|
|
"from abl.evaluation import ReasoningMetric, SymbolMetric\n", |
|
|
|
"from abl.bridge import SimpleBridge\n", |
|
|
|
"from abl.utils.utils import confidence_dist\n", |
|
|
|
"from abl.utils import ABLLogger, print_log" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"# Build logger\n", |
|
|
|
"print_log(\"Abductive Learning on the Zoo example.\", logger=\"current\")\n", |
|
|
|
"\n", |
|
|
|
"# Retrieve the directory of the Log file and define the directory for saving the model weights.\n", |
|
|
|
"log_dir = ABLLogger.get_current_instance().log_dir\n", |
|
|
|
"weights_dir = osp.join(log_dir, \"weights\")" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"### Learning Part" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"rf = RandomForestClassifier()\n", |
|
|
|
"model = ABLModel(rf)" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"### Logic Part" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"class ZooKB(KBBase):\n", |
|
|
|
" def __init__(self):\n", |
|
|
|
" super().__init__(pseudo_label_list=list(range(7)), use_cache=False)\n", |
|
|
|
" \n", |
|
|
|
" # Use z3 solver \n", |
|
|
|
" self.solver = Solver()\n", |
|
|
|
"\n", |
|
|
|
" # Load information of Zoo dataset\n", |
|
|
|
" dataset = openml.datasets.get_dataset(dataset_id = 62, download_data=False, download_qualities=False, download_features_meta_data=False)\n", |
|
|
|
" X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n", |
|
|
|
" self.attribute_names = attribute_names\n", |
|
|
|
" self.target_names = y.cat.categories.tolist()\n", |
|
|
|
" print(\"Attribute names are: \", self.attribute_names)\n", |
|
|
|
" print(\"Target names are: \", self.target_names)\n", |
|
|
|
" # self.attribute_names = [\"hair\", \"feathers\", \"eggs\", \"milk\", \"airborne\", \"aquatic\", \"predator\", \"toothed\", \"backbone\", \"breathes\", \"venomous\", \"fins\", \"legs\", \"tail\", \"domestic\", \"catsize\"]\n", |
|
|
|
" # self.target_names = [\"mammal\", \"bird\", \"reptile\", \"fish\", \"amphibian\", \"insect\", \"invertebrate\"]\n", |
|
|
|
"\n", |
|
|
|
" # Define variables\n", |
|
|
|
" for name in self.attribute_names+self.target_names:\n", |
|
|
|
" exec(f\"globals()['{name}'] = Int('{name}')\") ## or use dict to create var and modify rules\n", |
|
|
|
" # Define rules\n", |
|
|
|
" rules = [\n", |
|
|
|
" Implies(milk == 1, mammal == 1),\n", |
|
|
|
" Implies(mammal == 1, milk == 1),\n", |
|
|
|
" Implies(mammal == 1, backbone == 1),\n", |
|
|
|
" Implies(mammal == 1, breathes == 1),\n", |
|
|
|
" Implies(feathers == 1, bird == 1),\n", |
|
|
|
" Implies(bird == 1, feathers == 1),\n", |
|
|
|
" Implies(bird == 1, eggs == 1),\n", |
|
|
|
" Implies(bird == 1, backbone == 1),\n", |
|
|
|
" Implies(bird == 1, breathes == 1),\n", |
|
|
|
" Implies(bird == 1, legs == 2),\n", |
|
|
|
" Implies(bird == 1, tail == 1),\n", |
|
|
|
" Implies(reptile == 1, backbone == 1),\n", |
|
|
|
" Implies(reptile == 1, breathes == 1),\n", |
|
|
|
" Implies(reptile == 1, tail == 1),\n", |
|
|
|
" Implies(fish == 1, aquatic == 1),\n", |
|
|
|
" Implies(fish == 1, toothed == 1),\n", |
|
|
|
" Implies(fish == 1, backbone == 1),\n", |
|
|
|
" Implies(fish == 1, Not(breathes == 1)),\n", |
|
|
|
" Implies(fish == 1, fins == 1),\n", |
|
|
|
" Implies(fish == 1, legs == 0),\n", |
|
|
|
" Implies(fish == 1, tail == 1),\n", |
|
|
|
" Implies(amphibian == 1, eggs == 1),\n", |
|
|
|
" Implies(amphibian == 1, aquatic == 1),\n", |
|
|
|
" Implies(amphibian == 1, backbone == 1),\n", |
|
|
|
" Implies(amphibian == 1, breathes == 1),\n", |
|
|
|
" Implies(amphibian == 1, legs == 4),\n", |
|
|
|
" Implies(insect == 1, eggs == 1),\n", |
|
|
|
" Implies(insect == 1, Not(backbone == 1)),\n", |
|
|
|
" Implies(insect == 1, legs == 6),\n", |
|
|
|
" Implies(invertebrate == 1, Not(backbone == 1))\n", |
|
|
|
" ]\n", |
|
|
|
" # Define weights and sum of violated weights\n", |
|
|
|
" self.weights = {rule: 1 for rule in rules}\n", |
|
|
|
" self.total_violation_weight = Sum([If(Not(rule), self.weights[rule], 0) for rule in self.weights])\n", |
|
|
|
" \n", |
|
|
|
" def logic_forward(self, pseudo_label, data_point):\n", |
|
|
|
" attribute_names, target_names = self.attribute_names, self.target_names\n", |
|
|
|
" solver = self.solver\n", |
|
|
|
" total_violation_weight = self.total_violation_weight\n", |
|
|
|
" pseudo_label, data_point = pseudo_label[0], data_point[0]\n", |
|
|
|
" \n", |
|
|
|
" self.solver.reset()\n", |
|
|
|
" for name, value in zip(attribute_names, data_point):\n", |
|
|
|
" solver.add(eval(f\"{name} == {value}\"))\n", |
|
|
|
" for cate, name in zip(self.pseudo_label_list,target_names):\n", |
|
|
|
" value = 1 if (cate == pseudo_label) else 0\n", |
|
|
|
" solver.add(eval(f\"{name} == {value}\"))\n", |
|
|
|
" \n", |
|
|
|
" if solver.check() == sat:\n", |
|
|
|
" model = solver.model()\n", |
|
|
|
" total_weight = model.evaluate(total_violation_weight)\n", |
|
|
|
" # violated_rules = [str(rule) for rule in self.weights if model.evaluate(Not(rule))]\n", |
|
|
|
" # print(\"Total violation weight for the given data point:\", total_weight)\n", |
|
|
|
" # print(\"Violated rules:\", violated_rules)\n", |
|
|
|
" return total_weight.as_long()\n", |
|
|
|
" else:\n", |
|
|
|
" # No solution found\n", |
|
|
|
" return 1e10\n", |
|
|
|
" \n", |
|
|
|
"def consitency(data_sample, candidates, candidate_idxs, reasoning_results):\n", |
|
|
|
" pred_prob = data_sample.pred_prob\n", |
|
|
|
" model_scores = confidence_dist(pred_prob, candidate_idxs)\n", |
|
|
|
" rule_scores = np.array(reasoning_results)\n", |
|
|
|
" scores = model_scores + rule_scores\n", |
|
|
|
" return scores\n", |
|
|
|
"\n", |
|
|
|
"kb = ZooKB()\n", |
|
|
|
"reasoner = Reasoner(kb, dist_func=consitency)" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"### Datasets and Evaluation Metrics" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"# Function to load and preprocess the dataset\n", |
|
|
|
"def load_and_preprocess_dataset(dataset_id):\n", |
|
|
|
" dataset = openml.datasets.get_dataset(dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False)\n", |
|
|
|
" X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n", |
|
|
|
" # Convert data types\n", |
|
|
|
" for col in X.select_dtypes(include='bool').columns:\n", |
|
|
|
" X[col] = X[col].astype(int)\n", |
|
|
|
" y = y.cat.codes.astype(int)\n", |
|
|
|
" X, y = X.to_numpy(), y.to_numpy()\n", |
|
|
|
" return X, y\n", |
|
|
|
"\n", |
|
|
|
"# Function to split data (one shot)\n", |
|
|
|
"def split_dataset(X, y, test_size = 0.3):\n", |
|
|
|
" # For every class: 1 : (1-test_size)*(len-1) : test_size*(len-1)\n", |
|
|
|
" label_indices, unlabel_indices, test_indices = [], [], []\n", |
|
|
|
" for class_label in np.unique(y):\n", |
|
|
|
" idxs = np.where(y == class_label)[0]\n", |
|
|
|
" np.random.shuffle(idxs)\n", |
|
|
|
" n_train_unlabel = int((1-test_size)*(len(idxs)-1))\n", |
|
|
|
" label_indices.append(idxs[0])\n", |
|
|
|
" unlabel_indices.extend(idxs[1:1+n_train_unlabel])\n", |
|
|
|
" test_indices.extend(idxs[1+n_train_unlabel:])\n", |
|
|
|
" X_label, y_label = X[label_indices], y[label_indices]\n", |
|
|
|
" X_unlabel, y_unlabel = X[unlabel_indices], y[unlabel_indices]\n", |
|
|
|
" X_test, y_test = X[test_indices], y[test_indices]\n", |
|
|
|
" return X_label, y_label, X_unlabel, y_unlabel, X_test, y_test\n", |
|
|
|
"\n", |
|
|
|
"# Load and preprocess the Zoo dataset\n", |
|
|
|
"X, y = load_and_preprocess_dataset(dataset_id=62)\n", |
|
|
|
"\n", |
|
|
|
"# Split data into labeled/unlabeled/test data\n", |
|
|
|
"X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3)\n", |
|
|
|
"\n", |
|
|
|
"# Transform tabluar data to the format required by ABL, which is a tuple of (X, ground truth of X, reasoning results)\n", |
|
|
|
"# For tabular data in abl, each sample contains a single instance (a row from the dataset).\n", |
|
|
|
"# For these tabular data samples, the reasoning results are expected to be 0, indicating no rules are violated.\n", |
|
|
|
"def transform_tab_data(X, y):\n", |
|
|
|
" return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))\n", |
|
|
|
"label_data = transform_tab_data(X_label, y_label)\n", |
|
|
|
"test_data = transform_tab_data(X_test, y_test)\n", |
|
|
|
"train_data = transform_tab_data(X_unlabel, y_unlabel)" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"# Set up metrics\n", |
|
|
|
"metric_list = [SymbolMetric(prefix=\"zoo\"), ReasoningMetric(kb=kb, prefix=\"zoo\")]" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"### Bridge Machine Learning and Logic Reasoning" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"bridge = SimpleBridge(model, reasoner, metric_list)" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"### Train and Test" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"# Pre-train the machine learning model\n", |
|
|
|
"rf.fit(X_label, y_label)" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"# Test the initial model\n", |
|
|
|
"print(\"------- Test the initial model -----------\")\n", |
|
|
|
"bridge.test(test_data)\n", |
|
|
|
"print(\"------- Use ABL to train the model -----------\")\n", |
|
|
|
"# Use ABL to train the model\n", |
|
|
|
"bridge.train(train_data=train_data, label_data=label_data, loops=3, segment_size=len(X_unlabel), save_dir=weights_dir)\n", |
|
|
|
"print(\"------- Test the final model -----------\")\n", |
|
|
|
"# Test the final model\n", |
|
|
|
"bridge.test(test_data)" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"metadata": { |
|
|
|
"kernelspec": { |
|
|
|
"display_name": "abl", |
|
|
|
"language": "python", |
|
|
|
"name": "python3" |
|
|
|
}, |
|
|
|
"language_info": { |
|
|
|
"codemirror_mode": { |
|
|
|
"name": "ipython", |
|
|
|
"version": 3 |
|
|
|
}, |
|
|
|
"file_extension": ".py", |
|
|
|
"mimetype": "text/x-python", |
|
|
|
"name": "python", |
|
|
|
"nbconvert_exporter": "python", |
|
|
|
"pygments_lexer": "ipython3", |
|
|
|
"version": "3.8.13" |
|
|
|
} |
|
|
|
}, |
|
|
|
"nbformat": 4, |
|
|
|
"nbformat_minor": 2 |
|
|
|
} |