Compare commits

...

33 Commits

Author SHA1 Message Date
  Gao Enhao c3147dcce7 [MNT] delete redundant import 1 year ago
  Gao Enhao 3d878135b0 [MNT] accelerate __len__ 1 year ago
  Gao Enhao 4f388e1e2b [MNT] resolve comments in basic_nn and abl_model 1 year ago
  Gao Enhao 831aa855e7 [FIX] fix bugs and run mnist and hwf successfully 1 year ago
  Gao Enhao 8fe2c68cc5 [ENH] move cache to reasoner 1 year ago
  Gao Enhao 2951e5fe5a [ENH] add search engine 1 year ago
  Gao Enhao 99c9aa37b1 [ENH] add lru mechanism to Cache 1 year ago
  Gao Enhao 7a825e1033 [MNT] delete redundant test code 1 year ago
  Gao Enhao 6022f702d9 [MNT] add repr to kbs and change base to GKB 1 year ago
  Gao Enhao 15d79ee1f2 [MNT] change entail to check_equal 1 year ago
  Gao Enhao 0a6429b410 [MNT] change np.inf to None 1 year ago
  Gao Enhao 41d52ef6c4 [MNT] resolve comments in reasoner.py 1 year ago
  Gao Enhao 5a740b418a [MNT] delete redundant setting of metainfo 1 year ago
  Gao Enhao d6d29d632b [MNT] resolve some comments in kbs 1 year ago
  Gao Enhao b18e42b0e0 [MNT] resolve comments in abl_model.py 1 year ago
  Gao Enhao 0cc75e11dd [MNT] use deepcopy when set_metainfo 1 year ago
  Gao Enhao 82c9936853 [MNT] add type check to __setattr__ 1 year ago
  Gao Enhao d6f310c406 [MNT] sort import 1 year ago
  Gao Enhao aa88cefe56 [MNT] sort import 1 year ago
  Gao Enhao 29713d6f7d [ENH] run hwf and mnist_add successfully 1 year ago
  Gao Enhao 23bcbf1e01 [ENH] create Cache 1 year ago
  Gao Enhao ddf7b7a3e1 [ENH] create ListData 1 year ago
  Gao Enhao 9ae76d6552 [ENH] reformat interface of kb 1 year ago
  Gao Enhao 5111f03f7d [ENH] add abstract data interface to abl_model 1 year ago
  Gao Enhao 3a3b0ee575 [ENH] add abstract data interface to reasoner 1 year ago
  Gao Enhao 53bb17bf37 [ENH] add abstract data interface to bridge 1 year ago
  Gao Enhao 25e527f0fe [ENH] add abstract data interface to evaluation 1 year ago
  Gao Enhao 7e79dccd6e [MNT] add CURRENT_DIR 1 year ago
  Gao Enhao 92c466a7e6 [MNT] sort import 1 year ago
  Gao Enhao d80313b213 [ENH] remove dataset creation in predict 1 year ago
  Gao Enhao 95fa385ce6 [ENH] sort import 1 year ago
  Gao Enhao ebb4d0090e [ENH] use abstract data interface in SimpleBridge 1 year ago
  Gao Enhao b043bf6aee [FIX] fix a few bugs 1 year ago
45 changed files with 2417 additions and 1067 deletions
Split View
  1. +1
    -1
      abl/__init__.py
  2. +26
    -14
      abl/bridge/base_bridge.py
  3. +82
    -69
      abl/bridge/simple_bridge.py
  4. +2
    -1
      abl/dataset/__init__.py
  5. +2
    -1
      abl/dataset/bridge_dataset.py
  6. +2
    -1
      abl/dataset/classification_dataset.py
  7. +56
    -0
      abl/dataset/prediction_dataset.py
  8. +2
    -1
      abl/dataset/regression_dataset.py
  9. +1
    -1
      abl/evaluation/__init__.py
  10. +2
    -2
      abl/evaluation/base_metric.py
  11. +8
    -11
      abl/evaluation/semantics_metric.py
  12. +2
    -1
      abl/evaluation/symbol_metric.py
  13. +36
    -63
      abl/learning/abl_model.py
  14. +56
    -24
      abl/learning/basic_nn.py
  15. +5
    -1
      abl/reasoning/__init__.py
  16. +14
    -0
      abl/reasoning/base_kb.py
  17. +60
    -0
      abl/reasoning/ground_kb.py
  18. +0
    -222
      abl/reasoning/kb.py
  19. +44
    -0
      abl/reasoning/prolog_based_kb.py
  20. +164
    -530
      abl/reasoning/reasoner.py
  21. +49
    -0
      abl/reasoning/search_based_kb.py
  22. +2
    -0
      abl/reasoning/search_engine/__init__.py
  23. +13
    -0
      abl/reasoning/search_engine/base_search_engine.py
  24. +28
    -0
      abl/reasoning/search_engine/bfs.py
  25. +42
    -0
      abl/reasoning/search_engine/zoopt.py
  26. +2
    -0
      abl/structures/__init__.py
  27. +629
    -0
      abl/structures/base_data_element.py
  28. +321
    -0
      abl/structures/list_data.py
  29. +2
    -1
      abl/utils/__init__.py
  30. +112
    -0
      abl/utils/cache.py
  31. +2
    -1
      abl/utils/utils.py
  32. +1
    -2
      docs/conf.py
  33. +5
    -4
      examples/hed/datasets/get_hed.py
  34. +6
    -6
      examples/hed/hed_bridge.py
  35. +2
    -2
      examples/hed/hed_example.ipynb
  36. +1
    -1
      examples/hed/utils.py
  37. +5
    -5
      examples/hwf/datasets/get_hwf.py
  38. +14
    -45
      examples/hwf/hwf_example.ipynb
  39. +129
    -0
      examples/hwf/hwf_kb.py
  40. +25
    -15
      examples/mnist_add/datasets/get_mnist_add.py
  41. +37
    -37
      examples/mnist_add/mnist_add_example.ipynb
  42. +17
    -0
      examples/mnist_add/mnist_add_kb.py
  43. +4
    -5
      examples/models/nn.py
  44. +1
    -0
      setup.py
  45. +403
    -0
      tests/test_reasoning.py

+ 1
- 1
abl/__init__.py View File

@@ -1,2 +1,2 @@
from .learning import abl_model, basic_nn
from .reasoning import reasoner, kb
from .reasoning import base_kb, ground_kb, reasoner, search_based_kb

+ 26
- 14
abl/bridge/base_bridge.py View File

@@ -1,52 +1,64 @@
from abc import ABCMeta, abstractmethod
from typing import Any, List, Tuple
from typing import Any, List, Optional, Tuple, Union

from ..learning import ABLModel
from ..reasoning import ReasonerBase
from ..structures import ListData

DataSet = Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]]

class BaseBridge(metaclass=ABCMeta):

class BaseBridge(metaclass=ABCMeta):
def __init__(self, model: ABLModel, abducer: ReasonerBase) -> None:
if not isinstance(model, ABLModel):
raise TypeError("Expected an ABLModel")
raise TypeError(
"Expected an instance of ABLModel, but received type: {}".format(
type(model)
)
)
if not isinstance(abducer, ReasonerBase):
raise TypeError("Expected an ReasonerBase")
raise TypeError(
"Expected an instance of ReasonerBase, but received type: {}".format(
type(abducer)
)
)

self.model = model
self.abducer = abducer

@abstractmethod
def predict(self, X: List[List[Any]]) -> Tuple[List[List[Any]], List[List[Any]]]:
def predict(
self, data_samples: ListData
) -> Tuple[List[List[Any]], List[List[Any]]]:
"""Placeholder for predict labels from input."""
pass

@abstractmethod
def abduce_pseudo_label(self, pseudo_label: List[List[Any]]) -> List[List[Any]]:
def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
"""Placeholder for abduce pseudo labels."""
pass

@abstractmethod
def idx_to_pseudo_label(self, idx: List[List[Any]]) -> List[List[Any]]:
def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
"""Placeholder for map label space to symbol space."""
pass

@abstractmethod
def pseudo_label_to_idx(self, pseudo_label: List[List[Any]]) -> List[List[Any]]:
def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]:
"""Placeholder for map symbol space to label space."""
pass
@abstractmethod
def train(self, train_data):
def train(self, train_data: Union[ListData, DataSet]):
"""Placeholder for train loop of ABductive Learning."""
pass

@abstractmethod
def test(self, test_data):
def valid(self, valid_data: Union[ListData, DataSet]) -> None:
"""Placeholder for model test."""
pass

@abstractmethod
def valid(self, valid_data):
def test(self, test_data: Union[ListData, DataSet]) -> None:
"""Placeholder for model validation."""
pass

+ 82
- 69
abl/bridge/simple_bridge.py View File

@@ -1,13 +1,14 @@
from ..learning import ABLModel
from ..reasoning import ReasonerBase
from ..evaluation import BaseMetric
from .base_bridge import BaseBridge
from typing import List, Union, Any, Tuple, Dict, Optional
import os.path as osp
from typing import Any, Dict, List, Optional, Tuple, Union

from numpy import ndarray

from torch.utils.data import DataLoader
from ..dataset import BridgeDataset
from ..utils.logger import print_log
from ..evaluation import BaseMetric
from ..learning import ABLModel
from ..reasoning import ReasonerBase
from ..structures import ListData
from ..utils import print_log
from .base_bridge import BaseBridge, DataSet


class SimpleBridge(BaseBridge):
@@ -20,85 +21,99 @@ class SimpleBridge(BaseBridge):
super().__init__(model, abducer)
self.metric_list = metric_list

def predict(self, X) -> Tuple[List[List[Any]], ndarray]:
pred_res = self.model.predict(X)
pred_idx, pred_prob = pred_res["label"], pred_res["prob"]
return pred_idx, pred_prob
# TODO: add abducer.mapping to the property of SimpleBridge

def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]:
self.model.predict(data_samples)
return data_samples["pred_idx"], data_samples.get("pred_prob", None)

def abduce_pseudo_label(
self,
pred_prob: ndarray,
pred_pseudo_label: List[List[Any]],
Y: List[Any],
data_samples: ListData,
max_revision: int = -1,
require_more_revision: int = 0,
) -> List[List[Any]]:
return self.abducer.batch_abduce(pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision)
self.abducer.batch_abduce(data_samples, max_revision, require_more_revision)
return data_samples["abduced_pseudo_label"]

def idx_to_pseudo_label(
self, idx: List[List[Any]], mapping: Dict = None
self, data_samples: ListData, mapping: Optional[Dict] = None
) -> List[List[Any]]:
if mapping is None:
mapping = self.abducer.mapping
return [[mapping[_idx] for _idx in sub_list] for sub_list in idx]
pred_idx = data_samples.pred_idx
data_samples.pred_pseudo_label = [
[mapping[_idx] for _idx in sub_list] for sub_list in pred_idx
]
return data_samples["pred_pseudo_label"]

def pseudo_label_to_idx(
self, pseudo_label: List[List[Any]], mapping: Dict = None
self, data_samples: ListData, mapping: Optional[Dict] = None
) -> List[List[Any]]:
if mapping is None:
mapping = self.abducer.remapping
return [
[mapping[_pseudo_label] for _pseudo_label in sub_list]
for sub_list in pseudo_label
abduced_idx = [
[mapping[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list]
for sub_list in data_samples.abduced_pseudo_label
]
data_samples.abduced_idx = abduced_idx
return data_samples["abduced_idx"]

def data_preprocess(self, X: List[Any], gt_pseudo_label: List[Any], Y: List[Any]) -> ListData:
data_samples = ListData()

data_samples.X = X
data_samples.gt_pseudo_label = gt_pseudo_label
data_samples.Y = Y

return data_samples

def train(
self,
train_data: Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]],
epochs: int = 50,
batch_size: Union[int, float] = -1,
train_data: Union[ListData, DataSet],
loops: int = 50,
segment_size: Union[int, float] = -1,
eval_interval: int = 1,
save_interval: Optional[int] = None,
save_dir: Optional[str] = None,
):
dataset = BridgeDataset(*train_data)
data_loader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=lambda data_list: [list(data) for data in zip(*data_list)],
)

for epoch in range(epochs):
for seg_idx, (X, Z, Y) in enumerate(data_loader):
pred_idx, pred_prob = self.predict(X)
pred_pseudo_label = self.idx_to_pseudo_label(pred_idx)
abduced_pseudo_label = self.abduce_pseudo_label(
pred_prob, pred_pseudo_label, Y
)
abduced_label = self.pseudo_label_to_idx(abduced_pseudo_label)
loss = self.model.train(X, abduced_label)
if isinstance(train_data, ListData):
data_samples = train_data
else:
data_samples = self.data_preprocess(*train_data)

for loop in range(loops):
for seg_idx in range((len(data_samples) - 1) // segment_size + 1):
sub_data_samples = data_samples[
seg_idx * segment_size : (seg_idx + 1) * segment_size
]
self.predict(sub_data_samples)
self.idx_to_pseudo_label(sub_data_samples)
self.abduce_pseudo_label(sub_data_samples)
self.pseudo_label_to_idx(sub_data_samples)
loss = self.model.train(sub_data_samples)

print_log(
f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{len(data_loader)}] model loss is {loss:.5f}",
f"loop(train) [{loop + 1}/{loops}] segment(train) [{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] model loss is {loss:.5f}",
logger="current",
)

if (epoch + 1) % eval_interval == 0 or epoch == epochs - 1:
print_log(f"Evaluation start: Epoch(val) [{epoch}]", logger="current")
if (loop + 1) % eval_interval == 0 or loop == loops - 1:
print_log(f"Evaluation start: loop(val) [{loop + 1}]", logger="current")
self.valid(train_data)

def _valid(self, data_loader):
for X, Z, Y in data_loader:
pred_idx, pred_prob = self.predict(X)
pred_pseudo_label = self.idx_to_pseudo_label(pred_idx)
data_samples = dict(
pred_idx=pred_idx,
pred_prob=pred_prob,
pred_pseudo_label=pred_pseudo_label,
gt_pseudo_label=Z,
Y=Y,
logic_forward=self.abducer.kb.logic_forward,
)
if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1):
print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current")
self.model.save(save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth"))

def _valid(self, data_samples: ListData, batch_size: int = 128) -> None:
for seg_idx in range((len(data_samples) - 1) // batch_size + 1):
sub_data_samples = data_samples[seg_idx * batch_size : (seg_idx + 1) * batch_size]
self.predict(sub_data_samples)
self.idx_to_pseudo_label(sub_data_samples)

for metric in self.metric_list:
metric.process(data_samples)
metric.process(sub_data_samples)

res = dict()
for metric in self.metric_list:
@@ -108,14 +123,12 @@ class SimpleBridge(BaseBridge):
msg += k + f": {v:.3f} "
print_log(msg, logger="current")

def valid(self, valid_data, batch_size=1000):
dataset = BridgeDataset(*valid_data)
data_loader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=lambda data_list: [list(data) for data in zip(*data_list)],
)
self._valid(data_loader)

def test(self, test_data, batch_size=1000):
self.valid(test_data, batch_size)
def valid(self, valid_data: Union[ListData, DataSet], batch_size: int = 128) -> None:
if not isinstance(valid_data, ListData):
data_samples = self.data_preprocess(*valid_data)
else:
data_samples = valid_data
self._valid(data_samples, batch_size=batch_size)

def test(self, test_data: Union[ListData, DataSet], batch_size: int = 128) -> None:
self.valid(test_data, batch_size=batch_size)

+ 2
- 1
abl/dataset/__init__.py View File

@@ -1,3 +1,4 @@
from .bridge_dataset import BridgeDataset
from .classification_dataset import ClassificationDataset
from .regression_dataset import RegressionDataset
from .prediction_dataset import PredictionDataset
from .regression_dataset import RegressionDataset

+ 2
- 1
abl/dataset/bridge_dataset.py View File

@@ -1,5 +1,6 @@
from typing import Any, List, Tuple

from torch.utils.data import Dataset
from typing import List, Any, Tuple


class BridgeDataset(Dataset):


+ 2
- 1
abl/dataset/classification_dataset.py View File

@@ -1,6 +1,7 @@
from typing import Any, Callable, List, Tuple

import torch
from torch.utils.data import Dataset
from typing import List, Any, Tuple, Callable


class ClassificationDataset(Dataset):


+ 56
- 0
abl/dataset/prediction_dataset.py View File

@@ -0,0 +1,56 @@
from typing import Any, Callable, List, Tuple

import torch
from torch.utils.data import Dataset


class PredictionDataset(Dataset):
def __init__(self, X: List[Any], transform: Callable[..., Any] = None):
"""
Initialize the dataset used for classification task.

Parameters
----------
X : List[Any]
The input data.
transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version. Defaults to None.
"""
if not isinstance(X, list):
raise ValueError("X should be of type list.")

self.X = X
self.transform = transform

def __len__(self) -> int:
"""
Return the length of the dataset.

Returns
-------
int
The length of the dataset.
"""
return len(self.X)

def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]:
"""
Get the item at the given index.

Parameters
----------
index : int
The index of the item to get.

Returns
-------
Tuple[Any, torch.Tensor]
A tuple containing the object and its label.
"""
if index >= len(self):
raise ValueError("index range error")

x = self.X[index]
if self.transform is not None:
x = self.transform(x)
return x

+ 2
- 1
abl/dataset/regression_dataset.py View File

@@ -1,6 +1,7 @@
from typing import Any, List, Tuple

import torch
from torch.utils.data import Dataset
from typing import List, Any, Tuple


class RegressionDataset(Dataset):


+ 1
- 1
abl/evaluation/__init__.py View File

@@ -1,3 +1,3 @@
from .base_metric import BaseMetric
from .symbol_metric import SymbolMetric
from .semantics_metric import SemanticsMetric
from .symbol_metric import SymbolMetric

+ 2
- 2
abl/evaluation/base_metric.py View File

@@ -1,8 +1,8 @@
import logging
from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional, Sequence
from ..utils import print_log

import logging
from ..utils import print_log


class BaseMetric(metaclass=ABCMeta):


+ 8
- 11
abl/evaluation/semantics_metric.py View File

@@ -1,25 +1,22 @@
from typing import Optional, Sequence

from ..reasoning import BaseKB
from .base_metric import BaseMetric

class ABLMetric():
pass

class SemanticsMetric(BaseMetric):
def __init__(self, prefix: Optional[str] = None) -> None:
def __init__(self, kb: BaseKB = None, prefix: Optional[str] = None) -> None:
super().__init__(prefix)
self.kb = kb

def process(self, data_samples: Sequence[dict]) -> None:
pred_pseudo_label = data_samples["pred_pseudo_label"]
gt_Y = data_samples["Y"]
logic_forward = data_samples["logic_forward"]

for pred_z, y in zip(pred_pseudo_label, gt_Y):
if logic_forward(pred_z) == y:
for data_sample in data_samples:
if self.kb.check_equal(data_sample, data_sample["Y"][0]):
self.results.append(1)
else:
self.results.append(0)

def compute_metrics(self, results: list) -> dict:
metrics = dict()
metrics["semantics_accuracy"] = sum(results) / len(results)
return metrics
return metrics

+ 2
- 1
abl/evaluation/symbol_metric.py View File

@@ -1,4 +1,5 @@
from typing import Optional, Sequence, Callable
from typing import Optional, Sequence

from .base_metric import BaseMetric




+ 36
- 63
abl/learning/abl_model.py View File

@@ -10,8 +10,10 @@
#
# ================================================================#
import pickle
from utils import flatten, reform_idx
from typing import List, Any, Optional
from typing import Any, Dict

from ..structures import ListData
from ..utils import reform_idx


class ABLModel:
@@ -30,7 +32,7 @@ class ABLModel:

Methods
-------
predict(X: List[List[Any]], mapping: Optional[dict] = None) -> dict
predict(X: List[List[Any]], mapping: Optional[Dict] = None) -> Dict
Predict the labels and probabilities for the given data.
valid(X: List[List[Any]], Y: List[Any]) -> float
Calculate the accuracy score for the given data.
@@ -42,20 +44,13 @@ class ABLModel:
Load the model from a file.
"""

def __init__(self, base_model) -> None:
self.classifier_list = []
self.classifier_list.append(base_model)
def __init__(self, base_model: Any) -> None:
if not (hasattr(base_model, "fit") and hasattr(base_model, "predict")):
raise NotImplementedError("The base_model should implement fit and predict methods.")

if not (
hasattr(base_model, "fit")
and hasattr(base_model, "predict")
and hasattr(base_model, "score")
):
raise NotImplementedError(
"base_model should have fit, predict and score methods."
)
self.base_model = base_model

def predict(self, X: List[List[Any]], mapping: Optional[dict] = None) -> dict:
def predict(self, data_samples: ListData) -> Dict:
"""
Predict the labels and probabilities for the given data.

@@ -63,53 +58,30 @@ class ABLModel:
----------
X : List[List[Any]]
The data to predict on.
mapping : Optional[dict], optional
A mapping dictionary to map labels to their original values, by default None.

Returns
-------
dict
A dictionary containing the predicted labels and probabilities.
"""
model = self.classifier_list[0]
data_X = flatten(X)
model = self.base_model
data_X = data_samples.flatten("X")
if hasattr(model, "predict_proba"):
prob = model.predict_proba(X=data_X)
label = prob.argmax(axis=1)
prob = reform_idx(prob, X)
prob = reform_idx(prob, data_samples["X"])
else:
prob = None
label = model.predict(X=data_X)
label = reform_idx(label, data_samples["X"])

if mapping is not None:
label = [mapping[y] for y in label]

label = reform_idx(label, X)
data_samples.pred_idx = label
if prob is not None:
data_samples.pred_prob = prob

return {"label": label, "prob": prob}

def valid(self, X: List[List[Any]], Y: List[Any]) -> float:
"""
Calculate the accuracy for the given data.

Parameters
----------
X : List[List[Any]]
The data to calculate the accuracy on.
Y : List[Any]
The true labels for the given data.

Returns
-------
float
The accuracy score for the given data.
"""
data_X = flatten(X)
data_Y = flatten(Y)
score = self.classifier_list[0].score(X=data_X, y=data_Y)
return score

def train(self, X: List[List[Any]], Y: List[Any]) -> float:
def train(self, data_samples: ListData) -> float:
"""
Train the model on the given data.

@@ -125,29 +97,30 @@ class ABLModel:
float
The loss value of the trained model.
"""
data_X = flatten(X)
data_Y = flatten(Y)
return self.classifier_list[0].fit(X=data_X, y=data_Y)
data_X = data_samples.flatten("X")
data_y = data_samples.flatten("abduced_idx")
return self.base_model.fit(X=data_X, y=data_y)

def _model_operation(self, operation: str, *args, **kwargs):
model = self.classifier_list[0]
model = self.base_model
if hasattr(model, operation):
method = getattr(model, operation)
method(*args, **kwargs)
else:
try:
if not f"{operation}_path" in kwargs.keys():
raise ValueError(f"'{operation}_path' should not be None")
if operation == "save":
with open(kwargs["save_path"], 'wb') as file:
pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL)
elif operation == "load":
with open(kwargs["load_path"], 'rb') as file:
self.classifier_list[0] = pickle.load(file)
except:
raise NotImplementedError(
f"{type(model).__name__} object doesn't have the {operation} method"
)
if not f"{operation}_path" in kwargs.keys():
raise ValueError(f"'{operation}_path' should not be None")
else:
try:
if operation == "save":
with open(kwargs["save_path"], "wb") as file:
pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL)
elif operation == "load":
with open(kwargs["load_path"], "rb") as file:
self.base_model = pickle.load(file)
except:
raise NotImplementedError(
f"{type(model).__name__} object doesn't have the {operation} method and the default pickle-based {operation} method failed."
)

def save(self, *args, **kwargs) -> None:
"""


+ 56
- 24
abl/learning/basic_nn.py View File

@@ -10,14 +10,16 @@
#
# ================================================================#

import torch
import os
import logging
from typing import Any, Callable, List, Optional, T, Tuple

import numpy
import torch
from torch.utils.data import DataLoader
from ..utils.logger import print_log
from ..dataset import ClassificationDataset

import os
from typing import List, Any, T, Optional, Callable, Tuple
from ..dataset import ClassificationDataset, PredictionDataset
from ..utils.logger import print_log


class BasicNN:
@@ -99,9 +101,7 @@ class BasicNN:
loss_value = self.train_epoch(data_loader)
if self.save_interval is not None and (epoch + 1) % self.save_interval == 0:
if self.save_dir is None:
raise ValueError(
"save_dir should not be None if save_interval is not None."
)
raise ValueError("save_dir should not be None if save_interval is not None.")
self.save(epoch + 1)
if self.stop_loss is not None and loss_value < self.stop_loss:
break
@@ -191,7 +191,7 @@ class BasicNN:

with torch.no_grad():
results = []
for data, _ in data_loader:
for data in data_loader:
data = data.to(device)
out = model(data)
results.append(out)
@@ -199,7 +199,10 @@ class BasicNN:
return torch.cat(results, axis=0)

def predict(
self, data_loader: DataLoader = None, X: List[Any] = None
self,
data_loader: DataLoader = None,
X: List[Any] = None,
test_transform: Callable[..., Any] = None,
) -> numpy.ndarray:
"""
Predict the class of the input data.
@@ -218,11 +221,28 @@ class BasicNN:
"""

if data_loader is None:
data_loader = self._data_loader(X)
if test_transform is None:
print_log(
"Transform used in the training phase will be used in prediction.",
"current",
level=logging.WARNING,
)
dataset = PredictionDataset(X, self.transform)
else:
dataset = PredictionDataset(X, test_transform)
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=int(self.num_workers),
collate_fn=self.collate_fn,
)
return self._predict(data_loader).argmax(axis=1).cpu().numpy()

def predict_proba(
self, data_loader: DataLoader = None, X: List[Any] = None
self,
data_loader: DataLoader = None,
X: List[Any] = None,
test_transform: Callable[..., Any] = None,
) -> numpy.ndarray:
"""
Predict the probability of each class for the input data.
@@ -241,7 +261,21 @@ class BasicNN:
"""

if data_loader is None:
data_loader = self._data_loader(X)
if test_transform is None:
print_log(
"Transform used in the training phase will be used in prediction.",
"current",
level=logging.WARNING,
)
dataset = PredictionDataset(X, self.transform)
else:
dataset = PredictionDataset(X, test_transform)
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=int(self.num_workers),
collate_fn=self.collate_fn,
)
return self._predict(data_loader).softmax(axis=1).cpu().numpy()

def _score(self, data_loader) -> Tuple[float, float]:
@@ -313,15 +347,14 @@ class BasicNN:
if data_loader is None:
data_loader = self._data_loader(X, y)
mean_loss, accuracy = self._score(data_loader)
print_log(
f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current"
)
print_log(f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current")
return accuracy

def _data_loader(
self,
X: List[Any],
y: List[int] = None,
shuffle: bool = True,
) -> DataLoader:
"""
Generate a DataLoader for user-provided input and target data.
@@ -350,7 +383,7 @@ class BasicNN:
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=True,
shuffle=shuffle,
num_workers=int(self.num_workers),
collate_fn=self.collate_fn,
)
@@ -368,14 +401,13 @@ class BasicNN:
The path to save the model, by default None.
"""
if self.save_dir is None and save_path is None:
raise ValueError(
"'save_dir' and 'save_path' should not be None simultaneously."
)
raise ValueError("'save_dir' and 'save_path' should not be None simultaneously.")

if save_path is None:
save_path = os.path.join(
self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth"
)
if save_path is not None:
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
else:
save_path = os.path.join(self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth")
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)



+ 5
- 1
abl/reasoning/__init__.py View File

@@ -1,2 +1,6 @@
from .base_kb import BaseKB
from .ground_kb import GroundKB
from .prolog_based_kb import PrologBasedKB
from .reasoner import ReasonerBase
from .kb import KBBase, prolog_KB
from .search_based_kb import SearchBasedKB
from .search_engine import BFS, BaseSearchEngine

+ 14
- 0
abl/reasoning/base_kb.py View File

@@ -0,0 +1,14 @@
from abc import ABC


class BaseKB(ABC):
def __init__(self, pseudo_label_list) -> None:
self.pseudo_label_list = pseudo_label_list

# TODO: When the output is excessively long, use ellipses as a substitute.
def __repr__(self):
return (
f"<{self.__class__.__name__}(\n"
f" pseudo_label_list: {self.pseudo_label_list!r}\n"
f") at {hex(id(self))}>"
)

+ 60
- 0
abl/reasoning/ground_kb.py View File

@@ -0,0 +1,60 @@
from abc import ABC, abstractmethod
from typing import Any, Hashable, List

from ..structures import ListData
from .base_kb import BaseKB


class GroundKB(BaseKB, ABC):
def __init__(self, pseudo_label_list: List) -> None:
super().__init__(pseudo_label_list)
self.GKB = self.construct_base()

@abstractmethod
def construct_base(self) -> dict:
pass

@abstractmethod
def get_key(self, data_sample: ListData) -> Hashable:
pass

def key2candidates(self, key: Hashable) -> List[List[Any]]:
return self.GKB[key]

def filter_candidates(
self,
data_sample: ListData,
candidates: List[List[Any]],
max_revision_num: int,
require_more_revision: int = 0,
) -> List[List[Any]]:
return candidates

def abduce_candidates(
self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0
):
return self._abduce_by_GKB(
data_sample=data_sample,
max_revision_num=max_revision_num,
require_more_revision=require_more_revision,
)

def _abduce_by_GKB(
self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0
):
candidates = self.key2candidates(self.get_key(data_sample))
return self.filter_candidates(
data_sample=data_sample,
max_revision_num=max_revision_num,
require_more_revision=require_more_revision,
candidates=candidates,
)

# TODO: When the output is excessively long, use ellipses as a substitute.
def __repr__(self):
return (
f"<{self.__class__.__name__}(\n"
f" pseudo_label_list: {self.pseudo_label_list!r}\n"
f" GKB: {self.GKB!r}\n"
f") at {hex(id(self))}>"
)

+ 0
- 222
abl/reasoning/kb.py View File

@@ -1,222 +0,0 @@
from abc import ABC, abstractmethod
import bisect
import numpy as np

from collections import defaultdict
from itertools import product, combinations

from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list

from multiprocessing import Pool

from functools import lru_cache
import pyswip

class KBBase(ABC):
def __init__(self, pseudo_label_list, max_err=0, use_cache=True):
# TODO:添加一下类型检查,比如
# if not isinstance(X, (np.ndarray, spmatrix)):
# raise TypeError("X should be numpy array or sparse matrix")

self.pseudo_label_list = pseudo_label_list
self.max_err = max_err
self.use_cache = use_cache

@abstractmethod
def logic_forward(self, pseudo_labels):
pass

def abduce_candidates(self, pred_res, y, max_revision_num, require_more_revision=0):
if not self.use_cache:
return self._abduce_by_search(pred_res, y, max_revision_num, require_more_revision)
else:
return self._abduce_by_search_cache(to_hashable(pred_res), to_hashable(y), max_revision_num, require_more_revision)
def revise_by_idx(self, pred_res, y, revision_idx):
candidates = []
abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx))
for c in abduce_c:
candidate = pred_res.copy()
for i, idx in enumerate(revision_idx):
candidate[idx] = c[i]
if check_equal(self.logic_forward(candidate), y, self.max_err):
candidates.append(candidate)
return candidates

def _revision(self, revision_num, pred_res, y):
new_candidates = []
revision_idx_list = combinations(range(len(pred_res)), revision_num)

for revision_idx in revision_idx_list:
candidates = self.revise_by_idx(pred_res, y, revision_idx)
new_candidates.extend(candidates)
return new_candidates

def _abduce_by_search(self, pred_res, y, max_revision_num, require_more_revision):
candidates = []
for revision_num in range(len(pred_res) + 1):
if revision_num == 0 and check_equal(self.logic_forward(pred_res), y, self.max_err):
candidates.append(pred_res)
elif revision_num > 0:
candidates.extend(self._revision(revision_num, pred_res, y))
if len(candidates) > 0:
min_revision_num = revision_num
break
if revision_num >= max_revision_num:
return []

for revision_num in range(min_revision_num + 1, min_revision_num + require_more_revision + 1):
if revision_num > max_revision_num:
return candidates
candidates.extend(self._revision(revision_num, pred_res, y))
return candidates
@lru_cache(maxsize=None)
def _abduce_by_search_cache(self, pred_res, y, max_revision_num, require_more_revision):
pred_res = hashable_to_list(pred_res)
y = hashable_to_list(y)
return self._abduce_by_search(pred_res, y, max_revision_num, require_more_revision)
def _dict_len(self, dic):
if not self.GKB_flag:
return 0
else:
return sum(len(c) for c in dic.values())

def __len__(self):
if not self.GKB_flag:
return 0
else:
return sum(self._dict_len(v) for v in self.base.values())
class ground_KB(KBBase):
def __init__(self, pseudo_label_list, GKB_len_list=None, max_err=0):
super().__init__(pseudo_label_list, max_err)
self.GKB_len_list = GKB_len_list
self.base = {}
X, Y = self._get_GKB()
for x, y in zip(X, Y):
self.base.setdefault(len(x), defaultdict(list))[y].append(x)
# For parallel version of _get_GKB
def _get_XY_list(self, args):
pre_x, post_x_it = args[0], args[1]
XY_list = []
for post_x in post_x_it:
x = (pre_x,) + post_x
y = self.logic_forward(x)
if y is not None:
XY_list.append((x, y))
return XY_list

# Parallel _get_GKB
def _get_GKB(self):
X, Y = [], []
for length in self.GKB_len_list:
arg_list = []
for pre_x in self.pseudo_label_list:
post_x_it = product(self.pseudo_label_list, repeat=length - 1)
arg_list.append((pre_x, post_x_it))
with Pool(processes=len(arg_list)) as pool:
ret_list = pool.map(self._get_XY_list, arg_list)
for XY_list in ret_list:
if len(XY_list) == 0:
continue
part_X, part_Y = zip(*XY_list)
X.extend(part_X)
Y.extend(part_Y)
if Y and isinstance(Y[0], (int, float)):
X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1]))
return X, Y
def abduce_candidates(self, pred_res, y, max_revision_num, require_more_revision=0):
return self._abduce_by_GKB(pred_res, y, max_revision_num, require_more_revision)
def _find_candidate_GKB(self, pred_res, y):
if self.max_err == 0:
return self.base[len(pred_res)][y]
else:
potential_candidates = self.base[len(pred_res)]
key_list = list(potential_candidates.keys())
key_idx = bisect.bisect_left(key_list, y)
all_candidates = []
for idx in range(key_idx - 1, 0, -1):
k = key_list[idx]
if abs(k - y) <= self.max_err:
all_candidates.extend(potential_candidates[k])
else:
break
for idx in range(key_idx, len(key_list)):
k = key_list[idx]
if abs(k - y) <= self.max_err:
all_candidates.extend(potential_candidates[k])
else:
break
return all_candidates
def _abduce_by_GKB(self, pred_res, y, max_revision_num, require_more_revision):
if self.base == {} or len(pred_res) not in self.GKB_len_list:
return []
all_candidates = self._find_candidate_GKB(pred_res, y)
if len(all_candidates) == 0:
return []

cost_list = hamming_dist(pred_res, all_candidates)
min_revision_num = np.min(cost_list)
revision_num = min(max_revision_num, min_revision_num + require_more_revision)
idxs = np.where(cost_list <= revision_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
return candidates


class prolog_KB(KBBase):
def __init__(self, pseudo_label_list, pl_file, max_err=0):
super().__init__(pseudo_label_list, max_err)
self.prolog = pyswip.Prolog()
self.prolog.consult(pl_file)

def logic_forward(self, pseudo_labels):
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]['Res']
if result == 'true':
return True
elif result == 'false':
return False
return result
def _revision_pred_res(self, pred_res, revision_idx):
import re
revision_pred_res = pred_res.copy()
revision_pred_res = flatten(revision_pred_res)
for idx in revision_idx:
revision_pred_res[idx] = 'P' + str(idx)
revision_pred_res = reform_idx(revision_pred_res, pred_res)
# TODO:不知道有没有更简洁的方法
regex = r"'P\d+'"
return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pred_res))
def get_query_string(self, pred_res, y, revision_idx):
query_string = "logic_forward("
query_string += self._revision_pred_res(pred_res, revision_idx)
key_is_none_flag = y is None or (type(y) == list and y[0] is None)
query_string += ",%s)." % y if not key_is_none_flag else ")."
return query_string
def revise_by_idx(self, pred_res, y, revision_idx):
candidates = []
query_string = self.get_query_string(pred_res, y, revision_idx)
save_pred_res = pred_res
pred_res = flatten(pred_res)
abduce_c = [list(z.values()) for z in self.prolog.query(query_string)]
for c in abduce_c:
candidate = pred_res.copy()
for i, idx in enumerate(revision_idx):
candidate[idx] = c[i]
candidate = reform_idx(candidate, save_pred_res)
candidates.append(candidate)
return candidates

+ 44
- 0
abl/reasoning/prolog_based_kb.py View File

@@ -0,0 +1,44 @@
from abc import ABC, abstractmethod
from typing import Any, Generator, List, Tuple, Union

import numpy as np
import pyswip

from ..structures import ListData
from .base_kb import BaseKB


class PrologBasedKB(BaseKB, ABC):
def __init__(self, pseudo_label_list, pl_file):
self.pseudo_label_list = pseudo_label_list
self.prolog = pyswip.Prolog()
self.prolog.consult(pl_file)

def logic_forward(
self, data_sample: ListData, revision_idx: Union[List, Tuple, np.ndarray] = None
) -> Generator[Union[Any, pyswip.Variable, list, dict, None], Any, None]:
return self.prolog.query(self.to_query(data_sample, revision_idx))

@abstractmethod
def to_query(self, data_sample: ListData, revision_idx: Union[List, Tuple, np.ndarray] = None):
pass

@abstractmethod
def postprocess(
self, query_res, data_sample: ListData, revision_idx: Union[List, Tuple, np.ndarray]
):
return list(query_res)

@abstractmethod
def filter_candidates(
self,
data_sample: ListData,
candidates: List[List[Any]],
max_revision_num: int,
require_more_revision: int = 0,
) -> List[List[Any]]:
return candidates

def revise_at_idx(self, data_sample: ListData, revision_idx: Union[List, Tuple, np.ndarray]):
query_res = self.logic_forward(data_sample, revision_idx)
return self.postprocess(query_res, data_sample, revision_idx)

+ 164
- 530
abl/reasoning/reasoner.py View File

@@ -1,25 +1,33 @@
from typing import Any, List, Mapping, Optional

import numpy as np
from zoopt import Dimension, Objective, Parameter, Opt
from ..utils.utils import (
confidence_dist,
flatten,
reform_idx,
hamming_dist,
calculate_revision_num,
)

from ..structures import ListData
from ..utils import Cache, calculate_revision_num, confidence_dist, hamming_dist
from .base_kb import BaseKB
from .search_engine import BFS, BaseSearchEngine


class ReasonerBase:
def __init__(self, kb, dist_func="hamming", mapping=None, use_zoopt=False):
def __init__(
self,
kb: BaseKB,
dist_func: str = "confidence",
mapping: Optional[Mapping] = None,
search_engine: Optional[BaseSearchEngine] = None,
use_cache: bool = False,
cache_file: Optional[str] = None,
cache_size: Optional[int] = 4096,
):
"""
Base class for all reasoner in the ABL system.

Parameters
----------
kb : KBBase
kb : BaseKB
The knowledge base to be used for reasoning.
dist_func : str, optional
The distance function to be used. Can be "hamming" or "confidence". Default is "hamming".
The distance function to be used. Can be "hamming" or "confidence". Default is "confidence".
mapping : dict, optional
A mapping of indices to labels. If None, a default mapping is generated.
use_zoopt : bool, optional
@@ -31,207 +39,204 @@ class ReasonerBase:
If the specified distance function is neither "hamming" nor "confidence".
"""

if not (dist_func == "hamming" or dist_func == "confidence"):
raise NotImplementedError # Only hamming or confidence distance is available.

if not isinstance(kb, BaseKB):
raise ValueError("The kb should be of type BaseKB.")
self.kb = kb

if dist_func not in ["hamming", "confidence"]:
raise NotImplementedError(f"The distance function '{dist_func}' is not implemented.")
self.dist_func = dist_func
self.use_zoopt = use_zoopt
if mapping is None:
self.mapping = {
index: label for index, label in enumerate(self.kb.pseudo_label_list)
}
self.mapping = {index: label for index, label in enumerate(self.kb.pseudo_label_list)}
else:
self.mapping = mapping
self.remapping = dict(zip(self.mapping.values(), self.mapping.keys()))
if not isinstance(mapping, dict):
raise ValueError("mapping must be of type dict")

def _get_cost_list(self, pred_pseudo_label, pred_prob, candidates):
"""
Get the list of costs between each pseudo label and candidate.
for key, value in mapping.items():
if not isinstance(key, int):
raise ValueError("All keys in the mapping must be integers")

Parameters
----------
pred_pseudo_label : list
The pseudo label to be used for computing costs of candidates.
pred_prob : list
Probabilities of the predictions. Used when distance function is "confidence".
candidates : list
List of candidate abduction result.
if value not in self.kb.pseudo_label_list:
raise ValueError("All values in the mapping must be in the pseudo_label_list")

Returns
-------
numpy.ndarray
Array of computed costs for each candidate.
"""
if self.dist_func == "hamming":
return hamming_dist(pred_pseudo_label, candidates)

elif self.dist_func == "confidence":
candidates = [[self.remapping[x] for x in c] for c in candidates]
return confidence_dist(pred_prob, candidates)

def _get_one_candidate(self, pred_pseudo_label, pred_prob, candidates):
"""
Get one candidate. If multiple candidates exist, return the one with minimum cost.

Parameters
----------
pred_pseudo_label : list
The pseudo label to be used for selecting a candidate.
pred_prob : list
Probabilities of the predictions.
candidates : list
List of candidate abduction result.
self.mapping = mapping
self.remapping = dict(zip(self.mapping.values(), self.mapping.keys()))

Returns
-------
list
The chosen candidate based on minimum cost.
If no candidates, an empty list is returned.
"""
if len(candidates) == 0:
return []
elif len(candidates) == 1:
return candidates[0]
if search_engine is None:
self.search_engine = BFS()
else:
cost_array = self._get_cost_list(pred_pseudo_label, pred_prob, candidates)
candidate = candidates[np.argmin(cost_array)]
return candidate
if not isinstance(search_engine, BaseSearchEngine):
raise ValueError("The search_engine should be of type BaseSearchEngine.")
else:
self.search_engine = search_engine

self.use_cache = use_cache
self.cache_file = cache_file
if self.use_cache:
if not hasattr(self, "get_key"):
raise NotImplementedError("If use_cache is True, get_key should be implemented.")
key_func = self.get_key
else:
key_func = lambda x: x
self.cache = Cache[ListData, List[List[Any]]](
func=self.abduce_candidates,
cache=self.use_cache,
cache_file=self.cache_file,
key_func=key_func,
max_size=cache_size,
)

def zoopt_revision_score(self, symbol_num, pred_pseudo_label, pred_prob, y, sol):
def abduce(
self,
data_sample: ListData,
max_revision: int = -1,
require_more_revision: int = 0,
):
"""
Get the revision score for a single solution.
Perform revision by abduction on the given data.

Parameters
----------
symbol_num : int
Number of total symbols.
pred_pseudo_label : list
List of predicted pseudo labels.
pred_prob : list
List of probabilities for predicted results.
pred_pseudo_label : list
List of predicted pseudo labels.
y : any
Ground truth for the predicted results.
sol : array-like
Solution to evaluate.
max_revision : int or float, optional
Maximum number of revisions to use. If float, represents the fraction of total revisions to use.
If -1, any revisions are allowed. Defaults to -1.
require_more_revision : int, optional
Number of additional revisions to require. Defaults to 0.

Returns
-------
float
The revision score for the given solution.
list
The abduced revisions.
"""
revision_idx = np.where(sol.get_x() != 0)[0]
candidates = self.revise_by_idx(pred_pseudo_label, y, revision_idx)
if len(candidates) > 0:
return np.min(self._get_cost_list(pred_pseudo_label, pred_prob, candidates))
else:
return symbol_num
symbol_num = data_sample.elements_num("pred_pseudo_label")
max_revision_num = calculate_revision_num(max_revision, symbol_num)
data_sample.set_metainfo(dict(symbol_num=symbol_num))

def _constrain_revision_num(self, solution, max_revision_num):
x = solution.get_x()
return max_revision_num - x.sum()
candidates = self.cache.get(data_sample, max_revision_num, require_more_revision)
candidate = self.select_one_candidate(data_sample, candidates)
return candidate

def zoopt_get_solution(
self, symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num
def abduce_candidates(
self,
data_sample: ListData,
max_revision_num: int = -1,
require_more_revision: int = 0,
):
"""Get the optimal solution using the Zoopt library.
"""
Perform revision by abduction on the given data.

Parameters
----------
symbol_num : int
Number of total symbols.
pred_pseudo_label : list
List of predicted pseudo labels.
pred_prob : list
List of probabilities for predicted results.
pred_pseudo_label : list
List of predicted pseudo labels.
y : any
Ground truth for the predicted results.
max_revision_num : int
Maximum number of revisions to use.
max_revision : int or float, optional
Maximum number of revisions to use. If float, represents the fraction of total revisions to use.
If -1, any revisions are allowed. Defaults to -1.
require_more_revision : int, optional
Number of additional revisions to require. Defaults to 0.

Returns
-------
array-like
The optimal solution, i.e., where to revise predict pseudo label.
list
The abduced revisions.
"""
dimension = Dimension(
size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num
)
objective = Objective(
lambda sol: self.zoopt_revision_score(
symbol_num, pred_pseudo_label, pred_prob, y, sol
),
dim=dimension,
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),
)
parameter = Parameter(budget=100, intermediate_result=False, autoset=True)
solution = Opt.min(objective, parameter).get_x()
return solution

def revise_by_idx(self, pred_pseudo_label, y, revision_idx):
if hasattr(self.kb, "abduce_candidates"):
candidates = self.kb.abduce_candidates(
data_sample, max_revision_num, require_more_revision
)
elif hasattr(self.kb, "revise_at_idx"):
candidates = []
gen = self.search_engine.generator(
data_sample,
max_revision_num=max_revision_num,
require_more_revision=require_more_revision,
)
send_signal = True
for revision_idx in gen:
candidates.extend(self.kb.revise_at_idx(data_sample, revision_idx))
if len(candidates) > 0 and send_signal:
try:
revision_idx = gen.send("success")
candidates.extend(self.kb.revise_at_idx(data_sample, revision_idx))
send_signal = False
except StopIteration:
break
else:
raise NotImplementedError(
"The kb should either implement abduce_candidates or revise_at_idx."
)
return candidates

def select_one_candidate(self, data_sample: ListData, candidates: List[List[Any]]):
"""
Revise the pseudo label according to the given indices.
Get one candidate. If multiple candidates exist, return the one with minimum cost.

Parameters
----------
pred_pseudo_label : list
List of predicted pseudo labels.
y : any
Ground truth for the predicted results.
revision_idx : array-like
Indices of the revisions to retrieve.
The pseudo label to be used for selecting a candidate.
pred_prob : list
Probabilities of the predictions.
candidates : list
List of candidate abduction result.

Returns
-------
list
The revisions according to the given indices.
The chosen candidate based on minimum cost.
If no candidates, an empty list is returned.
"""
return self.kb.revise_by_idx(pred_pseudo_label, y, revision_idx)
if len(candidates) == 0:
return []
elif len(candidates) == 1:
return candidates[0]
else:
cost_array = self._get_dist_list(data_sample, candidates)
candidate = candidates[np.argmin(cost_array)]
return candidate

def abduce(
self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0
):
def _get_dist_list(self, data_sample: ListData, candidates: List[List[Any]]):
"""
Perform revision by abduction on the given data.
Get the list of costs between each pseudo label and candidate.

Parameters
----------
pred_prob : list
List of probabilities for predicted results.
pred_pseudo_label : list
List of predicted pseudo labels.
y : any
Ground truth for the predicted results.
max_revision : int or float, optional
Maximum number of revisions to use. If float, represents the fraction of total revisions to use.
If -1, any revisions are allowed. Defaults to -1.
require_more_revision : int, optional
Number of additional revisions to require. Defaults to 0.
The pseudo label to be used for computing costs of candidates.
pred_prob : list
Probabilities of the predictions. Used when distance function is "confidence".
candidates : list
List of candidate abduction result.

Returns
-------
list
The abduced revisions.
numpy.ndarray
Array of computed costs for each candidate.
"""
symbol_num = len(flatten(pred_pseudo_label))
max_revision_num = calculate_revision_num(max_revision, symbol_num)

if self.use_zoopt:
solution = self.zoopt_get_solution(
symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num
)
revision_idx = np.where(solution != 0)[0]
candidates = self.revise_by_idx(pred_pseudo_label, y, revision_idx)
else:
candidates = self.kb.abduce_candidates(
pred_pseudo_label, y, max_revision_num, require_more_revision
)
if self.dist_func == "hamming":
return hamming_dist(data_sample["pred_pseudo_label"][0], candidates)

candidate = self._get_one_candidate(pred_pseudo_label, pred_prob, candidates)
return candidate
elif self.dist_func == "confidence":
candidates = [[self.remapping[x] for x in c] for c in candidates]
return confidence_dist(data_sample["pred_prob"][0], candidates)

def batch_abduce(
self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0
self,
data_samples: ListData,
max_revision: int = -1,
require_more_revision: int = 0,
):
"""
Perform abduction on the given data in batches.
@@ -255,384 +260,13 @@ class ReasonerBase:
list
The abduced revisions in batches.
"""
return [
abduced_pseudo_label = [
self.abduce(
_pred_prob, _pred_pseudo_label, _Y, max_revision, require_more_revision
)
for _pred_prob, _pred_pseudo_label, _Y in zip(
pred_prob, pred_pseudo_label, Y
data_sample,
max_revision=max_revision,
require_more_revision=require_more_revision,
)
for data_sample in data_samples
]

# def _batch_abduce_helper(self, args):
# z, prob, y, max_revision, require_more_revision = args
# return self.abduce((z, prob, y), max_revision, require_more_revision)

# def batch_abduce(self, Z, Y, max_revision=-1, require_more_revision=0):
# with Pool(processes=os.cpu_count()) as pool:
# results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)])
# return results

def __call__(
self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0
):
return self.batch_abduce(
pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision
)




if __name__ == "__main__":
from kb import KBBase, ground_KB, prolog_KB

prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]]
prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]]

class add_KB(KBBase):
def __init__(self, pseudo_label_list=list(range(10)),
use_cache=True):
super().__init__(pseudo_label_list, use_cache=use_cache)

def logic_forward(self, nums):
return sum(nums)
class add_ground_KB(ground_KB):
def __init__(self, pseudo_label_list=list(range(10)),
GKB_len_list=[2]):
super().__init__(pseudo_label_list, GKB_len_list)

def logic_forward(self, nums):
return sum(nums)
def test_add(reasoner):
res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0)
print(res)
print()

print("add_KB with GKB:")
kb = add_ground_KB()
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("add_KB without GKB:")
kb = add_KB()
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("add_KB without GKB, no cache")
kb = add_KB(use_cache=False)
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("prolog_KB with add.pl:")
kb = prolog_KB(pseudo_label_list=list(range(10)),
pl_file="examples/mnist_add/datasets/add.pl")
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("prolog_KB with add.pl using zoopt:")
kb = prolog_KB(
pseudo_label_list=list(range(10)),
pl_file="examples/mnist_add/datasets/add.pl",
)
reasoner = ReasonerBase(kb, "confidence", use_zoopt=True)
test_add(reasoner)

print("add_KB with multiple inputs at once:")
multiple_prob = [[
[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
],
[
[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
]]

kb = add_KB()
reasoner = ReasonerBase(kb, "confidence")
res = reasoner.batch_abduce(
multiple_prob,
[[1, 1], [1, 2]],
[4, 8],
max_revision=2,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
multiple_prob,
[[1, 1], [1, 2]],
[4, 8],
max_revision=2,
require_more_revision=1,
)
print(res)
print()

class HWF_KB(KBBase):
def __init__(
self,
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9",
"+", "-", "times", "div"],
max_err=1e-3,
):
super().__init__(pseudo_label_list, max_err)

def _valid_candidate(self, formula):
if len(formula) % 2 == 0:
return False
for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]:
return False
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]:
return False
return True

def logic_forward(self, formula):
if not self._valid_candidate(formula):
return np.inf
mapping = {str(i): str(i) for i in range(1, 10)}
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"})
formula = [mapping[f] for f in formula]
return eval("".join(formula))
class HWF_ground_KB(ground_KB):
def __init__(
self,
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9",
"+", "-", "times", "div"],
GKB_len_list=[1, 3, 5, 7],
max_err=1e-3,
):
super().__init__(pseudo_label_list, GKB_len_list, max_err)

def _valid_candidate(self, formula):
if len(formula) % 2 == 0:
return False
for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]:
return False
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]:
return False
return True

def logic_forward(self, formula):
if not self._valid_candidate(formula):
return np.inf
mapping = {str(i): str(i) for i in range(1, 10)}
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"})
formula = [mapping[f] for f in formula]
return eval("".join(formula))
def test_hwf(reasoner):
res = reasoner.batch_abduce(
[None],
[["5", "+", "2"]],
[3],
max_revision=2,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[None],
[["5", "+", "9"]],
[65],
max_revision=3,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[None],
[["5", "8", "8", "8", "8"]],
[3.17],
max_revision=5,
require_more_revision=3,
)
print(res)
print()
def test_hwf_multiple(reasoner, max_revisions):
res = reasoner.batch_abduce(
[None, None],
[["5", "+", "2"], ["5", "+", "9"]],
[3, 64],
max_revision=max_revisions[0],
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[None, None],
[["5", "+", "2"], ["5", "+", "9"]],
[3, 64],
max_revision=max_revisions[1],
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[None, None],
[["5", "+", "2"], ["5", "+", "9"]],
[3, 65],
max_revision=max_revisions[2],
require_more_revision=0,
)
print(res)
print()

print("HWF_KB with GKB, max_err=0.1")
kb = HWF_ground_KB(GKB_len_list=[1, 3, 5], max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HWF_KB without GKB, max_err=0.1")
kb = HWF_KB(max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HWF_KB with GKB, max_err=1")
kb = HWF_ground_KB(GKB_len_list=[1, 3, 5], max_err=1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HWF_KB without GKB, max_err=1")
kb = HWF_KB(max_err=1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HWF_KB with multiple inputs at once:")
kb = HWF_KB(max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf_multiple(reasoner, max_revisions=[1,3,3])
print("max_revision is float")
test_hwf_multiple(reasoner, max_revisions=[0.5,0.9,0.9])

class HED_prolog_KB(prolog_KB):
def __init__(self, pseudo_label_list, pl_file):
super().__init__(pseudo_label_list, pl_file)

def consist_rule(self, exs, rules):
rules = str(rules).replace("'", "")
pl_query = "eval_inst_feature(%s, %s)." % (exs, rules)
return len(list(self.prolog.query(pl_query))) != 0

def abduce_rules(self, pred_res):
pl_query = "consistent_inst_feature(%s, X)." % pred_res
prolog_result = list(self.prolog.query(pl_query))
if len(prolog_result) == 0:
return None
prolog_rules = prolog_result[0]["X"]
rules = [rule.value for rule in prolog_rules]
return rules

class HED_Reasoner(ReasonerBase):
def __init__(self, kb, dist_func="hamming"):
super().__init__(kb, dist_func, use_zoopt=True)

def _revise_by_idxs(self, pred_res, y, all_revision_flag, idxs):
pred = []
k = []
revision_flag = []
for idx in idxs:
pred.append(pred_res[idx])
k.append(y[idx])
revision_flag += list(all_revision_flag[idx])
revision_idx = np.where(np.array(revision_flag) != 0)[0]
candidate = self.revise_by_idx(pred, k, revision_idx)
return candidate

def zoopt_revision_score(self, symbol_num, pred_res, pred_prob, y, sol):
all_revision_flag = reform_idx(sol.get_x(), pred_res)
lefted_idxs = [i for i in range(len(pred_res))]
candidate_size = []
while lefted_idxs:
idxs = []
idxs.append(lefted_idxs.pop(0))
max_candidate_idxs = []
found = False
for idx in range(-1, len(pred_res)):
if (not idx in idxs) and (idx >= 0):
idxs.append(idx)
candidate = self._revise_by_idxs(
pred_res, y, all_revision_flag, idxs
)
if len(candidate) == 0:
if len(idxs) > 1:
idxs.pop()
else:
if len(idxs) > len(max_candidate_idxs):
found = True
max_candidate_idxs = idxs.copy()
removed = [i for i in lefted_idxs if i in max_candidate_idxs]
if found:
candidate_size.append(len(removed) + 1)
lefted_idxs = [
i for i in lefted_idxs if i not in max_candidate_idxs
]
candidate_size.sort()
score = 0
import math

for i in range(0, len(candidate_size)):
score -= math.exp(-i) * candidate_size[i]
return score

def abduce_rules(self, pred_res):
return self.kb.abduce_rules(pred_res)

kb = HED_prolog_KB(
pseudo_label_list=[1, 0, "+", "="],
pl_file="examples/hed/datasets/learn_add.pl",
)
reasoner = HED_Reasoner(kb)
consist_exs = [
[1, 1, "+", 0, "=", 1, 1],
[1, "+", 1, "=", 1, 0],
[0, "+", 0, "=", 0],
]
inconsist_exs1 = [
[1, 1, "+", 0, "=", 1, 1],
[1, "+", 1, "=", 1, 0],
[0, "+", 0, "=", 0],
[0, "+", 0, "=", 1],
]
inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]]
rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"]

print("HED_kb logic forward")
print(kb.logic_forward(consist_exs))
print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2))
print()
print("HED_kb consist rule")
print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules))
print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules))
print()

print("HED_Reasoner abduce")
res = reasoner.abduce(
[[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs)
)
print(res)
res = reasoner.abduce(
[[[None]]] * len(inconsist_exs1), inconsist_exs1, [None] * len(inconsist_exs1)
)
print(res)
res = reasoner.abduce(
[[[None]]] * len(inconsist_exs2), inconsist_exs2, [None] * len(inconsist_exs2)
)
print(res)
print()

print("HED_Reasoner abduce rules")
abduced_rules = reasoner.abduce_rules(consist_exs)
print(abduced_rules)
data_samples.abduced_pseudo_label = abduced_pseudo_label
return abduced_pseudo_label

+ 49
- 0
abl/reasoning/search_based_kb.py View File

@@ -0,0 +1,49 @@
from abc import ABC, abstractmethod
from itertools import product
from typing import Any, List, Tuple, Union

import numpy

from ..structures import ListData
from .base_kb import BaseKB


class SearchBasedKB(BaseKB, ABC):
def __init__(
self,
pseudo_label_list: List,
) -> None:
super().__init__(pseudo_label_list)

@abstractmethod
def check_equal(self, data_sample: ListData, y: Any):
"""Placeholder for check_equal."""
pass

def revise_at_idx(
self,
data_sample: ListData,
revision_idx: Union[List, Tuple, numpy.ndarray],
):
candidates = []
abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx))
for c in abduce_c:
new_data_sample = data_sample.clone()
candidate = new_data_sample["pred_pseudo_label"][0].copy()
for i, idx in enumerate(revision_idx):
candidate[idx] = c[i]
new_data_sample["pred_pseudo_label"][0] = candidate
if self.check_equal(new_data_sample, new_data_sample["Y"][0]):
candidates.append(candidate)
return candidates

# TODO: When the output is excessively long, use ellipses as a substitute.
def __repr__(self):
return (
f"<{self.__class__.__name__}(\n"
f" pseudo_label_list: {self.pseudo_label_list!r}\n"
f" search_strategy: {self.search_strategy!r}\n"
f" use_cache: {self.use_cache!r}\n"
f" cache_root: {self.cache_root!r}\n"
f") at {hex(id(self))}>"
)

+ 2
- 0
abl/reasoning/search_engine/__init__.py View File

@@ -0,0 +1,2 @@
from .base_search_engine import BaseSearchEngine
from .bfs import BFS

+ 13
- 0
abl/reasoning/search_engine/base_search_engine.py View File

@@ -0,0 +1,13 @@
from abc import ABC, abstractmethod
from typing import List, Tuple, Union

import numpy

from ...structures import ListData


class BaseSearchEngine(ABC):
@abstractmethod
def generator(data_sample: ListData) -> Union[List, Tuple, numpy.ndarray]:
"""Placeholder for the generator of revision_idx."""
pass

+ 28
- 0
abl/reasoning/search_engine/bfs.py View File

@@ -0,0 +1,28 @@
from itertools import combinations
from typing import List, Tuple, Union

import numpy

from ...structures import ListData
from .base_search_engine import BaseSearchEngine


class BFS(BaseSearchEngine):
def __init__(self) -> None:
pass

def generator(
self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0
) -> Union[List, Tuple, numpy.ndarray]:
symbol_num = data_sample["symbol_num"]
max_revision_num = min(max_revision_num, symbol_num)
real_end = max_revision_num
for revision_num in range(max_revision_num + 1):
if revision_num > real_end:
break

revision_idx_tuple = combinations(range(symbol_num), revision_num)
for revision_idx in revision_idx_tuple:
received = yield revision_idx
if received == "success":
real_end = min(symbol_num, revision_num + require_more_revision)

+ 42
- 0
abl/reasoning/search_engine/zoopt.py View File

@@ -0,0 +1,42 @@
from typing import List, Tuple, Union

import numpy as np
from zoopt import Dimension, Objective, Opt, Parameter, Solution

from ...structures import ListData
from ..reasoner import ReasonerBase
from ..search_based_kb import SearchBasedKB
from .base_search_engine import BaseSearchEngine


class Zoopt(BaseSearchEngine):
def __init__(self, reasoner: ReasonerBase, kb: SearchBasedKB) -> None:
self.reasoner = reasoner
self.kb = kb

def score_func(self, data_sample: ListData, solution: Solution):
revision_idx = np.where(solution.get_x() != 0)[0]
candidates = self.kb.revise_at_idx(data_sample, revision_idx)
if len(candidates) > 0:
return np.min(self.reasoner._get_dist_list(data_sample, candidates))
else:
return data_sample["symbol_num"]

@staticmethod
def constraint(solution: Solution, max_revision_num: int):
x = solution.get_x()
return max_revision_num - x.sum()

def generator(
self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0
) -> Union[List, Tuple, np.ndarray]:
symbol_num = data_sample["symbol_num"]
dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num)
objective = Objective(
lambda solution: self.score_func(self, data_sample, solution),
dim=dimension,
constraint=lambda solution: self.constraint(solution, max_revision_num),
)
parameter = Parameter(budget=100, intermediate_result=False, autoset=True)
solution = Opt.min(objective, parameter).get_x()
yield solution

+ 2
- 0
abl/structures/__init__.py View File

@@ -0,0 +1,2 @@
from .base_data_element import BaseDataElement
from .list_data import ListData

+ 629
- 0
abl/structures/base_data_element.py View File

@@ -0,0 +1,629 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Any, Iterator, Optional, Tuple, Type, Union

import numpy as np
import torch


class BaseDataElement:
"""A base data interface that supports Tensor-like and dict-like
operations.

A typical data elements refer to predicted results or ground truth labels
on a task, such as predicted bboxes, instance masks, semantic
segmentation masks, etc. Because groundtruth labels and predicted results
often have similar properties (for example, the predicted bboxes and the
groundtruth bboxes), MMEngine uses the same abstract data interface to
encapsulate predicted results and groundtruth labels, and it is recommended
to use different name conventions to distinguish them, such as using
``gt_instances`` and ``pred_instances`` to distinguish between labels and
predicted results. Additionally, we distinguish data elements at instance
level, pixel level, and label level. Each of these types has its own
characteristics. Therefore, MMEngine defines the base class
``BaseDataElement``, and implement ``InstanceData``, ``PixelData``, and
``LabelData`` inheriting from ``BaseDataElement`` to represent different
types of ground truth labels or predictions.

Another common data element is sample data. A sample data consists of input
data (such as an image) and its annotations and predictions. In general,
an image can have multiple types of annotations and/or predictions at the
same time (for example, both pixel-level semantic segmentation annotations
and instance-level detection bboxes annotations). All labels and
predictions of a training sample are often passed between Dataset, Model,
Visualizer, and Evaluator components. In order to simplify the interface
between components, we can treat them as a large data element and
encapsulate them. Such data elements are generally called XXDataSample in
the OpenMMLab. Therefore, Similar to `nn.Module`, the `BaseDataElement`
allows `BaseDataElement` as its attribute. Such a class generally
encapsulates all the data of a sample in the algorithm library, and its
attributes generally are various types of data elements. For example,
MMDetection is assigned by the BaseDataElement to encapsulate all the data
elements of the sample labeling and prediction of a sample in the
algorithm library.

The attributes in ``BaseDataElement`` are divided into two parts,
the ``metainfo`` and the ``data`` respectively.

- ``metainfo``: Usually contains the
information about the image such as filename,
image_shape, pad_shape, etc. The attributes can be accessed or
modified by dict-like or object-like operations, such as
``.`` (for data access and modification), ``in``, ``del``,
``pop(str)``, ``get(str)``, ``metainfo_keys()``,
``metainfo_values()``, ``metainfo_items()``, ``set_metainfo()`` (for
set or change key-value pairs in metainfo).

- ``data``: Annotations or model predictions are
stored. The attributes can be accessed or modified by
dict-like or object-like operations, such as
``.``, ``in``, ``del``, ``pop(str)``, ``get(str)``, ``keys()``,
``values()``, ``items()``. Users can also apply tensor-like
methods to all :obj:`torch.Tensor` in the ``data_fields``,
such as ``.cuda()``, ``.cpu()``, ``.numpy()``, ``.to()``,
``to_tensor()``, ``.detach()``.

Args:
metainfo (dict, optional): A dict contains the meta information
of single image, such as ``dict(img_shape=(512, 512, 3),
scale_factor=(1, 1, 1, 1))``. Defaults to None.
kwargs (dict, optional): A dict contains annotations of single image or
model predictions. Defaults to None.

Examples:
>>> import torch
>>> from mmengine.structures import BaseDataElement
>>> gt_instances = BaseDataElement()
>>> bboxes = torch.rand((5, 4))
>>> scores = torch.rand((5,))
>>> img_id = 0
>>> img_shape = (800, 1333)
>>> gt_instances = BaseDataElement(
... metainfo=dict(img_id=img_id, img_shape=img_shape),
... bboxes=bboxes, scores=scores)
>>> gt_instances = BaseDataElement(
... metainfo=dict(img_id=img_id, img_shape=(640, 640)))

>>> # new
>>> gt_instances1 = gt_instances.new(
... metainfo=dict(img_id=1, img_shape=(640, 640)),
... bboxes=torch.rand((5, 4)),
... scores=torch.rand((5,)))
>>> gt_instances2 = gt_instances1.new()

>>> # add and process property
>>> gt_instances = BaseDataElement()
>>> gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100)))
>>> assert 'img_shape' in gt_instances.metainfo_keys()
>>> assert 'img_shape' in gt_instances
>>> assert 'img_shape' not in gt_instances.keys()
>>> assert 'img_shape' in gt_instances.all_keys()
>>> print(gt_instances.img_shape)
(100, 100)
>>> gt_instances.scores = torch.rand((5,))
>>> assert 'scores' in gt_instances.keys()
>>> assert 'scores' in gt_instances
>>> assert 'scores' in gt_instances.all_keys()
>>> assert 'scores' not in gt_instances.metainfo_keys()
>>> print(gt_instances.scores)
tensor([0.5230, 0.7885, 0.2426, 0.3911, 0.4876])
>>> gt_instances.bboxes = torch.rand((5, 4))
>>> assert 'bboxes' in gt_instances.keys()
>>> assert 'bboxes' in gt_instances
>>> assert 'bboxes' in gt_instances.all_keys()
>>> assert 'bboxes' not in gt_instances.metainfo_keys()
>>> print(gt_instances.bboxes)
tensor([[0.0900, 0.0424, 0.1755, 0.4469],
[0.8648, 0.0592, 0.3484, 0.0913],
[0.5808, 0.1909, 0.6165, 0.7088],
[0.5490, 0.4209, 0.9416, 0.2374],
[0.3652, 0.1218, 0.8805, 0.7523]])

>>> # delete and change property
>>> gt_instances = BaseDataElement(
... metainfo=dict(img_id=0, img_shape=(640, 640)),
... bboxes=torch.rand((6, 4)), scores=torch.rand((6,)))
>>> gt_instances.set_metainfo(dict(img_shape=(1280, 1280)))
>>> gt_instances.img_shape # (1280, 1280)
>>> gt_instances.bboxes = gt_instances.bboxes * 2
>>> gt_instances.get('img_shape', None) # (1280, 1280)
>>> gt_instances.get('bboxes', None) # 6x4 tensor
>>> del gt_instances.img_shape
>>> del gt_instances.bboxes
>>> assert 'img_shape' not in gt_instances
>>> assert 'bboxes' not in gt_instances
>>> gt_instances.pop('img_shape', None) # None
>>> gt_instances.pop('bboxes', None) # None

>>> # Tensor-like
>>> cuda_instances = gt_instances.cuda()
>>> cuda_instances = gt_instances.to('cuda:0')
>>> cpu_instances = cuda_instances.cpu()
>>> cpu_instances = cuda_instances.to('cpu')
>>> fp16_instances = cuda_instances.to(
... device=None, dtype=torch.float16, non_blocking=False,
... copy=False, memory_format=torch.preserve_format)
>>> cpu_instances = cuda_instances.detach()
>>> np_instances = cpu_instances.numpy()

>>> # print
>>> metainfo = dict(img_shape=(800, 1196, 3))
>>> gt_instances = BaseDataElement(
... metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3]))
>>> sample = BaseDataElement(metainfo=metainfo,
... gt_instances=gt_instances)
>>> print(sample)
<BaseDataElement(
META INFORMATION
img_shape: (800, 1196, 3)
DATA FIELDS
gt_instances: <BaseDataElement(
META INFORMATION
img_shape: (800, 1196, 3)
DATA FIELDS
det_labels: tensor([0, 1, 2, 3])
) at 0x7f0ec5eadc70>
) at 0x7f0fea49e130>

>>> # inheritance
>>> class DetDataSample(BaseDataElement):
... @property
... def proposals(self):
... return self._proposals
... @proposals.setter
... def proposals(self, value):
... self.set_field(value, '_proposals', dtype=BaseDataElement)
... @proposals.deleter
... def proposals(self):
... del self._proposals
... @property
... def gt_instances(self):
... return self._gt_instances
... @gt_instances.setter
... def gt_instances(self, value):
... self.set_field(value, '_gt_instances',
... dtype=BaseDataElement)
... @gt_instances.deleter
... def gt_instances(self):
... del self._gt_instances
... @property
... def pred_instances(self):
... return self._pred_instances
... @pred_instances.setter
... def pred_instances(self, value):
... self.set_field(value, '_pred_instances',
... dtype=BaseDataElement)
... @pred_instances.deleter
... def pred_instances(self):
... del self._pred_instances
>>> det_sample = DetDataSample()
>>> proposals = BaseDataElement(bboxes=torch.rand((5, 4)))
>>> det_sample.proposals = proposals
>>> assert 'proposals' in det_sample
>>> assert det_sample.proposals == proposals
>>> del det_sample.proposals
>>> assert 'proposals' not in det_sample
>>> with self.assertRaises(AssertionError):
... det_sample.proposals = torch.rand((5, 4))
"""

def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None:
self._metainfo_fields: set = set()
self._data_fields: set = set()

if metainfo is not None:
self.set_metainfo(metainfo=metainfo)
if kwargs:
self.set_data(kwargs)

def set_metainfo(self, metainfo: dict) -> None:
"""Set or change key-value pairs in ``metainfo_field`` by parameter
``metainfo``.

Args:
metainfo (dict): A dict contains the meta information
of image, such as ``img_shape``, ``scale_factor``, etc.
"""
assert isinstance(
metainfo, dict
), f"metainfo should be a ``dict`` but got {type(metainfo)}"
meta = copy.deepcopy(metainfo)
for k, v in meta.items():
self.set_field(name=k, value=v, field_type="metainfo", dtype=None)

def set_data(self, data: dict) -> None:
"""Set or change key-value pairs in ``data_field`` by parameter
``data``.

Args:
data (dict): A dict contains annotations of image or
model predictions.
"""
assert isinstance(data, dict), f"data should be a `dict` but got {data}"
for k, v in data.items():
# Use `setattr()` rather than `self.set_field` to allow `set_data`
# to set property method.
setattr(self, k, v)

def update(self, instance: "BaseDataElement") -> None:
"""The update() method updates the BaseDataElement with the elements
from another BaseDataElement object.

Args:
instance (BaseDataElement): Another BaseDataElement object for
update the current object.
"""
assert isinstance(
instance, BaseDataElement
), f"instance should be a `BaseDataElement` but got {type(instance)}"
self.set_metainfo(dict(instance.metainfo_items()))
self.set_data(dict(instance.items()))

def new(self, *, metainfo: Optional[dict] = None, **kwargs) -> "BaseDataElement":
"""Return a new data element with same type. If ``metainfo`` and
``data`` are None, the new data element will have same metainfo and
data. If metainfo or data is not None, the new result will overwrite it
with the input value.

Args:
metainfo (dict, optional): A dict contains the meta information
of image, such as ``img_shape``, ``scale_factor``, etc.
Defaults to None.
kwargs (dict): A dict contains annotations of image or
model predictions.

Returns:
BaseDataElement: A new data element with same type.
"""
new_data = self.__class__()

if metainfo is not None:
new_data.set_metainfo(metainfo)
else:
new_data.set_metainfo(dict(self.metainfo_items()))
if kwargs:
new_data.set_data(kwargs)
else:
new_data.set_data(dict(self.items()))
return new_data

def clone(self):
"""Deep copy the current data element.

Returns:
BaseDataElement: The copy of current data element.
"""
clone_data = self.__class__()
clone_data.set_metainfo(dict(self.metainfo_items()))
clone_data.set_data(dict(self.items()))
return clone_data

def keys(self) -> list:
"""
Returns:
list: Contains all keys in data_fields.
"""
# We assume that the name of the attribute related to property is
# '_' + the name of the property. We use this rule to filter out
# private keys.
# TODO: Use a more robust way to solve this problem
private_keys = {
"_" + key
for key in self._data_fields
if isinstance(getattr(type(self), key, None), property)
}
return list(self._data_fields - private_keys)

def metainfo_keys(self) -> list:
"""
Returns:
list: Contains all keys in metainfo_fields.
"""
return list(self._metainfo_fields)

def values(self) -> list:
"""
Returns:
list: Contains all values in data.
"""
return [getattr(self, k) for k in self.keys()]

def metainfo_values(self) -> list:
"""
Returns:
list: Contains all values in metainfo.
"""
return [getattr(self, k) for k in self.metainfo_keys()]

def all_keys(self) -> list:
"""
Returns:
list: Contains all keys in metainfo and data.
"""
return self.metainfo_keys() + self.keys()

def all_values(self) -> list:
"""
Returns:
list: Contains all values in metainfo and data.
"""
return self.metainfo_values() + self.values()

def all_items(self) -> Iterator[Tuple[str, Any]]:
"""
Returns:
iterator: An iterator object whose element is (key, value) tuple
pairs for ``metainfo`` and ``data``.
"""
for k in self.all_keys():
yield (k, getattr(self, k))

def items(self) -> Iterator[Tuple[str, Any]]:
"""
Returns:
iterator: An iterator object whose element is (key, value) tuple
pairs for ``data``.
"""
for k in self.keys():
yield (k, getattr(self, k))

def metainfo_items(self) -> Iterator[Tuple[str, Any]]:
"""
Returns:
iterator: An iterator object whose element is (key, value) tuple
pairs for ``metainfo``.
"""
for k in self.metainfo_keys():
yield (k, getattr(self, k))

@property
def metainfo(self) -> dict:
"""dict: A dict contains metainfo of current data element."""
return dict(self.metainfo_items())

def __setattr__(self, name: str, value: Any):
"""setattr is only used to set data."""
if name in ("_metainfo_fields", "_data_fields"):
if not hasattr(self, name):
super().__setattr__(name, value)
else:
raise AttributeError(
f"{name} has been used as a "
"private attribute, which is immutable."
)
else:
self.set_field(name=name, value=value, field_type="data", dtype=None)

def __delattr__(self, item: str):
"""Delete the item in dataelement.

Args:
item (str): The key to delete.
"""
if item in ("_metainfo_fields", "_data_fields"):
raise AttributeError(
f"{item} has been used as a " "private attribute, which is immutable."
)
super().__delattr__(item)
if item in self._metainfo_fields:
self._metainfo_fields.remove(item)
elif item in self._data_fields:
self._data_fields.remove(item)

# dict-like methods
__delitem__ = __delattr__

def get(self, key, default=None) -> Any:
"""Get property in data and metainfo as the same as python."""
# Use `getattr()` rather than `self.__dict__.get()` to allow getting
# properties.
return getattr(self, key, default)

def pop(self, *args) -> Any:
"""Pop property in data and metainfo as the same as python."""
assert len(args) < 3, "``pop`` get more than 2 arguments"
name = args[0]
if name in self._metainfo_fields:
self._metainfo_fields.remove(args[0])
return self.__dict__.pop(*args)

elif name in self._data_fields:
self._data_fields.remove(args[0])
return self.__dict__.pop(*args)

# with default value
elif len(args) == 2:
return args[1]
else:
# don't just use 'self.__dict__.pop(*args)' for only popping key in
# metainfo or data
raise KeyError(f"{args[0]} is not contained in metainfo or data")

def __contains__(self, item: str) -> bool:
"""Whether the item is in dataelement.

Args:
item (str): The key to inquire.
"""
return item in self._data_fields or item in self._metainfo_fields

def set_field(
self,
value: Any,
name: str,
dtype: Optional[Union[Type, Tuple[Type, ...]]] = None,
field_type: str = "data",
) -> None:
"""Special method for set union field, used as property.setter
functions."""
assert field_type in ["metainfo", "data"]
if dtype is not None:
assert isinstance(
value, dtype
), f"{value} should be a {dtype} but got {type(value)}"

if field_type == "metainfo":
if name in self._data_fields:
raise AttributeError(
f"Cannot set {name} to be a field of metainfo "
f"because {name} is already a data field"
)
self._metainfo_fields.add(name)
else:
if name in self._metainfo_fields:
raise AttributeError(
f"Cannot set {name} to be a field of data "
f"because {name} is already a metainfo field"
)
self._data_fields.add(name)
super().__setattr__(name, value)

# Tensor-like methods
def to(self, *args, **kwargs) -> "BaseDataElement":
"""Apply same name function to all tensors in data_fields."""
new_data = self.new()
for k, v in self.items():
if hasattr(v, "to"):
v = v.to(*args, **kwargs)
data = {k: v}
new_data.set_data(data)
return new_data

# Tensor-like methods
def cpu(self) -> "BaseDataElement":
"""Convert all tensors to CPU in data."""
new_data = self.new()
for k, v in self.items():
if isinstance(v, (torch.Tensor, BaseDataElement)):
v = v.cpu()
data = {k: v}
new_data.set_data(data)
return new_data

# Tensor-like methods
def cuda(self) -> "BaseDataElement":
"""Convert all tensors to GPU in data."""
new_data = self.new()
for k, v in self.items():
if isinstance(v, (torch.Tensor, BaseDataElement)):
v = v.cuda()
data = {k: v}
new_data.set_data(data)
return new_data

# Tensor-like methods
def npu(self) -> "BaseDataElement":
"""Convert all tensors to NPU in data."""
new_data = self.new()
for k, v in self.items():
if isinstance(v, (torch.Tensor, BaseDataElement)):
v = v.npu()
data = {k: v}
new_data.set_data(data)
return new_data

def mlu(self) -> "BaseDataElement":
"""Convert all tensors to MLU in data."""
new_data = self.new()
for k, v in self.items():
if isinstance(v, (torch.Tensor, BaseDataElement)):
v = v.mlu()
data = {k: v}
new_data.set_data(data)
return new_data

# Tensor-like methods
def detach(self) -> "BaseDataElement":
"""Detach all tensors in data."""
new_data = self.new()
for k, v in self.items():
if isinstance(v, (torch.Tensor, BaseDataElement)):
v = v.detach()
data = {k: v}
new_data.set_data(data)
return new_data

# Tensor-like methods
def numpy(self) -> "BaseDataElement":
"""Convert all tensors to np.ndarray in data."""
new_data = self.new()
for k, v in self.items():
if isinstance(v, (torch.Tensor, BaseDataElement)):
v = v.detach().cpu().numpy()
data = {k: v}
new_data.set_data(data)
return new_data

def to_tensor(self) -> "BaseDataElement":
"""Convert all np.ndarray to tensor in data."""
new_data = self.new()
for k, v in self.items():
data = {}
if isinstance(v, np.ndarray):
v = torch.from_numpy(v)
data[k] = v
elif isinstance(v, BaseDataElement):
v = v.to_tensor()
data[k] = v
new_data.set_data(data)
return new_data

def to_dict(self) -> dict:
"""Convert BaseDataElement to dict."""
return {
k: v.to_dict() if isinstance(v, BaseDataElement) else v
for k, v in self.all_items()
}

def __repr__(self) -> str:
"""Represent the object."""

def _addindent(s_: str, num_spaces: int) -> str:
"""This func is modified from `pytorch` https://github.com/pytorch/
pytorch/blob/b17b2b1cc7b017c3daaeff8cc7ec0f514d42ec37/torch/nn/modu
les/module.py#L29.

Args:
s_ (str): The string to add spaces.
num_spaces (int): The num of space to add.

Returns:
str: The string after add indent.
"""
s = s_.split("\n")
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * " ") + line for line in s]
s = "\n".join(s) # type: ignore
s = first + "\n" + s # type: ignore
return s # type: ignore

def dump(obj: Any) -> str:
"""Represent the object.

Args:
obj (Any): The obj to represent.

Returns:
str: The represented str.
"""
_repr = ""
if isinstance(obj, dict):
for k, v in obj.items():
_repr += f"\n{k}: {_addindent(dump(v), 4)}"
elif isinstance(obj, BaseDataElement):
_repr += "\n\n META INFORMATION"
metainfo_items = dict(obj.metainfo_items())
_repr += _addindent(dump(metainfo_items), 4)
_repr += "\n\n DATA FIELDS"
items = dict(obj.items())
_repr += _addindent(dump(items), 4)
classname = obj.__class__.__name__
_repr = f"<{classname}({_repr}\n) at {hex(id(obj))}>"
else:
_repr += repr(obj)
return _repr

return dump(self)

+ 321
- 0
abl/structures/list_data.py View File

@@ -0,0 +1,321 @@
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from collections.abc import Sized
from typing import Any, List, Union

import numpy as np
import torch

from ..utils import flatten as flatten_list
from ..utils import to_hashable
from .base_data_element import BaseDataElement

BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor]
LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor]

IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndarray]


# Modified from
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa
class ListData(BaseDataElement):
"""Data structure for instance-level annotations or predictions.

Subclass of :class:`BaseDataElement`. All value in `data_fields`
should have the same length. This design refer to
https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501
ListData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value
in data field can be base data structure such as `torch.Tensor`, `numpy.ndarray`, `list`, `str`, `tuple`,
and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes.

Examples:
>>> # custom data structure
>>> class TmpObject:
... def __init__(self, tmp) -> None:
... assert isinstance(tmp, list)
... self.tmp = tmp
... def __len__(self):
... return len(self.tmp)
... def __getitem__(self, item):
... if isinstance(item, int):
... if item >= len(self) or item < -len(self): # type:ignore
... raise IndexError(f'Index {item} out of range!')
... else:
... # keep the dimension
... item = slice(item, None, len(self))
... return TmpObject(self.tmp[item])
... @staticmethod
... def cat(tmp_objs):
... assert all(isinstance(results, TmpObject) for results in tmp_objs)
... if len(tmp_objs) == 1:
... return tmp_objs[0]
... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs]
... tmp_list = list(itertools.chain(*tmp_list))
... new_data = TmpObject(tmp_list)
... return new_data
... def __repr__(self):
... return str(self.tmp)
>>> from mmengine.structures import ListData
>>> import numpy as np
>>> import torch
>>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
>>> instance_data = ListData(metainfo=img_meta)
>>> 'img_shape' in instance_data
True
>>> instance_data.det_labels = torch.LongTensor([2, 3])
>>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7])
>>> instance_data.bboxes = torch.rand((2, 4))
>>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]])
>>> len(instance_data)
2
>>> print(instance_data)
<ListData(
META INFORMATION
img_shape: (800, 1196, 3)
pad_shape: (800, 1216, 3)
DATA FIELDS
det_labels: tensor([2, 3])
det_scores: tensor([0.8000, 0.7000])
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188],
[0.8101, 0.3105, 0.5123, 0.6263]])
polygons: [[1, 2, 3, 4], [5, 6, 7, 8]]
) at 0x7fb492de6280>
>>> sorted_results = instance_data[instance_data.det_scores.sort().indices]
>>> sorted_results.det_scores
tensor([0.7000, 0.8000])
>>> print(instance_data[instance_data.det_scores > 0.75])
<ListData(
META INFORMATION
img_shape: (800, 1196, 3)
pad_shape: (800, 1216, 3)
DATA FIELDS
det_labels: tensor([2])
det_scores: tensor([0.8000])
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]])
polygons: [[1, 2, 3, 4]]
) at 0x7f64ecf0ec40>
>>> print(instance_data[instance_data.det_scores > 1])
<ListData(
META INFORMATION
img_shape: (800, 1196, 3)
pad_shape: (800, 1216, 3)
DATA FIELDS
det_labels: tensor([], dtype=torch.int64)
det_scores: tensor([])
bboxes: tensor([], size=(0, 4))
polygons: []
) at 0x7f660a6a7f70>
>>> print(instance_data.cat([instance_data, instance_data]))
<ListData(
META INFORMATION
img_shape: (800, 1196, 3)
pad_shape: (800, 1216, 3)
DATA FIELDS
det_labels: tensor([2, 3, 2, 3])
det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000])
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188],
[0.8101, 0.3105, 0.5123, 0.6263],
[0.4997, 0.7707, 0.0595, 0.4188],
[0.8101, 0.3105, 0.5123, 0.6263]])
polygons: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]]
) at 0x7f203542feb0>
"""

def __setattr__(self, name: str, value: list):
"""setattr is only used to set data.

The value must have the attribute of `__len__` and have the same length
of `ListData`.
"""
if name in ("_metainfo_fields", "_data_fields"):
if not hasattr(self, name):
super().__setattr__(name, value)
else:
raise AttributeError(
f"{name} has been used as a "
"private attribute, which is immutable."
)

else:
assert isinstance(value, list), "value must be of type `list`"

if len(self) > 0:
assert len(value) == len(self), (
"The length of "
f"values {len(value)} is "
"not consistent with "
"the length of this "
":obj:`ListData` "
f"{len(self)}"
)
super().__setattr__(name, value)

__setitem__ = __setattr__

def __getitem__(self, item: IndexType) -> "ListData":
"""
Args:
item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`,
:obj:`torch.LongTensor`, :obj:`torch.BoolTensor`):
Get the corresponding values according to item.

Returns:
:obj:`ListData`: Corresponding values.
"""
assert isinstance(item, IndexType.__args__)
if isinstance(item, list):
item = np.array(item)
if isinstance(item, np.ndarray):
# The default int type of numpy is platform dependent, int32 for
# windows and int64 for linux. `torch.Tensor` requires the index
# should be int64, therefore we simply convert it to int64 here.
# More details in https://github.com/numpy/numpy/issues/9464
item = item.astype(np.int64) if item.dtype == np.int32 else item
item = torch.from_numpy(item)

if isinstance(item, str):
return getattr(self, item)

if isinstance(item, int):
if item >= len(self) or item < -len(self): # type:ignore
raise IndexError(f"Index {item} out of range!")
else:
# keep the dimension
item = slice(item, None, len(self))

new_data = self.__class__(metainfo=self.metainfo)
if isinstance(item, torch.Tensor):
assert item.dim() == 1, (
"Only support to get the" " values along the first dimension."
)
if isinstance(item, BoolTypeTensor.__args__):
assert len(item) == len(self), (
"The shape of the "
"input(BoolTensor) "
f"{len(item)} "
"does not match the shape "
"of the indexed tensor "
"in results_field "
f"{len(self)} at "
"first dimension."
)

for k, v in self.items():
if isinstance(v, torch.Tensor):
new_data[k] = v[item]
elif isinstance(v, np.ndarray):
new_data[k] = v[item.cpu().numpy()]
elif isinstance(v, (str, list, tuple)) or (
hasattr(v, "__getitem__") and hasattr(v, "cat")
):
# convert to indexes from BoolTensor
if isinstance(item, BoolTypeTensor.__args__):
indexes = torch.nonzero(item).view(-1).cpu().numpy().tolist()
else:
indexes = item.cpu().numpy().tolist()
slice_list = []
if indexes:
for index in indexes:
slice_list.append(slice(index, None, len(v)))
else:
slice_list.append(slice(None, 0, None))
r_list = [v[s] for s in slice_list]
if isinstance(v, (str, list, tuple)):
new_value = r_list[0]
for r in r_list[1:]:
new_value = new_value + r
else:
new_value = v.cat(r_list)
new_data[k] = new_value
else:
raise ValueError(
f"The type of `{k}` is `{type(v)}`, which has no "
"attribute of `cat`, so it does not "
"support slice with `bool`"
)

else:
# item is a slice
for k, v in self.items():
new_data[k] = v[item]
return new_data # type:ignore

@staticmethod
def cat(instances_list: List["ListData"]) -> "ListData":
"""Concat the instances of all :obj:`ListData` in the list.

Note: To ensure that cat returns as expected, make sure that
all elements in the list must have exactly the same keys.

Args:
instances_list (list[:obj:`ListData`]): A list
of :obj:`ListData`.

Returns:
:obj:`ListData`
"""
assert all(isinstance(results, ListData) for results in instances_list)
assert len(instances_list) > 0
if len(instances_list) == 1:
return instances_list[0]

# metainfo and data_fields must be exactly the
# same for each element to avoid exceptions.
field_keys_list = [instances.all_keys() for instances in instances_list]
assert len({len(field_keys) for field_keys in field_keys_list}) == 1 and len(
set(itertools.chain(*field_keys_list))
) == len(field_keys_list[0]), (
"There are different keys in "
"`instances_list`, which may "
"cause the cat operation "
"to fail. Please make sure all "
"elements in `instances_list` "
"have the exact same key."
)

new_data = instances_list[0].__class__(metainfo=instances_list[0].metainfo)
for k in instances_list[0].keys():
values = [results[k] for results in instances_list]
v0 = values[0]
if isinstance(v0, torch.Tensor):
new_values = torch.cat(values, dim=0)
elif isinstance(v0, np.ndarray):
new_values = np.concatenate(values, axis=0)
elif isinstance(v0, (str, list, tuple)):
new_values = v0[:]
for v in values[1:]:
new_values += v
elif hasattr(v0, "cat"):
new_values = v0.cat(values)
else:
raise ValueError(
f"The type of `{k}` is `{type(v0)}` which has no "
"attribute of `cat`"
)
new_data[k] = new_values
return new_data # type:ignore

def flatten(self, item: IndexType) -> List:
"""Flatten self[item].

Returns:
list: Flattened data fields.
"""
return flatten_list(self[item])
def elements_num(self, item: IndexType) -> int:
"""int: The number of elements in self[item]."""
return len(self.flatten(item))
def to_tuple(self, item: IndexType) -> tuple:
"""tuple: The data fields in self[item] converted to tuple."""
return to_hashable(self[item])
def __len__(self) -> int:
"""int: The length of ListData."""
if len(self._data_fields) > 0:
one_element = next(iter(self._data_fields))
return len(getattr(self, one_element))
# return len(self.values()[0])
else:
return 0

+ 2
- 1
abl/utils/__init__.py View File

@@ -1,2 +1,3 @@
from .cache import Cache
from .logger import ABLLogger, print_log
from .utils import *
from .utils import *

+ 112
- 0
abl/utils/cache.py View File

@@ -0,0 +1,112 @@
import pickle
from os import PathLike
from pathlib import Path
from typing import Callable, Generic, Hashable, TypeVar, Union

from .logger import print_log

K = TypeVar("K")
T = TypeVar("T")
PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields


class Cache(Generic[K, T]):
def __init__(
self,
func: Callable[[K], T],
cache: bool,
cache_file: Union[None, str, PathLike],
key_func: Callable[[K], Hashable] = lambda x: x,
max_size: int = 4096,
):
"""Create cache

:param func: Function this cache evaluates
:param cache: If true, do in memory caching.
:param cache_root: If not None, cache to files at the provided path.
:param key_func: Convert the key into a hashable object if needed
"""
self.func = func
self.key_func = key_func
self.cache = cache
if cache is True or cache_file is not None:
print_log("Caching is activated", logger="current")
self._init_cache(cache_file, max_size)
self.first = self.get_from_dict
else:
self.first = self.func

def __getitem__(self, item: K, *args) -> T:
return self.first(item, *args)

def invalidate(self):
"""Invalidate entire cache."""
self.cache_dict.clear()
if self.cache_file:
for p in self.cache_root.iterdir():
p.unlink()

def _init_cache(self, cache_file, max_size):
self.cache = True
self.cache_dict = dict()

self.hits, self.misses, self.maxsize = 0, 0, max_size
self.full = False
self.root = [] # root of the circular doubly linked list
self.root[:] = [self.root, self.root, None, None]

if cache_file is not None:
with open(cache_file, "rb") as f:
cache_dict_from_file = pickle.load(f)
self.maxsize += len(cache_dict_from_file)
print_log(
f"Max size of the cache has been enlarged to {self.maxsize}.", logger="current"
)
for cache_key, result in cache_dict_from_file.items():
last = self.root[PREV]
link = [last, self.root, cache_key, result]
last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link

def get(self, item: K, *args) -> T:
return self.first(item, *args)

def get_from_dict(self, item: K, *args) -> T:
"""Implements dict based cache."""
cache_key = (self.key_func(item), *args)
link = self.cache_dict.get(cache_key)
if link is not None:
# Move the link to the front of the circular queue
link_prev, link_next, _key, result = link
link_prev[NEXT] = link_next
link_next[PREV] = link_prev
last = self.root[PREV]
last[NEXT] = self.root[PREV] = link
link[PREV] = last
link[NEXT] = self.root
self.hits += 1
return result
self.misses += 1

result = self.func(item, *args)

if self.full:
# Use the old root to store the new key and result.
oldroot = self.root
oldroot[KEY] = cache_key
oldroot[RESULT] = result
# Empty the oldest link and make it the new root.
self.root = oldroot[NEXT]
oldkey = self.root[KEY]
oldresult = self.root[RESULT]
self.root[KEY] = self.root[RESULT] = None
# Now update the cache dictionary.
del self.cache_dict[oldkey]
self.cache_dict[cache_key] = oldroot
else:
# Put result in a new link at the front of the queue.
last = self.root[PREV]
link = [last, self.root, cache_key, result]
last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link
if isinstance(self.maxsize, int):
self.full = len(self.cache_dict) >= self.maxsize
return result

+ 2
- 1
abl/utils/utils.py View File

@@ -1,6 +1,7 @@
import numpy as np
from itertools import chain

import numpy as np


def flatten(nested_list):
"""


+ 1
- 2
docs/conf.py View File

@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-

import sys
import os
import re
import sys

if not 'READTHEDOCS' in os.environ:
sys.path.insert(0, os.path.abspath('..'))
@@ -11,7 +11,6 @@ sys.path.append(os.path.abspath('./ABL/'))
# from sphinx.locale import _
from sphinx_rtd_theme import __version__


project = u'ABL'
slug = re.sub(r'\W+', '-', project.lower())
author = u'Yu-Xuan Huang, Wen-Chao Hu, En-Hao Gao'


+ 5
- 4
examples/hed/datasets/get_hed.py View File

@@ -1,11 +1,12 @@
import os
import cv2
import torch
import torchvision
import pickle
import numpy as np
import random
from collections import defaultdict

import cv2
import numpy as np
import torch
import torchvision
from torch.utils.data import Dataset
from torchvision.transforms import transforms



+ 6
- 6
examples/hed/hed_bridge.py View File

@@ -1,18 +1,18 @@
import os
from collections import defaultdict

import torch
from torch.utils.data import DataLoader

from abl.reasoning import ReasonerBase
from abl.learning import ABLModel, BasicNN
from abl.bridge import SimpleBridge
from abl.evaluation import BaseMetric
from abl.dataset import BridgeDataset, RegressionDataset
from abl.evaluation import BaseMetric
from abl.learning import ABLModel, BasicNN
from abl.reasoning import ReasonerBase
from abl.utils import print_log

from examples.hed.utils import gen_mappings, InfiniteSampler
from examples.models.nn import SymbolNetAutoencoder
from examples.hed.datasets.get_hed import get_pretrain_data
from examples.hed.utils import InfiniteSampler, gen_mappings
from examples.models.nn import SymbolNetAutoencoder


class HEDBridge(SimpleBridge):


+ 2
- 2
examples/hed/hed_example.ipynb View File

@@ -12,7 +12,7 @@
"\n",
"from abl.reasoning import ReasonerBase, prolog_KB\n",
"from abl.learning import BasicNN, ABLModel\n",
"from abl.evaluation import SymbolMetric, ABLMetric\n",
"from abl.evaluation import SymbolMetric, SemanticsMetric\n",
"from abl.utils import ABLLogger, reform_idx\n",
"\n",
"from examples.hed.hed_bridge import HEDBridge\n",
@@ -206,7 +206,7 @@
"outputs": [],
"source": [
"# Add metric\n",
"metric = [SymbolMetric(prefix=\"hed\"), ABLMetric(prefix=\"hed\")]"
"metric = [SymbolMetric(prefix=\"hed\"), SemanticsMetric(prefix=\"hed\")]"
]
},
{


+ 1
- 1
examples/hed/utils.py View File

@@ -1,6 +1,6 @@
import numpy as np
import torch
import torch.nn as nn
import numpy as np
import torch.utils.data.sampler as sampler




+ 5
- 5
examples/hwf/datasets/get_hwf.py View File

@@ -1,10 +1,10 @@
import os
import json
import os.path as osp

from PIL import Image
from torchvision.transforms import transforms

CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))
CURRENT_DIR = osp.abspath(osp.dirname(__file__))

img_transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (1,))]
@@ -15,7 +15,7 @@ def get_data(file, get_pseudo_label):
X, Y = [], []
if get_pseudo_label:
Z = []
img_dir = os.path.join(CURRENT_DIR, "data/Handwritten_Math_Symbols/")
img_dir = osp.join(CURRENT_DIR, "data/Handwritten_Math_Symbols/")
with open(file) as f:
data = json.load(f)
for idx in range(len(data)):
@@ -40,8 +40,8 @@ def get_data(file, get_pseudo_label):

def get_hwf(train=True, get_gt_pseudo_label=False):
if train:
file = os.path.join(CURRENT_DIR, "data/expr_train.json")
file = osp.join(CURRENT_DIR, "data/expr_train.json")
else:
file = os.path.join(CURRENT_DIR, "data/expr_test.json")
file = osp.join(CURRENT_DIR, "data/expr_test.json")

return get_data(file, get_gt_pseudo_label)

+ 14
- 45
examples/hwf/hwf_example.ipynb View File

@@ -6,19 +6,19 @@
"metadata": {},
"outputs": [],
"source": [
"import os.path as osp\n",
"\n",
"import torch\n",
"import numpy as np\n",
"import torch.nn as nn\n",
"import os.path as osp\n",
"\n",
"from abl.reasoning import ReasonerBase, KBBase\n",
"from abl.learning import BasicNN, ABLModel\n",
"from abl.bridge import SimpleBridge\n",
"from abl.evaluation import SymbolMetric, SemanticsMetric\n",
"from abl.evaluation import SemanticsMetric, SymbolMetric\n",
"from abl.learning import ABLModel, BasicNN\n",
"from abl.reasoning import ReasonerBase\n",
"from abl.utils import ABLLogger, print_log\n",
"\n",
"from examples.models.nn import SymbolNet\n",
"from datasets.get_hwf import get_hwf"
"from examples.hwf.datasets.get_hwf import get_hwf\n",
"from examples.hwf.hwf_kb import HWF_KB\n",
"from examples.models.nn import SymbolNet"
]
},
{
@@ -50,37 +50,8 @@
"outputs": [],
"source": [
"# Initialize knowledge base and abducer\n",
"class HWF_KB(KBBase):\n",
" def __init__(\n",
" self, \n",
" pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \n",
" prebuild_GKB=False,\n",
" GKB_len_list=[1, 3, 5, 7],\n",
" max_err=1e-3,\n",
" use_cache=True\n",
" ):\n",
" super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)\n",
"\n",
" def _valid_candidate(self, formula):\n",
" if len(formula) % 2 == 0:\n",
" return False\n",
" for i in range(len(formula)):\n",
" if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']:\n",
" return False\n",
" if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']:\n",
" return False\n",
" return True\n",
"\n",
" def logic_forward(self, formula):\n",
" if not self._valid_candidate(formula):\n",
" return np.inf\n",
" mapping = {str(i): str(i) for i in range(1, 10)}\n",
" mapping.update({'+': '+', '-': '-', 'times': '*', 'div': '/'})\n",
" formula = [mapping[f] for f in formula]\n",
" return eval(''.join(formula))\n",
"\n",
"kb = HWF_KB(prebuild_GKB=True)\n",
"abducer = ReasonerBase(kb, dist_func='confidence')"
"kb = HWF_KB()\n",
"abducer = ReasonerBase(kb, dist_func=\"confidence\")"
]
},
{
@@ -117,10 +88,8 @@
" criterion=criterion,\n",
" optimizer=optimizer,\n",
" device=device,\n",
" save_interval=1,\n",
" save_dir=weights_dir,\n",
" batch_size=128,\n",
" num_epochs=3,\n",
" num_epochs=1,\n",
")"
]
},
@@ -131,7 +100,7 @@
"outputs": [],
"source": [
"# Initialize ABL model\n",
"# The main function of the ABL model is to serialize data and \n",
"# The main function of the ABL model is to serialize data and\n",
"# provide a unified interface for different machine learning models\n",
"model = ABLModel(base_model)"
]
@@ -151,7 +120,7 @@
"outputs": [],
"source": [
"# Add metric\n",
"metric_list = [SymbolMetric(prefix=\"hwf\"), SemanticsMetric(prefix=\"hwf\")]"
"metric_list = [SymbolMetric(prefix=\"hwf\"), SemanticsMetric(kb=kb, prefix=\"hwf\")]"
]
},
{
@@ -204,7 +173,7 @@
"metadata": {},
"outputs": [],
"source": [
"bridge.train(train_data, epochs=3, batch_size=1000)\n",
"bridge.train(train_data, loops=5, segment_size=1000, save_interval=1, save_dir=weights_dir)\n",
"bridge.test(test_data)"
]
}


+ 129
- 0
examples/hwf/hwf_kb.py View File

@@ -0,0 +1,129 @@
import bisect
from collections import defaultdict
from itertools import product
from multiprocessing import Pool
from typing import Any, Hashable, List

import numpy as np

from abl.reasoning import GroundKB
from abl.structures import ListData
from abl.utils import hamming_dist


class HWF_KB(GroundKB):
def __init__(
self,
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "times", "div"],
GKB_len_list=[1, 3, 5, 7],
max_err=1e-10,
):
self.GKB_len_list = GKB_len_list
self.max_err = max_err
self.label2evaluable = {str(i): str(i) for i in range(1, 10)}
self.label2evaluable.update({"+": "+", "-": "-", "times": "*", "div": "/"})
super().__init__(pseudo_label_list)

def logic_forward(self, data_sample: ListData):
if not self._valid_candidate(data_sample):
return None
formula = data_sample["pred_pseudo_label"][0]
formula = [self.label2evaluable[f] for f in formula]
data_sample["Y"] = [eval("".join(formula))]
return data_sample["Y"][0]

def check_equal(self, data_sample: ListData, y: Any):
if not self._valid_candidate(data_sample):
return False
formula = data_sample["pred_pseudo_label"][0]
formula = [self.label2evaluable[f] for f in formula]
return abs(eval("".join(formula)) - y) < self.max_err

def construct_base(self) -> dict:
X, Y = [], []
for length in self.GKB_len_list:
arg_list = []
for pre_x in self.pseudo_label_list:
post_x_it = product(self.pseudo_label_list, repeat=length - 1)
arg_list.append((pre_x, post_x_it))
with Pool(processes=len(arg_list)) as pool:
ret_list = pool.map(self._get_XY_list, arg_list)
for XY_list in ret_list:
if len(XY_list) == 0:
continue
part_X, part_Y = zip(*XY_list)
X.extend(part_X)
Y.extend(part_Y)
if Y and isinstance(Y[0], (int, float)):
X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1]))
GKB = {}
for x, y in zip(X, Y):
GKB.setdefault(len(x), defaultdict(list))[y].append(x)
return GKB

@staticmethod
def get_key(data_sample: ListData) -> Hashable:
return (data_sample["symbol_num"], data_sample["Y"][0])

def key2candidates(self, key: Hashable) -> List[List[Any]]:
equation_len, y = key
if self.max_err == 0:
return self.GKB[equation_len][y]
else:
potential_candidates = self.GKB[equation_len]
key_list = list(potential_candidates.keys())
key_idx = bisect.bisect_left(key_list, y)

all_candidates = []
for idx in range(key_idx - 1, -1, -1):
k = key_list[idx]
if abs(k - y) <= self.max_err:
all_candidates.extend(potential_candidates[k])
else:
break

for idx in range(key_idx, len(key_list)):
k = key_list[idx]
if abs(k - y) <= self.max_err:
all_candidates.extend(potential_candidates[k])
else:
break
return all_candidates

def filter_candidates(
self,
data_sample: ListData,
candidates: List[List[Any]],
max_revision_num: int,
require_more_revision: int = 0,
) -> List[List[Any]]:
cost_list = hamming_dist(data_sample["pred_pseudo_label"][0], candidates)
min_revision_num = np.min(cost_list)
revision_num = min(max_revision_num, min_revision_num + require_more_revision)
idxs = np.where(cost_list <= revision_num)[0]
filtered_candidates = [candidates[idx] for idx in idxs]
return filtered_candidates

# TODO: change return value to List[ListData]
def _get_XY_list(self, args):
pre_x, post_x_it = args[0], args[1]
XY_list = []
for post_x in post_x_it:
x = (pre_x,) + post_x
data_sample = ListData(pred_pseudo_label=[x])
y = self.logic_forward(data_sample)
if y is not None:
XY_list.append((x, y))
return XY_list

@staticmethod
def _valid_candidate(data_sample):
formula = data_sample["pred_pseudo_label"][0]
if len(formula) % 2 == 0:
return False
for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]:
return False
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]:
return False
return True

+ 25
- 15
examples/mnist_add/datasets/get_mnist_add.py View File

@@ -1,39 +1,49 @@
import os.path as osp

import torchvision
from torchvision.transforms import transforms

CURRENT_DIR = osp.abspath(osp.dirname(__file__))


def get_data(file, img_dataset, get_pseudo_label):
X = []
X, Y = [], []
if get_pseudo_label:
Z = []
Y = []
with open(file) as f:
for line in f:
line = line.strip().split(' ')
# if len(X) == 1000:
# break
line = line.strip().split(" ")
X.append([img_dataset[int(line[0])][0], img_dataset[int(line[1])][0]])
if get_pseudo_label:
Z.append([img_dataset[int(line[0])][1], img_dataset[int(line[1])][1]])
Y.append(int(line[2]))
if get_pseudo_label:
return X, Z, Y
else:
return X, None, Y

def get_mnist_add(train = True, get_pseudo_label = False):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))])
img_dataset = torchvision.datasets.MNIST(root='./datasets/', train=train, download=True, transform=transform)

def get_mnist_add(train=True, get_pseudo_label=False):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
img_dataset = torchvision.datasets.MNIST(
root=CURRENT_DIR, train=train, download=True, transform=transform
)

if train:
file = './datasets/train_data.txt'
file = osp.join(CURRENT_DIR, "train_data.txt")
else:
file = './datasets/test_data.txt'
file = osp.join(CURRENT_DIR, "test_data.txt")
return get_data(file, img_dataset, get_pseudo_label)

if __name__ == "__main__":
train_X, train_Y = get_mnist_add(train = True)
test_X, test_Y = get_mnist_add(train = False)
train_X, train_Z, train_Y = get_mnist_add(train=True)
test_X, test_Z, test_Y = get_mnist_add(train=False)
print(len(train_X), len(test_X))
print(train_X[0][0].shape, train_X[0][1].shape, train_Y[0])

+ 37
- 37
examples/mnist_add/mnist_add_example.ipynb View File

@@ -2,32 +2,37 @@
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"import torch\n",
"import os.path as osp\n",
"\n",
"from abl.reasoning import ReasonerBase, KBBase\n",
"import torch\n",
"import torch.nn as nn\n",
"\n",
"from abl.learning import BasicNN, ABLModel\n",
"from abl.bridge import SimpleBridge\n",
"from abl.evaluation import SymbolMetric, ABLMetric\n",
"from abl.utils import ABLLogger\n",
"\n",
"from models.nn import LeNet5\n",
"from examples.mnist_add.datasets.get_mnist_add import get_mnist_add"
"from abl.evaluation import SemanticsMetric, SymbolMetric\n",
"from abl.learning import ABLModel, BasicNN\n",
"from abl.reasoning import ReasonerBase\n",
"from abl.utils import ABLLogger, print_log\n",
"from examples.mnist_add.datasets.get_mnist_add import get_mnist_add\n",
"from examples.mnist_add.mnist_add_kb import AddKB\n",
"from examples.models.nn import LeNet5"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize logger\n",
"logger = ABLLogger.get_instance(\"abl\")"
"print_log(\"Abductive Learning on the MNIST Add example.\", logger=\"current\")\n",
"\n",
"# Retrieve the directory of the Log file and define the directory for saving the model weights.\n",
"log_dir = ABLLogger.get_current_instance().log_dir\n",
"weights_dir = osp.join(log_dir, \"weights\")"
]
},
{
@@ -40,22 +45,19 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize knowledge base and abducer\n",
"class add_KB(KBBase):\n",
" def __init__(self, pseudo_label_list=list(range(10)), prebuild_GKB=False, GKB_len_list=[2], max_err=0, use_cache=True):\n",
" super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)\n",
"\n",
" def logic_forward(self, nums):\n",
" return sum(nums)\n",
"kb = AddKB()\n",
"\n",
"kb = add_KB(prebuild_GKB=True)\n",
"# If use cache, get_key should be implemented in the abducer\n",
"class AddAbducer(ReasonerBase):\n",
" def get_key(self, data_sample):\n",
" return (data_sample.to_tuple(\"pred_pseudo_label\"), data_sample[\"Y\"][0])\n",
"\n",
"# kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='datasets/mnist_add/add.pl')\n",
"abducer = ReasonerBase(kb, dist_func=\"confidence\")"
"abducer = AddAbducer(kb, dist_func=\"confidence\", use_cache=True)"
]
},
{
@@ -68,7 +70,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -81,19 +83,17 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize BasicNN\n",
"# The function of BasicNN is to wrap NN models into the form of an sklearn estimator\n",
"base_model = BasicNN(\n",
" cls,\n",
" criterion,\n",
" optimizer,\n",
" device,\n",
" save_interval=1,\n",
" save_dir=logger.save_dir,\n",
" model=cls,\n",
" criterion=criterion,\n",
" optimizer=optimizer,\n",
" device=device,\n",
" batch_size=32,\n",
" num_epochs=1,\n",
")"
@@ -109,12 +109,12 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize ABL model\n",
"# The main function of the ABL model is to serialize data and \n",
"# The main function of the ABL model is to serialize data and\n",
"# provide a unified interface for different machine learning models\n",
"model = ABLModel(base_model)"
]
@@ -129,12 +129,12 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Add metric\n",
"metric = [SymbolMetric(prefix=\"mnist_add\"), ABLMetric(prefix=\"mnist_add\")]"
"metric = [SymbolMetric(prefix=\"mnist_add\"), SemanticsMetric(kb=kb, prefix=\"mnist_add\")]"
]
},
{
@@ -147,7 +147,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -187,7 +187,7 @@
"metadata": {},
"outputs": [],
"source": [
"bridge.train(train_data, epochs=5, batch_size=10000)\n",
"bridge.train(train_data, loops=10, segment_size=10000)\n",
"bridge.test(test_data)"
]
}
@@ -208,7 +208,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.8.16"
},
"orig_nbformat": 4,
"vscode": {


+ 17
- 0
examples/mnist_add/mnist_add_kb.py View File

@@ -0,0 +1,17 @@
from typing import Any

from abl.reasoning import SearchBasedKB
from abl.structures import ListData


class AddKB(SearchBasedKB):
def __init__(self, pseudo_label_list=list(range(10))):
super().__init__(
pseudo_label_list=pseudo_label_list
)

def check_equal(self, data_sample: ListData, y: Any):
return self.logic_forward(data_sample) == y

def logic_forward(self, data_sample):
return sum(data_sample["pred_pseudo_label"][0])

+ 4
- 5
examples/models/nn.py View File

@@ -11,8 +11,8 @@
# ================================================================#


import torch
import numpy as np
import torch
from torch import nn


@@ -66,7 +66,8 @@ class SymbolNet(nn.Module):
num_features = 64 * (image_size[0] // 4 - 1) * (image_size[1] // 4 - 1)
self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU())
self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU())
self.fc3 = nn.Sequential(nn.Linear(84, num_classes), nn.Softmax(dim=1))
# self.fc3 = nn.Sequential(nn.Linear(84, num_classes), nn.Softmax(dim=1))
self.fc3 = nn.Sequential(nn.Linear(84, num_classes))

def forward(self, x):
x = self.conv1(x)
@@ -84,9 +85,7 @@ class SymbolNetAutoencoder(nn.Module):
self.base_model = SymbolNet(num_classes, image_size)
self.softmax = nn.Softmax(dim=1)
self.fc1 = nn.Sequential(nn.Linear(num_classes, 100), nn.ReLU())
self.fc2 = nn.Sequential(
nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU()
)
self.fc2 = nn.Sequential(nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU())

def forward(self, x):
x = self.base_model(x)


+ 1
- 0
setup.py View File

@@ -1,4 +1,5 @@
import os

from setuptools import find_packages, setup




+ 403
- 0
tests/test_reasoning.py View File

@@ -0,0 +1,403 @@

from abl.reasoning import ReasonerBase, BaseKB, GroundKB, PrologBasedKB

if __name__ == "__main__":
prob1 = [
[
[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
]
]

prob2 = [
[
[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
]
]

class add_KB(BaseKB):
def __init__(self, pseudo_label_list=list(range(10)), use_cache=True):
super().__init__(pseudo_label_list, use_cache=use_cache)

def logic_forward(self, nums):
return sum(nums)

class add_GroundKB(GroundKB):
def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]):
super().__init__(pseudo_label_list, GKB_len_list)

def logic_forward(self, nums):
return sum(nums)

def test_add(reasoner):
res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0)
print(res)
print()

print("add_KB with GKB:")
kb = add_GroundKB()
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("add_KB without GKB:")
kb = add_KB()
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("add_KB without GKB, no cache")
kb = add_KB(use_cache=False)
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("PrologBasedKB with add.pl:")
kb = PrologBasedKB(
pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/datasets/add.pl"
)
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("PrologBasedKB with add.pl using zoopt:")
kb = PrologBasedKB(
pseudo_label_list=list(range(10)),
pl_file="examples/mnist_add/datasets/add.pl",
)
reasoner = ReasonerBase(kb, "confidence", use_zoopt=True)
test_add(reasoner)

print("add_KB with multiple inputs at once:")
multiple_prob = [
[
[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
],
[
[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
],
]

kb = add_KB()
reasoner = ReasonerBase(kb, "confidence")
res = reasoner.batch_abduce(
multiple_prob,
[[1, 1], [1, 2]],
[4, 8],
max_revision=2,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
multiple_prob,
[[1, 1], [1, 2]],
[4, 8],
max_revision=2,
require_more_revision=1,
)
print(res)
print()

class HWF_KB(BaseKB):
def __init__(
self,
pseudo_label_list=[
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"+",
"-",
"times",
"div",
],
max_err=1e-3,
):
super().__init__(pseudo_label_list, max_err)

def _valid_candidate(self, formula):
if len(formula) % 2 == 0:
return False
for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in [
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
]:
return False
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]:
return False
return True

def logic_forward(self, formula):
if not self._valid_candidate(formula):
return np.inf
mapping = {str(i): str(i) for i in range(1, 10)}
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"})
formula = [mapping[f] for f in formula]
return eval("".join(formula))

class HWF_GroundKB(GroundKB):
def __init__(
self,
pseudo_label_list=[
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"+",
"-",
"times",
"div",
],
GKB_len_list=[1, 3, 5, 7],
max_err=1e-3,
):
super().__init__(pseudo_label_list, GKB_len_list, max_err)

def _valid_candidate(self, formula):
if len(formula) % 2 == 0:
return False
for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in [
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
]:
return False
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]:
return False
return True

def logic_forward(self, formula):
if not self._valid_candidate(formula):
return np.inf
mapping = {str(i): str(i) for i in range(1, 10)}
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"})
formula = [mapping[f] for f in formula]
return eval("".join(formula))

def test_hwf(reasoner):
res = reasoner.batch_abduce(
[None],
[["5", "+", "2"]],
[3],
max_revision=2,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[None],
[["5", "+", "9"]],
[65],
max_revision=3,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[None],
[["5", "8", "8", "8", "8"]],
[3.17],
max_revision=5,
require_more_revision=3,
)
print(res)
print()

def test_hwf_multiple(reasoner, max_revisions):
res = reasoner.batch_abduce(
[None, None],
[["5", "+", "2"], ["5", "+", "9"]],
[3, 64],
max_revision=max_revisions[0],
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[None, None],
[["5", "+", "2"], ["5", "+", "9"]],
[3, 64],
max_revision=max_revisions[1],
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[None, None],
[["5", "+", "2"], ["5", "+", "9"]],
[3, 65],
max_revision=max_revisions[2],
require_more_revision=0,
)
print(res)
print()

print("HWF_KB with GKB, max_err=0.1")
kb = HWF_GroundKB(GKB_len_list=[1, 3, 5], max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HWF_KB without GKB, max_err=0.1")
kb = HWF_KB(max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HWF_KB with GKB, max_err=1")
kb = HWF_GroundKB(GKB_len_list=[1, 3, 5], max_err=1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HWF_KB without GKB, max_err=1")
kb = HWF_KB(max_err=1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HWF_KB with multiple inputs at once:")
kb = HWF_KB(max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf_multiple(reasoner, max_revisions=[1, 3, 3])

print("max_revision is float")
test_hwf_multiple(reasoner, max_revisions=[0.5, 0.9, 0.9])

class HED_prolog_KB(PrologBasedKB):
def __init__(self, pseudo_label_list, pl_file):
super().__init__(pseudo_label_list, pl_file)

def consist_rule(self, exs, rules):
rules = str(rules).replace("'", "")
pl_query = "eval_inst_feature(%s, %s)." % (exs, rules)
return len(list(self.prolog.query(pl_query))) != 0

def abduce_rules(self, pred_res):
pl_query = "consistent_inst_feature(%s, X)." % pred_res
prolog_result = list(self.prolog.query(pl_query))
if len(prolog_result) == 0:
return None
prolog_rules = prolog_result[0]["X"]
rules = [rule.value for rule in prolog_rules]
return rules

class HED_Reasoner(ReasonerBase):
def __init__(self, kb, dist_func="hamming"):
super().__init__(kb, dist_func, use_zoopt=True)

def _revise_at_idxs(self, pred_res, y, all_revision_flag, idxs):
pred = []
k = []
revision_flag = []
for idx in idxs:
pred.append(pred_res[idx])
k.append(y[idx])
revision_flag += list(all_revision_flag[idx])
revision_idx = np.where(np.array(revision_flag) != 0)[0]
candidate = self.revise_at_idx(pred, k, revision_idx)
return candidate

def zoopt_revision_score(self, symbol_num, pred_res, pred_prob, y, sol):
all_revision_flag = reform_idx(sol.get_x(), pred_res)
lefted_idxs = [i for i in range(len(pred_res))]
candidate_size = []
while lefted_idxs:
idxs = []
idxs.append(lefted_idxs.pop(0))
max_candidate_idxs = []
found = False
for idx in range(-1, len(pred_res)):
if (not idx in idxs) and (idx >= 0):
idxs.append(idx)
candidate = self._revise_at_idxs(pred_res, y, all_revision_flag, idxs)
if len(candidate) == 0:
if len(idxs) > 1:
idxs.pop()
else:
if len(idxs) > len(max_candidate_idxs):
found = True
max_candidate_idxs = idxs.copy()
removed = [i for i in lefted_idxs if i in max_candidate_idxs]
if found:
candidate_size.append(len(removed) + 1)
lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]
candidate_size.sort()
score = 0
import math

for i in range(0, len(candidate_size)):
score -= math.exp(-i) * candidate_size[i]
return score

def abduce_rules(self, pred_res):
return self.kb.abduce_rules(pred_res)

kb = HED_prolog_KB(
pseudo_label_list=[1, 0, "+", "="],
pl_file="examples/hed/datasets/learn_add.pl",
)
reasoner = HED_Reasoner(kb)
consist_exs = [
[1, 1, "+", 0, "=", 1, 1],
[1, "+", 1, "=", 1, 0],
[0, "+", 0, "=", 0],
]
inconsist_exs1 = [
[1, 1, "+", 0, "=", 1, 1],
[1, "+", 1, "=", 1, 0],
[0, "+", 0, "=", 0],
[0, "+", 0, "=", 1],
]
inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]]
rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"]

print("HED_kb logic forward")
print(kb.logic_forward(consist_exs))
print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2))
print()
print("HED_kb consist rule")
print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules))
print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules))
print()

print("HED_Reasoner abduce")
res = reasoner.abduce([[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs))
print(res)
res = reasoner.abduce(
[[[None]]] * len(inconsist_exs1), inconsist_exs1, [None] * len(inconsist_exs1)
)
print(res)
res = reasoner.abduce(
[[[None]]] * len(inconsist_exs2), inconsist_exs2, [None] * len(inconsist_exs2)
)
print(res)
print()

print("HED_Reasoner abduce rules")
abduced_rules = reasoner.abduce_rules(consist_exs)
print(abduced_rules)

Loading…
Cancel
Save