Browse Source

[ENH] transfer structures and evaluation to data

pull/1/head
Gao Enhao 1 year ago
parent
commit
9b3a683e33
33 changed files with 286 additions and 192 deletions
  1. +1
    -2
      .coveragerc
  2. +2
    -4
      abl/__init__.py
  3. +1
    -1
      abl/bridge/base_bridge.py
  4. +2
    -2
      abl/bridge/simple_bridge.py
  5. +2
    -0
      abl/data/__init__.py
  6. +0
    -0
      abl/data/evaluation/__init__.py
  7. +1
    -1
      abl/data/evaluation/base_metric.py
  8. +1
    -1
      abl/data/evaluation/reasoning_metric.py
  9. +0
    -0
      abl/data/evaluation/symbol_metric.py
  10. +0
    -0
      abl/data/structures/__init__.py
  11. +3
    -2
      abl/data/structures/base_data_element.py
  12. +6
    -7
      abl/data/structures/list_data.py
  13. +1
    -1
      abl/learning/abl_model.py
  14. +25
    -21
      abl/reasoning/reasoner.py
  15. +18
    -0
      docs/API/abl.data.rst
  16. +0
    -7
      docs/API/abl.evaluation.rst
  17. +0
    -7
      docs/API/abl.structures.rst
  18. +1
    -1
      docs/Examples/HED.rst
  19. +2
    -2
      docs/Examples/HWF.rst
  20. +2
    -2
      docs/Examples/MNISTAdd.rst
  21. +3
    -3
      docs/Intro/Datasets.rst
  22. +1
    -1
      docs/Intro/Evaluation.rst
  23. +1
    -1
      docs/Intro/Quick-Start.rst
  24. +1
    -2
      docs/index.rst
  25. +2
    -2
      examples/hed/bridge.py
  26. +1
    -1
      examples/hed/hed.ipynb
  27. +3
    -3
      examples/hwf/hwf.ipynb
  28. +75
    -41
      examples/hwf/main.py
  29. +60
    -34
      examples/mnist_add/main.py
  30. +16
    -16
      examples/mnist_add/mnist_add.ipynb
  31. +50
    -24
      examples/zoo/main.py
  32. +2
    -2
      examples/zoo/zoo.ipynb
  33. +3
    -1
      tests/conftest.py

+ 1
- 2
.coveragerc View File

@@ -8,8 +8,7 @@ omit =
*/abl/__init__.py
abl/bridge/__init__.py
abl/dataset/__init__.py
abl/evaluation/__init__.py
abl/data/__init__.py
abl/learning/__init__.py
abl/reasoning/__init__.py
abl/structures/__init__.py
abl/utils/__init__.py

+ 2
- 4
abl/__init__.py View File

@@ -1,11 +1,9 @@
from . import bridge, dataset, evaluation, learning, reasoning, structures, utils
from . import bridge, data, learning, reasoning, utils

__all__ = [
"bridge",
"dataset",
"evaluation",
"data",
"learning",
"reasoning",
"structures",
"utils",
]

+ 1
- 1
abl/bridge/base_bridge.py View File

@@ -3,7 +3,7 @@ from typing import Any, List, Optional, Tuple, Union

from ..learning import ABLModel
from ..reasoning import Reasoner
from ..structures import ListData
from ..data.structures import ListData


class BaseBridge(metaclass=ABCMeta):


+ 2
- 2
abl/bridge/simple_bridge.py View File

@@ -3,10 +3,10 @@ from typing import Any, List, Optional, Tuple, Union

from numpy import ndarray

from ..evaluation import BaseMetric
from ..data.evaluation import BaseMetric
from ..learning import ABLModel
from ..reasoning import Reasoner
from ..structures import ListData
from ..data.structures import ListData
from ..utils import print_log
from .base_bridge import BaseBridge



+ 2
- 0
abl/data/__init__.py View File

@@ -0,0 +1,2 @@
from .evaluation import *
from .structures import *

abl/evaluation/__init__.py → abl/data/evaluation/__init__.py View File


abl/evaluation/base_metric.py → abl/data/evaluation/base_metric.py View File

@@ -3,7 +3,7 @@ from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional

from ..structures import ListData
from ..utils import print_log
from ...utils import print_log


class BaseMetric(metaclass=ABCMeta):

abl/evaluation/reasoning_metric.py → abl/data/evaluation/reasoning_metric.py View File

@@ -1,6 +1,6 @@
from typing import Optional

from ..reasoning import KBBase
from ...reasoning import KBBase
from ..structures import ListData
from .base_metric import BaseMetric


abl/evaluation/symbol_metric.py → abl/data/evaluation/symbol_metric.py View File


abl/structures/__init__.py → abl/data/structures/__init__.py View File


abl/structures/base_data_element.py → abl/data/structures/base_data_element.py View File

@@ -5,8 +5,9 @@ from typing import Any, Iterator, Optional, Tuple, Type, Union
import numpy as np
import torch


# Modified from
# https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py
# https://github.com/open-mmlab/mmengine/blob/main/mmengine/data.structures/base_data_element.py
class BaseDataElement:
"""A base data interface that supports Tensor-like and dict-like
operations.
@@ -73,7 +74,7 @@ class BaseDataElement:

Examples:
>>> import torch
>>> from mmengine.structures import BaseDataElement
>>> from mmengine.data.structures import BaseDataElement
>>> gt_instances = BaseDataElement()
>>> bboxes = torch.rand((5, 4))
>>> scores = torch.rand((5,))

abl/structures/list_data.py → abl/data/structures/list_data.py View File

@@ -1,12 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from typing import List, Union

import numpy as np
import torch

from ..utils import flatten as flatten_list
from ..utils import to_hashable
from ...utils import flatten as flatten_list
from ...utils import to_hashable
from .base_data_element import BaseDataElement

BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor]
@@ -16,19 +15,19 @@ IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndar


# Modified from
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_data.structures/instance_data.py # noqa
class ListData(BaseDataElement):
"""
Data structure for example-level data.

Subclass of :class:`BaseDataElement`. All value in `data_fields`
should have the same length. This design refer to
https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py
https://github.com/facebookresearch/detectron2/blob/master/detectron2/data.structures/instances.py

ListData supports `index` and `slice` for data field. The type of value in data field can be either `None` or `list` of base data structures such as `torch.Tensor`, `numpy.ndarray`, `list`, `str` and `tuple`.
ListData supports `index` and `slice` for data field. The type of value in data field can be either `None` or `list` of base data data.structures such as `torch.Tensor`, `numpy.ndarray`, `list`, `str` and `tuple`.

Examples:
>>> from abl.structures import ListData
>>> from abl.data.structures import ListData
>>> import numpy as np
>>> import torch
>>> data_examples = ListData()

+ 1
- 1
abl/learning/abl_model.py View File

@@ -1,7 +1,7 @@
import pickle
from typing import Any, Dict

from ..structures import ListData
from ..data.structures import ListData
from ..utils import reform_list




+ 25
- 21
abl/reasoning/reasoner.py View File

@@ -5,7 +5,7 @@ import numpy as np
from zoopt import Dimension, Objective, Opt, Parameter, Solution

from ..reasoning import KBBase
from ..structures import ListData
from ..data.structures import ListData
from ..utils.utils import confidence_dist, hamming_dist


@@ -19,18 +19,18 @@ class Reasoner:
The knowledge base to be used for reasoning.
dist_func : Union[str, Callable], optional
The distance function used to determine the cost list between each
candidate and the given prediction. The cost is also referred to as a consistency
measure, wherein the candidate with lowest cost is selected as the final
abduced label. It can be either a string representing a predefined distance
function or a callable function. The available predefined distance functions:
'hamming' | 'confidence'. 'hamming': directly calculates the Hamming
distance between the predicted pseudo-label in the data example and each
candidate, 'confidence': calculates the distance between the prediction
and each candidate based on confidence derived from the predicted probability
in the data example. The callable function should have the signature
dist_func(data_example, candidates, candidate_idxs, reasoning_results) and must return a cost list. Each element
in this cost list should be a numerical value representing the cost for each
candidate, and the list should have the same length as candidates.
candidate and the given prediction. The cost is also referred to as a consistency
measure, wherein the candidate with lowest cost is selected as the final
abduced label. It can be either a string representing a predefined distance
function or a callable function. The available predefined distance functions:
'hamming' | 'confidence'. 'hamming': directly calculates the Hamming
distance between the predicted pseudo-label in the data example and each
candidate, 'confidence': calculates the distance between the prediction
and each candidate based on confidence derived from the predicted probability
in the data example. The callable function should have the signature
dist_func(data_example, candidates, candidate_idxs, reasoning_results) and must return a cost list. Each element
in this cost list should be a numerical value representing the cost for each
candidate, and the list should have the same length as candidates.
Defaults to 'confidence'.
idx_to_label : Optional[dict], optional
A mapping from index in the base model to label. If not provided, a default
@@ -64,7 +64,9 @@ class Reasoner:
self.require_more_revision = require_more_revision

if idx_to_label is None:
self.idx_to_label = {index: label for index, label in enumerate(self.kb.pseudo_label_list)}
self.idx_to_label = {
index: label for index, label in enumerate(self.kb.pseudo_label_list)
}
else:
self._check_valid_idx_to_label(idx_to_label)
self.idx_to_label = idx_to_label
@@ -80,7 +82,9 @@ class Reasoner:
elif callable(dist_func):
params = inspect.signature(dist_func).parameters.values()
if len(params) != 4:
raise ValueError(f"User-defined dist_func must have exactly four parameters, but got {len(params)}.")
raise ValueError(
f"User-defined dist_func must have exactly four parameters, but got {len(params)}."
)
return
else:
raise TypeError(
@@ -289,18 +293,18 @@ class Reasoner:
solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num)
revision_idx = np.where(solution.get_x() != 0)[0]
candidates, reasoning_results = self.kb.revise_at_idx(
pseudo_label=data_example.pred_pseudo_label,
y=data_example.Y,
x=data_example.X,
revision_idx=revision_idx
pseudo_label=data_example.pred_pseudo_label,
y=data_example.Y,
x=data_example.X,
revision_idx=revision_idx,
)
else:
candidates, reasoning_results = self.kb.abduce_candidates(
pseudo_label=data_example.pred_pseudo_label,
y=data_example.Y,
y=data_example.Y,
x=data_example.X,
max_revision_num=max_revision_num,
require_more_revision=self.require_more_revision
require_more_revision=self.require_more_revision,
)

candidate = self._get_one_candidate(data_example, candidates, reasoning_results)


+ 18
- 0
docs/API/abl.data.rst View File

@@ -0,0 +1,18 @@
abl.data
===================

Data Structure
--------------

.. autoclass:: abl.data.structures.ListData
:members:
:undoc-members:
:show-inheritance:

Evaluation Metric
-----------------

.. automodule:: abl.data.evaluation
:members:
:undoc-members:
:show-inheritance:

+ 0
- 7
docs/API/abl.evaluation.rst View File

@@ -1,7 +0,0 @@
abl.evaluation
==================

.. automodule:: abl.evaluation
:members:
:undoc-members:
:show-inheritance:

+ 0
- 7
docs/API/abl.structures.rst View File

@@ -1,7 +0,0 @@
abl.structures
==================

.. autoclass:: abl.structures.ListData
:members:
:undoc-members:
:show-inheritance:

+ 1
- 1
docs/Examples/HED.rst View File

@@ -32,7 +32,7 @@ model.
from examples.models.nn import SymbolNet
from abl.learning import ABLModel, BasicNN
from examples.hed.reasoning import HedKB, HedReasoner
from abl.evaluation import ReasoningMetric, SymbolMetric
from abl.data.evaluation import ReasoningMetric, SymbolMetric
from abl.utils import ABLLogger, print_log
from examples.hed.bridge import HedBridge



+ 2
- 2
docs/Examples/HWF.rst View File

@@ -30,7 +30,7 @@ machine learning model.
from examples.models.nn import SymbolNet
from abl.learning import ABLModel, BasicNN
from abl.reasoning import KBBase, Reasoner
from abl.evaluation import ReasoningMetric, SymbolMetric
from abl.data.evaluation import ReasoningMetric, SymbolMetric
from abl.utils import ABLLogger, print_log
from abl.bridge import SimpleBridge

@@ -225,7 +225,7 @@ examples.

.. code:: ipython3

from abl.structures import ListData
from abl.data.structures import ListData
# ListData is a data structure provided by ABL-Package that can be used to organize data examples
data_examples = ListData()
# We use the first 1001st and 3001st data examples in the training set as an illustration


+ 2
- 2
docs/Examples/MNISTAdd.rst View File

@@ -28,7 +28,7 @@ machine learning model.
from examples.models.nn import LeNet5
from abl.learning import ABLModel, BasicNN
from abl.reasoning import KBBase, Reasoner
from abl.evaluation import ReasoningMetric, SymbolMetric
from abl.data.evaluation import ReasoningMetric, SymbolMetric
from abl.utils import ABLLogger, print_log
from abl.bridge import SimpleBridge

@@ -191,7 +191,7 @@ examples.

.. code:: ipython3

from abl.structures import ListData
from abl.data.structures import ListData
# ListData is a data structure provided by ABL-Package that can be used to organize data examples
data_examples = ListData()
# We use the first 100 data examples in the training set as an illustration


+ 3
- 3
docs/Intro/Datasets.rst View File

@@ -16,7 +16,7 @@ In this section, we will look at the datasets and data structures in ABL-Package

# Import necessary libraries and modules
import torch
from abl.structures import ListData
from abl.data.structures import ListData

Dataset
-------
@@ -53,11 +53,11 @@ As an illustration, in the MNIST Addition example, the data used for training ar
Data Structure
--------------

Besides the user-provided dataset, various forms of data are utilized and dynamicly generate throughout the training and testing process of Abductive Learning framework. Examples include raw data, predicted pseudo-label, abduced pseudo-label, pseudo-label indices, and so on. To manage this diversity and ensure a stable, versatile interface, ABL-Package employs `abstract data interfaces <../API/abl.structures.html>`_ to encapsulate different forms of data that will be used in the total learning process.
Besides the user-provided dataset, various forms of data are utilized and dynamicly generate throughout the training and testing process of Abductive Learning framework. Examples include raw data, predicted pseudo-label, abduced pseudo-label, pseudo-label indices, and so on. To manage this diversity and ensure a stable, versatile interface, ABL-Package employs `abstract data interfaces <../API/abl.data.html>`_ to encapsulate different forms of data that will be used in the total learning process.

``BaseDataElement`` is the base class for all abstract data interfaces. Inherited from ``BaseDataElement``, ``ListData`` is the most commonly used abstract data interface in ABL-Package. As the fundamental data structure, ``ListData`` implements commonly used data manipulation methods and is responsible for transferring data between various components of ABL, ensuring that stages such as prediction, training, and abductive reasoning can utilize ``ListData`` as a unified input format.

Before proceeding to other stages, user-provided datasets are firstly converted into ``ListData``. For flexibility, ABL-Package also allows user to directly supply data in ``ListData`` format, which similarly requires the inclusion of three attributes: ``X``, ``gt_pseudo_label``, and ``Y``. The following code shows the basic usage of ``ListData``. More information can be found in the `API documentation <../API/abl.structures.html>`_.
Before proceeding to other stages, user-provided datasets are firstly converted into ``ListData``. For flexibility, ABL-Package also allows user to directly supply data in ``ListData`` format, which similarly requires the inclusion of three attributes: ``X``, ``gt_pseudo_label``, and ``Y``. The following code shows the basic usage of ``ListData``. More information can be found in the `API documentation <../API/abl.data.html>`_.

.. code-block:: python



+ 1
- 1
docs/Intro/Evaluation.rst View File

@@ -15,7 +15,7 @@ In this section, we will look at how to build evaluation metrics.
.. code:: python

# Import necessary modules
from abl.evaluation import BaseMetric, SymbolMetric, ReasoningMetric
from abl.data.evaluation import BaseMetric, SymbolMetric, ReasoningMetric

ABL-Package seperates the evaluation process from model training and testing as an independent class, ``BaseMetric``. The training and testing processes are implemented in the ``BaseBridge`` class, so metrics are used by this class and its sub-classes. After building a ``bridge`` with a list of ``BaseMetric`` instances, these metrics will be used by the ``bridge.valid`` method to evaluate the model performance during training and testing.



+ 1
- 1
docs/Intro/Quick-Start.rst View File

@@ -111,7 +111,7 @@ ABL-Package provides two basic metrics, namely ``SymbolMetric`` and ``ReasoningM

.. code:: python

from abl.evaluation import ReasoningMetric, SymbolMetric
from abl.data.evaluation import ReasoningMetric, SymbolMetric

metric_list = [SymbolMetric(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")]



+ 1
- 2
docs/index.rst View File

@@ -35,9 +35,8 @@
API/abl.dataset
API/abl.learning
API/abl.reasoning
API/abl.evaluation
API/abl.bridge
API/abl.structures
API/abl.data
API/abl.utils

.. toctree::


+ 2
- 2
examples/hed/bridge.py View File

@@ -5,10 +5,10 @@ import torch

from abl.bridge import SimpleBridge
from abl.dataset import RegressionDataset
from abl.evaluation import BaseMetric
from abl.data.evaluation import BaseMetric
from abl.learning import ABLModel, BasicNN
from abl.reasoning import Reasoner
from abl.structures import ListData
from abl.data.structures import ListData
from abl.utils import print_log
from examples.hed.datasets import get_pretrain_data
from examples.hed.utils import InfiniteSampler, gen_mappings


+ 1
- 1
examples/hed/hed.ipynb View File

@@ -26,7 +26,7 @@
"from examples.models.nn import SymbolNet\n",
"from abl.learning import ABLModel, BasicNN\n",
"from examples.hed.reasoning import HedKB, HedReasoner\n",
"from abl.evaluation import ReasoningMetric, SymbolMetric\n",
"from abl.data.evaluation import ReasoningMetric, SymbolMetric\n",
"from abl.utils import ABLLogger, print_log\n",
"from examples.hed.bridge import HedBridge"
]


+ 3
- 3
examples/hwf/hwf.ipynb View File

@@ -27,7 +27,7 @@
"from examples.models.nn import SymbolNet\n",
"from abl.learning import ABLModel, BasicNN\n",
"from abl.reasoning import KBBase, Reasoner\n",
"from abl.evaluation import ReasoningMetric, SymbolMetric\n",
"from abl.data.evaluation import ReasoningMetric, SymbolMetric\n",
"from abl.utils import ABLLogger, print_log\n",
"from abl.bridge import SimpleBridge"
]
@@ -55,7 +55,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Both `train_data` and `test_data` have the same structures: tuples with three components: X (list where each element is a list of images), gt_pseudo_label (list where each element is a list of symbols, i.e., pseudo-labels) and Y (list where each element is the computed result). The length and structures of datasets are illustrated as follows.\n",
"Both `train_data` and `test_data` have the same data.structures: tuples with three components: X (list where each element is a list of images), gt_pseudo_label (list where each element is a list of symbols, i.e., pseudo-labels) and Y (list where each element is the computed result). The length and data.structures of datasets are illustrated as follows.\n",
"\n",
"Note: ``gt_pseudo_label`` is only used to evaluate the performance of the learning part but not to train the model."
]
@@ -299,7 +299,7 @@
}
],
"source": [
"from abl.structures import ListData\n",
"from abl.data.structures import ListData\n",
"# ListData is a data structure provided by ABL-Package that can be used to organize data examples\n",
"data_examples = ListData()\n",
"# We use the first 1001st and 3001st data examples in the training set as an illustration\n",


+ 75
- 41
examples/hwf/main.py View File

@@ -10,15 +10,17 @@ from examples.hwf.datasets import get_dataset
from examples.models.nn import SymbolNet
from abl.learning import ABLModel, BasicNN
from abl.reasoning import KBBase, GroundKB, Reasoner
from abl.evaluation import ReasoningMetric, SymbolMetric
from abl.data.evaluation import ReasoningMetric, SymbolMetric
from abl.utils import ABLLogger, print_log
from abl.bridge import SimpleBridge


class HwfKB(KBBase):
def __init__(self,
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"],
max_err=1e-10,
):
def __init__(
self,
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"],
max_err=1e-10,
):
super().__init__(pseudo_label_list, max_err)

def _valid_candidate(self, formula):
@@ -30,19 +32,21 @@ class HwfKB(KBBase):
if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]:
return False
return True
# Implement the deduction function
def logic_forward(self, formula):
if not self._valid_candidate(formula):
return np.inf
return eval("".join(formula))


class HwfGroundKB(GroundKB):
def __init__(self,
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"],
GKB_len_list=[1,3,5,7],
max_err=1e-10,
):
def __init__(
self,
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"],
GKB_len_list=[1, 3, 5, 7],
max_err=1e-10,
):
super().__init__(pseudo_label_list, GKB_len_list, max_err)

def _valid_candidate(self, formula):
@@ -54,40 +58,62 @@ class HwfGroundKB(GroundKB):
if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]:
return False
return True
# Implement the deduction function
def logic_forward(self, formula):
if not self._valid_candidate(formula):
return np.inf
return eval("".join(formula))


def main():
parser = argparse.ArgumentParser(description='MNIST Addition example')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--epochs', type=int, default=3,
help='number of epochs in each learning loop iteration (default : 3)')
parser.add_argument('--lr', type=float, default=1e-3,
help='base model learning rate (default : 0.001)')
parser.add_argument('--batch-size', type=int, default=128,
help='base model batch size (default : 128)')
parser.add_argument('--loops', type=int, default=5,
help='number of loop iterations (default : 5)')
parser.add_argument('--segment_size', type=int or float, default=1000,
help='segment size (default : 1000)')
parser.add_argument('--save_interval', type=int, default=1,
help='save interval (default : 1)')
parser.add_argument('--max-revision', type=int or float, default=-1,
help='maximum revision in reasoner (default : -1)')
parser.add_argument('--require-more-revision', type=int, default=5,
help='require more revision in reasoner (default : 0)')
parser.add_argument("--ground", action="store_true", default=False,
help='use GroundKB (default: False)')
parser.add_argument("--max-err", type=float, default=1e-10,
help='max tolerance during abductive reasoning (default : 1e-10)')
parser = argparse.ArgumentParser(description="MNIST Addition example")
parser.add_argument(
"--no-cuda", action="store_true", default=False, help="disables CUDA training"
)
parser.add_argument(
"--epochs",
type=int,
default=3,
help="number of epochs in each learning loop iteration (default : 3)",
)
parser.add_argument(
"--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)"
)
parser.add_argument(
"--batch-size", type=int, default=128, help="base model batch size (default : 128)"
)
parser.add_argument(
"--loops", type=int, default=5, help="number of loop iterations (default : 5)"
)
parser.add_argument(
"--segment_size", type=int or float, default=1000, help="segment size (default : 1000)"
)
parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)")
parser.add_argument(
"--max-revision",
type=int or float,
default=-1,
help="maximum revision in reasoner (default : -1)",
)
parser.add_argument(
"--require-more-revision",
type=int,
default=5,
help="require more revision in reasoner (default : 0)",
)
parser.add_argument(
"--ground", action="store_true", default=False, help="use GroundKB (default: False)"
)
parser.add_argument(
"--max-err",
type=float,
default=1e-10,
help="max tolerance during abductive reasoning (default : 1e-10)",
)

args = parser.parse_args()
### Working with Data
train_data = get_dataset(train=True, get_pseudo_label=True)
test_data = get_dataset(train=False, get_pseudo_label=True)
@@ -112,16 +138,18 @@ def main():

# Build ABLModel
model = ABLModel(base_model)
### Building the Reasoning Part
# Build knowledge base
if args.ground:
kb = HwfGroundKB()
else:
kb = HwfKB()
# Create reasoner
reasoner = Reasoner(kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision)
reasoner = Reasoner(
kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision
)

### Building Evaluation Metrics
metric_list = [SymbolMetric(prefix="hwf"), ReasoningMetric(kb=kb, prefix="hwf")]
@@ -135,9 +163,15 @@ def main():
# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")
# Train and Test
bridge.train(train_data, loops=args.loops, segment_size=args.segment_size, save_interval=args.save_interval, save_dir=weights_dir)
bridge.train(
train_data,
loops=args.loops,
segment_size=args.segment_size,
save_interval=args.save_interval,
save_dir=weights_dir,
)
bridge.test(test_data)




+ 60
- 34
examples/mnist_add/main.py View File

@@ -9,10 +9,11 @@ from examples.mnist_add.datasets import get_dataset
from examples.models.nn import LeNet5
from abl.learning import ABLModel, BasicNN
from abl.reasoning import KBBase, GroundKB, PrologKB, Reasoner
from abl.evaluation import ReasoningMetric, SymbolMetric
from abl.data.evaluation import ReasoningMetric, SymbolMetric
from abl.utils import ABLLogger, print_log
from abl.bridge import SimpleBridge


class AddKB(KBBase):
def __init__(self, pseudo_label_list=list(range(10))):
super().__init__(pseudo_label_list)
@@ -20,6 +21,7 @@ class AddKB(KBBase):
def logic_forward(self, nums):
return sum(nums)


class AddGroundKB(GroundKB):
def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]):
super().__init__(pseudo_label_list, GKB_len_list)
@@ -27,36 +29,54 @@ class AddGroundKB(GroundKB):
def logic_forward(self, nums):
return sum(nums)


def main():
parser = argparse.ArgumentParser(description='MNIST Addition example')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--epochs', type=int, default=1,
help='number of epochs in each learning loop iteration (default : 1)')
parser.add_argument('--lr', type=float, default=1e-3,
help='base model learning rate (default : 0.001)')
parser.add_argument('--alpha', type=float, default=0.9,
help='alpha in RMSprop (default : 0.9)')
parser.add_argument('--batch-size', type=int, default=32,
help='base model batch size (default : 32)')
parser.add_argument('--loops', type=int, default=5,
help='number of loop iterations (default : 5)')
parser.add_argument('--segment_size', type=int or float, default=1/3,
help='segment size (default : 1/3)')
parser.add_argument('--save_interval', type=int, default=1,
help='save interval (default : 1)')
parser.add_argument('--max-revision', type=int or float, default=-1,
help='maximum revision in reasoner (default : -1)')
parser.add_argument('--require-more-revision', type=int, default=5,
help='require more revision in reasoner (default : 0)')
parser = argparse.ArgumentParser(description="MNIST Addition example")
parser.add_argument(
"--no-cuda", action="store_true", default=False, help="disables CUDA training"
)
parser.add_argument(
"--epochs",
type=int,
default=1,
help="number of epochs in each learning loop iteration (default : 1)",
)
parser.add_argument(
"--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)"
)
parser.add_argument("--alpha", type=float, default=0.9, help="alpha in RMSprop (default : 0.9)")
parser.add_argument(
"--batch-size", type=int, default=32, help="base model batch size (default : 32)"
)
parser.add_argument(
"--loops", type=int, default=5, help="number of loop iterations (default : 5)"
)
parser.add_argument(
"--segment_size", type=int or float, default=1 / 3, help="segment size (default : 1/3)"
)
parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)")
parser.add_argument(
"--max-revision",
type=int or float,
default=-1,
help="maximum revision in reasoner (default : -1)",
)
parser.add_argument(
"--require-more-revision",
type=int,
default=5,
help="require more revision in reasoner (default : 0)",
)
kb_type = parser.add_mutually_exclusive_group()
kb_type.add_argument("--prolog", action="store_true", default=False,
help='use PrologKB (default: False)')
kb_type.add_argument("--ground", action="store_true", default=False,
help='use GroundKB (default: False)')
kb_type.add_argument(
"--prolog", action="store_true", default=False, help="use PrologKB (default: False)"
)
kb_type.add_argument(
"--ground", action="store_true", default=False, help="use GroundKB (default: False)"
)

args = parser.parse_args()
### Working with Data
train_data = get_dataset(train=True, get_pseudo_label=True)
test_data = get_dataset(train=False, get_pseudo_label=True)
@@ -81,7 +101,7 @@ def main():

# Build ABLModel
model = ABLModel(base_model)
### Building the Reasoning Part
# Build knowledge base
if args.prolog:
@@ -90,9 +110,11 @@ def main():
kb = AddGroundKB()
else:
kb = AddKB()
# Create reasoner
reasoner = Reasoner(kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision)
reasoner = Reasoner(
kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision
)

### Building Evaluation Metrics
metric_list = [SymbolMetric(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")]
@@ -106,13 +128,17 @@ def main():
# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")
# Train and Test
bridge.train(train_data, loops=args.loops, segment_size=args.segment_size, save_interval=args.save_interval, save_dir=weights_dir)
bridge.train(
train_data,
loops=args.loops,
segment_size=args.segment_size,
save_interval=args.save_interval,
save_dir=weights_dir,
)
bridge.test(test_data)



if __name__ == "__main__":
main()

+ 16
- 16
examples/mnist_add/mnist_add.ipynb View File

@@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -26,7 +26,7 @@
"from examples.models.nn import LeNet5\n",
"from abl.learning import ABLModel, BasicNN\n",
"from abl.reasoning import KBBase, Reasoner\n",
"from abl.evaluation import ReasoningMetric, SymbolMetric\n",
"from abl.data.evaluation import ReasoningMetric, SymbolMetric\n",
"from abl.utils import ABLLogger, print_log\n",
"from abl.bridge import SimpleBridge"
]
@@ -42,7 +42,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -61,7 +61,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -110,7 +110,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -170,7 +170,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -198,7 +198,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -229,7 +229,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -245,7 +245,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -261,7 +261,7 @@
}
],
"source": [
"from abl.structures import ListData\n",
"from abl.data.structures import ListData\n",
"# ListData is a data structure provided by ABL-Package that can be used to organize data examples\n",
"data_examples = ListData()\n",
"# We use the first 100 data examples in the training set as an illustration\n",
@@ -295,7 +295,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -319,7 +319,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -352,7 +352,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -385,7 +385,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -404,7 +404,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -457,7 +457,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.8.18"
},
"orig_nbformat": 4,
"vscode": {


+ 50
- 24
examples/zoo/main.py View File

@@ -8,7 +8,7 @@ import openml

from abl.learning import ABLModel
from abl.reasoning import KBBase, Reasoner
from abl.evaluation import ReasoningMetric, SymbolMetric
from abl.data.evaluation import ReasoningMetric, SymbolMetric
from abl.bridge import SimpleBridge
from abl.utils.utils import confidence_dist
from abl.utils import ABLLogger, print_log
@@ -27,23 +27,33 @@ model = ABLModel(rf)
# %% [markdown]
# ### Logic Part


# %%
class ZooKB(KBBase):
def __init__(self):
super().__init__(pseudo_label_list=list(range(7)), use_cache=False)
# Use z3 solver
# Use z3 solver
self.solver = Solver()

# Load information of Zoo dataset
dataset = openml.datasets.get_dataset(dataset_id = 62, download_data=False, download_qualities=False, download_features_meta_data=False)
X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)
dataset = openml.datasets.get_dataset(
dataset_id=62,
download_data=False,
download_qualities=False,
download_features_meta_data=False,
)
X, y, categorical_indicator, attribute_names = dataset.get_data(
target=dataset.default_target_attribute
)
self.attribute_names = attribute_names
self.target_names = y.cat.categories.tolist()
# Define variables
for name in self.attribute_names+self.target_names:
exec(f"globals()['{name}'] = Int('{name}')") ## or use dict to create var and modify rules
for name in self.attribute_names + self.target_names:
exec(
f"globals()['{name}'] = Int('{name}')"
) ## or use dict to create var and modify rules
# Define rules
rules = [
Implies(milk == 1, mammal == 1),
@@ -75,25 +85,27 @@ class ZooKB(KBBase):
Implies(insect == 1, eggs == 1),
Implies(insect == 1, Not(backbone == 1)),
Implies(insect == 1, legs == 6),
Implies(invertebrate == 1, Not(backbone == 1))
Implies(invertebrate == 1, Not(backbone == 1)),
]
# Define weights and sum of violated weights
self.weights = {rule: 1 for rule in rules}
self.total_violation_weight = Sum([If(Not(rule), self.weights[rule], 0) for rule in self.weights])
self.total_violation_weight = Sum(
[If(Not(rule), self.weights[rule], 0) for rule in self.weights]
)

def logic_forward(self, pseudo_label, data_point):
attribute_names, target_names = self.attribute_names, self.target_names
solver = self.solver
total_violation_weight = self.total_violation_weight
pseudo_label, data_point = pseudo_label[0], data_point[0]
self.solver.reset()
for name, value in zip(attribute_names, data_point):
solver.add(eval(f"{name} == {value}"))
for cate, name in zip(self.pseudo_label_list,target_names):
for cate, name in zip(self.pseudo_label_list, target_names):
value = 1 if (cate == pseudo_label) else 0
solver.add(eval(f"{name} == {value}"))
if solver.check() == sat:
model = solver.model()
total_weight = model.evaluate(total_violation_weight)
@@ -101,7 +113,8 @@ class ZooKB(KBBase):
else:
# No solution found
return 1e10


def consitency(data_example, candidates, candidate_idxs, reasoning_results):
pred_prob = data_example.pred_prob
model_scores = confidence_dist(pred_prob, candidate_idxs)
@@ -109,51 +122,60 @@ def consitency(data_example, candidates, candidate_idxs, reasoning_results):
scores = model_scores + rule_scores
return scores


kb = ZooKB()
reasoner = Reasoner(kb, dist_func=consitency)

# %% [markdown]
# ### Datasets and Evaluation Metrics


# %%
# Function to load and preprocess the dataset
def load_and_preprocess_dataset(dataset_id):
dataset = openml.datasets.get_dataset(dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False)
dataset = openml.datasets.get_dataset(
dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False
)
X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)
# Convert data types
for col in X.select_dtypes(include='bool').columns:
for col in X.select_dtypes(include="bool").columns:
X[col] = X[col].astype(int)
y = y.cat.codes.astype(int)
X, y = X.to_numpy(), y.to_numpy()
return X, y


# Function to split data (one shot)
def split_dataset(X, y, test_size = 0.3):
def split_dataset(X, y, test_size=0.3):
# For every class: 1 : (1-test_size)*(len-1) : test_size*(len-1)
label_indices, unlabel_indices, test_indices = [], [], []
for class_label in np.unique(y):
idxs = np.where(y == class_label)[0]
np.random.shuffle(idxs)
n_train_unlabel = int((1-test_size)*(len(idxs)-1))
n_train_unlabel = int((1 - test_size) * (len(idxs) - 1))
label_indices.append(idxs[0])
unlabel_indices.extend(idxs[1:1+n_train_unlabel])
test_indices.extend(idxs[1+n_train_unlabel:])
unlabel_indices.extend(idxs[1 : 1 + n_train_unlabel])
test_indices.extend(idxs[1 + n_train_unlabel :])
X_label, y_label = X[label_indices], y[label_indices]
X_unlabel, y_unlabel = X[unlabel_indices], y[unlabel_indices]
X_test, y_test = X[test_indices], y[test_indices]
return X_label, y_label, X_unlabel, y_unlabel, X_test, y_test


# Load and preprocess the Zoo dataset
X, y = load_and_preprocess_dataset(dataset_id=62)

# Split data into labeled/unlabeled/test data
X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3)


# Transform tabluar data to the format required by ABL, which is a tuple of (X, ground truth of X, reasoning results)
# For tabular data in abl, each example contains a single instance (a row from the dataset).
# For these tabular data examples, the reasoning results are expected to be 0, indicating no rules are violated.
def transform_tab_data(X, y):
return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))


label_data = transform_tab_data(X_label, y_label)
test_data = transform_tab_data(X_test, y_test)
train_data = transform_tab_data(X_unlabel, y_unlabel)
@@ -181,9 +203,13 @@ print("------- Test the initial model -----------")
bridge.test(test_data)
print("------- Use ABL to train the model -----------")
# Use ABL to train the model
bridge.train(train_data=train_data, label_data=label_data, loops=3, segment_size=len(X_unlabel), save_dir=weights_dir)
bridge.train(
train_data=train_data,
label_data=label_data,
loops=3,
segment_size=len(X_unlabel),
save_dir=weights_dir,
)
print("------- Test the final model -----------")
# Test the final model
bridge.test(test_data)



+ 2
- 2
examples/zoo/zoo.ipynb View File

@@ -25,7 +25,7 @@
"from abl.learning import ABLModel\n",
"from examples.zoo.kb import ZooKB\n",
"from abl.reasoning import Reasoner\n",
"from abl.evaluation import ReasoningMetric, SymbolMetric\n",
"from abl.data.evaluation import ReasoningMetric, SymbolMetric\n",
"from abl.utils import ABLLogger, print_log, confidence_dist\n",
"from abl.bridge import SimpleBridge"
]
@@ -56,7 +56,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"`train_data` and `test_data` share identical structures: tuples with three components: X (list where each element is a list of two images), gt_pseudo_label (list where each element is a list of two digits, i.e., pseudo-labels) and Y (list where each element is the sum of the two digits). The length and structures of datasets are illustrated as follows.\n",
"`train_data` and `test_data` share identical data.structures: tuples with three components: X (list where each element is a list of two images), gt_pseudo_label (list where each element is a list of two digits, i.e., pseudo-labels) and Y (list where each element is the sum of the two digits). The length and data.structures of datasets are illustrated as follows.\n",
"\n",
"Note: ``gt_pseudo_label`` is only used to evaluate the performance of the learning part but not to train the model."
]


+ 3
- 1
tests/conftest.py View File

@@ -6,7 +6,7 @@ import torch.optim as optim

from abl.learning import BasicNN
from abl.reasoning import GroundKB, KBBase, PrologKB, Reasoner
from abl.structures import ListData
from abl.data.structures import ListData


class LeNet5(nn.Module):
@@ -202,10 +202,12 @@ def kb_add_prolog():
kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/add.pl")
return kb


@pytest.fixture
def kb_hwf1():
return HwfKB(max_err=0.1)


@pytest.fixture
def kb_hwf2():
return HwfKB(max_err=1)


Loading…
Cancel
Save