@@ -8,7 +8,7 @@ from ..structures import ListData | |||||
class BaseBridge(metaclass=ABCMeta): | class BaseBridge(metaclass=ABCMeta): | ||||
""" | """ | ||||
A base class for bridging machine learning and reasoning parts. | |||||
A base class for bridging learning and reasoning parts. | |||||
This class provides necessary methods that need to be overridden in subclasses | This class provides necessary methods that need to be overridden in subclasses | ||||
to construct a typical pipeline of Abductive learning (corresponding to ``train``), | to construct a typical pipeline of Abductive learning (corresponding to ``train``), | ||||
@@ -10,7 +10,7 @@ class ReasoningMetric(BaseMetric): | |||||
A metrics class for evaluating the model performance on tasks need reasoning. | A metrics class for evaluating the model performance on tasks need reasoning. | ||||
This class is designed to calculate the accuracy of the reasoing results. Reasoning | This class is designed to calculate the accuracy of the reasoing results. Reasoning | ||||
results are generated by first using the machine learning part to predict pseudo labels | |||||
results are generated by first using the learning part to predict pseudo labels | |||||
and then using a knowledge base (KB) to perform logical reasoning. The reasoning results | and then using a knowledge base (KB) to perform logical reasoning. The reasoning results | ||||
are then compared with the ground truth to calculate the accuracy. | are then compared with the ground truth to calculate the accuracy. | ||||
@@ -75,7 +75,7 @@ To implement this process, the following five steps are necessary: | |||||
2. Build the learning part | 2. Build the learning part | ||||
Build a model that can predict inputs to pseudo labels. | |||||
Build a machine learning model that can predict inputs to pseudo labels. | |||||
Then, use ``ABLModel`` to encapsulate the model. | Then, use ``ABLModel`` to encapsulate the model. | ||||
3. Build the reasoning part | 3. Build the reasoning part | ||||
@@ -89,7 +89,7 @@ To implement this process, the following five steps are necessary: | |||||
Define the metrics for measuring accuracy by inheriting from ``BaseMetric``. | Define the metrics for measuring accuracy by inheriting from ``BaseMetric``. | ||||
5. Bridge machine learning and reasoning | |||||
5. Bridge learning and reasoning | |||||
Use ``SimpleBridge`` to bridge the learning and reasoning part | Use ``SimpleBridge`` to bridge the learning and reasoning part | ||||
for integrated training and testing. | for integrated training and testing. |
@@ -10,30 +10,33 @@ | |||||
Bridge | Bridge | ||||
====== | ====== | ||||
Bridging machine learning and reasoning to train the model is the fundamental idea of Abductive Learning, ABL-Package implements a set of `bridge class <../API/abl.bridge.html>`_ to achieve this. | |||||
Bridging learning and reasoning part to train the model is the fundamental idea of Abductive Learning. ABL-Package implements a set of bridge classes to achieve this. | |||||
``BaseBridge`` is an abstract class with the following initialization parameters: | ``BaseBridge`` is an abstract class with the following initialization parameters: | ||||
- ``model``: an object of type ``ABLModel``. Machine Learning part are wrapped in this object. | |||||
- ``reasoner``: a object of type ``Reasoner``. Reasoning part are wrapped in this object. | |||||
- ``model`` is an object of type ``ABLModel``. Learning part are wrapped in this object. | |||||
- ``reasoner`` is a object of type ``Reasoner``. Reasoning part are wrapped in this object. | |||||
``BaseBridge`` has the following important methods that need to be overridden in subclasses: | ``BaseBridge`` has the following important methods that need to be overridden in subclasses: | ||||
+-----------------------------------+--------------------------------------------------------------------------------------+ | |||||
| Method Signature | Description | | |||||
+===================================+======================================================================================+ | |||||
| predict(data_samples) | Predicts class probabilities and indices for the given data samples. | | |||||
+-----------------------------------+--------------------------------------------------------------------------------------+ | |||||
| abduce_pseudo_label(data_samples) | Abduces pseudo labels for the given data samples. | | |||||
+-----------------------------------+--------------------------------------------------------------------------------------+ | |||||
| idx_to_pseudo_label(data_samples) | Converts indices to pseudo labels using the provided or default mapping. | | |||||
+-----------------------------------+--------------------------------------------------------------------------------------+ | |||||
| pseudo_label_to_idx(data_samples) | Converts pseudo labels to indices using the provided or default remapping. | | |||||
+-----------------------------------+--------------------------------------------------------------------------------------+ | |||||
| train(train_data) | Train the model. | | |||||
+-----------------------------------+--------------------------------------------------------------------------------------+ | |||||
| test(test_data) | Test the model. | | |||||
+-----------------------------------+--------------------------------------------------------------------------------------+ | |||||
+---------------------------------------+----------------------------------------------------+ | |||||
| Method Signature | Description | | |||||
+=======================================+====================================================+ | |||||
| ``predict(data_samples)`` | Predicts class probabilities and indices | | |||||
| | for the given data samples. | | |||||
+---------------------------------------+----------------------------------------------------+ | |||||
| ``abduce_pseudo_label(data_samples)`` | Abduces pseudo labels for the given data samples. | | |||||
+---------------------------------------+----------------------------------------------------+ | |||||
| ``idx_to_pseudo_label(data_samples)`` | Converts indices to pseudo labels using | | |||||
| | the provided or default mapping. | | |||||
+---------------------------------------+----------------------------------------------------+ | |||||
| ``pseudo_label_to_idx(data_samples)`` | Converts pseudo labels to indices | | |||||
| | using the provided or default remapping. | | |||||
+---------------------------------------+----------------------------------------------------+ | |||||
| ``train(train_data)`` | Train the model. | | |||||
+---------------------------------------+----------------------------------------------------+ | |||||
| ``test(test_data)`` | Test the model. | | |||||
+---------------------------------------+----------------------------------------------------+ | |||||
where ``train_data`` and ``test_data`` are both in the form of ``(X, gt_pseudo_label, Y)``. They will be used to construct ``ListData`` instances which are referred to as ``data_samples`` in the ``train`` and ``test`` methods respectively. More details can be found in `preparing datasets <Datasets.html>`_. | where ``train_data`` and ``test_data`` are both in the form of ``(X, gt_pseudo_label, Y)``. They will be used to construct ``ListData`` instances which are referred to as ``data_samples`` in the ``train`` and ``test`` methods respectively. More details can be found in `preparing datasets <Datasets.html>`_. | ||||
@@ -16,13 +16,16 @@ Dataset | |||||
ABL-Package assumes user data to be structured as a tuple, comprising the following three components: | ABL-Package assumes user data to be structured as a tuple, comprising the following three components: | ||||
- ``X``: List[List[Any]] | - ``X``: List[List[Any]] | ||||
A list of instances representing the input data. We refer to each List in ``X`` as an instance and one instance may contain several elements. | |||||
A list of sublists representing the input data. We refer to each sublist in ``X`` as an instance and each instance may contain several elements. | |||||
- ``gt_pseudo_label``: List[List[Any]], optional | - ``gt_pseudo_label``: List[List[Any]], optional | ||||
A list of objects representing the ground truth label of each element in ``X``. It should have the same shape as ``X``. This component is only used to evaluate the performance of the machine learning part but not to train the model. If elements are unlabeled, this component can be ``None``. | |||||
A list of sublists with each sublist representing ground truth pseudo labels for each **element** within an instance of ``X``. | |||||
- ``Y``: List[Any] | - ``Y``: List[Any] | ||||
A list of objects representing the ground truth label of the reasoning result of each instance in ``X``. | |||||
A list representing the ground truth reasoning result for each **instance** in ``X``. | |||||
In the MNIST Addition example, the data used for training looks like: | |||||
.. warning:: | |||||
Each sublist in ``gt_pseudo_label`` should have the same length as the sublist in ``X``. ``gt_pseudo_label`` is only used to evaluate the performance of the learning part but not to train the model. If the pseudo label of the elements in the datasets are unlabeled, ``gt_pseudo_label`` can be ``None``. | |||||
As an illustration, in the MNIST Addition example, the data used for training are organized as follows: | |||||
.. image:: ../img/Datasets_1.png | .. image:: ../img/Datasets_1.png | ||||
:width: 350px | :width: 350px | ||||
@@ -31,7 +34,7 @@ In the MNIST Addition example, the data used for training looks like: | |||||
Data Structure | Data Structure | ||||
-------------- | -------------- | ||||
In Abductive Learning, there are various types of data in the training and testing process, such as raw data, pseudo label, index of the pseudo label, abduced pseudo label, etc. To make the interface stable and possessing good versatility, ABL-Package uses `abstract data interfaces <../API/abl.structures.html>`_ to encapsulate various data during the implementation of the model. | |||||
In Abductive Learning, there are various types of data in the training and testing process, such as raw data, pseudo label, index of the pseudo label, abduced pseudo label, etc. To enhance the stability and versatility, ABL-Package uses `abstract data interfaces <../API/abl.structures.html>`_ to encapsulate various data during the implementation of the model. | |||||
One of the most commonly used abstract data interface is ``ListData``. Besides orginizing data into tuple, we can also prepare data to be in the form of this data interface. | One of the most commonly used abstract data interface is ``ListData``. Besides orginizing data into tuple, we can also prepare data to be in the form of this data interface. | ||||
@@ -10,11 +10,15 @@ | |||||
Learning Part | Learning Part | ||||
============= | ============= | ||||
Learnig part is constructed by first defining a base machine learning model and then wrap it into an instance of ``ABLModel`` class. | |||||
Learning part is constructed by first defining a base machine learning model and then wrap it into an instance of ``ABLModel`` class. | |||||
The flexibility of ABL package allows the base model to be any machine learning model conforming to the scikit-learn style, which requires implementing the ``fit`` and ``predict`` methods, or a PyTorch-based neural network, provided it has defined the architecture and implemented the ``forward`` method. | |||||
For base model, ABL package allows it to be one of the following forms: | |||||
Typically, base models are trained and make predictions on instance-level data, e.g. single images in the MNIST dataset, and therefore can not directly utilize sample-level data to train and predict, which is not suitable for most neural-symbolic tasks. ABL-Package provides the ``ABLModel`` to solve this problem. This class serves as a unified wrapper for all base models, which enables the learning part to train, test, and predict on sample-level data. The following two parts shows how to construct an ``ABLModel`` from a scikit-learn model and a PyTorch-based neural network, respectively. | |||||
1. Any machine learning model conforming to the scikit-learn style, i.e., models which has implemented the ``fit`` and ``predict`` methods; | |||||
2. A PyTorch-based neural network, provided it has defined the architecture and implemented the ``forward`` method. | |||||
However, base models are typically trained and make predictions on instance-level data, e.g. single images in the MNIST dataset, and therefore can not directly utilize sample-level data to train and predict, which is not suitable for most neural-symbolic tasks. ABL-Package provides the ``ABLModel`` to solve this problem. This class serves as a unified wrapper for all base models, which enables the learning part to train, test, and predict on sample-level data. The following two parts shows how to construct an ``ABLModel`` from a scikit-learn model and a PyTorch-based neural network, respectively. | |||||
For a scikit-learn model, we can directly use the model to create an instance of ``ABLModel``. For example, we can customize our machine learning model by | For a scikit-learn model, we can directly use the model to create an instance of ``ABLModel``. For example, we can customize our machine learning model by | ||||
@@ -43,18 +47,18 @@ For a PyTorch-based neural network, we first need to encapsulate it within a ``B | |||||
Besides ``fit`` and ``predict``, ``BasicNN`` also implements the following methods: | Besides ``fit`` and ``predict``, ``BasicNN`` also implements the following methods: | ||||
+---------------------------+----------------------------------------+ | |||||
| Method | Function | | |||||
+===========================+========================================+ | |||||
| train_epoch(data_loader) | Train the neural network for one epoch.| | |||||
+---------------------------+----------------------------------------+ | |||||
| predict_proba(X) | Predict the class probabilities of X. | | |||||
+---------------------------+----------------------------------------+ | |||||
| score(X, y) | Calculate the accuracy of the model on | | |||||
| | test data. | | |||||
+---------------------------+----------------------------------------+ | |||||
| save(epoch_id, save_path) | Save the model. | | |||||
+---------------------------+----------------------------------------+ | |||||
| load(load_path) | Load the model. | | |||||
+---------------------------+----------------------------------------+ | |||||
+-------------------------------+------------------------------------------+ | |||||
| Method | Function | | |||||
+===============================+==========================================+ | |||||
| ``train_epoch(data_loader)`` | Train the neural network for one epoch. | | |||||
+-------------------------------+------------------------------------------+ | |||||
| ``predict_proba(X)`` | Predict the class probabilities of ``X``.| | |||||
+-------------------------------+------------------------------------------+ | |||||
| ``score(X, y)`` | Calculate the accuracy of the model on | | |||||
| | test data. | | |||||
+-------------------------------+------------------------------------------+ | |||||
| ``save(epoch_id, save_path)`` | Save the model. | | |||||
+-------------------------------+------------------------------------------+ | |||||
| ``load(load_path)`` | Load the model. | | |||||
+-------------------------------+------------------------------------------+ | |||||
@@ -15,7 +15,7 @@ Working with Data | |||||
----------------- | ----------------- | ||||
ABL-Package assumes data to be in the form of ``(X, gt_pseudo_label, Y)`` where ``X`` is the input of the machine learning model, | ABL-Package assumes data to be in the form of ``(X, gt_pseudo_label, Y)`` where ``X`` is the input of the machine learning model, | ||||
``gt_pseudo_label`` is the ground truth label of each element in ``X`` and ``Y`` is the ground truth reasoning result of each instance in ``X``. Note that ``gt_pseudo_label`` is only used to evaluate the performance of the machine learning part but not to train the model. If elements in ``X`` are unlabeled, ``gt_pseudo_label`` can be ``None``. | |||||
``gt_pseudo_label`` is the ground truth label of each element in ``X`` and ``Y`` is the ground truth reasoning result of each instance in ``X``. Note that ``gt_pseudo_label`` is only used to evaluate the performance of the machine learning model but not to train it. If elements in ``X`` are unlabeled, ``gt_pseudo_label`` can be ``None``. | |||||
In the MNIST Addition task, the data loading looks like | In the MNIST Addition task, the data loading looks like | ||||
@@ -23,51 +23,11 @@ In the MNIST Addition task, the data loading looks like | |||||
from examples.mnist_add.datasets.get_mnist_add import get_mnist_add | from examples.mnist_add.datasets.get_mnist_add import get_mnist_add | ||||
# train_data and test_data are all tuples consist of X, gt_pseudo_label and Y. | |||||
# If get_pseudo_label is False, gt_pseudo_label will be None | |||||
# train_data and test_data both consists of multiple (X, gt_pseudo_label, Y) tuples. | |||||
# If get_pseudo_label is False, gt_pseudo_label in each tuple will be None. | |||||
train_data = get_mnist_add(train=True, get_pseudo_label=True) | train_data = get_mnist_add(train=True, get_pseudo_label=True) | ||||
test_data = get_mnist_add(train=False, get_pseudo_label=True) | test_data = get_mnist_add(train=False, get_pseudo_label=True) | ||||
ABL-Package assumes ``X`` to be of type ``List[List[Any]]``, ``gt_pseudo_label`` can be ``None`` or of the type ``List[List[Any]]`` and ``Y`` should be of type ``List[Any]``. The following code shows the structure of the dataset used in MNIST Addition. | |||||
.. code:: python | |||||
def describe_structure(lst): | |||||
if not isinstance(lst, list): | |||||
return type(lst).__name__ | |||||
return [describe_structure(item) for item in lst] | |||||
X, gt_pseudo_label, Y = train_data | |||||
print(f"Length of X List[List[Any]]: {len(X)}") | |||||
print(f"Length of gt_pseudo_label List[List[Any]]: {len(gt_pseudo_label)}") | |||||
print(f"Length of Y List[Any]: {len(Y)}\n") | |||||
structure_X = describe_structure(X[:3]) | |||||
print(f"Structure of X: {structure_X}") | |||||
structure_gt_pseudo_label = describe_structure(gt_pseudo_label[:3]) | |||||
print(f"Structure of gt_pseudo_label: {structure_gt_pseudo_label}") | |||||
structure_Y = describe_structure(Y[:3]) | |||||
print(f"Structure of Y: {structure_Y}\n") | |||||
print(f"Shape of X [C, H, W]: {X[0][0].shape}") | |||||
Out: | |||||
.. code-block:: none | |||||
:class: code-out | |||||
Length of X List[List[Any]]: 30000 | |||||
Length of gt_pseudo_label List[List[Any]]: 30000 | |||||
Length of Y List[Any]: 30000 | |||||
Structure of X: [['Tensor', 'Tensor'], ['Tensor', 'Tensor'], ['Tensor', 'Tensor']] | |||||
Structure of gt_pseudo_label: [['int', 'int'], ['int', 'int'], ['int', 'int']] | |||||
Structure of Y: ['int', 'int', 'int'] | |||||
Shape of X [C, H, W]: torch.Size([1, 28, 28]) | |||||
Read more about `preparing datasets <Datasets.html>`_. | Read more about `preparing datasets <Datasets.html>`_. | ||||
Building the Learning Part | Building the Learning Part | ||||
@@ -96,21 +56,6 @@ To facilitate uniform processing, ABL-Package provides the ``BasicNN`` class to | |||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||
base_model = BasicNN(cls, loss_fn, optimizer, device) | base_model = BasicNN(cls, loss_fn, optimizer, device) | ||||
.. code:: python | |||||
pred_idx = base_model.predict(X=[torch.randn(1, 28, 28).to(device) for _ in range(32)]) | |||||
print(f"Shape of pred_idx : {pred_idx.shape}") | |||||
pred_prob = base_model.predict_proba(X=[torch.randn(1, 28, 28).to(device) for _ in range(32)]) | |||||
print(f"Shape of pred_prob : {pred_prob.shape}") | |||||
Out: | |||||
.. code-block:: none | |||||
:class: code-out | |||||
Shape of pred_idx : (32,) | |||||
Shape of pred_prob : (32, 10) | |||||
Afterward, we wrap the scikit-learn style model, ``base_model``, into an instance of ``ABLModel``. This class serves as a unified wrapper for all base models, facilitating the learning part to train, test, and predict on sample-level data - such as equations in the MNIST Addition task. | Afterward, we wrap the scikit-learn style model, ``base_model``, into an instance of ``ABLModel``. This class serves as a unified wrapper for all base models, facilitating the learning part to train, test, and predict on sample-level data - such as equations in the MNIST Addition task. | ||||
.. code:: python | .. code:: python | ||||
@@ -156,7 +101,6 @@ knowledge base and the prediction from the learning part. | |||||
Read more about `building the reasoning part <Reasoning.html>`_. | Read more about `building the reasoning part <Reasoning.html>`_. | ||||
Building Evaluation Metrics | Building Evaluation Metrics | ||||
--------------------------- | --------------------------- | ||||
@@ -188,43 +132,4 @@ Finally, we proceed with training and testing. | |||||
bridge.train(train_data, loops=5, segment_size=1/3) | bridge.train(train_data, loops=5, segment_size=1/3) | ||||
bridge.test(test_data) | bridge.test(test_data) | ||||
Training log would be similar to this: | |||||
.. code-block:: none | |||||
:class: code-out | |||||
abl - INFO - Abductive Learning on the MNIST Add example. | |||||
abl - INFO - loop(train) [1/5] segment(train) [1/3] | |||||
abl - INFO - model loss: 1.91761 | |||||
abl - INFO - loop(train) [1/5] segment(train) [2/3] | |||||
abl - INFO - model loss: 1.59485 | |||||
abl - INFO - loop(train) [1/5] segment(train) [3/3] | |||||
abl - INFO - model loss: 1.33183 | |||||
abl - INFO - Evaluation start: loop(val) [1] | |||||
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.450 mnist_add/reasoning_accuracy: 0.237 | |||||
abl - INFO - Saving model: loop(save) [1] | |||||
abl - INFO - Checkpoints will be saved to results/work_dir/weights/model_checkpoint_loop_1.pth | |||||
abl - INFO - loop(train) [2/5] segment(train) [1/3] | |||||
abl - INFO - model loss: 1.00664 | |||||
abl - INFO - loop(train) [2/5] segment(train) [2/3] | |||||
abl - INFO - model loss: 0.52233 | |||||
abl - INFO - loop(train) [2/5] segment(train) [3/3] | |||||
abl - INFO - model loss: 0.11282 | |||||
abl - INFO - Evaluation start: loop(val) [2] | |||||
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.976 mnist_add/reasoning_accuracy: 0.954 | |||||
abl - INFO - Saving model: loop(save) [2] | |||||
abl - INFO - Checkpoints will be saved to results/work_dir/weights/model_checkpoint_loop_2.pth | |||||
... | |||||
abl - INFO - loop(train) [5/5] segment(train) [1/3] | |||||
abl - INFO - model loss: 0.04030 | |||||
abl - INFO - loop(train) [5/5] segment(train) [2/3] | |||||
abl - INFO - model loss: 0.03859 | |||||
abl - INFO - loop(train) [5/5] segment(train) [3/3] | |||||
abl - INFO - model loss: 0.03423 | |||||
abl - INFO - Evaluation start: loop(val) [5] | |||||
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.992 mnist_add/reasoning_accuracy: 0.984 | |||||
abl - INFO - Saving model: loop(save) [5] | |||||
abl - INFO - Checkpoints will be saved to results/work_dir/weights/model_checkpoint_loop_5.pth | |||||
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.987 mnist_add/reasoning_accuracy: 0.975 | |||||
Read more about `bridging machine learning and reasoning <Bridge.html>`_. | Read more about `bridging machine learning and reasoning <Bridge.html>`_. |
@@ -39,9 +39,9 @@ For the user-built KB from `KBBase` (an inherited subclass), it's only | |||||
required to pass the ``pseudo_label_list`` parameter in the ``__init__`` function | required to pass the ``pseudo_label_list`` parameter in the ``__init__`` function | ||||
and override the ``logic_forward`` function: | and override the ``logic_forward`` function: | ||||
- **pseudo_label_list** is the list of possible pseudo labels (also, | |||||
- ``pseudo_label_list`` is the list of possible pseudo labels (also, | |||||
the output of the machine learning model). | the output of the machine learning model). | ||||
- **logic_forward** defines how to perform (deductive) reasoning, | |||||
- ``logic_forward`` defines how to perform (deductive) reasoning, | |||||
i.e. matching each pseudo label sample (often consisting of multiple | i.e. matching each pseudo label sample (often consisting of multiple | ||||
pseudo labels) to its reasoning result. | pseudo labels) to its reasoning result. | ||||
@@ -76,11 +76,10 @@ and (deductive) reasoning in ``add_kb`` would be: | |||||
print(f"Reasoning result of pseudo label sample {pseudo_label_sample} is {reasoning_result}.") | print(f"Reasoning result of pseudo label sample {pseudo_label_sample} is {reasoning_result}.") | ||||
Out: | Out: | ||||
.. code:: none | |||||
:class: code-out | |||||
.. code:: none | |||||
:class: code-out | |||||
Reasoning result of pseudo label sample [1, 2] is 3 | |||||
Reasoning result of pseudo label sample [1, 2] is 3 | |||||
.. _other-par: | .. _other-par: | ||||
@@ -90,17 +89,17 @@ Other optional parameters | |||||
We can also pass the following parameters in the ``__init__`` function when building our | We can also pass the following parameters in the ``__init__`` function when building our | ||||
knowledge base: | knowledge base: | ||||
- **max_err** (float, optional), specifying the upper tolerance limit | |||||
- ``max_err`` (float, optional), specifying the upper tolerance limit | |||||
when comparing the similarity between a pseudo label sample's reasoning result | when comparing the similarity between a pseudo label sample's reasoning result | ||||
and the ground truth during abductive reasoning. This is only | and the ground truth during abductive reasoning. This is only | ||||
applicable when the reasoning result is of a numerical type. This is | applicable when the reasoning result is of a numerical type. This is | ||||
particularly relevant for regression problems where exact matches | particularly relevant for regression problems where exact matches | ||||
might not be feasible. Defaults to 1e-10. See :ref:`an example <kb-abd-2>`. | might not be feasible. Defaults to 1e-10. See :ref:`an example <kb-abd-2>`. | ||||
- **use_cache** (bool, optional), indicating whether to use cache to store | |||||
- ``use_cache`` (bool, optional), indicating whether to use cache to store | |||||
previous candidates (pseudo label samples generated from abductive reasoning) | previous candidates (pseudo label samples generated from abductive reasoning) | ||||
to speed up subsequent abductive reasoning operations. Defaults to True. | to speed up subsequent abductive reasoning operations. Defaults to True. | ||||
For more information of abductive reasoning, please refer to :ref:`this <kb-abd>`. | For more information of abductive reasoning, please refer to :ref:`this <kb-abd>`. | ||||
- **cache_size** (int, optional), specifying the maximum cache | |||||
- ``cache_size`` (int, optional), specifying the maximum cache | |||||
size. This is only operational when ``use_cache`` is set to True. | size. This is only operational when ``use_cache`` is set to True. | ||||
Defaults to 4096. | Defaults to 4096. | ||||
@@ -173,7 +172,7 @@ override the ``logic_forward`` function, and are allowed to pass other | |||||
:ref:`optional parameters <other-par>`. Additionally, we are required pass the | :ref:`optional parameters <other-par>`. Additionally, we are required pass the | ||||
``GKB_len_list`` parameter in the ``__init__`` function. | ``GKB_len_list`` parameter in the ``__init__`` function. | ||||
- **GKB_len_list** is the list of possible lengths for a pseudo label sample. | |||||
- ``GKB_len_list`` is the list of possible lengths for a pseudo label sample. | |||||
After that, other operations, including auto-construction of GKB, and | After that, other operations, including auto-construction of GKB, and | ||||
how to perform abductive reasoning, will be **automatically** set up. | how to perform abductive reasoning, will be **automatically** set up. | ||||
@@ -216,13 +215,13 @@ previously identified pseudo label sample. | |||||
``abduce_candidates(pseudo_label, y, max_revision_num, require_more_revision)`` | ``abduce_candidates(pseudo_label, y, max_revision_num, require_more_revision)`` | ||||
for performing abductive reasoning, where the parameters are: | for performing abductive reasoning, where the parameters are: | ||||
- **pseudo_label**, the pseudo label sample to be revised by abductive | |||||
- ``pseudo_label``, the pseudo label sample to be revised by abductive | |||||
reasoning, usually generated by the learning part. | reasoning, usually generated by the learning part. | ||||
- **y**, the ground truth of the reasoning result for the sample. The | |||||
- ``y``, the ground truth of the reasoning result for the sample. The | |||||
returned candidates should be compatible with it. | returned candidates should be compatible with it. | ||||
- **max_revision_num**, an int value specifying the upper limit on the | |||||
- ``max_revision_num``, an int value specifying the upper limit on the | |||||
number of revised labels for each sample. | number of revised labels for each sample. | ||||
- **require_more_revision**, an int value specifying additional number | |||||
- ``require_more_revision``, an int value specifying additional number | |||||
of revisions permitted beyond the minimum required. (e.g., If we set | of revisions permitted beyond the minimum required. (e.g., If we set | ||||
it to 0, even if ``max_revision_num`` is set to a high value, the | it to 0, even if ``max_revision_num`` is set to a high value, the | ||||
method will only output candidates with the minimum possible | method will only output candidates with the minimum possible | ||||
@@ -291,19 +290,19 @@ example for MNIST Addition, the reasoner definition would be: | |||||
When instantiating, besides the required knowledge base, we may also | When instantiating, besides the required knowledge base, we may also | ||||
specify: | specify: | ||||
- **max_revision** (int or float, optional), specifies the upper limit | |||||
- ``max_revision`` (int or float, optional), specifies the upper limit | |||||
on the number of revisions for each sample when performing | on the number of revisions for each sample when performing | ||||
:ref:`abductive reasoning in the knowledge base <kb-abd>`. If float, denotes the | :ref:`abductive reasoning in the knowledge base <kb-abd>`. If float, denotes the | ||||
fraction of the total length that can be revised. A value of -1 | fraction of the total length that can be revised. A value of -1 | ||||
implies no restriction on the number of revisions. Defaults to -1. | implies no restriction on the number of revisions. Defaults to -1. | ||||
- **require_more_revision** (int, optional), Specifies additional | |||||
- ``require_more_revision`` (int, optional), Specifies additional | |||||
number of revisions permitted beyond the minimum required when | number of revisions permitted beyond the minimum required when | ||||
performing :ref:`abductive reasoning in the knowledge base <kb-abd>`. Defaults to | performing :ref:`abductive reasoning in the knowledge base <kb-abd>`. Defaults to | ||||
0. | 0. | ||||
- **use_zoopt** (bool, optional), indicating whether to use the `ZOOpt library <https://github.com/polixir/ZOOpt>`_, | |||||
- ``use_zoopt`` (bool, optional), indicating whether to use the `ZOOpt library <https://github.com/polixir/ZOOpt>`_, | |||||
which is a library for zeroth-order optimization that can be used to | which is a library for zeroth-order optimization that can be used to | ||||
accelerate consistency minimization. Defaults to False. | accelerate consistency minimization. Defaults to False. | ||||
- **dist_func** (str, optional), specifying the distance function to be | |||||
- ``dist_func`` (str, optional), specifying the distance function to be | |||||
used when determining consistency between your prediction and | used when determining consistency between your prediction and | ||||
candidate returned from knowledge base. Valid options include | candidate returned from knowledge base. Valid options include | ||||
“confidence” (default) and “hamming”. For “confidence”, it calculates | “confidence” (default) and “hamming”. For “confidence”, it calculates | ||||
@@ -354,11 +353,10 @@ the output would differ for each sample: | |||||
print(f"The outputs for sample1 and sample2 are {candidate1} and {candidate2}, respectively.") | print(f"The outputs for sample1 and sample2 are {candidate1} and {candidate2}, respectively.") | ||||
Out: | Out: | ||||
.. code:: none | |||||
:class: code-out | |||||
.. code:: none | |||||
:class: code-out | |||||
The outputs for sample1 and sample2 are [1,7] and [7,1], respectively. | |||||
The outputs for sample1 and sample2 are [1,7] and [7,1], respectively. | |||||
Specifically, as mentioned before, ``confidence`` calculates the distance between the data | Specifically, as mentioned before, ``confidence`` calculates the distance between the data | ||||
sample and candidates based on the confidence derived from the predicted probability. | sample and candidates based on the confidence derived from the predicted probability. | ||||