Browse Source

[FIX] enable github actions running successfully

pull/1/head
Gao Enhao 1 year ago
parent
commit
9c888183f8
48 changed files with 426 additions and 944 deletions
  1. +8
    -1
      .coveragerc
  2. +1
    -1
      .github/workflows/build-and-test.yaml
  3. +2
    -2
      .github/workflows/lint.yaml
  4. +11
    -2
      abl/__init__.py
  5. +1
    -1
      abl/__version__.py
  6. +3
    -1
      abl/bridge/__init__.py
  7. +3
    -16
      abl/bridge/base_bridge.py
  8. +5
    -4
      abl/bridge/simple_bridge.py
  9. +7
    -0
      abl/dataset/__init__.py
  10. +6
    -2
      abl/dataset/bridge_dataset.py
  11. +5
    -5
      abl/dataset/classification_dataset.py
  12. +3
    -1
      abl/dataset/prediction_dataset.py
  13. +2
    -2
      abl/dataset/regression_dataset.py
  14. +2
    -0
      abl/evaluation/__init__.py
  15. +12
    -12
      abl/evaluation/base_metric.py
  16. +3
    -3
      abl/evaluation/symbol_metric.py
  17. +3
    -1
      abl/learning/__init__.py
  18. +9
    -6
      abl/learning/abl_model.py
  19. +7
    -4
      abl/learning/basic_nn.py
  20. +4
    -2
      abl/reasoning/__init__.py
  21. +29
    -30
      abl/reasoning/kb.py
  22. +6
    -10
      abl/reasoning/reasoner.py
  23. +3
    -1
      abl/structures/__init__.py
  24. +4
    -10
      abl/structures/base_data_element.py
  25. +1
    -2
      abl/structures/list_data.py
  26. +21
    -1
      abl/utils/__init__.py
  27. +0
    -6
      abl/utils/cache.py
  28. +15
    -7
      abl/utils/logger.py
  29. +12
    -12
      abl/utils/manager.py
  30. +0
    -57
      abl/utils/utils.py
  31. +3
    -0
      build_tools/requirements.txt
  32. +1
    -6
      docs/Examples/MNISTAdd.rst
  33. +5
    -7
      docs/conf.py
  34. +3
    -3
      examples/hed/datasets/get_hed.py
  35. +0
    -388
      examples/hed/framework_hed.py
  36. +15
    -14
      examples/hed/hed_bridge.py
  37. +8
    -6
      examples/hed/hed_example.ipynb
  38. +0
    -69
      examples/hed/hed_knn_example.py
  39. +0
    -159
      examples/hed/hed_tmp.py
  40. +4
    -5
      examples/hed/utils.py
  41. +2
    -4
      examples/hwf/datasets/get_hwf.py
  42. +12
    -8
      examples/hwf/hwf_example.ipynb
  43. +1
    -1
      examples/mnist_add/mnist_add_example.ipynb
  44. +2
    -4
      examples/models/nn.py
  45. +9
    -3
      setup.py
  46. +65
    -22
      tests/conftest.py
  47. +2
    -1
      tests/test_abl_model.py
  48. +106
    -42
      tests/test_reasoning.py

+ 8
- 1
.coveragerc View File

@@ -5,4 +5,11 @@ show_missing = True
disable_warnings = include-ignored
include = */abl/*
omit =
*/abl/__init__.py
*/abl/__init__.py
abl/bridge/__init__.py
abl/dataset/__init__.py
abl/evaluation/__init__.py
abl/learning/__init__.py
abl/reasoning/__init__.py
abl/structures/__init__.py
abl/utils/__init__.py

+ 1
- 1
.github/workflows/build-and-test.yaml View File

@@ -24,7 +24,7 @@ jobs:
- name: Install package dependencies
run: |
python -m pip install --upgrade pip
pip install -r ./requirements.txt
pip install -r build_tools/requirements.txt
- name: Run tests
run: |
pytest --cov-config=.coveragerc --cov-report=xml --cov=abl ./tests


+ 2
- 2
.github/workflows/lint.yaml View File

@@ -20,5 +20,5 @@ jobs:
- name: flake8 Lint
uses: py-actions/flake8@v2
with:
max-line-length: "110"
plugins: "flake8-bugbear flake8-black"
max-line-length: "100"
args: --ignore=E203,W503

+ 11
- 2
abl/__init__.py View File

@@ -1,2 +1,11 @@
from .learning import abl_model, basic_nn
from .reasoning import reasoner, kb
from . import bridge, dataset, evaluation, learning, reasoning, structures, utils

__all__ = [
"bridge",
"dataset",
"evaluation",
"learning",
"reasoning",
"structures",
"utils",
]

+ 1
- 1
abl/__version__.py View File

@@ -1,3 +1,3 @@
VERSION = (0, 0, 1)

__version__ = ".".join(map(str, VERSION))
__version__ = ".".join(map(str, VERSION))

+ 3
- 1
abl/bridge/__init__.py View File

@@ -1,2 +1,4 @@
from .base_bridge import BaseBridge
from .simple_bridge import SimpleBridge
from .simple_bridge import SimpleBridge

__all__ = ["BaseBridge", "SimpleBridge"]

+ 3
- 16
abl/bridge/base_bridge.py View File

@@ -12,53 +12,40 @@ class BaseBridge(metaclass=ABCMeta):
def __init__(self, model: ABLModel, reasoner: Reasoner) -> None:
if not isinstance(model, ABLModel):
raise TypeError(
"Expected an instance of ABLModel, but received type: {}".format(
type(model)
)
"Expected an instance of ABLModel, but received type: {}".format(type(model))
)
if not isinstance(reasoner, Reasoner):
raise TypeError(
"Expected an instance of Reasoner, but received type: {}".format(
type(reasoner)
)
"Expected an instance of Reasoner, but received type: {}".format(type(reasoner))
)

self.model = model
self.reasoner = reasoner

@abstractmethod
def predict(
self, data_samples: ListData
) -> Tuple[List[List[Any]], List[List[Any]]]:
def predict(self, data_samples: ListData) -> Tuple[List[List[Any]], List[List[Any]]]:
"""Placeholder for predict labels from input."""
pass

@abstractmethod
def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
"""Placeholder for abduce pseudo labels."""
pass

@abstractmethod
def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
"""Placeholder for map label space to symbol space."""
pass

@abstractmethod
def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]:
"""Placeholder for map symbol space to label space."""
pass

@abstractmethod
def train(self, train_data: Union[ListData, DataSet]):
"""Placeholder for train loop of ABductive Learning."""
pass

@abstractmethod
def valid(self, valid_data: Union[ListData, DataSet]) -> None:
"""Placeholder for model test."""
pass

@abstractmethod
def test(self, test_data: Union[ListData, DataSet]) -> None:
"""Placeholder for model validation."""
pass

+ 5
- 4
abl/bridge/simple_bridge.py View File

@@ -1,5 +1,5 @@
import os.path as osp
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union

from numpy import ndarray

@@ -32,8 +32,7 @@ class SimpleBridge(BaseBridge):
def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
pred_idx = data_samples.pred_idx
data_samples.pred_pseudo_label = [
[self.reasoner.mapping[_idx] for _idx in sub_list]
for sub_list in pred_idx
[self.reasoner.mapping[_idx] for _idx in sub_list] for sub_list in pred_idx
]
return data_samples.pred_pseudo_label

@@ -81,7 +80,9 @@ class SimpleBridge(BaseBridge):
loss = self.model.train(sub_data_samples)

print_log(
f"loop(train) [{loop + 1}/{loops}] segment(train) [{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] model loss is {loss:.5f}",
f"loop(train) [{loop + 1}/{loops}] segment(train) \
[{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] \
model loss is {loss:.5f}",
logger="current",
)



+ 7
- 0
abl/dataset/__init__.py View File

@@ -2,3 +2,10 @@ from .bridge_dataset import BridgeDataset
from .classification_dataset import ClassificationDataset
from .prediction_dataset import PredictionDataset
from .regression_dataset import RegressionDataset

__all__ = [
"BridgeDataset",
"ClassificationDataset",
"PredictionDataset",
"RegressionDataset",
]

+ 6
- 2
abl/dataset/bridge_dataset.py View File

@@ -13,11 +13,15 @@ class BridgeDataset(Dataset):
gt_pseudo_label : List[List[Any]], optional
A list of objects representing the ground truth label of each element in ``X``.
Y : List[Any]
A list of objects representing the ground truth of the reasoning result of each instance in ``X``.
A list of objects representing the ground truth of the reasoning result of
each instance in ``X``.
"""

def __init__(
self, X: List[List[Any]], gt_pseudo_label: Optional[List[List[Any]]], Y: List[Any]
self,
X: List[List[Any]],
gt_pseudo_label: Optional[List[List[Any]]],
Y: List[Any],
):
if (not isinstance(X, list)) or (not isinstance(Y, list)):
raise ValueError("X and Y should be of type list.")


+ 5
- 5
abl/dataset/classification_dataset.py View File

@@ -15,16 +15,16 @@ class ClassificationDataset(Dataset):
Y : List[int]
The target data.
transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version. Defaults to None.
A function/transform that takes in an object and returns a transformed version.
Defaults to None.
"""
def __init__(
self, X: List[Any], Y: List[int], transform: Callable[..., Any] = None
):

def __init__(self, X: List[Any], Y: List[int], transform: Callable[..., Any] = None):
if (not isinstance(X, list)) or (not isinstance(Y, list)):
raise ValueError("X and Y should be of type list.")
if len(X) != len(Y):
raise ValueError("Length of X and Y must be equal.")
self.X = X
self.Y = torch.LongTensor(Y)
self.transform = transform


+ 3
- 1
abl/dataset/prediction_dataset.py View File

@@ -13,8 +13,10 @@ class PredictionDataset(Dataset):
X : List[Any]
The input data.
transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version. Defaults to None.
A function/transform that takes in an object and returns a transformed version.
Defaults to None.
"""

def __init__(self, X: List[Any], transform: Callable[..., Any] = None):
if not isinstance(X, list):
raise ValueError("X should be of type list.")


+ 2
- 2
abl/dataset/regression_dataset.py View File

@@ -1,6 +1,5 @@
from typing import Any, List, Tuple

import torch
from torch.utils.data import Dataset


@@ -15,12 +14,13 @@ class RegressionDataset(Dataset):
Y : List[Any]
A list of objects representing the output data.
"""

def __init__(self, X: List[Any], Y: List[Any]):
if (not isinstance(X, list)) or (not isinstance(Y, list)):
raise ValueError("X and Y should be of type list.")
if len(X) != len(Y):
raise ValueError("Length of X and Y must be equal.")
self.X = X
self.Y = Y



+ 2
- 0
abl/evaluation/__init__.py View File

@@ -1,3 +1,5 @@
from .base_metric import BaseMetric
from .semantics_metric import SemanticsMetric
from .symbol_metric import SymbolMetric

__all__ = ["BaseMetric", "SemanticsMetric", "SymbolMetric"]

+ 12
- 12
abl/evaluation/base_metric.py View File

@@ -20,8 +20,10 @@ class BaseMetric(metaclass=ABCMeta):
will be used instead. Default: None
"""

def __init__(self,
prefix: Optional[str] = None,) -> None:
def __init__(
self,
prefix: Optional[str] = None,
) -> None:
self.results: List[Any] = []
self.prefix = prefix or self.default_prefix

@@ -65,20 +67,18 @@ class BaseMetric(metaclass=ABCMeta):
"""
if len(self.results) == 0:
print_log(
f'{self.__class__.__name__} got empty `self.results`. Please '
'ensure that the processed results are properly added into '
'`self.results` in `process` method.',
logger='current',
level=logging.WARNING)
f"{self.__class__.__name__} got empty `self.results`. Please "
"ensure that the processed results are properly added into "
"`self.results` in `process` method.",
logger="current",
level=logging.WARNING,
)

metrics = self.compute_metrics(self.results)
# Add prefix to metric names
if self.prefix:
metrics = {
'/'.join((self.prefix, k)): v
for k, v in metrics.items()
}
metrics = {"/".join((self.prefix, k)): v for k, v in metrics.items()}

# reset the results list
self.results.clear()
return metrics
return metrics

+ 3
- 3
abl/evaluation/symbol_metric.py View File

@@ -14,15 +14,15 @@ class SymbolMetric(BaseMetric):

if not len(pred_pseudo_label) == len(gt_pseudo_label):
raise ValueError("lengthes of pred_pseudo_label and gt_pseudo_label should be equal")
for pred_z, z in zip(pred_pseudo_label, gt_pseudo_label):
correct_num = 0
for pred_symbol, symbol in zip(pred_z, z):
if pred_symbol == symbol:
correct_num += 1
self.results.append(correct_num / len(z))
def compute_metrics(self, results: list) -> dict:
metrics = dict()
metrics["character_accuracy"] = sum(results) / len(results)
return metrics
return metrics

+ 3
- 1
abl/learning/__init__.py View File

@@ -1,2 +1,4 @@
from .abl_model import ABLModel
from .basic_nn import BasicNN
from .basic_nn import BasicNN

__all__ = ["ABLModel", "BasicNN"]

+ 9
- 6
abl/learning/abl_model.py View File

@@ -58,7 +58,8 @@ class ABLModel:
Parameters
----------
data_samples : ListData
A batch of data to train on, which typically contains the data, `X`, and the corresponding labels, `abduced_idx`.
A batch of data to train on, which typically contains the data, `X`, and the
corresponding labels, `abduced_idx`.

Returns
-------
@@ -68,7 +69,7 @@ class ABLModel:
data_X = data_samples.flatten("X")
data_y = data_samples.flatten("abduced_idx")
return self.base_model.fit(X=data_X, y=data_y)
def valid(self, data_samples: ListData) -> float:
"""
Validate the model on the given data.
@@ -76,7 +77,8 @@ class ABLModel:
Parameters
----------
data_samples : ListData
A batch of data to train on, which typically contains the data, `X`, and the corresponding labels, `abduced_idx`.
A batch of data to train on, which typically contains the data, `X`,
and the corresponding labels, `abduced_idx`.

Returns
-------
@@ -94,7 +96,7 @@ class ABLModel:
method = getattr(model, operation)
method(*args, **kwargs)
else:
if not f"{operation}_path" in kwargs.keys():
if f"{operation}_path" not in kwargs.keys():
raise ValueError(f"'{operation}_path' should not be None")
else:
try:
@@ -104,9 +106,10 @@ class ABLModel:
elif operation == "load":
with open(kwargs["load_path"], "rb") as file:
self.base_model = pickle.load(file)
except:
except (OSError, pickle.PickleError):
raise NotImplementedError(
f"{type(model).__name__} object doesn't have the {operation} method and the default pickle-based {operation} method failed."
f"{type(model).__name__} object doesn't have the {operation} method \
and the default pickle-based {operation} method failed."
)

def save(self, *args, **kwargs) -> None:


+ 7
- 4
abl/learning/basic_nn.py View File

@@ -1,5 +1,5 @@
import os
import logging
import os
from typing import Any, Callable, List, Optional, T, Tuple

import numpy
@@ -23,7 +23,8 @@ class BasicNN:
optimizer : torch.optim.Optimizer
The optimizer used for training.
device : torch.device, optional
The device on which the model will be trained or used for prediction, by default torch.device("cpu").
The device on which the model will be trained or used for prediction,
by default torch.device("cpu").
batch_size : int, optional
The batch size used for training, by default 32.
num_epochs : int, optional
@@ -37,9 +38,11 @@ class BasicNN:
save_dir : Optional[str], optional
The directory in which to save the model during training, by default None.
train_transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version used in the `fit` and `train_epoch` methods, by default None.
A function/transform that takes in an object and returns a transformed version used
in the `fit` and `train_epoch` methods, by default None.
test_transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version in the `predict`, `predict_proba` and `score` methods, , by default None.
A function/transform that takes in an object and returns a transformed version in the
`predict`, `predict_proba` and `score` methods, , by default None.
collate_fn : Callable[[List[T]], Any], optional
The function used to collate data, by default None.
"""


+ 4
- 2
abl/reasoning/__init__.py View File

@@ -1,2 +1,4 @@
from .kb import KBBase, GroundKB, PrologKB
from .reasoner import Reasoner
from .kb import GroundKB, KBBase, PrologKB
from .reasoner import Reasoner

__all__ = ["KBBase", "GroundKB", "PrologKB", "Reasoner"]

+ 29
- 30
abl/reasoning/kb.py View File

@@ -1,16 +1,15 @@
from abc import ABC, abstractmethod
import bisect
import os
from abc import ABC, abstractmethod
from collections import defaultdict
from itertools import product, combinations
from itertools import combinations, product
from multiprocessing import Pool
from functools import lru_cache

import numpy as np
import pyswip

from ..utils.utils import flatten, reform_list, hamming_dist, to_hashable
from ..utils.cache import abl_cache
from ..utils.utils import flatten, hamming_dist, reform_list, to_hashable


class KBBase(ABC):
@@ -20,19 +19,19 @@ class KBBase(ABC):
Parameters
----------
pseudo_label_list : list
List of possible pseudo labels. It's recommended to arrange the pseudo labels in this
list so that each aligns with its corresponding index in the base model: the first with
List of possible pseudo labels. It's recommended to arrange the pseudo labels in this
list so that each aligns with its corresponding index in the base model: the first with
the 0th index, the second with the 1st, and so forth.
max_err : float, optional
The upper tolerance limit when comparing the similarity between a pseudo label sample's reasoning
result and the ground truth. This is only applicable when the reasoning result is of a numerical type.
This is particularly relevant for regression problems where exact matches might not be
feasible. Defaults to 1e-10.
The upper tolerance limit when comparing the similarity between a pseudo label sample's
reasoning result and the ground truth. This is only applicable when the reasoning
result is of a numerical type. This is particularly relevant for regression problems where
exact matches might not be feasible. Defaults to 1e-10.
use_cache : bool, optional
Whether to use abl_cache for previously abduced candidates to speed up subsequent
operations. Defaults to True.
key_func : func, optional
A function employed for hashing in abl_cache. This is only operational when use_cache
A function employed for hashing in abl_cache. This is only operational when use_cache
is set to True. Defaults to to_hashable.
cache_size: int, optional
The cache size in abl_cache. This is only operational when use_cache is set to
@@ -75,7 +74,6 @@ class KBBase(ABC):
pseudo_label : List[Any]
Pseudo label sample.
"""
pass

def abduce_candidates(self, pseudo_label, y, max_revision_num, require_more_revision):
"""
@@ -104,7 +102,7 @@ class KBBase(ABC):
"""
Check whether the reasoning result of a pseduo label sample is equal to the ground truth
(or, within the maximum error allowed for numerical results).
Returns
-------
bool
@@ -130,7 +128,7 @@ class KBBase(ABC):
Ground truth of the reasoning result for the sample.
revision_idx : array-like
Indices of where revisions should be made to the pseudo label sample.
Returns
-------
List[List[Any]]
@@ -149,8 +147,8 @@ class KBBase(ABC):

def _revision(self, revision_num, pseudo_label, y):
"""
For a specified number of labels in a pseudo label sample to revise, iterate through all possible
indices to find any candidates that are compatible with the knowledge base.
For a specified number of labels in a pseudo label sample to revise, iterate through
all possible indices to find any candidates that are compatible with the knowledge base.
"""
new_candidates = []
revision_idx_list = combinations(range(len(pseudo_label)), revision_num)
@@ -164,8 +162,8 @@ class KBBase(ABC):
def _abduce_by_search(self, pseudo_label, y, max_revision_num, require_more_revision):
"""
Perform abductive reasoning by exhastive search. Specifically, begin with 0 and
continuously increase the number of labels in a pseudo label sample to revise, until candidates
that are compatible with the knowledge base are found.
continuously increase the number of labels in a pseudo label sample to revise, until
candidates that are compatible with the knowledge base are found.

Parameters
----------
@@ -177,8 +175,8 @@ class KBBase(ABC):
The upper limit on the number of revisions.
require_more_revision : int
If larger than 0, then after having found any candidates compatible with the
knowledge base, continue to increase the number of labels in a pseudo label sample to revise to
get more possible compatible candidates.
knowledge base, continue to increase the number of labels in a pseudo label sample to
revise to get more possible compatible candidates.

Returns
-------
@@ -286,7 +284,7 @@ class GroundKB(KBBase):
Perform abductive reasoning by directly retrieving compatible candidates from
the prebuilt GKB. In this way, the time-consuming exhaustive search can be
avoided.
Parameters
----------
pseudo_label : List[Any]
@@ -347,7 +345,7 @@ class GroundKB(KBBase):
num_candidates = len(self.GKB[i]) if i in self.GKB else 0
GKB_info_parts.append(f"{num_candidates} candidates of length {i}")
GKB_info = ", ".join(GKB_info_parts)
return (
f"{self.__class__.__name__} is a KB with "
f"pseudo_label_list={self.pseudo_label_list!r}, "
@@ -400,7 +398,7 @@ class PrologKB(KBBase):
returned `Res` as the reasoning results. To use this default function, there must be
a `logic_forward` method in the pl file to perform reasoning.
Otherwise, users would override this function.
Parameters
----------
pseudo_label : List[Any]
@@ -429,9 +427,10 @@ class PrologKB(KBBase):
def get_query_string(self, pseudo_label, y, revision_idx):
"""
Get the query to be used for consulting Prolog.
This is a default function for demo, users would override this function to adapt to their own
Prolog file. In this demo function, return query `logic_forward([kept_labels, Revise_labels], Res).`.
This is a default function for demo, users would override this function to adapt to
their own Prolog file. In this demo function, return query
`logic_forward([kept_labels, Revise_labels], Res).`.

Parameters
----------
pseudo_label : List[Any]
@@ -440,7 +439,7 @@ class PrologKB(KBBase):
Ground truth of the reasoning result for the sample.
revision_idx : array-like
Indices of where revisions should be made to the pseudo label sample.
Returns
-------
str
@@ -448,14 +447,14 @@ class PrologKB(KBBase):
"""
query_string = "logic_forward("
query_string += self._revision_pseudo_label(pseudo_label, revision_idx)
key_is_none_flag = y is None or (type(y) == list and y[0] is None)
key_is_none_flag = y is None or (isinstance(y, list) and y[0] is None)
query_string += ",%s)." % y if not key_is_none_flag else ")."
return query_string

def revise_at_idx(self, pseudo_label, y, revision_idx):
"""
Revise the pseudo label sample at specified index positions by querying Prolog.
Parameters
----------
pseudo_label : List[Any]
@@ -464,7 +463,7 @@ class PrologKB(KBBase):
Ground truth of the reasoning result for the sample.
revision_idx : array-like
Indices of where revisions should be made to the pseudo label sample.
Returns
-------
List[List[Any]]


+ 6
- 10
abl/reasoning/reasoner.py View File

@@ -1,11 +1,7 @@
import numpy as np
from zoopt import Dimension, Objective, Parameter, Opt
from ..utils.utils import (
confidence_dist,
flatten,
reform_list,
hamming_dist,
)
from zoopt import Dimension, Objective, Opt, Parameter

from ..utils.utils import confidence_dist, hamming_dist


class Reasoner:
@@ -124,7 +120,7 @@ class Reasoner:

def zoopt_get_solution(self, symbol_num, data_sample, max_revision_num):
"""
Get the optimal solution using ZOOpt library. The solution is a list of
Get the optimal solution using ZOOpt library. The solution is a list of
boolean values, where '1' (True) indicates the indices chosen to be revised.

Parameters
@@ -148,7 +144,7 @@ class Reasoner:

def zoopt_revision_score(self, symbol_num, data_sample, sol):
"""
Get the revision score for a solution. A lower score suggests that ZOOpt library
Get the revision score for a solution. A lower score suggests that ZOOpt library
has a higher preference for this solution.
"""
revision_idx = np.where(sol.get_x() != 0)[0]
@@ -198,7 +194,7 @@ class Reasoner:
Returns
-------
List[Any]
A revised pseudo label sample through abductive reasoning, which is compatible
A revised pseudo label sample through abductive reasoning, which is compatible
with the knowledge base.
"""
symbol_num = data_sample.elements_num("pred_pseudo_label")


+ 3
- 1
abl/structures/__init__.py View File

@@ -1,2 +1,4 @@
from .base_data_element import BaseDataElement
from .list_data import ListData
from .list_data import ListData

__all__ = ["BaseDataElement", "ListData"]

+ 4
- 10
abl/structures/base_data_element.py View File

@@ -224,9 +224,7 @@ class BaseDataElement:
metainfo (dict): A dict contains the meta information
of image, such as ``img_shape``, ``scale_factor``, etc.
"""
assert isinstance(
metainfo, dict
), f"metainfo should be a ``dict`` but got {type(metainfo)}"
assert isinstance(metainfo, dict), f"metainfo should be a ``dict`` but got {type(metainfo)}"
meta = copy.deepcopy(metainfo)
for k, v in meta.items():
self.set_field(name=k, value=v, field_type="metainfo", dtype=None)
@@ -388,8 +386,7 @@ class BaseDataElement:
super().__setattr__(name, value)
else:
raise AttributeError(
f"{name} has been used as a "
"private attribute, which is immutable."
f"{name} has been used as a " "private attribute, which is immutable."
)
else:
self.set_field(name=name, value=value, field_type="data", dtype=None)
@@ -458,9 +455,7 @@ class BaseDataElement:
functions."""
assert field_type in ["metainfo", "data"]
if dtype is not None:
assert isinstance(
value, dtype
), f"{value} should be a {dtype} but got {type(value)}"
assert isinstance(value, dtype), f"{value} should be a {dtype} but got {type(value)}"

if field_type == "metainfo":
if name in self._data_fields:
@@ -571,8 +566,7 @@ class BaseDataElement:
def to_dict(self) -> dict:
"""Convert BaseDataElement to dict."""
return {
k: v.to_dict() if isinstance(v, BaseDataElement) else v
for k, v in self.all_items()
k: v.to_dict() if isinstance(v, BaseDataElement) else v for k, v in self.all_items()
}

def __repr__(self) -> str:


+ 1
- 2
abl/structures/list_data.py View File

@@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from collections.abc import Sized
from typing import Any, List, Union
from typing import List, Union

import numpy as np
import torch


+ 21
- 1
abl/utils/__init__.py View File

@@ -1,3 +1,23 @@
from .cache import Cache, abl_cache
from .logger import ABLLogger, print_log
from .utils import *
from .utils import (
calculate_revision_num,
confidence_dist,
flatten,
hamming_dist,
reform_list,
to_hashable,
)

__all__ = [
"Cache",
"ABLLogger",
"print_log",
"calculate_revision_num",
"confidence_dist",
"flatten",
"hamming_dist",
"reform_list",
"to_hashable",
"abl_cache",
]

+ 0
- 6
abl/utils/cache.py View File

@@ -1,10 +1,5 @@
import pickle
import os
import os.path as osp
from typing import Callable, Generic, TypeVar

from .logger import print_log, ABLLogger

K = TypeVar("K")
T = TypeVar("T")
PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields
@@ -73,7 +68,6 @@ class Cache(Generic[K, T]):
# Empty the oldest link and make it the new root.
self.root = oldroot[NEXT]
oldkey = self.root[KEY]
oldresult = self.root[RESULT]
self.root[KEY] = self.root[RESULT] = None
# Now update the cache dictionary.
del self.cache_dict[oldkey]


+ 15
- 7
abl/utils/logger.py View File

@@ -15,7 +15,8 @@ class FilterDuplicateWarning(logging.Filter):
"""
Filter for eliminating repeated warning messages in logging.

This filter checks for duplicate warning messages and allows only the first occurrence of each message to be logged, filtering out subsequent duplicates.
This filter checks for duplicate warning messages and allows only the first occurrence of
each message to be logged, filtering out subsequent duplicates.

Parameters
----------
@@ -145,7 +146,8 @@ class ABLLogger(Logger, ManagerMixin):

`ABLLogger` provides a formatted logger that can log messages with different
log levels. It allows the creation of logger instances in a similar manner to `ManagerMixin`.
The logger has features like distributed log storage and colored terminal output for different log levels.
The logger has features like distributed log storage and colored terminal output for different
log levels.

Parameters
----------
@@ -154,7 +156,8 @@ class ABLLogger(Logger, ManagerMixin):
logger_name : str, optional
`name` attribute of `logging.Logger` instance. Defaults to 'abl'.
log_file : str, optional
The log filename. If specified, a `FileHandler` will be added to the logger. Defaults to None.
The log filename. If specified, a `FileHandler` will be added to the logger.
Defaults to None.
log_level : Union[int, str]
The log level of the handler. Defaults to 'INFO'.
If log level is 'DEBUG', distributed logs will be saved during distributed training.
@@ -287,20 +290,25 @@ def print_log(msg, logger: Optional[Union[Logger, str]] = None, level=logging.IN
"""
Print a log message using the specified logger or a default method.

This function logs a message with a given logger, if provided, or prints it using the standard `print` function. It supports special logger types such as 'silent' and 'current'.
This function logs a message with a given logger, if provided, or prints it using
the standard `print` function. It supports special logger types such as 'silent' and 'current'.

Parameters
----------
msg : str
The message to be logged.
logger : Optional[Union[Logger, str]], optional
The logger to use for logging the message. It can be a `logging.Logger` instance, a string specifying the logger name, 'silent', 'current', or None. If None, the `print` method is used.
The logger to use for logging the message. It can be a `logging.Logger` instance, a string
specifying the logger name, 'silent', 'current', or None. If None, the `print`
method is used.
- 'silent': No message will be printed.
- 'current': Use the latest created logger to log the message.
- other str: The instance name of the logger. A `ValueError` is raised if the logger has not been created.
- other str: The instance name of the logger. A `ValueError` is raised if the logger has not
been created.
- None: The `print()` method is used for logging.
level : int, optional
The logging level. This is only applicable when `logger` is a Logger object, 'current', or a named logger instance. The default is `logging.INFO`.
The logging level. This is only applicable when `logger` is a Logger object, 'current',
or a named logger instance. The default is `logging.INFO`.
"""
if logger is None:
print(msg)


+ 12
- 12
abl/utils/manager.py View File

@@ -6,7 +6,7 @@ from collections import OrderedDict
from typing import Type, TypeVar

_lock = threading.RLock()
T = TypeVar('T')
T = TypeVar("T")


def _accquire_lock() -> None:
@@ -47,7 +47,7 @@ class ManagerMeta(type):
cls._instance_dict = OrderedDict()
params = inspect.getfullargspec(cls)
params_names = params[0] if params[0] else []
assert 'name' in params_names, f'{cls} must have the `name` argument'
assert "name" in params_names, f"{cls} must have the `name` argument"
super().__init__(*args)


@@ -72,9 +72,8 @@ class ManagerMixin(metaclass=ManagerMeta):
name (str): Name of the instance. Defaults to ''.
"""

def __init__(self, name: str = '', **kwargs):
assert isinstance(name, str) and name, \
'name argument must be an non-empty string.'
def __init__(self, name: str = "", **kwargs):
assert isinstance(name, str) and name, "name argument must be an non-empty string."
self._instance_name = name

@classmethod
@@ -102,8 +101,7 @@ class ManagerMixin(metaclass=ManagerMeta):
instance.
"""
_accquire_lock()
assert isinstance(name, str), \
f'type of name should be str, but got {type(cls)}'
assert isinstance(name, str), f"type of name should be str, but got {type(cls)}"
instance_dict = cls._instance_dict # type: ignore
# Get the instance by name.
if name not in instance_dict:
@@ -111,9 +109,10 @@ class ManagerMixin(metaclass=ManagerMeta):
instance_dict[name] = instance # type: ignore
elif kwargs:
warnings.warn(
f'{cls} instance named of {name} has been created, '
'the method `get_instance` should not accept any other '
'arguments')
f"{cls} instance named of {name} has been created, "
"the method `get_instance` should not accept any other "
"arguments"
)
# Get latest instantiated instance or root instance.
_release_lock()
return instance_dict[name]
@@ -141,8 +140,9 @@ class ManagerMixin(metaclass=ManagerMeta):
_accquire_lock()
if not cls._instance_dict:
raise RuntimeError(
f'Before calling {cls.__name__}.get_current_instance(), you '
'should call get_instance(name=xxx) at least once.')
f"Before calling {cls.__name__}.get_current_instance(), you "
"should call get_instance(name=xxx) at least once."
)
name = next(iter(reversed(cls._instance_dict)))
_release_lock()
return cls._instance_dict[name]


+ 0
- 57
abl/utils/utils.py View File

@@ -221,60 +221,3 @@ def calculate_revision_num(parameter, total_length):
if parameter < 0:
raise ValueError("If parameter is an int, it must be non-negative.")
return parameter


if __name__ == "__main__":
A = np.array(
[
[
0.18401675,
0.06797526,
0.06797541,
0.06801736,
0.06797528,
0.06797526,
0.06818808,
0.06797527,
0.06800033,
0.06797526,
0.06797526,
0.06797526,
0.06797526,
],
[
0.07223161,
0.0685229,
0.06852708,
0.17227574,
0.06852163,
0.07018146,
0.06860291,
0.06852849,
0.06852163,
0.0685216,
0.0685216,
0.06852174,
0.0685216,
],
[
0.06794382,
0.0679436,
0.06794395,
0.06794346,
0.06794346,
0.18467231,
0.06794345,
0.06794871,
0.06794345,
0.06794345,
0.06794345,
0.06794345,
0.06794345,
],
],
dtype=np.float32,
)
B = [[0, 9, 3], [0, 11, 4]]

print(ori_confidence_dist(A, B))
print(confidence_dist(A, B))

+ 3
- 0
build_tools/requirements.txt View File

@@ -0,0 +1,3 @@
-r ../requirements.txt
pytest
pytest-cov

+ 1
- 6
docs/Examples/MNISTAdd.rst View File

@@ -3,7 +3,7 @@ MNIST Addition

MNIST Addition was first introduced in [1] and the inputs of this task are pairs of MNIST images and the outputs are their sums. The dataset looks like this:

.. image:: ../img/image_1.jpg
.. image:: ../img/Datasets_1.png
:width: 350px
:align: center

@@ -11,9 +11,4 @@ MNIST Addition was first introduced in [1] and the inputs of this task are pairs

The ``gt_pseudo_label`` is only used to test the performance of the machine learning model and is not used in the training phase.

In the Abductive Learning framework, the inference process is as follows:

.. image:: ../img/image_2.jpg
:width: 700px

[1] Robin Manhaeve, Sebastijan Dumancic, Angelika Kimmig, Thomas Demeester, and Luc De Raedt. Deepproblog: Neural probabilistic logic programming. In Advances in Neural Information Processing Systems 31 (NeurIPS), pages 3749-3759.2018.

+ 5
- 7
docs/conf.py View File

@@ -1,14 +1,12 @@
import sys
import os
import re
import sys

if not "READTHEDOCS" in os.environ:

if "READTHEDOCS" not in os.environ:
sys.path.insert(0, os.path.abspath(".."))
sys.path.append(os.path.abspath("./ABL/"))

# from sphinx.locale import _
from sphinx_rtd_theme import __version__


project = "ABL"
slug = re.sub(r"\W+", "-", project.lower())
@@ -48,8 +46,8 @@ pygments_style = "default"

html_theme = "sphinx_rtd_theme"
html_theme_options = {"display_version": True}
html_static_path = ['_static']
html_css_files = ['custom.css']
html_static_path = ["_static"]
html_css_files = ["custom.css"]
# html_theme_path = ["../.."]
# html_logo = "demo/static/logo-wordmark-light.svg"
# html_show_sourcelink = True


+ 3
- 3
examples/hed/datasets/get_hed.py View File

@@ -1,11 +1,11 @@
import os
import os.path as osp
import cv2
import pickle
import numpy as np
import random

from collections import defaultdict

import cv2
import numpy as np
from torchvision.transforms import transforms

CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))


+ 0
- 388
examples/hed/framework_hed.py View File

@@ -1,388 +0,0 @@
# coding: utf-8
# ================================================================#
# Copyright (C) 2021 Freecss All rights reserved.
#
# File Name :framework.py
# Author :freecss
# Email :karlfreecss@gmail.com
# Created Date :2021/06/07
# Description :
#
# ================================================================#

import torch
import torch.nn as nn
import numpy as np
import os

from abl.utils.plog import INFO
from abl.utils.utils import flatten, reform_idx
from abl.learning.basic_nn import BasicNN, BasicDataset

from utils import gen_mappings, mapping_res, remapping_res
from models.nn import SymbolNetAutoencoder
from torch.utils.data import RandomSampler
from datasets.get_hed import get_pretrain_data


def hed_pretrain(kb, cls, recorder):
cls_autoencoder = SymbolNetAutoencoder(num_classes=len(kb.pseudo_label_list))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if not os.path.exists("./weights/pretrain_weights.pth"):
INFO("Pretrain Start")
pretrain_data_X, pretrain_data_Y = get_pretrain_data(["0", "1", "10", "11"])
pretrain_data = BasicDataset(pretrain_data_X, pretrain_data_Y)
pretrain_data_loader = torch.utils.data.DataLoader(
pretrain_data, batch_size=64, shuffle=True
)

criterion = nn.MSELoss()
optimizer = torch.optim.RMSprop(
cls_autoencoder.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6
)

pretrain_model = BasicNN(
cls_autoencoder,
criterion,
optimizer,
device,
save_interval=1,
save_dir=recorder.save_dir,
num_epochs=10,
recorder=recorder,
)
pretrain_model.fit(pretrain_data_loader)
torch.save(
cls_autoencoder.base_model.state_dict(), "./weights/pretrain_weights.pth"
)
cls.load_state_dict(cls_autoencoder.base_model.state_dict())

else:
cls.load_state_dict(torch.load("./weights/pretrain_weights.pth"))


def _get_char_acc(model, X, consistent_pred_res, mapping):
original_pred_res = model.predict(X)["label"]
pred_res = flatten(mapping_res(original_pred_res, mapping))
INFO("Current model's output: ", pred_res)
INFO("Abduced labels: ", flatten(consistent_pred_res))
assert len(pred_res) == len(flatten(consistent_pred_res))
return sum(
[
pred_res[idx] == flatten(consistent_pred_res)[idx]
for idx in range(len(pred_res))
]
) / len(pred_res)


def abduce_and_train(model, abducer, mapping, train_X_true, select_num):
select_idx = RandomSampler(train_X_true, num_samples=select_num,replacement=False)
X = [train_X_true[idx] for idx in select_idx]

# original_pred_res = model.predict(X)['label']
pred_label = model.predict(X)["label"]

if mapping == None:
mappings = gen_mappings([0, 1, 2, 3], ["+", "=", 0, 1])
else:
mappings = [mapping]

consistent_idx = []
consistent_pred_res = []

for m in mappings:
pred_pseudo_label = mapping_res(pred_label, m)
max_revision_num = 20
solution = abducer.zoopt_get_solution(
pred_label,
pred_pseudo_label,
[None] * len(pred_label),
[None] * len(pred_label),
max_revision_num,
)
all_address_flag = reform_idx(solution, pred_label)

consistent_idx_tmp = []
consistent_pred_res_tmp = []

for idx in range(len(pred_label)):
address_idx = [
i for i, flag in enumerate(all_address_flag[idx]) if flag != 0
]
candidate = abducer.revise_by_idx([pred_pseudo_label[idx]], None, address_idx)
if len(candidate) > 0:
consistent_idx_tmp.append(idx)
consistent_pred_res_tmp.append(candidate[0][0])

if len(consistent_idx_tmp) > len(consistent_idx):
consistent_idx = consistent_idx_tmp
consistent_pred_res = consistent_pred_res_tmp
if len(mappings) > 1:
mapping = m

if len(consistent_idx) == 0:
return 0, 0, None

INFO("Train pool size is:", len(flatten(consistent_pred_res)))
INFO("Start to use abduced pseudo label to train model...")
model.train(
[X[idx] for idx in consistent_idx], remapping_res(consistent_pred_res, mapping)
)

consistent_acc = len(consistent_idx) / select_num
char_acc = _get_char_acc(
model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping
)
INFO("consistent_acc is %s, char_acc is %s" % (consistent_acc, char_acc))
return consistent_acc, char_acc, mapping


# def abduce_and_train(model, abducer, mapping, train_X_true, select_num):
# select_idx = np.random.randint(len(train_X_true), size=select_num)
# X = []
# for idx in select_idx:
# X.append(train_X_true[idx])

# original_pred_res = model.predict(X)['label']

# if mapping == None:
# mappings = gen_mappings([0, 1, 2, 3],['+', '=', 0, 1])
# else:
# mappings = [mapping]

# consistent_idx = []
# consistent_pred_res = []

# for m in mappings:
# pred_res = mapping_res(original_pred_res, m)
# max_abduce_num = 20
# solution = abducer.zoopt_get_solution(pred_res, [None] * len(pred_res), [None] * len(pred_res), max_abduce_num)
# all_address_flag = reform_idx(solution, pred_res)

# consistent_idx_tmp = []
# consistent_pred_res_tmp = []

# for idx in range(len(pred_res)):
# address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0]
# candidate = abducer.revise_by_idx([pred_res[idx]], None, address_idx)
# if len(candidate) > 0:
# consistent_idx_tmp.append(idx)
# consistent_pred_res_tmp.append(candidate[0][0])

# if len(consistent_idx_tmp) > len(consistent_idx):
# consistent_idx = consistent_idx_tmp
# consistent_pred_res = consistent_pred_res_tmp
# if len(mappings) > 1:
# mapping = m

# if len(consistent_idx) == 0:
# return 0, 0, None

# INFO('Train pool size is:', len(flatten(consistent_pred_res)))
# INFO("Start to use abduced pseudo label to train model...")
# model.train([X[idx] for idx in consistent_idx], remapping_res(consistent_pred_res, mapping))

# consistent_acc = len(consistent_idx) / select_num
# char_acc = _get_char_acc(model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping)
# INFO('consistent_acc is %s, char_acc is %s' % (consistent_acc, char_acc))
# return consistent_acc, char_acc, mapping


def _remove_duplicate_rule(rule_dict):
add_nums_dict = {}
for r in list(rule_dict):
add_nums = str(r.split("]")[0].split("[")[1]) + str(
r.split("]")[1].split("[")[1]
) # r = 'my_op([1], [0], [1, 0])' then add_nums = '10'
if add_nums in add_nums_dict:
old_r = add_nums_dict[add_nums]
if rule_dict[r] >= rule_dict[old_r]:
rule_dict.pop(old_r)
add_nums_dict[add_nums] = r
else:
rule_dict.pop(r)
else:
add_nums_dict[add_nums] = r
return list(rule_dict)


def get_rules_from_data(
model, abducer, mapping, train_X_true, samples_per_rule, samples_num
):
rules = []
for _ in range(samples_num):
while True:
select_idx = np.random.randint(len(train_X_true), size=samples_per_rule)
X = []
for idx in select_idx:
X.append(train_X_true[idx])
original_pred_res = model.predict(X)["label"]
pred_res = mapping_res(original_pred_res, mapping)

consistent_idx = []
consistent_pred_res = []
for idx in range(len(pred_res)):
if abducer.kb.logic_forward([pred_res[idx]]):
consistent_idx.append(idx)
consistent_pred_res.append(pred_res[idx])

if len(consistent_pred_res) != 0:
rule = abducer.abduce_rules(consistent_pred_res)
if rule != None:
break
rules.append(rule)

all_rule_dict = {}
for rule in rules:
for r in rule:
all_rule_dict[r] = 1 if r not in all_rule_dict else all_rule_dict[r] + 1
rule_dict = {rule: cnt for rule, cnt in all_rule_dict.items() if cnt >= 5}
rules = _remove_duplicate_rule(rule_dict)

return rules


def _get_consist_rule_acc(model, abducer, mapping, rules, X):
cnt = 0
for x in X:
original_pred_res = model.predict([x])["label"]
pred_res = flatten(mapping_res(original_pred_res, mapping))
if abducer.kb.consist_rule(pred_res, rules):
cnt += 1
return cnt / len(X)


def train_with_rule(
model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8
):
train_X = train_data
val_X = val_data

samples_num = 50
samples_per_rule = 3

# Start training / for each length of equations
for equation_len in range(min_len, max_len):
INFO(
"============== equation_len: %d-%d ================"
% (equation_len, equation_len + 1)
)
train_X_true = train_X[1][equation_len]
train_X_false = train_X[0][equation_len]
val_X_true = val_X[1][equation_len]
val_X_false = val_X[0][equation_len]

train_X_true.extend(train_X[1][equation_len + 1])
train_X_false.extend(train_X[0][equation_len + 1])
val_X_true.extend(val_X[1][equation_len + 1])
val_X_false.extend(val_X[0][equation_len + 1])

condition_cnt = 0
while True:
if equation_len == min_len:
mapping = None

# Abduce and train NN
consistent_acc, char_acc, mapping = abduce_and_train(
model, abducer, mapping, train_X_true, select_num
)
if consistent_acc == 0:
continue

# Test if we can use mlp to evaluate
if consistent_acc >= 0.9 and char_acc >= 0.9:
condition_cnt += 1
else:
condition_cnt = 0

# The condition has been satisfied continuously five times
if condition_cnt >= 5:
INFO("Now checking if we can go to next course")
rules = get_rules_from_data(
model, abducer, mapping, train_X_true, samples_per_rule, samples_num
)
INFO("Learned rules from data:", rules)

true_consist_rule_acc = _get_consist_rule_acc(
model, abducer, mapping, rules, val_X_true
)
false_consist_rule_acc = _get_consist_rule_acc(
model, abducer, mapping, rules, val_X_false
)

INFO(
"consist_rule_acc is %f, %f\n"
% (true_consist_rule_acc, false_consist_rule_acc)
)
# decide next course or restart
if true_consist_rule_acc > 0.95 and false_consist_rule_acc < 0.1:
torch.save(
model.classifier_list[0].model.state_dict(),
"./weights/weights_%d.pth" % equation_len,
)
break
else:
if equation_len == min_len:
INFO("Final mapping is: ", mapping)
model.classifier_list[0].model.load_state_dict(
torch.load("./weights/pretrain_weights.pth")
)
else:
model.classifier_list[0].model.load_state_dict(
torch.load("./weights/weights_%d.pth" % (equation_len - 1))
)
condition_cnt = 0
INFO("Reload Model and retrain")

return model, mapping


def hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8):
train_X = train_data
test_X = test_data

# Calcualte how many equations should be selected in each length
# for each length, there are equation_samples_num[equation_len] rules
print("Now begin to train final mlp model")
equation_samples_num = []
len_cnt = max_len - min_len + 1
samples_num = 50
equation_samples_num += [0] * min_len
if samples_num % len_cnt == 0:
equation_samples_num += [samples_num // len_cnt] * len_cnt
else:
equation_samples_num += [samples_num // len_cnt] * len_cnt
equation_samples_num[-1] += samples_num % len_cnt
assert sum(equation_samples_num) == samples_num

# Abduce rules
rules = []
samples_per_rule = 3
for equation_len in range(min_len, max_len + 1):
equation_rules = get_rules_from_data(
model,
abducer,
mapping,
train_X[1][equation_len],
samples_per_rule,
equation_samples_num[equation_len],
)
rules.extend(equation_rules)
rules = list(set(rules))
INFO("Learned rules from data:", rules)

for equation_len in range(5, 27):
true_consist_rule_acc = _get_consist_rule_acc(
model, abducer, mapping, rules, test_X[1][equation_len]
)
false_consist_rule_acc = _get_consist_rule_acc(
model, abducer, mapping, rules, test_X[0][equation_len]
)
INFO(
"consist_rule_acc of testing length %d equations are %f, %f"
% (equation_len, true_consist_rule_acc, false_consist_rule_acc)
)


if __name__ == "__main__":
pass

+ 15
- 14
examples/hed/hed_bridge.py View File

@@ -1,20 +1,18 @@
import os
from collections import defaultdict
from typing import Any, List
import torch
from torch.utils.data import DataLoader

from abl.reasoning import ReasonerBase
from abl.learning import ABLModel, BasicNN
from abl.bridge import SimpleBridge
from abl.dataset import RegressionDataset
from abl.evaluation import BaseMetric
from abl.dataset import BridgeDataset, RegressionDataset
from abl.learning import ABLModel, BasicNN
from abl.reasoning import ReasonerBase
from abl.structures import ListData
from abl.utils import print_log

from examples.hed.utils import gen_mappings, InfiniteSampler
from examples.models.nn import SymbolNetAutoencoder
from examples.hed.datasets.get_hed import get_pretrain_data
from examples.hed.utils import InfiniteSampler, gen_mappings
from examples.models.nn import SymbolNetAutoencoder


class HEDBridge(SimpleBridge):
@@ -95,7 +93,8 @@ class HEDBridge(SimpleBridge):
character_accuracy = self.model.valid(filtered_data_samples)
revisible_ratio = len(filtered_data_samples.X) / len(data_samples.X)
print_log(
f"Revisible ratio is {revisible_ratio:.3f}, Character accuracy is {character_accuracy:.3f}",
f"Revisible ratio is {revisible_ratio:.3f}, Character \
accuracy is {character_accuracy:.3f}",
logger="current",
)

@@ -111,7 +110,8 @@ class HEDBridge(SimpleBridge):
false_ratio = self.calc_consistent_ratio(val_X_false, rule)

print_log(
f"True consistent ratio is {true_ratio:.3f}, False inconsistent ratio is {1 - false_ratio:.3f}",
f"True consistent ratio is {true_ratio:.3f}, False inconsistent ratio \
is {1 - false_ratio:.3f}",
logger="current",
)

@@ -143,7 +143,7 @@ class HEDBridge(SimpleBridge):

if len(consistent_instance) != 0:
rule = self.reasoner.abduce_rules(consistent_instance)
if rule != None:
if rule is not None:
rules.append(rule)
break

@@ -214,7 +214,8 @@ class HEDBridge(SimpleBridge):
loss = self.model.train(filtered_sub_data_samples)

print_log(
f"Equation Len(train) [{equation_len}] Segment Index [{seg_idx + 1}] model loss is {loss:.5f}",
f"Equation Len(train) [{equation_len}] Segment Index [{seg_idx + 1}] \
model loss is {loss:.5f}",
logger="current",
)

@@ -224,11 +225,11 @@ class HEDBridge(SimpleBridge):
condition_num = 0

if condition_num >= 5:
print_log(f"Now checking if we can go to next course", logger="current")
print_log("Now checking if we can go to next course", logger="current")
rules = self.get_rules_from_data(
data_samples, samples_per_rule=3, samples_num=50
)
print_log(f"Learned rules from data: " + str(rules), logger="current")
print_log("Learned rules from data: " + str(rules), logger="current")

seems_good = self.check_rule_quality(rules, val_data, equation_len)
if seems_good:


+ 8
- 6
examples/hed/hed_example.ipynb View File

@@ -66,7 +66,8 @@
" prolog_rules = prolog_result[0][\"X\"]\n",
" rules = [rule.value for rule in prolog_rules]\n",
" return rules\n",
" \n",
"\n",
"\n",
"class HedReasoner(ReasonerBase):\n",
" def revise_at_idx(self, data_sample):\n",
" revision_idx = np.where(np.array(data_sample.flatten(\"revision_flag\")) != 0)[0]\n",
@@ -76,7 +77,9 @@
" return candidate\n",
"\n",
" def zoopt_revision_score(self, symbol_num, data_sample, sol):\n",
" revision_flag = reform_list(list(sol.get_x().astype(np.int32)), data_sample.pred_pseudo_label)\n",
" revision_flag = reform_list(\n",
" list(sol.get_x().astype(np.int32)), data_sample.pred_pseudo_label\n",
" )\n",
" data_sample.revision_flag = revision_flag\n",
"\n",
" lefted_idxs = [i for i in range(len(data_sample.pred_idx))]\n",
@@ -108,7 +111,7 @@
" for i in range(0, len(candidate_size)):\n",
" score -= math.exp(-i) * candidate_size[i]\n",
" return score\n",
" \n",
"\n",
" def abduce(self, data_sample):\n",
" symbol_num = data_sample.elements_num(\"pred_pseudo_label\")\n",
" max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)\n",
@@ -134,9 +137,8 @@
" def abduce_rules(self, pred_res):\n",
" return self.kb.abduce_rules(pred_res)\n",
"\n",
"kb = HedKB(\n",
" pseudo_label_list=[1, 0, \"+\", \"=\"], pl_file=\"./datasets/learn_add.pl\"\n",
")\n",
"\n",
"kb = HedKB(pseudo_label_list=[1, 0, \"+\", \"=\"], pl_file=\"./datasets/learn_add.pl\")\n",
"reasoner = HedReasoner(kb, dist_func=\"hamming\", use_zoopt=True, max_revision=20)"
]
},


+ 0
- 69
examples/hed/hed_knn_example.py View File

@@ -1,69 +0,0 @@
# coding: utf-8
# ================================================================#
# Copyright (C) 2021 Freecss All rights reserved.
#
# File Name :share_example.py
# Author :freecss
# Email :karlfreecss@gmail.com
# Created Date :2021/06/07
# Description :
#
# ================================================================#

import sys
sys.path.append("../")

from abl.utils.plog import logger, INFO
from abl.utils.utils import reduce_dimension
import torch.nn as nn
import torch

from abl.models.nn import LeNet5, SymbolNet
from abl.models.basic_model import BasicModel, BasicDataset
from abl.models.wabl_models import DecisionTree, WABLBasicModel
from sklearn.neighbors import KNeighborsClassifier

from abl.abducer.abducer_base import AbducerBase
from abl.abducer.kb import add_KB, HWF_KB, prolog_KB
from datasets.mnist_add.get_mnist_add import get_mnist_add
from datasets.hwf.get_hwf import get_hwf
from datasets.hed.get_hed import get_hed, split_equation
from abl import framework_hed_knn


def run_test():

# kb = add_KB(True)
# kb = HWF_KB(True)
# abducer = AbducerBase(kb)

kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl')
abducer = AbducerBase(kb, zoopt=True, multiple_predictions=True)

recorder = logger()

total_train_data = get_hed(train=True)
train_data, val_data = split_equation(total_train_data, 3, 1)
test_data = get_hed(train=False)

# ========================= KNN model ============================ #
reduce_dimension(train_data)
reduce_dimension(val_data)
reduce_dimension(test_data)
base_model = KNeighborsClassifier(n_neighbors=3)
pretrain_data_X, pretrain_data_Y = framework_hed_knn.hed_pretrain(base_model)
model = WABLBasicModel(base_model, kb.pseudo_label_list)
model, mapping = framework_hed_knn.train_with_rule(
model, abducer, train_data, val_data, (pretrain_data_X, pretrain_data_Y), select_num=10, min_len=5, max_len=8
)
framework_hed_knn.hed_test(
model, abducer, mapping, train_data, test_data, min_len=5, max_len=8
)
# ============================ End =============================== #

recorder.dump()
return True


if __name__ == "__main__":
run_test()

+ 0
- 159
examples/hed/hed_tmp.py View File

@@ -1,159 +0,0 @@
import os.path as osp

import numpy as np
import torch
import torch.nn as nn

from abl.evaluation import SemanticsMetric, SymbolMetric
from abl.learning import ABLModel, BasicNN
from abl.reasoning import PrologKB, ReasonerBase
from abl.utils import ABLLogger, print_log, reform_list
from examples.hed.datasets.get_hed import get_hed, split_equation
from examples.hed.hed_bridge import HEDBridge
from examples.models.nn import SymbolNet

# Build logger
print_log("Abductive Learning on the HED example.", logger="current")

# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")


### Logic Part
# Initialize knowledge base and abducer
class HedKB(PrologKB):
def __init__(self, pseudo_label_list, pl_file):
super().__init__(pseudo_label_list, pl_file)

def consist_rule(self, exs, rules):
rules = str(rules).replace("'", "")
return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0

def abduce_rules(self, pred_res):
prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res))
if len(prolog_result) == 0:
return None
prolog_rules = prolog_result[0]["X"]
rules = [rule.value for rule in prolog_rules]
return rules


class HedReasoner(ReasonerBase):
def revise_at_idx(self, data_sample):
revision_idx = np.where(np.array(data_sample.flatten("revision_flag")) != 0)[0]
candidate = self.kb.revise_at_idx(
data_sample.pred_pseudo_label, data_sample.Y, revision_idx
)
return candidate

def zoopt_revision_score(self, symbol_num, data_sample, sol):
revision_flag = reform_list(list(sol.get_x().astype(np.int32)), data_sample.pred_pseudo_label)
data_sample.revision_flag = revision_flag

lefted_idxs = [i for i in range(len(data_sample.pred_idx))]
candidate_size = []
while lefted_idxs:
idxs = []
idxs.append(lefted_idxs.pop(0))
max_candidate_idxs = []
found = False
for idx in range(-1, len(data_sample.pred_idx)):
if (not idx in idxs) and (idx >= 0):
idxs.append(idx)
candidate = self.revise_at_idx(data_sample[idxs])
if len(candidate) == 0:
if len(idxs) > 1:
idxs.pop()
else:
if len(idxs) > len(max_candidate_idxs):
found = True
max_candidate_idxs = idxs.copy()
removed = [i for i in lefted_idxs if i in max_candidate_idxs]
if found:
candidate_size.append(len(removed) + 1)
lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]
candidate_size.sort()
score = 0
import math

for i in range(0, len(candidate_size)):
score -= math.exp(-i) * candidate_size[i]
return score

def abduce(self, data_sample):
symbol_num = data_sample.elements_num("pred_pseudo_label")
max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)

solution = self.zoopt_get_solution(symbol_num, data_sample, max_revision_num)

data_sample.revision_flag = reform_list(
solution.astype(np.int32), data_sample.pred_pseudo_label
)

abduced_pseudo_label = []

for single_instance in data_sample:
single_instance.pred_pseudo_label = [single_instance.pred_pseudo_label]
candidates = self.revise_at_idx(single_instance)
if len(candidates) == 0:
abduced_pseudo_label.append([])
else:
abduced_pseudo_label.append(candidates[0][0])
data_sample.abduced_pseudo_label = abduced_pseudo_label
return abduced_pseudo_label

def abduce_rules(self, pred_res):
return self.kb.abduce_rules(pred_res)


import os

CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))

kb = HedKB(
pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "./datasets/learn_add.pl")
)
reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=20)

### Machine Learning Part
# Build necessary components for BasicNN
cls = SymbolNet(num_classes=4)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cls.parameters(), lr=0.001)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Build BasicNN
# The function of BasicNN is to wrap NN models into the form of an sklearn estimator
base_model = BasicNN(
cls,
criterion,
optimizer,
device,
batch_size=32,
num_epochs=1,
save_interval=1,
save_dir=weights_dir,
)

# Build ABLModel
# The main function of the ABL model is to serialize data and
# provide a unified interface for different machine learning models
model = ABLModel(base_model)

### Metric
# Set up metrics
metric_list = [SymbolMetric(prefix="hed"), SemanticsMetric(prefix="hed")]

### Bridge Machine Learning and Logic Reasoning
bridge = HEDBridge(model, reasoner, metric_list)

### Dataset
total_train_data = get_hed(train=True)
train_data, val_data = split_equation(total_train_data, 3, 1)
test_data = get_hed(train=False)

### Train and Test
bridge.pretrain("examples/hed/weights")
bridge.train(train_data, val_data)

+ 4
- 5
examples/hed/utils.py View File

@@ -1,6 +1,6 @@
import numpy as np
import torch
import torch.nn as nn
import numpy as np
import torch.utils.data.sampler as sampler


@@ -13,7 +13,7 @@ class InfiniteSampler(sampler.Sampler):
while True:
order = np.random.permutation(self.num_samples)
for i in range(self.num_samples):
yield order[i: i + self.batch_size]
yield order[i : i + self.batch_size]
i += self.batch_size

def __len__(self):
@@ -58,7 +58,6 @@ def reduce_dimension(data):
for equation_len in range(5, 27):
equations = data[truth_value][equation_len]
reduced_equations = [
[extract_feature(symbol_img) for symbol_img in equation]
for equation in equations
[extract_feature(symbol_img) for symbol_img in equation] for equation in equations
]
data[truth_value][equation_len] = reduced_equations
data[truth_value][equation_len] = reduced_equations

+ 2
- 4
examples/hwf/datasets/get_hwf.py View File

@@ -1,14 +1,12 @@
import os
import json
import os

from PIL import Image
from torchvision.transforms import transforms

CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))

img_transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (1,))]
)
img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1,))])


def get_data(file, get_pseudo_label):


+ 12
- 8
examples/hwf/hwf_example.ipynb View File

@@ -51,14 +51,13 @@
"source": [
"# Initialize knowledge base and reasoner\n",
"class HWF_KB(KBBase):\n",
"\n",
" def _valid_candidate(self, formula):\n",
" if len(formula) % 2 == 0:\n",
" return False\n",
" for i in range(len(formula)):\n",
" if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']:\n",
" if i % 2 == 0 and formula[i] not in [\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\"]:\n",
" return False\n",
" if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']:\n",
" if i % 2 != 0 and formula[i] not in [\"+\", \"-\", \"times\", \"div\"]:\n",
" return False\n",
" return True\n",
"\n",
@@ -66,12 +65,17 @@
" if not self._valid_candidate(formula):\n",
" return np.inf\n",
" mapping = {str(i): str(i) for i in range(1, 10)}\n",
" mapping.update({'+': '+', '-': '-', 'times': '*', 'div': '/'})\n",
" mapping.update({\"+\": \"+\", \"-\": \"-\", \"times\": \"*\", \"div\": \"/\"})\n",
" formula = [mapping[f] for f in formula]\n",
" return eval(''.join(formula))\n",
" return eval(\"\".join(formula))\n",
"\n",
"\n",
"kb = HWF_KB(pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], max_err=1e-10, use_cache=False)\n",
"reasoner = ReasonerBase(kb, dist_func='confidence')"
"kb = HWF_KB(\n",
" pseudo_label_list=[\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"+\", \"-\", \"times\", \"div\"],\n",
" max_err=1e-10,\n",
" use_cache=False,\n",
")\n",
"reasoner = ReasonerBase(kb, dist_func=\"confidence\")"
]
},
{
@@ -122,7 +126,7 @@
"outputs": [],
"source": [
"# Initialize ABL model\n",
"# The main function of the ABL model is to serialize data and \n",
"# The main function of the ABL model is to serialize data and\n",
"# provide a unified interface for different machine learning models\n",
"model = ABLModel(base_model)"
]


+ 1
- 1
examples/mnist_add/mnist_add_example.ipynb View File

@@ -80,7 +80,7 @@
"outputs": [],
"source": [
"# Build ABLModel\n",
"# The main function of the ABL model is to serialize data and \n",
"# The main function of the ABL model is to serialize data and\n",
"# provide a unified interface for different machine learning models\n",
"model = ABLModel(base_model)"
]


+ 2
- 4
examples/models/nn.py View File

@@ -11,8 +11,8 @@
# ================================================================#


import torch
import numpy as np
import torch
from torch import nn


@@ -84,9 +84,7 @@ class SymbolNetAutoencoder(nn.Module):
self.base_model = SymbolNet(num_classes, image_size)
self.softmax = nn.Softmax(dim=1)
self.fc1 = nn.Sequential(nn.Linear(num_classes, 100), nn.ReLU())
self.fc2 = nn.Sequential(
nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU()
)
self.fc2 = nn.Sequential(nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU())

def forward(self, x):
x = self.base_model(x)


+ 9
- 3
setup.py View File

@@ -1,4 +1,5 @@
import os

from setuptools import find_packages, setup


@@ -27,7 +28,13 @@ here = os.path.abspath(os.path.dirname(__file__))
try:
with open(os.path.join(here, "requirements.txt"), encoding="utf-8") as f:
REQUIRED = f.read().split("\n")
except:
except FileNotFoundError:
# Handle the case where the file does not exist
print("requirements.txt file not found.")
REQUIRED = []
except Exception as e:
# Handle other possible exceptions
print(f"An error occurred: {e}")
REQUIRED = []

EXTRAS = {
@@ -64,7 +71,7 @@ if __name__ == "__main__":
install_requires=REQUIRED,
extras_require=EXTRAS,
classifiers=[
'Development Status :: 3 - Alpha',
"Development Status :: 3 - Alpha",
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"Programming Language :: Python",
@@ -74,4 +81,3 @@ if __name__ == "__main__":
"Programming Language :: Python :: 3.8",
],
)

+ 65
- 22
tests/conftest.py View File

@@ -4,10 +4,11 @@ import torch.nn as nn
import torch.optim as optim

from abl.learning import BasicNN
from abl.reasoning import KBBase, GroundKB, PrologKB, Reasoner
from abl.reasoning import GroundKB, KBBase, PrologKB, Reasoner
from abl.structures import ListData
from examples.models.nn import LeNet5


# Fixture for BasicNN instance
@pytest.fixture
def basic_nn_instance():
@@ -16,6 +17,7 @@ def basic_nn_instance():
optimizer = optim.Adam(model.parameters())
return BasicNN(model, criterion, optimizer)


# Fixture for base_model instance
@pytest.fixture
def base_model_instance():
@@ -24,6 +26,7 @@ def base_model_instance():
optimizer = optim.Adam(model.parameters())
return BasicNN(model, criterion, optimizer)


# Fixture for ListData instance
@pytest.fixture
def list_data_instance():
@@ -37,47 +40,71 @@ def list_data_instance():
@pytest.fixture
def data_samples_add():
# favor 1 in first one
prob1 = [[0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]
prob1 = [
[0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
]
# favor 7 in first one
prob2 = [[0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]
prob2 = [
[0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
]

data_samples_add = ListData()
data_samples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]]
data_samples_add.pred_prob = [prob1, prob2, prob1, prob2]
data_samples_add.Y = [8, 8, 17, 10]
return data_samples_add


@pytest.fixture
def data_samples_hwf():
data_samples_hwf = ListData()
data_samples_hwf.pred_pseudo_label = [["5", "+", "2"], ["5", "+", "9"], ["5", "+", "9"], ["5", "-", "8", "8", "8"]]
data_samples_hwf.pred_pseudo_label = [
["5", "+", "2"],
["5", "+", "9"],
["5", "+", "9"],
["5", "-", "8", "8", "8"],
]
data_samples_hwf.pred_prob = [None, None, None, None]
data_samples_hwf.Y = [3, 64, 65, 3.17]
return data_samples_hwf


class AddKB(KBBase):
def __init__(self, pseudo_label_list=list(range(10)),
use_cache=False):
def __init__(self, pseudo_label_list=list(range(10)), use_cache=False):
super().__init__(pseudo_label_list, use_cache=use_cache)

def logic_forward(self, nums):
return sum(nums)


class AddGroundKB(GroundKB):
def __init__(self, pseudo_label_list=list(range(10)),
GKB_len_list=[2]):
def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]):
super().__init__(pseudo_label_list, GKB_len_list)
def logic_forward(self, nums):
return sum(nums)


class HwfKB(KBBase):
def __init__(
self,
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9",
"+", "-", "times", "div"],
pseudo_label_list=[
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"+",
"-",
"times",
"div",
],
max_err=1e-3,
use_cache=False,
):
@@ -87,7 +114,17 @@ class HwfKB(KBBase):
if len(formula) % 2 == 0:
return False
for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]:
if i % 2 == 0 and formula[i] not in [
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
]:
return False
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]:
return False
@@ -100,7 +137,8 @@ class HwfKB(KBBase):
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"})
formula = [mapping[f] for f in formula]
return eval("".join(formula))


class HedKB(PrologKB):
def __init__(self, pseudo_label_list, pl_file):
super().__init__(pseudo_label_list, pl_file)
@@ -110,24 +148,28 @@ class HedKB(PrologKB):
pl_query = "eval_inst_feature(%s, %s)." % (exs, rules)
return len(list(self.prolog.query(pl_query))) != 0


@pytest.fixture
def kb_add():
return AddKB()


@pytest.fixture
def kb_add_cache():
return AddKB(use_cache=True)
return AddKB(use_cache=True)


@pytest.fixture
def kb_add_ground():
return AddGroundKB()


@pytest.fixture
def kb_add_prolog():
kb = PrologKB(pseudo_label_list=list(range(10)),
pl_file="examples/mnist_add/datasets/add.pl")
kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/datasets/add.pl")
return kb


@pytest.fixture
def kb_hed():
kb = HedKB(
@@ -136,6 +178,7 @@ def kb_hed():
)
return kb


@pytest.fixture
def reasoner_instance(kb_add):
return Reasoner(kb_add, "confidence")
return Reasoner(kb_add, "confidence")

+ 2
- 1
tests/test_abl_model.py View File

@@ -1,8 +1,9 @@
from unittest.mock import Mock, create_autospec

import numpy as np
import pytest

from abl.learning import ABLModel
from unittest.mock import Mock, create_autospec


class TestABLModel(object):


+ 106
- 42
tests/test_reasoning.py View File

@@ -1,63 +1,66 @@
import pytest
from abl.reasoning import KBBase, GroundKB, PrologKB, Reasoner

from abl.reasoning import PrologKB, Reasoner


class TestKBBase(object):
def test_init(self, kb_add):
assert kb_add.pseudo_label_list == list(range(10))
def test_init_cache(self, kb_add_cache):
assert kb_add_cache.pseudo_label_list == list(range(10))
assert kb_add_cache.use_cache == True
assert kb_add_cache.use_cache is True
def test_logic_forward(self, kb_add):
result = kb_add.logic_forward([1, 2])
assert result == 3
def test_revise_at_idx(self, kb_add):
result = kb_add.revise_at_idx([1, 2], 2, [0])
assert result == [[0, 2]]
def test_abduce_candidates(self, kb_add):
result = kb_add.abduce_candidates([1, 2], 1, max_revision_num=2,
require_more_revision=0)
result = kb_add.abduce_candidates([1, 2], 1, max_revision_num=2, require_more_revision=0)
assert result == [[1, 0]]


class TestGroundKB(object):
def test_init(self, kb_add_ground):
assert kb_add_ground.pseudo_label_list == list(range(10))
assert kb_add_ground.GKB_len_list == [2]
assert kb_add_ground.GKB
def test_logic_forward_ground(self, kb_add_ground):
result = kb_add_ground.logic_forward([1, 2])
assert result == 3
def test_abduce_candidates_ground(self, kb_add_ground):
result = kb_add_ground.abduce_candidates([1, 2], 1, max_revision_num=2,
require_more_revision=0)
result = kb_add_ground.abduce_candidates(
[1, 2], 1, max_revision_num=2, require_more_revision=0
)
assert result == [(1, 0)]
class TestPrologKB(object):


class TestPrologKB(object):
def test_init_pl1(self, kb_add_prolog):
assert kb_add_prolog.pseudo_label_list == list(range(10))
assert kb_add_prolog.pl_file == "examples/mnist_add/datasets/add.pl"
def test_init_pl2(self, kb_hed):
assert kb_hed.pseudo_label_list == [1, 0, "+", "="]
assert kb_hed.pl_file == "examples/hed/datasets/learn_add.pl"
def test_prolog_file_not_exist(self):
pseudo_label_list = [1, 2]
non_existing_file = "path/to/non_existing_file.pl"
with pytest.raises(FileNotFoundError) as excinfo:
PrologKB(pseudo_label_list=pseudo_label_list,
pl_file=non_existing_file)
PrologKB(pseudo_label_list=pseudo_label_list, pl_file=non_existing_file)
assert non_existing_file in str(excinfo.value)
def test_logic_forward_pl1(self, kb_add_prolog):
result = kb_add_prolog.logic_forward([1, 2])
assert result == 3
def test_logic_forward_pl2(self, kb_hed):
consist_exs = [
[1, 1, "+", 0, "=", 1, 1],
@@ -70,21 +73,24 @@ class TestPrologKB(object):
[0, "+", 0, "=", 0],
[0, "+", 0, "=", 1],
]
assert kb_hed.logic_forward(consist_exs) == True
assert kb_hed.logic_forward(inconsist_exs) == False
assert kb_hed.logic_forward(consist_exs) is True
assert kb_hed.logic_forward(inconsist_exs) is False

def test_revise_at_idx(self, kb_add_prolog):
result = kb_add_prolog.revise_at_idx([1, 2], 2, [0])
assert result == [[0, 2]]


class TestReaonser(object):
def test_reasoner_init(self, reasoner_instance):
assert reasoner_instance.dist_func == "confidence"
def test_invalid_dist_funce(kb_add):
with pytest.raises(NotImplementedError) as excinfo:
Reasoner(kb_add, "invalid_dist_func")
assert "Valid options for dist_func include \"hamming\" and \"confidence\"" in str(excinfo.value)
assert 'Valid options for dist_func include "hamming" and "confidence"' in str(
excinfo.value
)


class test_batch_abduce(object):
@@ -95,8 +101,18 @@ class test_batch_abduce(object):
reasoner4 = Reasoner(kb_add, "confidence", max_revision=2, require_more_revision=1)
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner3.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [1, 9]]
assert reasoner4.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [7, 3]]
assert reasoner3.batch_abduce(data_samples_add) == [
[1, 7],
[7, 1],
[8, 9],
[1, 9],
]
assert reasoner4.batch_abduce(data_samples_add) == [
[1, 7],
[7, 1],
[8, 9],
[7, 3],
]

def test_batch_abduce_ground(self, kb_add_ground, data_samples_add):
reasoner1 = Reasoner(kb_add_ground, "confidence", max_revision=1, require_more_revision=0)
@@ -105,8 +121,18 @@ class test_batch_abduce(object):
reasoner4 = Reasoner(kb_add_ground, "confidence", max_revision=2, require_more_revision=1)
assert reasoner1.batch_abduce(data_samples_add) == [(1, 7), (7, 1), [], (1, 9)]
assert reasoner2.batch_abduce(data_samples_add) == [(1, 7), (7, 1), [], (1, 9)]
assert reasoner3.batch_abduce(data_samples_add) == [(1, 7), (7, 1), (8, 9), (1, 9)]
assert reasoner4.batch_abduce(data_samples_add) == [(1, 7), (7, 1), (8, 9), (7, 3)]
assert reasoner3.batch_abduce(data_samples_add) == [
(1, 7),
(7, 1),
(8, 9),
(1, 9),
]
assert reasoner4.batch_abduce(data_samples_add) == [
(1, 7),
(7, 1),
(8, 9),
(7, 3),
]

def test_batch_abduce_prolog(self, kb_add_prolog, data_samples_add):
reasoner1 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=0)
@@ -115,35 +141,73 @@ class test_batch_abduce(object):
reasoner4 = Reasoner(kb_add_prolog, "confidence", max_revision=2, require_more_revision=1)
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner3.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [1, 9]]
assert reasoner4.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [7, 3]]
assert reasoner3.batch_abduce(data_samples_add) == [
[1, 7],
[7, 1],
[8, 9],
[1, 9],
]
assert reasoner4.batch_abduce(data_samples_add) == [
[1, 7],
[7, 1],
[8, 9],
[7, 3],
]

def test_batch_abduce_zoopt(self, kb_add_prolog, data_samples_add):
reasoner1 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=1)
reasoner2 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=2)
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [7, 3]]
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner2.batch_abduce(data_samples_add) == [
[1, 7],
[7, 1],
[8, 9],
[7, 3],
]

def test_batch_abduce_hwf1(self, kb_hwf1, data_samples_hwf):
reasoner1 = Reasoner(kb_hwf1, "hamming", max_revision=3, require_more_revision=0)
reasoner2 = Reasoner(kb_hwf1, "hamming", max_revision=0.5, require_more_revision=0)
reasoner3 = Reasoner(kb_hwf1, "hamming", max_revision=0.9, require_more_revision=0)
res = reasoner1.batch_abduce(data_samples_hwf)
assert res == [['1', '+', '2'], ['8', 'times', '8'], [], ['4', '-', '6', 'div', '8']]
assert res == [
["1", "+", "2"],
["8", "times", "8"],
[],
["4", "-", "6", "div", "8"],
]
res = reasoner2.batch_abduce(data_samples_hwf)
assert res == [['1', '+', '2'], [], [], []]
assert res == [["1", "+", "2"], [], [], []]
res = reasoner3.batch_abduce(data_samples_hwf)
assert res == [['1', '+', '2'], ['8', 'times', '8'], [], ['4', '-', '6', 'div', '8']]
assert res == [
["1", "+", "2"],
["8", "times", "8"],
[],
["4", "-", "6", "div", "8"],
]

def test_batch_abduce_hwf2(self, kb_hwf2, data_samples_hwf):
reasoner1 = Reasoner(kb_hwf2, "hamming", max_revision=3, require_more_revision=0)
reasoner2 = Reasoner(kb_hwf2, "hamming", max_revision=0.5, require_more_revision=0)
reasoner3 = Reasoner(kb_hwf2, "hamming", max_revision=0.9, require_more_revision=0)
res = reasoner1.batch_abduce(data_samples_hwf)
assert res == [['1', '+', '2'], ['7', 'times', '9'], ['8', 'times', '8'], ['5', '-', '8', 'div', '8']]
assert res == [
["1", "+", "2"],
["7", "times", "9"],
["8", "times", "8"],
["5", "-", "8", "div", "8"],
]
res = reasoner2.batch_abduce(data_samples_hwf)
assert res == [['1', '+', '2'], ['7', 'times', '9'], [], ['5', '-', '8', 'div', '8']]
assert res == [
["1", "+", "2"],
["7", "times", "9"],
[],
["5", "-", "8", "div", "8"],
]
res = reasoner3.batch_abduce(data_samples_hwf)
assert res == [['1', '+', '2'], ['7', 'times', '9'], ['8', 'times', '8'], ['5', '-', '8', 'div', '8']]

assert res == [
["1", "+", "2"],
["7", "times", "9"],
["8", "times", "8"],
["5", "-", "8", "div", "8"],
]

Loading…
Cancel
Save