@@ -5,4 +5,11 @@ show_missing = True | |||
disable_warnings = include-ignored | |||
include = */abl/* | |||
omit = | |||
*/abl/__init__.py | |||
*/abl/__init__.py | |||
abl/bridge/__init__.py | |||
abl/dataset/__init__.py | |||
abl/evaluation/__init__.py | |||
abl/learning/__init__.py | |||
abl/reasoning/__init__.py | |||
abl/structures/__init__.py | |||
abl/utils/__init__.py |
@@ -24,7 +24,7 @@ jobs: | |||
- name: Install package dependencies | |||
run: | | |||
python -m pip install --upgrade pip | |||
pip install -r ./requirements.txt | |||
pip install -r build_tools/requirements.txt | |||
- name: Run tests | |||
run: | | |||
pytest --cov-config=.coveragerc --cov-report=xml --cov=abl ./tests | |||
@@ -20,5 +20,5 @@ jobs: | |||
- name: flake8 Lint | |||
uses: py-actions/flake8@v2 | |||
with: | |||
max-line-length: "110" | |||
plugins: "flake8-bugbear flake8-black" | |||
max-line-length: "100" | |||
args: --ignore=E203,W503 |
@@ -1,2 +1,11 @@ | |||
from .learning import abl_model, basic_nn | |||
from .reasoning import reasoner, kb | |||
from . import bridge, dataset, evaluation, learning, reasoning, structures, utils | |||
__all__ = [ | |||
"bridge", | |||
"dataset", | |||
"evaluation", | |||
"learning", | |||
"reasoning", | |||
"structures", | |||
"utils", | |||
] |
@@ -1,3 +1,3 @@ | |||
VERSION = (0, 0, 1) | |||
__version__ = ".".join(map(str, VERSION)) | |||
__version__ = ".".join(map(str, VERSION)) |
@@ -1,2 +1,4 @@ | |||
from .base_bridge import BaseBridge | |||
from .simple_bridge import SimpleBridge | |||
from .simple_bridge import SimpleBridge | |||
__all__ = ["BaseBridge", "SimpleBridge"] |
@@ -12,53 +12,40 @@ class BaseBridge(metaclass=ABCMeta): | |||
def __init__(self, model: ABLModel, reasoner: Reasoner) -> None: | |||
if not isinstance(model, ABLModel): | |||
raise TypeError( | |||
"Expected an instance of ABLModel, but received type: {}".format( | |||
type(model) | |||
) | |||
"Expected an instance of ABLModel, but received type: {}".format(type(model)) | |||
) | |||
if not isinstance(reasoner, Reasoner): | |||
raise TypeError( | |||
"Expected an instance of Reasoner, but received type: {}".format( | |||
type(reasoner) | |||
) | |||
"Expected an instance of Reasoner, but received type: {}".format(type(reasoner)) | |||
) | |||
self.model = model | |||
self.reasoner = reasoner | |||
@abstractmethod | |||
def predict( | |||
self, data_samples: ListData | |||
) -> Tuple[List[List[Any]], List[List[Any]]]: | |||
def predict(self, data_samples: ListData) -> Tuple[List[List[Any]], List[List[Any]]]: | |||
"""Placeholder for predict labels from input.""" | |||
pass | |||
@abstractmethod | |||
def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | |||
"""Placeholder for abduce pseudo labels.""" | |||
pass | |||
@abstractmethod | |||
def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | |||
"""Placeholder for map label space to symbol space.""" | |||
pass | |||
@abstractmethod | |||
def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]: | |||
"""Placeholder for map symbol space to label space.""" | |||
pass | |||
@abstractmethod | |||
def train(self, train_data: Union[ListData, DataSet]): | |||
"""Placeholder for train loop of ABductive Learning.""" | |||
pass | |||
@abstractmethod | |||
def valid(self, valid_data: Union[ListData, DataSet]) -> None: | |||
"""Placeholder for model test.""" | |||
pass | |||
@abstractmethod | |||
def test(self, test_data: Union[ListData, DataSet]) -> None: | |||
"""Placeholder for model validation.""" | |||
pass |
@@ -1,5 +1,5 @@ | |||
import os.path as osp | |||
from typing import Any, Dict, List, Optional, Tuple, Union | |||
from typing import Any, List, Optional, Tuple, Union | |||
from numpy import ndarray | |||
@@ -32,8 +32,7 @@ class SimpleBridge(BaseBridge): | |||
def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | |||
pred_idx = data_samples.pred_idx | |||
data_samples.pred_pseudo_label = [ | |||
[self.reasoner.mapping[_idx] for _idx in sub_list] | |||
for sub_list in pred_idx | |||
[self.reasoner.mapping[_idx] for _idx in sub_list] for sub_list in pred_idx | |||
] | |||
return data_samples.pred_pseudo_label | |||
@@ -81,7 +80,9 @@ class SimpleBridge(BaseBridge): | |||
loss = self.model.train(sub_data_samples) | |||
print_log( | |||
f"loop(train) [{loop + 1}/{loops}] segment(train) [{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] model loss is {loss:.5f}", | |||
f"loop(train) [{loop + 1}/{loops}] segment(train) \ | |||
[{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] \ | |||
model loss is {loss:.5f}", | |||
logger="current", | |||
) | |||
@@ -2,3 +2,10 @@ from .bridge_dataset import BridgeDataset | |||
from .classification_dataset import ClassificationDataset | |||
from .prediction_dataset import PredictionDataset | |||
from .regression_dataset import RegressionDataset | |||
__all__ = [ | |||
"BridgeDataset", | |||
"ClassificationDataset", | |||
"PredictionDataset", | |||
"RegressionDataset", | |||
] |
@@ -13,11 +13,15 @@ class BridgeDataset(Dataset): | |||
gt_pseudo_label : List[List[Any]], optional | |||
A list of objects representing the ground truth label of each element in ``X``. | |||
Y : List[Any] | |||
A list of objects representing the ground truth of the reasoning result of each instance in ``X``. | |||
A list of objects representing the ground truth of the reasoning result of | |||
each instance in ``X``. | |||
""" | |||
def __init__( | |||
self, X: List[List[Any]], gt_pseudo_label: Optional[List[List[Any]]], Y: List[Any] | |||
self, | |||
X: List[List[Any]], | |||
gt_pseudo_label: Optional[List[List[Any]]], | |||
Y: List[Any], | |||
): | |||
if (not isinstance(X, list)) or (not isinstance(Y, list)): | |||
raise ValueError("X and Y should be of type list.") | |||
@@ -15,16 +15,16 @@ class ClassificationDataset(Dataset): | |||
Y : List[int] | |||
The target data. | |||
transform : Callable[..., Any], optional | |||
A function/transform that takes in an object and returns a transformed version. Defaults to None. | |||
A function/transform that takes in an object and returns a transformed version. | |||
Defaults to None. | |||
""" | |||
def __init__( | |||
self, X: List[Any], Y: List[int], transform: Callable[..., Any] = None | |||
): | |||
def __init__(self, X: List[Any], Y: List[int], transform: Callable[..., Any] = None): | |||
if (not isinstance(X, list)) or (not isinstance(Y, list)): | |||
raise ValueError("X and Y should be of type list.") | |||
if len(X) != len(Y): | |||
raise ValueError("Length of X and Y must be equal.") | |||
self.X = X | |||
self.Y = torch.LongTensor(Y) | |||
self.transform = transform | |||
@@ -13,8 +13,10 @@ class PredictionDataset(Dataset): | |||
X : List[Any] | |||
The input data. | |||
transform : Callable[..., Any], optional | |||
A function/transform that takes in an object and returns a transformed version. Defaults to None. | |||
A function/transform that takes in an object and returns a transformed version. | |||
Defaults to None. | |||
""" | |||
def __init__(self, X: List[Any], transform: Callable[..., Any] = None): | |||
if not isinstance(X, list): | |||
raise ValueError("X should be of type list.") | |||
@@ -1,6 +1,5 @@ | |||
from typing import Any, List, Tuple | |||
import torch | |||
from torch.utils.data import Dataset | |||
@@ -15,12 +14,13 @@ class RegressionDataset(Dataset): | |||
Y : List[Any] | |||
A list of objects representing the output data. | |||
""" | |||
def __init__(self, X: List[Any], Y: List[Any]): | |||
if (not isinstance(X, list)) or (not isinstance(Y, list)): | |||
raise ValueError("X and Y should be of type list.") | |||
if len(X) != len(Y): | |||
raise ValueError("Length of X and Y must be equal.") | |||
self.X = X | |||
self.Y = Y | |||
@@ -1,3 +1,5 @@ | |||
from .base_metric import BaseMetric | |||
from .semantics_metric import SemanticsMetric | |||
from .symbol_metric import SymbolMetric | |||
__all__ = ["BaseMetric", "SemanticsMetric", "SymbolMetric"] |
@@ -20,8 +20,10 @@ class BaseMetric(metaclass=ABCMeta): | |||
will be used instead. Default: None | |||
""" | |||
def __init__(self, | |||
prefix: Optional[str] = None,) -> None: | |||
def __init__( | |||
self, | |||
prefix: Optional[str] = None, | |||
) -> None: | |||
self.results: List[Any] = [] | |||
self.prefix = prefix or self.default_prefix | |||
@@ -65,20 +67,18 @@ class BaseMetric(metaclass=ABCMeta): | |||
""" | |||
if len(self.results) == 0: | |||
print_log( | |||
f'{self.__class__.__name__} got empty `self.results`. Please ' | |||
'ensure that the processed results are properly added into ' | |||
'`self.results` in `process` method.', | |||
logger='current', | |||
level=logging.WARNING) | |||
f"{self.__class__.__name__} got empty `self.results`. Please " | |||
"ensure that the processed results are properly added into " | |||
"`self.results` in `process` method.", | |||
logger="current", | |||
level=logging.WARNING, | |||
) | |||
metrics = self.compute_metrics(self.results) | |||
# Add prefix to metric names | |||
if self.prefix: | |||
metrics = { | |||
'/'.join((self.prefix, k)): v | |||
for k, v in metrics.items() | |||
} | |||
metrics = {"/".join((self.prefix, k)): v for k, v in metrics.items()} | |||
# reset the results list | |||
self.results.clear() | |||
return metrics | |||
return metrics |
@@ -14,15 +14,15 @@ class SymbolMetric(BaseMetric): | |||
if not len(pred_pseudo_label) == len(gt_pseudo_label): | |||
raise ValueError("lengthes of pred_pseudo_label and gt_pseudo_label should be equal") | |||
for pred_z, z in zip(pred_pseudo_label, gt_pseudo_label): | |||
correct_num = 0 | |||
for pred_symbol, symbol in zip(pred_z, z): | |||
if pred_symbol == symbol: | |||
correct_num += 1 | |||
self.results.append(correct_num / len(z)) | |||
def compute_metrics(self, results: list) -> dict: | |||
metrics = dict() | |||
metrics["character_accuracy"] = sum(results) / len(results) | |||
return metrics | |||
return metrics |
@@ -1,2 +1,4 @@ | |||
from .abl_model import ABLModel | |||
from .basic_nn import BasicNN | |||
from .basic_nn import BasicNN | |||
__all__ = ["ABLModel", "BasicNN"] |
@@ -58,7 +58,8 @@ class ABLModel: | |||
Parameters | |||
---------- | |||
data_samples : ListData | |||
A batch of data to train on, which typically contains the data, `X`, and the corresponding labels, `abduced_idx`. | |||
A batch of data to train on, which typically contains the data, `X`, and the | |||
corresponding labels, `abduced_idx`. | |||
Returns | |||
------- | |||
@@ -68,7 +69,7 @@ class ABLModel: | |||
data_X = data_samples.flatten("X") | |||
data_y = data_samples.flatten("abduced_idx") | |||
return self.base_model.fit(X=data_X, y=data_y) | |||
def valid(self, data_samples: ListData) -> float: | |||
""" | |||
Validate the model on the given data. | |||
@@ -76,7 +77,8 @@ class ABLModel: | |||
Parameters | |||
---------- | |||
data_samples : ListData | |||
A batch of data to train on, which typically contains the data, `X`, and the corresponding labels, `abduced_idx`. | |||
A batch of data to train on, which typically contains the data, `X`, | |||
and the corresponding labels, `abduced_idx`. | |||
Returns | |||
------- | |||
@@ -94,7 +96,7 @@ class ABLModel: | |||
method = getattr(model, operation) | |||
method(*args, **kwargs) | |||
else: | |||
if not f"{operation}_path" in kwargs.keys(): | |||
if f"{operation}_path" not in kwargs.keys(): | |||
raise ValueError(f"'{operation}_path' should not be None") | |||
else: | |||
try: | |||
@@ -104,9 +106,10 @@ class ABLModel: | |||
elif operation == "load": | |||
with open(kwargs["load_path"], "rb") as file: | |||
self.base_model = pickle.load(file) | |||
except: | |||
except (OSError, pickle.PickleError): | |||
raise NotImplementedError( | |||
f"{type(model).__name__} object doesn't have the {operation} method and the default pickle-based {operation} method failed." | |||
f"{type(model).__name__} object doesn't have the {operation} method \ | |||
and the default pickle-based {operation} method failed." | |||
) | |||
def save(self, *args, **kwargs) -> None: | |||
@@ -1,5 +1,5 @@ | |||
import os | |||
import logging | |||
import os | |||
from typing import Any, Callable, List, Optional, T, Tuple | |||
import numpy | |||
@@ -23,7 +23,8 @@ class BasicNN: | |||
optimizer : torch.optim.Optimizer | |||
The optimizer used for training. | |||
device : torch.device, optional | |||
The device on which the model will be trained or used for prediction, by default torch.device("cpu"). | |||
The device on which the model will be trained or used for prediction, | |||
by default torch.device("cpu"). | |||
batch_size : int, optional | |||
The batch size used for training, by default 32. | |||
num_epochs : int, optional | |||
@@ -37,9 +38,11 @@ class BasicNN: | |||
save_dir : Optional[str], optional | |||
The directory in which to save the model during training, by default None. | |||
train_transform : Callable[..., Any], optional | |||
A function/transform that takes in an object and returns a transformed version used in the `fit` and `train_epoch` methods, by default None. | |||
A function/transform that takes in an object and returns a transformed version used | |||
in the `fit` and `train_epoch` methods, by default None. | |||
test_transform : Callable[..., Any], optional | |||
A function/transform that takes in an object and returns a transformed version in the `predict`, `predict_proba` and `score` methods, , by default None. | |||
A function/transform that takes in an object and returns a transformed version in the | |||
`predict`, `predict_proba` and `score` methods, , by default None. | |||
collate_fn : Callable[[List[T]], Any], optional | |||
The function used to collate data, by default None. | |||
""" | |||
@@ -1,2 +1,4 @@ | |||
from .kb import KBBase, GroundKB, PrologKB | |||
from .reasoner import Reasoner | |||
from .kb import GroundKB, KBBase, PrologKB | |||
from .reasoner import Reasoner | |||
__all__ = ["KBBase", "GroundKB", "PrologKB", "Reasoner"] |
@@ -1,16 +1,15 @@ | |||
from abc import ABC, abstractmethod | |||
import bisect | |||
import os | |||
from abc import ABC, abstractmethod | |||
from collections import defaultdict | |||
from itertools import product, combinations | |||
from itertools import combinations, product | |||
from multiprocessing import Pool | |||
from functools import lru_cache | |||
import numpy as np | |||
import pyswip | |||
from ..utils.utils import flatten, reform_list, hamming_dist, to_hashable | |||
from ..utils.cache import abl_cache | |||
from ..utils.utils import flatten, hamming_dist, reform_list, to_hashable | |||
class KBBase(ABC): | |||
@@ -20,19 +19,19 @@ class KBBase(ABC): | |||
Parameters | |||
---------- | |||
pseudo_label_list : list | |||
List of possible pseudo labels. It's recommended to arrange the pseudo labels in this | |||
list so that each aligns with its corresponding index in the base model: the first with | |||
List of possible pseudo labels. It's recommended to arrange the pseudo labels in this | |||
list so that each aligns with its corresponding index in the base model: the first with | |||
the 0th index, the second with the 1st, and so forth. | |||
max_err : float, optional | |||
The upper tolerance limit when comparing the similarity between a pseudo label sample's reasoning | |||
result and the ground truth. This is only applicable when the reasoning result is of a numerical type. | |||
This is particularly relevant for regression problems where exact matches might not be | |||
feasible. Defaults to 1e-10. | |||
The upper tolerance limit when comparing the similarity between a pseudo label sample's | |||
reasoning result and the ground truth. This is only applicable when the reasoning | |||
result is of a numerical type. This is particularly relevant for regression problems where | |||
exact matches might not be feasible. Defaults to 1e-10. | |||
use_cache : bool, optional | |||
Whether to use abl_cache for previously abduced candidates to speed up subsequent | |||
operations. Defaults to True. | |||
key_func : func, optional | |||
A function employed for hashing in abl_cache. This is only operational when use_cache | |||
A function employed for hashing in abl_cache. This is only operational when use_cache | |||
is set to True. Defaults to to_hashable. | |||
cache_size: int, optional | |||
The cache size in abl_cache. This is only operational when use_cache is set to | |||
@@ -75,7 +74,6 @@ class KBBase(ABC): | |||
pseudo_label : List[Any] | |||
Pseudo label sample. | |||
""" | |||
pass | |||
def abduce_candidates(self, pseudo_label, y, max_revision_num, require_more_revision): | |||
""" | |||
@@ -104,7 +102,7 @@ class KBBase(ABC): | |||
""" | |||
Check whether the reasoning result of a pseduo label sample is equal to the ground truth | |||
(or, within the maximum error allowed for numerical results). | |||
Returns | |||
------- | |||
bool | |||
@@ -130,7 +128,7 @@ class KBBase(ABC): | |||
Ground truth of the reasoning result for the sample. | |||
revision_idx : array-like | |||
Indices of where revisions should be made to the pseudo label sample. | |||
Returns | |||
------- | |||
List[List[Any]] | |||
@@ -149,8 +147,8 @@ class KBBase(ABC): | |||
def _revision(self, revision_num, pseudo_label, y): | |||
""" | |||
For a specified number of labels in a pseudo label sample to revise, iterate through all possible | |||
indices to find any candidates that are compatible with the knowledge base. | |||
For a specified number of labels in a pseudo label sample to revise, iterate through | |||
all possible indices to find any candidates that are compatible with the knowledge base. | |||
""" | |||
new_candidates = [] | |||
revision_idx_list = combinations(range(len(pseudo_label)), revision_num) | |||
@@ -164,8 +162,8 @@ class KBBase(ABC): | |||
def _abduce_by_search(self, pseudo_label, y, max_revision_num, require_more_revision): | |||
""" | |||
Perform abductive reasoning by exhastive search. Specifically, begin with 0 and | |||
continuously increase the number of labels in a pseudo label sample to revise, until candidates | |||
that are compatible with the knowledge base are found. | |||
continuously increase the number of labels in a pseudo label sample to revise, until | |||
candidates that are compatible with the knowledge base are found. | |||
Parameters | |||
---------- | |||
@@ -177,8 +175,8 @@ class KBBase(ABC): | |||
The upper limit on the number of revisions. | |||
require_more_revision : int | |||
If larger than 0, then after having found any candidates compatible with the | |||
knowledge base, continue to increase the number of labels in a pseudo label sample to revise to | |||
get more possible compatible candidates. | |||
knowledge base, continue to increase the number of labels in a pseudo label sample to | |||
revise to get more possible compatible candidates. | |||
Returns | |||
------- | |||
@@ -286,7 +284,7 @@ class GroundKB(KBBase): | |||
Perform abductive reasoning by directly retrieving compatible candidates from | |||
the prebuilt GKB. In this way, the time-consuming exhaustive search can be | |||
avoided. | |||
Parameters | |||
---------- | |||
pseudo_label : List[Any] | |||
@@ -347,7 +345,7 @@ class GroundKB(KBBase): | |||
num_candidates = len(self.GKB[i]) if i in self.GKB else 0 | |||
GKB_info_parts.append(f"{num_candidates} candidates of length {i}") | |||
GKB_info = ", ".join(GKB_info_parts) | |||
return ( | |||
f"{self.__class__.__name__} is a KB with " | |||
f"pseudo_label_list={self.pseudo_label_list!r}, " | |||
@@ -400,7 +398,7 @@ class PrologKB(KBBase): | |||
returned `Res` as the reasoning results. To use this default function, there must be | |||
a `logic_forward` method in the pl file to perform reasoning. | |||
Otherwise, users would override this function. | |||
Parameters | |||
---------- | |||
pseudo_label : List[Any] | |||
@@ -429,9 +427,10 @@ class PrologKB(KBBase): | |||
def get_query_string(self, pseudo_label, y, revision_idx): | |||
""" | |||
Get the query to be used for consulting Prolog. | |||
This is a default function for demo, users would override this function to adapt to their own | |||
Prolog file. In this demo function, return query `logic_forward([kept_labels, Revise_labels], Res).`. | |||
This is a default function for demo, users would override this function to adapt to | |||
their own Prolog file. In this demo function, return query | |||
`logic_forward([kept_labels, Revise_labels], Res).`. | |||
Parameters | |||
---------- | |||
pseudo_label : List[Any] | |||
@@ -440,7 +439,7 @@ class PrologKB(KBBase): | |||
Ground truth of the reasoning result for the sample. | |||
revision_idx : array-like | |||
Indices of where revisions should be made to the pseudo label sample. | |||
Returns | |||
------- | |||
str | |||
@@ -448,14 +447,14 @@ class PrologKB(KBBase): | |||
""" | |||
query_string = "logic_forward(" | |||
query_string += self._revision_pseudo_label(pseudo_label, revision_idx) | |||
key_is_none_flag = y is None or (type(y) == list and y[0] is None) | |||
key_is_none_flag = y is None or (isinstance(y, list) and y[0] is None) | |||
query_string += ",%s)." % y if not key_is_none_flag else ")." | |||
return query_string | |||
def revise_at_idx(self, pseudo_label, y, revision_idx): | |||
""" | |||
Revise the pseudo label sample at specified index positions by querying Prolog. | |||
Parameters | |||
---------- | |||
pseudo_label : List[Any] | |||
@@ -464,7 +463,7 @@ class PrologKB(KBBase): | |||
Ground truth of the reasoning result for the sample. | |||
revision_idx : array-like | |||
Indices of where revisions should be made to the pseudo label sample. | |||
Returns | |||
------- | |||
List[List[Any]] | |||
@@ -1,11 +1,7 @@ | |||
import numpy as np | |||
from zoopt import Dimension, Objective, Parameter, Opt | |||
from ..utils.utils import ( | |||
confidence_dist, | |||
flatten, | |||
reform_list, | |||
hamming_dist, | |||
) | |||
from zoopt import Dimension, Objective, Opt, Parameter | |||
from ..utils.utils import confidence_dist, hamming_dist | |||
class Reasoner: | |||
@@ -124,7 +120,7 @@ class Reasoner: | |||
def zoopt_get_solution(self, symbol_num, data_sample, max_revision_num): | |||
""" | |||
Get the optimal solution using ZOOpt library. The solution is a list of | |||
Get the optimal solution using ZOOpt library. The solution is a list of | |||
boolean values, where '1' (True) indicates the indices chosen to be revised. | |||
Parameters | |||
@@ -148,7 +144,7 @@ class Reasoner: | |||
def zoopt_revision_score(self, symbol_num, data_sample, sol): | |||
""" | |||
Get the revision score for a solution. A lower score suggests that ZOOpt library | |||
Get the revision score for a solution. A lower score suggests that ZOOpt library | |||
has a higher preference for this solution. | |||
""" | |||
revision_idx = np.where(sol.get_x() != 0)[0] | |||
@@ -198,7 +194,7 @@ class Reasoner: | |||
Returns | |||
------- | |||
List[Any] | |||
A revised pseudo label sample through abductive reasoning, which is compatible | |||
A revised pseudo label sample through abductive reasoning, which is compatible | |||
with the knowledge base. | |||
""" | |||
symbol_num = data_sample.elements_num("pred_pseudo_label") | |||
@@ -1,2 +1,4 @@ | |||
from .base_data_element import BaseDataElement | |||
from .list_data import ListData | |||
from .list_data import ListData | |||
__all__ = ["BaseDataElement", "ListData"] |
@@ -224,9 +224,7 @@ class BaseDataElement: | |||
metainfo (dict): A dict contains the meta information | |||
of image, such as ``img_shape``, ``scale_factor``, etc. | |||
""" | |||
assert isinstance( | |||
metainfo, dict | |||
), f"metainfo should be a ``dict`` but got {type(metainfo)}" | |||
assert isinstance(metainfo, dict), f"metainfo should be a ``dict`` but got {type(metainfo)}" | |||
meta = copy.deepcopy(metainfo) | |||
for k, v in meta.items(): | |||
self.set_field(name=k, value=v, field_type="metainfo", dtype=None) | |||
@@ -388,8 +386,7 @@ class BaseDataElement: | |||
super().__setattr__(name, value) | |||
else: | |||
raise AttributeError( | |||
f"{name} has been used as a " | |||
"private attribute, which is immutable." | |||
f"{name} has been used as a " "private attribute, which is immutable." | |||
) | |||
else: | |||
self.set_field(name=name, value=value, field_type="data", dtype=None) | |||
@@ -458,9 +455,7 @@ class BaseDataElement: | |||
functions.""" | |||
assert field_type in ["metainfo", "data"] | |||
if dtype is not None: | |||
assert isinstance( | |||
value, dtype | |||
), f"{value} should be a {dtype} but got {type(value)}" | |||
assert isinstance(value, dtype), f"{value} should be a {dtype} but got {type(value)}" | |||
if field_type == "metainfo": | |||
if name in self._data_fields: | |||
@@ -571,8 +566,7 @@ class BaseDataElement: | |||
def to_dict(self) -> dict: | |||
"""Convert BaseDataElement to dict.""" | |||
return { | |||
k: v.to_dict() if isinstance(v, BaseDataElement) else v | |||
for k, v in self.all_items() | |||
k: v.to_dict() if isinstance(v, BaseDataElement) else v for k, v in self.all_items() | |||
} | |||
def __repr__(self) -> str: | |||
@@ -1,7 +1,6 @@ | |||
# Copyright (c) OpenMMLab. All rights reserved. | |||
import itertools | |||
from collections.abc import Sized | |||
from typing import Any, List, Union | |||
from typing import List, Union | |||
import numpy as np | |||
import torch | |||
@@ -1,3 +1,23 @@ | |||
from .cache import Cache, abl_cache | |||
from .logger import ABLLogger, print_log | |||
from .utils import * | |||
from .utils import ( | |||
calculate_revision_num, | |||
confidence_dist, | |||
flatten, | |||
hamming_dist, | |||
reform_list, | |||
to_hashable, | |||
) | |||
__all__ = [ | |||
"Cache", | |||
"ABLLogger", | |||
"print_log", | |||
"calculate_revision_num", | |||
"confidence_dist", | |||
"flatten", | |||
"hamming_dist", | |||
"reform_list", | |||
"to_hashable", | |||
"abl_cache", | |||
] |
@@ -1,10 +1,5 @@ | |||
import pickle | |||
import os | |||
import os.path as osp | |||
from typing import Callable, Generic, TypeVar | |||
from .logger import print_log, ABLLogger | |||
K = TypeVar("K") | |||
T = TypeVar("T") | |||
PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields | |||
@@ -73,7 +68,6 @@ class Cache(Generic[K, T]): | |||
# Empty the oldest link and make it the new root. | |||
self.root = oldroot[NEXT] | |||
oldkey = self.root[KEY] | |||
oldresult = self.root[RESULT] | |||
self.root[KEY] = self.root[RESULT] = None | |||
# Now update the cache dictionary. | |||
del self.cache_dict[oldkey] | |||
@@ -15,7 +15,8 @@ class FilterDuplicateWarning(logging.Filter): | |||
""" | |||
Filter for eliminating repeated warning messages in logging. | |||
This filter checks for duplicate warning messages and allows only the first occurrence of each message to be logged, filtering out subsequent duplicates. | |||
This filter checks for duplicate warning messages and allows only the first occurrence of | |||
each message to be logged, filtering out subsequent duplicates. | |||
Parameters | |||
---------- | |||
@@ -145,7 +146,8 @@ class ABLLogger(Logger, ManagerMixin): | |||
`ABLLogger` provides a formatted logger that can log messages with different | |||
log levels. It allows the creation of logger instances in a similar manner to `ManagerMixin`. | |||
The logger has features like distributed log storage and colored terminal output for different log levels. | |||
The logger has features like distributed log storage and colored terminal output for different | |||
log levels. | |||
Parameters | |||
---------- | |||
@@ -154,7 +156,8 @@ class ABLLogger(Logger, ManagerMixin): | |||
logger_name : str, optional | |||
`name` attribute of `logging.Logger` instance. Defaults to 'abl'. | |||
log_file : str, optional | |||
The log filename. If specified, a `FileHandler` will be added to the logger. Defaults to None. | |||
The log filename. If specified, a `FileHandler` will be added to the logger. | |||
Defaults to None. | |||
log_level : Union[int, str] | |||
The log level of the handler. Defaults to 'INFO'. | |||
If log level is 'DEBUG', distributed logs will be saved during distributed training. | |||
@@ -287,20 +290,25 @@ def print_log(msg, logger: Optional[Union[Logger, str]] = None, level=logging.IN | |||
""" | |||
Print a log message using the specified logger or a default method. | |||
This function logs a message with a given logger, if provided, or prints it using the standard `print` function. It supports special logger types such as 'silent' and 'current'. | |||
This function logs a message with a given logger, if provided, or prints it using | |||
the standard `print` function. It supports special logger types such as 'silent' and 'current'. | |||
Parameters | |||
---------- | |||
msg : str | |||
The message to be logged. | |||
logger : Optional[Union[Logger, str]], optional | |||
The logger to use for logging the message. It can be a `logging.Logger` instance, a string specifying the logger name, 'silent', 'current', or None. If None, the `print` method is used. | |||
The logger to use for logging the message. It can be a `logging.Logger` instance, a string | |||
specifying the logger name, 'silent', 'current', or None. If None, the `print` | |||
method is used. | |||
- 'silent': No message will be printed. | |||
- 'current': Use the latest created logger to log the message. | |||
- other str: The instance name of the logger. A `ValueError` is raised if the logger has not been created. | |||
- other str: The instance name of the logger. A `ValueError` is raised if the logger has not | |||
been created. | |||
- None: The `print()` method is used for logging. | |||
level : int, optional | |||
The logging level. This is only applicable when `logger` is a Logger object, 'current', or a named logger instance. The default is `logging.INFO`. | |||
The logging level. This is only applicable when `logger` is a Logger object, 'current', | |||
or a named logger instance. The default is `logging.INFO`. | |||
""" | |||
if logger is None: | |||
print(msg) | |||
@@ -6,7 +6,7 @@ from collections import OrderedDict | |||
from typing import Type, TypeVar | |||
_lock = threading.RLock() | |||
T = TypeVar('T') | |||
T = TypeVar("T") | |||
def _accquire_lock() -> None: | |||
@@ -47,7 +47,7 @@ class ManagerMeta(type): | |||
cls._instance_dict = OrderedDict() | |||
params = inspect.getfullargspec(cls) | |||
params_names = params[0] if params[0] else [] | |||
assert 'name' in params_names, f'{cls} must have the `name` argument' | |||
assert "name" in params_names, f"{cls} must have the `name` argument" | |||
super().__init__(*args) | |||
@@ -72,9 +72,8 @@ class ManagerMixin(metaclass=ManagerMeta): | |||
name (str): Name of the instance. Defaults to ''. | |||
""" | |||
def __init__(self, name: str = '', **kwargs): | |||
assert isinstance(name, str) and name, \ | |||
'name argument must be an non-empty string.' | |||
def __init__(self, name: str = "", **kwargs): | |||
assert isinstance(name, str) and name, "name argument must be an non-empty string." | |||
self._instance_name = name | |||
@classmethod | |||
@@ -102,8 +101,7 @@ class ManagerMixin(metaclass=ManagerMeta): | |||
instance. | |||
""" | |||
_accquire_lock() | |||
assert isinstance(name, str), \ | |||
f'type of name should be str, but got {type(cls)}' | |||
assert isinstance(name, str), f"type of name should be str, but got {type(cls)}" | |||
instance_dict = cls._instance_dict # type: ignore | |||
# Get the instance by name. | |||
if name not in instance_dict: | |||
@@ -111,9 +109,10 @@ class ManagerMixin(metaclass=ManagerMeta): | |||
instance_dict[name] = instance # type: ignore | |||
elif kwargs: | |||
warnings.warn( | |||
f'{cls} instance named of {name} has been created, ' | |||
'the method `get_instance` should not accept any other ' | |||
'arguments') | |||
f"{cls} instance named of {name} has been created, " | |||
"the method `get_instance` should not accept any other " | |||
"arguments" | |||
) | |||
# Get latest instantiated instance or root instance. | |||
_release_lock() | |||
return instance_dict[name] | |||
@@ -141,8 +140,9 @@ class ManagerMixin(metaclass=ManagerMeta): | |||
_accquire_lock() | |||
if not cls._instance_dict: | |||
raise RuntimeError( | |||
f'Before calling {cls.__name__}.get_current_instance(), you ' | |||
'should call get_instance(name=xxx) at least once.') | |||
f"Before calling {cls.__name__}.get_current_instance(), you " | |||
"should call get_instance(name=xxx) at least once." | |||
) | |||
name = next(iter(reversed(cls._instance_dict))) | |||
_release_lock() | |||
return cls._instance_dict[name] | |||
@@ -221,60 +221,3 @@ def calculate_revision_num(parameter, total_length): | |||
if parameter < 0: | |||
raise ValueError("If parameter is an int, it must be non-negative.") | |||
return parameter | |||
if __name__ == "__main__": | |||
A = np.array( | |||
[ | |||
[ | |||
0.18401675, | |||
0.06797526, | |||
0.06797541, | |||
0.06801736, | |||
0.06797528, | |||
0.06797526, | |||
0.06818808, | |||
0.06797527, | |||
0.06800033, | |||
0.06797526, | |||
0.06797526, | |||
0.06797526, | |||
0.06797526, | |||
], | |||
[ | |||
0.07223161, | |||
0.0685229, | |||
0.06852708, | |||
0.17227574, | |||
0.06852163, | |||
0.07018146, | |||
0.06860291, | |||
0.06852849, | |||
0.06852163, | |||
0.0685216, | |||
0.0685216, | |||
0.06852174, | |||
0.0685216, | |||
], | |||
[ | |||
0.06794382, | |||
0.0679436, | |||
0.06794395, | |||
0.06794346, | |||
0.06794346, | |||
0.18467231, | |||
0.06794345, | |||
0.06794871, | |||
0.06794345, | |||
0.06794345, | |||
0.06794345, | |||
0.06794345, | |||
0.06794345, | |||
], | |||
], | |||
dtype=np.float32, | |||
) | |||
B = [[0, 9, 3], [0, 11, 4]] | |||
print(ori_confidence_dist(A, B)) | |||
print(confidence_dist(A, B)) |
@@ -0,0 +1,3 @@ | |||
-r ../requirements.txt | |||
pytest | |||
pytest-cov |
@@ -3,7 +3,7 @@ MNIST Addition | |||
MNIST Addition was first introduced in [1] and the inputs of this task are pairs of MNIST images and the outputs are their sums. The dataset looks like this: | |||
.. image:: ../img/image_1.jpg | |||
.. image:: ../img/Datasets_1.png | |||
:width: 350px | |||
:align: center | |||
@@ -11,9 +11,4 @@ MNIST Addition was first introduced in [1] and the inputs of this task are pairs | |||
The ``gt_pseudo_label`` is only used to test the performance of the machine learning model and is not used in the training phase. | |||
In the Abductive Learning framework, the inference process is as follows: | |||
.. image:: ../img/image_2.jpg | |||
:width: 700px | |||
[1] Robin Manhaeve, Sebastijan Dumancic, Angelika Kimmig, Thomas Demeester, and Luc De Raedt. Deepproblog: Neural probabilistic logic programming. In Advances in Neural Information Processing Systems 31 (NeurIPS), pages 3749-3759.2018. |
@@ -1,14 +1,12 @@ | |||
import sys | |||
import os | |||
import re | |||
import sys | |||
if not "READTHEDOCS" in os.environ: | |||
if "READTHEDOCS" not in os.environ: | |||
sys.path.insert(0, os.path.abspath("..")) | |||
sys.path.append(os.path.abspath("./ABL/")) | |||
# from sphinx.locale import _ | |||
from sphinx_rtd_theme import __version__ | |||
project = "ABL" | |||
slug = re.sub(r"\W+", "-", project.lower()) | |||
@@ -48,8 +46,8 @@ pygments_style = "default" | |||
html_theme = "sphinx_rtd_theme" | |||
html_theme_options = {"display_version": True} | |||
html_static_path = ['_static'] | |||
html_css_files = ['custom.css'] | |||
html_static_path = ["_static"] | |||
html_css_files = ["custom.css"] | |||
# html_theme_path = ["../.."] | |||
# html_logo = "demo/static/logo-wordmark-light.svg" | |||
# html_show_sourcelink = True | |||
@@ -1,11 +1,11 @@ | |||
import os | |||
import os.path as osp | |||
import cv2 | |||
import pickle | |||
import numpy as np | |||
import random | |||
from collections import defaultdict | |||
import cv2 | |||
import numpy as np | |||
from torchvision.transforms import transforms | |||
CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | |||
@@ -1,388 +0,0 @@ | |||
# coding: utf-8 | |||
# ================================================================# | |||
# Copyright (C) 2021 Freecss All rights reserved. | |||
# | |||
# File Name :framework.py | |||
# Author :freecss | |||
# Email :karlfreecss@gmail.com | |||
# Created Date :2021/06/07 | |||
# Description : | |||
# | |||
# ================================================================# | |||
import torch | |||
import torch.nn as nn | |||
import numpy as np | |||
import os | |||
from abl.utils.plog import INFO | |||
from abl.utils.utils import flatten, reform_idx | |||
from abl.learning.basic_nn import BasicNN, BasicDataset | |||
from utils import gen_mappings, mapping_res, remapping_res | |||
from models.nn import SymbolNetAutoencoder | |||
from torch.utils.data import RandomSampler | |||
from datasets.get_hed import get_pretrain_data | |||
def hed_pretrain(kb, cls, recorder): | |||
cls_autoencoder = SymbolNetAutoencoder(num_classes=len(kb.pseudo_label_list)) | |||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
if not os.path.exists("./weights/pretrain_weights.pth"): | |||
INFO("Pretrain Start") | |||
pretrain_data_X, pretrain_data_Y = get_pretrain_data(["0", "1", "10", "11"]) | |||
pretrain_data = BasicDataset(pretrain_data_X, pretrain_data_Y) | |||
pretrain_data_loader = torch.utils.data.DataLoader( | |||
pretrain_data, batch_size=64, shuffle=True | |||
) | |||
criterion = nn.MSELoss() | |||
optimizer = torch.optim.RMSprop( | |||
cls_autoencoder.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6 | |||
) | |||
pretrain_model = BasicNN( | |||
cls_autoencoder, | |||
criterion, | |||
optimizer, | |||
device, | |||
save_interval=1, | |||
save_dir=recorder.save_dir, | |||
num_epochs=10, | |||
recorder=recorder, | |||
) | |||
pretrain_model.fit(pretrain_data_loader) | |||
torch.save( | |||
cls_autoencoder.base_model.state_dict(), "./weights/pretrain_weights.pth" | |||
) | |||
cls.load_state_dict(cls_autoencoder.base_model.state_dict()) | |||
else: | |||
cls.load_state_dict(torch.load("./weights/pretrain_weights.pth")) | |||
def _get_char_acc(model, X, consistent_pred_res, mapping): | |||
original_pred_res = model.predict(X)["label"] | |||
pred_res = flatten(mapping_res(original_pred_res, mapping)) | |||
INFO("Current model's output: ", pred_res) | |||
INFO("Abduced labels: ", flatten(consistent_pred_res)) | |||
assert len(pred_res) == len(flatten(consistent_pred_res)) | |||
return sum( | |||
[ | |||
pred_res[idx] == flatten(consistent_pred_res)[idx] | |||
for idx in range(len(pred_res)) | |||
] | |||
) / len(pred_res) | |||
def abduce_and_train(model, abducer, mapping, train_X_true, select_num): | |||
select_idx = RandomSampler(train_X_true, num_samples=select_num,replacement=False) | |||
X = [train_X_true[idx] for idx in select_idx] | |||
# original_pred_res = model.predict(X)['label'] | |||
pred_label = model.predict(X)["label"] | |||
if mapping == None: | |||
mappings = gen_mappings([0, 1, 2, 3], ["+", "=", 0, 1]) | |||
else: | |||
mappings = [mapping] | |||
consistent_idx = [] | |||
consistent_pred_res = [] | |||
for m in mappings: | |||
pred_pseudo_label = mapping_res(pred_label, m) | |||
max_revision_num = 20 | |||
solution = abducer.zoopt_get_solution( | |||
pred_label, | |||
pred_pseudo_label, | |||
[None] * len(pred_label), | |||
[None] * len(pred_label), | |||
max_revision_num, | |||
) | |||
all_address_flag = reform_idx(solution, pred_label) | |||
consistent_idx_tmp = [] | |||
consistent_pred_res_tmp = [] | |||
for idx in range(len(pred_label)): | |||
address_idx = [ | |||
i for i, flag in enumerate(all_address_flag[idx]) if flag != 0 | |||
] | |||
candidate = abducer.revise_by_idx([pred_pseudo_label[idx]], None, address_idx) | |||
if len(candidate) > 0: | |||
consistent_idx_tmp.append(idx) | |||
consistent_pred_res_tmp.append(candidate[0][0]) | |||
if len(consistent_idx_tmp) > len(consistent_idx): | |||
consistent_idx = consistent_idx_tmp | |||
consistent_pred_res = consistent_pred_res_tmp | |||
if len(mappings) > 1: | |||
mapping = m | |||
if len(consistent_idx) == 0: | |||
return 0, 0, None | |||
INFO("Train pool size is:", len(flatten(consistent_pred_res))) | |||
INFO("Start to use abduced pseudo label to train model...") | |||
model.train( | |||
[X[idx] for idx in consistent_idx], remapping_res(consistent_pred_res, mapping) | |||
) | |||
consistent_acc = len(consistent_idx) / select_num | |||
char_acc = _get_char_acc( | |||
model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping | |||
) | |||
INFO("consistent_acc is %s, char_acc is %s" % (consistent_acc, char_acc)) | |||
return consistent_acc, char_acc, mapping | |||
# def abduce_and_train(model, abducer, mapping, train_X_true, select_num): | |||
# select_idx = np.random.randint(len(train_X_true), size=select_num) | |||
# X = [] | |||
# for idx in select_idx: | |||
# X.append(train_X_true[idx]) | |||
# original_pred_res = model.predict(X)['label'] | |||
# if mapping == None: | |||
# mappings = gen_mappings([0, 1, 2, 3],['+', '=', 0, 1]) | |||
# else: | |||
# mappings = [mapping] | |||
# consistent_idx = [] | |||
# consistent_pred_res = [] | |||
# for m in mappings: | |||
# pred_res = mapping_res(original_pred_res, m) | |||
# max_abduce_num = 20 | |||
# solution = abducer.zoopt_get_solution(pred_res, [None] * len(pred_res), [None] * len(pred_res), max_abduce_num) | |||
# all_address_flag = reform_idx(solution, pred_res) | |||
# consistent_idx_tmp = [] | |||
# consistent_pred_res_tmp = [] | |||
# for idx in range(len(pred_res)): | |||
# address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] | |||
# candidate = abducer.revise_by_idx([pred_res[idx]], None, address_idx) | |||
# if len(candidate) > 0: | |||
# consistent_idx_tmp.append(idx) | |||
# consistent_pred_res_tmp.append(candidate[0][0]) | |||
# if len(consistent_idx_tmp) > len(consistent_idx): | |||
# consistent_idx = consistent_idx_tmp | |||
# consistent_pred_res = consistent_pred_res_tmp | |||
# if len(mappings) > 1: | |||
# mapping = m | |||
# if len(consistent_idx) == 0: | |||
# return 0, 0, None | |||
# INFO('Train pool size is:', len(flatten(consistent_pred_res))) | |||
# INFO("Start to use abduced pseudo label to train model...") | |||
# model.train([X[idx] for idx in consistent_idx], remapping_res(consistent_pred_res, mapping)) | |||
# consistent_acc = len(consistent_idx) / select_num | |||
# char_acc = _get_char_acc(model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping) | |||
# INFO('consistent_acc is %s, char_acc is %s' % (consistent_acc, char_acc)) | |||
# return consistent_acc, char_acc, mapping | |||
def _remove_duplicate_rule(rule_dict): | |||
add_nums_dict = {} | |||
for r in list(rule_dict): | |||
add_nums = str(r.split("]")[0].split("[")[1]) + str( | |||
r.split("]")[1].split("[")[1] | |||
) # r = 'my_op([1], [0], [1, 0])' then add_nums = '10' | |||
if add_nums in add_nums_dict: | |||
old_r = add_nums_dict[add_nums] | |||
if rule_dict[r] >= rule_dict[old_r]: | |||
rule_dict.pop(old_r) | |||
add_nums_dict[add_nums] = r | |||
else: | |||
rule_dict.pop(r) | |||
else: | |||
add_nums_dict[add_nums] = r | |||
return list(rule_dict) | |||
def get_rules_from_data( | |||
model, abducer, mapping, train_X_true, samples_per_rule, samples_num | |||
): | |||
rules = [] | |||
for _ in range(samples_num): | |||
while True: | |||
select_idx = np.random.randint(len(train_X_true), size=samples_per_rule) | |||
X = [] | |||
for idx in select_idx: | |||
X.append(train_X_true[idx]) | |||
original_pred_res = model.predict(X)["label"] | |||
pred_res = mapping_res(original_pred_res, mapping) | |||
consistent_idx = [] | |||
consistent_pred_res = [] | |||
for idx in range(len(pred_res)): | |||
if abducer.kb.logic_forward([pred_res[idx]]): | |||
consistent_idx.append(idx) | |||
consistent_pred_res.append(pred_res[idx]) | |||
if len(consistent_pred_res) != 0: | |||
rule = abducer.abduce_rules(consistent_pred_res) | |||
if rule != None: | |||
break | |||
rules.append(rule) | |||
all_rule_dict = {} | |||
for rule in rules: | |||
for r in rule: | |||
all_rule_dict[r] = 1 if r not in all_rule_dict else all_rule_dict[r] + 1 | |||
rule_dict = {rule: cnt for rule, cnt in all_rule_dict.items() if cnt >= 5} | |||
rules = _remove_duplicate_rule(rule_dict) | |||
return rules | |||
def _get_consist_rule_acc(model, abducer, mapping, rules, X): | |||
cnt = 0 | |||
for x in X: | |||
original_pred_res = model.predict([x])["label"] | |||
pred_res = flatten(mapping_res(original_pred_res, mapping)) | |||
if abducer.kb.consist_rule(pred_res, rules): | |||
cnt += 1 | |||
return cnt / len(X) | |||
def train_with_rule( | |||
model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8 | |||
): | |||
train_X = train_data | |||
val_X = val_data | |||
samples_num = 50 | |||
samples_per_rule = 3 | |||
# Start training / for each length of equations | |||
for equation_len in range(min_len, max_len): | |||
INFO( | |||
"============== equation_len: %d-%d ================" | |||
% (equation_len, equation_len + 1) | |||
) | |||
train_X_true = train_X[1][equation_len] | |||
train_X_false = train_X[0][equation_len] | |||
val_X_true = val_X[1][equation_len] | |||
val_X_false = val_X[0][equation_len] | |||
train_X_true.extend(train_X[1][equation_len + 1]) | |||
train_X_false.extend(train_X[0][equation_len + 1]) | |||
val_X_true.extend(val_X[1][equation_len + 1]) | |||
val_X_false.extend(val_X[0][equation_len + 1]) | |||
condition_cnt = 0 | |||
while True: | |||
if equation_len == min_len: | |||
mapping = None | |||
# Abduce and train NN | |||
consistent_acc, char_acc, mapping = abduce_and_train( | |||
model, abducer, mapping, train_X_true, select_num | |||
) | |||
if consistent_acc == 0: | |||
continue | |||
# Test if we can use mlp to evaluate | |||
if consistent_acc >= 0.9 and char_acc >= 0.9: | |||
condition_cnt += 1 | |||
else: | |||
condition_cnt = 0 | |||
# The condition has been satisfied continuously five times | |||
if condition_cnt >= 5: | |||
INFO("Now checking if we can go to next course") | |||
rules = get_rules_from_data( | |||
model, abducer, mapping, train_X_true, samples_per_rule, samples_num | |||
) | |||
INFO("Learned rules from data:", rules) | |||
true_consist_rule_acc = _get_consist_rule_acc( | |||
model, abducer, mapping, rules, val_X_true | |||
) | |||
false_consist_rule_acc = _get_consist_rule_acc( | |||
model, abducer, mapping, rules, val_X_false | |||
) | |||
INFO( | |||
"consist_rule_acc is %f, %f\n" | |||
% (true_consist_rule_acc, false_consist_rule_acc) | |||
) | |||
# decide next course or restart | |||
if true_consist_rule_acc > 0.95 and false_consist_rule_acc < 0.1: | |||
torch.save( | |||
model.classifier_list[0].model.state_dict(), | |||
"./weights/weights_%d.pth" % equation_len, | |||
) | |||
break | |||
else: | |||
if equation_len == min_len: | |||
INFO("Final mapping is: ", mapping) | |||
model.classifier_list[0].model.load_state_dict( | |||
torch.load("./weights/pretrain_weights.pth") | |||
) | |||
else: | |||
model.classifier_list[0].model.load_state_dict( | |||
torch.load("./weights/weights_%d.pth" % (equation_len - 1)) | |||
) | |||
condition_cnt = 0 | |||
INFO("Reload Model and retrain") | |||
return model, mapping | |||
def hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8): | |||
train_X = train_data | |||
test_X = test_data | |||
# Calcualte how many equations should be selected in each length | |||
# for each length, there are equation_samples_num[equation_len] rules | |||
print("Now begin to train final mlp model") | |||
equation_samples_num = [] | |||
len_cnt = max_len - min_len + 1 | |||
samples_num = 50 | |||
equation_samples_num += [0] * min_len | |||
if samples_num % len_cnt == 0: | |||
equation_samples_num += [samples_num // len_cnt] * len_cnt | |||
else: | |||
equation_samples_num += [samples_num // len_cnt] * len_cnt | |||
equation_samples_num[-1] += samples_num % len_cnt | |||
assert sum(equation_samples_num) == samples_num | |||
# Abduce rules | |||
rules = [] | |||
samples_per_rule = 3 | |||
for equation_len in range(min_len, max_len + 1): | |||
equation_rules = get_rules_from_data( | |||
model, | |||
abducer, | |||
mapping, | |||
train_X[1][equation_len], | |||
samples_per_rule, | |||
equation_samples_num[equation_len], | |||
) | |||
rules.extend(equation_rules) | |||
rules = list(set(rules)) | |||
INFO("Learned rules from data:", rules) | |||
for equation_len in range(5, 27): | |||
true_consist_rule_acc = _get_consist_rule_acc( | |||
model, abducer, mapping, rules, test_X[1][equation_len] | |||
) | |||
false_consist_rule_acc = _get_consist_rule_acc( | |||
model, abducer, mapping, rules, test_X[0][equation_len] | |||
) | |||
INFO( | |||
"consist_rule_acc of testing length %d equations are %f, %f" | |||
% (equation_len, true_consist_rule_acc, false_consist_rule_acc) | |||
) | |||
if __name__ == "__main__": | |||
pass |
@@ -1,20 +1,18 @@ | |||
import os | |||
from collections import defaultdict | |||
from typing import Any, List | |||
import torch | |||
from torch.utils.data import DataLoader | |||
from abl.reasoning import ReasonerBase | |||
from abl.learning import ABLModel, BasicNN | |||
from abl.bridge import SimpleBridge | |||
from abl.dataset import RegressionDataset | |||
from abl.evaluation import BaseMetric | |||
from abl.dataset import BridgeDataset, RegressionDataset | |||
from abl.learning import ABLModel, BasicNN | |||
from abl.reasoning import ReasonerBase | |||
from abl.structures import ListData | |||
from abl.utils import print_log | |||
from examples.hed.utils import gen_mappings, InfiniteSampler | |||
from examples.models.nn import SymbolNetAutoencoder | |||
from examples.hed.datasets.get_hed import get_pretrain_data | |||
from examples.hed.utils import InfiniteSampler, gen_mappings | |||
from examples.models.nn import SymbolNetAutoencoder | |||
class HEDBridge(SimpleBridge): | |||
@@ -95,7 +93,8 @@ class HEDBridge(SimpleBridge): | |||
character_accuracy = self.model.valid(filtered_data_samples) | |||
revisible_ratio = len(filtered_data_samples.X) / len(data_samples.X) | |||
print_log( | |||
f"Revisible ratio is {revisible_ratio:.3f}, Character accuracy is {character_accuracy:.3f}", | |||
f"Revisible ratio is {revisible_ratio:.3f}, Character \ | |||
accuracy is {character_accuracy:.3f}", | |||
logger="current", | |||
) | |||
@@ -111,7 +110,8 @@ class HEDBridge(SimpleBridge): | |||
false_ratio = self.calc_consistent_ratio(val_X_false, rule) | |||
print_log( | |||
f"True consistent ratio is {true_ratio:.3f}, False inconsistent ratio is {1 - false_ratio:.3f}", | |||
f"True consistent ratio is {true_ratio:.3f}, False inconsistent ratio \ | |||
is {1 - false_ratio:.3f}", | |||
logger="current", | |||
) | |||
@@ -143,7 +143,7 @@ class HEDBridge(SimpleBridge): | |||
if len(consistent_instance) != 0: | |||
rule = self.reasoner.abduce_rules(consistent_instance) | |||
if rule != None: | |||
if rule is not None: | |||
rules.append(rule) | |||
break | |||
@@ -214,7 +214,8 @@ class HEDBridge(SimpleBridge): | |||
loss = self.model.train(filtered_sub_data_samples) | |||
print_log( | |||
f"Equation Len(train) [{equation_len}] Segment Index [{seg_idx + 1}] model loss is {loss:.5f}", | |||
f"Equation Len(train) [{equation_len}] Segment Index [{seg_idx + 1}] \ | |||
model loss is {loss:.5f}", | |||
logger="current", | |||
) | |||
@@ -224,11 +225,11 @@ class HEDBridge(SimpleBridge): | |||
condition_num = 0 | |||
if condition_num >= 5: | |||
print_log(f"Now checking if we can go to next course", logger="current") | |||
print_log("Now checking if we can go to next course", logger="current") | |||
rules = self.get_rules_from_data( | |||
data_samples, samples_per_rule=3, samples_num=50 | |||
) | |||
print_log(f"Learned rules from data: " + str(rules), logger="current") | |||
print_log("Learned rules from data: " + str(rules), logger="current") | |||
seems_good = self.check_rule_quality(rules, val_data, equation_len) | |||
if seems_good: | |||
@@ -66,7 +66,8 @@ | |||
" prolog_rules = prolog_result[0][\"X\"]\n", | |||
" rules = [rule.value for rule in prolog_rules]\n", | |||
" return rules\n", | |||
" \n", | |||
"\n", | |||
"\n", | |||
"class HedReasoner(ReasonerBase):\n", | |||
" def revise_at_idx(self, data_sample):\n", | |||
" revision_idx = np.where(np.array(data_sample.flatten(\"revision_flag\")) != 0)[0]\n", | |||
@@ -76,7 +77,9 @@ | |||
" return candidate\n", | |||
"\n", | |||
" def zoopt_revision_score(self, symbol_num, data_sample, sol):\n", | |||
" revision_flag = reform_list(list(sol.get_x().astype(np.int32)), data_sample.pred_pseudo_label)\n", | |||
" revision_flag = reform_list(\n", | |||
" list(sol.get_x().astype(np.int32)), data_sample.pred_pseudo_label\n", | |||
" )\n", | |||
" data_sample.revision_flag = revision_flag\n", | |||
"\n", | |||
" lefted_idxs = [i for i in range(len(data_sample.pred_idx))]\n", | |||
@@ -108,7 +111,7 @@ | |||
" for i in range(0, len(candidate_size)):\n", | |||
" score -= math.exp(-i) * candidate_size[i]\n", | |||
" return score\n", | |||
" \n", | |||
"\n", | |||
" def abduce(self, data_sample):\n", | |||
" symbol_num = data_sample.elements_num(\"pred_pseudo_label\")\n", | |||
" max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)\n", | |||
@@ -134,9 +137,8 @@ | |||
" def abduce_rules(self, pred_res):\n", | |||
" return self.kb.abduce_rules(pred_res)\n", | |||
"\n", | |||
"kb = HedKB(\n", | |||
" pseudo_label_list=[1, 0, \"+\", \"=\"], pl_file=\"./datasets/learn_add.pl\"\n", | |||
")\n", | |||
"\n", | |||
"kb = HedKB(pseudo_label_list=[1, 0, \"+\", \"=\"], pl_file=\"./datasets/learn_add.pl\")\n", | |||
"reasoner = HedReasoner(kb, dist_func=\"hamming\", use_zoopt=True, max_revision=20)" | |||
] | |||
}, | |||
@@ -1,69 +0,0 @@ | |||
# coding: utf-8 | |||
# ================================================================# | |||
# Copyright (C) 2021 Freecss All rights reserved. | |||
# | |||
# File Name :share_example.py | |||
# Author :freecss | |||
# Email :karlfreecss@gmail.com | |||
# Created Date :2021/06/07 | |||
# Description : | |||
# | |||
# ================================================================# | |||
import sys | |||
sys.path.append("../") | |||
from abl.utils.plog import logger, INFO | |||
from abl.utils.utils import reduce_dimension | |||
import torch.nn as nn | |||
import torch | |||
from abl.models.nn import LeNet5, SymbolNet | |||
from abl.models.basic_model import BasicModel, BasicDataset | |||
from abl.models.wabl_models import DecisionTree, WABLBasicModel | |||
from sklearn.neighbors import KNeighborsClassifier | |||
from abl.abducer.abducer_base import AbducerBase | |||
from abl.abducer.kb import add_KB, HWF_KB, prolog_KB | |||
from datasets.mnist_add.get_mnist_add import get_mnist_add | |||
from datasets.hwf.get_hwf import get_hwf | |||
from datasets.hed.get_hed import get_hed, split_equation | |||
from abl import framework_hed_knn | |||
def run_test(): | |||
# kb = add_KB(True) | |||
# kb = HWF_KB(True) | |||
# abducer = AbducerBase(kb) | |||
kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl') | |||
abducer = AbducerBase(kb, zoopt=True, multiple_predictions=True) | |||
recorder = logger() | |||
total_train_data = get_hed(train=True) | |||
train_data, val_data = split_equation(total_train_data, 3, 1) | |||
test_data = get_hed(train=False) | |||
# ========================= KNN model ============================ # | |||
reduce_dimension(train_data) | |||
reduce_dimension(val_data) | |||
reduce_dimension(test_data) | |||
base_model = KNeighborsClassifier(n_neighbors=3) | |||
pretrain_data_X, pretrain_data_Y = framework_hed_knn.hed_pretrain(base_model) | |||
model = WABLBasicModel(base_model, kb.pseudo_label_list) | |||
model, mapping = framework_hed_knn.train_with_rule( | |||
model, abducer, train_data, val_data, (pretrain_data_X, pretrain_data_Y), select_num=10, min_len=5, max_len=8 | |||
) | |||
framework_hed_knn.hed_test( | |||
model, abducer, mapping, train_data, test_data, min_len=5, max_len=8 | |||
) | |||
# ============================ End =============================== # | |||
recorder.dump() | |||
return True | |||
if __name__ == "__main__": | |||
run_test() |
@@ -1,159 +0,0 @@ | |||
import os.path as osp | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
from abl.evaluation import SemanticsMetric, SymbolMetric | |||
from abl.learning import ABLModel, BasicNN | |||
from abl.reasoning import PrologKB, ReasonerBase | |||
from abl.utils import ABLLogger, print_log, reform_list | |||
from examples.hed.datasets.get_hed import get_hed, split_equation | |||
from examples.hed.hed_bridge import HEDBridge | |||
from examples.models.nn import SymbolNet | |||
# Build logger | |||
print_log("Abductive Learning on the HED example.", logger="current") | |||
# Retrieve the directory of the Log file and define the directory for saving the model weights. | |||
log_dir = ABLLogger.get_current_instance().log_dir | |||
weights_dir = osp.join(log_dir, "weights") | |||
### Logic Part | |||
# Initialize knowledge base and abducer | |||
class HedKB(PrologKB): | |||
def __init__(self, pseudo_label_list, pl_file): | |||
super().__init__(pseudo_label_list, pl_file) | |||
def consist_rule(self, exs, rules): | |||
rules = str(rules).replace("'", "") | |||
return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0 | |||
def abduce_rules(self, pred_res): | |||
prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res)) | |||
if len(prolog_result) == 0: | |||
return None | |||
prolog_rules = prolog_result[0]["X"] | |||
rules = [rule.value for rule in prolog_rules] | |||
return rules | |||
class HedReasoner(ReasonerBase): | |||
def revise_at_idx(self, data_sample): | |||
revision_idx = np.where(np.array(data_sample.flatten("revision_flag")) != 0)[0] | |||
candidate = self.kb.revise_at_idx( | |||
data_sample.pred_pseudo_label, data_sample.Y, revision_idx | |||
) | |||
return candidate | |||
def zoopt_revision_score(self, symbol_num, data_sample, sol): | |||
revision_flag = reform_list(list(sol.get_x().astype(np.int32)), data_sample.pred_pseudo_label) | |||
data_sample.revision_flag = revision_flag | |||
lefted_idxs = [i for i in range(len(data_sample.pred_idx))] | |||
candidate_size = [] | |||
while lefted_idxs: | |||
idxs = [] | |||
idxs.append(lefted_idxs.pop(0)) | |||
max_candidate_idxs = [] | |||
found = False | |||
for idx in range(-1, len(data_sample.pred_idx)): | |||
if (not idx in idxs) and (idx >= 0): | |||
idxs.append(idx) | |||
candidate = self.revise_at_idx(data_sample[idxs]) | |||
if len(candidate) == 0: | |||
if len(idxs) > 1: | |||
idxs.pop() | |||
else: | |||
if len(idxs) > len(max_candidate_idxs): | |||
found = True | |||
max_candidate_idxs = idxs.copy() | |||
removed = [i for i in lefted_idxs if i in max_candidate_idxs] | |||
if found: | |||
candidate_size.append(len(removed) + 1) | |||
lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] | |||
candidate_size.sort() | |||
score = 0 | |||
import math | |||
for i in range(0, len(candidate_size)): | |||
score -= math.exp(-i) * candidate_size[i] | |||
return score | |||
def abduce(self, data_sample): | |||
symbol_num = data_sample.elements_num("pred_pseudo_label") | |||
max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num) | |||
solution = self.zoopt_get_solution(symbol_num, data_sample, max_revision_num) | |||
data_sample.revision_flag = reform_list( | |||
solution.astype(np.int32), data_sample.pred_pseudo_label | |||
) | |||
abduced_pseudo_label = [] | |||
for single_instance in data_sample: | |||
single_instance.pred_pseudo_label = [single_instance.pred_pseudo_label] | |||
candidates = self.revise_at_idx(single_instance) | |||
if len(candidates) == 0: | |||
abduced_pseudo_label.append([]) | |||
else: | |||
abduced_pseudo_label.append(candidates[0][0]) | |||
data_sample.abduced_pseudo_label = abduced_pseudo_label | |||
return abduced_pseudo_label | |||
def abduce_rules(self, pred_res): | |||
return self.kb.abduce_rules(pred_res) | |||
import os | |||
CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | |||
kb = HedKB( | |||
pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "./datasets/learn_add.pl") | |||
) | |||
reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=20) | |||
### Machine Learning Part | |||
# Build necessary components for BasicNN | |||
cls = SymbolNet(num_classes=4) | |||
criterion = nn.CrossEntropyLoss() | |||
optimizer = torch.optim.Adam(cls.parameters(), lr=0.001) | |||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||
# Build BasicNN | |||
# The function of BasicNN is to wrap NN models into the form of an sklearn estimator | |||
base_model = BasicNN( | |||
cls, | |||
criterion, | |||
optimizer, | |||
device, | |||
batch_size=32, | |||
num_epochs=1, | |||
save_interval=1, | |||
save_dir=weights_dir, | |||
) | |||
# Build ABLModel | |||
# The main function of the ABL model is to serialize data and | |||
# provide a unified interface for different machine learning models | |||
model = ABLModel(base_model) | |||
### Metric | |||
# Set up metrics | |||
metric_list = [SymbolMetric(prefix="hed"), SemanticsMetric(prefix="hed")] | |||
### Bridge Machine Learning and Logic Reasoning | |||
bridge = HEDBridge(model, reasoner, metric_list) | |||
### Dataset | |||
total_train_data = get_hed(train=True) | |||
train_data, val_data = split_equation(total_train_data, 3, 1) | |||
test_data = get_hed(train=False) | |||
### Train and Test | |||
bridge.pretrain("examples/hed/weights") | |||
bridge.train(train_data, val_data) |
@@ -1,6 +1,6 @@ | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
import numpy as np | |||
import torch.utils.data.sampler as sampler | |||
@@ -13,7 +13,7 @@ class InfiniteSampler(sampler.Sampler): | |||
while True: | |||
order = np.random.permutation(self.num_samples) | |||
for i in range(self.num_samples): | |||
yield order[i: i + self.batch_size] | |||
yield order[i : i + self.batch_size] | |||
i += self.batch_size | |||
def __len__(self): | |||
@@ -58,7 +58,6 @@ def reduce_dimension(data): | |||
for equation_len in range(5, 27): | |||
equations = data[truth_value][equation_len] | |||
reduced_equations = [ | |||
[extract_feature(symbol_img) for symbol_img in equation] | |||
for equation in equations | |||
[extract_feature(symbol_img) for symbol_img in equation] for equation in equations | |||
] | |||
data[truth_value][equation_len] = reduced_equations | |||
data[truth_value][equation_len] = reduced_equations |
@@ -1,14 +1,12 @@ | |||
import os | |||
import json | |||
import os | |||
from PIL import Image | |||
from torchvision.transforms import transforms | |||
CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | |||
img_transform = transforms.Compose( | |||
[transforms.ToTensor(), transforms.Normalize((0.5,), (1,))] | |||
) | |||
img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1,))]) | |||
def get_data(file, get_pseudo_label): | |||
@@ -51,14 +51,13 @@ | |||
"source": [ | |||
"# Initialize knowledge base and reasoner\n", | |||
"class HWF_KB(KBBase):\n", | |||
"\n", | |||
" def _valid_candidate(self, formula):\n", | |||
" if len(formula) % 2 == 0:\n", | |||
" return False\n", | |||
" for i in range(len(formula)):\n", | |||
" if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']:\n", | |||
" if i % 2 == 0 and formula[i] not in [\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\"]:\n", | |||
" return False\n", | |||
" if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']:\n", | |||
" if i % 2 != 0 and formula[i] not in [\"+\", \"-\", \"times\", \"div\"]:\n", | |||
" return False\n", | |||
" return True\n", | |||
"\n", | |||
@@ -66,12 +65,17 @@ | |||
" if not self._valid_candidate(formula):\n", | |||
" return np.inf\n", | |||
" mapping = {str(i): str(i) for i in range(1, 10)}\n", | |||
" mapping.update({'+': '+', '-': '-', 'times': '*', 'div': '/'})\n", | |||
" mapping.update({\"+\": \"+\", \"-\": \"-\", \"times\": \"*\", \"div\": \"/\"})\n", | |||
" formula = [mapping[f] for f in formula]\n", | |||
" return eval(''.join(formula))\n", | |||
" return eval(\"\".join(formula))\n", | |||
"\n", | |||
"\n", | |||
"kb = HWF_KB(pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], max_err=1e-10, use_cache=False)\n", | |||
"reasoner = ReasonerBase(kb, dist_func='confidence')" | |||
"kb = HWF_KB(\n", | |||
" pseudo_label_list=[\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"+\", \"-\", \"times\", \"div\"],\n", | |||
" max_err=1e-10,\n", | |||
" use_cache=False,\n", | |||
")\n", | |||
"reasoner = ReasonerBase(kb, dist_func=\"confidence\")" | |||
] | |||
}, | |||
{ | |||
@@ -122,7 +126,7 @@ | |||
"outputs": [], | |||
"source": [ | |||
"# Initialize ABL model\n", | |||
"# The main function of the ABL model is to serialize data and \n", | |||
"# The main function of the ABL model is to serialize data and\n", | |||
"# provide a unified interface for different machine learning models\n", | |||
"model = ABLModel(base_model)" | |||
] | |||
@@ -80,7 +80,7 @@ | |||
"outputs": [], | |||
"source": [ | |||
"# Build ABLModel\n", | |||
"# The main function of the ABL model is to serialize data and \n", | |||
"# The main function of the ABL model is to serialize data and\n", | |||
"# provide a unified interface for different machine learning models\n", | |||
"model = ABLModel(base_model)" | |||
] | |||
@@ -11,8 +11,8 @@ | |||
# ================================================================# | |||
import torch | |||
import numpy as np | |||
import torch | |||
from torch import nn | |||
@@ -84,9 +84,7 @@ class SymbolNetAutoencoder(nn.Module): | |||
self.base_model = SymbolNet(num_classes, image_size) | |||
self.softmax = nn.Softmax(dim=1) | |||
self.fc1 = nn.Sequential(nn.Linear(num_classes, 100), nn.ReLU()) | |||
self.fc2 = nn.Sequential( | |||
nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU() | |||
) | |||
self.fc2 = nn.Sequential(nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU()) | |||
def forward(self, x): | |||
x = self.base_model(x) | |||
@@ -1,4 +1,5 @@ | |||
import os | |||
from setuptools import find_packages, setup | |||
@@ -27,7 +28,13 @@ here = os.path.abspath(os.path.dirname(__file__)) | |||
try: | |||
with open(os.path.join(here, "requirements.txt"), encoding="utf-8") as f: | |||
REQUIRED = f.read().split("\n") | |||
except: | |||
except FileNotFoundError: | |||
# Handle the case where the file does not exist | |||
print("requirements.txt file not found.") | |||
REQUIRED = [] | |||
except Exception as e: | |||
# Handle other possible exceptions | |||
print(f"An error occurred: {e}") | |||
REQUIRED = [] | |||
EXTRAS = { | |||
@@ -64,7 +71,7 @@ if __name__ == "__main__": | |||
install_requires=REQUIRED, | |||
extras_require=EXTRAS, | |||
classifiers=[ | |||
'Development Status :: 3 - Alpha', | |||
"Development Status :: 3 - Alpha", | |||
"Intended Audience :: Science/Research", | |||
"Intended Audience :: Developers", | |||
"Programming Language :: Python", | |||
@@ -74,4 +81,3 @@ if __name__ == "__main__": | |||
"Programming Language :: Python :: 3.8", | |||
], | |||
) | |||
@@ -4,10 +4,11 @@ import torch.nn as nn | |||
import torch.optim as optim | |||
from abl.learning import BasicNN | |||
from abl.reasoning import KBBase, GroundKB, PrologKB, Reasoner | |||
from abl.reasoning import GroundKB, KBBase, PrologKB, Reasoner | |||
from abl.structures import ListData | |||
from examples.models.nn import LeNet5 | |||
# Fixture for BasicNN instance | |||
@pytest.fixture | |||
def basic_nn_instance(): | |||
@@ -16,6 +17,7 @@ def basic_nn_instance(): | |||
optimizer = optim.Adam(model.parameters()) | |||
return BasicNN(model, criterion, optimizer) | |||
# Fixture for base_model instance | |||
@pytest.fixture | |||
def base_model_instance(): | |||
@@ -24,6 +26,7 @@ def base_model_instance(): | |||
optimizer = optim.Adam(model.parameters()) | |||
return BasicNN(model, criterion, optimizer) | |||
# Fixture for ListData instance | |||
@pytest.fixture | |||
def list_data_instance(): | |||
@@ -37,47 +40,71 @@ def list_data_instance(): | |||
@pytest.fixture | |||
def data_samples_add(): | |||
# favor 1 in first one | |||
prob1 = [[0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0], | |||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] | |||
prob1 = [ | |||
[0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0], | |||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
] | |||
# favor 7 in first one | |||
prob2 = [[0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0], | |||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] | |||
prob2 = [ | |||
[0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0], | |||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
] | |||
data_samples_add = ListData() | |||
data_samples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]] | |||
data_samples_add.pred_prob = [prob1, prob2, prob1, prob2] | |||
data_samples_add.Y = [8, 8, 17, 10] | |||
return data_samples_add | |||
@pytest.fixture | |||
def data_samples_hwf(): | |||
data_samples_hwf = ListData() | |||
data_samples_hwf.pred_pseudo_label = [["5", "+", "2"], ["5", "+", "9"], ["5", "+", "9"], ["5", "-", "8", "8", "8"]] | |||
data_samples_hwf.pred_pseudo_label = [ | |||
["5", "+", "2"], | |||
["5", "+", "9"], | |||
["5", "+", "9"], | |||
["5", "-", "8", "8", "8"], | |||
] | |||
data_samples_hwf.pred_prob = [None, None, None, None] | |||
data_samples_hwf.Y = [3, 64, 65, 3.17] | |||
return data_samples_hwf | |||
class AddKB(KBBase): | |||
def __init__(self, pseudo_label_list=list(range(10)), | |||
use_cache=False): | |||
def __init__(self, pseudo_label_list=list(range(10)), use_cache=False): | |||
super().__init__(pseudo_label_list, use_cache=use_cache) | |||
def logic_forward(self, nums): | |||
return sum(nums) | |||
class AddGroundKB(GroundKB): | |||
def __init__(self, pseudo_label_list=list(range(10)), | |||
GKB_len_list=[2]): | |||
def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]): | |||
super().__init__(pseudo_label_list, GKB_len_list) | |||
def logic_forward(self, nums): | |||
return sum(nums) | |||
class HwfKB(KBBase): | |||
def __init__( | |||
self, | |||
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", | |||
"+", "-", "times", "div"], | |||
pseudo_label_list=[ | |||
"1", | |||
"2", | |||
"3", | |||
"4", | |||
"5", | |||
"6", | |||
"7", | |||
"8", | |||
"9", | |||
"+", | |||
"-", | |||
"times", | |||
"div", | |||
], | |||
max_err=1e-3, | |||
use_cache=False, | |||
): | |||
@@ -87,7 +114,17 @@ class HwfKB(KBBase): | |||
if len(formula) % 2 == 0: | |||
return False | |||
for i in range(len(formula)): | |||
if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: | |||
if i % 2 == 0 and formula[i] not in [ | |||
"1", | |||
"2", | |||
"3", | |||
"4", | |||
"5", | |||
"6", | |||
"7", | |||
"8", | |||
"9", | |||
]: | |||
return False | |||
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | |||
return False | |||
@@ -100,7 +137,8 @@ class HwfKB(KBBase): | |||
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | |||
formula = [mapping[f] for f in formula] | |||
return eval("".join(formula)) | |||
class HedKB(PrologKB): | |||
def __init__(self, pseudo_label_list, pl_file): | |||
super().__init__(pseudo_label_list, pl_file) | |||
@@ -110,24 +148,28 @@ class HedKB(PrologKB): | |||
pl_query = "eval_inst_feature(%s, %s)." % (exs, rules) | |||
return len(list(self.prolog.query(pl_query))) != 0 | |||
@pytest.fixture | |||
def kb_add(): | |||
return AddKB() | |||
@pytest.fixture | |||
def kb_add_cache(): | |||
return AddKB(use_cache=True) | |||
return AddKB(use_cache=True) | |||
@pytest.fixture | |||
def kb_add_ground(): | |||
return AddGroundKB() | |||
@pytest.fixture | |||
def kb_add_prolog(): | |||
kb = PrologKB(pseudo_label_list=list(range(10)), | |||
pl_file="examples/mnist_add/datasets/add.pl") | |||
kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/datasets/add.pl") | |||
return kb | |||
@pytest.fixture | |||
def kb_hed(): | |||
kb = HedKB( | |||
@@ -136,6 +178,7 @@ def kb_hed(): | |||
) | |||
return kb | |||
@pytest.fixture | |||
def reasoner_instance(kb_add): | |||
return Reasoner(kb_add, "confidence") | |||
return Reasoner(kb_add, "confidence") |
@@ -1,8 +1,9 @@ | |||
from unittest.mock import Mock, create_autospec | |||
import numpy as np | |||
import pytest | |||
from abl.learning import ABLModel | |||
from unittest.mock import Mock, create_autospec | |||
class TestABLModel(object): | |||
@@ -1,63 +1,66 @@ | |||
import pytest | |||
from abl.reasoning import KBBase, GroundKB, PrologKB, Reasoner | |||
from abl.reasoning import PrologKB, Reasoner | |||
class TestKBBase(object): | |||
def test_init(self, kb_add): | |||
assert kb_add.pseudo_label_list == list(range(10)) | |||
def test_init_cache(self, kb_add_cache): | |||
assert kb_add_cache.pseudo_label_list == list(range(10)) | |||
assert kb_add_cache.use_cache == True | |||
assert kb_add_cache.use_cache is True | |||
def test_logic_forward(self, kb_add): | |||
result = kb_add.logic_forward([1, 2]) | |||
assert result == 3 | |||
def test_revise_at_idx(self, kb_add): | |||
result = kb_add.revise_at_idx([1, 2], 2, [0]) | |||
assert result == [[0, 2]] | |||
def test_abduce_candidates(self, kb_add): | |||
result = kb_add.abduce_candidates([1, 2], 1, max_revision_num=2, | |||
require_more_revision=0) | |||
result = kb_add.abduce_candidates([1, 2], 1, max_revision_num=2, require_more_revision=0) | |||
assert result == [[1, 0]] | |||
class TestGroundKB(object): | |||
def test_init(self, kb_add_ground): | |||
assert kb_add_ground.pseudo_label_list == list(range(10)) | |||
assert kb_add_ground.GKB_len_list == [2] | |||
assert kb_add_ground.GKB | |||
def test_logic_forward_ground(self, kb_add_ground): | |||
result = kb_add_ground.logic_forward([1, 2]) | |||
assert result == 3 | |||
def test_abduce_candidates_ground(self, kb_add_ground): | |||
result = kb_add_ground.abduce_candidates([1, 2], 1, max_revision_num=2, | |||
require_more_revision=0) | |||
result = kb_add_ground.abduce_candidates( | |||
[1, 2], 1, max_revision_num=2, require_more_revision=0 | |||
) | |||
assert result == [(1, 0)] | |||
class TestPrologKB(object): | |||
class TestPrologKB(object): | |||
def test_init_pl1(self, kb_add_prolog): | |||
assert kb_add_prolog.pseudo_label_list == list(range(10)) | |||
assert kb_add_prolog.pl_file == "examples/mnist_add/datasets/add.pl" | |||
def test_init_pl2(self, kb_hed): | |||
assert kb_hed.pseudo_label_list == [1, 0, "+", "="] | |||
assert kb_hed.pl_file == "examples/hed/datasets/learn_add.pl" | |||
def test_prolog_file_not_exist(self): | |||
pseudo_label_list = [1, 2] | |||
non_existing_file = "path/to/non_existing_file.pl" | |||
with pytest.raises(FileNotFoundError) as excinfo: | |||
PrologKB(pseudo_label_list=pseudo_label_list, | |||
pl_file=non_existing_file) | |||
PrologKB(pseudo_label_list=pseudo_label_list, pl_file=non_existing_file) | |||
assert non_existing_file in str(excinfo.value) | |||
def test_logic_forward_pl1(self, kb_add_prolog): | |||
result = kb_add_prolog.logic_forward([1, 2]) | |||
assert result == 3 | |||
def test_logic_forward_pl2(self, kb_hed): | |||
consist_exs = [ | |||
[1, 1, "+", 0, "=", 1, 1], | |||
@@ -70,21 +73,24 @@ class TestPrologKB(object): | |||
[0, "+", 0, "=", 0], | |||
[0, "+", 0, "=", 1], | |||
] | |||
assert kb_hed.logic_forward(consist_exs) == True | |||
assert kb_hed.logic_forward(inconsist_exs) == False | |||
assert kb_hed.logic_forward(consist_exs) is True | |||
assert kb_hed.logic_forward(inconsist_exs) is False | |||
def test_revise_at_idx(self, kb_add_prolog): | |||
result = kb_add_prolog.revise_at_idx([1, 2], 2, [0]) | |||
assert result == [[0, 2]] | |||
class TestReaonser(object): | |||
def test_reasoner_init(self, reasoner_instance): | |||
assert reasoner_instance.dist_func == "confidence" | |||
def test_invalid_dist_funce(kb_add): | |||
with pytest.raises(NotImplementedError) as excinfo: | |||
Reasoner(kb_add, "invalid_dist_func") | |||
assert "Valid options for dist_func include \"hamming\" and \"confidence\"" in str(excinfo.value) | |||
assert 'Valid options for dist_func include "hamming" and "confidence"' in str( | |||
excinfo.value | |||
) | |||
class test_batch_abduce(object): | |||
@@ -95,8 +101,18 @@ class test_batch_abduce(object): | |||
reasoner4 = Reasoner(kb_add, "confidence", max_revision=2, require_more_revision=1) | |||
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] | |||
assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] | |||
assert reasoner3.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [1, 9]] | |||
assert reasoner4.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [7, 3]] | |||
assert reasoner3.batch_abduce(data_samples_add) == [ | |||
[1, 7], | |||
[7, 1], | |||
[8, 9], | |||
[1, 9], | |||
] | |||
assert reasoner4.batch_abduce(data_samples_add) == [ | |||
[1, 7], | |||
[7, 1], | |||
[8, 9], | |||
[7, 3], | |||
] | |||
def test_batch_abduce_ground(self, kb_add_ground, data_samples_add): | |||
reasoner1 = Reasoner(kb_add_ground, "confidence", max_revision=1, require_more_revision=0) | |||
@@ -105,8 +121,18 @@ class test_batch_abduce(object): | |||
reasoner4 = Reasoner(kb_add_ground, "confidence", max_revision=2, require_more_revision=1) | |||
assert reasoner1.batch_abduce(data_samples_add) == [(1, 7), (7, 1), [], (1, 9)] | |||
assert reasoner2.batch_abduce(data_samples_add) == [(1, 7), (7, 1), [], (1, 9)] | |||
assert reasoner3.batch_abduce(data_samples_add) == [(1, 7), (7, 1), (8, 9), (1, 9)] | |||
assert reasoner4.batch_abduce(data_samples_add) == [(1, 7), (7, 1), (8, 9), (7, 3)] | |||
assert reasoner3.batch_abduce(data_samples_add) == [ | |||
(1, 7), | |||
(7, 1), | |||
(8, 9), | |||
(1, 9), | |||
] | |||
assert reasoner4.batch_abduce(data_samples_add) == [ | |||
(1, 7), | |||
(7, 1), | |||
(8, 9), | |||
(7, 3), | |||
] | |||
def test_batch_abduce_prolog(self, kb_add_prolog, data_samples_add): | |||
reasoner1 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=0) | |||
@@ -115,35 +141,73 @@ class test_batch_abduce(object): | |||
reasoner4 = Reasoner(kb_add_prolog, "confidence", max_revision=2, require_more_revision=1) | |||
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] | |||
assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] | |||
assert reasoner3.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [1, 9]] | |||
assert reasoner4.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [7, 3]] | |||
assert reasoner3.batch_abduce(data_samples_add) == [ | |||
[1, 7], | |||
[7, 1], | |||
[8, 9], | |||
[1, 9], | |||
] | |||
assert reasoner4.batch_abduce(data_samples_add) == [ | |||
[1, 7], | |||
[7, 1], | |||
[8, 9], | |||
[7, 3], | |||
] | |||
def test_batch_abduce_zoopt(self, kb_add_prolog, data_samples_add): | |||
reasoner1 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=1) | |||
reasoner2 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=2) | |||
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] | |||
assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [7, 3]] | |||
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] | |||
assert reasoner2.batch_abduce(data_samples_add) == [ | |||
[1, 7], | |||
[7, 1], | |||
[8, 9], | |||
[7, 3], | |||
] | |||
def test_batch_abduce_hwf1(self, kb_hwf1, data_samples_hwf): | |||
reasoner1 = Reasoner(kb_hwf1, "hamming", max_revision=3, require_more_revision=0) | |||
reasoner2 = Reasoner(kb_hwf1, "hamming", max_revision=0.5, require_more_revision=0) | |||
reasoner3 = Reasoner(kb_hwf1, "hamming", max_revision=0.9, require_more_revision=0) | |||
res = reasoner1.batch_abduce(data_samples_hwf) | |||
assert res == [['1', '+', '2'], ['8', 'times', '8'], [], ['4', '-', '6', 'div', '8']] | |||
assert res == [ | |||
["1", "+", "2"], | |||
["8", "times", "8"], | |||
[], | |||
["4", "-", "6", "div", "8"], | |||
] | |||
res = reasoner2.batch_abduce(data_samples_hwf) | |||
assert res == [['1', '+', '2'], [], [], []] | |||
assert res == [["1", "+", "2"], [], [], []] | |||
res = reasoner3.batch_abduce(data_samples_hwf) | |||
assert res == [['1', '+', '2'], ['8', 'times', '8'], [], ['4', '-', '6', 'div', '8']] | |||
assert res == [ | |||
["1", "+", "2"], | |||
["8", "times", "8"], | |||
[], | |||
["4", "-", "6", "div", "8"], | |||
] | |||
def test_batch_abduce_hwf2(self, kb_hwf2, data_samples_hwf): | |||
reasoner1 = Reasoner(kb_hwf2, "hamming", max_revision=3, require_more_revision=0) | |||
reasoner2 = Reasoner(kb_hwf2, "hamming", max_revision=0.5, require_more_revision=0) | |||
reasoner3 = Reasoner(kb_hwf2, "hamming", max_revision=0.9, require_more_revision=0) | |||
res = reasoner1.batch_abduce(data_samples_hwf) | |||
assert res == [['1', '+', '2'], ['7', 'times', '9'], ['8', 'times', '8'], ['5', '-', '8', 'div', '8']] | |||
assert res == [ | |||
["1", "+", "2"], | |||
["7", "times", "9"], | |||
["8", "times", "8"], | |||
["5", "-", "8", "div", "8"], | |||
] | |||
res = reasoner2.batch_abduce(data_samples_hwf) | |||
assert res == [['1', '+', '2'], ['7', 'times', '9'], [], ['5', '-', '8', 'div', '8']] | |||
assert res == [ | |||
["1", "+", "2"], | |||
["7", "times", "9"], | |||
[], | |||
["5", "-", "8", "div", "8"], | |||
] | |||
res = reasoner3.batch_abduce(data_samples_hwf) | |||
assert res == [['1', '+', '2'], ['7', 'times', '9'], ['8', 'times', '8'], ['5', '-', '8', 'div', '8']] | |||
assert res == [ | |||
["1", "+", "2"], | |||
["7", "times", "9"], | |||
["8", "times", "8"], | |||
["5", "-", "8", "div", "8"], | |||
] |