diff --git a/README.md b/README.md index e24b6ce..88ba424 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ For Linux users: $ sudo apt-get install swi-prolog ``` -For Windows and Mac users, please refer to the [Swi-Prolog Download Page](https://www.swi-prolog.org/Download.html). +For Windows and Mac users, please refer to the [Swi-Prolog Install Guide](https://github.com/yuce/pyswip/blob/master/INSTALL.md). ## Examples diff --git a/abl/learning/basic_nn.py b/abl/learning/basic_nn.py index 9a962bc..017960c 100644 --- a/abl/learning/basic_nn.py +++ b/abl/learning/basic_nn.py @@ -66,9 +66,9 @@ class BasicNN: num_workers: int = 0, save_interval: Optional[int] = None, save_dir: Optional[str] = None, - train_transform: Callable[..., Any] = None, - test_transform: Callable[..., Any] = None, - collate_fn: Callable[[List[Any]], Any] = None, + train_transform: Optional[Callable[..., Any]] = None, + test_transform: Optional[Callable[..., Any]] = None, + collate_fn: Optional[Callable[[List[Any]], Any]] = None, ) -> None: if not isinstance(model, torch.nn.Module): raise TypeError("model must be an instance of torch.nn.Module") diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index 9efb39c..df54b28 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -471,7 +471,7 @@ class PrologKB(KBBase): except (IndexError, ImportError): print("A Prolog-based knowledge base is in use. Please install Swi-Prolog \ using the command 'sudo apt-get install swi-prolog' for Linux users, \ - or download it from https://www.swi-prolog.org/Download.html for Windows and Mac users.") + or download it following the guide in https://github.com/yuce/pyswip/blob/master/INSTALL.md for Windows and Mac users.") self.prolog = pyswip.Prolog() self.pl_file = pl_file diff --git a/abl/utils/__init__.py b/abl/utils/__init__.py index d69e09b..9cfd590 100644 --- a/abl/utils/__init__.py +++ b/abl/utils/__init__.py @@ -1,6 +1,6 @@ from .cache import Cache, abl_cache from .logger import ABLLogger, print_log -from .utils import confidence_dist, flatten, hamming_dist, reform_list, to_hashable +from .utils import confidence_dist, flatten, hamming_dist, reform_list, to_hashable, tab_data_to_tuple __all__ = [ "Cache", @@ -12,4 +12,5 @@ __all__ = [ "reform_list", "to_hashable", "abl_cache", + "tab_data_to_tuple", ] diff --git a/abl/utils/utils.py b/abl/utils/utils.py index bbeb58b..1e10c21 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -154,4 +154,13 @@ def restore_from_hashable(x): return [restore_from_hashable(item) for item in x] return x - +def tab_data_to_tuple(X, y, reasoning_result = 0): + ''' + Convert a tabular data to a tuple by adding a dimension to each element of X and y. The tuple contains three elements: data, label, and reasoning result. + If X is None, return None. + ''' + if X is None: + return None + if len(X) != len(y): + raise ValueError("The length of X and y should be the same, but got {} and {}.".format(len(X), len(y))) + return ([[x] for x in X], [[y_item] for y_item in y], [reasoning_result] * len(y)) \ No newline at end of file diff --git a/docs/Overview/Installation.rst b/docs/Overview/Installation.rst index 78b664f..ecabd0e 100644 --- a/docs/Overview/Installation.rst +++ b/docs/Overview/Installation.rst @@ -31,4 +31,4 @@ For Linux users: $ sudo apt-get install swi-prolog -For Windows and Mac users, please refer to the `Swi-Prolog Download Page `_. \ No newline at end of file +For Windows and Mac users, please refer to the `Swi-Prolog Install Guide `_. \ No newline at end of file diff --git a/docs/README.rst b/docs/README.rst index 09280d0..c83deae 100644 --- a/docs/README.rst +++ b/docs/README.rst @@ -51,7 +51,7 @@ For Linux users: $ sudo apt-get install swi-prolog -For Windows and Mac users, please refer to the `Swi-Prolog Download Page `_. +For Windows and Mac users, please refer to the `Swi-Prolog Install Guide `_. References ---------- diff --git a/examples/hed/main.py b/examples/hed/main.py index f4e7564..984ff5c 100644 --- a/examples/hed/main.py +++ b/examples/hed/main.py @@ -46,13 +46,20 @@ def main(): ) args = parser.parse_args() + + # Build logger + print_log("Abductive Learning on the HED example.", logger="current") ### Working with Data + print_log("Working with Data.", logger="current") + total_train_data = get_dataset(train=True) train_data, val_data = split_equation(total_train_data, 3, 1) test_data = get_dataset(train=False) ### Building the Learning Part + print_log("Building the Learning Part.", logger="current") + # Build necessary components for BasicNN cls = SymbolNet(num_classes=4) loss_fn = nn.CrossEntropyLoss() @@ -75,6 +82,8 @@ def main(): model = ABLModel(base_model) ### Building the Reasoning Part + print_log("Building the Reasoning Part.", logger="current") + # Build knowledge base kb = HedKB() @@ -82,14 +91,13 @@ def main(): reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=args.max_revision) ### Building Evaluation Metrics + print_log("Building Evaluation Metrics.", logger="current") metric_list = [ConsistencyMetric(kb=kb)] ### Bridge Learning and Reasoning + print_log("Bridge Learning and Reasoning.", logger="current") bridge = HedBridge(model, reasoner, metric_list) - # Build logger - print_log("Abductive Learning on the HED example.", logger="current") - # Retrieve the directory of the Log file and define the directory for saving the model weights. log_dir = ABLLogger.get_current_instance().log_dir weights_dir = osp.join(log_dir, "weights") diff --git a/examples/hwf/main.py b/examples/hwf/main.py index 963fa15..83c60e9 100644 --- a/examples/hwf/main.py +++ b/examples/hwf/main.py @@ -113,12 +113,19 @@ def main(): ) args = parser.parse_args() + + # Build logger + print_log("Abductive Learning on the HWF example.", logger="current") ### Working with Data + print_log("Working with Data.", logger="current") + train_data = get_dataset(train=True, get_pseudo_label=True) test_data = get_dataset(train=False, get_pseudo_label=True) ### Building the Learning Part + print_log("Building the Learning Part.", logger="current") + # Build necessary components for BasicNN cls = SymbolNet(num_classes=13, image_size=(45, 45, 1)) loss_fn = nn.CrossEntropyLoss() @@ -140,6 +147,8 @@ def main(): model = ABLModel(base_model) ### Building the Reasoning Part + print_log("Building the Reasoning Part.", logger="current") + # Build knowledge base if args.ground: kb = HwfGroundKB() @@ -152,14 +161,13 @@ def main(): ) ### Building Evaluation Metrics + print_log("Building Evaluation Metrics.", logger="current") metric_list = [SymbolAccuracy(prefix="hwf"), ReasoningMetric(kb=kb, prefix="hwf")] ### Bridge Learning and Reasoning + print_log("Bridge Learning and Reasoning.", logger="current") bridge = SimpleBridge(model, reasoner, metric_list) - # Build logger - print_log("Abductive Learning on the HWF example.", logger="current") - # Retrieve the directory of the Log file and define the directory for saving the model weights. log_dir = ABLLogger.get_current_instance().log_dir weights_dir = osp.join(log_dir, "weights") diff --git a/examples/mnist_add/main.py b/examples/mnist_add/main.py index b74b82e..0616fc5 100644 --- a/examples/mnist_add/main.py +++ b/examples/mnist_add/main.py @@ -78,11 +78,17 @@ def main(): args = parser.parse_args() + # Build logger + print_log("Abductive Learning on the MNIST Addition example.", logger="current") + ### Working with Data + print_log("Working with Data.", logger="current") train_data = get_dataset(train=True, get_pseudo_label=True) test_data = get_dataset(train=False, get_pseudo_label=True) ### Building the Learning Part + print_log("Building the Learning Part.", logger="current") + # Build necessary components for BasicNN cls = LeNet5(num_classes=10) loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2) @@ -112,6 +118,8 @@ def main(): model = ABLModel(base_model) ### Building the Reasoning Part + print_log("Building the Reasoning Part.", logger="current") + # Build knowledge base if args.prolog: kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl") @@ -126,14 +134,13 @@ def main(): ) ### Building Evaluation Metrics + print_log("Building Evaluation Metrics.", logger="current") metric_list = [SymbolAccuracy(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")] ### Bridge Learning and Reasoning + print_log("Bridge Learning and Reasoning.", logger="current") bridge = SimpleBridge(model, reasoner, metric_list) - # Build logger - print_log("Abductive Learning on the MNIST Addition example.", logger="current") - # Retrieve the directory of the Log file and define the directory for saving the model weights. log_dir = ABLLogger.get_current_instance().log_dir weights_dir = osp.join(log_dir, "weights") diff --git a/examples/zoo/main.py b/examples/zoo/main.py index 26bdc66..b4da2d1 100644 --- a/examples/zoo/main.py +++ b/examples/zoo/main.py @@ -8,14 +8,12 @@ from abl.bridge import SimpleBridge from abl.data.evaluation import ReasoningMetric, SymbolAccuracy from abl.learning import ABLModel from abl.reasoning import Reasoner -from abl.utils import ABLLogger, confidence_dist, print_log +from abl.utils import ABLLogger, confidence_dist, print_log, tab_data_to_tuple from get_dataset import load_and_preprocess_dataset, split_dataset from kb import ZooKB -def transform_tab_data(X, y): - return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y)) def consitency(data_example, candidates, candidate_idxs, reasoning_results): pred_prob = data_example.pred_prob @@ -30,21 +28,31 @@ def main(): "--loops", type=int, default=3, help="number of loop iterations (default : 3)" ) args = parser.parse_args() + + # Build logger + print_log("Abductive Learning on the ZOO example.", logger="current") ### Working with Data + print_log("Working with Data.", logger="current") + X, y = load_and_preprocess_dataset(dataset_id=62) X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3) - label_data = transform_tab_data(X_label, y_label) - test_data = transform_tab_data(X_test, y_test) - train_data = transform_tab_data(X_unlabel, y_unlabel) + label_data = tab_data_to_tuple(X_label, y_label) + test_data = tab_data_to_tuple(X_test, y_test) + train_data = tab_data_to_tuple(X_unlabel, y_unlabel) ### Building the Learning Part + print_log("Building the Learning Part.", logger="current") + + # Build base model base_model = RandomForestClassifier() # Build ABLModel model = ABLModel(base_model) ### Building the Reasoning Part + print_log("Building the Reasoning Part.", logger="current") + # Build knowledge base kb = ZooKB() @@ -52,16 +60,17 @@ def main(): reasoner = Reasoner(kb, dist_func=consitency) ### Building Evaluation Metrics + print_log("Building Evaluation Metrics.", logger="current") metric_list = [SymbolAccuracy(prefix="zoo"), ReasoningMetric(kb=kb, prefix="zoo")] - # Build logger - print_log("Abductive Learning on the ZOO example.", logger="current") - log_dir = ABLLogger.get_current_instance().log_dir - weights_dir = osp.join(log_dir, "weights") - ### Bridging learning and reasoning + print_log("Bridge Learning and Reasoning.", logger="current") bridge = SimpleBridge(model, reasoner, metric_list) + # Retrieve the directory of the Log file and define the directory for saving the model weights. + log_dir = ABLLogger.get_current_instance().log_dir + weights_dir = osp.join(log_dir, "weights") + # Performing training and testing print_log("------- Use labeled data to pretrain the model -----------", logger="current") base_model.fit(X_label, y_label) diff --git a/examples/zoo/zoo.ipynb b/examples/zoo/zoo.ipynb index 4596a55..bf21f43 100644 --- a/examples/zoo/zoo.ipynb +++ b/examples/zoo/zoo.ipynb @@ -27,7 +27,7 @@ "from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", "from abl.learning import ABLModel\n", "from abl.reasoning import Reasoner\n", - "from abl.utils import ABLLogger, confidence_dist, print_log\n", + "from abl.utils import ABLLogger, confidence_dist, print_log, tab_data_to_tuple\n", "\n", "from get_dataset import load_and_preprocess_dataset, split_dataset\n", "from kb import ZooKB" @@ -106,11 +106,9 @@ "metadata": {}, "outputs": [], "source": [ - "def transform_tab_data(X, y):\n", - " return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))\n", - "label_data = transform_tab_data(X_label, y_label)\n", - "test_data = transform_tab_data(X_test, y_test)\n", - "train_data = transform_tab_data(X_unlabel, y_unlabel)" + "label_data = tab_data_to_tuple(X_label, y_label, reasoning_result = 0)\n", + "test_data = tab_data_to_tuple(X_test, y_test, reasoning_result = 0)\n", + "train_data = tab_data_to_tuple(X_unlabel, y_unlabel, reasoning_result = 0)" ] }, { diff --git a/pyproject.toml b/pyproject.toml index c41b5f0..9355614 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" name = "abl" version = "0.1.4" authors = [ - { name="LAMDA 2023" }, + { name="LAMDA 2024" }, ] description = "Abductive learning package project" readme = "README.md" @@ -25,12 +25,12 @@ classifiers = [ "Programming Language :: Python :: 3.9", ] dependencies = [ - "numpy", - "pyswip==0.2.9", - "torch", - "torchvision", - "zoopt", - "termcolor" + "numpy>=1.15.0", + "pyswip>=0.2.9", + "torch>=1.11.0", + "torchvision>=0.12.0", + "zoopt>=0.3.0", + "termcolor>=2.3.0" ] [project.urls] diff --git a/requirements.txt b/requirements.txt index 66e1a5b..6c27923 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ -numpy -pyswip==0.2.9 -torch -torchvision -zoopt -termcolor \ No newline at end of file +numpy>=1.15.0, +pyswip>=0.2.9, +torch>=1.11.0, +torchvision>=0.12.0, +zoopt>=0.3.0, +termcolor>=2.3.0 \ No newline at end of file