@@ -5,4 +5,11 @@ show_missing = True | |||||
disable_warnings = include-ignored | disable_warnings = include-ignored | ||||
include = */abl/* | include = */abl/* | ||||
omit = | omit = | ||||
*/abl/__init__.py | |||||
*/abl/__init__.py | |||||
abl/bridge/__init__.py | |||||
abl/dataset/__init__.py | |||||
abl/evaluation/__init__.py | |||||
abl/learning/__init__.py | |||||
abl/reasoning/__init__.py | |||||
abl/structures/__init__.py | |||||
abl/utils/__init__.py |
@@ -24,7 +24,7 @@ jobs: | |||||
- name: Install package dependencies | - name: Install package dependencies | ||||
run: | | run: | | ||||
python -m pip install --upgrade pip | python -m pip install --upgrade pip | ||||
pip install -r ./requirements.txt | |||||
pip install -r build_tools/requirements.txt | |||||
- name: Run tests | - name: Run tests | ||||
run: | | run: | | ||||
pytest --cov-config=.coveragerc --cov-report=xml --cov=abl ./tests | pytest --cov-config=.coveragerc --cov-report=xml --cov=abl ./tests | ||||
@@ -20,5 +20,5 @@ jobs: | |||||
- name: flake8 Lint | - name: flake8 Lint | ||||
uses: py-actions/flake8@v2 | uses: py-actions/flake8@v2 | ||||
with: | with: | ||||
max-line-length: "110" | |||||
plugins: "flake8-bugbear flake8-black" | |||||
max-line-length: "100" | |||||
args: --ignore=E203,W503 |
@@ -1,2 +1,11 @@ | |||||
from .learning import abl_model, basic_nn | |||||
from .reasoning import reasoner, kb | |||||
from . import bridge, dataset, evaluation, learning, reasoning, structures, utils | |||||
__all__ = [ | |||||
"bridge", | |||||
"dataset", | |||||
"evaluation", | |||||
"learning", | |||||
"reasoning", | |||||
"structures", | |||||
"utils", | |||||
] |
@@ -1,3 +1,3 @@ | |||||
VERSION = (0, 0, 1) | VERSION = (0, 0, 1) | ||||
__version__ = ".".join(map(str, VERSION)) | |||||
__version__ = ".".join(map(str, VERSION)) |
@@ -1,2 +1,4 @@ | |||||
from .base_bridge import BaseBridge | from .base_bridge import BaseBridge | ||||
from .simple_bridge import SimpleBridge | |||||
from .simple_bridge import SimpleBridge | |||||
__all__ = ["BaseBridge", "SimpleBridge"] |
@@ -12,53 +12,40 @@ class BaseBridge(metaclass=ABCMeta): | |||||
def __init__(self, model: ABLModel, reasoner: Reasoner) -> None: | def __init__(self, model: ABLModel, reasoner: Reasoner) -> None: | ||||
if not isinstance(model, ABLModel): | if not isinstance(model, ABLModel): | ||||
raise TypeError( | raise TypeError( | ||||
"Expected an instance of ABLModel, but received type: {}".format( | |||||
type(model) | |||||
) | |||||
"Expected an instance of ABLModel, but received type: {}".format(type(model)) | |||||
) | ) | ||||
if not isinstance(reasoner, Reasoner): | if not isinstance(reasoner, Reasoner): | ||||
raise TypeError( | raise TypeError( | ||||
"Expected an instance of Reasoner, but received type: {}".format( | |||||
type(reasoner) | |||||
) | |||||
"Expected an instance of Reasoner, but received type: {}".format(type(reasoner)) | |||||
) | ) | ||||
self.model = model | self.model = model | ||||
self.reasoner = reasoner | self.reasoner = reasoner | ||||
@abstractmethod | @abstractmethod | ||||
def predict( | |||||
self, data_samples: ListData | |||||
) -> Tuple[List[List[Any]], List[List[Any]]]: | |||||
def predict(self, data_samples: ListData) -> Tuple[List[List[Any]], List[List[Any]]]: | |||||
"""Placeholder for predict labels from input.""" | """Placeholder for predict labels from input.""" | ||||
pass | |||||
@abstractmethod | @abstractmethod | ||||
def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | ||||
"""Placeholder for abduce pseudo labels.""" | """Placeholder for abduce pseudo labels.""" | ||||
pass | |||||
@abstractmethod | @abstractmethod | ||||
def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | ||||
"""Placeholder for map label space to symbol space.""" | """Placeholder for map label space to symbol space.""" | ||||
pass | |||||
@abstractmethod | @abstractmethod | ||||
def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]: | def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]: | ||||
"""Placeholder for map symbol space to label space.""" | """Placeholder for map symbol space to label space.""" | ||||
pass | |||||
@abstractmethod | @abstractmethod | ||||
def train(self, train_data: Union[ListData, DataSet]): | def train(self, train_data: Union[ListData, DataSet]): | ||||
"""Placeholder for train loop of ABductive Learning.""" | """Placeholder for train loop of ABductive Learning.""" | ||||
pass | |||||
@abstractmethod | @abstractmethod | ||||
def valid(self, valid_data: Union[ListData, DataSet]) -> None: | def valid(self, valid_data: Union[ListData, DataSet]) -> None: | ||||
"""Placeholder for model test.""" | """Placeholder for model test.""" | ||||
pass | |||||
@abstractmethod | @abstractmethod | ||||
def test(self, test_data: Union[ListData, DataSet]) -> None: | def test(self, test_data: Union[ListData, DataSet]) -> None: | ||||
"""Placeholder for model validation.""" | """Placeholder for model validation.""" | ||||
pass |
@@ -1,5 +1,5 @@ | |||||
import os.path as osp | import os.path as osp | ||||
from typing import Any, Dict, List, Optional, Tuple, Union | |||||
from typing import Any, List, Optional, Tuple, Union | |||||
from numpy import ndarray | from numpy import ndarray | ||||
@@ -32,8 +32,7 @@ class SimpleBridge(BaseBridge): | |||||
def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | ||||
pred_idx = data_samples.pred_idx | pred_idx = data_samples.pred_idx | ||||
data_samples.pred_pseudo_label = [ | data_samples.pred_pseudo_label = [ | ||||
[self.reasoner.mapping[_idx] for _idx in sub_list] | |||||
for sub_list in pred_idx | |||||
[self.reasoner.mapping[_idx] for _idx in sub_list] for sub_list in pred_idx | |||||
] | ] | ||||
return data_samples.pred_pseudo_label | return data_samples.pred_pseudo_label | ||||
@@ -81,7 +80,9 @@ class SimpleBridge(BaseBridge): | |||||
loss = self.model.train(sub_data_samples) | loss = self.model.train(sub_data_samples) | ||||
print_log( | print_log( | ||||
f"loop(train) [{loop + 1}/{loops}] segment(train) [{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] model loss is {loss:.5f}", | |||||
f"loop(train) [{loop + 1}/{loops}] segment(train) \ | |||||
[{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] \ | |||||
model loss is {loss:.5f}", | |||||
logger="current", | logger="current", | ||||
) | ) | ||||
@@ -2,3 +2,10 @@ from .bridge_dataset import BridgeDataset | |||||
from .classification_dataset import ClassificationDataset | from .classification_dataset import ClassificationDataset | ||||
from .prediction_dataset import PredictionDataset | from .prediction_dataset import PredictionDataset | ||||
from .regression_dataset import RegressionDataset | from .regression_dataset import RegressionDataset | ||||
__all__ = [ | |||||
"BridgeDataset", | |||||
"ClassificationDataset", | |||||
"PredictionDataset", | |||||
"RegressionDataset", | |||||
] |
@@ -13,11 +13,15 @@ class BridgeDataset(Dataset): | |||||
gt_pseudo_label : List[List[Any]], optional | gt_pseudo_label : List[List[Any]], optional | ||||
A list of objects representing the ground truth label of each element in ``X``. | A list of objects representing the ground truth label of each element in ``X``. | ||||
Y : List[Any] | Y : List[Any] | ||||
A list of objects representing the ground truth of the reasoning result of each instance in ``X``. | |||||
A list of objects representing the ground truth of the reasoning result of | |||||
each instance in ``X``. | |||||
""" | """ | ||||
def __init__( | def __init__( | ||||
self, X: List[List[Any]], gt_pseudo_label: Optional[List[List[Any]]], Y: List[Any] | |||||
self, | |||||
X: List[List[Any]], | |||||
gt_pseudo_label: Optional[List[List[Any]]], | |||||
Y: List[Any], | |||||
): | ): | ||||
if (not isinstance(X, list)) or (not isinstance(Y, list)): | if (not isinstance(X, list)) or (not isinstance(Y, list)): | ||||
raise ValueError("X and Y should be of type list.") | raise ValueError("X and Y should be of type list.") | ||||
@@ -15,16 +15,16 @@ class ClassificationDataset(Dataset): | |||||
Y : List[int] | Y : List[int] | ||||
The target data. | The target data. | ||||
transform : Callable[..., Any], optional | transform : Callable[..., Any], optional | ||||
A function/transform that takes in an object and returns a transformed version. Defaults to None. | |||||
A function/transform that takes in an object and returns a transformed version. | |||||
Defaults to None. | |||||
""" | """ | ||||
def __init__( | |||||
self, X: List[Any], Y: List[int], transform: Callable[..., Any] = None | |||||
): | |||||
def __init__(self, X: List[Any], Y: List[int], transform: Callable[..., Any] = None): | |||||
if (not isinstance(X, list)) or (not isinstance(Y, list)): | if (not isinstance(X, list)) or (not isinstance(Y, list)): | ||||
raise ValueError("X and Y should be of type list.") | raise ValueError("X and Y should be of type list.") | ||||
if len(X) != len(Y): | if len(X) != len(Y): | ||||
raise ValueError("Length of X and Y must be equal.") | raise ValueError("Length of X and Y must be equal.") | ||||
self.X = X | self.X = X | ||||
self.Y = torch.LongTensor(Y) | self.Y = torch.LongTensor(Y) | ||||
self.transform = transform | self.transform = transform | ||||
@@ -13,8 +13,10 @@ class PredictionDataset(Dataset): | |||||
X : List[Any] | X : List[Any] | ||||
The input data. | The input data. | ||||
transform : Callable[..., Any], optional | transform : Callable[..., Any], optional | ||||
A function/transform that takes in an object and returns a transformed version. Defaults to None. | |||||
A function/transform that takes in an object and returns a transformed version. | |||||
Defaults to None. | |||||
""" | """ | ||||
def __init__(self, X: List[Any], transform: Callable[..., Any] = None): | def __init__(self, X: List[Any], transform: Callable[..., Any] = None): | ||||
if not isinstance(X, list): | if not isinstance(X, list): | ||||
raise ValueError("X should be of type list.") | raise ValueError("X should be of type list.") | ||||
@@ -1,6 +1,5 @@ | |||||
from typing import Any, List, Tuple | from typing import Any, List, Tuple | ||||
import torch | |||||
from torch.utils.data import Dataset | from torch.utils.data import Dataset | ||||
@@ -15,12 +14,13 @@ class RegressionDataset(Dataset): | |||||
Y : List[Any] | Y : List[Any] | ||||
A list of objects representing the output data. | A list of objects representing the output data. | ||||
""" | """ | ||||
def __init__(self, X: List[Any], Y: List[Any]): | def __init__(self, X: List[Any], Y: List[Any]): | ||||
if (not isinstance(X, list)) or (not isinstance(Y, list)): | if (not isinstance(X, list)) or (not isinstance(Y, list)): | ||||
raise ValueError("X and Y should be of type list.") | raise ValueError("X and Y should be of type list.") | ||||
if len(X) != len(Y): | if len(X) != len(Y): | ||||
raise ValueError("Length of X and Y must be equal.") | raise ValueError("Length of X and Y must be equal.") | ||||
self.X = X | self.X = X | ||||
self.Y = Y | self.Y = Y | ||||
@@ -1,3 +1,5 @@ | |||||
from .base_metric import BaseMetric | from .base_metric import BaseMetric | ||||
from .semantics_metric import SemanticsMetric | from .semantics_metric import SemanticsMetric | ||||
from .symbol_metric import SymbolMetric | from .symbol_metric import SymbolMetric | ||||
__all__ = ["BaseMetric", "SemanticsMetric", "SymbolMetric"] |
@@ -20,8 +20,10 @@ class BaseMetric(metaclass=ABCMeta): | |||||
will be used instead. Default: None | will be used instead. Default: None | ||||
""" | """ | ||||
def __init__(self, | |||||
prefix: Optional[str] = None,) -> None: | |||||
def __init__( | |||||
self, | |||||
prefix: Optional[str] = None, | |||||
) -> None: | |||||
self.results: List[Any] = [] | self.results: List[Any] = [] | ||||
self.prefix = prefix or self.default_prefix | self.prefix = prefix or self.default_prefix | ||||
@@ -65,20 +67,18 @@ class BaseMetric(metaclass=ABCMeta): | |||||
""" | """ | ||||
if len(self.results) == 0: | if len(self.results) == 0: | ||||
print_log( | print_log( | ||||
f'{self.__class__.__name__} got empty `self.results`. Please ' | |||||
'ensure that the processed results are properly added into ' | |||||
'`self.results` in `process` method.', | |||||
logger='current', | |||||
level=logging.WARNING) | |||||
f"{self.__class__.__name__} got empty `self.results`. Please " | |||||
"ensure that the processed results are properly added into " | |||||
"`self.results` in `process` method.", | |||||
logger="current", | |||||
level=logging.WARNING, | |||||
) | |||||
metrics = self.compute_metrics(self.results) | metrics = self.compute_metrics(self.results) | ||||
# Add prefix to metric names | # Add prefix to metric names | ||||
if self.prefix: | if self.prefix: | ||||
metrics = { | |||||
'/'.join((self.prefix, k)): v | |||||
for k, v in metrics.items() | |||||
} | |||||
metrics = {"/".join((self.prefix, k)): v for k, v in metrics.items()} | |||||
# reset the results list | # reset the results list | ||||
self.results.clear() | self.results.clear() | ||||
return metrics | |||||
return metrics |
@@ -14,15 +14,15 @@ class SymbolMetric(BaseMetric): | |||||
if not len(pred_pseudo_label) == len(gt_pseudo_label): | if not len(pred_pseudo_label) == len(gt_pseudo_label): | ||||
raise ValueError("lengthes of pred_pseudo_label and gt_pseudo_label should be equal") | raise ValueError("lengthes of pred_pseudo_label and gt_pseudo_label should be equal") | ||||
for pred_z, z in zip(pred_pseudo_label, gt_pseudo_label): | for pred_z, z in zip(pred_pseudo_label, gt_pseudo_label): | ||||
correct_num = 0 | correct_num = 0 | ||||
for pred_symbol, symbol in zip(pred_z, z): | for pred_symbol, symbol in zip(pred_z, z): | ||||
if pred_symbol == symbol: | if pred_symbol == symbol: | ||||
correct_num += 1 | correct_num += 1 | ||||
self.results.append(correct_num / len(z)) | self.results.append(correct_num / len(z)) | ||||
def compute_metrics(self, results: list) -> dict: | def compute_metrics(self, results: list) -> dict: | ||||
metrics = dict() | metrics = dict() | ||||
metrics["character_accuracy"] = sum(results) / len(results) | metrics["character_accuracy"] = sum(results) / len(results) | ||||
return metrics | |||||
return metrics |
@@ -1,2 +1,4 @@ | |||||
from .abl_model import ABLModel | from .abl_model import ABLModel | ||||
from .basic_nn import BasicNN | |||||
from .basic_nn import BasicNN | |||||
__all__ = ["ABLModel", "BasicNN"] |
@@ -58,7 +58,8 @@ class ABLModel: | |||||
Parameters | Parameters | ||||
---------- | ---------- | ||||
data_samples : ListData | data_samples : ListData | ||||
A batch of data to train on, which typically contains the data, `X`, and the corresponding labels, `abduced_idx`. | |||||
A batch of data to train on, which typically contains the data, `X`, and the | |||||
corresponding labels, `abduced_idx`. | |||||
Returns | Returns | ||||
------- | ------- | ||||
@@ -68,7 +69,7 @@ class ABLModel: | |||||
data_X = data_samples.flatten("X") | data_X = data_samples.flatten("X") | ||||
data_y = data_samples.flatten("abduced_idx") | data_y = data_samples.flatten("abduced_idx") | ||||
return self.base_model.fit(X=data_X, y=data_y) | return self.base_model.fit(X=data_X, y=data_y) | ||||
def valid(self, data_samples: ListData) -> float: | def valid(self, data_samples: ListData) -> float: | ||||
""" | """ | ||||
Validate the model on the given data. | Validate the model on the given data. | ||||
@@ -76,7 +77,8 @@ class ABLModel: | |||||
Parameters | Parameters | ||||
---------- | ---------- | ||||
data_samples : ListData | data_samples : ListData | ||||
A batch of data to train on, which typically contains the data, `X`, and the corresponding labels, `abduced_idx`. | |||||
A batch of data to train on, which typically contains the data, `X`, | |||||
and the corresponding labels, `abduced_idx`. | |||||
Returns | Returns | ||||
------- | ------- | ||||
@@ -94,7 +96,7 @@ class ABLModel: | |||||
method = getattr(model, operation) | method = getattr(model, operation) | ||||
method(*args, **kwargs) | method(*args, **kwargs) | ||||
else: | else: | ||||
if not f"{operation}_path" in kwargs.keys(): | |||||
if f"{operation}_path" not in kwargs.keys(): | |||||
raise ValueError(f"'{operation}_path' should not be None") | raise ValueError(f"'{operation}_path' should not be None") | ||||
else: | else: | ||||
try: | try: | ||||
@@ -104,9 +106,10 @@ class ABLModel: | |||||
elif operation == "load": | elif operation == "load": | ||||
with open(kwargs["load_path"], "rb") as file: | with open(kwargs["load_path"], "rb") as file: | ||||
self.base_model = pickle.load(file) | self.base_model = pickle.load(file) | ||||
except: | |||||
except (OSError, pickle.PickleError): | |||||
raise NotImplementedError( | raise NotImplementedError( | ||||
f"{type(model).__name__} object doesn't have the {operation} method and the default pickle-based {operation} method failed." | |||||
f"{type(model).__name__} object doesn't have the {operation} method \ | |||||
and the default pickle-based {operation} method failed." | |||||
) | ) | ||||
def save(self, *args, **kwargs) -> None: | def save(self, *args, **kwargs) -> None: | ||||
@@ -1,5 +1,5 @@ | |||||
import os | |||||
import logging | import logging | ||||
import os | |||||
from typing import Any, Callable, List, Optional, T, Tuple | from typing import Any, Callable, List, Optional, T, Tuple | ||||
import numpy | import numpy | ||||
@@ -23,7 +23,8 @@ class BasicNN: | |||||
optimizer : torch.optim.Optimizer | optimizer : torch.optim.Optimizer | ||||
The optimizer used for training. | The optimizer used for training. | ||||
device : torch.device, optional | device : torch.device, optional | ||||
The device on which the model will be trained or used for prediction, by default torch.device("cpu"). | |||||
The device on which the model will be trained or used for prediction, | |||||
by default torch.device("cpu"). | |||||
batch_size : int, optional | batch_size : int, optional | ||||
The batch size used for training, by default 32. | The batch size used for training, by default 32. | ||||
num_epochs : int, optional | num_epochs : int, optional | ||||
@@ -37,9 +38,11 @@ class BasicNN: | |||||
save_dir : Optional[str], optional | save_dir : Optional[str], optional | ||||
The directory in which to save the model during training, by default None. | The directory in which to save the model during training, by default None. | ||||
train_transform : Callable[..., Any], optional | train_transform : Callable[..., Any], optional | ||||
A function/transform that takes in an object and returns a transformed version used in the `fit` and `train_epoch` methods, by default None. | |||||
A function/transform that takes in an object and returns a transformed version used | |||||
in the `fit` and `train_epoch` methods, by default None. | |||||
test_transform : Callable[..., Any], optional | test_transform : Callable[..., Any], optional | ||||
A function/transform that takes in an object and returns a transformed version in the `predict`, `predict_proba` and `score` methods, , by default None. | |||||
A function/transform that takes in an object and returns a transformed version in the | |||||
`predict`, `predict_proba` and `score` methods, , by default None. | |||||
collate_fn : Callable[[List[T]], Any], optional | collate_fn : Callable[[List[T]], Any], optional | ||||
The function used to collate data, by default None. | The function used to collate data, by default None. | ||||
""" | """ | ||||
@@ -1,2 +1,4 @@ | |||||
from .kb import KBBase, GroundKB, PrologKB | |||||
from .reasoner import Reasoner | |||||
from .kb import GroundKB, KBBase, PrologKB | |||||
from .reasoner import Reasoner | |||||
__all__ = ["KBBase", "GroundKB", "PrologKB", "Reasoner"] |
@@ -1,16 +1,15 @@ | |||||
from abc import ABC, abstractmethod | |||||
import bisect | import bisect | ||||
import os | import os | ||||
from abc import ABC, abstractmethod | |||||
from collections import defaultdict | from collections import defaultdict | ||||
from itertools import product, combinations | |||||
from itertools import combinations, product | |||||
from multiprocessing import Pool | from multiprocessing import Pool | ||||
from functools import lru_cache | |||||
import numpy as np | import numpy as np | ||||
import pyswip | import pyswip | ||||
from ..utils.utils import flatten, reform_list, hamming_dist, to_hashable | |||||
from ..utils.cache import abl_cache | from ..utils.cache import abl_cache | ||||
from ..utils.utils import flatten, hamming_dist, reform_list, to_hashable | |||||
class KBBase(ABC): | class KBBase(ABC): | ||||
@@ -20,19 +19,19 @@ class KBBase(ABC): | |||||
Parameters | Parameters | ||||
---------- | ---------- | ||||
pseudo_label_list : list | pseudo_label_list : list | ||||
List of possible pseudo labels. It's recommended to arrange the pseudo labels in this | |||||
list so that each aligns with its corresponding index in the base model: the first with | |||||
List of possible pseudo labels. It's recommended to arrange the pseudo labels in this | |||||
list so that each aligns with its corresponding index in the base model: the first with | |||||
the 0th index, the second with the 1st, and so forth. | the 0th index, the second with the 1st, and so forth. | ||||
max_err : float, optional | max_err : float, optional | ||||
The upper tolerance limit when comparing the similarity between a pseudo label sample's reasoning | |||||
result and the ground truth. This is only applicable when the reasoning result is of a numerical type. | |||||
This is particularly relevant for regression problems where exact matches might not be | |||||
feasible. Defaults to 1e-10. | |||||
The upper tolerance limit when comparing the similarity between a pseudo label sample's | |||||
reasoning result and the ground truth. This is only applicable when the reasoning | |||||
result is of a numerical type. This is particularly relevant for regression problems where | |||||
exact matches might not be feasible. Defaults to 1e-10. | |||||
use_cache : bool, optional | use_cache : bool, optional | ||||
Whether to use abl_cache for previously abduced candidates to speed up subsequent | Whether to use abl_cache for previously abduced candidates to speed up subsequent | ||||
operations. Defaults to True. | operations. Defaults to True. | ||||
key_func : func, optional | key_func : func, optional | ||||
A function employed for hashing in abl_cache. This is only operational when use_cache | |||||
A function employed for hashing in abl_cache. This is only operational when use_cache | |||||
is set to True. Defaults to to_hashable. | is set to True. Defaults to to_hashable. | ||||
cache_size: int, optional | cache_size: int, optional | ||||
The cache size in abl_cache. This is only operational when use_cache is set to | The cache size in abl_cache. This is only operational when use_cache is set to | ||||
@@ -75,7 +74,6 @@ class KBBase(ABC): | |||||
pseudo_label : List[Any] | pseudo_label : List[Any] | ||||
Pseudo label sample. | Pseudo label sample. | ||||
""" | """ | ||||
pass | |||||
def abduce_candidates(self, pseudo_label, y, max_revision_num, require_more_revision): | def abduce_candidates(self, pseudo_label, y, max_revision_num, require_more_revision): | ||||
""" | """ | ||||
@@ -104,7 +102,7 @@ class KBBase(ABC): | |||||
""" | """ | ||||
Check whether the reasoning result of a pseduo label sample is equal to the ground truth | Check whether the reasoning result of a pseduo label sample is equal to the ground truth | ||||
(or, within the maximum error allowed for numerical results). | (or, within the maximum error allowed for numerical results). | ||||
Returns | Returns | ||||
------- | ------- | ||||
bool | bool | ||||
@@ -130,7 +128,7 @@ class KBBase(ABC): | |||||
Ground truth of the reasoning result for the sample. | Ground truth of the reasoning result for the sample. | ||||
revision_idx : array-like | revision_idx : array-like | ||||
Indices of where revisions should be made to the pseudo label sample. | Indices of where revisions should be made to the pseudo label sample. | ||||
Returns | Returns | ||||
------- | ------- | ||||
List[List[Any]] | List[List[Any]] | ||||
@@ -149,8 +147,8 @@ class KBBase(ABC): | |||||
def _revision(self, revision_num, pseudo_label, y): | def _revision(self, revision_num, pseudo_label, y): | ||||
""" | """ | ||||
For a specified number of labels in a pseudo label sample to revise, iterate through all possible | |||||
indices to find any candidates that are compatible with the knowledge base. | |||||
For a specified number of labels in a pseudo label sample to revise, iterate through | |||||
all possible indices to find any candidates that are compatible with the knowledge base. | |||||
""" | """ | ||||
new_candidates = [] | new_candidates = [] | ||||
revision_idx_list = combinations(range(len(pseudo_label)), revision_num) | revision_idx_list = combinations(range(len(pseudo_label)), revision_num) | ||||
@@ -164,8 +162,8 @@ class KBBase(ABC): | |||||
def _abduce_by_search(self, pseudo_label, y, max_revision_num, require_more_revision): | def _abduce_by_search(self, pseudo_label, y, max_revision_num, require_more_revision): | ||||
""" | """ | ||||
Perform abductive reasoning by exhastive search. Specifically, begin with 0 and | Perform abductive reasoning by exhastive search. Specifically, begin with 0 and | ||||
continuously increase the number of labels in a pseudo label sample to revise, until candidates | |||||
that are compatible with the knowledge base are found. | |||||
continuously increase the number of labels in a pseudo label sample to revise, until | |||||
candidates that are compatible with the knowledge base are found. | |||||
Parameters | Parameters | ||||
---------- | ---------- | ||||
@@ -177,8 +175,8 @@ class KBBase(ABC): | |||||
The upper limit on the number of revisions. | The upper limit on the number of revisions. | ||||
require_more_revision : int | require_more_revision : int | ||||
If larger than 0, then after having found any candidates compatible with the | If larger than 0, then after having found any candidates compatible with the | ||||
knowledge base, continue to increase the number of labels in a pseudo label sample to revise to | |||||
get more possible compatible candidates. | |||||
knowledge base, continue to increase the number of labels in a pseudo label sample to | |||||
revise to get more possible compatible candidates. | |||||
Returns | Returns | ||||
------- | ------- | ||||
@@ -286,7 +284,7 @@ class GroundKB(KBBase): | |||||
Perform abductive reasoning by directly retrieving compatible candidates from | Perform abductive reasoning by directly retrieving compatible candidates from | ||||
the prebuilt GKB. In this way, the time-consuming exhaustive search can be | the prebuilt GKB. In this way, the time-consuming exhaustive search can be | ||||
avoided. | avoided. | ||||
Parameters | Parameters | ||||
---------- | ---------- | ||||
pseudo_label : List[Any] | pseudo_label : List[Any] | ||||
@@ -347,7 +345,7 @@ class GroundKB(KBBase): | |||||
num_candidates = len(self.GKB[i]) if i in self.GKB else 0 | num_candidates = len(self.GKB[i]) if i in self.GKB else 0 | ||||
GKB_info_parts.append(f"{num_candidates} candidates of length {i}") | GKB_info_parts.append(f"{num_candidates} candidates of length {i}") | ||||
GKB_info = ", ".join(GKB_info_parts) | GKB_info = ", ".join(GKB_info_parts) | ||||
return ( | return ( | ||||
f"{self.__class__.__name__} is a KB with " | f"{self.__class__.__name__} is a KB with " | ||||
f"pseudo_label_list={self.pseudo_label_list!r}, " | f"pseudo_label_list={self.pseudo_label_list!r}, " | ||||
@@ -400,7 +398,7 @@ class PrologKB(KBBase): | |||||
returned `Res` as the reasoning results. To use this default function, there must be | returned `Res` as the reasoning results. To use this default function, there must be | ||||
a `logic_forward` method in the pl file to perform reasoning. | a `logic_forward` method in the pl file to perform reasoning. | ||||
Otherwise, users would override this function. | Otherwise, users would override this function. | ||||
Parameters | Parameters | ||||
---------- | ---------- | ||||
pseudo_label : List[Any] | pseudo_label : List[Any] | ||||
@@ -429,9 +427,10 @@ class PrologKB(KBBase): | |||||
def get_query_string(self, pseudo_label, y, revision_idx): | def get_query_string(self, pseudo_label, y, revision_idx): | ||||
""" | """ | ||||
Get the query to be used for consulting Prolog. | Get the query to be used for consulting Prolog. | ||||
This is a default function for demo, users would override this function to adapt to their own | |||||
Prolog file. In this demo function, return query `logic_forward([kept_labels, Revise_labels], Res).`. | |||||
This is a default function for demo, users would override this function to adapt to | |||||
their own Prolog file. In this demo function, return query | |||||
`logic_forward([kept_labels, Revise_labels], Res).`. | |||||
Parameters | Parameters | ||||
---------- | ---------- | ||||
pseudo_label : List[Any] | pseudo_label : List[Any] | ||||
@@ -440,7 +439,7 @@ class PrologKB(KBBase): | |||||
Ground truth of the reasoning result for the sample. | Ground truth of the reasoning result for the sample. | ||||
revision_idx : array-like | revision_idx : array-like | ||||
Indices of where revisions should be made to the pseudo label sample. | Indices of where revisions should be made to the pseudo label sample. | ||||
Returns | Returns | ||||
------- | ------- | ||||
str | str | ||||
@@ -448,14 +447,14 @@ class PrologKB(KBBase): | |||||
""" | """ | ||||
query_string = "logic_forward(" | query_string = "logic_forward(" | ||||
query_string += self._revision_pseudo_label(pseudo_label, revision_idx) | query_string += self._revision_pseudo_label(pseudo_label, revision_idx) | ||||
key_is_none_flag = y is None or (type(y) == list and y[0] is None) | |||||
key_is_none_flag = y is None or (isinstance(y, list) and y[0] is None) | |||||
query_string += ",%s)." % y if not key_is_none_flag else ")." | query_string += ",%s)." % y if not key_is_none_flag else ")." | ||||
return query_string | return query_string | ||||
def revise_at_idx(self, pseudo_label, y, revision_idx): | def revise_at_idx(self, pseudo_label, y, revision_idx): | ||||
""" | """ | ||||
Revise the pseudo label sample at specified index positions by querying Prolog. | Revise the pseudo label sample at specified index positions by querying Prolog. | ||||
Parameters | Parameters | ||||
---------- | ---------- | ||||
pseudo_label : List[Any] | pseudo_label : List[Any] | ||||
@@ -464,7 +463,7 @@ class PrologKB(KBBase): | |||||
Ground truth of the reasoning result for the sample. | Ground truth of the reasoning result for the sample. | ||||
revision_idx : array-like | revision_idx : array-like | ||||
Indices of where revisions should be made to the pseudo label sample. | Indices of where revisions should be made to the pseudo label sample. | ||||
Returns | Returns | ||||
------- | ------- | ||||
List[List[Any]] | List[List[Any]] | ||||
@@ -1,11 +1,7 @@ | |||||
import numpy as np | import numpy as np | ||||
from zoopt import Dimension, Objective, Parameter, Opt | |||||
from ..utils.utils import ( | |||||
confidence_dist, | |||||
flatten, | |||||
reform_list, | |||||
hamming_dist, | |||||
) | |||||
from zoopt import Dimension, Objective, Opt, Parameter | |||||
from ..utils.utils import confidence_dist, hamming_dist | |||||
class Reasoner: | class Reasoner: | ||||
@@ -124,7 +120,7 @@ class Reasoner: | |||||
def zoopt_get_solution(self, symbol_num, data_sample, max_revision_num): | def zoopt_get_solution(self, symbol_num, data_sample, max_revision_num): | ||||
""" | """ | ||||
Get the optimal solution using ZOOpt library. The solution is a list of | |||||
Get the optimal solution using ZOOpt library. The solution is a list of | |||||
boolean values, where '1' (True) indicates the indices chosen to be revised. | boolean values, where '1' (True) indicates the indices chosen to be revised. | ||||
Parameters | Parameters | ||||
@@ -148,7 +144,7 @@ class Reasoner: | |||||
def zoopt_revision_score(self, symbol_num, data_sample, sol): | def zoopt_revision_score(self, symbol_num, data_sample, sol): | ||||
""" | """ | ||||
Get the revision score for a solution. A lower score suggests that ZOOpt library | |||||
Get the revision score for a solution. A lower score suggests that ZOOpt library | |||||
has a higher preference for this solution. | has a higher preference for this solution. | ||||
""" | """ | ||||
revision_idx = np.where(sol.get_x() != 0)[0] | revision_idx = np.where(sol.get_x() != 0)[0] | ||||
@@ -198,7 +194,7 @@ class Reasoner: | |||||
Returns | Returns | ||||
------- | ------- | ||||
List[Any] | List[Any] | ||||
A revised pseudo label sample through abductive reasoning, which is compatible | |||||
A revised pseudo label sample through abductive reasoning, which is compatible | |||||
with the knowledge base. | with the knowledge base. | ||||
""" | """ | ||||
symbol_num = data_sample.elements_num("pred_pseudo_label") | symbol_num = data_sample.elements_num("pred_pseudo_label") | ||||
@@ -1,2 +1,4 @@ | |||||
from .base_data_element import BaseDataElement | from .base_data_element import BaseDataElement | ||||
from .list_data import ListData | |||||
from .list_data import ListData | |||||
__all__ = ["BaseDataElement", "ListData"] |
@@ -224,9 +224,7 @@ class BaseDataElement: | |||||
metainfo (dict): A dict contains the meta information | metainfo (dict): A dict contains the meta information | ||||
of image, such as ``img_shape``, ``scale_factor``, etc. | of image, such as ``img_shape``, ``scale_factor``, etc. | ||||
""" | """ | ||||
assert isinstance( | |||||
metainfo, dict | |||||
), f"metainfo should be a ``dict`` but got {type(metainfo)}" | |||||
assert isinstance(metainfo, dict), f"metainfo should be a ``dict`` but got {type(metainfo)}" | |||||
meta = copy.deepcopy(metainfo) | meta = copy.deepcopy(metainfo) | ||||
for k, v in meta.items(): | for k, v in meta.items(): | ||||
self.set_field(name=k, value=v, field_type="metainfo", dtype=None) | self.set_field(name=k, value=v, field_type="metainfo", dtype=None) | ||||
@@ -388,8 +386,7 @@ class BaseDataElement: | |||||
super().__setattr__(name, value) | super().__setattr__(name, value) | ||||
else: | else: | ||||
raise AttributeError( | raise AttributeError( | ||||
f"{name} has been used as a " | |||||
"private attribute, which is immutable." | |||||
f"{name} has been used as a " "private attribute, which is immutable." | |||||
) | ) | ||||
else: | else: | ||||
self.set_field(name=name, value=value, field_type="data", dtype=None) | self.set_field(name=name, value=value, field_type="data", dtype=None) | ||||
@@ -458,9 +455,7 @@ class BaseDataElement: | |||||
functions.""" | functions.""" | ||||
assert field_type in ["metainfo", "data"] | assert field_type in ["metainfo", "data"] | ||||
if dtype is not None: | if dtype is not None: | ||||
assert isinstance( | |||||
value, dtype | |||||
), f"{value} should be a {dtype} but got {type(value)}" | |||||
assert isinstance(value, dtype), f"{value} should be a {dtype} but got {type(value)}" | |||||
if field_type == "metainfo": | if field_type == "metainfo": | ||||
if name in self._data_fields: | if name in self._data_fields: | ||||
@@ -571,8 +566,7 @@ class BaseDataElement: | |||||
def to_dict(self) -> dict: | def to_dict(self) -> dict: | ||||
"""Convert BaseDataElement to dict.""" | """Convert BaseDataElement to dict.""" | ||||
return { | return { | ||||
k: v.to_dict() if isinstance(v, BaseDataElement) else v | |||||
for k, v in self.all_items() | |||||
k: v.to_dict() if isinstance(v, BaseDataElement) else v for k, v in self.all_items() | |||||
} | } | ||||
def __repr__(self) -> str: | def __repr__(self) -> str: | ||||
@@ -1,7 +1,6 @@ | |||||
# Copyright (c) OpenMMLab. All rights reserved. | # Copyright (c) OpenMMLab. All rights reserved. | ||||
import itertools | import itertools | ||||
from collections.abc import Sized | |||||
from typing import Any, List, Union | |||||
from typing import List, Union | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
@@ -1,3 +1,23 @@ | |||||
from .cache import Cache, abl_cache | from .cache import Cache, abl_cache | ||||
from .logger import ABLLogger, print_log | from .logger import ABLLogger, print_log | ||||
from .utils import * | |||||
from .utils import ( | |||||
calculate_revision_num, | |||||
confidence_dist, | |||||
flatten, | |||||
hamming_dist, | |||||
reform_list, | |||||
to_hashable, | |||||
) | |||||
__all__ = [ | |||||
"Cache", | |||||
"ABLLogger", | |||||
"print_log", | |||||
"calculate_revision_num", | |||||
"confidence_dist", | |||||
"flatten", | |||||
"hamming_dist", | |||||
"reform_list", | |||||
"to_hashable", | |||||
"abl_cache", | |||||
] |
@@ -1,10 +1,5 @@ | |||||
import pickle | |||||
import os | |||||
import os.path as osp | |||||
from typing import Callable, Generic, TypeVar | from typing import Callable, Generic, TypeVar | ||||
from .logger import print_log, ABLLogger | |||||
K = TypeVar("K") | K = TypeVar("K") | ||||
T = TypeVar("T") | T = TypeVar("T") | ||||
PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields | PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields | ||||
@@ -73,7 +68,6 @@ class Cache(Generic[K, T]): | |||||
# Empty the oldest link and make it the new root. | # Empty the oldest link and make it the new root. | ||||
self.root = oldroot[NEXT] | self.root = oldroot[NEXT] | ||||
oldkey = self.root[KEY] | oldkey = self.root[KEY] | ||||
oldresult = self.root[RESULT] | |||||
self.root[KEY] = self.root[RESULT] = None | self.root[KEY] = self.root[RESULT] = None | ||||
# Now update the cache dictionary. | # Now update the cache dictionary. | ||||
del self.cache_dict[oldkey] | del self.cache_dict[oldkey] | ||||
@@ -15,7 +15,8 @@ class FilterDuplicateWarning(logging.Filter): | |||||
""" | """ | ||||
Filter for eliminating repeated warning messages in logging. | Filter for eliminating repeated warning messages in logging. | ||||
This filter checks for duplicate warning messages and allows only the first occurrence of each message to be logged, filtering out subsequent duplicates. | |||||
This filter checks for duplicate warning messages and allows only the first occurrence of | |||||
each message to be logged, filtering out subsequent duplicates. | |||||
Parameters | Parameters | ||||
---------- | ---------- | ||||
@@ -145,7 +146,8 @@ class ABLLogger(Logger, ManagerMixin): | |||||
`ABLLogger` provides a formatted logger that can log messages with different | `ABLLogger` provides a formatted logger that can log messages with different | ||||
log levels. It allows the creation of logger instances in a similar manner to `ManagerMixin`. | log levels. It allows the creation of logger instances in a similar manner to `ManagerMixin`. | ||||
The logger has features like distributed log storage and colored terminal output for different log levels. | |||||
The logger has features like distributed log storage and colored terminal output for different | |||||
log levels. | |||||
Parameters | Parameters | ||||
---------- | ---------- | ||||
@@ -154,7 +156,8 @@ class ABLLogger(Logger, ManagerMixin): | |||||
logger_name : str, optional | logger_name : str, optional | ||||
`name` attribute of `logging.Logger` instance. Defaults to 'abl'. | `name` attribute of `logging.Logger` instance. Defaults to 'abl'. | ||||
log_file : str, optional | log_file : str, optional | ||||
The log filename. If specified, a `FileHandler` will be added to the logger. Defaults to None. | |||||
The log filename. If specified, a `FileHandler` will be added to the logger. | |||||
Defaults to None. | |||||
log_level : Union[int, str] | log_level : Union[int, str] | ||||
The log level of the handler. Defaults to 'INFO'. | The log level of the handler. Defaults to 'INFO'. | ||||
If log level is 'DEBUG', distributed logs will be saved during distributed training. | If log level is 'DEBUG', distributed logs will be saved during distributed training. | ||||
@@ -287,20 +290,25 @@ def print_log(msg, logger: Optional[Union[Logger, str]] = None, level=logging.IN | |||||
""" | """ | ||||
Print a log message using the specified logger or a default method. | Print a log message using the specified logger or a default method. | ||||
This function logs a message with a given logger, if provided, or prints it using the standard `print` function. It supports special logger types such as 'silent' and 'current'. | |||||
This function logs a message with a given logger, if provided, or prints it using | |||||
the standard `print` function. It supports special logger types such as 'silent' and 'current'. | |||||
Parameters | Parameters | ||||
---------- | ---------- | ||||
msg : str | msg : str | ||||
The message to be logged. | The message to be logged. | ||||
logger : Optional[Union[Logger, str]], optional | logger : Optional[Union[Logger, str]], optional | ||||
The logger to use for logging the message. It can be a `logging.Logger` instance, a string specifying the logger name, 'silent', 'current', or None. If None, the `print` method is used. | |||||
The logger to use for logging the message. It can be a `logging.Logger` instance, a string | |||||
specifying the logger name, 'silent', 'current', or None. If None, the `print` | |||||
method is used. | |||||
- 'silent': No message will be printed. | - 'silent': No message will be printed. | ||||
- 'current': Use the latest created logger to log the message. | - 'current': Use the latest created logger to log the message. | ||||
- other str: The instance name of the logger. A `ValueError` is raised if the logger has not been created. | |||||
- other str: The instance name of the logger. A `ValueError` is raised if the logger has not | |||||
been created. | |||||
- None: The `print()` method is used for logging. | - None: The `print()` method is used for logging. | ||||
level : int, optional | level : int, optional | ||||
The logging level. This is only applicable when `logger` is a Logger object, 'current', or a named logger instance. The default is `logging.INFO`. | |||||
The logging level. This is only applicable when `logger` is a Logger object, 'current', | |||||
or a named logger instance. The default is `logging.INFO`. | |||||
""" | """ | ||||
if logger is None: | if logger is None: | ||||
print(msg) | print(msg) | ||||
@@ -6,7 +6,7 @@ from collections import OrderedDict | |||||
from typing import Type, TypeVar | from typing import Type, TypeVar | ||||
_lock = threading.RLock() | _lock = threading.RLock() | ||||
T = TypeVar('T') | |||||
T = TypeVar("T") | |||||
def _accquire_lock() -> None: | def _accquire_lock() -> None: | ||||
@@ -47,7 +47,7 @@ class ManagerMeta(type): | |||||
cls._instance_dict = OrderedDict() | cls._instance_dict = OrderedDict() | ||||
params = inspect.getfullargspec(cls) | params = inspect.getfullargspec(cls) | ||||
params_names = params[0] if params[0] else [] | params_names = params[0] if params[0] else [] | ||||
assert 'name' in params_names, f'{cls} must have the `name` argument' | |||||
assert "name" in params_names, f"{cls} must have the `name` argument" | |||||
super().__init__(*args) | super().__init__(*args) | ||||
@@ -72,9 +72,8 @@ class ManagerMixin(metaclass=ManagerMeta): | |||||
name (str): Name of the instance. Defaults to ''. | name (str): Name of the instance. Defaults to ''. | ||||
""" | """ | ||||
def __init__(self, name: str = '', **kwargs): | |||||
assert isinstance(name, str) and name, \ | |||||
'name argument must be an non-empty string.' | |||||
def __init__(self, name: str = "", **kwargs): | |||||
assert isinstance(name, str) and name, "name argument must be an non-empty string." | |||||
self._instance_name = name | self._instance_name = name | ||||
@classmethod | @classmethod | ||||
@@ -102,8 +101,7 @@ class ManagerMixin(metaclass=ManagerMeta): | |||||
instance. | instance. | ||||
""" | """ | ||||
_accquire_lock() | _accquire_lock() | ||||
assert isinstance(name, str), \ | |||||
f'type of name should be str, but got {type(cls)}' | |||||
assert isinstance(name, str), f"type of name should be str, but got {type(cls)}" | |||||
instance_dict = cls._instance_dict # type: ignore | instance_dict = cls._instance_dict # type: ignore | ||||
# Get the instance by name. | # Get the instance by name. | ||||
if name not in instance_dict: | if name not in instance_dict: | ||||
@@ -111,9 +109,10 @@ class ManagerMixin(metaclass=ManagerMeta): | |||||
instance_dict[name] = instance # type: ignore | instance_dict[name] = instance # type: ignore | ||||
elif kwargs: | elif kwargs: | ||||
warnings.warn( | warnings.warn( | ||||
f'{cls} instance named of {name} has been created, ' | |||||
'the method `get_instance` should not accept any other ' | |||||
'arguments') | |||||
f"{cls} instance named of {name} has been created, " | |||||
"the method `get_instance` should not accept any other " | |||||
"arguments" | |||||
) | |||||
# Get latest instantiated instance or root instance. | # Get latest instantiated instance or root instance. | ||||
_release_lock() | _release_lock() | ||||
return instance_dict[name] | return instance_dict[name] | ||||
@@ -141,8 +140,9 @@ class ManagerMixin(metaclass=ManagerMeta): | |||||
_accquire_lock() | _accquire_lock() | ||||
if not cls._instance_dict: | if not cls._instance_dict: | ||||
raise RuntimeError( | raise RuntimeError( | ||||
f'Before calling {cls.__name__}.get_current_instance(), you ' | |||||
'should call get_instance(name=xxx) at least once.') | |||||
f"Before calling {cls.__name__}.get_current_instance(), you " | |||||
"should call get_instance(name=xxx) at least once." | |||||
) | |||||
name = next(iter(reversed(cls._instance_dict))) | name = next(iter(reversed(cls._instance_dict))) | ||||
_release_lock() | _release_lock() | ||||
return cls._instance_dict[name] | return cls._instance_dict[name] | ||||
@@ -221,60 +221,3 @@ def calculate_revision_num(parameter, total_length): | |||||
if parameter < 0: | if parameter < 0: | ||||
raise ValueError("If parameter is an int, it must be non-negative.") | raise ValueError("If parameter is an int, it must be non-negative.") | ||||
return parameter | return parameter | ||||
if __name__ == "__main__": | |||||
A = np.array( | |||||
[ | |||||
[ | |||||
0.18401675, | |||||
0.06797526, | |||||
0.06797541, | |||||
0.06801736, | |||||
0.06797528, | |||||
0.06797526, | |||||
0.06818808, | |||||
0.06797527, | |||||
0.06800033, | |||||
0.06797526, | |||||
0.06797526, | |||||
0.06797526, | |||||
0.06797526, | |||||
], | |||||
[ | |||||
0.07223161, | |||||
0.0685229, | |||||
0.06852708, | |||||
0.17227574, | |||||
0.06852163, | |||||
0.07018146, | |||||
0.06860291, | |||||
0.06852849, | |||||
0.06852163, | |||||
0.0685216, | |||||
0.0685216, | |||||
0.06852174, | |||||
0.0685216, | |||||
], | |||||
[ | |||||
0.06794382, | |||||
0.0679436, | |||||
0.06794395, | |||||
0.06794346, | |||||
0.06794346, | |||||
0.18467231, | |||||
0.06794345, | |||||
0.06794871, | |||||
0.06794345, | |||||
0.06794345, | |||||
0.06794345, | |||||
0.06794345, | |||||
0.06794345, | |||||
], | |||||
], | |||||
dtype=np.float32, | |||||
) | |||||
B = [[0, 9, 3], [0, 11, 4]] | |||||
print(ori_confidence_dist(A, B)) | |||||
print(confidence_dist(A, B)) |
@@ -0,0 +1,3 @@ | |||||
-r ../requirements.txt | |||||
pytest | |||||
pytest-cov |
@@ -3,7 +3,7 @@ MNIST Addition | |||||
MNIST Addition was first introduced in [1] and the inputs of this task are pairs of MNIST images and the outputs are their sums. The dataset looks like this: | MNIST Addition was first introduced in [1] and the inputs of this task are pairs of MNIST images and the outputs are their sums. The dataset looks like this: | ||||
.. image:: ../img/image_1.jpg | |||||
.. image:: ../img/Datasets_1.png | |||||
:width: 350px | :width: 350px | ||||
:align: center | :align: center | ||||
@@ -11,9 +11,4 @@ MNIST Addition was first introduced in [1] and the inputs of this task are pairs | |||||
The ``gt_pseudo_label`` is only used to test the performance of the machine learning model and is not used in the training phase. | The ``gt_pseudo_label`` is only used to test the performance of the machine learning model and is not used in the training phase. | ||||
In the Abductive Learning framework, the inference process is as follows: | |||||
.. image:: ../img/image_2.jpg | |||||
:width: 700px | |||||
[1] Robin Manhaeve, Sebastijan Dumancic, Angelika Kimmig, Thomas Demeester, and Luc De Raedt. Deepproblog: Neural probabilistic logic programming. In Advances in Neural Information Processing Systems 31 (NeurIPS), pages 3749-3759.2018. | [1] Robin Manhaeve, Sebastijan Dumancic, Angelika Kimmig, Thomas Demeester, and Luc De Raedt. Deepproblog: Neural probabilistic logic programming. In Advances in Neural Information Processing Systems 31 (NeurIPS), pages 3749-3759.2018. |
@@ -1,14 +1,12 @@ | |||||
import sys | |||||
import os | import os | ||||
import re | import re | ||||
import sys | |||||
if not "READTHEDOCS" in os.environ: | |||||
if "READTHEDOCS" not in os.environ: | |||||
sys.path.insert(0, os.path.abspath("..")) | sys.path.insert(0, os.path.abspath("..")) | ||||
sys.path.append(os.path.abspath("./ABL/")) | sys.path.append(os.path.abspath("./ABL/")) | ||||
# from sphinx.locale import _ | |||||
from sphinx_rtd_theme import __version__ | |||||
project = "ABL" | project = "ABL" | ||||
slug = re.sub(r"\W+", "-", project.lower()) | slug = re.sub(r"\W+", "-", project.lower()) | ||||
@@ -48,8 +46,8 @@ pygments_style = "default" | |||||
html_theme = "sphinx_rtd_theme" | html_theme = "sphinx_rtd_theme" | ||||
html_theme_options = {"display_version": True} | html_theme_options = {"display_version": True} | ||||
html_static_path = ['_static'] | |||||
html_css_files = ['custom.css'] | |||||
html_static_path = ["_static"] | |||||
html_css_files = ["custom.css"] | |||||
# html_theme_path = ["../.."] | # html_theme_path = ["../.."] | ||||
# html_logo = "demo/static/logo-wordmark-light.svg" | # html_logo = "demo/static/logo-wordmark-light.svg" | ||||
# html_show_sourcelink = True | # html_show_sourcelink = True | ||||
@@ -1,11 +1,11 @@ | |||||
import os | import os | ||||
import os.path as osp | import os.path as osp | ||||
import cv2 | |||||
import pickle | import pickle | ||||
import numpy as np | |||||
import random | import random | ||||
from collections import defaultdict | from collections import defaultdict | ||||
import cv2 | |||||
import numpy as np | |||||
from torchvision.transforms import transforms | from torchvision.transforms import transforms | ||||
CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | ||||
@@ -1,388 +0,0 @@ | |||||
# coding: utf-8 | |||||
# ================================================================# | |||||
# Copyright (C) 2021 Freecss All rights reserved. | |||||
# | |||||
# File Name :framework.py | |||||
# Author :freecss | |||||
# Email :karlfreecss@gmail.com | |||||
# Created Date :2021/06/07 | |||||
# Description : | |||||
# | |||||
# ================================================================# | |||||
import torch | |||||
import torch.nn as nn | |||||
import numpy as np | |||||
import os | |||||
from abl.utils.plog import INFO | |||||
from abl.utils.utils import flatten, reform_idx | |||||
from abl.learning.basic_nn import BasicNN, BasicDataset | |||||
from utils import gen_mappings, mapping_res, remapping_res | |||||
from models.nn import SymbolNetAutoencoder | |||||
from torch.utils.data import RandomSampler | |||||
from datasets.get_hed import get_pretrain_data | |||||
def hed_pretrain(kb, cls, recorder): | |||||
cls_autoencoder = SymbolNetAutoencoder(num_classes=len(kb.pseudo_label_list)) | |||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||||
if not os.path.exists("./weights/pretrain_weights.pth"): | |||||
INFO("Pretrain Start") | |||||
pretrain_data_X, pretrain_data_Y = get_pretrain_data(["0", "1", "10", "11"]) | |||||
pretrain_data = BasicDataset(pretrain_data_X, pretrain_data_Y) | |||||
pretrain_data_loader = torch.utils.data.DataLoader( | |||||
pretrain_data, batch_size=64, shuffle=True | |||||
) | |||||
criterion = nn.MSELoss() | |||||
optimizer = torch.optim.RMSprop( | |||||
cls_autoencoder.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6 | |||||
) | |||||
pretrain_model = BasicNN( | |||||
cls_autoencoder, | |||||
criterion, | |||||
optimizer, | |||||
device, | |||||
save_interval=1, | |||||
save_dir=recorder.save_dir, | |||||
num_epochs=10, | |||||
recorder=recorder, | |||||
) | |||||
pretrain_model.fit(pretrain_data_loader) | |||||
torch.save( | |||||
cls_autoencoder.base_model.state_dict(), "./weights/pretrain_weights.pth" | |||||
) | |||||
cls.load_state_dict(cls_autoencoder.base_model.state_dict()) | |||||
else: | |||||
cls.load_state_dict(torch.load("./weights/pretrain_weights.pth")) | |||||
def _get_char_acc(model, X, consistent_pred_res, mapping): | |||||
original_pred_res = model.predict(X)["label"] | |||||
pred_res = flatten(mapping_res(original_pred_res, mapping)) | |||||
INFO("Current model's output: ", pred_res) | |||||
INFO("Abduced labels: ", flatten(consistent_pred_res)) | |||||
assert len(pred_res) == len(flatten(consistent_pred_res)) | |||||
return sum( | |||||
[ | |||||
pred_res[idx] == flatten(consistent_pred_res)[idx] | |||||
for idx in range(len(pred_res)) | |||||
] | |||||
) / len(pred_res) | |||||
def abduce_and_train(model, abducer, mapping, train_X_true, select_num): | |||||
select_idx = RandomSampler(train_X_true, num_samples=select_num,replacement=False) | |||||
X = [train_X_true[idx] for idx in select_idx] | |||||
# original_pred_res = model.predict(X)['label'] | |||||
pred_label = model.predict(X)["label"] | |||||
if mapping == None: | |||||
mappings = gen_mappings([0, 1, 2, 3], ["+", "=", 0, 1]) | |||||
else: | |||||
mappings = [mapping] | |||||
consistent_idx = [] | |||||
consistent_pred_res = [] | |||||
for m in mappings: | |||||
pred_pseudo_label = mapping_res(pred_label, m) | |||||
max_revision_num = 20 | |||||
solution = abducer.zoopt_get_solution( | |||||
pred_label, | |||||
pred_pseudo_label, | |||||
[None] * len(pred_label), | |||||
[None] * len(pred_label), | |||||
max_revision_num, | |||||
) | |||||
all_address_flag = reform_idx(solution, pred_label) | |||||
consistent_idx_tmp = [] | |||||
consistent_pred_res_tmp = [] | |||||
for idx in range(len(pred_label)): | |||||
address_idx = [ | |||||
i for i, flag in enumerate(all_address_flag[idx]) if flag != 0 | |||||
] | |||||
candidate = abducer.revise_by_idx([pred_pseudo_label[idx]], None, address_idx) | |||||
if len(candidate) > 0: | |||||
consistent_idx_tmp.append(idx) | |||||
consistent_pred_res_tmp.append(candidate[0][0]) | |||||
if len(consistent_idx_tmp) > len(consistent_idx): | |||||
consistent_idx = consistent_idx_tmp | |||||
consistent_pred_res = consistent_pred_res_tmp | |||||
if len(mappings) > 1: | |||||
mapping = m | |||||
if len(consistent_idx) == 0: | |||||
return 0, 0, None | |||||
INFO("Train pool size is:", len(flatten(consistent_pred_res))) | |||||
INFO("Start to use abduced pseudo label to train model...") | |||||
model.train( | |||||
[X[idx] for idx in consistent_idx], remapping_res(consistent_pred_res, mapping) | |||||
) | |||||
consistent_acc = len(consistent_idx) / select_num | |||||
char_acc = _get_char_acc( | |||||
model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping | |||||
) | |||||
INFO("consistent_acc is %s, char_acc is %s" % (consistent_acc, char_acc)) | |||||
return consistent_acc, char_acc, mapping | |||||
# def abduce_and_train(model, abducer, mapping, train_X_true, select_num): | |||||
# select_idx = np.random.randint(len(train_X_true), size=select_num) | |||||
# X = [] | |||||
# for idx in select_idx: | |||||
# X.append(train_X_true[idx]) | |||||
# original_pred_res = model.predict(X)['label'] | |||||
# if mapping == None: | |||||
# mappings = gen_mappings([0, 1, 2, 3],['+', '=', 0, 1]) | |||||
# else: | |||||
# mappings = [mapping] | |||||
# consistent_idx = [] | |||||
# consistent_pred_res = [] | |||||
# for m in mappings: | |||||
# pred_res = mapping_res(original_pred_res, m) | |||||
# max_abduce_num = 20 | |||||
# solution = abducer.zoopt_get_solution(pred_res, [None] * len(pred_res), [None] * len(pred_res), max_abduce_num) | |||||
# all_address_flag = reform_idx(solution, pred_res) | |||||
# consistent_idx_tmp = [] | |||||
# consistent_pred_res_tmp = [] | |||||
# for idx in range(len(pred_res)): | |||||
# address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] | |||||
# candidate = abducer.revise_by_idx([pred_res[idx]], None, address_idx) | |||||
# if len(candidate) > 0: | |||||
# consistent_idx_tmp.append(idx) | |||||
# consistent_pred_res_tmp.append(candidate[0][0]) | |||||
# if len(consistent_idx_tmp) > len(consistent_idx): | |||||
# consistent_idx = consistent_idx_tmp | |||||
# consistent_pred_res = consistent_pred_res_tmp | |||||
# if len(mappings) > 1: | |||||
# mapping = m | |||||
# if len(consistent_idx) == 0: | |||||
# return 0, 0, None | |||||
# INFO('Train pool size is:', len(flatten(consistent_pred_res))) | |||||
# INFO("Start to use abduced pseudo label to train model...") | |||||
# model.train([X[idx] for idx in consistent_idx], remapping_res(consistent_pred_res, mapping)) | |||||
# consistent_acc = len(consistent_idx) / select_num | |||||
# char_acc = _get_char_acc(model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping) | |||||
# INFO('consistent_acc is %s, char_acc is %s' % (consistent_acc, char_acc)) | |||||
# return consistent_acc, char_acc, mapping | |||||
def _remove_duplicate_rule(rule_dict): | |||||
add_nums_dict = {} | |||||
for r in list(rule_dict): | |||||
add_nums = str(r.split("]")[0].split("[")[1]) + str( | |||||
r.split("]")[1].split("[")[1] | |||||
) # r = 'my_op([1], [0], [1, 0])' then add_nums = '10' | |||||
if add_nums in add_nums_dict: | |||||
old_r = add_nums_dict[add_nums] | |||||
if rule_dict[r] >= rule_dict[old_r]: | |||||
rule_dict.pop(old_r) | |||||
add_nums_dict[add_nums] = r | |||||
else: | |||||
rule_dict.pop(r) | |||||
else: | |||||
add_nums_dict[add_nums] = r | |||||
return list(rule_dict) | |||||
def get_rules_from_data( | |||||
model, abducer, mapping, train_X_true, samples_per_rule, samples_num | |||||
): | |||||
rules = [] | |||||
for _ in range(samples_num): | |||||
while True: | |||||
select_idx = np.random.randint(len(train_X_true), size=samples_per_rule) | |||||
X = [] | |||||
for idx in select_idx: | |||||
X.append(train_X_true[idx]) | |||||
original_pred_res = model.predict(X)["label"] | |||||
pred_res = mapping_res(original_pred_res, mapping) | |||||
consistent_idx = [] | |||||
consistent_pred_res = [] | |||||
for idx in range(len(pred_res)): | |||||
if abducer.kb.logic_forward([pred_res[idx]]): | |||||
consistent_idx.append(idx) | |||||
consistent_pred_res.append(pred_res[idx]) | |||||
if len(consistent_pred_res) != 0: | |||||
rule = abducer.abduce_rules(consistent_pred_res) | |||||
if rule != None: | |||||
break | |||||
rules.append(rule) | |||||
all_rule_dict = {} | |||||
for rule in rules: | |||||
for r in rule: | |||||
all_rule_dict[r] = 1 if r not in all_rule_dict else all_rule_dict[r] + 1 | |||||
rule_dict = {rule: cnt for rule, cnt in all_rule_dict.items() if cnt >= 5} | |||||
rules = _remove_duplicate_rule(rule_dict) | |||||
return rules | |||||
def _get_consist_rule_acc(model, abducer, mapping, rules, X): | |||||
cnt = 0 | |||||
for x in X: | |||||
original_pred_res = model.predict([x])["label"] | |||||
pred_res = flatten(mapping_res(original_pred_res, mapping)) | |||||
if abducer.kb.consist_rule(pred_res, rules): | |||||
cnt += 1 | |||||
return cnt / len(X) | |||||
def train_with_rule( | |||||
model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8 | |||||
): | |||||
train_X = train_data | |||||
val_X = val_data | |||||
samples_num = 50 | |||||
samples_per_rule = 3 | |||||
# Start training / for each length of equations | |||||
for equation_len in range(min_len, max_len): | |||||
INFO( | |||||
"============== equation_len: %d-%d ================" | |||||
% (equation_len, equation_len + 1) | |||||
) | |||||
train_X_true = train_X[1][equation_len] | |||||
train_X_false = train_X[0][equation_len] | |||||
val_X_true = val_X[1][equation_len] | |||||
val_X_false = val_X[0][equation_len] | |||||
train_X_true.extend(train_X[1][equation_len + 1]) | |||||
train_X_false.extend(train_X[0][equation_len + 1]) | |||||
val_X_true.extend(val_X[1][equation_len + 1]) | |||||
val_X_false.extend(val_X[0][equation_len + 1]) | |||||
condition_cnt = 0 | |||||
while True: | |||||
if equation_len == min_len: | |||||
mapping = None | |||||
# Abduce and train NN | |||||
consistent_acc, char_acc, mapping = abduce_and_train( | |||||
model, abducer, mapping, train_X_true, select_num | |||||
) | |||||
if consistent_acc == 0: | |||||
continue | |||||
# Test if we can use mlp to evaluate | |||||
if consistent_acc >= 0.9 and char_acc >= 0.9: | |||||
condition_cnt += 1 | |||||
else: | |||||
condition_cnt = 0 | |||||
# The condition has been satisfied continuously five times | |||||
if condition_cnt >= 5: | |||||
INFO("Now checking if we can go to next course") | |||||
rules = get_rules_from_data( | |||||
model, abducer, mapping, train_X_true, samples_per_rule, samples_num | |||||
) | |||||
INFO("Learned rules from data:", rules) | |||||
true_consist_rule_acc = _get_consist_rule_acc( | |||||
model, abducer, mapping, rules, val_X_true | |||||
) | |||||
false_consist_rule_acc = _get_consist_rule_acc( | |||||
model, abducer, mapping, rules, val_X_false | |||||
) | |||||
INFO( | |||||
"consist_rule_acc is %f, %f\n" | |||||
% (true_consist_rule_acc, false_consist_rule_acc) | |||||
) | |||||
# decide next course or restart | |||||
if true_consist_rule_acc > 0.95 and false_consist_rule_acc < 0.1: | |||||
torch.save( | |||||
model.classifier_list[0].model.state_dict(), | |||||
"./weights/weights_%d.pth" % equation_len, | |||||
) | |||||
break | |||||
else: | |||||
if equation_len == min_len: | |||||
INFO("Final mapping is: ", mapping) | |||||
model.classifier_list[0].model.load_state_dict( | |||||
torch.load("./weights/pretrain_weights.pth") | |||||
) | |||||
else: | |||||
model.classifier_list[0].model.load_state_dict( | |||||
torch.load("./weights/weights_%d.pth" % (equation_len - 1)) | |||||
) | |||||
condition_cnt = 0 | |||||
INFO("Reload Model and retrain") | |||||
return model, mapping | |||||
def hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8): | |||||
train_X = train_data | |||||
test_X = test_data | |||||
# Calcualte how many equations should be selected in each length | |||||
# for each length, there are equation_samples_num[equation_len] rules | |||||
print("Now begin to train final mlp model") | |||||
equation_samples_num = [] | |||||
len_cnt = max_len - min_len + 1 | |||||
samples_num = 50 | |||||
equation_samples_num += [0] * min_len | |||||
if samples_num % len_cnt == 0: | |||||
equation_samples_num += [samples_num // len_cnt] * len_cnt | |||||
else: | |||||
equation_samples_num += [samples_num // len_cnt] * len_cnt | |||||
equation_samples_num[-1] += samples_num % len_cnt | |||||
assert sum(equation_samples_num) == samples_num | |||||
# Abduce rules | |||||
rules = [] | |||||
samples_per_rule = 3 | |||||
for equation_len in range(min_len, max_len + 1): | |||||
equation_rules = get_rules_from_data( | |||||
model, | |||||
abducer, | |||||
mapping, | |||||
train_X[1][equation_len], | |||||
samples_per_rule, | |||||
equation_samples_num[equation_len], | |||||
) | |||||
rules.extend(equation_rules) | |||||
rules = list(set(rules)) | |||||
INFO("Learned rules from data:", rules) | |||||
for equation_len in range(5, 27): | |||||
true_consist_rule_acc = _get_consist_rule_acc( | |||||
model, abducer, mapping, rules, test_X[1][equation_len] | |||||
) | |||||
false_consist_rule_acc = _get_consist_rule_acc( | |||||
model, abducer, mapping, rules, test_X[0][equation_len] | |||||
) | |||||
INFO( | |||||
"consist_rule_acc of testing length %d equations are %f, %f" | |||||
% (equation_len, true_consist_rule_acc, false_consist_rule_acc) | |||||
) | |||||
if __name__ == "__main__": | |||||
pass |
@@ -1,20 +1,18 @@ | |||||
import os | import os | ||||
from collections import defaultdict | from collections import defaultdict | ||||
from typing import Any, List | |||||
import torch | import torch | ||||
from torch.utils.data import DataLoader | |||||
from abl.reasoning import ReasonerBase | |||||
from abl.learning import ABLModel, BasicNN | |||||
from abl.bridge import SimpleBridge | from abl.bridge import SimpleBridge | ||||
from abl.dataset import RegressionDataset | |||||
from abl.evaluation import BaseMetric | from abl.evaluation import BaseMetric | ||||
from abl.dataset import BridgeDataset, RegressionDataset | |||||
from abl.learning import ABLModel, BasicNN | |||||
from abl.reasoning import ReasonerBase | |||||
from abl.structures import ListData | from abl.structures import ListData | ||||
from abl.utils import print_log | from abl.utils import print_log | ||||
from examples.hed.utils import gen_mappings, InfiniteSampler | |||||
from examples.models.nn import SymbolNetAutoencoder | |||||
from examples.hed.datasets.get_hed import get_pretrain_data | from examples.hed.datasets.get_hed import get_pretrain_data | ||||
from examples.hed.utils import InfiniteSampler, gen_mappings | |||||
from examples.models.nn import SymbolNetAutoencoder | |||||
class HEDBridge(SimpleBridge): | class HEDBridge(SimpleBridge): | ||||
@@ -95,7 +93,8 @@ class HEDBridge(SimpleBridge): | |||||
character_accuracy = self.model.valid(filtered_data_samples) | character_accuracy = self.model.valid(filtered_data_samples) | ||||
revisible_ratio = len(filtered_data_samples.X) / len(data_samples.X) | revisible_ratio = len(filtered_data_samples.X) / len(data_samples.X) | ||||
print_log( | print_log( | ||||
f"Revisible ratio is {revisible_ratio:.3f}, Character accuracy is {character_accuracy:.3f}", | |||||
f"Revisible ratio is {revisible_ratio:.3f}, Character \ | |||||
accuracy is {character_accuracy:.3f}", | |||||
logger="current", | logger="current", | ||||
) | ) | ||||
@@ -111,7 +110,8 @@ class HEDBridge(SimpleBridge): | |||||
false_ratio = self.calc_consistent_ratio(val_X_false, rule) | false_ratio = self.calc_consistent_ratio(val_X_false, rule) | ||||
print_log( | print_log( | ||||
f"True consistent ratio is {true_ratio:.3f}, False inconsistent ratio is {1 - false_ratio:.3f}", | |||||
f"True consistent ratio is {true_ratio:.3f}, False inconsistent ratio \ | |||||
is {1 - false_ratio:.3f}", | |||||
logger="current", | logger="current", | ||||
) | ) | ||||
@@ -143,7 +143,7 @@ class HEDBridge(SimpleBridge): | |||||
if len(consistent_instance) != 0: | if len(consistent_instance) != 0: | ||||
rule = self.reasoner.abduce_rules(consistent_instance) | rule = self.reasoner.abduce_rules(consistent_instance) | ||||
if rule != None: | |||||
if rule is not None: | |||||
rules.append(rule) | rules.append(rule) | ||||
break | break | ||||
@@ -214,7 +214,8 @@ class HEDBridge(SimpleBridge): | |||||
loss = self.model.train(filtered_sub_data_samples) | loss = self.model.train(filtered_sub_data_samples) | ||||
print_log( | print_log( | ||||
f"Equation Len(train) [{equation_len}] Segment Index [{seg_idx + 1}] model loss is {loss:.5f}", | |||||
f"Equation Len(train) [{equation_len}] Segment Index [{seg_idx + 1}] \ | |||||
model loss is {loss:.5f}", | |||||
logger="current", | logger="current", | ||||
) | ) | ||||
@@ -224,11 +225,11 @@ class HEDBridge(SimpleBridge): | |||||
condition_num = 0 | condition_num = 0 | ||||
if condition_num >= 5: | if condition_num >= 5: | ||||
print_log(f"Now checking if we can go to next course", logger="current") | |||||
print_log("Now checking if we can go to next course", logger="current") | |||||
rules = self.get_rules_from_data( | rules = self.get_rules_from_data( | ||||
data_samples, samples_per_rule=3, samples_num=50 | data_samples, samples_per_rule=3, samples_num=50 | ||||
) | ) | ||||
print_log(f"Learned rules from data: " + str(rules), logger="current") | |||||
print_log("Learned rules from data: " + str(rules), logger="current") | |||||
seems_good = self.check_rule_quality(rules, val_data, equation_len) | seems_good = self.check_rule_quality(rules, val_data, equation_len) | ||||
if seems_good: | if seems_good: | ||||
@@ -66,7 +66,8 @@ | |||||
" prolog_rules = prolog_result[0][\"X\"]\n", | " prolog_rules = prolog_result[0][\"X\"]\n", | ||||
" rules = [rule.value for rule in prolog_rules]\n", | " rules = [rule.value for rule in prolog_rules]\n", | ||||
" return rules\n", | " return rules\n", | ||||
" \n", | |||||
"\n", | |||||
"\n", | |||||
"class HedReasoner(ReasonerBase):\n", | "class HedReasoner(ReasonerBase):\n", | ||||
" def revise_at_idx(self, data_sample):\n", | " def revise_at_idx(self, data_sample):\n", | ||||
" revision_idx = np.where(np.array(data_sample.flatten(\"revision_flag\")) != 0)[0]\n", | " revision_idx = np.where(np.array(data_sample.flatten(\"revision_flag\")) != 0)[0]\n", | ||||
@@ -76,7 +77,9 @@ | |||||
" return candidate\n", | " return candidate\n", | ||||
"\n", | "\n", | ||||
" def zoopt_revision_score(self, symbol_num, data_sample, sol):\n", | " def zoopt_revision_score(self, symbol_num, data_sample, sol):\n", | ||||
" revision_flag = reform_list(list(sol.get_x().astype(np.int32)), data_sample.pred_pseudo_label)\n", | |||||
" revision_flag = reform_list(\n", | |||||
" list(sol.get_x().astype(np.int32)), data_sample.pred_pseudo_label\n", | |||||
" )\n", | |||||
" data_sample.revision_flag = revision_flag\n", | " data_sample.revision_flag = revision_flag\n", | ||||
"\n", | "\n", | ||||
" lefted_idxs = [i for i in range(len(data_sample.pred_idx))]\n", | " lefted_idxs = [i for i in range(len(data_sample.pred_idx))]\n", | ||||
@@ -108,7 +111,7 @@ | |||||
" for i in range(0, len(candidate_size)):\n", | " for i in range(0, len(candidate_size)):\n", | ||||
" score -= math.exp(-i) * candidate_size[i]\n", | " score -= math.exp(-i) * candidate_size[i]\n", | ||||
" return score\n", | " return score\n", | ||||
" \n", | |||||
"\n", | |||||
" def abduce(self, data_sample):\n", | " def abduce(self, data_sample):\n", | ||||
" symbol_num = data_sample.elements_num(\"pred_pseudo_label\")\n", | " symbol_num = data_sample.elements_num(\"pred_pseudo_label\")\n", | ||||
" max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)\n", | " max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)\n", | ||||
@@ -134,9 +137,8 @@ | |||||
" def abduce_rules(self, pred_res):\n", | " def abduce_rules(self, pred_res):\n", | ||||
" return self.kb.abduce_rules(pred_res)\n", | " return self.kb.abduce_rules(pred_res)\n", | ||||
"\n", | "\n", | ||||
"kb = HedKB(\n", | |||||
" pseudo_label_list=[1, 0, \"+\", \"=\"], pl_file=\"./datasets/learn_add.pl\"\n", | |||||
")\n", | |||||
"\n", | |||||
"kb = HedKB(pseudo_label_list=[1, 0, \"+\", \"=\"], pl_file=\"./datasets/learn_add.pl\")\n", | |||||
"reasoner = HedReasoner(kb, dist_func=\"hamming\", use_zoopt=True, max_revision=20)" | "reasoner = HedReasoner(kb, dist_func=\"hamming\", use_zoopt=True, max_revision=20)" | ||||
] | ] | ||||
}, | }, | ||||
@@ -1,69 +0,0 @@ | |||||
# coding: utf-8 | |||||
# ================================================================# | |||||
# Copyright (C) 2021 Freecss All rights reserved. | |||||
# | |||||
# File Name :share_example.py | |||||
# Author :freecss | |||||
# Email :karlfreecss@gmail.com | |||||
# Created Date :2021/06/07 | |||||
# Description : | |||||
# | |||||
# ================================================================# | |||||
import sys | |||||
sys.path.append("../") | |||||
from abl.utils.plog import logger, INFO | |||||
from abl.utils.utils import reduce_dimension | |||||
import torch.nn as nn | |||||
import torch | |||||
from abl.models.nn import LeNet5, SymbolNet | |||||
from abl.models.basic_model import BasicModel, BasicDataset | |||||
from abl.models.wabl_models import DecisionTree, WABLBasicModel | |||||
from sklearn.neighbors import KNeighborsClassifier | |||||
from abl.abducer.abducer_base import AbducerBase | |||||
from abl.abducer.kb import add_KB, HWF_KB, prolog_KB | |||||
from datasets.mnist_add.get_mnist_add import get_mnist_add | |||||
from datasets.hwf.get_hwf import get_hwf | |||||
from datasets.hed.get_hed import get_hed, split_equation | |||||
from abl import framework_hed_knn | |||||
def run_test(): | |||||
# kb = add_KB(True) | |||||
# kb = HWF_KB(True) | |||||
# abducer = AbducerBase(kb) | |||||
kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl') | |||||
abducer = AbducerBase(kb, zoopt=True, multiple_predictions=True) | |||||
recorder = logger() | |||||
total_train_data = get_hed(train=True) | |||||
train_data, val_data = split_equation(total_train_data, 3, 1) | |||||
test_data = get_hed(train=False) | |||||
# ========================= KNN model ============================ # | |||||
reduce_dimension(train_data) | |||||
reduce_dimension(val_data) | |||||
reduce_dimension(test_data) | |||||
base_model = KNeighborsClassifier(n_neighbors=3) | |||||
pretrain_data_X, pretrain_data_Y = framework_hed_knn.hed_pretrain(base_model) | |||||
model = WABLBasicModel(base_model, kb.pseudo_label_list) | |||||
model, mapping = framework_hed_knn.train_with_rule( | |||||
model, abducer, train_data, val_data, (pretrain_data_X, pretrain_data_Y), select_num=10, min_len=5, max_len=8 | |||||
) | |||||
framework_hed_knn.hed_test( | |||||
model, abducer, mapping, train_data, test_data, min_len=5, max_len=8 | |||||
) | |||||
# ============================ End =============================== # | |||||
recorder.dump() | |||||
return True | |||||
if __name__ == "__main__": | |||||
run_test() |
@@ -1,159 +0,0 @@ | |||||
import os.path as osp | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
from abl.evaluation import SemanticsMetric, SymbolMetric | |||||
from abl.learning import ABLModel, BasicNN | |||||
from abl.reasoning import PrologKB, ReasonerBase | |||||
from abl.utils import ABLLogger, print_log, reform_list | |||||
from examples.hed.datasets.get_hed import get_hed, split_equation | |||||
from examples.hed.hed_bridge import HEDBridge | |||||
from examples.models.nn import SymbolNet | |||||
# Build logger | |||||
print_log("Abductive Learning on the HED example.", logger="current") | |||||
# Retrieve the directory of the Log file and define the directory for saving the model weights. | |||||
log_dir = ABLLogger.get_current_instance().log_dir | |||||
weights_dir = osp.join(log_dir, "weights") | |||||
### Logic Part | |||||
# Initialize knowledge base and abducer | |||||
class HedKB(PrologKB): | |||||
def __init__(self, pseudo_label_list, pl_file): | |||||
super().__init__(pseudo_label_list, pl_file) | |||||
def consist_rule(self, exs, rules): | |||||
rules = str(rules).replace("'", "") | |||||
return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0 | |||||
def abduce_rules(self, pred_res): | |||||
prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res)) | |||||
if len(prolog_result) == 0: | |||||
return None | |||||
prolog_rules = prolog_result[0]["X"] | |||||
rules = [rule.value for rule in prolog_rules] | |||||
return rules | |||||
class HedReasoner(ReasonerBase): | |||||
def revise_at_idx(self, data_sample): | |||||
revision_idx = np.where(np.array(data_sample.flatten("revision_flag")) != 0)[0] | |||||
candidate = self.kb.revise_at_idx( | |||||
data_sample.pred_pseudo_label, data_sample.Y, revision_idx | |||||
) | |||||
return candidate | |||||
def zoopt_revision_score(self, symbol_num, data_sample, sol): | |||||
revision_flag = reform_list(list(sol.get_x().astype(np.int32)), data_sample.pred_pseudo_label) | |||||
data_sample.revision_flag = revision_flag | |||||
lefted_idxs = [i for i in range(len(data_sample.pred_idx))] | |||||
candidate_size = [] | |||||
while lefted_idxs: | |||||
idxs = [] | |||||
idxs.append(lefted_idxs.pop(0)) | |||||
max_candidate_idxs = [] | |||||
found = False | |||||
for idx in range(-1, len(data_sample.pred_idx)): | |||||
if (not idx in idxs) and (idx >= 0): | |||||
idxs.append(idx) | |||||
candidate = self.revise_at_idx(data_sample[idxs]) | |||||
if len(candidate) == 0: | |||||
if len(idxs) > 1: | |||||
idxs.pop() | |||||
else: | |||||
if len(idxs) > len(max_candidate_idxs): | |||||
found = True | |||||
max_candidate_idxs = idxs.copy() | |||||
removed = [i for i in lefted_idxs if i in max_candidate_idxs] | |||||
if found: | |||||
candidate_size.append(len(removed) + 1) | |||||
lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] | |||||
candidate_size.sort() | |||||
score = 0 | |||||
import math | |||||
for i in range(0, len(candidate_size)): | |||||
score -= math.exp(-i) * candidate_size[i] | |||||
return score | |||||
def abduce(self, data_sample): | |||||
symbol_num = data_sample.elements_num("pred_pseudo_label") | |||||
max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num) | |||||
solution = self.zoopt_get_solution(symbol_num, data_sample, max_revision_num) | |||||
data_sample.revision_flag = reform_list( | |||||
solution.astype(np.int32), data_sample.pred_pseudo_label | |||||
) | |||||
abduced_pseudo_label = [] | |||||
for single_instance in data_sample: | |||||
single_instance.pred_pseudo_label = [single_instance.pred_pseudo_label] | |||||
candidates = self.revise_at_idx(single_instance) | |||||
if len(candidates) == 0: | |||||
abduced_pseudo_label.append([]) | |||||
else: | |||||
abduced_pseudo_label.append(candidates[0][0]) | |||||
data_sample.abduced_pseudo_label = abduced_pseudo_label | |||||
return abduced_pseudo_label | |||||
def abduce_rules(self, pred_res): | |||||
return self.kb.abduce_rules(pred_res) | |||||
import os | |||||
CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | |||||
kb = HedKB( | |||||
pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "./datasets/learn_add.pl") | |||||
) | |||||
reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=20) | |||||
### Machine Learning Part | |||||
# Build necessary components for BasicNN | |||||
cls = SymbolNet(num_classes=4) | |||||
criterion = nn.CrossEntropyLoss() | |||||
optimizer = torch.optim.Adam(cls.parameters(), lr=0.001) | |||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||||
# Build BasicNN | |||||
# The function of BasicNN is to wrap NN models into the form of an sklearn estimator | |||||
base_model = BasicNN( | |||||
cls, | |||||
criterion, | |||||
optimizer, | |||||
device, | |||||
batch_size=32, | |||||
num_epochs=1, | |||||
save_interval=1, | |||||
save_dir=weights_dir, | |||||
) | |||||
# Build ABLModel | |||||
# The main function of the ABL model is to serialize data and | |||||
# provide a unified interface for different machine learning models | |||||
model = ABLModel(base_model) | |||||
### Metric | |||||
# Set up metrics | |||||
metric_list = [SymbolMetric(prefix="hed"), SemanticsMetric(prefix="hed")] | |||||
### Bridge Machine Learning and Logic Reasoning | |||||
bridge = HEDBridge(model, reasoner, metric_list) | |||||
### Dataset | |||||
total_train_data = get_hed(train=True) | |||||
train_data, val_data = split_equation(total_train_data, 3, 1) | |||||
test_data = get_hed(train=False) | |||||
### Train and Test | |||||
bridge.pretrain("examples/hed/weights") | |||||
bridge.train(train_data, val_data) |
@@ -1,6 +1,6 @@ | |||||
import numpy as np | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import numpy as np | |||||
import torch.utils.data.sampler as sampler | import torch.utils.data.sampler as sampler | ||||
@@ -13,7 +13,7 @@ class InfiniteSampler(sampler.Sampler): | |||||
while True: | while True: | ||||
order = np.random.permutation(self.num_samples) | order = np.random.permutation(self.num_samples) | ||||
for i in range(self.num_samples): | for i in range(self.num_samples): | ||||
yield order[i: i + self.batch_size] | |||||
yield order[i : i + self.batch_size] | |||||
i += self.batch_size | i += self.batch_size | ||||
def __len__(self): | def __len__(self): | ||||
@@ -58,7 +58,6 @@ def reduce_dimension(data): | |||||
for equation_len in range(5, 27): | for equation_len in range(5, 27): | ||||
equations = data[truth_value][equation_len] | equations = data[truth_value][equation_len] | ||||
reduced_equations = [ | reduced_equations = [ | ||||
[extract_feature(symbol_img) for symbol_img in equation] | |||||
for equation in equations | |||||
[extract_feature(symbol_img) for symbol_img in equation] for equation in equations | |||||
] | ] | ||||
data[truth_value][equation_len] = reduced_equations | |||||
data[truth_value][equation_len] = reduced_equations |
@@ -1,14 +1,12 @@ | |||||
import os | |||||
import json | import json | ||||
import os | |||||
from PIL import Image | from PIL import Image | ||||
from torchvision.transforms import transforms | from torchvision.transforms import transforms | ||||
CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | ||||
img_transform = transforms.Compose( | |||||
[transforms.ToTensor(), transforms.Normalize((0.5,), (1,))] | |||||
) | |||||
img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1,))]) | |||||
def get_data(file, get_pseudo_label): | def get_data(file, get_pseudo_label): | ||||
@@ -51,14 +51,13 @@ | |||||
"source": [ | "source": [ | ||||
"# Initialize knowledge base and reasoner\n", | "# Initialize knowledge base and reasoner\n", | ||||
"class HWF_KB(KBBase):\n", | "class HWF_KB(KBBase):\n", | ||||
"\n", | |||||
" def _valid_candidate(self, formula):\n", | " def _valid_candidate(self, formula):\n", | ||||
" if len(formula) % 2 == 0:\n", | " if len(formula) % 2 == 0:\n", | ||||
" return False\n", | " return False\n", | ||||
" for i in range(len(formula)):\n", | " for i in range(len(formula)):\n", | ||||
" if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']:\n", | |||||
" if i % 2 == 0 and formula[i] not in [\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\"]:\n", | |||||
" return False\n", | " return False\n", | ||||
" if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']:\n", | |||||
" if i % 2 != 0 and formula[i] not in [\"+\", \"-\", \"times\", \"div\"]:\n", | |||||
" return False\n", | " return False\n", | ||||
" return True\n", | " return True\n", | ||||
"\n", | "\n", | ||||
@@ -66,12 +65,17 @@ | |||||
" if not self._valid_candidate(formula):\n", | " if not self._valid_candidate(formula):\n", | ||||
" return np.inf\n", | " return np.inf\n", | ||||
" mapping = {str(i): str(i) for i in range(1, 10)}\n", | " mapping = {str(i): str(i) for i in range(1, 10)}\n", | ||||
" mapping.update({'+': '+', '-': '-', 'times': '*', 'div': '/'})\n", | |||||
" mapping.update({\"+\": \"+\", \"-\": \"-\", \"times\": \"*\", \"div\": \"/\"})\n", | |||||
" formula = [mapping[f] for f in formula]\n", | " formula = [mapping[f] for f in formula]\n", | ||||
" return eval(''.join(formula))\n", | |||||
" return eval(\"\".join(formula))\n", | |||||
"\n", | |||||
"\n", | "\n", | ||||
"kb = HWF_KB(pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], max_err=1e-10, use_cache=False)\n", | |||||
"reasoner = ReasonerBase(kb, dist_func='confidence')" | |||||
"kb = HWF_KB(\n", | |||||
" pseudo_label_list=[\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"+\", \"-\", \"times\", \"div\"],\n", | |||||
" max_err=1e-10,\n", | |||||
" use_cache=False,\n", | |||||
")\n", | |||||
"reasoner = ReasonerBase(kb, dist_func=\"confidence\")" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
@@ -122,7 +126,7 @@ | |||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
"# Initialize ABL model\n", | "# Initialize ABL model\n", | ||||
"# The main function of the ABL model is to serialize data and \n", | |||||
"# The main function of the ABL model is to serialize data and\n", | |||||
"# provide a unified interface for different machine learning models\n", | "# provide a unified interface for different machine learning models\n", | ||||
"model = ABLModel(base_model)" | "model = ABLModel(base_model)" | ||||
] | ] | ||||
@@ -80,7 +80,7 @@ | |||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
"# Build ABLModel\n", | "# Build ABLModel\n", | ||||
"# The main function of the ABL model is to serialize data and \n", | |||||
"# The main function of the ABL model is to serialize data and\n", | |||||
"# provide a unified interface for different machine learning models\n", | "# provide a unified interface for different machine learning models\n", | ||||
"model = ABLModel(base_model)" | "model = ABLModel(base_model)" | ||||
] | ] | ||||
@@ -11,8 +11,8 @@ | |||||
# ================================================================# | # ================================================================# | ||||
import torch | |||||
import numpy as np | import numpy as np | ||||
import torch | |||||
from torch import nn | from torch import nn | ||||
@@ -84,9 +84,7 @@ class SymbolNetAutoencoder(nn.Module): | |||||
self.base_model = SymbolNet(num_classes, image_size) | self.base_model = SymbolNet(num_classes, image_size) | ||||
self.softmax = nn.Softmax(dim=1) | self.softmax = nn.Softmax(dim=1) | ||||
self.fc1 = nn.Sequential(nn.Linear(num_classes, 100), nn.ReLU()) | self.fc1 = nn.Sequential(nn.Linear(num_classes, 100), nn.ReLU()) | ||||
self.fc2 = nn.Sequential( | |||||
nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU() | |||||
) | |||||
self.fc2 = nn.Sequential(nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU()) | |||||
def forward(self, x): | def forward(self, x): | ||||
x = self.base_model(x) | x = self.base_model(x) | ||||
@@ -1,4 +1,5 @@ | |||||
import os | import os | ||||
from setuptools import find_packages, setup | from setuptools import find_packages, setup | ||||
@@ -27,7 +28,13 @@ here = os.path.abspath(os.path.dirname(__file__)) | |||||
try: | try: | ||||
with open(os.path.join(here, "requirements.txt"), encoding="utf-8") as f: | with open(os.path.join(here, "requirements.txt"), encoding="utf-8") as f: | ||||
REQUIRED = f.read().split("\n") | REQUIRED = f.read().split("\n") | ||||
except: | |||||
except FileNotFoundError: | |||||
# Handle the case where the file does not exist | |||||
print("requirements.txt file not found.") | |||||
REQUIRED = [] | |||||
except Exception as e: | |||||
# Handle other possible exceptions | |||||
print(f"An error occurred: {e}") | |||||
REQUIRED = [] | REQUIRED = [] | ||||
EXTRAS = { | EXTRAS = { | ||||
@@ -64,7 +71,7 @@ if __name__ == "__main__": | |||||
install_requires=REQUIRED, | install_requires=REQUIRED, | ||||
extras_require=EXTRAS, | extras_require=EXTRAS, | ||||
classifiers=[ | classifiers=[ | ||||
'Development Status :: 3 - Alpha', | |||||
"Development Status :: 3 - Alpha", | |||||
"Intended Audience :: Science/Research", | "Intended Audience :: Science/Research", | ||||
"Intended Audience :: Developers", | "Intended Audience :: Developers", | ||||
"Programming Language :: Python", | "Programming Language :: Python", | ||||
@@ -74,4 +81,3 @@ if __name__ == "__main__": | |||||
"Programming Language :: Python :: 3.8", | "Programming Language :: Python :: 3.8", | ||||
], | ], | ||||
) | ) | ||||
@@ -4,10 +4,11 @@ import torch.nn as nn | |||||
import torch.optim as optim | import torch.optim as optim | ||||
from abl.learning import BasicNN | from abl.learning import BasicNN | ||||
from abl.reasoning import KBBase, GroundKB, PrologKB, Reasoner | |||||
from abl.reasoning import GroundKB, KBBase, PrologKB, Reasoner | |||||
from abl.structures import ListData | from abl.structures import ListData | ||||
from examples.models.nn import LeNet5 | from examples.models.nn import LeNet5 | ||||
# Fixture for BasicNN instance | # Fixture for BasicNN instance | ||||
@pytest.fixture | @pytest.fixture | ||||
def basic_nn_instance(): | def basic_nn_instance(): | ||||
@@ -16,6 +17,7 @@ def basic_nn_instance(): | |||||
optimizer = optim.Adam(model.parameters()) | optimizer = optim.Adam(model.parameters()) | ||||
return BasicNN(model, criterion, optimizer) | return BasicNN(model, criterion, optimizer) | ||||
# Fixture for base_model instance | # Fixture for base_model instance | ||||
@pytest.fixture | @pytest.fixture | ||||
def base_model_instance(): | def base_model_instance(): | ||||
@@ -24,6 +26,7 @@ def base_model_instance(): | |||||
optimizer = optim.Adam(model.parameters()) | optimizer = optim.Adam(model.parameters()) | ||||
return BasicNN(model, criterion, optimizer) | return BasicNN(model, criterion, optimizer) | ||||
# Fixture for ListData instance | # Fixture for ListData instance | ||||
@pytest.fixture | @pytest.fixture | ||||
def list_data_instance(): | def list_data_instance(): | ||||
@@ -37,47 +40,71 @@ def list_data_instance(): | |||||
@pytest.fixture | @pytest.fixture | ||||
def data_samples_add(): | def data_samples_add(): | ||||
# favor 1 in first one | # favor 1 in first one | ||||
prob1 = [[0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0], | |||||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] | |||||
prob1 = [ | |||||
[0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0], | |||||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||||
] | |||||
# favor 7 in first one | # favor 7 in first one | ||||
prob2 = [[0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0], | |||||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] | |||||
prob2 = [ | |||||
[0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0], | |||||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||||
] | |||||
data_samples_add = ListData() | data_samples_add = ListData() | ||||
data_samples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]] | data_samples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]] | ||||
data_samples_add.pred_prob = [prob1, prob2, prob1, prob2] | data_samples_add.pred_prob = [prob1, prob2, prob1, prob2] | ||||
data_samples_add.Y = [8, 8, 17, 10] | data_samples_add.Y = [8, 8, 17, 10] | ||||
return data_samples_add | return data_samples_add | ||||
@pytest.fixture | @pytest.fixture | ||||
def data_samples_hwf(): | def data_samples_hwf(): | ||||
data_samples_hwf = ListData() | data_samples_hwf = ListData() | ||||
data_samples_hwf.pred_pseudo_label = [["5", "+", "2"], ["5", "+", "9"], ["5", "+", "9"], ["5", "-", "8", "8", "8"]] | |||||
data_samples_hwf.pred_pseudo_label = [ | |||||
["5", "+", "2"], | |||||
["5", "+", "9"], | |||||
["5", "+", "9"], | |||||
["5", "-", "8", "8", "8"], | |||||
] | |||||
data_samples_hwf.pred_prob = [None, None, None, None] | data_samples_hwf.pred_prob = [None, None, None, None] | ||||
data_samples_hwf.Y = [3, 64, 65, 3.17] | data_samples_hwf.Y = [3, 64, 65, 3.17] | ||||
return data_samples_hwf | return data_samples_hwf | ||||
class AddKB(KBBase): | class AddKB(KBBase): | ||||
def __init__(self, pseudo_label_list=list(range(10)), | |||||
use_cache=False): | |||||
def __init__(self, pseudo_label_list=list(range(10)), use_cache=False): | |||||
super().__init__(pseudo_label_list, use_cache=use_cache) | super().__init__(pseudo_label_list, use_cache=use_cache) | ||||
def logic_forward(self, nums): | def logic_forward(self, nums): | ||||
return sum(nums) | return sum(nums) | ||||
class AddGroundKB(GroundKB): | class AddGroundKB(GroundKB): | ||||
def __init__(self, pseudo_label_list=list(range(10)), | |||||
GKB_len_list=[2]): | |||||
def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]): | |||||
super().__init__(pseudo_label_list, GKB_len_list) | super().__init__(pseudo_label_list, GKB_len_list) | ||||
def logic_forward(self, nums): | def logic_forward(self, nums): | ||||
return sum(nums) | return sum(nums) | ||||
class HwfKB(KBBase): | class HwfKB(KBBase): | ||||
def __init__( | def __init__( | ||||
self, | self, | ||||
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", | |||||
"+", "-", "times", "div"], | |||||
pseudo_label_list=[ | |||||
"1", | |||||
"2", | |||||
"3", | |||||
"4", | |||||
"5", | |||||
"6", | |||||
"7", | |||||
"8", | |||||
"9", | |||||
"+", | |||||
"-", | |||||
"times", | |||||
"div", | |||||
], | |||||
max_err=1e-3, | max_err=1e-3, | ||||
use_cache=False, | use_cache=False, | ||||
): | ): | ||||
@@ -87,7 +114,17 @@ class HwfKB(KBBase): | |||||
if len(formula) % 2 == 0: | if len(formula) % 2 == 0: | ||||
return False | return False | ||||
for i in range(len(formula)): | for i in range(len(formula)): | ||||
if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: | |||||
if i % 2 == 0 and formula[i] not in [ | |||||
"1", | |||||
"2", | |||||
"3", | |||||
"4", | |||||
"5", | |||||
"6", | |||||
"7", | |||||
"8", | |||||
"9", | |||||
]: | |||||
return False | return False | ||||
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | ||||
return False | return False | ||||
@@ -100,7 +137,8 @@ class HwfKB(KBBase): | |||||
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | ||||
formula = [mapping[f] for f in formula] | formula = [mapping[f] for f in formula] | ||||
return eval("".join(formula)) | return eval("".join(formula)) | ||||
class HedKB(PrologKB): | class HedKB(PrologKB): | ||||
def __init__(self, pseudo_label_list, pl_file): | def __init__(self, pseudo_label_list, pl_file): | ||||
super().__init__(pseudo_label_list, pl_file) | super().__init__(pseudo_label_list, pl_file) | ||||
@@ -110,24 +148,28 @@ class HedKB(PrologKB): | |||||
pl_query = "eval_inst_feature(%s, %s)." % (exs, rules) | pl_query = "eval_inst_feature(%s, %s)." % (exs, rules) | ||||
return len(list(self.prolog.query(pl_query))) != 0 | return len(list(self.prolog.query(pl_query))) != 0 | ||||
@pytest.fixture | @pytest.fixture | ||||
def kb_add(): | def kb_add(): | ||||
return AddKB() | return AddKB() | ||||
@pytest.fixture | @pytest.fixture | ||||
def kb_add_cache(): | def kb_add_cache(): | ||||
return AddKB(use_cache=True) | |||||
return AddKB(use_cache=True) | |||||
@pytest.fixture | @pytest.fixture | ||||
def kb_add_ground(): | def kb_add_ground(): | ||||
return AddGroundKB() | return AddGroundKB() | ||||
@pytest.fixture | @pytest.fixture | ||||
def kb_add_prolog(): | def kb_add_prolog(): | ||||
kb = PrologKB(pseudo_label_list=list(range(10)), | |||||
pl_file="examples/mnist_add/datasets/add.pl") | |||||
kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/datasets/add.pl") | |||||
return kb | return kb | ||||
@pytest.fixture | @pytest.fixture | ||||
def kb_hed(): | def kb_hed(): | ||||
kb = HedKB( | kb = HedKB( | ||||
@@ -136,6 +178,7 @@ def kb_hed(): | |||||
) | ) | ||||
return kb | return kb | ||||
@pytest.fixture | @pytest.fixture | ||||
def reasoner_instance(kb_add): | def reasoner_instance(kb_add): | ||||
return Reasoner(kb_add, "confidence") | |||||
return Reasoner(kb_add, "confidence") |
@@ -1,8 +1,9 @@ | |||||
from unittest.mock import Mock, create_autospec | |||||
import numpy as np | import numpy as np | ||||
import pytest | import pytest | ||||
from abl.learning import ABLModel | from abl.learning import ABLModel | ||||
from unittest.mock import Mock, create_autospec | |||||
class TestABLModel(object): | class TestABLModel(object): | ||||
@@ -1,63 +1,66 @@ | |||||
import pytest | import pytest | ||||
from abl.reasoning import KBBase, GroundKB, PrologKB, Reasoner | |||||
from abl.reasoning import PrologKB, Reasoner | |||||
class TestKBBase(object): | class TestKBBase(object): | ||||
def test_init(self, kb_add): | def test_init(self, kb_add): | ||||
assert kb_add.pseudo_label_list == list(range(10)) | assert kb_add.pseudo_label_list == list(range(10)) | ||||
def test_init_cache(self, kb_add_cache): | def test_init_cache(self, kb_add_cache): | ||||
assert kb_add_cache.pseudo_label_list == list(range(10)) | assert kb_add_cache.pseudo_label_list == list(range(10)) | ||||
assert kb_add_cache.use_cache == True | |||||
assert kb_add_cache.use_cache is True | |||||
def test_logic_forward(self, kb_add): | def test_logic_forward(self, kb_add): | ||||
result = kb_add.logic_forward([1, 2]) | result = kb_add.logic_forward([1, 2]) | ||||
assert result == 3 | assert result == 3 | ||||
def test_revise_at_idx(self, kb_add): | def test_revise_at_idx(self, kb_add): | ||||
result = kb_add.revise_at_idx([1, 2], 2, [0]) | result = kb_add.revise_at_idx([1, 2], 2, [0]) | ||||
assert result == [[0, 2]] | assert result == [[0, 2]] | ||||
def test_abduce_candidates(self, kb_add): | def test_abduce_candidates(self, kb_add): | ||||
result = kb_add.abduce_candidates([1, 2], 1, max_revision_num=2, | |||||
require_more_revision=0) | |||||
result = kb_add.abduce_candidates([1, 2], 1, max_revision_num=2, require_more_revision=0) | |||||
assert result == [[1, 0]] | assert result == [[1, 0]] | ||||
class TestGroundKB(object): | class TestGroundKB(object): | ||||
def test_init(self, kb_add_ground): | def test_init(self, kb_add_ground): | ||||
assert kb_add_ground.pseudo_label_list == list(range(10)) | assert kb_add_ground.pseudo_label_list == list(range(10)) | ||||
assert kb_add_ground.GKB_len_list == [2] | assert kb_add_ground.GKB_len_list == [2] | ||||
assert kb_add_ground.GKB | assert kb_add_ground.GKB | ||||
def test_logic_forward_ground(self, kb_add_ground): | def test_logic_forward_ground(self, kb_add_ground): | ||||
result = kb_add_ground.logic_forward([1, 2]) | result = kb_add_ground.logic_forward([1, 2]) | ||||
assert result == 3 | assert result == 3 | ||||
def test_abduce_candidates_ground(self, kb_add_ground): | def test_abduce_candidates_ground(self, kb_add_ground): | ||||
result = kb_add_ground.abduce_candidates([1, 2], 1, max_revision_num=2, | |||||
require_more_revision=0) | |||||
result = kb_add_ground.abduce_candidates( | |||||
[1, 2], 1, max_revision_num=2, require_more_revision=0 | |||||
) | |||||
assert result == [(1, 0)] | assert result == [(1, 0)] | ||||
class TestPrologKB(object): | |||||
class TestPrologKB(object): | |||||
def test_init_pl1(self, kb_add_prolog): | def test_init_pl1(self, kb_add_prolog): | ||||
assert kb_add_prolog.pseudo_label_list == list(range(10)) | assert kb_add_prolog.pseudo_label_list == list(range(10)) | ||||
assert kb_add_prolog.pl_file == "examples/mnist_add/datasets/add.pl" | assert kb_add_prolog.pl_file == "examples/mnist_add/datasets/add.pl" | ||||
def test_init_pl2(self, kb_hed): | def test_init_pl2(self, kb_hed): | ||||
assert kb_hed.pseudo_label_list == [1, 0, "+", "="] | assert kb_hed.pseudo_label_list == [1, 0, "+", "="] | ||||
assert kb_hed.pl_file == "examples/hed/datasets/learn_add.pl" | assert kb_hed.pl_file == "examples/hed/datasets/learn_add.pl" | ||||
def test_prolog_file_not_exist(self): | def test_prolog_file_not_exist(self): | ||||
pseudo_label_list = [1, 2] | pseudo_label_list = [1, 2] | ||||
non_existing_file = "path/to/non_existing_file.pl" | non_existing_file = "path/to/non_existing_file.pl" | ||||
with pytest.raises(FileNotFoundError) as excinfo: | with pytest.raises(FileNotFoundError) as excinfo: | ||||
PrologKB(pseudo_label_list=pseudo_label_list, | |||||
pl_file=non_existing_file) | |||||
PrologKB(pseudo_label_list=pseudo_label_list, pl_file=non_existing_file) | |||||
assert non_existing_file in str(excinfo.value) | assert non_existing_file in str(excinfo.value) | ||||
def test_logic_forward_pl1(self, kb_add_prolog): | def test_logic_forward_pl1(self, kb_add_prolog): | ||||
result = kb_add_prolog.logic_forward([1, 2]) | result = kb_add_prolog.logic_forward([1, 2]) | ||||
assert result == 3 | assert result == 3 | ||||
def test_logic_forward_pl2(self, kb_hed): | def test_logic_forward_pl2(self, kb_hed): | ||||
consist_exs = [ | consist_exs = [ | ||||
[1, 1, "+", 0, "=", 1, 1], | [1, 1, "+", 0, "=", 1, 1], | ||||
@@ -70,21 +73,24 @@ class TestPrologKB(object): | |||||
[0, "+", 0, "=", 0], | [0, "+", 0, "=", 0], | ||||
[0, "+", 0, "=", 1], | [0, "+", 0, "=", 1], | ||||
] | ] | ||||
assert kb_hed.logic_forward(consist_exs) == True | |||||
assert kb_hed.logic_forward(inconsist_exs) == False | |||||
assert kb_hed.logic_forward(consist_exs) is True | |||||
assert kb_hed.logic_forward(inconsist_exs) is False | |||||
def test_revise_at_idx(self, kb_add_prolog): | def test_revise_at_idx(self, kb_add_prolog): | ||||
result = kb_add_prolog.revise_at_idx([1, 2], 2, [0]) | result = kb_add_prolog.revise_at_idx([1, 2], 2, [0]) | ||||
assert result == [[0, 2]] | assert result == [[0, 2]] | ||||
class TestReaonser(object): | class TestReaonser(object): | ||||
def test_reasoner_init(self, reasoner_instance): | def test_reasoner_init(self, reasoner_instance): | ||||
assert reasoner_instance.dist_func == "confidence" | assert reasoner_instance.dist_func == "confidence" | ||||
def test_invalid_dist_funce(kb_add): | def test_invalid_dist_funce(kb_add): | ||||
with pytest.raises(NotImplementedError) as excinfo: | with pytest.raises(NotImplementedError) as excinfo: | ||||
Reasoner(kb_add, "invalid_dist_func") | Reasoner(kb_add, "invalid_dist_func") | ||||
assert "Valid options for dist_func include \"hamming\" and \"confidence\"" in str(excinfo.value) | |||||
assert 'Valid options for dist_func include "hamming" and "confidence"' in str( | |||||
excinfo.value | |||||
) | |||||
class test_batch_abduce(object): | class test_batch_abduce(object): | ||||
@@ -95,8 +101,18 @@ class test_batch_abduce(object): | |||||
reasoner4 = Reasoner(kb_add, "confidence", max_revision=2, require_more_revision=1) | reasoner4 = Reasoner(kb_add, "confidence", max_revision=2, require_more_revision=1) | ||||
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] | assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] | ||||
assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] | assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] | ||||
assert reasoner3.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [1, 9]] | |||||
assert reasoner4.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [7, 3]] | |||||
assert reasoner3.batch_abduce(data_samples_add) == [ | |||||
[1, 7], | |||||
[7, 1], | |||||
[8, 9], | |||||
[1, 9], | |||||
] | |||||
assert reasoner4.batch_abduce(data_samples_add) == [ | |||||
[1, 7], | |||||
[7, 1], | |||||
[8, 9], | |||||
[7, 3], | |||||
] | |||||
def test_batch_abduce_ground(self, kb_add_ground, data_samples_add): | def test_batch_abduce_ground(self, kb_add_ground, data_samples_add): | ||||
reasoner1 = Reasoner(kb_add_ground, "confidence", max_revision=1, require_more_revision=0) | reasoner1 = Reasoner(kb_add_ground, "confidence", max_revision=1, require_more_revision=0) | ||||
@@ -105,8 +121,18 @@ class test_batch_abduce(object): | |||||
reasoner4 = Reasoner(kb_add_ground, "confidence", max_revision=2, require_more_revision=1) | reasoner4 = Reasoner(kb_add_ground, "confidence", max_revision=2, require_more_revision=1) | ||||
assert reasoner1.batch_abduce(data_samples_add) == [(1, 7), (7, 1), [], (1, 9)] | assert reasoner1.batch_abduce(data_samples_add) == [(1, 7), (7, 1), [], (1, 9)] | ||||
assert reasoner2.batch_abduce(data_samples_add) == [(1, 7), (7, 1), [], (1, 9)] | assert reasoner2.batch_abduce(data_samples_add) == [(1, 7), (7, 1), [], (1, 9)] | ||||
assert reasoner3.batch_abduce(data_samples_add) == [(1, 7), (7, 1), (8, 9), (1, 9)] | |||||
assert reasoner4.batch_abduce(data_samples_add) == [(1, 7), (7, 1), (8, 9), (7, 3)] | |||||
assert reasoner3.batch_abduce(data_samples_add) == [ | |||||
(1, 7), | |||||
(7, 1), | |||||
(8, 9), | |||||
(1, 9), | |||||
] | |||||
assert reasoner4.batch_abduce(data_samples_add) == [ | |||||
(1, 7), | |||||
(7, 1), | |||||
(8, 9), | |||||
(7, 3), | |||||
] | |||||
def test_batch_abduce_prolog(self, kb_add_prolog, data_samples_add): | def test_batch_abduce_prolog(self, kb_add_prolog, data_samples_add): | ||||
reasoner1 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=0) | reasoner1 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=0) | ||||
@@ -115,35 +141,73 @@ class test_batch_abduce(object): | |||||
reasoner4 = Reasoner(kb_add_prolog, "confidence", max_revision=2, require_more_revision=1) | reasoner4 = Reasoner(kb_add_prolog, "confidence", max_revision=2, require_more_revision=1) | ||||
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] | assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] | ||||
assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] | assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] | ||||
assert reasoner3.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [1, 9]] | |||||
assert reasoner4.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [7, 3]] | |||||
assert reasoner3.batch_abduce(data_samples_add) == [ | |||||
[1, 7], | |||||
[7, 1], | |||||
[8, 9], | |||||
[1, 9], | |||||
] | |||||
assert reasoner4.batch_abduce(data_samples_add) == [ | |||||
[1, 7], | |||||
[7, 1], | |||||
[8, 9], | |||||
[7, 3], | |||||
] | |||||
def test_batch_abduce_zoopt(self, kb_add_prolog, data_samples_add): | def test_batch_abduce_zoopt(self, kb_add_prolog, data_samples_add): | ||||
reasoner1 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=1) | reasoner1 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=1) | ||||
reasoner2 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=2) | reasoner2 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=2) | ||||
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] | |||||
assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [7, 3]] | |||||
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] | |||||
assert reasoner2.batch_abduce(data_samples_add) == [ | |||||
[1, 7], | |||||
[7, 1], | |||||
[8, 9], | |||||
[7, 3], | |||||
] | |||||
def test_batch_abduce_hwf1(self, kb_hwf1, data_samples_hwf): | def test_batch_abduce_hwf1(self, kb_hwf1, data_samples_hwf): | ||||
reasoner1 = Reasoner(kb_hwf1, "hamming", max_revision=3, require_more_revision=0) | reasoner1 = Reasoner(kb_hwf1, "hamming", max_revision=3, require_more_revision=0) | ||||
reasoner2 = Reasoner(kb_hwf1, "hamming", max_revision=0.5, require_more_revision=0) | reasoner2 = Reasoner(kb_hwf1, "hamming", max_revision=0.5, require_more_revision=0) | ||||
reasoner3 = Reasoner(kb_hwf1, "hamming", max_revision=0.9, require_more_revision=0) | reasoner3 = Reasoner(kb_hwf1, "hamming", max_revision=0.9, require_more_revision=0) | ||||
res = reasoner1.batch_abduce(data_samples_hwf) | res = reasoner1.batch_abduce(data_samples_hwf) | ||||
assert res == [['1', '+', '2'], ['8', 'times', '8'], [], ['4', '-', '6', 'div', '8']] | |||||
assert res == [ | |||||
["1", "+", "2"], | |||||
["8", "times", "8"], | |||||
[], | |||||
["4", "-", "6", "div", "8"], | |||||
] | |||||
res = reasoner2.batch_abduce(data_samples_hwf) | res = reasoner2.batch_abduce(data_samples_hwf) | ||||
assert res == [['1', '+', '2'], [], [], []] | |||||
assert res == [["1", "+", "2"], [], [], []] | |||||
res = reasoner3.batch_abduce(data_samples_hwf) | res = reasoner3.batch_abduce(data_samples_hwf) | ||||
assert res == [['1', '+', '2'], ['8', 'times', '8'], [], ['4', '-', '6', 'div', '8']] | |||||
assert res == [ | |||||
["1", "+", "2"], | |||||
["8", "times", "8"], | |||||
[], | |||||
["4", "-", "6", "div", "8"], | |||||
] | |||||
def test_batch_abduce_hwf2(self, kb_hwf2, data_samples_hwf): | def test_batch_abduce_hwf2(self, kb_hwf2, data_samples_hwf): | ||||
reasoner1 = Reasoner(kb_hwf2, "hamming", max_revision=3, require_more_revision=0) | reasoner1 = Reasoner(kb_hwf2, "hamming", max_revision=3, require_more_revision=0) | ||||
reasoner2 = Reasoner(kb_hwf2, "hamming", max_revision=0.5, require_more_revision=0) | reasoner2 = Reasoner(kb_hwf2, "hamming", max_revision=0.5, require_more_revision=0) | ||||
reasoner3 = Reasoner(kb_hwf2, "hamming", max_revision=0.9, require_more_revision=0) | reasoner3 = Reasoner(kb_hwf2, "hamming", max_revision=0.9, require_more_revision=0) | ||||
res = reasoner1.batch_abduce(data_samples_hwf) | res = reasoner1.batch_abduce(data_samples_hwf) | ||||
assert res == [['1', '+', '2'], ['7', 'times', '9'], ['8', 'times', '8'], ['5', '-', '8', 'div', '8']] | |||||
assert res == [ | |||||
["1", "+", "2"], | |||||
["7", "times", "9"], | |||||
["8", "times", "8"], | |||||
["5", "-", "8", "div", "8"], | |||||
] | |||||
res = reasoner2.batch_abduce(data_samples_hwf) | res = reasoner2.batch_abduce(data_samples_hwf) | ||||
assert res == [['1', '+', '2'], ['7', 'times', '9'], [], ['5', '-', '8', 'div', '8']] | |||||
assert res == [ | |||||
["1", "+", "2"], | |||||
["7", "times", "9"], | |||||
[], | |||||
["5", "-", "8", "div", "8"], | |||||
] | |||||
res = reasoner3.batch_abduce(data_samples_hwf) | res = reasoner3.batch_abduce(data_samples_hwf) | ||||
assert res == [['1', '+', '2'], ['7', 'times', '9'], ['8', 'times', '8'], ['5', '-', '8', 'div', '8']] | |||||
assert res == [ | |||||
["1", "+", "2"], | |||||
["7", "times", "9"], | |||||
["8", "times", "8"], | |||||
["5", "-", "8", "div", "8"], | |||||
] |