@@ -8,8 +8,7 @@ omit = | |||
*/abl/__init__.py | |||
abl/bridge/__init__.py | |||
abl/dataset/__init__.py | |||
abl/evaluation/__init__.py | |||
abl/data/__init__.py | |||
abl/learning/__init__.py | |||
abl/reasoning/__init__.py | |||
abl/structures/__init__.py | |||
abl/utils/__init__.py |
@@ -1,11 +1,9 @@ | |||
from . import bridge, dataset, evaluation, learning, reasoning, structures, utils | |||
from . import bridge, data, learning, reasoning, utils | |||
__all__ = [ | |||
"bridge", | |||
"dataset", | |||
"evaluation", | |||
"data", | |||
"learning", | |||
"reasoning", | |||
"structures", | |||
"utils", | |||
] |
@@ -3,7 +3,7 @@ from typing import Any, List, Optional, Tuple, Union | |||
from ..learning import ABLModel | |||
from ..reasoning import Reasoner | |||
from ..structures import ListData | |||
from ..data.structures import ListData | |||
class BaseBridge(metaclass=ABCMeta): | |||
@@ -3,10 +3,10 @@ from typing import Any, List, Optional, Tuple, Union | |||
from numpy import ndarray | |||
from ..evaluation import BaseMetric | |||
from ..data.evaluation import BaseMetric | |||
from ..learning import ABLModel | |||
from ..reasoning import Reasoner | |||
from ..structures import ListData | |||
from ..data.structures import ListData | |||
from ..utils import print_log | |||
from .base_bridge import BaseBridge | |||
@@ -0,0 +1,2 @@ | |||
from .evaluation import * | |||
from .structures import * |
@@ -3,7 +3,7 @@ from abc import ABCMeta, abstractmethod | |||
from typing import Any, List, Optional | |||
from ..structures import ListData | |||
from ..utils import print_log | |||
from ...utils import print_log | |||
class BaseMetric(metaclass=ABCMeta): |
@@ -1,6 +1,6 @@ | |||
from typing import Optional | |||
from ..reasoning import KBBase | |||
from ...reasoning import KBBase | |||
from ..structures import ListData | |||
from .base_metric import BaseMetric | |||
@@ -5,8 +5,9 @@ from typing import Any, Iterator, Optional, Tuple, Type, Union | |||
import numpy as np | |||
import torch | |||
# Modified from | |||
# https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py | |||
# https://github.com/open-mmlab/mmengine/blob/main/mmengine/data.structures/base_data_element.py | |||
class BaseDataElement: | |||
"""A base data interface that supports Tensor-like and dict-like | |||
operations. | |||
@@ -73,7 +74,7 @@ class BaseDataElement: | |||
Examples: | |||
>>> import torch | |||
>>> from mmengine.structures import BaseDataElement | |||
>>> from mmengine.data.structures import BaseDataElement | |||
>>> gt_instances = BaseDataElement() | |||
>>> bboxes = torch.rand((5, 4)) | |||
>>> scores = torch.rand((5,)) |
@@ -1,12 +1,11 @@ | |||
# Copyright (c) OpenMMLab. All rights reserved. | |||
import itertools | |||
from typing import List, Union | |||
import numpy as np | |||
import torch | |||
from ..utils import flatten as flatten_list | |||
from ..utils import to_hashable | |||
from ...utils import flatten as flatten_list | |||
from ...utils import to_hashable | |||
from .base_data_element import BaseDataElement | |||
BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] | |||
@@ -16,19 +15,19 @@ IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndar | |||
# Modified from | |||
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa | |||
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_data.structures/instance_data.py # noqa | |||
class ListData(BaseDataElement): | |||
""" | |||
Data structure for example-level data. | |||
Subclass of :class:`BaseDataElement`. All value in `data_fields` | |||
should have the same length. This design refer to | |||
https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py | |||
https://github.com/facebookresearch/detectron2/blob/master/detectron2/data.structures/instances.py | |||
ListData supports `index` and `slice` for data field. The type of value in data field can be either `None` or `list` of base data structures such as `torch.Tensor`, `numpy.ndarray`, `list`, `str` and `tuple`. | |||
ListData supports `index` and `slice` for data field. The type of value in data field can be either `None` or `list` of base data data.structures such as `torch.Tensor`, `numpy.ndarray`, `list`, `str` and `tuple`. | |||
Examples: | |||
>>> from abl.structures import ListData | |||
>>> from abl.data.structures import ListData | |||
>>> import numpy as np | |||
>>> import torch | |||
>>> data_examples = ListData() |
@@ -1,7 +1,7 @@ | |||
import pickle | |||
from typing import Any, Dict | |||
from ..structures import ListData | |||
from ..data.structures import ListData | |||
from ..utils import reform_list | |||
@@ -5,7 +5,7 @@ import numpy as np | |||
from zoopt import Dimension, Objective, Opt, Parameter, Solution | |||
from ..reasoning import KBBase | |||
from ..structures import ListData | |||
from ..data.structures import ListData | |||
from ..utils.utils import confidence_dist, hamming_dist | |||
@@ -19,18 +19,18 @@ class Reasoner: | |||
The knowledge base to be used for reasoning. | |||
dist_func : Union[str, Callable], optional | |||
The distance function used to determine the cost list between each | |||
candidate and the given prediction. The cost is also referred to as a consistency | |||
measure, wherein the candidate with lowest cost is selected as the final | |||
abduced label. It can be either a string representing a predefined distance | |||
function or a callable function. The available predefined distance functions: | |||
'hamming' | 'confidence'. 'hamming': directly calculates the Hamming | |||
distance between the predicted pseudo-label in the data example and each | |||
candidate, 'confidence': calculates the distance between the prediction | |||
and each candidate based on confidence derived from the predicted probability | |||
in the data example. The callable function should have the signature | |||
dist_func(data_example, candidates, candidate_idxs, reasoning_results) and must return a cost list. Each element | |||
in this cost list should be a numerical value representing the cost for each | |||
candidate, and the list should have the same length as candidates. | |||
candidate and the given prediction. The cost is also referred to as a consistency | |||
measure, wherein the candidate with lowest cost is selected as the final | |||
abduced label. It can be either a string representing a predefined distance | |||
function or a callable function. The available predefined distance functions: | |||
'hamming' | 'confidence'. 'hamming': directly calculates the Hamming | |||
distance between the predicted pseudo-label in the data example and each | |||
candidate, 'confidence': calculates the distance between the prediction | |||
and each candidate based on confidence derived from the predicted probability | |||
in the data example. The callable function should have the signature | |||
dist_func(data_example, candidates, candidate_idxs, reasoning_results) and must return a cost list. Each element | |||
in this cost list should be a numerical value representing the cost for each | |||
candidate, and the list should have the same length as candidates. | |||
Defaults to 'confidence'. | |||
idx_to_label : Optional[dict], optional | |||
A mapping from index in the base model to label. If not provided, a default | |||
@@ -64,7 +64,9 @@ class Reasoner: | |||
self.require_more_revision = require_more_revision | |||
if idx_to_label is None: | |||
self.idx_to_label = {index: label for index, label in enumerate(self.kb.pseudo_label_list)} | |||
self.idx_to_label = { | |||
index: label for index, label in enumerate(self.kb.pseudo_label_list) | |||
} | |||
else: | |||
self._check_valid_idx_to_label(idx_to_label) | |||
self.idx_to_label = idx_to_label | |||
@@ -80,7 +82,9 @@ class Reasoner: | |||
elif callable(dist_func): | |||
params = inspect.signature(dist_func).parameters.values() | |||
if len(params) != 4: | |||
raise ValueError(f"User-defined dist_func must have exactly four parameters, but got {len(params)}.") | |||
raise ValueError( | |||
f"User-defined dist_func must have exactly four parameters, but got {len(params)}." | |||
) | |||
return | |||
else: | |||
raise TypeError( | |||
@@ -289,18 +293,18 @@ class Reasoner: | |||
solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num) | |||
revision_idx = np.where(solution.get_x() != 0)[0] | |||
candidates, reasoning_results = self.kb.revise_at_idx( | |||
pseudo_label=data_example.pred_pseudo_label, | |||
y=data_example.Y, | |||
x=data_example.X, | |||
revision_idx=revision_idx | |||
pseudo_label=data_example.pred_pseudo_label, | |||
y=data_example.Y, | |||
x=data_example.X, | |||
revision_idx=revision_idx, | |||
) | |||
else: | |||
candidates, reasoning_results = self.kb.abduce_candidates( | |||
pseudo_label=data_example.pred_pseudo_label, | |||
y=data_example.Y, | |||
y=data_example.Y, | |||
x=data_example.X, | |||
max_revision_num=max_revision_num, | |||
require_more_revision=self.require_more_revision | |||
require_more_revision=self.require_more_revision, | |||
) | |||
candidate = self._get_one_candidate(data_example, candidates, reasoning_results) | |||
@@ -0,0 +1,18 @@ | |||
abl.data | |||
=================== | |||
Data Structure | |||
-------------- | |||
.. autoclass:: abl.data.structures.ListData | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
Evaluation Metric | |||
----------------- | |||
.. automodule:: abl.data.evaluation | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -1,7 +0,0 @@ | |||
abl.evaluation | |||
================== | |||
.. automodule:: abl.evaluation | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -1,7 +0,0 @@ | |||
abl.structures | |||
================== | |||
.. autoclass:: abl.structures.ListData | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -32,7 +32,7 @@ model. | |||
from examples.models.nn import SymbolNet | |||
from abl.learning import ABLModel, BasicNN | |||
from examples.hed.reasoning import HedKB, HedReasoner | |||
from abl.evaluation import ReasoningMetric, SymbolMetric | |||
from abl.data.evaluation import ReasoningMetric, SymbolMetric | |||
from abl.utils import ABLLogger, print_log | |||
from examples.hed.bridge import HedBridge | |||
@@ -30,7 +30,7 @@ machine learning model. | |||
from examples.models.nn import SymbolNet | |||
from abl.learning import ABLModel, BasicNN | |||
from abl.reasoning import KBBase, Reasoner | |||
from abl.evaluation import ReasoningMetric, SymbolMetric | |||
from abl.data.evaluation import ReasoningMetric, SymbolMetric | |||
from abl.utils import ABLLogger, print_log | |||
from abl.bridge import SimpleBridge | |||
@@ -225,7 +225,7 @@ examples. | |||
.. code:: ipython3 | |||
from abl.structures import ListData | |||
from abl.data.structures import ListData | |||
# ListData is a data structure provided by ABL-Package that can be used to organize data examples | |||
data_examples = ListData() | |||
# We use the first 1001st and 3001st data examples in the training set as an illustration | |||
@@ -28,7 +28,7 @@ machine learning model. | |||
from examples.models.nn import LeNet5 | |||
from abl.learning import ABLModel, BasicNN | |||
from abl.reasoning import KBBase, Reasoner | |||
from abl.evaluation import ReasoningMetric, SymbolMetric | |||
from abl.data.evaluation import ReasoningMetric, SymbolMetric | |||
from abl.utils import ABLLogger, print_log | |||
from abl.bridge import SimpleBridge | |||
@@ -191,7 +191,7 @@ examples. | |||
.. code:: ipython3 | |||
from abl.structures import ListData | |||
from abl.data.structures import ListData | |||
# ListData is a data structure provided by ABL-Package that can be used to organize data examples | |||
data_examples = ListData() | |||
# We use the first 100 data examples in the training set as an illustration | |||
@@ -16,7 +16,7 @@ In this section, we will look at the datasets and data structures in ABL-Package | |||
# Import necessary libraries and modules | |||
import torch | |||
from abl.structures import ListData | |||
from abl.data.structures import ListData | |||
Dataset | |||
------- | |||
@@ -53,11 +53,11 @@ As an illustration, in the MNIST Addition example, the data used for training ar | |||
Data Structure | |||
-------------- | |||
Besides the user-provided dataset, various forms of data are utilized and dynamicly generate throughout the training and testing process of Abductive Learning framework. Examples include raw data, predicted pseudo-label, abduced pseudo-label, pseudo-label indices, and so on. To manage this diversity and ensure a stable, versatile interface, ABL-Package employs `abstract data interfaces <../API/abl.structures.html>`_ to encapsulate different forms of data that will be used in the total learning process. | |||
Besides the user-provided dataset, various forms of data are utilized and dynamicly generate throughout the training and testing process of Abductive Learning framework. Examples include raw data, predicted pseudo-label, abduced pseudo-label, pseudo-label indices, and so on. To manage this diversity and ensure a stable, versatile interface, ABL-Package employs `abstract data interfaces <../API/abl.data.html>`_ to encapsulate different forms of data that will be used in the total learning process. | |||
``BaseDataElement`` is the base class for all abstract data interfaces. Inherited from ``BaseDataElement``, ``ListData`` is the most commonly used abstract data interface in ABL-Package. As the fundamental data structure, ``ListData`` implements commonly used data manipulation methods and is responsible for transferring data between various components of ABL, ensuring that stages such as prediction, training, and abductive reasoning can utilize ``ListData`` as a unified input format. | |||
Before proceeding to other stages, user-provided datasets are firstly converted into ``ListData``. For flexibility, ABL-Package also allows user to directly supply data in ``ListData`` format, which similarly requires the inclusion of three attributes: ``X``, ``gt_pseudo_label``, and ``Y``. The following code shows the basic usage of ``ListData``. More information can be found in the `API documentation <../API/abl.structures.html>`_. | |||
Before proceeding to other stages, user-provided datasets are firstly converted into ``ListData``. For flexibility, ABL-Package also allows user to directly supply data in ``ListData`` format, which similarly requires the inclusion of three attributes: ``X``, ``gt_pseudo_label``, and ``Y``. The following code shows the basic usage of ``ListData``. More information can be found in the `API documentation <../API/abl.data.html>`_. | |||
.. code-block:: python | |||
@@ -15,7 +15,7 @@ In this section, we will look at how to build evaluation metrics. | |||
.. code:: python | |||
# Import necessary modules | |||
from abl.evaluation import BaseMetric, SymbolMetric, ReasoningMetric | |||
from abl.data.evaluation import BaseMetric, SymbolMetric, ReasoningMetric | |||
ABL-Package seperates the evaluation process from model training and testing as an independent class, ``BaseMetric``. The training and testing processes are implemented in the ``BaseBridge`` class, so metrics are used by this class and its sub-classes. After building a ``bridge`` with a list of ``BaseMetric`` instances, these metrics will be used by the ``bridge.valid`` method to evaluate the model performance during training and testing. | |||
@@ -111,7 +111,7 @@ ABL-Package provides two basic metrics, namely ``SymbolMetric`` and ``ReasoningM | |||
.. code:: python | |||
from abl.evaluation import ReasoningMetric, SymbolMetric | |||
from abl.data.evaluation import ReasoningMetric, SymbolMetric | |||
metric_list = [SymbolMetric(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")] | |||
@@ -35,9 +35,8 @@ | |||
API/abl.dataset | |||
API/abl.learning | |||
API/abl.reasoning | |||
API/abl.evaluation | |||
API/abl.bridge | |||
API/abl.structures | |||
API/abl.data | |||
API/abl.utils | |||
.. toctree:: | |||
@@ -5,10 +5,10 @@ import torch | |||
from abl.bridge import SimpleBridge | |||
from abl.dataset import RegressionDataset | |||
from abl.evaluation import BaseMetric | |||
from abl.data.evaluation import BaseMetric | |||
from abl.learning import ABLModel, BasicNN | |||
from abl.reasoning import Reasoner | |||
from abl.structures import ListData | |||
from abl.data.structures import ListData | |||
from abl.utils import print_log | |||
from examples.hed.datasets import get_pretrain_data | |||
from examples.hed.utils import InfiniteSampler, gen_mappings | |||
@@ -26,7 +26,7 @@ | |||
"from examples.models.nn import SymbolNet\n", | |||
"from abl.learning import ABLModel, BasicNN\n", | |||
"from examples.hed.reasoning import HedKB, HedReasoner\n", | |||
"from abl.evaluation import ReasoningMetric, SymbolMetric\n", | |||
"from abl.data.evaluation import ReasoningMetric, SymbolMetric\n", | |||
"from abl.utils import ABLLogger, print_log\n", | |||
"from examples.hed.bridge import HedBridge" | |||
] | |||
@@ -27,7 +27,7 @@ | |||
"from examples.models.nn import SymbolNet\n", | |||
"from abl.learning import ABLModel, BasicNN\n", | |||
"from abl.reasoning import KBBase, Reasoner\n", | |||
"from abl.evaluation import ReasoningMetric, SymbolMetric\n", | |||
"from abl.data.evaluation import ReasoningMetric, SymbolMetric\n", | |||
"from abl.utils import ABLLogger, print_log\n", | |||
"from abl.bridge import SimpleBridge" | |||
] | |||
@@ -55,7 +55,7 @@ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"Both `train_data` and `test_data` have the same structures: tuples with three components: X (list where each element is a list of images), gt_pseudo_label (list where each element is a list of symbols, i.e., pseudo-labels) and Y (list where each element is the computed result). The length and structures of datasets are illustrated as follows.\n", | |||
"Both `train_data` and `test_data` have the same data.structures: tuples with three components: X (list where each element is a list of images), gt_pseudo_label (list where each element is a list of symbols, i.e., pseudo-labels) and Y (list where each element is the computed result). The length and data.structures of datasets are illustrated as follows.\n", | |||
"\n", | |||
"Note: ``gt_pseudo_label`` is only used to evaluate the performance of the learning part but not to train the model." | |||
] | |||
@@ -299,7 +299,7 @@ | |||
} | |||
], | |||
"source": [ | |||
"from abl.structures import ListData\n", | |||
"from abl.data.structures import ListData\n", | |||
"# ListData is a data structure provided by ABL-Package that can be used to organize data examples\n", | |||
"data_examples = ListData()\n", | |||
"# We use the first 1001st and 3001st data examples in the training set as an illustration\n", | |||
@@ -10,15 +10,17 @@ from examples.hwf.datasets import get_dataset | |||
from examples.models.nn import SymbolNet | |||
from abl.learning import ABLModel, BasicNN | |||
from abl.reasoning import KBBase, GroundKB, Reasoner | |||
from abl.evaluation import ReasoningMetric, SymbolMetric | |||
from abl.data.evaluation import ReasoningMetric, SymbolMetric | |||
from abl.utils import ABLLogger, print_log | |||
from abl.bridge import SimpleBridge | |||
class HwfKB(KBBase): | |||
def __init__(self, | |||
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"], | |||
max_err=1e-10, | |||
): | |||
def __init__( | |||
self, | |||
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"], | |||
max_err=1e-10, | |||
): | |||
super().__init__(pseudo_label_list, max_err) | |||
def _valid_candidate(self, formula): | |||
@@ -30,19 +32,21 @@ class HwfKB(KBBase): | |||
if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]: | |||
return False | |||
return True | |||
# Implement the deduction function | |||
def logic_forward(self, formula): | |||
if not self._valid_candidate(formula): | |||
return np.inf | |||
return eval("".join(formula)) | |||
class HwfGroundKB(GroundKB): | |||
def __init__(self, | |||
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"], | |||
GKB_len_list=[1,3,5,7], | |||
max_err=1e-10, | |||
): | |||
def __init__( | |||
self, | |||
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"], | |||
GKB_len_list=[1, 3, 5, 7], | |||
max_err=1e-10, | |||
): | |||
super().__init__(pseudo_label_list, GKB_len_list, max_err) | |||
def _valid_candidate(self, formula): | |||
@@ -54,40 +58,62 @@ class HwfGroundKB(GroundKB): | |||
if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]: | |||
return False | |||
return True | |||
# Implement the deduction function | |||
def logic_forward(self, formula): | |||
if not self._valid_candidate(formula): | |||
return np.inf | |||
return eval("".join(formula)) | |||
def main(): | |||
parser = argparse.ArgumentParser(description='MNIST Addition example') | |||
parser.add_argument('--no-cuda', action='store_true', default=False, | |||
help='disables CUDA training') | |||
parser.add_argument('--epochs', type=int, default=3, | |||
help='number of epochs in each learning loop iteration (default : 3)') | |||
parser.add_argument('--lr', type=float, default=1e-3, | |||
help='base model learning rate (default : 0.001)') | |||
parser.add_argument('--batch-size', type=int, default=128, | |||
help='base model batch size (default : 128)') | |||
parser.add_argument('--loops', type=int, default=5, | |||
help='number of loop iterations (default : 5)') | |||
parser.add_argument('--segment_size', type=int or float, default=1000, | |||
help='segment size (default : 1000)') | |||
parser.add_argument('--save_interval', type=int, default=1, | |||
help='save interval (default : 1)') | |||
parser.add_argument('--max-revision', type=int or float, default=-1, | |||
help='maximum revision in reasoner (default : -1)') | |||
parser.add_argument('--require-more-revision', type=int, default=5, | |||
help='require more revision in reasoner (default : 0)') | |||
parser.add_argument("--ground", action="store_true", default=False, | |||
help='use GroundKB (default: False)') | |||
parser.add_argument("--max-err", type=float, default=1e-10, | |||
help='max tolerance during abductive reasoning (default : 1e-10)') | |||
parser = argparse.ArgumentParser(description="MNIST Addition example") | |||
parser.add_argument( | |||
"--no-cuda", action="store_true", default=False, help="disables CUDA training" | |||
) | |||
parser.add_argument( | |||
"--epochs", | |||
type=int, | |||
default=3, | |||
help="number of epochs in each learning loop iteration (default : 3)", | |||
) | |||
parser.add_argument( | |||
"--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)" | |||
) | |||
parser.add_argument( | |||
"--batch-size", type=int, default=128, help="base model batch size (default : 128)" | |||
) | |||
parser.add_argument( | |||
"--loops", type=int, default=5, help="number of loop iterations (default : 5)" | |||
) | |||
parser.add_argument( | |||
"--segment_size", type=int or float, default=1000, help="segment size (default : 1000)" | |||
) | |||
parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)") | |||
parser.add_argument( | |||
"--max-revision", | |||
type=int or float, | |||
default=-1, | |||
help="maximum revision in reasoner (default : -1)", | |||
) | |||
parser.add_argument( | |||
"--require-more-revision", | |||
type=int, | |||
default=5, | |||
help="require more revision in reasoner (default : 0)", | |||
) | |||
parser.add_argument( | |||
"--ground", action="store_true", default=False, help="use GroundKB (default: False)" | |||
) | |||
parser.add_argument( | |||
"--max-err", | |||
type=float, | |||
default=1e-10, | |||
help="max tolerance during abductive reasoning (default : 1e-10)", | |||
) | |||
args = parser.parse_args() | |||
### Working with Data | |||
train_data = get_dataset(train=True, get_pseudo_label=True) | |||
test_data = get_dataset(train=False, get_pseudo_label=True) | |||
@@ -112,16 +138,18 @@ def main(): | |||
# Build ABLModel | |||
model = ABLModel(base_model) | |||
### Building the Reasoning Part | |||
# Build knowledge base | |||
if args.ground: | |||
kb = HwfGroundKB() | |||
else: | |||
kb = HwfKB() | |||
# Create reasoner | |||
reasoner = Reasoner(kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision) | |||
reasoner = Reasoner( | |||
kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision | |||
) | |||
### Building Evaluation Metrics | |||
metric_list = [SymbolMetric(prefix="hwf"), ReasoningMetric(kb=kb, prefix="hwf")] | |||
@@ -135,9 +163,15 @@ def main(): | |||
# Retrieve the directory of the Log file and define the directory for saving the model weights. | |||
log_dir = ABLLogger.get_current_instance().log_dir | |||
weights_dir = osp.join(log_dir, "weights") | |||
# Train and Test | |||
bridge.train(train_data, loops=args.loops, segment_size=args.segment_size, save_interval=args.save_interval, save_dir=weights_dir) | |||
bridge.train( | |||
train_data, | |||
loops=args.loops, | |||
segment_size=args.segment_size, | |||
save_interval=args.save_interval, | |||
save_dir=weights_dir, | |||
) | |||
bridge.test(test_data) | |||
@@ -9,10 +9,11 @@ from examples.mnist_add.datasets import get_dataset | |||
from examples.models.nn import LeNet5 | |||
from abl.learning import ABLModel, BasicNN | |||
from abl.reasoning import KBBase, GroundKB, PrologKB, Reasoner | |||
from abl.evaluation import ReasoningMetric, SymbolMetric | |||
from abl.data.evaluation import ReasoningMetric, SymbolMetric | |||
from abl.utils import ABLLogger, print_log | |||
from abl.bridge import SimpleBridge | |||
class AddKB(KBBase): | |||
def __init__(self, pseudo_label_list=list(range(10))): | |||
super().__init__(pseudo_label_list) | |||
@@ -20,6 +21,7 @@ class AddKB(KBBase): | |||
def logic_forward(self, nums): | |||
return sum(nums) | |||
class AddGroundKB(GroundKB): | |||
def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]): | |||
super().__init__(pseudo_label_list, GKB_len_list) | |||
@@ -27,36 +29,54 @@ class AddGroundKB(GroundKB): | |||
def logic_forward(self, nums): | |||
return sum(nums) | |||
def main(): | |||
parser = argparse.ArgumentParser(description='MNIST Addition example') | |||
parser.add_argument('--no-cuda', action='store_true', default=False, | |||
help='disables CUDA training') | |||
parser.add_argument('--epochs', type=int, default=1, | |||
help='number of epochs in each learning loop iteration (default : 1)') | |||
parser.add_argument('--lr', type=float, default=1e-3, | |||
help='base model learning rate (default : 0.001)') | |||
parser.add_argument('--alpha', type=float, default=0.9, | |||
help='alpha in RMSprop (default : 0.9)') | |||
parser.add_argument('--batch-size', type=int, default=32, | |||
help='base model batch size (default : 32)') | |||
parser.add_argument('--loops', type=int, default=5, | |||
help='number of loop iterations (default : 5)') | |||
parser.add_argument('--segment_size', type=int or float, default=1/3, | |||
help='segment size (default : 1/3)') | |||
parser.add_argument('--save_interval', type=int, default=1, | |||
help='save interval (default : 1)') | |||
parser.add_argument('--max-revision', type=int or float, default=-1, | |||
help='maximum revision in reasoner (default : -1)') | |||
parser.add_argument('--require-more-revision', type=int, default=5, | |||
help='require more revision in reasoner (default : 0)') | |||
parser = argparse.ArgumentParser(description="MNIST Addition example") | |||
parser.add_argument( | |||
"--no-cuda", action="store_true", default=False, help="disables CUDA training" | |||
) | |||
parser.add_argument( | |||
"--epochs", | |||
type=int, | |||
default=1, | |||
help="number of epochs in each learning loop iteration (default : 1)", | |||
) | |||
parser.add_argument( | |||
"--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)" | |||
) | |||
parser.add_argument("--alpha", type=float, default=0.9, help="alpha in RMSprop (default : 0.9)") | |||
parser.add_argument( | |||
"--batch-size", type=int, default=32, help="base model batch size (default : 32)" | |||
) | |||
parser.add_argument( | |||
"--loops", type=int, default=5, help="number of loop iterations (default : 5)" | |||
) | |||
parser.add_argument( | |||
"--segment_size", type=int or float, default=1 / 3, help="segment size (default : 1/3)" | |||
) | |||
parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)") | |||
parser.add_argument( | |||
"--max-revision", | |||
type=int or float, | |||
default=-1, | |||
help="maximum revision in reasoner (default : -1)", | |||
) | |||
parser.add_argument( | |||
"--require-more-revision", | |||
type=int, | |||
default=5, | |||
help="require more revision in reasoner (default : 0)", | |||
) | |||
kb_type = parser.add_mutually_exclusive_group() | |||
kb_type.add_argument("--prolog", action="store_true", default=False, | |||
help='use PrologKB (default: False)') | |||
kb_type.add_argument("--ground", action="store_true", default=False, | |||
help='use GroundKB (default: False)') | |||
kb_type.add_argument( | |||
"--prolog", action="store_true", default=False, help="use PrologKB (default: False)" | |||
) | |||
kb_type.add_argument( | |||
"--ground", action="store_true", default=False, help="use GroundKB (default: False)" | |||
) | |||
args = parser.parse_args() | |||
### Working with Data | |||
train_data = get_dataset(train=True, get_pseudo_label=True) | |||
test_data = get_dataset(train=False, get_pseudo_label=True) | |||
@@ -81,7 +101,7 @@ def main(): | |||
# Build ABLModel | |||
model = ABLModel(base_model) | |||
### Building the Reasoning Part | |||
# Build knowledge base | |||
if args.prolog: | |||
@@ -90,9 +110,11 @@ def main(): | |||
kb = AddGroundKB() | |||
else: | |||
kb = AddKB() | |||
# Create reasoner | |||
reasoner = Reasoner(kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision) | |||
reasoner = Reasoner( | |||
kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision | |||
) | |||
### Building Evaluation Metrics | |||
metric_list = [SymbolMetric(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")] | |||
@@ -106,13 +128,17 @@ def main(): | |||
# Retrieve the directory of the Log file and define the directory for saving the model weights. | |||
log_dir = ABLLogger.get_current_instance().log_dir | |||
weights_dir = osp.join(log_dir, "weights") | |||
# Train and Test | |||
bridge.train(train_data, loops=args.loops, segment_size=args.segment_size, save_interval=args.save_interval, save_dir=weights_dir) | |||
bridge.train( | |||
train_data, | |||
loops=args.loops, | |||
segment_size=args.segment_size, | |||
save_interval=args.save_interval, | |||
save_dir=weights_dir, | |||
) | |||
bridge.test(test_data) | |||
if __name__ == "__main__": | |||
main() |
@@ -13,7 +13,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 1, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -26,7 +26,7 @@ | |||
"from examples.models.nn import LeNet5\n", | |||
"from abl.learning import ABLModel, BasicNN\n", | |||
"from abl.reasoning import KBBase, Reasoner\n", | |||
"from abl.evaluation import ReasoningMetric, SymbolMetric\n", | |||
"from abl.data.evaluation import ReasoningMetric, SymbolMetric\n", | |||
"from abl.utils import ABLLogger, print_log\n", | |||
"from abl.bridge import SimpleBridge" | |||
] | |||
@@ -42,7 +42,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 2, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -61,7 +61,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 3, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
@@ -110,7 +110,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 4, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
@@ -170,7 +170,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 5, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -198,7 +198,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 6, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
@@ -229,7 +229,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 7, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -245,7 +245,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 8, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
@@ -261,7 +261,7 @@ | |||
} | |||
], | |||
"source": [ | |||
"from abl.structures import ListData\n", | |||
"from abl.data.structures import ListData\n", | |||
"# ListData is a data structure provided by ABL-Package that can be used to organize data examples\n", | |||
"data_examples = ListData()\n", | |||
"# We use the first 100 data examples in the training set as an illustration\n", | |||
@@ -295,7 +295,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 9, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -319,7 +319,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 10, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
@@ -352,7 +352,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 11, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -385,7 +385,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 12, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -404,7 +404,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 13, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -457,7 +457,7 @@ | |||
"name": "python", | |||
"nbconvert_exporter": "python", | |||
"pygments_lexer": "ipython3", | |||
"version": "3.8.13" | |||
"version": "3.8.18" | |||
}, | |||
"orig_nbformat": 4, | |||
"vscode": { | |||
@@ -8,7 +8,7 @@ import openml | |||
from abl.learning import ABLModel | |||
from abl.reasoning import KBBase, Reasoner | |||
from abl.evaluation import ReasoningMetric, SymbolMetric | |||
from abl.data.evaluation import ReasoningMetric, SymbolMetric | |||
from abl.bridge import SimpleBridge | |||
from abl.utils.utils import confidence_dist | |||
from abl.utils import ABLLogger, print_log | |||
@@ -27,23 +27,33 @@ model = ABLModel(rf) | |||
# %% [markdown] | |||
# ### Logic Part | |||
# %% | |||
class ZooKB(KBBase): | |||
def __init__(self): | |||
super().__init__(pseudo_label_list=list(range(7)), use_cache=False) | |||
# Use z3 solver | |||
# Use z3 solver | |||
self.solver = Solver() | |||
# Load information of Zoo dataset | |||
dataset = openml.datasets.get_dataset(dataset_id = 62, download_data=False, download_qualities=False, download_features_meta_data=False) | |||
X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute) | |||
dataset = openml.datasets.get_dataset( | |||
dataset_id=62, | |||
download_data=False, | |||
download_qualities=False, | |||
download_features_meta_data=False, | |||
) | |||
X, y, categorical_indicator, attribute_names = dataset.get_data( | |||
target=dataset.default_target_attribute | |||
) | |||
self.attribute_names = attribute_names | |||
self.target_names = y.cat.categories.tolist() | |||
# Define variables | |||
for name in self.attribute_names+self.target_names: | |||
exec(f"globals()['{name}'] = Int('{name}')") ## or use dict to create var and modify rules | |||
for name in self.attribute_names + self.target_names: | |||
exec( | |||
f"globals()['{name}'] = Int('{name}')" | |||
) ## or use dict to create var and modify rules | |||
# Define rules | |||
rules = [ | |||
Implies(milk == 1, mammal == 1), | |||
@@ -75,25 +85,27 @@ class ZooKB(KBBase): | |||
Implies(insect == 1, eggs == 1), | |||
Implies(insect == 1, Not(backbone == 1)), | |||
Implies(insect == 1, legs == 6), | |||
Implies(invertebrate == 1, Not(backbone == 1)) | |||
Implies(invertebrate == 1, Not(backbone == 1)), | |||
] | |||
# Define weights and sum of violated weights | |||
self.weights = {rule: 1 for rule in rules} | |||
self.total_violation_weight = Sum([If(Not(rule), self.weights[rule], 0) for rule in self.weights]) | |||
self.total_violation_weight = Sum( | |||
[If(Not(rule), self.weights[rule], 0) for rule in self.weights] | |||
) | |||
def logic_forward(self, pseudo_label, data_point): | |||
attribute_names, target_names = self.attribute_names, self.target_names | |||
solver = self.solver | |||
total_violation_weight = self.total_violation_weight | |||
pseudo_label, data_point = pseudo_label[0], data_point[0] | |||
self.solver.reset() | |||
for name, value in zip(attribute_names, data_point): | |||
solver.add(eval(f"{name} == {value}")) | |||
for cate, name in zip(self.pseudo_label_list,target_names): | |||
for cate, name in zip(self.pseudo_label_list, target_names): | |||
value = 1 if (cate == pseudo_label) else 0 | |||
solver.add(eval(f"{name} == {value}")) | |||
if solver.check() == sat: | |||
model = solver.model() | |||
total_weight = model.evaluate(total_violation_weight) | |||
@@ -101,7 +113,8 @@ class ZooKB(KBBase): | |||
else: | |||
# No solution found | |||
return 1e10 | |||
def consitency(data_example, candidates, candidate_idxs, reasoning_results): | |||
pred_prob = data_example.pred_prob | |||
model_scores = confidence_dist(pred_prob, candidate_idxs) | |||
@@ -109,51 +122,60 @@ def consitency(data_example, candidates, candidate_idxs, reasoning_results): | |||
scores = model_scores + rule_scores | |||
return scores | |||
kb = ZooKB() | |||
reasoner = Reasoner(kb, dist_func=consitency) | |||
# %% [markdown] | |||
# ### Datasets and Evaluation Metrics | |||
# %% | |||
# Function to load and preprocess the dataset | |||
def load_and_preprocess_dataset(dataset_id): | |||
dataset = openml.datasets.get_dataset(dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False) | |||
dataset = openml.datasets.get_dataset( | |||
dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False | |||
) | |||
X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute) | |||
# Convert data types | |||
for col in X.select_dtypes(include='bool').columns: | |||
for col in X.select_dtypes(include="bool").columns: | |||
X[col] = X[col].astype(int) | |||
y = y.cat.codes.astype(int) | |||
X, y = X.to_numpy(), y.to_numpy() | |||
return X, y | |||
# Function to split data (one shot) | |||
def split_dataset(X, y, test_size = 0.3): | |||
def split_dataset(X, y, test_size=0.3): | |||
# For every class: 1 : (1-test_size)*(len-1) : test_size*(len-1) | |||
label_indices, unlabel_indices, test_indices = [], [], [] | |||
for class_label in np.unique(y): | |||
idxs = np.where(y == class_label)[0] | |||
np.random.shuffle(idxs) | |||
n_train_unlabel = int((1-test_size)*(len(idxs)-1)) | |||
n_train_unlabel = int((1 - test_size) * (len(idxs) - 1)) | |||
label_indices.append(idxs[0]) | |||
unlabel_indices.extend(idxs[1:1+n_train_unlabel]) | |||
test_indices.extend(idxs[1+n_train_unlabel:]) | |||
unlabel_indices.extend(idxs[1 : 1 + n_train_unlabel]) | |||
test_indices.extend(idxs[1 + n_train_unlabel :]) | |||
X_label, y_label = X[label_indices], y[label_indices] | |||
X_unlabel, y_unlabel = X[unlabel_indices], y[unlabel_indices] | |||
X_test, y_test = X[test_indices], y[test_indices] | |||
return X_label, y_label, X_unlabel, y_unlabel, X_test, y_test | |||
# Load and preprocess the Zoo dataset | |||
X, y = load_and_preprocess_dataset(dataset_id=62) | |||
# Split data into labeled/unlabeled/test data | |||
X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3) | |||
# Transform tabluar data to the format required by ABL, which is a tuple of (X, ground truth of X, reasoning results) | |||
# For tabular data in abl, each example contains a single instance (a row from the dataset). | |||
# For these tabular data examples, the reasoning results are expected to be 0, indicating no rules are violated. | |||
def transform_tab_data(X, y): | |||
return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y)) | |||
label_data = transform_tab_data(X_label, y_label) | |||
test_data = transform_tab_data(X_test, y_test) | |||
train_data = transform_tab_data(X_unlabel, y_unlabel) | |||
@@ -181,9 +203,13 @@ print("------- Test the initial model -----------") | |||
bridge.test(test_data) | |||
print("------- Use ABL to train the model -----------") | |||
# Use ABL to train the model | |||
bridge.train(train_data=train_data, label_data=label_data, loops=3, segment_size=len(X_unlabel), save_dir=weights_dir) | |||
bridge.train( | |||
train_data=train_data, | |||
label_data=label_data, | |||
loops=3, | |||
segment_size=len(X_unlabel), | |||
save_dir=weights_dir, | |||
) | |||
print("------- Test the final model -----------") | |||
# Test the final model | |||
bridge.test(test_data) | |||
@@ -25,7 +25,7 @@ | |||
"from abl.learning import ABLModel\n", | |||
"from examples.zoo.kb import ZooKB\n", | |||
"from abl.reasoning import Reasoner\n", | |||
"from abl.evaluation import ReasoningMetric, SymbolMetric\n", | |||
"from abl.data.evaluation import ReasoningMetric, SymbolMetric\n", | |||
"from abl.utils import ABLLogger, print_log, confidence_dist\n", | |||
"from abl.bridge import SimpleBridge" | |||
] | |||
@@ -56,7 +56,7 @@ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"`train_data` and `test_data` share identical structures: tuples with three components: X (list where each element is a list of two images), gt_pseudo_label (list where each element is a list of two digits, i.e., pseudo-labels) and Y (list where each element is the sum of the two digits). The length and structures of datasets are illustrated as follows.\n", | |||
"`train_data` and `test_data` share identical data.structures: tuples with three components: X (list where each element is a list of two images), gt_pseudo_label (list where each element is a list of two digits, i.e., pseudo-labels) and Y (list where each element is the sum of the two digits). The length and data.structures of datasets are illustrated as follows.\n", | |||
"\n", | |||
"Note: ``gt_pseudo_label`` is only used to evaluate the performance of the learning part but not to train the model." | |||
] | |||
@@ -6,7 +6,7 @@ import torch.optim as optim | |||
from abl.learning import BasicNN | |||
from abl.reasoning import GroundKB, KBBase, PrologKB, Reasoner | |||
from abl.structures import ListData | |||
from abl.data.structures import ListData | |||
class LeNet5(nn.Module): | |||
@@ -202,10 +202,12 @@ def kb_add_prolog(): | |||
kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/add.pl") | |||
return kb | |||
@pytest.fixture | |||
def kb_hwf1(): | |||
return HwfKB(max_err=0.1) | |||
@pytest.fixture | |||
def kb_hwf2(): | |||
return HwfKB(max_err=1) | |||