From 9b3a683e3377d3b2b57a2b7ea479c59bfbaaa724 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Thu, 21 Dec 2023 16:50:41 +0800 Subject: [PATCH] [ENH] transfer structures and evaluation to data --- .coveragerc | 3 +- abl/__init__.py | 6 +- abl/bridge/base_bridge.py | 2 +- abl/bridge/simple_bridge.py | 4 +- abl/data/__init__.py | 2 + abl/{ => data}/evaluation/__init__.py | 0 abl/{ => data}/evaluation/base_metric.py | 2 +- abl/{ => data}/evaluation/reasoning_metric.py | 2 +- abl/{ => data}/evaluation/symbol_metric.py | 0 abl/{ => data}/structures/__init__.py | 0 .../structures/base_data_element.py | 5 +- abl/{ => data}/structures/list_data.py | 13 +- abl/learning/abl_model.py | 2 +- abl/reasoning/reasoner.py | 46 +++---- docs/API/abl.data.rst | 18 +++ docs/API/abl.evaluation.rst | 7 -- docs/API/abl.structures.rst | 7 -- docs/Examples/HED.rst | 2 +- docs/Examples/HWF.rst | 4 +- docs/Examples/MNISTAdd.rst | 4 +- docs/Intro/Datasets.rst | 6 +- docs/Intro/Evaluation.rst | 2 +- docs/Intro/Quick-Start.rst | 2 +- docs/index.rst | 3 +- examples/hed/bridge.py | 4 +- examples/hed/hed.ipynb | 2 +- examples/hwf/hwf.ipynb | 6 +- examples/hwf/main.py | 116 +++++++++++------- examples/mnist_add/main.py | 94 +++++++++----- examples/mnist_add/mnist_add.ipynb | 32 ++--- examples/zoo/main.py | 74 +++++++---- examples/zoo/zoo.ipynb | 4 +- tests/conftest.py | 4 +- 33 files changed, 286 insertions(+), 192 deletions(-) create mode 100644 abl/data/__init__.py rename abl/{ => data}/evaluation/__init__.py (100%) rename abl/{ => data}/evaluation/base_metric.py (98%) rename abl/{ => data}/evaluation/reasoning_metric.py (98%) rename abl/{ => data}/evaluation/symbol_metric.py (100%) rename abl/{ => data}/structures/__init__.py (100%) rename abl/{ => data}/structures/base_data_element.py (99%) rename abl/{ => data}/structures/list_data.py (96%) create mode 100644 docs/API/abl.data.rst delete mode 100644 docs/API/abl.evaluation.rst delete mode 100644 docs/API/abl.structures.rst diff --git a/.coveragerc b/.coveragerc index ccb97bd..30b55ca 100644 --- a/.coveragerc +++ b/.coveragerc @@ -8,8 +8,7 @@ omit = */abl/__init__.py abl/bridge/__init__.py abl/dataset/__init__.py - abl/evaluation/__init__.py + abl/data/__init__.py abl/learning/__init__.py abl/reasoning/__init__.py - abl/structures/__init__.py abl/utils/__init__.py \ No newline at end of file diff --git a/abl/__init__.py b/abl/__init__.py index 979e136..60ab54f 100644 --- a/abl/__init__.py +++ b/abl/__init__.py @@ -1,11 +1,9 @@ -from . import bridge, dataset, evaluation, learning, reasoning, structures, utils +from . import bridge, data, learning, reasoning, utils __all__ = [ "bridge", - "dataset", - "evaluation", + "data", "learning", "reasoning", - "structures", "utils", ] diff --git a/abl/bridge/base_bridge.py b/abl/bridge/base_bridge.py index 57b5cf3..9aec7cb 100644 --- a/abl/bridge/base_bridge.py +++ b/abl/bridge/base_bridge.py @@ -3,7 +3,7 @@ from typing import Any, List, Optional, Tuple, Union from ..learning import ABLModel from ..reasoning import Reasoner -from ..structures import ListData +from ..data.structures import ListData class BaseBridge(metaclass=ABCMeta): diff --git a/abl/bridge/simple_bridge.py b/abl/bridge/simple_bridge.py index bad7a58..d0d39a1 100644 --- a/abl/bridge/simple_bridge.py +++ b/abl/bridge/simple_bridge.py @@ -3,10 +3,10 @@ from typing import Any, List, Optional, Tuple, Union from numpy import ndarray -from ..evaluation import BaseMetric +from ..data.evaluation import BaseMetric from ..learning import ABLModel from ..reasoning import Reasoner -from ..structures import ListData +from ..data.structures import ListData from ..utils import print_log from .base_bridge import BaseBridge diff --git a/abl/data/__init__.py b/abl/data/__init__.py new file mode 100644 index 0000000..3dc4849 --- /dev/null +++ b/abl/data/__init__.py @@ -0,0 +1,2 @@ +from .evaluation import * +from .structures import * \ No newline at end of file diff --git a/abl/evaluation/__init__.py b/abl/data/evaluation/__init__.py similarity index 100% rename from abl/evaluation/__init__.py rename to abl/data/evaluation/__init__.py diff --git a/abl/evaluation/base_metric.py b/abl/data/evaluation/base_metric.py similarity index 98% rename from abl/evaluation/base_metric.py rename to abl/data/evaluation/base_metric.py index cdaff7c..3371190 100644 --- a/abl/evaluation/base_metric.py +++ b/abl/data/evaluation/base_metric.py @@ -3,7 +3,7 @@ from abc import ABCMeta, abstractmethod from typing import Any, List, Optional from ..structures import ListData -from ..utils import print_log +from ...utils import print_log class BaseMetric(metaclass=ABCMeta): diff --git a/abl/evaluation/reasoning_metric.py b/abl/data/evaluation/reasoning_metric.py similarity index 98% rename from abl/evaluation/reasoning_metric.py rename to abl/data/evaluation/reasoning_metric.py index 759f6de..3368bd3 100644 --- a/abl/evaluation/reasoning_metric.py +++ b/abl/data/evaluation/reasoning_metric.py @@ -1,6 +1,6 @@ from typing import Optional -from ..reasoning import KBBase +from ...reasoning import KBBase from ..structures import ListData from .base_metric import BaseMetric diff --git a/abl/evaluation/symbol_metric.py b/abl/data/evaluation/symbol_metric.py similarity index 100% rename from abl/evaluation/symbol_metric.py rename to abl/data/evaluation/symbol_metric.py diff --git a/abl/structures/__init__.py b/abl/data/structures/__init__.py similarity index 100% rename from abl/structures/__init__.py rename to abl/data/structures/__init__.py diff --git a/abl/structures/base_data_element.py b/abl/data/structures/base_data_element.py similarity index 99% rename from abl/structures/base_data_element.py rename to abl/data/structures/base_data_element.py index 79cfa61..8bff0ea 100644 --- a/abl/structures/base_data_element.py +++ b/abl/data/structures/base_data_element.py @@ -5,8 +5,9 @@ from typing import Any, Iterator, Optional, Tuple, Type, Union import numpy as np import torch + # Modified from -# https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py +# https://github.com/open-mmlab/mmengine/blob/main/mmengine/data.structures/base_data_element.py class BaseDataElement: """A base data interface that supports Tensor-like and dict-like operations. @@ -73,7 +74,7 @@ class BaseDataElement: Examples: >>> import torch - >>> from mmengine.structures import BaseDataElement + >>> from mmengine.data.structures import BaseDataElement >>> gt_instances = BaseDataElement() >>> bboxes = torch.rand((5, 4)) >>> scores = torch.rand((5,)) diff --git a/abl/structures/list_data.py b/abl/data/structures/list_data.py similarity index 96% rename from abl/structures/list_data.py rename to abl/data/structures/list_data.py index dbc8c2d..61bd208 100644 --- a/abl/structures/list_data.py +++ b/abl/data/structures/list_data.py @@ -1,12 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -import itertools from typing import List, Union import numpy as np import torch -from ..utils import flatten as flatten_list -from ..utils import to_hashable +from ...utils import flatten as flatten_list +from ...utils import to_hashable from .base_data_element import BaseDataElement BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] @@ -16,19 +15,19 @@ IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndar # Modified from -# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa +# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_data.structures/instance_data.py # noqa class ListData(BaseDataElement): """ Data structure for example-level data. Subclass of :class:`BaseDataElement`. All value in `data_fields` should have the same length. This design refer to - https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py + https://github.com/facebookresearch/detectron2/blob/master/detectron2/data.structures/instances.py - ListData supports `index` and `slice` for data field. The type of value in data field can be either `None` or `list` of base data structures such as `torch.Tensor`, `numpy.ndarray`, `list`, `str` and `tuple`. + ListData supports `index` and `slice` for data field. The type of value in data field can be either `None` or `list` of base data data.structures such as `torch.Tensor`, `numpy.ndarray`, `list`, `str` and `tuple`. Examples: - >>> from abl.structures import ListData + >>> from abl.data.structures import ListData >>> import numpy as np >>> import torch >>> data_examples = ListData() diff --git a/abl/learning/abl_model.py b/abl/learning/abl_model.py index 41e0ff5..65c0452 100644 --- a/abl/learning/abl_model.py +++ b/abl/learning/abl_model.py @@ -1,7 +1,7 @@ import pickle from typing import Any, Dict -from ..structures import ListData +from ..data.structures import ListData from ..utils import reform_list diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 8c3faa2..1f57a6d 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -5,7 +5,7 @@ import numpy as np from zoopt import Dimension, Objective, Opt, Parameter, Solution from ..reasoning import KBBase -from ..structures import ListData +from ..data.structures import ListData from ..utils.utils import confidence_dist, hamming_dist @@ -19,18 +19,18 @@ class Reasoner: The knowledge base to be used for reasoning. dist_func : Union[str, Callable], optional The distance function used to determine the cost list between each - candidate and the given prediction. The cost is also referred to as a consistency - measure, wherein the candidate with lowest cost is selected as the final - abduced label. It can be either a string representing a predefined distance - function or a callable function. The available predefined distance functions: - 'hamming' | 'confidence'. 'hamming': directly calculates the Hamming - distance between the predicted pseudo-label in the data example and each - candidate, 'confidence': calculates the distance between the prediction - and each candidate based on confidence derived from the predicted probability - in the data example. The callable function should have the signature - dist_func(data_example, candidates, candidate_idxs, reasoning_results) and must return a cost list. Each element - in this cost list should be a numerical value representing the cost for each - candidate, and the list should have the same length as candidates. + candidate and the given prediction. The cost is also referred to as a consistency + measure, wherein the candidate with lowest cost is selected as the final + abduced label. It can be either a string representing a predefined distance + function or a callable function. The available predefined distance functions: + 'hamming' | 'confidence'. 'hamming': directly calculates the Hamming + distance between the predicted pseudo-label in the data example and each + candidate, 'confidence': calculates the distance between the prediction + and each candidate based on confidence derived from the predicted probability + in the data example. The callable function should have the signature + dist_func(data_example, candidates, candidate_idxs, reasoning_results) and must return a cost list. Each element + in this cost list should be a numerical value representing the cost for each + candidate, and the list should have the same length as candidates. Defaults to 'confidence'. idx_to_label : Optional[dict], optional A mapping from index in the base model to label. If not provided, a default @@ -64,7 +64,9 @@ class Reasoner: self.require_more_revision = require_more_revision if idx_to_label is None: - self.idx_to_label = {index: label for index, label in enumerate(self.kb.pseudo_label_list)} + self.idx_to_label = { + index: label for index, label in enumerate(self.kb.pseudo_label_list) + } else: self._check_valid_idx_to_label(idx_to_label) self.idx_to_label = idx_to_label @@ -80,7 +82,9 @@ class Reasoner: elif callable(dist_func): params = inspect.signature(dist_func).parameters.values() if len(params) != 4: - raise ValueError(f"User-defined dist_func must have exactly four parameters, but got {len(params)}.") + raise ValueError( + f"User-defined dist_func must have exactly four parameters, but got {len(params)}." + ) return else: raise TypeError( @@ -289,18 +293,18 @@ class Reasoner: solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num) revision_idx = np.where(solution.get_x() != 0)[0] candidates, reasoning_results = self.kb.revise_at_idx( - pseudo_label=data_example.pred_pseudo_label, - y=data_example.Y, - x=data_example.X, - revision_idx=revision_idx + pseudo_label=data_example.pred_pseudo_label, + y=data_example.Y, + x=data_example.X, + revision_idx=revision_idx, ) else: candidates, reasoning_results = self.kb.abduce_candidates( pseudo_label=data_example.pred_pseudo_label, - y=data_example.Y, + y=data_example.Y, x=data_example.X, max_revision_num=max_revision_num, - require_more_revision=self.require_more_revision + require_more_revision=self.require_more_revision, ) candidate = self._get_one_candidate(data_example, candidates, reasoning_results) diff --git a/docs/API/abl.data.rst b/docs/API/abl.data.rst new file mode 100644 index 0000000..464f329 --- /dev/null +++ b/docs/API/abl.data.rst @@ -0,0 +1,18 @@ +abl.data +=================== + +Data Structure +-------------- + +.. autoclass:: abl.data.structures.ListData + :members: + :undoc-members: + :show-inheritance: + +Evaluation Metric +----------------- + +.. automodule:: abl.data.evaluation + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/API/abl.evaluation.rst b/docs/API/abl.evaluation.rst deleted file mode 100644 index b07808c..0000000 --- a/docs/API/abl.evaluation.rst +++ /dev/null @@ -1,7 +0,0 @@ -abl.evaluation -================== - -.. automodule:: abl.evaluation - :members: - :undoc-members: - :show-inheritance: \ No newline at end of file diff --git a/docs/API/abl.structures.rst b/docs/API/abl.structures.rst deleted file mode 100644 index fee74c4..0000000 --- a/docs/API/abl.structures.rst +++ /dev/null @@ -1,7 +0,0 @@ -abl.structures -================== - -.. autoclass:: abl.structures.ListData - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/Examples/HED.rst b/docs/Examples/HED.rst index cf17f80..9a4c262 100644 --- a/docs/Examples/HED.rst +++ b/docs/Examples/HED.rst @@ -32,7 +32,7 @@ model. from examples.models.nn import SymbolNet from abl.learning import ABLModel, BasicNN from examples.hed.reasoning import HedKB, HedReasoner - from abl.evaluation import ReasoningMetric, SymbolMetric + from abl.data.evaluation import ReasoningMetric, SymbolMetric from abl.utils import ABLLogger, print_log from examples.hed.bridge import HedBridge diff --git a/docs/Examples/HWF.rst b/docs/Examples/HWF.rst index 88f1238..9c59ca4 100644 --- a/docs/Examples/HWF.rst +++ b/docs/Examples/HWF.rst @@ -30,7 +30,7 @@ machine learning model. from examples.models.nn import SymbolNet from abl.learning import ABLModel, BasicNN from abl.reasoning import KBBase, Reasoner - from abl.evaluation import ReasoningMetric, SymbolMetric + from abl.data.evaluation import ReasoningMetric, SymbolMetric from abl.utils import ABLLogger, print_log from abl.bridge import SimpleBridge @@ -225,7 +225,7 @@ examples. .. code:: ipython3 - from abl.structures import ListData + from abl.data.structures import ListData # ListData is a data structure provided by ABL-Package that can be used to organize data examples data_examples = ListData() # We use the first 1001st and 3001st data examples in the training set as an illustration diff --git a/docs/Examples/MNISTAdd.rst b/docs/Examples/MNISTAdd.rst index 12b6ee7..2d1f30a 100644 --- a/docs/Examples/MNISTAdd.rst +++ b/docs/Examples/MNISTAdd.rst @@ -28,7 +28,7 @@ machine learning model. from examples.models.nn import LeNet5 from abl.learning import ABLModel, BasicNN from abl.reasoning import KBBase, Reasoner - from abl.evaluation import ReasoningMetric, SymbolMetric + from abl.data.evaluation import ReasoningMetric, SymbolMetric from abl.utils import ABLLogger, print_log from abl.bridge import SimpleBridge @@ -191,7 +191,7 @@ examples. .. code:: ipython3 - from abl.structures import ListData + from abl.data.structures import ListData # ListData is a data structure provided by ABL-Package that can be used to organize data examples data_examples = ListData() # We use the first 100 data examples in the training set as an illustration diff --git a/docs/Intro/Datasets.rst b/docs/Intro/Datasets.rst index 3ee1050..9c11403 100644 --- a/docs/Intro/Datasets.rst +++ b/docs/Intro/Datasets.rst @@ -16,7 +16,7 @@ In this section, we will look at the datasets and data structures in ABL-Package # Import necessary libraries and modules import torch - from abl.structures import ListData + from abl.data.structures import ListData Dataset ------- @@ -53,11 +53,11 @@ As an illustration, in the MNIST Addition example, the data used for training ar Data Structure -------------- -Besides the user-provided dataset, various forms of data are utilized and dynamicly generate throughout the training and testing process of Abductive Learning framework. Examples include raw data, predicted pseudo-label, abduced pseudo-label, pseudo-label indices, and so on. To manage this diversity and ensure a stable, versatile interface, ABL-Package employs `abstract data interfaces <../API/abl.structures.html>`_ to encapsulate different forms of data that will be used in the total learning process. +Besides the user-provided dataset, various forms of data are utilized and dynamicly generate throughout the training and testing process of Abductive Learning framework. Examples include raw data, predicted pseudo-label, abduced pseudo-label, pseudo-label indices, and so on. To manage this diversity and ensure a stable, versatile interface, ABL-Package employs `abstract data interfaces <../API/abl.data.html>`_ to encapsulate different forms of data that will be used in the total learning process. ``BaseDataElement`` is the base class for all abstract data interfaces. Inherited from ``BaseDataElement``, ``ListData`` is the most commonly used abstract data interface in ABL-Package. As the fundamental data structure, ``ListData`` implements commonly used data manipulation methods and is responsible for transferring data between various components of ABL, ensuring that stages such as prediction, training, and abductive reasoning can utilize ``ListData`` as a unified input format. -Before proceeding to other stages, user-provided datasets are firstly converted into ``ListData``. For flexibility, ABL-Package also allows user to directly supply data in ``ListData`` format, which similarly requires the inclusion of three attributes: ``X``, ``gt_pseudo_label``, and ``Y``. The following code shows the basic usage of ``ListData``. More information can be found in the `API documentation <../API/abl.structures.html>`_. +Before proceeding to other stages, user-provided datasets are firstly converted into ``ListData``. For flexibility, ABL-Package also allows user to directly supply data in ``ListData`` format, which similarly requires the inclusion of three attributes: ``X``, ``gt_pseudo_label``, and ``Y``. The following code shows the basic usage of ``ListData``. More information can be found in the `API documentation <../API/abl.data.html>`_. .. code-block:: python diff --git a/docs/Intro/Evaluation.rst b/docs/Intro/Evaluation.rst index d4c4688..dfad61d 100644 --- a/docs/Intro/Evaluation.rst +++ b/docs/Intro/Evaluation.rst @@ -15,7 +15,7 @@ In this section, we will look at how to build evaluation metrics. .. code:: python # Import necessary modules - from abl.evaluation import BaseMetric, SymbolMetric, ReasoningMetric + from abl.data.evaluation import BaseMetric, SymbolMetric, ReasoningMetric ABL-Package seperates the evaluation process from model training and testing as an independent class, ``BaseMetric``. The training and testing processes are implemented in the ``BaseBridge`` class, so metrics are used by this class and its sub-classes. After building a ``bridge`` with a list of ``BaseMetric`` instances, these metrics will be used by the ``bridge.valid`` method to evaluate the model performance during training and testing. diff --git a/docs/Intro/Quick-Start.rst b/docs/Intro/Quick-Start.rst index c0d303a..b6e9412 100644 --- a/docs/Intro/Quick-Start.rst +++ b/docs/Intro/Quick-Start.rst @@ -111,7 +111,7 @@ ABL-Package provides two basic metrics, namely ``SymbolMetric`` and ``ReasoningM .. code:: python - from abl.evaluation import ReasoningMetric, SymbolMetric + from abl.data.evaluation import ReasoningMetric, SymbolMetric metric_list = [SymbolMetric(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")] diff --git a/docs/index.rst b/docs/index.rst index 4c43008..6191a0b 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -35,9 +35,8 @@ API/abl.dataset API/abl.learning API/abl.reasoning - API/abl.evaluation API/abl.bridge - API/abl.structures + API/abl.data API/abl.utils .. toctree:: diff --git a/examples/hed/bridge.py b/examples/hed/bridge.py index 255f267..b7ca577 100644 --- a/examples/hed/bridge.py +++ b/examples/hed/bridge.py @@ -5,10 +5,10 @@ import torch from abl.bridge import SimpleBridge from abl.dataset import RegressionDataset -from abl.evaluation import BaseMetric +from abl.data.evaluation import BaseMetric from abl.learning import ABLModel, BasicNN from abl.reasoning import Reasoner -from abl.structures import ListData +from abl.data.structures import ListData from abl.utils import print_log from examples.hed.datasets import get_pretrain_data from examples.hed.utils import InfiniteSampler, gen_mappings diff --git a/examples/hed/hed.ipynb b/examples/hed/hed.ipynb index b593a89..6c9489b 100644 --- a/examples/hed/hed.ipynb +++ b/examples/hed/hed.ipynb @@ -26,7 +26,7 @@ "from examples.models.nn import SymbolNet\n", "from abl.learning import ABLModel, BasicNN\n", "from examples.hed.reasoning import HedKB, HedReasoner\n", - "from abl.evaluation import ReasoningMetric, SymbolMetric\n", + "from abl.data.evaluation import ReasoningMetric, SymbolMetric\n", "from abl.utils import ABLLogger, print_log\n", "from examples.hed.bridge import HedBridge" ] diff --git a/examples/hwf/hwf.ipynb b/examples/hwf/hwf.ipynb index 6cdd31f..3434fec 100644 --- a/examples/hwf/hwf.ipynb +++ b/examples/hwf/hwf.ipynb @@ -27,7 +27,7 @@ "from examples.models.nn import SymbolNet\n", "from abl.learning import ABLModel, BasicNN\n", "from abl.reasoning import KBBase, Reasoner\n", - "from abl.evaluation import ReasoningMetric, SymbolMetric\n", + "from abl.data.evaluation import ReasoningMetric, SymbolMetric\n", "from abl.utils import ABLLogger, print_log\n", "from abl.bridge import SimpleBridge" ] @@ -55,7 +55,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Both `train_data` and `test_data` have the same structures: tuples with three components: X (list where each element is a list of images), gt_pseudo_label (list where each element is a list of symbols, i.e., pseudo-labels) and Y (list where each element is the computed result). The length and structures of datasets are illustrated as follows.\n", + "Both `train_data` and `test_data` have the same data.structures: tuples with three components: X (list where each element is a list of images), gt_pseudo_label (list where each element is a list of symbols, i.e., pseudo-labels) and Y (list where each element is the computed result). The length and data.structures of datasets are illustrated as follows.\n", "\n", "Note: ``gt_pseudo_label`` is only used to evaluate the performance of the learning part but not to train the model." ] @@ -299,7 +299,7 @@ } ], "source": [ - "from abl.structures import ListData\n", + "from abl.data.structures import ListData\n", "# ListData is a data structure provided by ABL-Package that can be used to organize data examples\n", "data_examples = ListData()\n", "# We use the first 1001st and 3001st data examples in the training set as an illustration\n", diff --git a/examples/hwf/main.py b/examples/hwf/main.py index 75248e4..83161ad 100644 --- a/examples/hwf/main.py +++ b/examples/hwf/main.py @@ -10,15 +10,17 @@ from examples.hwf.datasets import get_dataset from examples.models.nn import SymbolNet from abl.learning import ABLModel, BasicNN from abl.reasoning import KBBase, GroundKB, Reasoner -from abl.evaluation import ReasoningMetric, SymbolMetric +from abl.data.evaluation import ReasoningMetric, SymbolMetric from abl.utils import ABLLogger, print_log from abl.bridge import SimpleBridge + class HwfKB(KBBase): - def __init__(self, - pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"], - max_err=1e-10, - ): + def __init__( + self, + pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"], + max_err=1e-10, + ): super().__init__(pseudo_label_list, max_err) def _valid_candidate(self, formula): @@ -30,19 +32,21 @@ class HwfKB(KBBase): if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]: return False return True - + # Implement the deduction function def logic_forward(self, formula): if not self._valid_candidate(formula): return np.inf return eval("".join(formula)) + class HwfGroundKB(GroundKB): - def __init__(self, - pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"], - GKB_len_list=[1,3,5,7], - max_err=1e-10, - ): + def __init__( + self, + pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"], + GKB_len_list=[1, 3, 5, 7], + max_err=1e-10, + ): super().__init__(pseudo_label_list, GKB_len_list, max_err) def _valid_candidate(self, formula): @@ -54,40 +58,62 @@ class HwfGroundKB(GroundKB): if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]: return False return True - + # Implement the deduction function def logic_forward(self, formula): if not self._valid_candidate(formula): return np.inf return eval("".join(formula)) + def main(): - parser = argparse.ArgumentParser(description='MNIST Addition example') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--epochs', type=int, default=3, - help='number of epochs in each learning loop iteration (default : 3)') - parser.add_argument('--lr', type=float, default=1e-3, - help='base model learning rate (default : 0.001)') - parser.add_argument('--batch-size', type=int, default=128, - help='base model batch size (default : 128)') - parser.add_argument('--loops', type=int, default=5, - help='number of loop iterations (default : 5)') - parser.add_argument('--segment_size', type=int or float, default=1000, - help='segment size (default : 1000)') - parser.add_argument('--save_interval', type=int, default=1, - help='save interval (default : 1)') - parser.add_argument('--max-revision', type=int or float, default=-1, - help='maximum revision in reasoner (default : -1)') - parser.add_argument('--require-more-revision', type=int, default=5, - help='require more revision in reasoner (default : 0)') - parser.add_argument("--ground", action="store_true", default=False, - help='use GroundKB (default: False)') - parser.add_argument("--max-err", type=float, default=1e-10, - help='max tolerance during abductive reasoning (default : 1e-10)') + parser = argparse.ArgumentParser(description="MNIST Addition example") + parser.add_argument( + "--no-cuda", action="store_true", default=False, help="disables CUDA training" + ) + parser.add_argument( + "--epochs", + type=int, + default=3, + help="number of epochs in each learning loop iteration (default : 3)", + ) + parser.add_argument( + "--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)" + ) + parser.add_argument( + "--batch-size", type=int, default=128, help="base model batch size (default : 128)" + ) + parser.add_argument( + "--loops", type=int, default=5, help="number of loop iterations (default : 5)" + ) + parser.add_argument( + "--segment_size", type=int or float, default=1000, help="segment size (default : 1000)" + ) + parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)") + parser.add_argument( + "--max-revision", + type=int or float, + default=-1, + help="maximum revision in reasoner (default : -1)", + ) + parser.add_argument( + "--require-more-revision", + type=int, + default=5, + help="require more revision in reasoner (default : 0)", + ) + parser.add_argument( + "--ground", action="store_true", default=False, help="use GroundKB (default: False)" + ) + parser.add_argument( + "--max-err", + type=float, + default=1e-10, + help="max tolerance during abductive reasoning (default : 1e-10)", + ) args = parser.parse_args() - + ### Working with Data train_data = get_dataset(train=True, get_pseudo_label=True) test_data = get_dataset(train=False, get_pseudo_label=True) @@ -112,16 +138,18 @@ def main(): # Build ABLModel model = ABLModel(base_model) - + ### Building the Reasoning Part # Build knowledge base if args.ground: kb = HwfGroundKB() else: kb = HwfKB() - + # Create reasoner - reasoner = Reasoner(kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision) + reasoner = Reasoner( + kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision + ) ### Building Evaluation Metrics metric_list = [SymbolMetric(prefix="hwf"), ReasoningMetric(kb=kb, prefix="hwf")] @@ -135,9 +163,15 @@ def main(): # Retrieve the directory of the Log file and define the directory for saving the model weights. log_dir = ABLLogger.get_current_instance().log_dir weights_dir = osp.join(log_dir, "weights") - + # Train and Test - bridge.train(train_data, loops=args.loops, segment_size=args.segment_size, save_interval=args.save_interval, save_dir=weights_dir) + bridge.train( + train_data, + loops=args.loops, + segment_size=args.segment_size, + save_interval=args.save_interval, + save_dir=weights_dir, + ) bridge.test(test_data) diff --git a/examples/mnist_add/main.py b/examples/mnist_add/main.py index 873dae2..e00c7fd 100644 --- a/examples/mnist_add/main.py +++ b/examples/mnist_add/main.py @@ -9,10 +9,11 @@ from examples.mnist_add.datasets import get_dataset from examples.models.nn import LeNet5 from abl.learning import ABLModel, BasicNN from abl.reasoning import KBBase, GroundKB, PrologKB, Reasoner -from abl.evaluation import ReasoningMetric, SymbolMetric +from abl.data.evaluation import ReasoningMetric, SymbolMetric from abl.utils import ABLLogger, print_log from abl.bridge import SimpleBridge + class AddKB(KBBase): def __init__(self, pseudo_label_list=list(range(10))): super().__init__(pseudo_label_list) @@ -20,6 +21,7 @@ class AddKB(KBBase): def logic_forward(self, nums): return sum(nums) + class AddGroundKB(GroundKB): def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]): super().__init__(pseudo_label_list, GKB_len_list) @@ -27,36 +29,54 @@ class AddGroundKB(GroundKB): def logic_forward(self, nums): return sum(nums) + def main(): - parser = argparse.ArgumentParser(description='MNIST Addition example') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--epochs', type=int, default=1, - help='number of epochs in each learning loop iteration (default : 1)') - parser.add_argument('--lr', type=float, default=1e-3, - help='base model learning rate (default : 0.001)') - parser.add_argument('--alpha', type=float, default=0.9, - help='alpha in RMSprop (default : 0.9)') - parser.add_argument('--batch-size', type=int, default=32, - help='base model batch size (default : 32)') - parser.add_argument('--loops', type=int, default=5, - help='number of loop iterations (default : 5)') - parser.add_argument('--segment_size', type=int or float, default=1/3, - help='segment size (default : 1/3)') - parser.add_argument('--save_interval', type=int, default=1, - help='save interval (default : 1)') - parser.add_argument('--max-revision', type=int or float, default=-1, - help='maximum revision in reasoner (default : -1)') - parser.add_argument('--require-more-revision', type=int, default=5, - help='require more revision in reasoner (default : 0)') + parser = argparse.ArgumentParser(description="MNIST Addition example") + parser.add_argument( + "--no-cuda", action="store_true", default=False, help="disables CUDA training" + ) + parser.add_argument( + "--epochs", + type=int, + default=1, + help="number of epochs in each learning loop iteration (default : 1)", + ) + parser.add_argument( + "--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)" + ) + parser.add_argument("--alpha", type=float, default=0.9, help="alpha in RMSprop (default : 0.9)") + parser.add_argument( + "--batch-size", type=int, default=32, help="base model batch size (default : 32)" + ) + parser.add_argument( + "--loops", type=int, default=5, help="number of loop iterations (default : 5)" + ) + parser.add_argument( + "--segment_size", type=int or float, default=1 / 3, help="segment size (default : 1/3)" + ) + parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)") + parser.add_argument( + "--max-revision", + type=int or float, + default=-1, + help="maximum revision in reasoner (default : -1)", + ) + parser.add_argument( + "--require-more-revision", + type=int, + default=5, + help="require more revision in reasoner (default : 0)", + ) kb_type = parser.add_mutually_exclusive_group() - kb_type.add_argument("--prolog", action="store_true", default=False, - help='use PrologKB (default: False)') - kb_type.add_argument("--ground", action="store_true", default=False, - help='use GroundKB (default: False)') + kb_type.add_argument( + "--prolog", action="store_true", default=False, help="use PrologKB (default: False)" + ) + kb_type.add_argument( + "--ground", action="store_true", default=False, help="use GroundKB (default: False)" + ) args = parser.parse_args() - + ### Working with Data train_data = get_dataset(train=True, get_pseudo_label=True) test_data = get_dataset(train=False, get_pseudo_label=True) @@ -81,7 +101,7 @@ def main(): # Build ABLModel model = ABLModel(base_model) - + ### Building the Reasoning Part # Build knowledge base if args.prolog: @@ -90,9 +110,11 @@ def main(): kb = AddGroundKB() else: kb = AddKB() - + # Create reasoner - reasoner = Reasoner(kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision) + reasoner = Reasoner( + kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision + ) ### Building Evaluation Metrics metric_list = [SymbolMetric(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")] @@ -106,13 +128,17 @@ def main(): # Retrieve the directory of the Log file and define the directory for saving the model weights. log_dir = ABLLogger.get_current_instance().log_dir weights_dir = osp.join(log_dir, "weights") - + # Train and Test - bridge.train(train_data, loops=args.loops, segment_size=args.segment_size, save_interval=args.save_interval, save_dir=weights_dir) + bridge.train( + train_data, + loops=args.loops, + segment_size=args.segment_size, + save_interval=args.save_interval, + save_dir=weights_dir, + ) bridge.test(test_data) - - if __name__ == "__main__": main() diff --git a/examples/mnist_add/mnist_add.ipynb b/examples/mnist_add/mnist_add.ipynb index 31ed3af..172e305 100644 --- a/examples/mnist_add/mnist_add.ipynb +++ b/examples/mnist_add/mnist_add.ipynb @@ -13,7 +13,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -26,7 +26,7 @@ "from examples.models.nn import LeNet5\n", "from abl.learning import ABLModel, BasicNN\n", "from abl.reasoning import KBBase, Reasoner\n", - "from abl.evaluation import ReasoningMetric, SymbolMetric\n", + "from abl.data.evaluation import ReasoningMetric, SymbolMetric\n", "from abl.utils import ABLLogger, print_log\n", "from abl.bridge import SimpleBridge" ] @@ -42,7 +42,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -61,7 +61,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -110,7 +110,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -170,7 +170,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -198,7 +198,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -229,7 +229,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -245,7 +245,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -261,7 +261,7 @@ } ], "source": [ - "from abl.structures import ListData\n", + "from abl.data.structures import ListData\n", "# ListData is a data structure provided by ABL-Package that can be used to organize data examples\n", "data_examples = ListData()\n", "# We use the first 100 data examples in the training set as an illustration\n", @@ -295,7 +295,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -319,7 +319,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -352,7 +352,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -385,7 +385,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -404,7 +404,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -457,7 +457,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.18" }, "orig_nbformat": 4, "vscode": { diff --git a/examples/zoo/main.py b/examples/zoo/main.py index fe28d59..76f39af 100644 --- a/examples/zoo/main.py +++ b/examples/zoo/main.py @@ -8,7 +8,7 @@ import openml from abl.learning import ABLModel from abl.reasoning import KBBase, Reasoner -from abl.evaluation import ReasoningMetric, SymbolMetric +from abl.data.evaluation import ReasoningMetric, SymbolMetric from abl.bridge import SimpleBridge from abl.utils.utils import confidence_dist from abl.utils import ABLLogger, print_log @@ -27,23 +27,33 @@ model = ABLModel(rf) # %% [markdown] # ### Logic Part + # %% class ZooKB(KBBase): def __init__(self): super().__init__(pseudo_label_list=list(range(7)), use_cache=False) - - # Use z3 solver + + # Use z3 solver self.solver = Solver() # Load information of Zoo dataset - dataset = openml.datasets.get_dataset(dataset_id = 62, download_data=False, download_qualities=False, download_features_meta_data=False) - X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute) + dataset = openml.datasets.get_dataset( + dataset_id=62, + download_data=False, + download_qualities=False, + download_features_meta_data=False, + ) + X, y, categorical_indicator, attribute_names = dataset.get_data( + target=dataset.default_target_attribute + ) self.attribute_names = attribute_names self.target_names = y.cat.categories.tolist() - + # Define variables - for name in self.attribute_names+self.target_names: - exec(f"globals()['{name}'] = Int('{name}')") ## or use dict to create var and modify rules + for name in self.attribute_names + self.target_names: + exec( + f"globals()['{name}'] = Int('{name}')" + ) ## or use dict to create var and modify rules # Define rules rules = [ Implies(milk == 1, mammal == 1), @@ -75,25 +85,27 @@ class ZooKB(KBBase): Implies(insect == 1, eggs == 1), Implies(insect == 1, Not(backbone == 1)), Implies(insect == 1, legs == 6), - Implies(invertebrate == 1, Not(backbone == 1)) + Implies(invertebrate == 1, Not(backbone == 1)), ] # Define weights and sum of violated weights self.weights = {rule: 1 for rule in rules} - self.total_violation_weight = Sum([If(Not(rule), self.weights[rule], 0) for rule in self.weights]) - + self.total_violation_weight = Sum( + [If(Not(rule), self.weights[rule], 0) for rule in self.weights] + ) + def logic_forward(self, pseudo_label, data_point): attribute_names, target_names = self.attribute_names, self.target_names solver = self.solver total_violation_weight = self.total_violation_weight pseudo_label, data_point = pseudo_label[0], data_point[0] - + self.solver.reset() for name, value in zip(attribute_names, data_point): solver.add(eval(f"{name} == {value}")) - for cate, name in zip(self.pseudo_label_list,target_names): + for cate, name in zip(self.pseudo_label_list, target_names): value = 1 if (cate == pseudo_label) else 0 solver.add(eval(f"{name} == {value}")) - + if solver.check() == sat: model = solver.model() total_weight = model.evaluate(total_violation_weight) @@ -101,7 +113,8 @@ class ZooKB(KBBase): else: # No solution found return 1e10 - + + def consitency(data_example, candidates, candidate_idxs, reasoning_results): pred_prob = data_example.pred_prob model_scores = confidence_dist(pred_prob, candidate_idxs) @@ -109,51 +122,60 @@ def consitency(data_example, candidates, candidate_idxs, reasoning_results): scores = model_scores + rule_scores return scores + kb = ZooKB() reasoner = Reasoner(kb, dist_func=consitency) # %% [markdown] # ### Datasets and Evaluation Metrics + # %% # Function to load and preprocess the dataset def load_and_preprocess_dataset(dataset_id): - dataset = openml.datasets.get_dataset(dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False) + dataset = openml.datasets.get_dataset( + dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False + ) X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute) # Convert data types - for col in X.select_dtypes(include='bool').columns: + for col in X.select_dtypes(include="bool").columns: X[col] = X[col].astype(int) y = y.cat.codes.astype(int) X, y = X.to_numpy(), y.to_numpy() return X, y + # Function to split data (one shot) -def split_dataset(X, y, test_size = 0.3): +def split_dataset(X, y, test_size=0.3): # For every class: 1 : (1-test_size)*(len-1) : test_size*(len-1) label_indices, unlabel_indices, test_indices = [], [], [] for class_label in np.unique(y): idxs = np.where(y == class_label)[0] np.random.shuffle(idxs) - n_train_unlabel = int((1-test_size)*(len(idxs)-1)) + n_train_unlabel = int((1 - test_size) * (len(idxs) - 1)) label_indices.append(idxs[0]) - unlabel_indices.extend(idxs[1:1+n_train_unlabel]) - test_indices.extend(idxs[1+n_train_unlabel:]) + unlabel_indices.extend(idxs[1 : 1 + n_train_unlabel]) + test_indices.extend(idxs[1 + n_train_unlabel :]) X_label, y_label = X[label_indices], y[label_indices] X_unlabel, y_unlabel = X[unlabel_indices], y[unlabel_indices] X_test, y_test = X[test_indices], y[test_indices] return X_label, y_label, X_unlabel, y_unlabel, X_test, y_test + # Load and preprocess the Zoo dataset X, y = load_and_preprocess_dataset(dataset_id=62) # Split data into labeled/unlabeled/test data X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3) + # Transform tabluar data to the format required by ABL, which is a tuple of (X, ground truth of X, reasoning results) # For tabular data in abl, each example contains a single instance (a row from the dataset). # For these tabular data examples, the reasoning results are expected to be 0, indicating no rules are violated. def transform_tab_data(X, y): return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y)) + + label_data = transform_tab_data(X_label, y_label) test_data = transform_tab_data(X_test, y_test) train_data = transform_tab_data(X_unlabel, y_unlabel) @@ -181,9 +203,13 @@ print("------- Test the initial model -----------") bridge.test(test_data) print("------- Use ABL to train the model -----------") # Use ABL to train the model -bridge.train(train_data=train_data, label_data=label_data, loops=3, segment_size=len(X_unlabel), save_dir=weights_dir) +bridge.train( + train_data=train_data, + label_data=label_data, + loops=3, + segment_size=len(X_unlabel), + save_dir=weights_dir, +) print("------- Test the final model -----------") # Test the final model bridge.test(test_data) - - diff --git a/examples/zoo/zoo.ipynb b/examples/zoo/zoo.ipynb index 2fa570d..cc9b43c 100644 --- a/examples/zoo/zoo.ipynb +++ b/examples/zoo/zoo.ipynb @@ -25,7 +25,7 @@ "from abl.learning import ABLModel\n", "from examples.zoo.kb import ZooKB\n", "from abl.reasoning import Reasoner\n", - "from abl.evaluation import ReasoningMetric, SymbolMetric\n", + "from abl.data.evaluation import ReasoningMetric, SymbolMetric\n", "from abl.utils import ABLLogger, print_log, confidence_dist\n", "from abl.bridge import SimpleBridge" ] @@ -56,7 +56,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "`train_data` and `test_data` share identical structures: tuples with three components: X (list where each element is a list of two images), gt_pseudo_label (list where each element is a list of two digits, i.e., pseudo-labels) and Y (list where each element is the sum of the two digits). The length and structures of datasets are illustrated as follows.\n", + "`train_data` and `test_data` share identical data.structures: tuples with three components: X (list where each element is a list of two images), gt_pseudo_label (list where each element is a list of two digits, i.e., pseudo-labels) and Y (list where each element is the sum of the two digits). The length and data.structures of datasets are illustrated as follows.\n", "\n", "Note: ``gt_pseudo_label`` is only used to evaluate the performance of the learning part but not to train the model." ] diff --git a/tests/conftest.py b/tests/conftest.py index 67c8024..dc299a8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,7 @@ import torch.optim as optim from abl.learning import BasicNN from abl.reasoning import GroundKB, KBBase, PrologKB, Reasoner -from abl.structures import ListData +from abl.data.structures import ListData class LeNet5(nn.Module): @@ -202,10 +202,12 @@ def kb_add_prolog(): kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/add.pl") return kb + @pytest.fixture def kb_hwf1(): return HwfKB(max_err=0.1) + @pytest.fixture def kb_hwf2(): return HwfKB(max_err=1)