@@ -53,7 +53,7 @@ For Linux users: | |||
$ sudo apt-get install swi-prolog | |||
``` | |||
For Windows and Mac users, please refer to the [Swi-Prolog Download Page](https://www.swi-prolog.org/Download.html). | |||
For Windows and Mac users, please refer to the [Swi-Prolog Install Guide](https://github.com/yuce/pyswip/blob/master/INSTALL.md). | |||
## Examples | |||
@@ -66,9 +66,9 @@ class BasicNN: | |||
num_workers: int = 0, | |||
save_interval: Optional[int] = None, | |||
save_dir: Optional[str] = None, | |||
train_transform: Callable[..., Any] = None, | |||
test_transform: Callable[..., Any] = None, | |||
collate_fn: Callable[[List[Any]], Any] = None, | |||
train_transform: Optional[Callable[..., Any]] = None, | |||
test_transform: Optional[Callable[..., Any]] = None, | |||
collate_fn: Optional[Callable[[List[Any]], Any]] = None, | |||
) -> None: | |||
if not isinstance(model, torch.nn.Module): | |||
raise TypeError("model must be an instance of torch.nn.Module") | |||
@@ -471,7 +471,7 @@ class PrologKB(KBBase): | |||
except (IndexError, ImportError): | |||
print("A Prolog-based knowledge base is in use. Please install Swi-Prolog \ | |||
using the command 'sudo apt-get install swi-prolog' for Linux users, \ | |||
or download it from https://www.swi-prolog.org/Download.html for Windows and Mac users.") | |||
or download it following the guide in https://github.com/yuce/pyswip/blob/master/INSTALL.md for Windows and Mac users.") | |||
self.prolog = pyswip.Prolog() | |||
self.pl_file = pl_file | |||
@@ -1,6 +1,6 @@ | |||
from .cache import Cache, abl_cache | |||
from .logger import ABLLogger, print_log | |||
from .utils import confidence_dist, flatten, hamming_dist, reform_list, to_hashable | |||
from .utils import confidence_dist, flatten, hamming_dist, reform_list, to_hashable, tab_data_to_tuple | |||
__all__ = [ | |||
"Cache", | |||
@@ -12,4 +12,5 @@ __all__ = [ | |||
"reform_list", | |||
"to_hashable", | |||
"abl_cache", | |||
"tab_data_to_tuple", | |||
] |
@@ -154,4 +154,13 @@ def restore_from_hashable(x): | |||
return [restore_from_hashable(item) for item in x] | |||
return x | |||
def tab_data_to_tuple(X, y, reasoning_result = 0): | |||
''' | |||
Convert a tabular data to a tuple by adding a dimension to each element of X and y. The tuple contains three elements: data, label, and reasoning result. | |||
If X is None, return None. | |||
''' | |||
if X is None: | |||
return None | |||
if len(X) != len(y): | |||
raise ValueError("The length of X and y should be the same, but got {} and {}.".format(len(X), len(y))) | |||
return ([[x] for x in X], [[y_item] for y_item in y], [reasoning_result] * len(y)) |
@@ -31,4 +31,4 @@ For Linux users: | |||
$ sudo apt-get install swi-prolog | |||
For Windows and Mac users, please refer to the `Swi-Prolog Download Page <https://www.swi-prolog.org/Download.html>`_. | |||
For Windows and Mac users, please refer to the `Swi-Prolog Install Guide <https://github.com/yuce/pyswip/blob/master/INSTALL.md>`_. |
@@ -51,7 +51,7 @@ For Linux users: | |||
$ sudo apt-get install swi-prolog | |||
For Windows and Mac users, please refer to the `Swi-Prolog Download Page <https://www.swi-prolog.org/Download.html>`_. | |||
For Windows and Mac users, please refer to the `Swi-Prolog Install Guide <https://github.com/yuce/pyswip/blob/master/INSTALL.md>`_. | |||
References | |||
---------- | |||
@@ -46,13 +46,20 @@ def main(): | |||
) | |||
args = parser.parse_args() | |||
# Build logger | |||
print_log("Abductive Learning on the HED example.", logger="current") | |||
### Working with Data | |||
print_log("Working with Data.", logger="current") | |||
total_train_data = get_dataset(train=True) | |||
train_data, val_data = split_equation(total_train_data, 3, 1) | |||
test_data = get_dataset(train=False) | |||
### Building the Learning Part | |||
print_log("Building the Learning Part.", logger="current") | |||
# Build necessary components for BasicNN | |||
cls = SymbolNet(num_classes=4) | |||
loss_fn = nn.CrossEntropyLoss() | |||
@@ -75,6 +82,8 @@ def main(): | |||
model = ABLModel(base_model) | |||
### Building the Reasoning Part | |||
print_log("Building the Reasoning Part.", logger="current") | |||
# Build knowledge base | |||
kb = HedKB() | |||
@@ -82,14 +91,13 @@ def main(): | |||
reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=args.max_revision) | |||
### Building Evaluation Metrics | |||
print_log("Building Evaluation Metrics.", logger="current") | |||
metric_list = [ConsistencyMetric(kb=kb)] | |||
### Bridge Learning and Reasoning | |||
print_log("Bridge Learning and Reasoning.", logger="current") | |||
bridge = HedBridge(model, reasoner, metric_list) | |||
# Build logger | |||
print_log("Abductive Learning on the HED example.", logger="current") | |||
# Retrieve the directory of the Log file and define the directory for saving the model weights. | |||
log_dir = ABLLogger.get_current_instance().log_dir | |||
weights_dir = osp.join(log_dir, "weights") | |||
@@ -113,12 +113,19 @@ def main(): | |||
) | |||
args = parser.parse_args() | |||
# Build logger | |||
print_log("Abductive Learning on the HWF example.", logger="current") | |||
### Working with Data | |||
print_log("Working with Data.", logger="current") | |||
train_data = get_dataset(train=True, get_pseudo_label=True) | |||
test_data = get_dataset(train=False, get_pseudo_label=True) | |||
### Building the Learning Part | |||
print_log("Building the Learning Part.", logger="current") | |||
# Build necessary components for BasicNN | |||
cls = SymbolNet(num_classes=13, image_size=(45, 45, 1)) | |||
loss_fn = nn.CrossEntropyLoss() | |||
@@ -140,6 +147,8 @@ def main(): | |||
model = ABLModel(base_model) | |||
### Building the Reasoning Part | |||
print_log("Building the Reasoning Part.", logger="current") | |||
# Build knowledge base | |||
if args.ground: | |||
kb = HwfGroundKB() | |||
@@ -152,14 +161,13 @@ def main(): | |||
) | |||
### Building Evaluation Metrics | |||
print_log("Building Evaluation Metrics.", logger="current") | |||
metric_list = [SymbolAccuracy(prefix="hwf"), ReasoningMetric(kb=kb, prefix="hwf")] | |||
### Bridge Learning and Reasoning | |||
print_log("Bridge Learning and Reasoning.", logger="current") | |||
bridge = SimpleBridge(model, reasoner, metric_list) | |||
# Build logger | |||
print_log("Abductive Learning on the HWF example.", logger="current") | |||
# Retrieve the directory of the Log file and define the directory for saving the model weights. | |||
log_dir = ABLLogger.get_current_instance().log_dir | |||
weights_dir = osp.join(log_dir, "weights") | |||
@@ -78,11 +78,17 @@ def main(): | |||
args = parser.parse_args() | |||
# Build logger | |||
print_log("Abductive Learning on the MNIST Addition example.", logger="current") | |||
### Working with Data | |||
print_log("Working with Data.", logger="current") | |||
train_data = get_dataset(train=True, get_pseudo_label=True) | |||
test_data = get_dataset(train=False, get_pseudo_label=True) | |||
### Building the Learning Part | |||
print_log("Building the Learning Part.", logger="current") | |||
# Build necessary components for BasicNN | |||
cls = LeNet5(num_classes=10) | |||
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2) | |||
@@ -112,6 +118,8 @@ def main(): | |||
model = ABLModel(base_model) | |||
### Building the Reasoning Part | |||
print_log("Building the Reasoning Part.", logger="current") | |||
# Build knowledge base | |||
if args.prolog: | |||
kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl") | |||
@@ -126,14 +134,13 @@ def main(): | |||
) | |||
### Building Evaluation Metrics | |||
print_log("Building Evaluation Metrics.", logger="current") | |||
metric_list = [SymbolAccuracy(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")] | |||
### Bridge Learning and Reasoning | |||
print_log("Bridge Learning and Reasoning.", logger="current") | |||
bridge = SimpleBridge(model, reasoner, metric_list) | |||
# Build logger | |||
print_log("Abductive Learning on the MNIST Addition example.", logger="current") | |||
# Retrieve the directory of the Log file and define the directory for saving the model weights. | |||
log_dir = ABLLogger.get_current_instance().log_dir | |||
weights_dir = osp.join(log_dir, "weights") | |||
@@ -8,14 +8,12 @@ from abl.bridge import SimpleBridge | |||
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | |||
from abl.learning import ABLModel | |||
from abl.reasoning import Reasoner | |||
from abl.utils import ABLLogger, confidence_dist, print_log | |||
from abl.utils import ABLLogger, confidence_dist, print_log, tab_data_to_tuple | |||
from get_dataset import load_and_preprocess_dataset, split_dataset | |||
from kb import ZooKB | |||
def transform_tab_data(X, y): | |||
return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y)) | |||
def consitency(data_example, candidates, candidate_idxs, reasoning_results): | |||
pred_prob = data_example.pred_prob | |||
@@ -30,21 +28,31 @@ def main(): | |||
"--loops", type=int, default=3, help="number of loop iterations (default : 3)" | |||
) | |||
args = parser.parse_args() | |||
# Build logger | |||
print_log("Abductive Learning on the ZOO example.", logger="current") | |||
### Working with Data | |||
print_log("Working with Data.", logger="current") | |||
X, y = load_and_preprocess_dataset(dataset_id=62) | |||
X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3) | |||
label_data = transform_tab_data(X_label, y_label) | |||
test_data = transform_tab_data(X_test, y_test) | |||
train_data = transform_tab_data(X_unlabel, y_unlabel) | |||
label_data = tab_data_to_tuple(X_label, y_label) | |||
test_data = tab_data_to_tuple(X_test, y_test) | |||
train_data = tab_data_to_tuple(X_unlabel, y_unlabel) | |||
### Building the Learning Part | |||
print_log("Building the Learning Part.", logger="current") | |||
# Build base model | |||
base_model = RandomForestClassifier() | |||
# Build ABLModel | |||
model = ABLModel(base_model) | |||
### Building the Reasoning Part | |||
print_log("Building the Reasoning Part.", logger="current") | |||
# Build knowledge base | |||
kb = ZooKB() | |||
@@ -52,16 +60,17 @@ def main(): | |||
reasoner = Reasoner(kb, dist_func=consitency) | |||
### Building Evaluation Metrics | |||
print_log("Building Evaluation Metrics.", logger="current") | |||
metric_list = [SymbolAccuracy(prefix="zoo"), ReasoningMetric(kb=kb, prefix="zoo")] | |||
# Build logger | |||
print_log("Abductive Learning on the ZOO example.", logger="current") | |||
log_dir = ABLLogger.get_current_instance().log_dir | |||
weights_dir = osp.join(log_dir, "weights") | |||
### Bridging learning and reasoning | |||
print_log("Bridge Learning and Reasoning.", logger="current") | |||
bridge = SimpleBridge(model, reasoner, metric_list) | |||
# Retrieve the directory of the Log file and define the directory for saving the model weights. | |||
log_dir = ABLLogger.get_current_instance().log_dir | |||
weights_dir = osp.join(log_dir, "weights") | |||
# Performing training and testing | |||
print_log("------- Use labeled data to pretrain the model -----------", logger="current") | |||
base_model.fit(X_label, y_label) | |||
@@ -27,7 +27,7 @@ | |||
"from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", | |||
"from abl.learning import ABLModel\n", | |||
"from abl.reasoning import Reasoner\n", | |||
"from abl.utils import ABLLogger, confidence_dist, print_log\n", | |||
"from abl.utils import ABLLogger, confidence_dist, print_log, tab_data_to_tuple\n", | |||
"\n", | |||
"from get_dataset import load_and_preprocess_dataset, split_dataset\n", | |||
"from kb import ZooKB" | |||
@@ -106,11 +106,9 @@ | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"def transform_tab_data(X, y):\n", | |||
" return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))\n", | |||
"label_data = transform_tab_data(X_label, y_label)\n", | |||
"test_data = transform_tab_data(X_test, y_test)\n", | |||
"train_data = transform_tab_data(X_unlabel, y_unlabel)" | |||
"label_data = tab_data_to_tuple(X_label, y_label, reasoning_result = 0)\n", | |||
"test_data = tab_data_to_tuple(X_test, y_test, reasoning_result = 0)\n", | |||
"train_data = tab_data_to_tuple(X_unlabel, y_unlabel, reasoning_result = 0)" | |||
] | |||
}, | |||
{ | |||
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" | |||
name = "abl" | |||
version = "0.1.4" | |||
authors = [ | |||
{ name="LAMDA 2023" }, | |||
{ name="LAMDA 2024" }, | |||
] | |||
description = "Abductive learning package project" | |||
readme = "README.md" | |||
@@ -25,12 +25,12 @@ classifiers = [ | |||
"Programming Language :: Python :: 3.9", | |||
] | |||
dependencies = [ | |||
"numpy", | |||
"pyswip==0.2.9", | |||
"torch", | |||
"torchvision", | |||
"zoopt", | |||
"termcolor" | |||
"numpy>=1.15.0", | |||
"pyswip>=0.2.9", | |||
"torch>=1.11.0", | |||
"torchvision>=0.12.0", | |||
"zoopt>=0.3.0", | |||
"termcolor>=2.3.0" | |||
] | |||
[project.urls] | |||
@@ -1,6 +1,6 @@ | |||
numpy | |||
pyswip==0.2.9 | |||
torch | |||
torchvision | |||
zoopt | |||
termcolor | |||
numpy>=1.15.0, | |||
pyswip>=0.2.9, | |||
torch>=1.11.0, | |||
torchvision>=0.12.0, | |||
zoopt>=0.3.0, | |||
termcolor>=2.3.0 |