Browse Source

[FIX] enable github actions running successfully

pull/1/head
Gao Enhao 1 year ago
parent
commit
9c888183f8
48 changed files with 426 additions and 944 deletions
  1. +8
    -1
      .coveragerc
  2. +1
    -1
      .github/workflows/build-and-test.yaml
  3. +2
    -2
      .github/workflows/lint.yaml
  4. +11
    -2
      abl/__init__.py
  5. +1
    -1
      abl/__version__.py
  6. +3
    -1
      abl/bridge/__init__.py
  7. +3
    -16
      abl/bridge/base_bridge.py
  8. +5
    -4
      abl/bridge/simple_bridge.py
  9. +7
    -0
      abl/dataset/__init__.py
  10. +6
    -2
      abl/dataset/bridge_dataset.py
  11. +5
    -5
      abl/dataset/classification_dataset.py
  12. +3
    -1
      abl/dataset/prediction_dataset.py
  13. +2
    -2
      abl/dataset/regression_dataset.py
  14. +2
    -0
      abl/evaluation/__init__.py
  15. +12
    -12
      abl/evaluation/base_metric.py
  16. +3
    -3
      abl/evaluation/symbol_metric.py
  17. +3
    -1
      abl/learning/__init__.py
  18. +9
    -6
      abl/learning/abl_model.py
  19. +7
    -4
      abl/learning/basic_nn.py
  20. +4
    -2
      abl/reasoning/__init__.py
  21. +29
    -30
      abl/reasoning/kb.py
  22. +6
    -10
      abl/reasoning/reasoner.py
  23. +3
    -1
      abl/structures/__init__.py
  24. +4
    -10
      abl/structures/base_data_element.py
  25. +1
    -2
      abl/structures/list_data.py
  26. +21
    -1
      abl/utils/__init__.py
  27. +0
    -6
      abl/utils/cache.py
  28. +15
    -7
      abl/utils/logger.py
  29. +12
    -12
      abl/utils/manager.py
  30. +0
    -57
      abl/utils/utils.py
  31. +3
    -0
      build_tools/requirements.txt
  32. +1
    -6
      docs/Examples/MNISTAdd.rst
  33. +5
    -7
      docs/conf.py
  34. +3
    -3
      examples/hed/datasets/get_hed.py
  35. +0
    -388
      examples/hed/framework_hed.py
  36. +15
    -14
      examples/hed/hed_bridge.py
  37. +8
    -6
      examples/hed/hed_example.ipynb
  38. +0
    -69
      examples/hed/hed_knn_example.py
  39. +0
    -159
      examples/hed/hed_tmp.py
  40. +4
    -5
      examples/hed/utils.py
  41. +2
    -4
      examples/hwf/datasets/get_hwf.py
  42. +12
    -8
      examples/hwf/hwf_example.ipynb
  43. +1
    -1
      examples/mnist_add/mnist_add_example.ipynb
  44. +2
    -4
      examples/models/nn.py
  45. +9
    -3
      setup.py
  46. +65
    -22
      tests/conftest.py
  47. +2
    -1
      tests/test_abl_model.py
  48. +106
    -42
      tests/test_reasoning.py

+ 8
- 1
.coveragerc View File

@@ -5,4 +5,11 @@ show_missing = True
disable_warnings = include-ignored disable_warnings = include-ignored
include = */abl/* include = */abl/*
omit = omit =
*/abl/__init__.py
*/abl/__init__.py
abl/bridge/__init__.py
abl/dataset/__init__.py
abl/evaluation/__init__.py
abl/learning/__init__.py
abl/reasoning/__init__.py
abl/structures/__init__.py
abl/utils/__init__.py

+ 1
- 1
.github/workflows/build-and-test.yaml View File

@@ -24,7 +24,7 @@ jobs:
- name: Install package dependencies - name: Install package dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install -r ./requirements.txt
pip install -r build_tools/requirements.txt
- name: Run tests - name: Run tests
run: | run: |
pytest --cov-config=.coveragerc --cov-report=xml --cov=abl ./tests pytest --cov-config=.coveragerc --cov-report=xml --cov=abl ./tests


+ 2
- 2
.github/workflows/lint.yaml View File

@@ -20,5 +20,5 @@ jobs:
- name: flake8 Lint - name: flake8 Lint
uses: py-actions/flake8@v2 uses: py-actions/flake8@v2
with: with:
max-line-length: "110"
plugins: "flake8-bugbear flake8-black"
max-line-length: "100"
args: --ignore=E203,W503

+ 11
- 2
abl/__init__.py View File

@@ -1,2 +1,11 @@
from .learning import abl_model, basic_nn
from .reasoning import reasoner, kb
from . import bridge, dataset, evaluation, learning, reasoning, structures, utils

__all__ = [
"bridge",
"dataset",
"evaluation",
"learning",
"reasoning",
"structures",
"utils",
]

+ 1
- 1
abl/__version__.py View File

@@ -1,3 +1,3 @@
VERSION = (0, 0, 1) VERSION = (0, 0, 1)


__version__ = ".".join(map(str, VERSION))
__version__ = ".".join(map(str, VERSION))

+ 3
- 1
abl/bridge/__init__.py View File

@@ -1,2 +1,4 @@
from .base_bridge import BaseBridge from .base_bridge import BaseBridge
from .simple_bridge import SimpleBridge
from .simple_bridge import SimpleBridge

__all__ = ["BaseBridge", "SimpleBridge"]

+ 3
- 16
abl/bridge/base_bridge.py View File

@@ -12,53 +12,40 @@ class BaseBridge(metaclass=ABCMeta):
def __init__(self, model: ABLModel, reasoner: Reasoner) -> None: def __init__(self, model: ABLModel, reasoner: Reasoner) -> None:
if not isinstance(model, ABLModel): if not isinstance(model, ABLModel):
raise TypeError( raise TypeError(
"Expected an instance of ABLModel, but received type: {}".format(
type(model)
)
"Expected an instance of ABLModel, but received type: {}".format(type(model))
) )
if not isinstance(reasoner, Reasoner): if not isinstance(reasoner, Reasoner):
raise TypeError( raise TypeError(
"Expected an instance of Reasoner, but received type: {}".format(
type(reasoner)
)
"Expected an instance of Reasoner, but received type: {}".format(type(reasoner))
) )


self.model = model self.model = model
self.reasoner = reasoner self.reasoner = reasoner


@abstractmethod @abstractmethod
def predict(
self, data_samples: ListData
) -> Tuple[List[List[Any]], List[List[Any]]]:
def predict(self, data_samples: ListData) -> Tuple[List[List[Any]], List[List[Any]]]:
"""Placeholder for predict labels from input.""" """Placeholder for predict labels from input."""
pass


@abstractmethod @abstractmethod
def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
"""Placeholder for abduce pseudo labels.""" """Placeholder for abduce pseudo labels."""
pass


@abstractmethod @abstractmethod
def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
"""Placeholder for map label space to symbol space.""" """Placeholder for map label space to symbol space."""
pass


@abstractmethod @abstractmethod
def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]: def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]:
"""Placeholder for map symbol space to label space.""" """Placeholder for map symbol space to label space."""
pass


@abstractmethod @abstractmethod
def train(self, train_data: Union[ListData, DataSet]): def train(self, train_data: Union[ListData, DataSet]):
"""Placeholder for train loop of ABductive Learning.""" """Placeholder for train loop of ABductive Learning."""
pass


@abstractmethod @abstractmethod
def valid(self, valid_data: Union[ListData, DataSet]) -> None: def valid(self, valid_data: Union[ListData, DataSet]) -> None:
"""Placeholder for model test.""" """Placeholder for model test."""
pass


@abstractmethod @abstractmethod
def test(self, test_data: Union[ListData, DataSet]) -> None: def test(self, test_data: Union[ListData, DataSet]) -> None:
"""Placeholder for model validation.""" """Placeholder for model validation."""
pass

+ 5
- 4
abl/bridge/simple_bridge.py View File

@@ -1,5 +1,5 @@
import os.path as osp import os.path as osp
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union


from numpy import ndarray from numpy import ndarray


@@ -32,8 +32,7 @@ class SimpleBridge(BaseBridge):
def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
pred_idx = data_samples.pred_idx pred_idx = data_samples.pred_idx
data_samples.pred_pseudo_label = [ data_samples.pred_pseudo_label = [
[self.reasoner.mapping[_idx] for _idx in sub_list]
for sub_list in pred_idx
[self.reasoner.mapping[_idx] for _idx in sub_list] for sub_list in pred_idx
] ]
return data_samples.pred_pseudo_label return data_samples.pred_pseudo_label


@@ -81,7 +80,9 @@ class SimpleBridge(BaseBridge):
loss = self.model.train(sub_data_samples) loss = self.model.train(sub_data_samples)


print_log( print_log(
f"loop(train) [{loop + 1}/{loops}] segment(train) [{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] model loss is {loss:.5f}",
f"loop(train) [{loop + 1}/{loops}] segment(train) \
[{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] \
model loss is {loss:.5f}",
logger="current", logger="current",
) )




+ 7
- 0
abl/dataset/__init__.py View File

@@ -2,3 +2,10 @@ from .bridge_dataset import BridgeDataset
from .classification_dataset import ClassificationDataset from .classification_dataset import ClassificationDataset
from .prediction_dataset import PredictionDataset from .prediction_dataset import PredictionDataset
from .regression_dataset import RegressionDataset from .regression_dataset import RegressionDataset

__all__ = [
"BridgeDataset",
"ClassificationDataset",
"PredictionDataset",
"RegressionDataset",
]

+ 6
- 2
abl/dataset/bridge_dataset.py View File

@@ -13,11 +13,15 @@ class BridgeDataset(Dataset):
gt_pseudo_label : List[List[Any]], optional gt_pseudo_label : List[List[Any]], optional
A list of objects representing the ground truth label of each element in ``X``. A list of objects representing the ground truth label of each element in ``X``.
Y : List[Any] Y : List[Any]
A list of objects representing the ground truth of the reasoning result of each instance in ``X``.
A list of objects representing the ground truth of the reasoning result of
each instance in ``X``.
""" """


def __init__( def __init__(
self, X: List[List[Any]], gt_pseudo_label: Optional[List[List[Any]]], Y: List[Any]
self,
X: List[List[Any]],
gt_pseudo_label: Optional[List[List[Any]]],
Y: List[Any],
): ):
if (not isinstance(X, list)) or (not isinstance(Y, list)): if (not isinstance(X, list)) or (not isinstance(Y, list)):
raise ValueError("X and Y should be of type list.") raise ValueError("X and Y should be of type list.")


+ 5
- 5
abl/dataset/classification_dataset.py View File

@@ -15,16 +15,16 @@ class ClassificationDataset(Dataset):
Y : List[int] Y : List[int]
The target data. The target data.
transform : Callable[..., Any], optional transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version. Defaults to None.
A function/transform that takes in an object and returns a transformed version.
Defaults to None.
""" """
def __init__(
self, X: List[Any], Y: List[int], transform: Callable[..., Any] = None
):

def __init__(self, X: List[Any], Y: List[int], transform: Callable[..., Any] = None):
if (not isinstance(X, list)) or (not isinstance(Y, list)): if (not isinstance(X, list)) or (not isinstance(Y, list)):
raise ValueError("X and Y should be of type list.") raise ValueError("X and Y should be of type list.")
if len(X) != len(Y): if len(X) != len(Y):
raise ValueError("Length of X and Y must be equal.") raise ValueError("Length of X and Y must be equal.")
self.X = X self.X = X
self.Y = torch.LongTensor(Y) self.Y = torch.LongTensor(Y)
self.transform = transform self.transform = transform


+ 3
- 1
abl/dataset/prediction_dataset.py View File

@@ -13,8 +13,10 @@ class PredictionDataset(Dataset):
X : List[Any] X : List[Any]
The input data. The input data.
transform : Callable[..., Any], optional transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version. Defaults to None.
A function/transform that takes in an object and returns a transformed version.
Defaults to None.
""" """

def __init__(self, X: List[Any], transform: Callable[..., Any] = None): def __init__(self, X: List[Any], transform: Callable[..., Any] = None):
if not isinstance(X, list): if not isinstance(X, list):
raise ValueError("X should be of type list.") raise ValueError("X should be of type list.")


+ 2
- 2
abl/dataset/regression_dataset.py View File

@@ -1,6 +1,5 @@
from typing import Any, List, Tuple from typing import Any, List, Tuple


import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset




@@ -15,12 +14,13 @@ class RegressionDataset(Dataset):
Y : List[Any] Y : List[Any]
A list of objects representing the output data. A list of objects representing the output data.
""" """

def __init__(self, X: List[Any], Y: List[Any]): def __init__(self, X: List[Any], Y: List[Any]):
if (not isinstance(X, list)) or (not isinstance(Y, list)): if (not isinstance(X, list)) or (not isinstance(Y, list)):
raise ValueError("X and Y should be of type list.") raise ValueError("X and Y should be of type list.")
if len(X) != len(Y): if len(X) != len(Y):
raise ValueError("Length of X and Y must be equal.") raise ValueError("Length of X and Y must be equal.")
self.X = X self.X = X
self.Y = Y self.Y = Y




+ 2
- 0
abl/evaluation/__init__.py View File

@@ -1,3 +1,5 @@
from .base_metric import BaseMetric from .base_metric import BaseMetric
from .semantics_metric import SemanticsMetric from .semantics_metric import SemanticsMetric
from .symbol_metric import SymbolMetric from .symbol_metric import SymbolMetric

__all__ = ["BaseMetric", "SemanticsMetric", "SymbolMetric"]

+ 12
- 12
abl/evaluation/base_metric.py View File

@@ -20,8 +20,10 @@ class BaseMetric(metaclass=ABCMeta):
will be used instead. Default: None will be used instead. Default: None
""" """


def __init__(self,
prefix: Optional[str] = None,) -> None:
def __init__(
self,
prefix: Optional[str] = None,
) -> None:
self.results: List[Any] = [] self.results: List[Any] = []
self.prefix = prefix or self.default_prefix self.prefix = prefix or self.default_prefix


@@ -65,20 +67,18 @@ class BaseMetric(metaclass=ABCMeta):
""" """
if len(self.results) == 0: if len(self.results) == 0:
print_log( print_log(
f'{self.__class__.__name__} got empty `self.results`. Please '
'ensure that the processed results are properly added into '
'`self.results` in `process` method.',
logger='current',
level=logging.WARNING)
f"{self.__class__.__name__} got empty `self.results`. Please "
"ensure that the processed results are properly added into "
"`self.results` in `process` method.",
logger="current",
level=logging.WARNING,
)


metrics = self.compute_metrics(self.results) metrics = self.compute_metrics(self.results)
# Add prefix to metric names # Add prefix to metric names
if self.prefix: if self.prefix:
metrics = {
'/'.join((self.prefix, k)): v
for k, v in metrics.items()
}
metrics = {"/".join((self.prefix, k)): v for k, v in metrics.items()}


# reset the results list # reset the results list
self.results.clear() self.results.clear()
return metrics
return metrics

+ 3
- 3
abl/evaluation/symbol_metric.py View File

@@ -14,15 +14,15 @@ class SymbolMetric(BaseMetric):


if not len(pred_pseudo_label) == len(gt_pseudo_label): if not len(pred_pseudo_label) == len(gt_pseudo_label):
raise ValueError("lengthes of pred_pseudo_label and gt_pseudo_label should be equal") raise ValueError("lengthes of pred_pseudo_label and gt_pseudo_label should be equal")
for pred_z, z in zip(pred_pseudo_label, gt_pseudo_label): for pred_z, z in zip(pred_pseudo_label, gt_pseudo_label):
correct_num = 0 correct_num = 0
for pred_symbol, symbol in zip(pred_z, z): for pred_symbol, symbol in zip(pred_z, z):
if pred_symbol == symbol: if pred_symbol == symbol:
correct_num += 1 correct_num += 1
self.results.append(correct_num / len(z)) self.results.append(correct_num / len(z))
def compute_metrics(self, results: list) -> dict: def compute_metrics(self, results: list) -> dict:
metrics = dict() metrics = dict()
metrics["character_accuracy"] = sum(results) / len(results) metrics["character_accuracy"] = sum(results) / len(results)
return metrics
return metrics

+ 3
- 1
abl/learning/__init__.py View File

@@ -1,2 +1,4 @@
from .abl_model import ABLModel from .abl_model import ABLModel
from .basic_nn import BasicNN
from .basic_nn import BasicNN

__all__ = ["ABLModel", "BasicNN"]

+ 9
- 6
abl/learning/abl_model.py View File

@@ -58,7 +58,8 @@ class ABLModel:
Parameters Parameters
---------- ----------
data_samples : ListData data_samples : ListData
A batch of data to train on, which typically contains the data, `X`, and the corresponding labels, `abduced_idx`.
A batch of data to train on, which typically contains the data, `X`, and the
corresponding labels, `abduced_idx`.


Returns Returns
------- -------
@@ -68,7 +69,7 @@ class ABLModel:
data_X = data_samples.flatten("X") data_X = data_samples.flatten("X")
data_y = data_samples.flatten("abduced_idx") data_y = data_samples.flatten("abduced_idx")
return self.base_model.fit(X=data_X, y=data_y) return self.base_model.fit(X=data_X, y=data_y)
def valid(self, data_samples: ListData) -> float: def valid(self, data_samples: ListData) -> float:
""" """
Validate the model on the given data. Validate the model on the given data.
@@ -76,7 +77,8 @@ class ABLModel:
Parameters Parameters
---------- ----------
data_samples : ListData data_samples : ListData
A batch of data to train on, which typically contains the data, `X`, and the corresponding labels, `abduced_idx`.
A batch of data to train on, which typically contains the data, `X`,
and the corresponding labels, `abduced_idx`.


Returns Returns
------- -------
@@ -94,7 +96,7 @@ class ABLModel:
method = getattr(model, operation) method = getattr(model, operation)
method(*args, **kwargs) method(*args, **kwargs)
else: else:
if not f"{operation}_path" in kwargs.keys():
if f"{operation}_path" not in kwargs.keys():
raise ValueError(f"'{operation}_path' should not be None") raise ValueError(f"'{operation}_path' should not be None")
else: else:
try: try:
@@ -104,9 +106,10 @@ class ABLModel:
elif operation == "load": elif operation == "load":
with open(kwargs["load_path"], "rb") as file: with open(kwargs["load_path"], "rb") as file:
self.base_model = pickle.load(file) self.base_model = pickle.load(file)
except:
except (OSError, pickle.PickleError):
raise NotImplementedError( raise NotImplementedError(
f"{type(model).__name__} object doesn't have the {operation} method and the default pickle-based {operation} method failed."
f"{type(model).__name__} object doesn't have the {operation} method \
and the default pickle-based {operation} method failed."
) )


def save(self, *args, **kwargs) -> None: def save(self, *args, **kwargs) -> None:


+ 7
- 4
abl/learning/basic_nn.py View File

@@ -1,5 +1,5 @@
import os
import logging import logging
import os
from typing import Any, Callable, List, Optional, T, Tuple from typing import Any, Callable, List, Optional, T, Tuple


import numpy import numpy
@@ -23,7 +23,8 @@ class BasicNN:
optimizer : torch.optim.Optimizer optimizer : torch.optim.Optimizer
The optimizer used for training. The optimizer used for training.
device : torch.device, optional device : torch.device, optional
The device on which the model will be trained or used for prediction, by default torch.device("cpu").
The device on which the model will be trained or used for prediction,
by default torch.device("cpu").
batch_size : int, optional batch_size : int, optional
The batch size used for training, by default 32. The batch size used for training, by default 32.
num_epochs : int, optional num_epochs : int, optional
@@ -37,9 +38,11 @@ class BasicNN:
save_dir : Optional[str], optional save_dir : Optional[str], optional
The directory in which to save the model during training, by default None. The directory in which to save the model during training, by default None.
train_transform : Callable[..., Any], optional train_transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version used in the `fit` and `train_epoch` methods, by default None.
A function/transform that takes in an object and returns a transformed version used
in the `fit` and `train_epoch` methods, by default None.
test_transform : Callable[..., Any], optional test_transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version in the `predict`, `predict_proba` and `score` methods, , by default None.
A function/transform that takes in an object and returns a transformed version in the
`predict`, `predict_proba` and `score` methods, , by default None.
collate_fn : Callable[[List[T]], Any], optional collate_fn : Callable[[List[T]], Any], optional
The function used to collate data, by default None. The function used to collate data, by default None.
""" """


+ 4
- 2
abl/reasoning/__init__.py View File

@@ -1,2 +1,4 @@
from .kb import KBBase, GroundKB, PrologKB
from .reasoner import Reasoner
from .kb import GroundKB, KBBase, PrologKB
from .reasoner import Reasoner

__all__ = ["KBBase", "GroundKB", "PrologKB", "Reasoner"]

+ 29
- 30
abl/reasoning/kb.py View File

@@ -1,16 +1,15 @@
from abc import ABC, abstractmethod
import bisect import bisect
import os import os
from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from itertools import product, combinations
from itertools import combinations, product
from multiprocessing import Pool from multiprocessing import Pool
from functools import lru_cache


import numpy as np import numpy as np
import pyswip import pyswip


from ..utils.utils import flatten, reform_list, hamming_dist, to_hashable
from ..utils.cache import abl_cache from ..utils.cache import abl_cache
from ..utils.utils import flatten, hamming_dist, reform_list, to_hashable




class KBBase(ABC): class KBBase(ABC):
@@ -20,19 +19,19 @@ class KBBase(ABC):
Parameters Parameters
---------- ----------
pseudo_label_list : list pseudo_label_list : list
List of possible pseudo labels. It's recommended to arrange the pseudo labels in this
list so that each aligns with its corresponding index in the base model: the first with
List of possible pseudo labels. It's recommended to arrange the pseudo labels in this
list so that each aligns with its corresponding index in the base model: the first with
the 0th index, the second with the 1st, and so forth. the 0th index, the second with the 1st, and so forth.
max_err : float, optional max_err : float, optional
The upper tolerance limit when comparing the similarity between a pseudo label sample's reasoning
result and the ground truth. This is only applicable when the reasoning result is of a numerical type.
This is particularly relevant for regression problems where exact matches might not be
feasible. Defaults to 1e-10.
The upper tolerance limit when comparing the similarity between a pseudo label sample's
reasoning result and the ground truth. This is only applicable when the reasoning
result is of a numerical type. This is particularly relevant for regression problems where
exact matches might not be feasible. Defaults to 1e-10.
use_cache : bool, optional use_cache : bool, optional
Whether to use abl_cache for previously abduced candidates to speed up subsequent Whether to use abl_cache for previously abduced candidates to speed up subsequent
operations. Defaults to True. operations. Defaults to True.
key_func : func, optional key_func : func, optional
A function employed for hashing in abl_cache. This is only operational when use_cache
A function employed for hashing in abl_cache. This is only operational when use_cache
is set to True. Defaults to to_hashable. is set to True. Defaults to to_hashable.
cache_size: int, optional cache_size: int, optional
The cache size in abl_cache. This is only operational when use_cache is set to The cache size in abl_cache. This is only operational when use_cache is set to
@@ -75,7 +74,6 @@ class KBBase(ABC):
pseudo_label : List[Any] pseudo_label : List[Any]
Pseudo label sample. Pseudo label sample.
""" """
pass


def abduce_candidates(self, pseudo_label, y, max_revision_num, require_more_revision): def abduce_candidates(self, pseudo_label, y, max_revision_num, require_more_revision):
""" """
@@ -104,7 +102,7 @@ class KBBase(ABC):
""" """
Check whether the reasoning result of a pseduo label sample is equal to the ground truth Check whether the reasoning result of a pseduo label sample is equal to the ground truth
(or, within the maximum error allowed for numerical results). (or, within the maximum error allowed for numerical results).
Returns Returns
------- -------
bool bool
@@ -130,7 +128,7 @@ class KBBase(ABC):
Ground truth of the reasoning result for the sample. Ground truth of the reasoning result for the sample.
revision_idx : array-like revision_idx : array-like
Indices of where revisions should be made to the pseudo label sample. Indices of where revisions should be made to the pseudo label sample.
Returns Returns
------- -------
List[List[Any]] List[List[Any]]
@@ -149,8 +147,8 @@ class KBBase(ABC):


def _revision(self, revision_num, pseudo_label, y): def _revision(self, revision_num, pseudo_label, y):
""" """
For a specified number of labels in a pseudo label sample to revise, iterate through all possible
indices to find any candidates that are compatible with the knowledge base.
For a specified number of labels in a pseudo label sample to revise, iterate through
all possible indices to find any candidates that are compatible with the knowledge base.
""" """
new_candidates = [] new_candidates = []
revision_idx_list = combinations(range(len(pseudo_label)), revision_num) revision_idx_list = combinations(range(len(pseudo_label)), revision_num)
@@ -164,8 +162,8 @@ class KBBase(ABC):
def _abduce_by_search(self, pseudo_label, y, max_revision_num, require_more_revision): def _abduce_by_search(self, pseudo_label, y, max_revision_num, require_more_revision):
""" """
Perform abductive reasoning by exhastive search. Specifically, begin with 0 and Perform abductive reasoning by exhastive search. Specifically, begin with 0 and
continuously increase the number of labels in a pseudo label sample to revise, until candidates
that are compatible with the knowledge base are found.
continuously increase the number of labels in a pseudo label sample to revise, until
candidates that are compatible with the knowledge base are found.


Parameters Parameters
---------- ----------
@@ -177,8 +175,8 @@ class KBBase(ABC):
The upper limit on the number of revisions. The upper limit on the number of revisions.
require_more_revision : int require_more_revision : int
If larger than 0, then after having found any candidates compatible with the If larger than 0, then after having found any candidates compatible with the
knowledge base, continue to increase the number of labels in a pseudo label sample to revise to
get more possible compatible candidates.
knowledge base, continue to increase the number of labels in a pseudo label sample to
revise to get more possible compatible candidates.


Returns Returns
------- -------
@@ -286,7 +284,7 @@ class GroundKB(KBBase):
Perform abductive reasoning by directly retrieving compatible candidates from Perform abductive reasoning by directly retrieving compatible candidates from
the prebuilt GKB. In this way, the time-consuming exhaustive search can be the prebuilt GKB. In this way, the time-consuming exhaustive search can be
avoided. avoided.
Parameters Parameters
---------- ----------
pseudo_label : List[Any] pseudo_label : List[Any]
@@ -347,7 +345,7 @@ class GroundKB(KBBase):
num_candidates = len(self.GKB[i]) if i in self.GKB else 0 num_candidates = len(self.GKB[i]) if i in self.GKB else 0
GKB_info_parts.append(f"{num_candidates} candidates of length {i}") GKB_info_parts.append(f"{num_candidates} candidates of length {i}")
GKB_info = ", ".join(GKB_info_parts) GKB_info = ", ".join(GKB_info_parts)
return ( return (
f"{self.__class__.__name__} is a KB with " f"{self.__class__.__name__} is a KB with "
f"pseudo_label_list={self.pseudo_label_list!r}, " f"pseudo_label_list={self.pseudo_label_list!r}, "
@@ -400,7 +398,7 @@ class PrologKB(KBBase):
returned `Res` as the reasoning results. To use this default function, there must be returned `Res` as the reasoning results. To use this default function, there must be
a `logic_forward` method in the pl file to perform reasoning. a `logic_forward` method in the pl file to perform reasoning.
Otherwise, users would override this function. Otherwise, users would override this function.
Parameters Parameters
---------- ----------
pseudo_label : List[Any] pseudo_label : List[Any]
@@ -429,9 +427,10 @@ class PrologKB(KBBase):
def get_query_string(self, pseudo_label, y, revision_idx): def get_query_string(self, pseudo_label, y, revision_idx):
""" """
Get the query to be used for consulting Prolog. Get the query to be used for consulting Prolog.
This is a default function for demo, users would override this function to adapt to their own
Prolog file. In this demo function, return query `logic_forward([kept_labels, Revise_labels], Res).`.
This is a default function for demo, users would override this function to adapt to
their own Prolog file. In this demo function, return query
`logic_forward([kept_labels, Revise_labels], Res).`.

Parameters Parameters
---------- ----------
pseudo_label : List[Any] pseudo_label : List[Any]
@@ -440,7 +439,7 @@ class PrologKB(KBBase):
Ground truth of the reasoning result for the sample. Ground truth of the reasoning result for the sample.
revision_idx : array-like revision_idx : array-like
Indices of where revisions should be made to the pseudo label sample. Indices of where revisions should be made to the pseudo label sample.
Returns Returns
------- -------
str str
@@ -448,14 +447,14 @@ class PrologKB(KBBase):
""" """
query_string = "logic_forward(" query_string = "logic_forward("
query_string += self._revision_pseudo_label(pseudo_label, revision_idx) query_string += self._revision_pseudo_label(pseudo_label, revision_idx)
key_is_none_flag = y is None or (type(y) == list and y[0] is None)
key_is_none_flag = y is None or (isinstance(y, list) and y[0] is None)
query_string += ",%s)." % y if not key_is_none_flag else ")." query_string += ",%s)." % y if not key_is_none_flag else ")."
return query_string return query_string


def revise_at_idx(self, pseudo_label, y, revision_idx): def revise_at_idx(self, pseudo_label, y, revision_idx):
""" """
Revise the pseudo label sample at specified index positions by querying Prolog. Revise the pseudo label sample at specified index positions by querying Prolog.
Parameters Parameters
---------- ----------
pseudo_label : List[Any] pseudo_label : List[Any]
@@ -464,7 +463,7 @@ class PrologKB(KBBase):
Ground truth of the reasoning result for the sample. Ground truth of the reasoning result for the sample.
revision_idx : array-like revision_idx : array-like
Indices of where revisions should be made to the pseudo label sample. Indices of where revisions should be made to the pseudo label sample.
Returns Returns
------- -------
List[List[Any]] List[List[Any]]


+ 6
- 10
abl/reasoning/reasoner.py View File

@@ -1,11 +1,7 @@
import numpy as np import numpy as np
from zoopt import Dimension, Objective, Parameter, Opt
from ..utils.utils import (
confidence_dist,
flatten,
reform_list,
hamming_dist,
)
from zoopt import Dimension, Objective, Opt, Parameter

from ..utils.utils import confidence_dist, hamming_dist




class Reasoner: class Reasoner:
@@ -124,7 +120,7 @@ class Reasoner:


def zoopt_get_solution(self, symbol_num, data_sample, max_revision_num): def zoopt_get_solution(self, symbol_num, data_sample, max_revision_num):
""" """
Get the optimal solution using ZOOpt library. The solution is a list of
Get the optimal solution using ZOOpt library. The solution is a list of
boolean values, where '1' (True) indicates the indices chosen to be revised. boolean values, where '1' (True) indicates the indices chosen to be revised.


Parameters Parameters
@@ -148,7 +144,7 @@ class Reasoner:


def zoopt_revision_score(self, symbol_num, data_sample, sol): def zoopt_revision_score(self, symbol_num, data_sample, sol):
""" """
Get the revision score for a solution. A lower score suggests that ZOOpt library
Get the revision score for a solution. A lower score suggests that ZOOpt library
has a higher preference for this solution. has a higher preference for this solution.
""" """
revision_idx = np.where(sol.get_x() != 0)[0] revision_idx = np.where(sol.get_x() != 0)[0]
@@ -198,7 +194,7 @@ class Reasoner:
Returns Returns
------- -------
List[Any] List[Any]
A revised pseudo label sample through abductive reasoning, which is compatible
A revised pseudo label sample through abductive reasoning, which is compatible
with the knowledge base. with the knowledge base.
""" """
symbol_num = data_sample.elements_num("pred_pseudo_label") symbol_num = data_sample.elements_num("pred_pseudo_label")


+ 3
- 1
abl/structures/__init__.py View File

@@ -1,2 +1,4 @@
from .base_data_element import BaseDataElement from .base_data_element import BaseDataElement
from .list_data import ListData
from .list_data import ListData

__all__ = ["BaseDataElement", "ListData"]

+ 4
- 10
abl/structures/base_data_element.py View File

@@ -224,9 +224,7 @@ class BaseDataElement:
metainfo (dict): A dict contains the meta information metainfo (dict): A dict contains the meta information
of image, such as ``img_shape``, ``scale_factor``, etc. of image, such as ``img_shape``, ``scale_factor``, etc.
""" """
assert isinstance(
metainfo, dict
), f"metainfo should be a ``dict`` but got {type(metainfo)}"
assert isinstance(metainfo, dict), f"metainfo should be a ``dict`` but got {type(metainfo)}"
meta = copy.deepcopy(metainfo) meta = copy.deepcopy(metainfo)
for k, v in meta.items(): for k, v in meta.items():
self.set_field(name=k, value=v, field_type="metainfo", dtype=None) self.set_field(name=k, value=v, field_type="metainfo", dtype=None)
@@ -388,8 +386,7 @@ class BaseDataElement:
super().__setattr__(name, value) super().__setattr__(name, value)
else: else:
raise AttributeError( raise AttributeError(
f"{name} has been used as a "
"private attribute, which is immutable."
f"{name} has been used as a " "private attribute, which is immutable."
) )
else: else:
self.set_field(name=name, value=value, field_type="data", dtype=None) self.set_field(name=name, value=value, field_type="data", dtype=None)
@@ -458,9 +455,7 @@ class BaseDataElement:
functions.""" functions."""
assert field_type in ["metainfo", "data"] assert field_type in ["metainfo", "data"]
if dtype is not None: if dtype is not None:
assert isinstance(
value, dtype
), f"{value} should be a {dtype} but got {type(value)}"
assert isinstance(value, dtype), f"{value} should be a {dtype} but got {type(value)}"


if field_type == "metainfo": if field_type == "metainfo":
if name in self._data_fields: if name in self._data_fields:
@@ -571,8 +566,7 @@ class BaseDataElement:
def to_dict(self) -> dict: def to_dict(self) -> dict:
"""Convert BaseDataElement to dict.""" """Convert BaseDataElement to dict."""
return { return {
k: v.to_dict() if isinstance(v, BaseDataElement) else v
for k, v in self.all_items()
k: v.to_dict() if isinstance(v, BaseDataElement) else v for k, v in self.all_items()
} }


def __repr__(self) -> str: def __repr__(self) -> str:


+ 1
- 2
abl/structures/list_data.py View File

@@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import itertools import itertools
from collections.abc import Sized
from typing import Any, List, Union
from typing import List, Union


import numpy as np import numpy as np
import torch import torch


+ 21
- 1
abl/utils/__init__.py View File

@@ -1,3 +1,23 @@
from .cache import Cache, abl_cache from .cache import Cache, abl_cache
from .logger import ABLLogger, print_log from .logger import ABLLogger, print_log
from .utils import *
from .utils import (
calculate_revision_num,
confidence_dist,
flatten,
hamming_dist,
reform_list,
to_hashable,
)

__all__ = [
"Cache",
"ABLLogger",
"print_log",
"calculate_revision_num",
"confidence_dist",
"flatten",
"hamming_dist",
"reform_list",
"to_hashable",
"abl_cache",
]

+ 0
- 6
abl/utils/cache.py View File

@@ -1,10 +1,5 @@
import pickle
import os
import os.path as osp
from typing import Callable, Generic, TypeVar from typing import Callable, Generic, TypeVar


from .logger import print_log, ABLLogger

K = TypeVar("K") K = TypeVar("K")
T = TypeVar("T") T = TypeVar("T")
PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields
@@ -73,7 +68,6 @@ class Cache(Generic[K, T]):
# Empty the oldest link and make it the new root. # Empty the oldest link and make it the new root.
self.root = oldroot[NEXT] self.root = oldroot[NEXT]
oldkey = self.root[KEY] oldkey = self.root[KEY]
oldresult = self.root[RESULT]
self.root[KEY] = self.root[RESULT] = None self.root[KEY] = self.root[RESULT] = None
# Now update the cache dictionary. # Now update the cache dictionary.
del self.cache_dict[oldkey] del self.cache_dict[oldkey]


+ 15
- 7
abl/utils/logger.py View File

@@ -15,7 +15,8 @@ class FilterDuplicateWarning(logging.Filter):
""" """
Filter for eliminating repeated warning messages in logging. Filter for eliminating repeated warning messages in logging.


This filter checks for duplicate warning messages and allows only the first occurrence of each message to be logged, filtering out subsequent duplicates.
This filter checks for duplicate warning messages and allows only the first occurrence of
each message to be logged, filtering out subsequent duplicates.


Parameters Parameters
---------- ----------
@@ -145,7 +146,8 @@ class ABLLogger(Logger, ManagerMixin):


`ABLLogger` provides a formatted logger that can log messages with different `ABLLogger` provides a formatted logger that can log messages with different
log levels. It allows the creation of logger instances in a similar manner to `ManagerMixin`. log levels. It allows the creation of logger instances in a similar manner to `ManagerMixin`.
The logger has features like distributed log storage and colored terminal output for different log levels.
The logger has features like distributed log storage and colored terminal output for different
log levels.


Parameters Parameters
---------- ----------
@@ -154,7 +156,8 @@ class ABLLogger(Logger, ManagerMixin):
logger_name : str, optional logger_name : str, optional
`name` attribute of `logging.Logger` instance. Defaults to 'abl'. `name` attribute of `logging.Logger` instance. Defaults to 'abl'.
log_file : str, optional log_file : str, optional
The log filename. If specified, a `FileHandler` will be added to the logger. Defaults to None.
The log filename. If specified, a `FileHandler` will be added to the logger.
Defaults to None.
log_level : Union[int, str] log_level : Union[int, str]
The log level of the handler. Defaults to 'INFO'. The log level of the handler. Defaults to 'INFO'.
If log level is 'DEBUG', distributed logs will be saved during distributed training. If log level is 'DEBUG', distributed logs will be saved during distributed training.
@@ -287,20 +290,25 @@ def print_log(msg, logger: Optional[Union[Logger, str]] = None, level=logging.IN
""" """
Print a log message using the specified logger or a default method. Print a log message using the specified logger or a default method.


This function logs a message with a given logger, if provided, or prints it using the standard `print` function. It supports special logger types such as 'silent' and 'current'.
This function logs a message with a given logger, if provided, or prints it using
the standard `print` function. It supports special logger types such as 'silent' and 'current'.


Parameters Parameters
---------- ----------
msg : str msg : str
The message to be logged. The message to be logged.
logger : Optional[Union[Logger, str]], optional logger : Optional[Union[Logger, str]], optional
The logger to use for logging the message. It can be a `logging.Logger` instance, a string specifying the logger name, 'silent', 'current', or None. If None, the `print` method is used.
The logger to use for logging the message. It can be a `logging.Logger` instance, a string
specifying the logger name, 'silent', 'current', or None. If None, the `print`
method is used.
- 'silent': No message will be printed. - 'silent': No message will be printed.
- 'current': Use the latest created logger to log the message. - 'current': Use the latest created logger to log the message.
- other str: The instance name of the logger. A `ValueError` is raised if the logger has not been created.
- other str: The instance name of the logger. A `ValueError` is raised if the logger has not
been created.
- None: The `print()` method is used for logging. - None: The `print()` method is used for logging.
level : int, optional level : int, optional
The logging level. This is only applicable when `logger` is a Logger object, 'current', or a named logger instance. The default is `logging.INFO`.
The logging level. This is only applicable when `logger` is a Logger object, 'current',
or a named logger instance. The default is `logging.INFO`.
""" """
if logger is None: if logger is None:
print(msg) print(msg)


+ 12
- 12
abl/utils/manager.py View File

@@ -6,7 +6,7 @@ from collections import OrderedDict
from typing import Type, TypeVar from typing import Type, TypeVar


_lock = threading.RLock() _lock = threading.RLock()
T = TypeVar('T')
T = TypeVar("T")




def _accquire_lock() -> None: def _accquire_lock() -> None:
@@ -47,7 +47,7 @@ class ManagerMeta(type):
cls._instance_dict = OrderedDict() cls._instance_dict = OrderedDict()
params = inspect.getfullargspec(cls) params = inspect.getfullargspec(cls)
params_names = params[0] if params[0] else [] params_names = params[0] if params[0] else []
assert 'name' in params_names, f'{cls} must have the `name` argument'
assert "name" in params_names, f"{cls} must have the `name` argument"
super().__init__(*args) super().__init__(*args)




@@ -72,9 +72,8 @@ class ManagerMixin(metaclass=ManagerMeta):
name (str): Name of the instance. Defaults to ''. name (str): Name of the instance. Defaults to ''.
""" """


def __init__(self, name: str = '', **kwargs):
assert isinstance(name, str) and name, \
'name argument must be an non-empty string.'
def __init__(self, name: str = "", **kwargs):
assert isinstance(name, str) and name, "name argument must be an non-empty string."
self._instance_name = name self._instance_name = name


@classmethod @classmethod
@@ -102,8 +101,7 @@ class ManagerMixin(metaclass=ManagerMeta):
instance. instance.
""" """
_accquire_lock() _accquire_lock()
assert isinstance(name, str), \
f'type of name should be str, but got {type(cls)}'
assert isinstance(name, str), f"type of name should be str, but got {type(cls)}"
instance_dict = cls._instance_dict # type: ignore instance_dict = cls._instance_dict # type: ignore
# Get the instance by name. # Get the instance by name.
if name not in instance_dict: if name not in instance_dict:
@@ -111,9 +109,10 @@ class ManagerMixin(metaclass=ManagerMeta):
instance_dict[name] = instance # type: ignore instance_dict[name] = instance # type: ignore
elif kwargs: elif kwargs:
warnings.warn( warnings.warn(
f'{cls} instance named of {name} has been created, '
'the method `get_instance` should not accept any other '
'arguments')
f"{cls} instance named of {name} has been created, "
"the method `get_instance` should not accept any other "
"arguments"
)
# Get latest instantiated instance or root instance. # Get latest instantiated instance or root instance.
_release_lock() _release_lock()
return instance_dict[name] return instance_dict[name]
@@ -141,8 +140,9 @@ class ManagerMixin(metaclass=ManagerMeta):
_accquire_lock() _accquire_lock()
if not cls._instance_dict: if not cls._instance_dict:
raise RuntimeError( raise RuntimeError(
f'Before calling {cls.__name__}.get_current_instance(), you '
'should call get_instance(name=xxx) at least once.')
f"Before calling {cls.__name__}.get_current_instance(), you "
"should call get_instance(name=xxx) at least once."
)
name = next(iter(reversed(cls._instance_dict))) name = next(iter(reversed(cls._instance_dict)))
_release_lock() _release_lock()
return cls._instance_dict[name] return cls._instance_dict[name]


+ 0
- 57
abl/utils/utils.py View File

@@ -221,60 +221,3 @@ def calculate_revision_num(parameter, total_length):
if parameter < 0: if parameter < 0:
raise ValueError("If parameter is an int, it must be non-negative.") raise ValueError("If parameter is an int, it must be non-negative.")
return parameter return parameter


if __name__ == "__main__":
A = np.array(
[
[
0.18401675,
0.06797526,
0.06797541,
0.06801736,
0.06797528,
0.06797526,
0.06818808,
0.06797527,
0.06800033,
0.06797526,
0.06797526,
0.06797526,
0.06797526,
],
[
0.07223161,
0.0685229,
0.06852708,
0.17227574,
0.06852163,
0.07018146,
0.06860291,
0.06852849,
0.06852163,
0.0685216,
0.0685216,
0.06852174,
0.0685216,
],
[
0.06794382,
0.0679436,
0.06794395,
0.06794346,
0.06794346,
0.18467231,
0.06794345,
0.06794871,
0.06794345,
0.06794345,
0.06794345,
0.06794345,
0.06794345,
],
],
dtype=np.float32,
)
B = [[0, 9, 3], [0, 11, 4]]

print(ori_confidence_dist(A, B))
print(confidence_dist(A, B))

+ 3
- 0
build_tools/requirements.txt View File

@@ -0,0 +1,3 @@
-r ../requirements.txt
pytest
pytest-cov

+ 1
- 6
docs/Examples/MNISTAdd.rst View File

@@ -3,7 +3,7 @@ MNIST Addition


MNIST Addition was first introduced in [1] and the inputs of this task are pairs of MNIST images and the outputs are their sums. The dataset looks like this: MNIST Addition was first introduced in [1] and the inputs of this task are pairs of MNIST images and the outputs are their sums. The dataset looks like this:


.. image:: ../img/image_1.jpg
.. image:: ../img/Datasets_1.png
:width: 350px :width: 350px
:align: center :align: center


@@ -11,9 +11,4 @@ MNIST Addition was first introduced in [1] and the inputs of this task are pairs


The ``gt_pseudo_label`` is only used to test the performance of the machine learning model and is not used in the training phase. The ``gt_pseudo_label`` is only used to test the performance of the machine learning model and is not used in the training phase.


In the Abductive Learning framework, the inference process is as follows:

.. image:: ../img/image_2.jpg
:width: 700px

[1] Robin Manhaeve, Sebastijan Dumancic, Angelika Kimmig, Thomas Demeester, and Luc De Raedt. Deepproblog: Neural probabilistic logic programming. In Advances in Neural Information Processing Systems 31 (NeurIPS), pages 3749-3759.2018. [1] Robin Manhaeve, Sebastijan Dumancic, Angelika Kimmig, Thomas Demeester, and Luc De Raedt. Deepproblog: Neural probabilistic logic programming. In Advances in Neural Information Processing Systems 31 (NeurIPS), pages 3749-3759.2018.

+ 5
- 7
docs/conf.py View File

@@ -1,14 +1,12 @@
import sys
import os import os
import re import re
import sys


if not "READTHEDOCS" in os.environ:

if "READTHEDOCS" not in os.environ:
sys.path.insert(0, os.path.abspath("..")) sys.path.insert(0, os.path.abspath(".."))
sys.path.append(os.path.abspath("./ABL/")) sys.path.append(os.path.abspath("./ABL/"))


# from sphinx.locale import _
from sphinx_rtd_theme import __version__



project = "ABL" project = "ABL"
slug = re.sub(r"\W+", "-", project.lower()) slug = re.sub(r"\W+", "-", project.lower())
@@ -48,8 +46,8 @@ pygments_style = "default"


html_theme = "sphinx_rtd_theme" html_theme = "sphinx_rtd_theme"
html_theme_options = {"display_version": True} html_theme_options = {"display_version": True}
html_static_path = ['_static']
html_css_files = ['custom.css']
html_static_path = ["_static"]
html_css_files = ["custom.css"]
# html_theme_path = ["../.."] # html_theme_path = ["../.."]
# html_logo = "demo/static/logo-wordmark-light.svg" # html_logo = "demo/static/logo-wordmark-light.svg"
# html_show_sourcelink = True # html_show_sourcelink = True


+ 3
- 3
examples/hed/datasets/get_hed.py View File

@@ -1,11 +1,11 @@
import os import os
import os.path as osp import os.path as osp
import cv2
import pickle import pickle
import numpy as np
import random import random

from collections import defaultdict from collections import defaultdict

import cv2
import numpy as np
from torchvision.transforms import transforms from torchvision.transforms import transforms


CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))


+ 0
- 388
examples/hed/framework_hed.py View File

@@ -1,388 +0,0 @@
# coding: utf-8
# ================================================================#
# Copyright (C) 2021 Freecss All rights reserved.
#
# File Name :framework.py
# Author :freecss
# Email :karlfreecss@gmail.com
# Created Date :2021/06/07
# Description :
#
# ================================================================#

import torch
import torch.nn as nn
import numpy as np
import os

from abl.utils.plog import INFO
from abl.utils.utils import flatten, reform_idx
from abl.learning.basic_nn import BasicNN, BasicDataset

from utils import gen_mappings, mapping_res, remapping_res
from models.nn import SymbolNetAutoencoder
from torch.utils.data import RandomSampler
from datasets.get_hed import get_pretrain_data


def hed_pretrain(kb, cls, recorder):
cls_autoencoder = SymbolNetAutoencoder(num_classes=len(kb.pseudo_label_list))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if not os.path.exists("./weights/pretrain_weights.pth"):
INFO("Pretrain Start")
pretrain_data_X, pretrain_data_Y = get_pretrain_data(["0", "1", "10", "11"])
pretrain_data = BasicDataset(pretrain_data_X, pretrain_data_Y)
pretrain_data_loader = torch.utils.data.DataLoader(
pretrain_data, batch_size=64, shuffle=True
)

criterion = nn.MSELoss()
optimizer = torch.optim.RMSprop(
cls_autoencoder.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6
)

pretrain_model = BasicNN(
cls_autoencoder,
criterion,
optimizer,
device,
save_interval=1,
save_dir=recorder.save_dir,
num_epochs=10,
recorder=recorder,
)
pretrain_model.fit(pretrain_data_loader)
torch.save(
cls_autoencoder.base_model.state_dict(), "./weights/pretrain_weights.pth"
)
cls.load_state_dict(cls_autoencoder.base_model.state_dict())

else:
cls.load_state_dict(torch.load("./weights/pretrain_weights.pth"))


def _get_char_acc(model, X, consistent_pred_res, mapping):
original_pred_res = model.predict(X)["label"]
pred_res = flatten(mapping_res(original_pred_res, mapping))
INFO("Current model's output: ", pred_res)
INFO("Abduced labels: ", flatten(consistent_pred_res))
assert len(pred_res) == len(flatten(consistent_pred_res))
return sum(
[
pred_res[idx] == flatten(consistent_pred_res)[idx]
for idx in range(len(pred_res))
]
) / len(pred_res)


def abduce_and_train(model, abducer, mapping, train_X_true, select_num):
select_idx = RandomSampler(train_X_true, num_samples=select_num,replacement=False)
X = [train_X_true[idx] for idx in select_idx]

# original_pred_res = model.predict(X)['label']
pred_label = model.predict(X)["label"]

if mapping == None:
mappings = gen_mappings([0, 1, 2, 3], ["+", "=", 0, 1])
else:
mappings = [mapping]

consistent_idx = []
consistent_pred_res = []

for m in mappings:
pred_pseudo_label = mapping_res(pred_label, m)
max_revision_num = 20
solution = abducer.zoopt_get_solution(
pred_label,
pred_pseudo_label,
[None] * len(pred_label),
[None] * len(pred_label),
max_revision_num,
)
all_address_flag = reform_idx(solution, pred_label)

consistent_idx_tmp = []
consistent_pred_res_tmp = []

for idx in range(len(pred_label)):
address_idx = [
i for i, flag in enumerate(all_address_flag[idx]) if flag != 0
]
candidate = abducer.revise_by_idx([pred_pseudo_label[idx]], None, address_idx)
if len(candidate) > 0:
consistent_idx_tmp.append(idx)
consistent_pred_res_tmp.append(candidate[0][0])

if len(consistent_idx_tmp) > len(consistent_idx):
consistent_idx = consistent_idx_tmp
consistent_pred_res = consistent_pred_res_tmp
if len(mappings) > 1:
mapping = m

if len(consistent_idx) == 0:
return 0, 0, None

INFO("Train pool size is:", len(flatten(consistent_pred_res)))
INFO("Start to use abduced pseudo label to train model...")
model.train(
[X[idx] for idx in consistent_idx], remapping_res(consistent_pred_res, mapping)
)

consistent_acc = len(consistent_idx) / select_num
char_acc = _get_char_acc(
model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping
)
INFO("consistent_acc is %s, char_acc is %s" % (consistent_acc, char_acc))
return consistent_acc, char_acc, mapping


# def abduce_and_train(model, abducer, mapping, train_X_true, select_num):
# select_idx = np.random.randint(len(train_X_true), size=select_num)
# X = []
# for idx in select_idx:
# X.append(train_X_true[idx])

# original_pred_res = model.predict(X)['label']

# if mapping == None:
# mappings = gen_mappings([0, 1, 2, 3],['+', '=', 0, 1])
# else:
# mappings = [mapping]

# consistent_idx = []
# consistent_pred_res = []

# for m in mappings:
# pred_res = mapping_res(original_pred_res, m)
# max_abduce_num = 20
# solution = abducer.zoopt_get_solution(pred_res, [None] * len(pred_res), [None] * len(pred_res), max_abduce_num)
# all_address_flag = reform_idx(solution, pred_res)

# consistent_idx_tmp = []
# consistent_pred_res_tmp = []

# for idx in range(len(pred_res)):
# address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0]
# candidate = abducer.revise_by_idx([pred_res[idx]], None, address_idx)
# if len(candidate) > 0:
# consistent_idx_tmp.append(idx)
# consistent_pred_res_tmp.append(candidate[0][0])

# if len(consistent_idx_tmp) > len(consistent_idx):
# consistent_idx = consistent_idx_tmp
# consistent_pred_res = consistent_pred_res_tmp
# if len(mappings) > 1:
# mapping = m

# if len(consistent_idx) == 0:
# return 0, 0, None

# INFO('Train pool size is:', len(flatten(consistent_pred_res)))
# INFO("Start to use abduced pseudo label to train model...")
# model.train([X[idx] for idx in consistent_idx], remapping_res(consistent_pred_res, mapping))

# consistent_acc = len(consistent_idx) / select_num
# char_acc = _get_char_acc(model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping)
# INFO('consistent_acc is %s, char_acc is %s' % (consistent_acc, char_acc))
# return consistent_acc, char_acc, mapping


def _remove_duplicate_rule(rule_dict):
add_nums_dict = {}
for r in list(rule_dict):
add_nums = str(r.split("]")[0].split("[")[1]) + str(
r.split("]")[1].split("[")[1]
) # r = 'my_op([1], [0], [1, 0])' then add_nums = '10'
if add_nums in add_nums_dict:
old_r = add_nums_dict[add_nums]
if rule_dict[r] >= rule_dict[old_r]:
rule_dict.pop(old_r)
add_nums_dict[add_nums] = r
else:
rule_dict.pop(r)
else:
add_nums_dict[add_nums] = r
return list(rule_dict)


def get_rules_from_data(
model, abducer, mapping, train_X_true, samples_per_rule, samples_num
):
rules = []
for _ in range(samples_num):
while True:
select_idx = np.random.randint(len(train_X_true), size=samples_per_rule)
X = []
for idx in select_idx:
X.append(train_X_true[idx])
original_pred_res = model.predict(X)["label"]
pred_res = mapping_res(original_pred_res, mapping)

consistent_idx = []
consistent_pred_res = []
for idx in range(len(pred_res)):
if abducer.kb.logic_forward([pred_res[idx]]):
consistent_idx.append(idx)
consistent_pred_res.append(pred_res[idx])

if len(consistent_pred_res) != 0:
rule = abducer.abduce_rules(consistent_pred_res)
if rule != None:
break
rules.append(rule)

all_rule_dict = {}
for rule in rules:
for r in rule:
all_rule_dict[r] = 1 if r not in all_rule_dict else all_rule_dict[r] + 1
rule_dict = {rule: cnt for rule, cnt in all_rule_dict.items() if cnt >= 5}
rules = _remove_duplicate_rule(rule_dict)

return rules


def _get_consist_rule_acc(model, abducer, mapping, rules, X):
cnt = 0
for x in X:
original_pred_res = model.predict([x])["label"]
pred_res = flatten(mapping_res(original_pred_res, mapping))
if abducer.kb.consist_rule(pred_res, rules):
cnt += 1
return cnt / len(X)


def train_with_rule(
model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8
):
train_X = train_data
val_X = val_data

samples_num = 50
samples_per_rule = 3

# Start training / for each length of equations
for equation_len in range(min_len, max_len):
INFO(
"============== equation_len: %d-%d ================"
% (equation_len, equation_len + 1)
)
train_X_true = train_X[1][equation_len]
train_X_false = train_X[0][equation_len]
val_X_true = val_X[1][equation_len]
val_X_false = val_X[0][equation_len]

train_X_true.extend(train_X[1][equation_len + 1])
train_X_false.extend(train_X[0][equation_len + 1])
val_X_true.extend(val_X[1][equation_len + 1])
val_X_false.extend(val_X[0][equation_len + 1])

condition_cnt = 0
while True:
if equation_len == min_len:
mapping = None

# Abduce and train NN
consistent_acc, char_acc, mapping = abduce_and_train(
model, abducer, mapping, train_X_true, select_num
)
if consistent_acc == 0:
continue

# Test if we can use mlp to evaluate
if consistent_acc >= 0.9 and char_acc >= 0.9:
condition_cnt += 1
else:
condition_cnt = 0

# The condition has been satisfied continuously five times
if condition_cnt >= 5:
INFO("Now checking if we can go to next course")
rules = get_rules_from_data(
model, abducer, mapping, train_X_true, samples_per_rule, samples_num
)
INFO("Learned rules from data:", rules)

true_consist_rule_acc = _get_consist_rule_acc(
model, abducer, mapping, rules, val_X_true
)
false_consist_rule_acc = _get_consist_rule_acc(
model, abducer, mapping, rules, val_X_false
)

INFO(
"consist_rule_acc is %f, %f\n"
% (true_consist_rule_acc, false_consist_rule_acc)
)
# decide next course or restart
if true_consist_rule_acc > 0.95 and false_consist_rule_acc < 0.1:
torch.save(
model.classifier_list[0].model.state_dict(),
"./weights/weights_%d.pth" % equation_len,
)
break
else:
if equation_len == min_len:
INFO("Final mapping is: ", mapping)
model.classifier_list[0].model.load_state_dict(
torch.load("./weights/pretrain_weights.pth")
)
else:
model.classifier_list[0].model.load_state_dict(
torch.load("./weights/weights_%d.pth" % (equation_len - 1))
)
condition_cnt = 0
INFO("Reload Model and retrain")

return model, mapping


def hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8):
train_X = train_data
test_X = test_data

# Calcualte how many equations should be selected in each length
# for each length, there are equation_samples_num[equation_len] rules
print("Now begin to train final mlp model")
equation_samples_num = []
len_cnt = max_len - min_len + 1
samples_num = 50
equation_samples_num += [0] * min_len
if samples_num % len_cnt == 0:
equation_samples_num += [samples_num // len_cnt] * len_cnt
else:
equation_samples_num += [samples_num // len_cnt] * len_cnt
equation_samples_num[-1] += samples_num % len_cnt
assert sum(equation_samples_num) == samples_num

# Abduce rules
rules = []
samples_per_rule = 3
for equation_len in range(min_len, max_len + 1):
equation_rules = get_rules_from_data(
model,
abducer,
mapping,
train_X[1][equation_len],
samples_per_rule,
equation_samples_num[equation_len],
)
rules.extend(equation_rules)
rules = list(set(rules))
INFO("Learned rules from data:", rules)

for equation_len in range(5, 27):
true_consist_rule_acc = _get_consist_rule_acc(
model, abducer, mapping, rules, test_X[1][equation_len]
)
false_consist_rule_acc = _get_consist_rule_acc(
model, abducer, mapping, rules, test_X[0][equation_len]
)
INFO(
"consist_rule_acc of testing length %d equations are %f, %f"
% (equation_len, true_consist_rule_acc, false_consist_rule_acc)
)


if __name__ == "__main__":
pass

+ 15
- 14
examples/hed/hed_bridge.py View File

@@ -1,20 +1,18 @@
import os import os
from collections import defaultdict from collections import defaultdict
from typing import Any, List
import torch import torch
from torch.utils.data import DataLoader


from abl.reasoning import ReasonerBase
from abl.learning import ABLModel, BasicNN
from abl.bridge import SimpleBridge from abl.bridge import SimpleBridge
from abl.dataset import RegressionDataset
from abl.evaluation import BaseMetric from abl.evaluation import BaseMetric
from abl.dataset import BridgeDataset, RegressionDataset
from abl.learning import ABLModel, BasicNN
from abl.reasoning import ReasonerBase
from abl.structures import ListData from abl.structures import ListData
from abl.utils import print_log from abl.utils import print_log

from examples.hed.utils import gen_mappings, InfiniteSampler
from examples.models.nn import SymbolNetAutoencoder
from examples.hed.datasets.get_hed import get_pretrain_data from examples.hed.datasets.get_hed import get_pretrain_data
from examples.hed.utils import InfiniteSampler, gen_mappings
from examples.models.nn import SymbolNetAutoencoder




class HEDBridge(SimpleBridge): class HEDBridge(SimpleBridge):
@@ -95,7 +93,8 @@ class HEDBridge(SimpleBridge):
character_accuracy = self.model.valid(filtered_data_samples) character_accuracy = self.model.valid(filtered_data_samples)
revisible_ratio = len(filtered_data_samples.X) / len(data_samples.X) revisible_ratio = len(filtered_data_samples.X) / len(data_samples.X)
print_log( print_log(
f"Revisible ratio is {revisible_ratio:.3f}, Character accuracy is {character_accuracy:.3f}",
f"Revisible ratio is {revisible_ratio:.3f}, Character \
accuracy is {character_accuracy:.3f}",
logger="current", logger="current",
) )


@@ -111,7 +110,8 @@ class HEDBridge(SimpleBridge):
false_ratio = self.calc_consistent_ratio(val_X_false, rule) false_ratio = self.calc_consistent_ratio(val_X_false, rule)


print_log( print_log(
f"True consistent ratio is {true_ratio:.3f}, False inconsistent ratio is {1 - false_ratio:.3f}",
f"True consistent ratio is {true_ratio:.3f}, False inconsistent ratio \
is {1 - false_ratio:.3f}",
logger="current", logger="current",
) )


@@ -143,7 +143,7 @@ class HEDBridge(SimpleBridge):


if len(consistent_instance) != 0: if len(consistent_instance) != 0:
rule = self.reasoner.abduce_rules(consistent_instance) rule = self.reasoner.abduce_rules(consistent_instance)
if rule != None:
if rule is not None:
rules.append(rule) rules.append(rule)
break break


@@ -214,7 +214,8 @@ class HEDBridge(SimpleBridge):
loss = self.model.train(filtered_sub_data_samples) loss = self.model.train(filtered_sub_data_samples)


print_log( print_log(
f"Equation Len(train) [{equation_len}] Segment Index [{seg_idx + 1}] model loss is {loss:.5f}",
f"Equation Len(train) [{equation_len}] Segment Index [{seg_idx + 1}] \
model loss is {loss:.5f}",
logger="current", logger="current",
) )


@@ -224,11 +225,11 @@ class HEDBridge(SimpleBridge):
condition_num = 0 condition_num = 0


if condition_num >= 5: if condition_num >= 5:
print_log(f"Now checking if we can go to next course", logger="current")
print_log("Now checking if we can go to next course", logger="current")
rules = self.get_rules_from_data( rules = self.get_rules_from_data(
data_samples, samples_per_rule=3, samples_num=50 data_samples, samples_per_rule=3, samples_num=50
) )
print_log(f"Learned rules from data: " + str(rules), logger="current")
print_log("Learned rules from data: " + str(rules), logger="current")


seems_good = self.check_rule_quality(rules, val_data, equation_len) seems_good = self.check_rule_quality(rules, val_data, equation_len)
if seems_good: if seems_good:


+ 8
- 6
examples/hed/hed_example.ipynb View File

@@ -66,7 +66,8 @@
" prolog_rules = prolog_result[0][\"X\"]\n", " prolog_rules = prolog_result[0][\"X\"]\n",
" rules = [rule.value for rule in prolog_rules]\n", " rules = [rule.value for rule in prolog_rules]\n",
" return rules\n", " return rules\n",
" \n",
"\n",
"\n",
"class HedReasoner(ReasonerBase):\n", "class HedReasoner(ReasonerBase):\n",
" def revise_at_idx(self, data_sample):\n", " def revise_at_idx(self, data_sample):\n",
" revision_idx = np.where(np.array(data_sample.flatten(\"revision_flag\")) != 0)[0]\n", " revision_idx = np.where(np.array(data_sample.flatten(\"revision_flag\")) != 0)[0]\n",
@@ -76,7 +77,9 @@
" return candidate\n", " return candidate\n",
"\n", "\n",
" def zoopt_revision_score(self, symbol_num, data_sample, sol):\n", " def zoopt_revision_score(self, symbol_num, data_sample, sol):\n",
" revision_flag = reform_list(list(sol.get_x().astype(np.int32)), data_sample.pred_pseudo_label)\n",
" revision_flag = reform_list(\n",
" list(sol.get_x().astype(np.int32)), data_sample.pred_pseudo_label\n",
" )\n",
" data_sample.revision_flag = revision_flag\n", " data_sample.revision_flag = revision_flag\n",
"\n", "\n",
" lefted_idxs = [i for i in range(len(data_sample.pred_idx))]\n", " lefted_idxs = [i for i in range(len(data_sample.pred_idx))]\n",
@@ -108,7 +111,7 @@
" for i in range(0, len(candidate_size)):\n", " for i in range(0, len(candidate_size)):\n",
" score -= math.exp(-i) * candidate_size[i]\n", " score -= math.exp(-i) * candidate_size[i]\n",
" return score\n", " return score\n",
" \n",
"\n",
" def abduce(self, data_sample):\n", " def abduce(self, data_sample):\n",
" symbol_num = data_sample.elements_num(\"pred_pseudo_label\")\n", " symbol_num = data_sample.elements_num(\"pred_pseudo_label\")\n",
" max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)\n", " max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)\n",
@@ -134,9 +137,8 @@
" def abduce_rules(self, pred_res):\n", " def abduce_rules(self, pred_res):\n",
" return self.kb.abduce_rules(pred_res)\n", " return self.kb.abduce_rules(pred_res)\n",
"\n", "\n",
"kb = HedKB(\n",
" pseudo_label_list=[1, 0, \"+\", \"=\"], pl_file=\"./datasets/learn_add.pl\"\n",
")\n",
"\n",
"kb = HedKB(pseudo_label_list=[1, 0, \"+\", \"=\"], pl_file=\"./datasets/learn_add.pl\")\n",
"reasoner = HedReasoner(kb, dist_func=\"hamming\", use_zoopt=True, max_revision=20)" "reasoner = HedReasoner(kb, dist_func=\"hamming\", use_zoopt=True, max_revision=20)"
] ]
}, },


+ 0
- 69
examples/hed/hed_knn_example.py View File

@@ -1,69 +0,0 @@
# coding: utf-8
# ================================================================#
# Copyright (C) 2021 Freecss All rights reserved.
#
# File Name :share_example.py
# Author :freecss
# Email :karlfreecss@gmail.com
# Created Date :2021/06/07
# Description :
#
# ================================================================#

import sys
sys.path.append("../")

from abl.utils.plog import logger, INFO
from abl.utils.utils import reduce_dimension
import torch.nn as nn
import torch

from abl.models.nn import LeNet5, SymbolNet
from abl.models.basic_model import BasicModel, BasicDataset
from abl.models.wabl_models import DecisionTree, WABLBasicModel
from sklearn.neighbors import KNeighborsClassifier

from abl.abducer.abducer_base import AbducerBase
from abl.abducer.kb import add_KB, HWF_KB, prolog_KB
from datasets.mnist_add.get_mnist_add import get_mnist_add
from datasets.hwf.get_hwf import get_hwf
from datasets.hed.get_hed import get_hed, split_equation
from abl import framework_hed_knn


def run_test():

# kb = add_KB(True)
# kb = HWF_KB(True)
# abducer = AbducerBase(kb)

kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl')
abducer = AbducerBase(kb, zoopt=True, multiple_predictions=True)

recorder = logger()

total_train_data = get_hed(train=True)
train_data, val_data = split_equation(total_train_data, 3, 1)
test_data = get_hed(train=False)

# ========================= KNN model ============================ #
reduce_dimension(train_data)
reduce_dimension(val_data)
reduce_dimension(test_data)
base_model = KNeighborsClassifier(n_neighbors=3)
pretrain_data_X, pretrain_data_Y = framework_hed_knn.hed_pretrain(base_model)
model = WABLBasicModel(base_model, kb.pseudo_label_list)
model, mapping = framework_hed_knn.train_with_rule(
model, abducer, train_data, val_data, (pretrain_data_X, pretrain_data_Y), select_num=10, min_len=5, max_len=8
)
framework_hed_knn.hed_test(
model, abducer, mapping, train_data, test_data, min_len=5, max_len=8
)
# ============================ End =============================== #

recorder.dump()
return True


if __name__ == "__main__":
run_test()

+ 0
- 159
examples/hed/hed_tmp.py View File

@@ -1,159 +0,0 @@
import os.path as osp

import numpy as np
import torch
import torch.nn as nn

from abl.evaluation import SemanticsMetric, SymbolMetric
from abl.learning import ABLModel, BasicNN
from abl.reasoning import PrologKB, ReasonerBase
from abl.utils import ABLLogger, print_log, reform_list
from examples.hed.datasets.get_hed import get_hed, split_equation
from examples.hed.hed_bridge import HEDBridge
from examples.models.nn import SymbolNet

# Build logger
print_log("Abductive Learning on the HED example.", logger="current")

# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")


### Logic Part
# Initialize knowledge base and abducer
class HedKB(PrologKB):
def __init__(self, pseudo_label_list, pl_file):
super().__init__(pseudo_label_list, pl_file)

def consist_rule(self, exs, rules):
rules = str(rules).replace("'", "")
return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0

def abduce_rules(self, pred_res):
prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res))
if len(prolog_result) == 0:
return None
prolog_rules = prolog_result[0]["X"]
rules = [rule.value for rule in prolog_rules]
return rules


class HedReasoner(ReasonerBase):
def revise_at_idx(self, data_sample):
revision_idx = np.where(np.array(data_sample.flatten("revision_flag")) != 0)[0]
candidate = self.kb.revise_at_idx(
data_sample.pred_pseudo_label, data_sample.Y, revision_idx
)
return candidate

def zoopt_revision_score(self, symbol_num, data_sample, sol):
revision_flag = reform_list(list(sol.get_x().astype(np.int32)), data_sample.pred_pseudo_label)
data_sample.revision_flag = revision_flag

lefted_idxs = [i for i in range(len(data_sample.pred_idx))]
candidate_size = []
while lefted_idxs:
idxs = []
idxs.append(lefted_idxs.pop(0))
max_candidate_idxs = []
found = False
for idx in range(-1, len(data_sample.pred_idx)):
if (not idx in idxs) and (idx >= 0):
idxs.append(idx)
candidate = self.revise_at_idx(data_sample[idxs])
if len(candidate) == 0:
if len(idxs) > 1:
idxs.pop()
else:
if len(idxs) > len(max_candidate_idxs):
found = True
max_candidate_idxs = idxs.copy()
removed = [i for i in lefted_idxs if i in max_candidate_idxs]
if found:
candidate_size.append(len(removed) + 1)
lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]
candidate_size.sort()
score = 0
import math

for i in range(0, len(candidate_size)):
score -= math.exp(-i) * candidate_size[i]
return score

def abduce(self, data_sample):
symbol_num = data_sample.elements_num("pred_pseudo_label")
max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)

solution = self.zoopt_get_solution(symbol_num, data_sample, max_revision_num)

data_sample.revision_flag = reform_list(
solution.astype(np.int32), data_sample.pred_pseudo_label
)

abduced_pseudo_label = []

for single_instance in data_sample:
single_instance.pred_pseudo_label = [single_instance.pred_pseudo_label]
candidates = self.revise_at_idx(single_instance)
if len(candidates) == 0:
abduced_pseudo_label.append([])
else:
abduced_pseudo_label.append(candidates[0][0])
data_sample.abduced_pseudo_label = abduced_pseudo_label
return abduced_pseudo_label

def abduce_rules(self, pred_res):
return self.kb.abduce_rules(pred_res)


import os

CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))

kb = HedKB(
pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "./datasets/learn_add.pl")
)
reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=20)

### Machine Learning Part
# Build necessary components for BasicNN
cls = SymbolNet(num_classes=4)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cls.parameters(), lr=0.001)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Build BasicNN
# The function of BasicNN is to wrap NN models into the form of an sklearn estimator
base_model = BasicNN(
cls,
criterion,
optimizer,
device,
batch_size=32,
num_epochs=1,
save_interval=1,
save_dir=weights_dir,
)

# Build ABLModel
# The main function of the ABL model is to serialize data and
# provide a unified interface for different machine learning models
model = ABLModel(base_model)

### Metric
# Set up metrics
metric_list = [SymbolMetric(prefix="hed"), SemanticsMetric(prefix="hed")]

### Bridge Machine Learning and Logic Reasoning
bridge = HEDBridge(model, reasoner, metric_list)

### Dataset
total_train_data = get_hed(train=True)
train_data, val_data = split_equation(total_train_data, 3, 1)
test_data = get_hed(train=False)

### Train and Test
bridge.pretrain("examples/hed/weights")
bridge.train(train_data, val_data)

+ 4
- 5
examples/hed/utils.py View File

@@ -1,6 +1,6 @@
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np
import torch.utils.data.sampler as sampler import torch.utils.data.sampler as sampler




@@ -13,7 +13,7 @@ class InfiniteSampler(sampler.Sampler):
while True: while True:
order = np.random.permutation(self.num_samples) order = np.random.permutation(self.num_samples)
for i in range(self.num_samples): for i in range(self.num_samples):
yield order[i: i + self.batch_size]
yield order[i : i + self.batch_size]
i += self.batch_size i += self.batch_size


def __len__(self): def __len__(self):
@@ -58,7 +58,6 @@ def reduce_dimension(data):
for equation_len in range(5, 27): for equation_len in range(5, 27):
equations = data[truth_value][equation_len] equations = data[truth_value][equation_len]
reduced_equations = [ reduced_equations = [
[extract_feature(symbol_img) for symbol_img in equation]
for equation in equations
[extract_feature(symbol_img) for symbol_img in equation] for equation in equations
] ]
data[truth_value][equation_len] = reduced_equations
data[truth_value][equation_len] = reduced_equations

+ 2
- 4
examples/hwf/datasets/get_hwf.py View File

@@ -1,14 +1,12 @@
import os
import json import json
import os


from PIL import Image from PIL import Image
from torchvision.transforms import transforms from torchvision.transforms import transforms


CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))


img_transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (1,))]
)
img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1,))])




def get_data(file, get_pseudo_label): def get_data(file, get_pseudo_label):


+ 12
- 8
examples/hwf/hwf_example.ipynb View File

@@ -51,14 +51,13 @@
"source": [ "source": [
"# Initialize knowledge base and reasoner\n", "# Initialize knowledge base and reasoner\n",
"class HWF_KB(KBBase):\n", "class HWF_KB(KBBase):\n",
"\n",
" def _valid_candidate(self, formula):\n", " def _valid_candidate(self, formula):\n",
" if len(formula) % 2 == 0:\n", " if len(formula) % 2 == 0:\n",
" return False\n", " return False\n",
" for i in range(len(formula)):\n", " for i in range(len(formula)):\n",
" if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']:\n",
" if i % 2 == 0 and formula[i] not in [\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\"]:\n",
" return False\n", " return False\n",
" if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']:\n",
" if i % 2 != 0 and formula[i] not in [\"+\", \"-\", \"times\", \"div\"]:\n",
" return False\n", " return False\n",
" return True\n", " return True\n",
"\n", "\n",
@@ -66,12 +65,17 @@
" if not self._valid_candidate(formula):\n", " if not self._valid_candidate(formula):\n",
" return np.inf\n", " return np.inf\n",
" mapping = {str(i): str(i) for i in range(1, 10)}\n", " mapping = {str(i): str(i) for i in range(1, 10)}\n",
" mapping.update({'+': '+', '-': '-', 'times': '*', 'div': '/'})\n",
" mapping.update({\"+\": \"+\", \"-\": \"-\", \"times\": \"*\", \"div\": \"/\"})\n",
" formula = [mapping[f] for f in formula]\n", " formula = [mapping[f] for f in formula]\n",
" return eval(''.join(formula))\n",
" return eval(\"\".join(formula))\n",
"\n",
"\n", "\n",
"kb = HWF_KB(pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], max_err=1e-10, use_cache=False)\n",
"reasoner = ReasonerBase(kb, dist_func='confidence')"
"kb = HWF_KB(\n",
" pseudo_label_list=[\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"+\", \"-\", \"times\", \"div\"],\n",
" max_err=1e-10,\n",
" use_cache=False,\n",
")\n",
"reasoner = ReasonerBase(kb, dist_func=\"confidence\")"
] ]
}, },
{ {
@@ -122,7 +126,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# Initialize ABL model\n", "# Initialize ABL model\n",
"# The main function of the ABL model is to serialize data and \n",
"# The main function of the ABL model is to serialize data and\n",
"# provide a unified interface for different machine learning models\n", "# provide a unified interface for different machine learning models\n",
"model = ABLModel(base_model)" "model = ABLModel(base_model)"
] ]


+ 1
- 1
examples/mnist_add/mnist_add_example.ipynb View File

@@ -80,7 +80,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# Build ABLModel\n", "# Build ABLModel\n",
"# The main function of the ABL model is to serialize data and \n",
"# The main function of the ABL model is to serialize data and\n",
"# provide a unified interface for different machine learning models\n", "# provide a unified interface for different machine learning models\n",
"model = ABLModel(base_model)" "model = ABLModel(base_model)"
] ]


+ 2
- 4
examples/models/nn.py View File

@@ -11,8 +11,8 @@
# ================================================================# # ================================================================#




import torch
import numpy as np import numpy as np
import torch
from torch import nn from torch import nn




@@ -84,9 +84,7 @@ class SymbolNetAutoencoder(nn.Module):
self.base_model = SymbolNet(num_classes, image_size) self.base_model = SymbolNet(num_classes, image_size)
self.softmax = nn.Softmax(dim=1) self.softmax = nn.Softmax(dim=1)
self.fc1 = nn.Sequential(nn.Linear(num_classes, 100), nn.ReLU()) self.fc1 = nn.Sequential(nn.Linear(num_classes, 100), nn.ReLU())
self.fc2 = nn.Sequential(
nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU()
)
self.fc2 = nn.Sequential(nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU())


def forward(self, x): def forward(self, x):
x = self.base_model(x) x = self.base_model(x)


+ 9
- 3
setup.py View File

@@ -1,4 +1,5 @@
import os import os

from setuptools import find_packages, setup from setuptools import find_packages, setup




@@ -27,7 +28,13 @@ here = os.path.abspath(os.path.dirname(__file__))
try: try:
with open(os.path.join(here, "requirements.txt"), encoding="utf-8") as f: with open(os.path.join(here, "requirements.txt"), encoding="utf-8") as f:
REQUIRED = f.read().split("\n") REQUIRED = f.read().split("\n")
except:
except FileNotFoundError:
# Handle the case where the file does not exist
print("requirements.txt file not found.")
REQUIRED = []
except Exception as e:
# Handle other possible exceptions
print(f"An error occurred: {e}")
REQUIRED = [] REQUIRED = []


EXTRAS = { EXTRAS = {
@@ -64,7 +71,7 @@ if __name__ == "__main__":
install_requires=REQUIRED, install_requires=REQUIRED,
extras_require=EXTRAS, extras_require=EXTRAS,
classifiers=[ classifiers=[
'Development Status :: 3 - Alpha',
"Development Status :: 3 - Alpha",
"Intended Audience :: Science/Research", "Intended Audience :: Science/Research",
"Intended Audience :: Developers", "Intended Audience :: Developers",
"Programming Language :: Python", "Programming Language :: Python",
@@ -74,4 +81,3 @@ if __name__ == "__main__":
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
], ],
) )

+ 65
- 22
tests/conftest.py View File

@@ -4,10 +4,11 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim


from abl.learning import BasicNN from abl.learning import BasicNN
from abl.reasoning import KBBase, GroundKB, PrologKB, Reasoner
from abl.reasoning import GroundKB, KBBase, PrologKB, Reasoner
from abl.structures import ListData from abl.structures import ListData
from examples.models.nn import LeNet5 from examples.models.nn import LeNet5



# Fixture for BasicNN instance # Fixture for BasicNN instance
@pytest.fixture @pytest.fixture
def basic_nn_instance(): def basic_nn_instance():
@@ -16,6 +17,7 @@ def basic_nn_instance():
optimizer = optim.Adam(model.parameters()) optimizer = optim.Adam(model.parameters())
return BasicNN(model, criterion, optimizer) return BasicNN(model, criterion, optimizer)



# Fixture for base_model instance # Fixture for base_model instance
@pytest.fixture @pytest.fixture
def base_model_instance(): def base_model_instance():
@@ -24,6 +26,7 @@ def base_model_instance():
optimizer = optim.Adam(model.parameters()) optimizer = optim.Adam(model.parameters())
return BasicNN(model, criterion, optimizer) return BasicNN(model, criterion, optimizer)



# Fixture for ListData instance # Fixture for ListData instance
@pytest.fixture @pytest.fixture
def list_data_instance(): def list_data_instance():
@@ -37,47 +40,71 @@ def list_data_instance():
@pytest.fixture @pytest.fixture
def data_samples_add(): def data_samples_add():
# favor 1 in first one # favor 1 in first one
prob1 = [[0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]
prob1 = [
[0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
]
# favor 7 in first one # favor 7 in first one
prob2 = [[0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]
prob2 = [
[0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
]

data_samples_add = ListData() data_samples_add = ListData()
data_samples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]] data_samples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]]
data_samples_add.pred_prob = [prob1, prob2, prob1, prob2] data_samples_add.pred_prob = [prob1, prob2, prob1, prob2]
data_samples_add.Y = [8, 8, 17, 10] data_samples_add.Y = [8, 8, 17, 10]
return data_samples_add return data_samples_add



@pytest.fixture @pytest.fixture
def data_samples_hwf(): def data_samples_hwf():
data_samples_hwf = ListData() data_samples_hwf = ListData()
data_samples_hwf.pred_pseudo_label = [["5", "+", "2"], ["5", "+", "9"], ["5", "+", "9"], ["5", "-", "8", "8", "8"]]
data_samples_hwf.pred_pseudo_label = [
["5", "+", "2"],
["5", "+", "9"],
["5", "+", "9"],
["5", "-", "8", "8", "8"],
]
data_samples_hwf.pred_prob = [None, None, None, None] data_samples_hwf.pred_prob = [None, None, None, None]
data_samples_hwf.Y = [3, 64, 65, 3.17] data_samples_hwf.Y = [3, 64, 65, 3.17]
return data_samples_hwf return data_samples_hwf



class AddKB(KBBase): class AddKB(KBBase):
def __init__(self, pseudo_label_list=list(range(10)),
use_cache=False):
def __init__(self, pseudo_label_list=list(range(10)), use_cache=False):
super().__init__(pseudo_label_list, use_cache=use_cache) super().__init__(pseudo_label_list, use_cache=use_cache)


def logic_forward(self, nums): def logic_forward(self, nums):
return sum(nums) return sum(nums)


class AddGroundKB(GroundKB): class AddGroundKB(GroundKB):
def __init__(self, pseudo_label_list=list(range(10)),
GKB_len_list=[2]):
def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]):
super().__init__(pseudo_label_list, GKB_len_list) super().__init__(pseudo_label_list, GKB_len_list)
def logic_forward(self, nums): def logic_forward(self, nums):
return sum(nums) return sum(nums)


class HwfKB(KBBase): class HwfKB(KBBase):
def __init__( def __init__(
self, self,
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9",
"+", "-", "times", "div"],
pseudo_label_list=[
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"+",
"-",
"times",
"div",
],
max_err=1e-3, max_err=1e-3,
use_cache=False, use_cache=False,
): ):
@@ -87,7 +114,17 @@ class HwfKB(KBBase):
if len(formula) % 2 == 0: if len(formula) % 2 == 0:
return False return False
for i in range(len(formula)): for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]:
if i % 2 == 0 and formula[i] not in [
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
]:
return False return False
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]:
return False return False
@@ -100,7 +137,8 @@ class HwfKB(KBBase):
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"})
formula = [mapping[f] for f in formula] formula = [mapping[f] for f in formula]
return eval("".join(formula)) return eval("".join(formula))


class HedKB(PrologKB): class HedKB(PrologKB):
def __init__(self, pseudo_label_list, pl_file): def __init__(self, pseudo_label_list, pl_file):
super().__init__(pseudo_label_list, pl_file) super().__init__(pseudo_label_list, pl_file)
@@ -110,24 +148,28 @@ class HedKB(PrologKB):
pl_query = "eval_inst_feature(%s, %s)." % (exs, rules) pl_query = "eval_inst_feature(%s, %s)." % (exs, rules)
return len(list(self.prolog.query(pl_query))) != 0 return len(list(self.prolog.query(pl_query))) != 0



@pytest.fixture @pytest.fixture
def kb_add(): def kb_add():
return AddKB() return AddKB()



@pytest.fixture @pytest.fixture
def kb_add_cache(): def kb_add_cache():
return AddKB(use_cache=True)
return AddKB(use_cache=True)



@pytest.fixture @pytest.fixture
def kb_add_ground(): def kb_add_ground():
return AddGroundKB() return AddGroundKB()



@pytest.fixture @pytest.fixture
def kb_add_prolog(): def kb_add_prolog():
kb = PrologKB(pseudo_label_list=list(range(10)),
pl_file="examples/mnist_add/datasets/add.pl")
kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/datasets/add.pl")
return kb return kb



@pytest.fixture @pytest.fixture
def kb_hed(): def kb_hed():
kb = HedKB( kb = HedKB(
@@ -136,6 +178,7 @@ def kb_hed():
) )
return kb return kb



@pytest.fixture @pytest.fixture
def reasoner_instance(kb_add): def reasoner_instance(kb_add):
return Reasoner(kb_add, "confidence")
return Reasoner(kb_add, "confidence")

+ 2
- 1
tests/test_abl_model.py View File

@@ -1,8 +1,9 @@
from unittest.mock import Mock, create_autospec

import numpy as np import numpy as np
import pytest import pytest


from abl.learning import ABLModel from abl.learning import ABLModel
from unittest.mock import Mock, create_autospec




class TestABLModel(object): class TestABLModel(object):


+ 106
- 42
tests/test_reasoning.py View File

@@ -1,63 +1,66 @@
import pytest import pytest
from abl.reasoning import KBBase, GroundKB, PrologKB, Reasoner

from abl.reasoning import PrologKB, Reasoner



class TestKBBase(object): class TestKBBase(object):
def test_init(self, kb_add): def test_init(self, kb_add):
assert kb_add.pseudo_label_list == list(range(10)) assert kb_add.pseudo_label_list == list(range(10))
def test_init_cache(self, kb_add_cache): def test_init_cache(self, kb_add_cache):
assert kb_add_cache.pseudo_label_list == list(range(10)) assert kb_add_cache.pseudo_label_list == list(range(10))
assert kb_add_cache.use_cache == True
assert kb_add_cache.use_cache is True
def test_logic_forward(self, kb_add): def test_logic_forward(self, kb_add):
result = kb_add.logic_forward([1, 2]) result = kb_add.logic_forward([1, 2])
assert result == 3 assert result == 3
def test_revise_at_idx(self, kb_add): def test_revise_at_idx(self, kb_add):
result = kb_add.revise_at_idx([1, 2], 2, [0]) result = kb_add.revise_at_idx([1, 2], 2, [0])
assert result == [[0, 2]] assert result == [[0, 2]]
def test_abduce_candidates(self, kb_add): def test_abduce_candidates(self, kb_add):
result = kb_add.abduce_candidates([1, 2], 1, max_revision_num=2,
require_more_revision=0)
result = kb_add.abduce_candidates([1, 2], 1, max_revision_num=2, require_more_revision=0)
assert result == [[1, 0]] assert result == [[1, 0]]


class TestGroundKB(object): class TestGroundKB(object):
def test_init(self, kb_add_ground): def test_init(self, kb_add_ground):
assert kb_add_ground.pseudo_label_list == list(range(10)) assert kb_add_ground.pseudo_label_list == list(range(10))
assert kb_add_ground.GKB_len_list == [2] assert kb_add_ground.GKB_len_list == [2]
assert kb_add_ground.GKB assert kb_add_ground.GKB
def test_logic_forward_ground(self, kb_add_ground): def test_logic_forward_ground(self, kb_add_ground):
result = kb_add_ground.logic_forward([1, 2]) result = kb_add_ground.logic_forward([1, 2])
assert result == 3 assert result == 3
def test_abduce_candidates_ground(self, kb_add_ground): def test_abduce_candidates_ground(self, kb_add_ground):
result = kb_add_ground.abduce_candidates([1, 2], 1, max_revision_num=2,
require_more_revision=0)
result = kb_add_ground.abduce_candidates(
[1, 2], 1, max_revision_num=2, require_more_revision=0
)
assert result == [(1, 0)] assert result == [(1, 0)]
class TestPrologKB(object):


class TestPrologKB(object):
def test_init_pl1(self, kb_add_prolog): def test_init_pl1(self, kb_add_prolog):
assert kb_add_prolog.pseudo_label_list == list(range(10)) assert kb_add_prolog.pseudo_label_list == list(range(10))
assert kb_add_prolog.pl_file == "examples/mnist_add/datasets/add.pl" assert kb_add_prolog.pl_file == "examples/mnist_add/datasets/add.pl"
def test_init_pl2(self, kb_hed): def test_init_pl2(self, kb_hed):
assert kb_hed.pseudo_label_list == [1, 0, "+", "="] assert kb_hed.pseudo_label_list == [1, 0, "+", "="]
assert kb_hed.pl_file == "examples/hed/datasets/learn_add.pl" assert kb_hed.pl_file == "examples/hed/datasets/learn_add.pl"
def test_prolog_file_not_exist(self): def test_prolog_file_not_exist(self):
pseudo_label_list = [1, 2] pseudo_label_list = [1, 2]
non_existing_file = "path/to/non_existing_file.pl" non_existing_file = "path/to/non_existing_file.pl"
with pytest.raises(FileNotFoundError) as excinfo: with pytest.raises(FileNotFoundError) as excinfo:
PrologKB(pseudo_label_list=pseudo_label_list,
pl_file=non_existing_file)
PrologKB(pseudo_label_list=pseudo_label_list, pl_file=non_existing_file)
assert non_existing_file in str(excinfo.value) assert non_existing_file in str(excinfo.value)
def test_logic_forward_pl1(self, kb_add_prolog): def test_logic_forward_pl1(self, kb_add_prolog):
result = kb_add_prolog.logic_forward([1, 2]) result = kb_add_prolog.logic_forward([1, 2])
assert result == 3 assert result == 3
def test_logic_forward_pl2(self, kb_hed): def test_logic_forward_pl2(self, kb_hed):
consist_exs = [ consist_exs = [
[1, 1, "+", 0, "=", 1, 1], [1, 1, "+", 0, "=", 1, 1],
@@ -70,21 +73,24 @@ class TestPrologKB(object):
[0, "+", 0, "=", 0], [0, "+", 0, "=", 0],
[0, "+", 0, "=", 1], [0, "+", 0, "=", 1],
] ]
assert kb_hed.logic_forward(consist_exs) == True
assert kb_hed.logic_forward(inconsist_exs) == False
assert kb_hed.logic_forward(consist_exs) is True
assert kb_hed.logic_forward(inconsist_exs) is False


def test_revise_at_idx(self, kb_add_prolog): def test_revise_at_idx(self, kb_add_prolog):
result = kb_add_prolog.revise_at_idx([1, 2], 2, [0]) result = kb_add_prolog.revise_at_idx([1, 2], 2, [0])
assert result == [[0, 2]] assert result == [[0, 2]]



class TestReaonser(object): class TestReaonser(object):
def test_reasoner_init(self, reasoner_instance): def test_reasoner_init(self, reasoner_instance):
assert reasoner_instance.dist_func == "confidence" assert reasoner_instance.dist_func == "confidence"
def test_invalid_dist_funce(kb_add): def test_invalid_dist_funce(kb_add):
with pytest.raises(NotImplementedError) as excinfo: with pytest.raises(NotImplementedError) as excinfo:
Reasoner(kb_add, "invalid_dist_func") Reasoner(kb_add, "invalid_dist_func")
assert "Valid options for dist_func include \"hamming\" and \"confidence\"" in str(excinfo.value)
assert 'Valid options for dist_func include "hamming" and "confidence"' in str(
excinfo.value
)




class test_batch_abduce(object): class test_batch_abduce(object):
@@ -95,8 +101,18 @@ class test_batch_abduce(object):
reasoner4 = Reasoner(kb_add, "confidence", max_revision=2, require_more_revision=1) reasoner4 = Reasoner(kb_add, "confidence", max_revision=2, require_more_revision=1)
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner3.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [1, 9]]
assert reasoner4.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [7, 3]]
assert reasoner3.batch_abduce(data_samples_add) == [
[1, 7],
[7, 1],
[8, 9],
[1, 9],
]
assert reasoner4.batch_abduce(data_samples_add) == [
[1, 7],
[7, 1],
[8, 9],
[7, 3],
]


def test_batch_abduce_ground(self, kb_add_ground, data_samples_add): def test_batch_abduce_ground(self, kb_add_ground, data_samples_add):
reasoner1 = Reasoner(kb_add_ground, "confidence", max_revision=1, require_more_revision=0) reasoner1 = Reasoner(kb_add_ground, "confidence", max_revision=1, require_more_revision=0)
@@ -105,8 +121,18 @@ class test_batch_abduce(object):
reasoner4 = Reasoner(kb_add_ground, "confidence", max_revision=2, require_more_revision=1) reasoner4 = Reasoner(kb_add_ground, "confidence", max_revision=2, require_more_revision=1)
assert reasoner1.batch_abduce(data_samples_add) == [(1, 7), (7, 1), [], (1, 9)] assert reasoner1.batch_abduce(data_samples_add) == [(1, 7), (7, 1), [], (1, 9)]
assert reasoner2.batch_abduce(data_samples_add) == [(1, 7), (7, 1), [], (1, 9)] assert reasoner2.batch_abduce(data_samples_add) == [(1, 7), (7, 1), [], (1, 9)]
assert reasoner3.batch_abduce(data_samples_add) == [(1, 7), (7, 1), (8, 9), (1, 9)]
assert reasoner4.batch_abduce(data_samples_add) == [(1, 7), (7, 1), (8, 9), (7, 3)]
assert reasoner3.batch_abduce(data_samples_add) == [
(1, 7),
(7, 1),
(8, 9),
(1, 9),
]
assert reasoner4.batch_abduce(data_samples_add) == [
(1, 7),
(7, 1),
(8, 9),
(7, 3),
]


def test_batch_abduce_prolog(self, kb_add_prolog, data_samples_add): def test_batch_abduce_prolog(self, kb_add_prolog, data_samples_add):
reasoner1 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=0) reasoner1 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=0)
@@ -115,35 +141,73 @@ class test_batch_abduce(object):
reasoner4 = Reasoner(kb_add_prolog, "confidence", max_revision=2, require_more_revision=1) reasoner4 = Reasoner(kb_add_prolog, "confidence", max_revision=2, require_more_revision=1)
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner3.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [1, 9]]
assert reasoner4.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [7, 3]]
assert reasoner3.batch_abduce(data_samples_add) == [
[1, 7],
[7, 1],
[8, 9],
[1, 9],
]
assert reasoner4.batch_abduce(data_samples_add) == [
[1, 7],
[7, 1],
[8, 9],
[7, 3],
]

def test_batch_abduce_zoopt(self, kb_add_prolog, data_samples_add): def test_batch_abduce_zoopt(self, kb_add_prolog, data_samples_add):
reasoner1 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=1) reasoner1 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=1)
reasoner2 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=2) reasoner2 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=2)
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [7, 3]]
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner2.batch_abduce(data_samples_add) == [
[1, 7],
[7, 1],
[8, 9],
[7, 3],
]


def test_batch_abduce_hwf1(self, kb_hwf1, data_samples_hwf): def test_batch_abduce_hwf1(self, kb_hwf1, data_samples_hwf):
reasoner1 = Reasoner(kb_hwf1, "hamming", max_revision=3, require_more_revision=0) reasoner1 = Reasoner(kb_hwf1, "hamming", max_revision=3, require_more_revision=0)
reasoner2 = Reasoner(kb_hwf1, "hamming", max_revision=0.5, require_more_revision=0) reasoner2 = Reasoner(kb_hwf1, "hamming", max_revision=0.5, require_more_revision=0)
reasoner3 = Reasoner(kb_hwf1, "hamming", max_revision=0.9, require_more_revision=0) reasoner3 = Reasoner(kb_hwf1, "hamming", max_revision=0.9, require_more_revision=0)
res = reasoner1.batch_abduce(data_samples_hwf) res = reasoner1.batch_abduce(data_samples_hwf)
assert res == [['1', '+', '2'], ['8', 'times', '8'], [], ['4', '-', '6', 'div', '8']]
assert res == [
["1", "+", "2"],
["8", "times", "8"],
[],
["4", "-", "6", "div", "8"],
]
res = reasoner2.batch_abduce(data_samples_hwf) res = reasoner2.batch_abduce(data_samples_hwf)
assert res == [['1', '+', '2'], [], [], []]
assert res == [["1", "+", "2"], [], [], []]
res = reasoner3.batch_abduce(data_samples_hwf) res = reasoner3.batch_abduce(data_samples_hwf)
assert res == [['1', '+', '2'], ['8', 'times', '8'], [], ['4', '-', '6', 'div', '8']]
assert res == [
["1", "+", "2"],
["8", "times", "8"],
[],
["4", "-", "6", "div", "8"],
]


def test_batch_abduce_hwf2(self, kb_hwf2, data_samples_hwf): def test_batch_abduce_hwf2(self, kb_hwf2, data_samples_hwf):
reasoner1 = Reasoner(kb_hwf2, "hamming", max_revision=3, require_more_revision=0) reasoner1 = Reasoner(kb_hwf2, "hamming", max_revision=3, require_more_revision=0)
reasoner2 = Reasoner(kb_hwf2, "hamming", max_revision=0.5, require_more_revision=0) reasoner2 = Reasoner(kb_hwf2, "hamming", max_revision=0.5, require_more_revision=0)
reasoner3 = Reasoner(kb_hwf2, "hamming", max_revision=0.9, require_more_revision=0) reasoner3 = Reasoner(kb_hwf2, "hamming", max_revision=0.9, require_more_revision=0)
res = reasoner1.batch_abduce(data_samples_hwf) res = reasoner1.batch_abduce(data_samples_hwf)
assert res == [['1', '+', '2'], ['7', 'times', '9'], ['8', 'times', '8'], ['5', '-', '8', 'div', '8']]
assert res == [
["1", "+", "2"],
["7", "times", "9"],
["8", "times", "8"],
["5", "-", "8", "div", "8"],
]
res = reasoner2.batch_abduce(data_samples_hwf) res = reasoner2.batch_abduce(data_samples_hwf)
assert res == [['1', '+', '2'], ['7', 'times', '9'], [], ['5', '-', '8', 'div', '8']]
assert res == [
["1", "+", "2"],
["7", "times", "9"],
[],
["5", "-", "8", "div", "8"],
]
res = reasoner3.batch_abduce(data_samples_hwf) res = reasoner3.batch_abduce(data_samples_hwf)
assert res == [['1', '+', '2'], ['7', 'times', '9'], ['8', 'times', '8'], ['5', '-', '8', 'div', '8']]

assert res == [
["1", "+", "2"],
["7", "times", "9"],
["8", "times", "8"],
["5", "-", "8", "div", "8"],
]

Loading…
Cancel
Save