@@ -1,2 +1,2 @@ | |||
from .learning import abl_model, basic_nn | |||
from .reasoning import reasoner, kb | |||
from .reasoning import base_kb, ground_kb, reasoner, search_based_kb |
@@ -1,52 +1,64 @@ | |||
from abc import ABCMeta, abstractmethod | |||
from typing import Any, List, Tuple | |||
from typing import Any, List, Optional, Tuple, Union | |||
from ..learning import ABLModel | |||
from ..reasoning import ReasonerBase | |||
from ..structures import ListData | |||
DataSet = Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]] | |||
class BaseBridge(metaclass=ABCMeta): | |||
class BaseBridge(metaclass=ABCMeta): | |||
def __init__(self, model: ABLModel, abducer: ReasonerBase) -> None: | |||
if not isinstance(model, ABLModel): | |||
raise TypeError("Expected an ABLModel") | |||
raise TypeError( | |||
"Expected an instance of ABLModel, but received type: {}".format( | |||
type(model) | |||
) | |||
) | |||
if not isinstance(abducer, ReasonerBase): | |||
raise TypeError("Expected an ReasonerBase") | |||
raise TypeError( | |||
"Expected an instance of ReasonerBase, but received type: {}".format( | |||
type(abducer) | |||
) | |||
) | |||
self.model = model | |||
self.abducer = abducer | |||
@abstractmethod | |||
def predict(self, X: List[List[Any]]) -> Tuple[List[List[Any]], List[List[Any]]]: | |||
def predict( | |||
self, data_samples: ListData | |||
) -> Tuple[List[List[Any]], List[List[Any]]]: | |||
"""Placeholder for predict labels from input.""" | |||
pass | |||
@abstractmethod | |||
def abduce_pseudo_label(self, pseudo_label: List[List[Any]]) -> List[List[Any]]: | |||
def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | |||
"""Placeholder for abduce pseudo labels.""" | |||
pass | |||
@abstractmethod | |||
def idx_to_pseudo_label(self, idx: List[List[Any]]) -> List[List[Any]]: | |||
def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | |||
"""Placeholder for map label space to symbol space.""" | |||
pass | |||
@abstractmethod | |||
def pseudo_label_to_idx(self, pseudo_label: List[List[Any]]) -> List[List[Any]]: | |||
def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]: | |||
"""Placeholder for map symbol space to label space.""" | |||
pass | |||
@abstractmethod | |||
def train(self, train_data): | |||
def train(self, train_data: Union[ListData, DataSet]): | |||
"""Placeholder for train loop of ABductive Learning.""" | |||
pass | |||
@abstractmethod | |||
def test(self, test_data): | |||
def valid(self, valid_data: Union[ListData, DataSet]) -> None: | |||
"""Placeholder for model test.""" | |||
pass | |||
@abstractmethod | |||
def valid(self, valid_data): | |||
def test(self, test_data: Union[ListData, DataSet]) -> None: | |||
"""Placeholder for model validation.""" | |||
pass | |||
@@ -1,13 +1,14 @@ | |||
from ..learning import ABLModel | |||
from ..reasoning import ReasonerBase | |||
from ..evaluation import BaseMetric | |||
from .base_bridge import BaseBridge | |||
from typing import List, Union, Any, Tuple, Dict, Optional | |||
import os.path as osp | |||
from typing import Any, Dict, List, Optional, Tuple, Union | |||
from numpy import ndarray | |||
from torch.utils.data import DataLoader | |||
from ..dataset import BridgeDataset | |||
from ..utils.logger import print_log | |||
from ..evaluation import BaseMetric | |||
from ..learning import ABLModel | |||
from ..reasoning import ReasonerBase | |||
from ..structures import ListData | |||
from ..utils import print_log | |||
from .base_bridge import BaseBridge, DataSet | |||
class SimpleBridge(BaseBridge): | |||
@@ -20,85 +21,99 @@ class SimpleBridge(BaseBridge): | |||
super().__init__(model, abducer) | |||
self.metric_list = metric_list | |||
def predict(self, X) -> Tuple[List[List[Any]], ndarray]: | |||
pred_res = self.model.predict(X) | |||
pred_idx, pred_prob = pred_res["label"], pred_res["prob"] | |||
return pred_idx, pred_prob | |||
# TODO: add abducer.mapping to the property of SimpleBridge | |||
def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: | |||
self.model.predict(data_samples) | |||
return data_samples["pred_idx"], data_samples.get("pred_prob", None) | |||
def abduce_pseudo_label( | |||
self, | |||
pred_prob: ndarray, | |||
pred_pseudo_label: List[List[Any]], | |||
Y: List[Any], | |||
data_samples: ListData, | |||
max_revision: int = -1, | |||
require_more_revision: int = 0, | |||
) -> List[List[Any]]: | |||
return self.abducer.batch_abduce(pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision) | |||
self.abducer.batch_abduce(data_samples, max_revision, require_more_revision) | |||
return data_samples["abduced_pseudo_label"] | |||
def idx_to_pseudo_label( | |||
self, idx: List[List[Any]], mapping: Dict = None | |||
self, data_samples: ListData, mapping: Optional[Dict] = None | |||
) -> List[List[Any]]: | |||
if mapping is None: | |||
mapping = self.abducer.mapping | |||
return [[mapping[_idx] for _idx in sub_list] for sub_list in idx] | |||
pred_idx = data_samples.pred_idx | |||
data_samples.pred_pseudo_label = [ | |||
[mapping[_idx] for _idx in sub_list] for sub_list in pred_idx | |||
] | |||
return data_samples["pred_pseudo_label"] | |||
def pseudo_label_to_idx( | |||
self, pseudo_label: List[List[Any]], mapping: Dict = None | |||
self, data_samples: ListData, mapping: Optional[Dict] = None | |||
) -> List[List[Any]]: | |||
if mapping is None: | |||
mapping = self.abducer.remapping | |||
return [ | |||
[mapping[_pseudo_label] for _pseudo_label in sub_list] | |||
for sub_list in pseudo_label | |||
abduced_idx = [ | |||
[mapping[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list] | |||
for sub_list in data_samples.abduced_pseudo_label | |||
] | |||
data_samples.abduced_idx = abduced_idx | |||
return data_samples["abduced_idx"] | |||
def data_preprocess(self, X: List[Any], gt_pseudo_label: List[Any], Y: List[Any]) -> ListData: | |||
data_samples = ListData() | |||
data_samples.X = X | |||
data_samples.gt_pseudo_label = gt_pseudo_label | |||
data_samples.Y = Y | |||
return data_samples | |||
def train( | |||
self, | |||
train_data: Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]], | |||
epochs: int = 50, | |||
batch_size: Union[int, float] = -1, | |||
train_data: Union[ListData, DataSet], | |||
loops: int = 50, | |||
segment_size: Union[int, float] = -1, | |||
eval_interval: int = 1, | |||
save_interval: Optional[int] = None, | |||
save_dir: Optional[str] = None, | |||
): | |||
dataset = BridgeDataset(*train_data) | |||
data_loader = DataLoader( | |||
dataset, | |||
batch_size=batch_size, | |||
collate_fn=lambda data_list: [list(data) for data in zip(*data_list)], | |||
) | |||
for epoch in range(epochs): | |||
for seg_idx, (X, Z, Y) in enumerate(data_loader): | |||
pred_idx, pred_prob = self.predict(X) | |||
pred_pseudo_label = self.idx_to_pseudo_label(pred_idx) | |||
abduced_pseudo_label = self.abduce_pseudo_label( | |||
pred_prob, pred_pseudo_label, Y | |||
) | |||
abduced_label = self.pseudo_label_to_idx(abduced_pseudo_label) | |||
loss = self.model.train(X, abduced_label) | |||
if isinstance(train_data, ListData): | |||
data_samples = train_data | |||
else: | |||
data_samples = self.data_preprocess(*train_data) | |||
for loop in range(loops): | |||
for seg_idx in range((len(data_samples) - 1) // segment_size + 1): | |||
sub_data_samples = data_samples[ | |||
seg_idx * segment_size : (seg_idx + 1) * segment_size | |||
] | |||
self.predict(sub_data_samples) | |||
self.idx_to_pseudo_label(sub_data_samples) | |||
self.abduce_pseudo_label(sub_data_samples) | |||
self.pseudo_label_to_idx(sub_data_samples) | |||
loss = self.model.train(sub_data_samples) | |||
print_log( | |||
f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{len(data_loader)}] model loss is {loss:.5f}", | |||
f"loop(train) [{loop + 1}/{loops}] segment(train) [{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] model loss is {loss:.5f}", | |||
logger="current", | |||
) | |||
if (epoch + 1) % eval_interval == 0 or epoch == epochs - 1: | |||
print_log(f"Evaluation start: Epoch(val) [{epoch}]", logger="current") | |||
if (loop + 1) % eval_interval == 0 or loop == loops - 1: | |||
print_log(f"Evaluation start: loop(val) [{loop + 1}]", logger="current") | |||
self.valid(train_data) | |||
def _valid(self, data_loader): | |||
for X, Z, Y in data_loader: | |||
pred_idx, pred_prob = self.predict(X) | |||
pred_pseudo_label = self.idx_to_pseudo_label(pred_idx) | |||
data_samples = dict( | |||
pred_idx=pred_idx, | |||
pred_prob=pred_prob, | |||
pred_pseudo_label=pred_pseudo_label, | |||
gt_pseudo_label=Z, | |||
Y=Y, | |||
logic_forward=self.abducer.kb.logic_forward, | |||
) | |||
if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1): | |||
print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current") | |||
self.model.save(save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth")) | |||
def _valid(self, data_samples: ListData, batch_size: int = 128) -> None: | |||
for seg_idx in range((len(data_samples) - 1) // batch_size + 1): | |||
sub_data_samples = data_samples[seg_idx * batch_size : (seg_idx + 1) * batch_size] | |||
self.predict(sub_data_samples) | |||
self.idx_to_pseudo_label(sub_data_samples) | |||
for metric in self.metric_list: | |||
metric.process(data_samples) | |||
metric.process(sub_data_samples) | |||
res = dict() | |||
for metric in self.metric_list: | |||
@@ -108,14 +123,12 @@ class SimpleBridge(BaseBridge): | |||
msg += k + f": {v:.3f} " | |||
print_log(msg, logger="current") | |||
def valid(self, valid_data, batch_size=1000): | |||
dataset = BridgeDataset(*valid_data) | |||
data_loader = DataLoader( | |||
dataset, | |||
batch_size=batch_size, | |||
collate_fn=lambda data_list: [list(data) for data in zip(*data_list)], | |||
) | |||
self._valid(data_loader) | |||
def test(self, test_data, batch_size=1000): | |||
self.valid(test_data, batch_size) | |||
def valid(self, valid_data: Union[ListData, DataSet], batch_size: int = 128) -> None: | |||
if not isinstance(valid_data, ListData): | |||
data_samples = self.data_preprocess(*valid_data) | |||
else: | |||
data_samples = valid_data | |||
self._valid(data_samples, batch_size=batch_size) | |||
def test(self, test_data: Union[ListData, DataSet], batch_size: int = 128) -> None: | |||
self.valid(test_data, batch_size=batch_size) |
@@ -1,3 +1,4 @@ | |||
from .bridge_dataset import BridgeDataset | |||
from .classification_dataset import ClassificationDataset | |||
from .regression_dataset import RegressionDataset | |||
from .prediction_dataset import PredictionDataset | |||
from .regression_dataset import RegressionDataset |
@@ -1,5 +1,6 @@ | |||
from typing import Any, List, Tuple | |||
from torch.utils.data import Dataset | |||
from typing import List, Any, Tuple | |||
class BridgeDataset(Dataset): | |||
@@ -1,6 +1,7 @@ | |||
from typing import Any, Callable, List, Tuple | |||
import torch | |||
from torch.utils.data import Dataset | |||
from typing import List, Any, Tuple, Callable | |||
class ClassificationDataset(Dataset): | |||
@@ -0,0 +1,56 @@ | |||
from typing import Any, Callable, List, Tuple | |||
import torch | |||
from torch.utils.data import Dataset | |||
class PredictionDataset(Dataset): | |||
def __init__(self, X: List[Any], transform: Callable[..., Any] = None): | |||
""" | |||
Initialize the dataset used for classification task. | |||
Parameters | |||
---------- | |||
X : List[Any] | |||
The input data. | |||
transform : Callable[..., Any], optional | |||
A function/transform that takes in an object and returns a transformed version. Defaults to None. | |||
""" | |||
if not isinstance(X, list): | |||
raise ValueError("X should be of type list.") | |||
self.X = X | |||
self.transform = transform | |||
def __len__(self) -> int: | |||
""" | |||
Return the length of the dataset. | |||
Returns | |||
------- | |||
int | |||
The length of the dataset. | |||
""" | |||
return len(self.X) | |||
def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]: | |||
""" | |||
Get the item at the given index. | |||
Parameters | |||
---------- | |||
index : int | |||
The index of the item to get. | |||
Returns | |||
------- | |||
Tuple[Any, torch.Tensor] | |||
A tuple containing the object and its label. | |||
""" | |||
if index >= len(self): | |||
raise ValueError("index range error") | |||
x = self.X[index] | |||
if self.transform is not None: | |||
x = self.transform(x) | |||
return x |
@@ -1,6 +1,7 @@ | |||
from typing import Any, List, Tuple | |||
import torch | |||
from torch.utils.data import Dataset | |||
from typing import List, Any, Tuple | |||
class RegressionDataset(Dataset): | |||
@@ -1,3 +1,3 @@ | |||
from .base_metric import BaseMetric | |||
from .symbol_metric import SymbolMetric | |||
from .semantics_metric import SemanticsMetric | |||
from .symbol_metric import SymbolMetric |
@@ -1,8 +1,8 @@ | |||
import logging | |||
from abc import ABCMeta, abstractmethod | |||
from typing import Any, List, Optional, Sequence | |||
from ..utils import print_log | |||
import logging | |||
from ..utils import print_log | |||
class BaseMetric(metaclass=ABCMeta): | |||
@@ -1,25 +1,22 @@ | |||
from typing import Optional, Sequence | |||
from ..reasoning import BaseKB | |||
from .base_metric import BaseMetric | |||
class ABLMetric(): | |||
pass | |||
class SemanticsMetric(BaseMetric): | |||
def __init__(self, prefix: Optional[str] = None) -> None: | |||
def __init__(self, kb: BaseKB = None, prefix: Optional[str] = None) -> None: | |||
super().__init__(prefix) | |||
self.kb = kb | |||
def process(self, data_samples: Sequence[dict]) -> None: | |||
pred_pseudo_label = data_samples["pred_pseudo_label"] | |||
gt_Y = data_samples["Y"] | |||
logic_forward = data_samples["logic_forward"] | |||
for pred_z, y in zip(pred_pseudo_label, gt_Y): | |||
if logic_forward(pred_z) == y: | |||
for data_sample in data_samples: | |||
if self.kb.check_equal(data_sample, data_sample["Y"][0]): | |||
self.results.append(1) | |||
else: | |||
self.results.append(0) | |||
def compute_metrics(self, results: list) -> dict: | |||
metrics = dict() | |||
metrics["semantics_accuracy"] = sum(results) / len(results) | |||
return metrics | |||
return metrics |
@@ -1,4 +1,5 @@ | |||
from typing import Optional, Sequence, Callable | |||
from typing import Optional, Sequence | |||
from .base_metric import BaseMetric | |||
@@ -10,8 +10,10 @@ | |||
# | |||
# ================================================================# | |||
import pickle | |||
from utils import flatten, reform_idx | |||
from typing import List, Any, Optional | |||
from typing import Any, Dict | |||
from ..structures import ListData | |||
from ..utils import reform_idx | |||
class ABLModel: | |||
@@ -30,7 +32,7 @@ class ABLModel: | |||
Methods | |||
------- | |||
predict(X: List[List[Any]], mapping: Optional[dict] = None) -> dict | |||
predict(X: List[List[Any]], mapping: Optional[Dict] = None) -> Dict | |||
Predict the labels and probabilities for the given data. | |||
valid(X: List[List[Any]], Y: List[Any]) -> float | |||
Calculate the accuracy score for the given data. | |||
@@ -42,20 +44,13 @@ class ABLModel: | |||
Load the model from a file. | |||
""" | |||
def __init__(self, base_model) -> None: | |||
self.classifier_list = [] | |||
self.classifier_list.append(base_model) | |||
def __init__(self, base_model: Any) -> None: | |||
if not (hasattr(base_model, "fit") and hasattr(base_model, "predict")): | |||
raise NotImplementedError("The base_model should implement fit and predict methods.") | |||
if not ( | |||
hasattr(base_model, "fit") | |||
and hasattr(base_model, "predict") | |||
and hasattr(base_model, "score") | |||
): | |||
raise NotImplementedError( | |||
"base_model should have fit, predict and score methods." | |||
) | |||
self.base_model = base_model | |||
def predict(self, X: List[List[Any]], mapping: Optional[dict] = None) -> dict: | |||
def predict(self, data_samples: ListData) -> Dict: | |||
""" | |||
Predict the labels and probabilities for the given data. | |||
@@ -63,53 +58,30 @@ class ABLModel: | |||
---------- | |||
X : List[List[Any]] | |||
The data to predict on. | |||
mapping : Optional[dict], optional | |||
A mapping dictionary to map labels to their original values, by default None. | |||
Returns | |||
------- | |||
dict | |||
A dictionary containing the predicted labels and probabilities. | |||
""" | |||
model = self.classifier_list[0] | |||
data_X = flatten(X) | |||
model = self.base_model | |||
data_X = data_samples.flatten("X") | |||
if hasattr(model, "predict_proba"): | |||
prob = model.predict_proba(X=data_X) | |||
label = prob.argmax(axis=1) | |||
prob = reform_idx(prob, X) | |||
prob = reform_idx(prob, data_samples["X"]) | |||
else: | |||
prob = None | |||
label = model.predict(X=data_X) | |||
label = reform_idx(label, data_samples["X"]) | |||
if mapping is not None: | |||
label = [mapping[y] for y in label] | |||
label = reform_idx(label, X) | |||
data_samples.pred_idx = label | |||
if prob is not None: | |||
data_samples.pred_prob = prob | |||
return {"label": label, "prob": prob} | |||
def valid(self, X: List[List[Any]], Y: List[Any]) -> float: | |||
""" | |||
Calculate the accuracy for the given data. | |||
Parameters | |||
---------- | |||
X : List[List[Any]] | |||
The data to calculate the accuracy on. | |||
Y : List[Any] | |||
The true labels for the given data. | |||
Returns | |||
------- | |||
float | |||
The accuracy score for the given data. | |||
""" | |||
data_X = flatten(X) | |||
data_Y = flatten(Y) | |||
score = self.classifier_list[0].score(X=data_X, y=data_Y) | |||
return score | |||
def train(self, X: List[List[Any]], Y: List[Any]) -> float: | |||
def train(self, data_samples: ListData) -> float: | |||
""" | |||
Train the model on the given data. | |||
@@ -125,29 +97,30 @@ class ABLModel: | |||
float | |||
The loss value of the trained model. | |||
""" | |||
data_X = flatten(X) | |||
data_Y = flatten(Y) | |||
return self.classifier_list[0].fit(X=data_X, y=data_Y) | |||
data_X = data_samples.flatten("X") | |||
data_y = data_samples.flatten("abduced_idx") | |||
return self.base_model.fit(X=data_X, y=data_y) | |||
def _model_operation(self, operation: str, *args, **kwargs): | |||
model = self.classifier_list[0] | |||
model = self.base_model | |||
if hasattr(model, operation): | |||
method = getattr(model, operation) | |||
method(*args, **kwargs) | |||
else: | |||
try: | |||
if not f"{operation}_path" in kwargs.keys(): | |||
raise ValueError(f"'{operation}_path' should not be None") | |||
if operation == "save": | |||
with open(kwargs["save_path"], 'wb') as file: | |||
pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) | |||
elif operation == "load": | |||
with open(kwargs["load_path"], 'rb') as file: | |||
self.classifier_list[0] = pickle.load(file) | |||
except: | |||
raise NotImplementedError( | |||
f"{type(model).__name__} object doesn't have the {operation} method" | |||
) | |||
if not f"{operation}_path" in kwargs.keys(): | |||
raise ValueError(f"'{operation}_path' should not be None") | |||
else: | |||
try: | |||
if operation == "save": | |||
with open(kwargs["save_path"], "wb") as file: | |||
pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) | |||
elif operation == "load": | |||
with open(kwargs["load_path"], "rb") as file: | |||
self.base_model = pickle.load(file) | |||
except: | |||
raise NotImplementedError( | |||
f"{type(model).__name__} object doesn't have the {operation} method and the default pickle-based {operation} method failed." | |||
) | |||
def save(self, *args, **kwargs) -> None: | |||
""" | |||
@@ -10,14 +10,16 @@ | |||
# | |||
# ================================================================# | |||
import torch | |||
import os | |||
import logging | |||
from typing import Any, Callable, List, Optional, T, Tuple | |||
import numpy | |||
import torch | |||
from torch.utils.data import DataLoader | |||
from ..utils.logger import print_log | |||
from ..dataset import ClassificationDataset | |||
import os | |||
from typing import List, Any, T, Optional, Callable, Tuple | |||
from ..dataset import ClassificationDataset, PredictionDataset | |||
from ..utils.logger import print_log | |||
class BasicNN: | |||
@@ -99,9 +101,7 @@ class BasicNN: | |||
loss_value = self.train_epoch(data_loader) | |||
if self.save_interval is not None and (epoch + 1) % self.save_interval == 0: | |||
if self.save_dir is None: | |||
raise ValueError( | |||
"save_dir should not be None if save_interval is not None." | |||
) | |||
raise ValueError("save_dir should not be None if save_interval is not None.") | |||
self.save(epoch + 1) | |||
if self.stop_loss is not None and loss_value < self.stop_loss: | |||
break | |||
@@ -191,7 +191,7 @@ class BasicNN: | |||
with torch.no_grad(): | |||
results = [] | |||
for data, _ in data_loader: | |||
for data in data_loader: | |||
data = data.to(device) | |||
out = model(data) | |||
results.append(out) | |||
@@ -199,7 +199,10 @@ class BasicNN: | |||
return torch.cat(results, axis=0) | |||
def predict( | |||
self, data_loader: DataLoader = None, X: List[Any] = None | |||
self, | |||
data_loader: DataLoader = None, | |||
X: List[Any] = None, | |||
test_transform: Callable[..., Any] = None, | |||
) -> numpy.ndarray: | |||
""" | |||
Predict the class of the input data. | |||
@@ -218,11 +221,28 @@ class BasicNN: | |||
""" | |||
if data_loader is None: | |||
data_loader = self._data_loader(X) | |||
if test_transform is None: | |||
print_log( | |||
"Transform used in the training phase will be used in prediction.", | |||
"current", | |||
level=logging.WARNING, | |||
) | |||
dataset = PredictionDataset(X, self.transform) | |||
else: | |||
dataset = PredictionDataset(X, test_transform) | |||
data_loader = DataLoader( | |||
dataset, | |||
batch_size=self.batch_size, | |||
num_workers=int(self.num_workers), | |||
collate_fn=self.collate_fn, | |||
) | |||
return self._predict(data_loader).argmax(axis=1).cpu().numpy() | |||
def predict_proba( | |||
self, data_loader: DataLoader = None, X: List[Any] = None | |||
self, | |||
data_loader: DataLoader = None, | |||
X: List[Any] = None, | |||
test_transform: Callable[..., Any] = None, | |||
) -> numpy.ndarray: | |||
""" | |||
Predict the probability of each class for the input data. | |||
@@ -241,7 +261,21 @@ class BasicNN: | |||
""" | |||
if data_loader is None: | |||
data_loader = self._data_loader(X) | |||
if test_transform is None: | |||
print_log( | |||
"Transform used in the training phase will be used in prediction.", | |||
"current", | |||
level=logging.WARNING, | |||
) | |||
dataset = PredictionDataset(X, self.transform) | |||
else: | |||
dataset = PredictionDataset(X, test_transform) | |||
data_loader = DataLoader( | |||
dataset, | |||
batch_size=self.batch_size, | |||
num_workers=int(self.num_workers), | |||
collate_fn=self.collate_fn, | |||
) | |||
return self._predict(data_loader).softmax(axis=1).cpu().numpy() | |||
def _score(self, data_loader) -> Tuple[float, float]: | |||
@@ -313,15 +347,14 @@ class BasicNN: | |||
if data_loader is None: | |||
data_loader = self._data_loader(X, y) | |||
mean_loss, accuracy = self._score(data_loader) | |||
print_log( | |||
f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current" | |||
) | |||
print_log(f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current") | |||
return accuracy | |||
def _data_loader( | |||
self, | |||
X: List[Any], | |||
y: List[int] = None, | |||
shuffle: bool = True, | |||
) -> DataLoader: | |||
""" | |||
Generate a DataLoader for user-provided input and target data. | |||
@@ -350,7 +383,7 @@ class BasicNN: | |||
data_loader = DataLoader( | |||
dataset, | |||
batch_size=self.batch_size, | |||
shuffle=True, | |||
shuffle=shuffle, | |||
num_workers=int(self.num_workers), | |||
collate_fn=self.collate_fn, | |||
) | |||
@@ -368,14 +401,13 @@ class BasicNN: | |||
The path to save the model, by default None. | |||
""" | |||
if self.save_dir is None and save_path is None: | |||
raise ValueError( | |||
"'save_dir' and 'save_path' should not be None simultaneously." | |||
) | |||
raise ValueError("'save_dir' and 'save_path' should not be None simultaneously.") | |||
if save_path is None: | |||
save_path = os.path.join( | |||
self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth" | |||
) | |||
if save_path is not None: | |||
if not os.path.exists(os.path.dirname(save_path)): | |||
os.makedirs(os.path.dirname(save_path)) | |||
else: | |||
save_path = os.path.join(self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth") | |||
if not os.path.exists(self.save_dir): | |||
os.makedirs(self.save_dir) | |||
@@ -1,2 +1,6 @@ | |||
from .base_kb import BaseKB | |||
from .ground_kb import GroundKB | |||
from .prolog_based_kb import PrologBasedKB | |||
from .reasoner import ReasonerBase | |||
from .kb import KBBase, prolog_KB | |||
from .search_based_kb import SearchBasedKB | |||
from .search_engine import BFS, BaseSearchEngine |
@@ -0,0 +1,14 @@ | |||
from abc import ABC | |||
class BaseKB(ABC): | |||
def __init__(self, pseudo_label_list) -> None: | |||
self.pseudo_label_list = pseudo_label_list | |||
# TODO: When the output is excessively long, use ellipses as a substitute. | |||
def __repr__(self): | |||
return ( | |||
f"<{self.__class__.__name__}(\n" | |||
f" pseudo_label_list: {self.pseudo_label_list!r}\n" | |||
f") at {hex(id(self))}>" | |||
) |
@@ -0,0 +1,60 @@ | |||
from abc import ABC, abstractmethod | |||
from typing import Any, Hashable, List | |||
from ..structures import ListData | |||
from .base_kb import BaseKB | |||
class GroundKB(BaseKB, ABC): | |||
def __init__(self, pseudo_label_list: List) -> None: | |||
super().__init__(pseudo_label_list) | |||
self.GKB = self.construct_base() | |||
@abstractmethod | |||
def construct_base(self) -> dict: | |||
pass | |||
@abstractmethod | |||
def get_key(self, data_sample: ListData) -> Hashable: | |||
pass | |||
def key2candidates(self, key: Hashable) -> List[List[Any]]: | |||
return self.GKB[key] | |||
def filter_candidates( | |||
self, | |||
data_sample: ListData, | |||
candidates: List[List[Any]], | |||
max_revision_num: int, | |||
require_more_revision: int = 0, | |||
) -> List[List[Any]]: | |||
return candidates | |||
def abduce_candidates( | |||
self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0 | |||
): | |||
return self._abduce_by_GKB( | |||
data_sample=data_sample, | |||
max_revision_num=max_revision_num, | |||
require_more_revision=require_more_revision, | |||
) | |||
def _abduce_by_GKB( | |||
self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0 | |||
): | |||
candidates = self.key2candidates(self.get_key(data_sample)) | |||
return self.filter_candidates( | |||
data_sample=data_sample, | |||
max_revision_num=max_revision_num, | |||
require_more_revision=require_more_revision, | |||
candidates=candidates, | |||
) | |||
# TODO: When the output is excessively long, use ellipses as a substitute. | |||
def __repr__(self): | |||
return ( | |||
f"<{self.__class__.__name__}(\n" | |||
f" pseudo_label_list: {self.pseudo_label_list!r}\n" | |||
f" GKB: {self.GKB!r}\n" | |||
f") at {hex(id(self))}>" | |||
) |
@@ -1,222 +0,0 @@ | |||
from abc import ABC, abstractmethod | |||
import bisect | |||
import numpy as np | |||
from collections import defaultdict | |||
from itertools import product, combinations | |||
from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list | |||
from multiprocessing import Pool | |||
from functools import lru_cache | |||
import pyswip | |||
class KBBase(ABC): | |||
def __init__(self, pseudo_label_list, max_err=0, use_cache=True): | |||
# TODO:添加一下类型检查,比如 | |||
# if not isinstance(X, (np.ndarray, spmatrix)): | |||
# raise TypeError("X should be numpy array or sparse matrix") | |||
self.pseudo_label_list = pseudo_label_list | |||
self.max_err = max_err | |||
self.use_cache = use_cache | |||
@abstractmethod | |||
def logic_forward(self, pseudo_labels): | |||
pass | |||
def abduce_candidates(self, pred_res, y, max_revision_num, require_more_revision=0): | |||
if not self.use_cache: | |||
return self._abduce_by_search(pred_res, y, max_revision_num, require_more_revision) | |||
else: | |||
return self._abduce_by_search_cache(to_hashable(pred_res), to_hashable(y), max_revision_num, require_more_revision) | |||
def revise_by_idx(self, pred_res, y, revision_idx): | |||
candidates = [] | |||
abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx)) | |||
for c in abduce_c: | |||
candidate = pred_res.copy() | |||
for i, idx in enumerate(revision_idx): | |||
candidate[idx] = c[i] | |||
if check_equal(self.logic_forward(candidate), y, self.max_err): | |||
candidates.append(candidate) | |||
return candidates | |||
def _revision(self, revision_num, pred_res, y): | |||
new_candidates = [] | |||
revision_idx_list = combinations(range(len(pred_res)), revision_num) | |||
for revision_idx in revision_idx_list: | |||
candidates = self.revise_by_idx(pred_res, y, revision_idx) | |||
new_candidates.extend(candidates) | |||
return new_candidates | |||
def _abduce_by_search(self, pred_res, y, max_revision_num, require_more_revision): | |||
candidates = [] | |||
for revision_num in range(len(pred_res) + 1): | |||
if revision_num == 0 and check_equal(self.logic_forward(pred_res), y, self.max_err): | |||
candidates.append(pred_res) | |||
elif revision_num > 0: | |||
candidates.extend(self._revision(revision_num, pred_res, y)) | |||
if len(candidates) > 0: | |||
min_revision_num = revision_num | |||
break | |||
if revision_num >= max_revision_num: | |||
return [] | |||
for revision_num in range(min_revision_num + 1, min_revision_num + require_more_revision + 1): | |||
if revision_num > max_revision_num: | |||
return candidates | |||
candidates.extend(self._revision(revision_num, pred_res, y)) | |||
return candidates | |||
@lru_cache(maxsize=None) | |||
def _abduce_by_search_cache(self, pred_res, y, max_revision_num, require_more_revision): | |||
pred_res = hashable_to_list(pred_res) | |||
y = hashable_to_list(y) | |||
return self._abduce_by_search(pred_res, y, max_revision_num, require_more_revision) | |||
def _dict_len(self, dic): | |||
if not self.GKB_flag: | |||
return 0 | |||
else: | |||
return sum(len(c) for c in dic.values()) | |||
def __len__(self): | |||
if not self.GKB_flag: | |||
return 0 | |||
else: | |||
return sum(self._dict_len(v) for v in self.base.values()) | |||
class ground_KB(KBBase): | |||
def __init__(self, pseudo_label_list, GKB_len_list=None, max_err=0): | |||
super().__init__(pseudo_label_list, max_err) | |||
self.GKB_len_list = GKB_len_list | |||
self.base = {} | |||
X, Y = self._get_GKB() | |||
for x, y in zip(X, Y): | |||
self.base.setdefault(len(x), defaultdict(list))[y].append(x) | |||
# For parallel version of _get_GKB | |||
def _get_XY_list(self, args): | |||
pre_x, post_x_it = args[0], args[1] | |||
XY_list = [] | |||
for post_x in post_x_it: | |||
x = (pre_x,) + post_x | |||
y = self.logic_forward(x) | |||
if y is not None: | |||
XY_list.append((x, y)) | |||
return XY_list | |||
# Parallel _get_GKB | |||
def _get_GKB(self): | |||
X, Y = [], [] | |||
for length in self.GKB_len_list: | |||
arg_list = [] | |||
for pre_x in self.pseudo_label_list: | |||
post_x_it = product(self.pseudo_label_list, repeat=length - 1) | |||
arg_list.append((pre_x, post_x_it)) | |||
with Pool(processes=len(arg_list)) as pool: | |||
ret_list = pool.map(self._get_XY_list, arg_list) | |||
for XY_list in ret_list: | |||
if len(XY_list) == 0: | |||
continue | |||
part_X, part_Y = zip(*XY_list) | |||
X.extend(part_X) | |||
Y.extend(part_Y) | |||
if Y and isinstance(Y[0], (int, float)): | |||
X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) | |||
return X, Y | |||
def abduce_candidates(self, pred_res, y, max_revision_num, require_more_revision=0): | |||
return self._abduce_by_GKB(pred_res, y, max_revision_num, require_more_revision) | |||
def _find_candidate_GKB(self, pred_res, y): | |||
if self.max_err == 0: | |||
return self.base[len(pred_res)][y] | |||
else: | |||
potential_candidates = self.base[len(pred_res)] | |||
key_list = list(potential_candidates.keys()) | |||
key_idx = bisect.bisect_left(key_list, y) | |||
all_candidates = [] | |||
for idx in range(key_idx - 1, 0, -1): | |||
k = key_list[idx] | |||
if abs(k - y) <= self.max_err: | |||
all_candidates.extend(potential_candidates[k]) | |||
else: | |||
break | |||
for idx in range(key_idx, len(key_list)): | |||
k = key_list[idx] | |||
if abs(k - y) <= self.max_err: | |||
all_candidates.extend(potential_candidates[k]) | |||
else: | |||
break | |||
return all_candidates | |||
def _abduce_by_GKB(self, pred_res, y, max_revision_num, require_more_revision): | |||
if self.base == {} or len(pred_res) not in self.GKB_len_list: | |||
return [] | |||
all_candidates = self._find_candidate_GKB(pred_res, y) | |||
if len(all_candidates) == 0: | |||
return [] | |||
cost_list = hamming_dist(pred_res, all_candidates) | |||
min_revision_num = np.min(cost_list) | |||
revision_num = min(max_revision_num, min_revision_num + require_more_revision) | |||
idxs = np.where(cost_list <= revision_num)[0] | |||
candidates = [all_candidates[idx] for idx in idxs] | |||
return candidates | |||
class prolog_KB(KBBase): | |||
def __init__(self, pseudo_label_list, pl_file, max_err=0): | |||
super().__init__(pseudo_label_list, max_err) | |||
self.prolog = pyswip.Prolog() | |||
self.prolog.consult(pl_file) | |||
def logic_forward(self, pseudo_labels): | |||
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]['Res'] | |||
if result == 'true': | |||
return True | |||
elif result == 'false': | |||
return False | |||
return result | |||
def _revision_pred_res(self, pred_res, revision_idx): | |||
import re | |||
revision_pred_res = pred_res.copy() | |||
revision_pred_res = flatten(revision_pred_res) | |||
for idx in revision_idx: | |||
revision_pred_res[idx] = 'P' + str(idx) | |||
revision_pred_res = reform_idx(revision_pred_res, pred_res) | |||
# TODO:不知道有没有更简洁的方法 | |||
regex = r"'P\d+'" | |||
return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pred_res)) | |||
def get_query_string(self, pred_res, y, revision_idx): | |||
query_string = "logic_forward(" | |||
query_string += self._revision_pred_res(pred_res, revision_idx) | |||
key_is_none_flag = y is None or (type(y) == list and y[0] is None) | |||
query_string += ",%s)." % y if not key_is_none_flag else ")." | |||
return query_string | |||
def revise_by_idx(self, pred_res, y, revision_idx): | |||
candidates = [] | |||
query_string = self.get_query_string(pred_res, y, revision_idx) | |||
save_pred_res = pred_res | |||
pred_res = flatten(pred_res) | |||
abduce_c = [list(z.values()) for z in self.prolog.query(query_string)] | |||
for c in abduce_c: | |||
candidate = pred_res.copy() | |||
for i, idx in enumerate(revision_idx): | |||
candidate[idx] = c[i] | |||
candidate = reform_idx(candidate, save_pred_res) | |||
candidates.append(candidate) | |||
return candidates |
@@ -0,0 +1,44 @@ | |||
from abc import ABC, abstractmethod | |||
from typing import Any, Generator, List, Tuple, Union | |||
import numpy as np | |||
import pyswip | |||
from ..structures import ListData | |||
from .base_kb import BaseKB | |||
class PrologBasedKB(BaseKB, ABC): | |||
def __init__(self, pseudo_label_list, pl_file): | |||
self.pseudo_label_list = pseudo_label_list | |||
self.prolog = pyswip.Prolog() | |||
self.prolog.consult(pl_file) | |||
def logic_forward( | |||
self, data_sample: ListData, revision_idx: Union[List, Tuple, np.ndarray] = None | |||
) -> Generator[Union[Any, pyswip.Variable, list, dict, None], Any, None]: | |||
return self.prolog.query(self.to_query(data_sample, revision_idx)) | |||
@abstractmethod | |||
def to_query(self, data_sample: ListData, revision_idx: Union[List, Tuple, np.ndarray] = None): | |||
pass | |||
@abstractmethod | |||
def postprocess( | |||
self, query_res, data_sample: ListData, revision_idx: Union[List, Tuple, np.ndarray] | |||
): | |||
return list(query_res) | |||
@abstractmethod | |||
def filter_candidates( | |||
self, | |||
data_sample: ListData, | |||
candidates: List[List[Any]], | |||
max_revision_num: int, | |||
require_more_revision: int = 0, | |||
) -> List[List[Any]]: | |||
return candidates | |||
def revise_at_idx(self, data_sample: ListData, revision_idx: Union[List, Tuple, np.ndarray]): | |||
query_res = self.logic_forward(data_sample, revision_idx) | |||
return self.postprocess(query_res, data_sample, revision_idx) |
@@ -1,25 +1,33 @@ | |||
from typing import Any, List, Mapping, Optional | |||
import numpy as np | |||
from zoopt import Dimension, Objective, Parameter, Opt | |||
from ..utils.utils import ( | |||
confidence_dist, | |||
flatten, | |||
reform_idx, | |||
hamming_dist, | |||
calculate_revision_num, | |||
) | |||
from ..structures import ListData | |||
from ..utils import Cache, calculate_revision_num, confidence_dist, hamming_dist | |||
from .base_kb import BaseKB | |||
from .search_engine import BFS, BaseSearchEngine | |||
class ReasonerBase: | |||
def __init__(self, kb, dist_func="hamming", mapping=None, use_zoopt=False): | |||
def __init__( | |||
self, | |||
kb: BaseKB, | |||
dist_func: str = "confidence", | |||
mapping: Optional[Mapping] = None, | |||
search_engine: Optional[BaseSearchEngine] = None, | |||
use_cache: bool = False, | |||
cache_file: Optional[str] = None, | |||
cache_size: Optional[int] = 4096, | |||
): | |||
""" | |||
Base class for all reasoner in the ABL system. | |||
Parameters | |||
---------- | |||
kb : KBBase | |||
kb : BaseKB | |||
The knowledge base to be used for reasoning. | |||
dist_func : str, optional | |||
The distance function to be used. Can be "hamming" or "confidence". Default is "hamming". | |||
The distance function to be used. Can be "hamming" or "confidence". Default is "confidence". | |||
mapping : dict, optional | |||
A mapping of indices to labels. If None, a default mapping is generated. | |||
use_zoopt : bool, optional | |||
@@ -31,207 +39,204 @@ class ReasonerBase: | |||
If the specified distance function is neither "hamming" nor "confidence". | |||
""" | |||
if not (dist_func == "hamming" or dist_func == "confidence"): | |||
raise NotImplementedError # Only hamming or confidence distance is available. | |||
if not isinstance(kb, BaseKB): | |||
raise ValueError("The kb should be of type BaseKB.") | |||
self.kb = kb | |||
if dist_func not in ["hamming", "confidence"]: | |||
raise NotImplementedError(f"The distance function '{dist_func}' is not implemented.") | |||
self.dist_func = dist_func | |||
self.use_zoopt = use_zoopt | |||
if mapping is None: | |||
self.mapping = { | |||
index: label for index, label in enumerate(self.kb.pseudo_label_list) | |||
} | |||
self.mapping = {index: label for index, label in enumerate(self.kb.pseudo_label_list)} | |||
else: | |||
self.mapping = mapping | |||
self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) | |||
if not isinstance(mapping, dict): | |||
raise ValueError("mapping must be of type dict") | |||
def _get_cost_list(self, pred_pseudo_label, pred_prob, candidates): | |||
""" | |||
Get the list of costs between each pseudo label and candidate. | |||
for key, value in mapping.items(): | |||
if not isinstance(key, int): | |||
raise ValueError("All keys in the mapping must be integers") | |||
Parameters | |||
---------- | |||
pred_pseudo_label : list | |||
The pseudo label to be used for computing costs of candidates. | |||
pred_prob : list | |||
Probabilities of the predictions. Used when distance function is "confidence". | |||
candidates : list | |||
List of candidate abduction result. | |||
if value not in self.kb.pseudo_label_list: | |||
raise ValueError("All values in the mapping must be in the pseudo_label_list") | |||
Returns | |||
------- | |||
numpy.ndarray | |||
Array of computed costs for each candidate. | |||
""" | |||
if self.dist_func == "hamming": | |||
return hamming_dist(pred_pseudo_label, candidates) | |||
elif self.dist_func == "confidence": | |||
candidates = [[self.remapping[x] for x in c] for c in candidates] | |||
return confidence_dist(pred_prob, candidates) | |||
def _get_one_candidate(self, pred_pseudo_label, pred_prob, candidates): | |||
""" | |||
Get one candidate. If multiple candidates exist, return the one with minimum cost. | |||
Parameters | |||
---------- | |||
pred_pseudo_label : list | |||
The pseudo label to be used for selecting a candidate. | |||
pred_prob : list | |||
Probabilities of the predictions. | |||
candidates : list | |||
List of candidate abduction result. | |||
self.mapping = mapping | |||
self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) | |||
Returns | |||
------- | |||
list | |||
The chosen candidate based on minimum cost. | |||
If no candidates, an empty list is returned. | |||
""" | |||
if len(candidates) == 0: | |||
return [] | |||
elif len(candidates) == 1: | |||
return candidates[0] | |||
if search_engine is None: | |||
self.search_engine = BFS() | |||
else: | |||
cost_array = self._get_cost_list(pred_pseudo_label, pred_prob, candidates) | |||
candidate = candidates[np.argmin(cost_array)] | |||
return candidate | |||
if not isinstance(search_engine, BaseSearchEngine): | |||
raise ValueError("The search_engine should be of type BaseSearchEngine.") | |||
else: | |||
self.search_engine = search_engine | |||
self.use_cache = use_cache | |||
self.cache_file = cache_file | |||
if self.use_cache: | |||
if not hasattr(self, "get_key"): | |||
raise NotImplementedError("If use_cache is True, get_key should be implemented.") | |||
key_func = self.get_key | |||
else: | |||
key_func = lambda x: x | |||
self.cache = Cache[ListData, List[List[Any]]]( | |||
func=self.abduce_candidates, | |||
cache=self.use_cache, | |||
cache_file=self.cache_file, | |||
key_func=key_func, | |||
max_size=cache_size, | |||
) | |||
def zoopt_revision_score(self, symbol_num, pred_pseudo_label, pred_prob, y, sol): | |||
def abduce( | |||
self, | |||
data_sample: ListData, | |||
max_revision: int = -1, | |||
require_more_revision: int = 0, | |||
): | |||
""" | |||
Get the revision score for a single solution. | |||
Perform revision by abduction on the given data. | |||
Parameters | |||
---------- | |||
symbol_num : int | |||
Number of total symbols. | |||
pred_pseudo_label : list | |||
List of predicted pseudo labels. | |||
pred_prob : list | |||
List of probabilities for predicted results. | |||
pred_pseudo_label : list | |||
List of predicted pseudo labels. | |||
y : any | |||
Ground truth for the predicted results. | |||
sol : array-like | |||
Solution to evaluate. | |||
max_revision : int or float, optional | |||
Maximum number of revisions to use. If float, represents the fraction of total revisions to use. | |||
If -1, any revisions are allowed. Defaults to -1. | |||
require_more_revision : int, optional | |||
Number of additional revisions to require. Defaults to 0. | |||
Returns | |||
------- | |||
float | |||
The revision score for the given solution. | |||
list | |||
The abduced revisions. | |||
""" | |||
revision_idx = np.where(sol.get_x() != 0)[0] | |||
candidates = self.revise_by_idx(pred_pseudo_label, y, revision_idx) | |||
if len(candidates) > 0: | |||
return np.min(self._get_cost_list(pred_pseudo_label, pred_prob, candidates)) | |||
else: | |||
return symbol_num | |||
symbol_num = data_sample.elements_num("pred_pseudo_label") | |||
max_revision_num = calculate_revision_num(max_revision, symbol_num) | |||
data_sample.set_metainfo(dict(symbol_num=symbol_num)) | |||
def _constrain_revision_num(self, solution, max_revision_num): | |||
x = solution.get_x() | |||
return max_revision_num - x.sum() | |||
candidates = self.cache.get(data_sample, max_revision_num, require_more_revision) | |||
candidate = self.select_one_candidate(data_sample, candidates) | |||
return candidate | |||
def zoopt_get_solution( | |||
self, symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num | |||
def abduce_candidates( | |||
self, | |||
data_sample: ListData, | |||
max_revision_num: int = -1, | |||
require_more_revision: int = 0, | |||
): | |||
"""Get the optimal solution using the Zoopt library. | |||
""" | |||
Perform revision by abduction on the given data. | |||
Parameters | |||
---------- | |||
symbol_num : int | |||
Number of total symbols. | |||
pred_pseudo_label : list | |||
List of predicted pseudo labels. | |||
pred_prob : list | |||
List of probabilities for predicted results. | |||
pred_pseudo_label : list | |||
List of predicted pseudo labels. | |||
y : any | |||
Ground truth for the predicted results. | |||
max_revision_num : int | |||
Maximum number of revisions to use. | |||
max_revision : int or float, optional | |||
Maximum number of revisions to use. If float, represents the fraction of total revisions to use. | |||
If -1, any revisions are allowed. Defaults to -1. | |||
require_more_revision : int, optional | |||
Number of additional revisions to require. Defaults to 0. | |||
Returns | |||
------- | |||
array-like | |||
The optimal solution, i.e., where to revise predict pseudo label. | |||
list | |||
The abduced revisions. | |||
""" | |||
dimension = Dimension( | |||
size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num | |||
) | |||
objective = Objective( | |||
lambda sol: self.zoopt_revision_score( | |||
symbol_num, pred_pseudo_label, pred_prob, y, sol | |||
), | |||
dim=dimension, | |||
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), | |||
) | |||
parameter = Parameter(budget=100, intermediate_result=False, autoset=True) | |||
solution = Opt.min(objective, parameter).get_x() | |||
return solution | |||
def revise_by_idx(self, pred_pseudo_label, y, revision_idx): | |||
if hasattr(self.kb, "abduce_candidates"): | |||
candidates = self.kb.abduce_candidates( | |||
data_sample, max_revision_num, require_more_revision | |||
) | |||
elif hasattr(self.kb, "revise_at_idx"): | |||
candidates = [] | |||
gen = self.search_engine.generator( | |||
data_sample, | |||
max_revision_num=max_revision_num, | |||
require_more_revision=require_more_revision, | |||
) | |||
send_signal = True | |||
for revision_idx in gen: | |||
candidates.extend(self.kb.revise_at_idx(data_sample, revision_idx)) | |||
if len(candidates) > 0 and send_signal: | |||
try: | |||
revision_idx = gen.send("success") | |||
candidates.extend(self.kb.revise_at_idx(data_sample, revision_idx)) | |||
send_signal = False | |||
except StopIteration: | |||
break | |||
else: | |||
raise NotImplementedError( | |||
"The kb should either implement abduce_candidates or revise_at_idx." | |||
) | |||
return candidates | |||
def select_one_candidate(self, data_sample: ListData, candidates: List[List[Any]]): | |||
""" | |||
Revise the pseudo label according to the given indices. | |||
Get one candidate. If multiple candidates exist, return the one with minimum cost. | |||
Parameters | |||
---------- | |||
pred_pseudo_label : list | |||
List of predicted pseudo labels. | |||
y : any | |||
Ground truth for the predicted results. | |||
revision_idx : array-like | |||
Indices of the revisions to retrieve. | |||
The pseudo label to be used for selecting a candidate. | |||
pred_prob : list | |||
Probabilities of the predictions. | |||
candidates : list | |||
List of candidate abduction result. | |||
Returns | |||
------- | |||
list | |||
The revisions according to the given indices. | |||
The chosen candidate based on minimum cost. | |||
If no candidates, an empty list is returned. | |||
""" | |||
return self.kb.revise_by_idx(pred_pseudo_label, y, revision_idx) | |||
if len(candidates) == 0: | |||
return [] | |||
elif len(candidates) == 1: | |||
return candidates[0] | |||
else: | |||
cost_array = self._get_dist_list(data_sample, candidates) | |||
candidate = candidates[np.argmin(cost_array)] | |||
return candidate | |||
def abduce( | |||
self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0 | |||
): | |||
def _get_dist_list(self, data_sample: ListData, candidates: List[List[Any]]): | |||
""" | |||
Perform revision by abduction on the given data. | |||
Get the list of costs between each pseudo label and candidate. | |||
Parameters | |||
---------- | |||
pred_prob : list | |||
List of probabilities for predicted results. | |||
pred_pseudo_label : list | |||
List of predicted pseudo labels. | |||
y : any | |||
Ground truth for the predicted results. | |||
max_revision : int or float, optional | |||
Maximum number of revisions to use. If float, represents the fraction of total revisions to use. | |||
If -1, any revisions are allowed. Defaults to -1. | |||
require_more_revision : int, optional | |||
Number of additional revisions to require. Defaults to 0. | |||
The pseudo label to be used for computing costs of candidates. | |||
pred_prob : list | |||
Probabilities of the predictions. Used when distance function is "confidence". | |||
candidates : list | |||
List of candidate abduction result. | |||
Returns | |||
------- | |||
list | |||
The abduced revisions. | |||
numpy.ndarray | |||
Array of computed costs for each candidate. | |||
""" | |||
symbol_num = len(flatten(pred_pseudo_label)) | |||
max_revision_num = calculate_revision_num(max_revision, symbol_num) | |||
if self.use_zoopt: | |||
solution = self.zoopt_get_solution( | |||
symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num | |||
) | |||
revision_idx = np.where(solution != 0)[0] | |||
candidates = self.revise_by_idx(pred_pseudo_label, y, revision_idx) | |||
else: | |||
candidates = self.kb.abduce_candidates( | |||
pred_pseudo_label, y, max_revision_num, require_more_revision | |||
) | |||
if self.dist_func == "hamming": | |||
return hamming_dist(data_sample["pred_pseudo_label"][0], candidates) | |||
candidate = self._get_one_candidate(pred_pseudo_label, pred_prob, candidates) | |||
return candidate | |||
elif self.dist_func == "confidence": | |||
candidates = [[self.remapping[x] for x in c] for c in candidates] | |||
return confidence_dist(data_sample["pred_prob"][0], candidates) | |||
def batch_abduce( | |||
self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0 | |||
self, | |||
data_samples: ListData, | |||
max_revision: int = -1, | |||
require_more_revision: int = 0, | |||
): | |||
""" | |||
Perform abduction on the given data in batches. | |||
@@ -255,384 +260,13 @@ class ReasonerBase: | |||
list | |||
The abduced revisions in batches. | |||
""" | |||
return [ | |||
abduced_pseudo_label = [ | |||
self.abduce( | |||
_pred_prob, _pred_pseudo_label, _Y, max_revision, require_more_revision | |||
) | |||
for _pred_prob, _pred_pseudo_label, _Y in zip( | |||
pred_prob, pred_pseudo_label, Y | |||
data_sample, | |||
max_revision=max_revision, | |||
require_more_revision=require_more_revision, | |||
) | |||
for data_sample in data_samples | |||
] | |||
# def _batch_abduce_helper(self, args): | |||
# z, prob, y, max_revision, require_more_revision = args | |||
# return self.abduce((z, prob, y), max_revision, require_more_revision) | |||
# def batch_abduce(self, Z, Y, max_revision=-1, require_more_revision=0): | |||
# with Pool(processes=os.cpu_count()) as pool: | |||
# results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]) | |||
# return results | |||
def __call__( | |||
self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0 | |||
): | |||
return self.batch_abduce( | |||
pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision | |||
) | |||
if __name__ == "__main__": | |||
from kb import KBBase, ground_KB, prolog_KB | |||
prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] | |||
prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] | |||
class add_KB(KBBase): | |||
def __init__(self, pseudo_label_list=list(range(10)), | |||
use_cache=True): | |||
super().__init__(pseudo_label_list, use_cache=use_cache) | |||
def logic_forward(self, nums): | |||
return sum(nums) | |||
class add_ground_KB(ground_KB): | |||
def __init__(self, pseudo_label_list=list(range(10)), | |||
GKB_len_list=[2]): | |||
super().__init__(pseudo_label_list, GKB_len_list) | |||
def logic_forward(self, nums): | |||
return sum(nums) | |||
def test_add(reasoner): | |||
res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||
print(res) | |||
res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||
print(res) | |||
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) | |||
print(res) | |||
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) | |||
print(res) | |||
res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) | |||
print(res) | |||
print() | |||
print("add_KB with GKB:") | |||
kb = add_ground_KB() | |||
reasoner = ReasonerBase(kb, "confidence") | |||
test_add(reasoner) | |||
print("add_KB without GKB:") | |||
kb = add_KB() | |||
reasoner = ReasonerBase(kb, "confidence") | |||
test_add(reasoner) | |||
print("add_KB without GKB, no cache") | |||
kb = add_KB(use_cache=False) | |||
reasoner = ReasonerBase(kb, "confidence") | |||
test_add(reasoner) | |||
print("prolog_KB with add.pl:") | |||
kb = prolog_KB(pseudo_label_list=list(range(10)), | |||
pl_file="examples/mnist_add/datasets/add.pl") | |||
reasoner = ReasonerBase(kb, "confidence") | |||
test_add(reasoner) | |||
print("prolog_KB with add.pl using zoopt:") | |||
kb = prolog_KB( | |||
pseudo_label_list=list(range(10)), | |||
pl_file="examples/mnist_add/datasets/add.pl", | |||
) | |||
reasoner = ReasonerBase(kb, "confidence", use_zoopt=True) | |||
test_add(reasoner) | |||
print("add_KB with multiple inputs at once:") | |||
multiple_prob = [[ | |||
[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
], | |||
[ | |||
[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
]] | |||
kb = add_KB() | |||
reasoner = ReasonerBase(kb, "confidence") | |||
res = reasoner.batch_abduce( | |||
multiple_prob, | |||
[[1, 1], [1, 2]], | |||
[4, 8], | |||
max_revision=2, | |||
require_more_revision=0, | |||
) | |||
print(res) | |||
res = reasoner.batch_abduce( | |||
multiple_prob, | |||
[[1, 1], [1, 2]], | |||
[4, 8], | |||
max_revision=2, | |||
require_more_revision=1, | |||
) | |||
print(res) | |||
print() | |||
class HWF_KB(KBBase): | |||
def __init__( | |||
self, | |||
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", | |||
"+", "-", "times", "div"], | |||
max_err=1e-3, | |||
): | |||
super().__init__(pseudo_label_list, max_err) | |||
def _valid_candidate(self, formula): | |||
if len(formula) % 2 == 0: | |||
return False | |||
for i in range(len(formula)): | |||
if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: | |||
return False | |||
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | |||
return False | |||
return True | |||
def logic_forward(self, formula): | |||
if not self._valid_candidate(formula): | |||
return np.inf | |||
mapping = {str(i): str(i) for i in range(1, 10)} | |||
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | |||
formula = [mapping[f] for f in formula] | |||
return eval("".join(formula)) | |||
class HWF_ground_KB(ground_KB): | |||
def __init__( | |||
self, | |||
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", | |||
"+", "-", "times", "div"], | |||
GKB_len_list=[1, 3, 5, 7], | |||
max_err=1e-3, | |||
): | |||
super().__init__(pseudo_label_list, GKB_len_list, max_err) | |||
def _valid_candidate(self, formula): | |||
if len(formula) % 2 == 0: | |||
return False | |||
for i in range(len(formula)): | |||
if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: | |||
return False | |||
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | |||
return False | |||
return True | |||
def logic_forward(self, formula): | |||
if not self._valid_candidate(formula): | |||
return np.inf | |||
mapping = {str(i): str(i) for i in range(1, 10)} | |||
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | |||
formula = [mapping[f] for f in formula] | |||
return eval("".join(formula)) | |||
def test_hwf(reasoner): | |||
res = reasoner.batch_abduce( | |||
[None], | |||
[["5", "+", "2"]], | |||
[3], | |||
max_revision=2, | |||
require_more_revision=0, | |||
) | |||
print(res) | |||
res = reasoner.batch_abduce( | |||
[None], | |||
[["5", "+", "9"]], | |||
[65], | |||
max_revision=3, | |||
require_more_revision=0, | |||
) | |||
print(res) | |||
res = reasoner.batch_abduce( | |||
[None], | |||
[["5", "8", "8", "8", "8"]], | |||
[3.17], | |||
max_revision=5, | |||
require_more_revision=3, | |||
) | |||
print(res) | |||
print() | |||
def test_hwf_multiple(reasoner, max_revisions): | |||
res = reasoner.batch_abduce( | |||
[None, None], | |||
[["5", "+", "2"], ["5", "+", "9"]], | |||
[3, 64], | |||
max_revision=max_revisions[0], | |||
require_more_revision=0, | |||
) | |||
print(res) | |||
res = reasoner.batch_abduce( | |||
[None, None], | |||
[["5", "+", "2"], ["5", "+", "9"]], | |||
[3, 64], | |||
max_revision=max_revisions[1], | |||
require_more_revision=0, | |||
) | |||
print(res) | |||
res = reasoner.batch_abduce( | |||
[None, None], | |||
[["5", "+", "2"], ["5", "+", "9"]], | |||
[3, 65], | |||
max_revision=max_revisions[2], | |||
require_more_revision=0, | |||
) | |||
print(res) | |||
print() | |||
print("HWF_KB with GKB, max_err=0.1") | |||
kb = HWF_ground_KB(GKB_len_list=[1, 3, 5], max_err=0.1) | |||
reasoner = ReasonerBase(kb, "hamming") | |||
test_hwf(reasoner) | |||
print("HWF_KB without GKB, max_err=0.1") | |||
kb = HWF_KB(max_err=0.1) | |||
reasoner = ReasonerBase(kb, "hamming") | |||
test_hwf(reasoner) | |||
print("HWF_KB with GKB, max_err=1") | |||
kb = HWF_ground_KB(GKB_len_list=[1, 3, 5], max_err=1) | |||
reasoner = ReasonerBase(kb, "hamming") | |||
test_hwf(reasoner) | |||
print("HWF_KB without GKB, max_err=1") | |||
kb = HWF_KB(max_err=1) | |||
reasoner = ReasonerBase(kb, "hamming") | |||
test_hwf(reasoner) | |||
print("HWF_KB with multiple inputs at once:") | |||
kb = HWF_KB(max_err=0.1) | |||
reasoner = ReasonerBase(kb, "hamming") | |||
test_hwf_multiple(reasoner, max_revisions=[1,3,3]) | |||
print("max_revision is float") | |||
test_hwf_multiple(reasoner, max_revisions=[0.5,0.9,0.9]) | |||
class HED_prolog_KB(prolog_KB): | |||
def __init__(self, pseudo_label_list, pl_file): | |||
super().__init__(pseudo_label_list, pl_file) | |||
def consist_rule(self, exs, rules): | |||
rules = str(rules).replace("'", "") | |||
pl_query = "eval_inst_feature(%s, %s)." % (exs, rules) | |||
return len(list(self.prolog.query(pl_query))) != 0 | |||
def abduce_rules(self, pred_res): | |||
pl_query = "consistent_inst_feature(%s, X)." % pred_res | |||
prolog_result = list(self.prolog.query(pl_query)) | |||
if len(prolog_result) == 0: | |||
return None | |||
prolog_rules = prolog_result[0]["X"] | |||
rules = [rule.value for rule in prolog_rules] | |||
return rules | |||
class HED_Reasoner(ReasonerBase): | |||
def __init__(self, kb, dist_func="hamming"): | |||
super().__init__(kb, dist_func, use_zoopt=True) | |||
def _revise_by_idxs(self, pred_res, y, all_revision_flag, idxs): | |||
pred = [] | |||
k = [] | |||
revision_flag = [] | |||
for idx in idxs: | |||
pred.append(pred_res[idx]) | |||
k.append(y[idx]) | |||
revision_flag += list(all_revision_flag[idx]) | |||
revision_idx = np.where(np.array(revision_flag) != 0)[0] | |||
candidate = self.revise_by_idx(pred, k, revision_idx) | |||
return candidate | |||
def zoopt_revision_score(self, symbol_num, pred_res, pred_prob, y, sol): | |||
all_revision_flag = reform_idx(sol.get_x(), pred_res) | |||
lefted_idxs = [i for i in range(len(pred_res))] | |||
candidate_size = [] | |||
while lefted_idxs: | |||
idxs = [] | |||
idxs.append(lefted_idxs.pop(0)) | |||
max_candidate_idxs = [] | |||
found = False | |||
for idx in range(-1, len(pred_res)): | |||
if (not idx in idxs) and (idx >= 0): | |||
idxs.append(idx) | |||
candidate = self._revise_by_idxs( | |||
pred_res, y, all_revision_flag, idxs | |||
) | |||
if len(candidate) == 0: | |||
if len(idxs) > 1: | |||
idxs.pop() | |||
else: | |||
if len(idxs) > len(max_candidate_idxs): | |||
found = True | |||
max_candidate_idxs = idxs.copy() | |||
removed = [i for i in lefted_idxs if i in max_candidate_idxs] | |||
if found: | |||
candidate_size.append(len(removed) + 1) | |||
lefted_idxs = [ | |||
i for i in lefted_idxs if i not in max_candidate_idxs | |||
] | |||
candidate_size.sort() | |||
score = 0 | |||
import math | |||
for i in range(0, len(candidate_size)): | |||
score -= math.exp(-i) * candidate_size[i] | |||
return score | |||
def abduce_rules(self, pred_res): | |||
return self.kb.abduce_rules(pred_res) | |||
kb = HED_prolog_KB( | |||
pseudo_label_list=[1, 0, "+", "="], | |||
pl_file="examples/hed/datasets/learn_add.pl", | |||
) | |||
reasoner = HED_Reasoner(kb) | |||
consist_exs = [ | |||
[1, 1, "+", 0, "=", 1, 1], | |||
[1, "+", 1, "=", 1, 0], | |||
[0, "+", 0, "=", 0], | |||
] | |||
inconsist_exs1 = [ | |||
[1, 1, "+", 0, "=", 1, 1], | |||
[1, "+", 1, "=", 1, 0], | |||
[0, "+", 0, "=", 0], | |||
[0, "+", 0, "=", 1], | |||
] | |||
inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] | |||
rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"] | |||
print("HED_kb logic forward") | |||
print(kb.logic_forward(consist_exs)) | |||
print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2)) | |||
print() | |||
print("HED_kb consist rule") | |||
print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules)) | |||
print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules)) | |||
print() | |||
print("HED_Reasoner abduce") | |||
res = reasoner.abduce( | |||
[[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs) | |||
) | |||
print(res) | |||
res = reasoner.abduce( | |||
[[[None]]] * len(inconsist_exs1), inconsist_exs1, [None] * len(inconsist_exs1) | |||
) | |||
print(res) | |||
res = reasoner.abduce( | |||
[[[None]]] * len(inconsist_exs2), inconsist_exs2, [None] * len(inconsist_exs2) | |||
) | |||
print(res) | |||
print() | |||
print("HED_Reasoner abduce rules") | |||
abduced_rules = reasoner.abduce_rules(consist_exs) | |||
print(abduced_rules) | |||
data_samples.abduced_pseudo_label = abduced_pseudo_label | |||
return abduced_pseudo_label |
@@ -0,0 +1,49 @@ | |||
from abc import ABC, abstractmethod | |||
from itertools import product | |||
from typing import Any, List, Tuple, Union | |||
import numpy | |||
from ..structures import ListData | |||
from .base_kb import BaseKB | |||
class SearchBasedKB(BaseKB, ABC): | |||
def __init__( | |||
self, | |||
pseudo_label_list: List, | |||
) -> None: | |||
super().__init__(pseudo_label_list) | |||
@abstractmethod | |||
def check_equal(self, data_sample: ListData, y: Any): | |||
"""Placeholder for check_equal.""" | |||
pass | |||
def revise_at_idx( | |||
self, | |||
data_sample: ListData, | |||
revision_idx: Union[List, Tuple, numpy.ndarray], | |||
): | |||
candidates = [] | |||
abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx)) | |||
for c in abduce_c: | |||
new_data_sample = data_sample.clone() | |||
candidate = new_data_sample["pred_pseudo_label"][0].copy() | |||
for i, idx in enumerate(revision_idx): | |||
candidate[idx] = c[i] | |||
new_data_sample["pred_pseudo_label"][0] = candidate | |||
if self.check_equal(new_data_sample, new_data_sample["Y"][0]): | |||
candidates.append(candidate) | |||
return candidates | |||
# TODO: When the output is excessively long, use ellipses as a substitute. | |||
def __repr__(self): | |||
return ( | |||
f"<{self.__class__.__name__}(\n" | |||
f" pseudo_label_list: {self.pseudo_label_list!r}\n" | |||
f" search_strategy: {self.search_strategy!r}\n" | |||
f" use_cache: {self.use_cache!r}\n" | |||
f" cache_root: {self.cache_root!r}\n" | |||
f") at {hex(id(self))}>" | |||
) |
@@ -0,0 +1,2 @@ | |||
from .base_search_engine import BaseSearchEngine | |||
from .bfs import BFS |
@@ -0,0 +1,13 @@ | |||
from abc import ABC, abstractmethod | |||
from typing import List, Tuple, Union | |||
import numpy | |||
from ...structures import ListData | |||
class BaseSearchEngine(ABC): | |||
@abstractmethod | |||
def generator(data_sample: ListData) -> Union[List, Tuple, numpy.ndarray]: | |||
"""Placeholder for the generator of revision_idx.""" | |||
pass |
@@ -0,0 +1,28 @@ | |||
from itertools import combinations | |||
from typing import List, Tuple, Union | |||
import numpy | |||
from ...structures import ListData | |||
from .base_search_engine import BaseSearchEngine | |||
class BFS(BaseSearchEngine): | |||
def __init__(self) -> None: | |||
pass | |||
def generator( | |||
self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0 | |||
) -> Union[List, Tuple, numpy.ndarray]: | |||
symbol_num = data_sample["symbol_num"] | |||
max_revision_num = min(max_revision_num, symbol_num) | |||
real_end = max_revision_num | |||
for revision_num in range(max_revision_num + 1): | |||
if revision_num > real_end: | |||
break | |||
revision_idx_tuple = combinations(range(symbol_num), revision_num) | |||
for revision_idx in revision_idx_tuple: | |||
received = yield revision_idx | |||
if received == "success": | |||
real_end = min(symbol_num, revision_num + require_more_revision) |
@@ -0,0 +1,42 @@ | |||
from typing import List, Tuple, Union | |||
import numpy as np | |||
from zoopt import Dimension, Objective, Opt, Parameter, Solution | |||
from ...structures import ListData | |||
from ..reasoner import ReasonerBase | |||
from ..search_based_kb import SearchBasedKB | |||
from .base_search_engine import BaseSearchEngine | |||
class Zoopt(BaseSearchEngine): | |||
def __init__(self, reasoner: ReasonerBase, kb: SearchBasedKB) -> None: | |||
self.reasoner = reasoner | |||
self.kb = kb | |||
def score_func(self, data_sample: ListData, solution: Solution): | |||
revision_idx = np.where(solution.get_x() != 0)[0] | |||
candidates = self.kb.revise_at_idx(data_sample, revision_idx) | |||
if len(candidates) > 0: | |||
return np.min(self.reasoner._get_dist_list(data_sample, candidates)) | |||
else: | |||
return data_sample["symbol_num"] | |||
@staticmethod | |||
def constraint(solution: Solution, max_revision_num: int): | |||
x = solution.get_x() | |||
return max_revision_num - x.sum() | |||
def generator( | |||
self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0 | |||
) -> Union[List, Tuple, np.ndarray]: | |||
symbol_num = data_sample["symbol_num"] | |||
dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num) | |||
objective = Objective( | |||
lambda solution: self.score_func(self, data_sample, solution), | |||
dim=dimension, | |||
constraint=lambda solution: self.constraint(solution, max_revision_num), | |||
) | |||
parameter = Parameter(budget=100, intermediate_result=False, autoset=True) | |||
solution = Opt.min(objective, parameter).get_x() | |||
yield solution |
@@ -0,0 +1,2 @@ | |||
from .base_data_element import BaseDataElement | |||
from .list_data import ListData |
@@ -0,0 +1,629 @@ | |||
# Copyright (c) OpenMMLab. All rights reserved. | |||
import copy | |||
from typing import Any, Iterator, Optional, Tuple, Type, Union | |||
import numpy as np | |||
import torch | |||
class BaseDataElement: | |||
"""A base data interface that supports Tensor-like and dict-like | |||
operations. | |||
A typical data elements refer to predicted results or ground truth labels | |||
on a task, such as predicted bboxes, instance masks, semantic | |||
segmentation masks, etc. Because groundtruth labels and predicted results | |||
often have similar properties (for example, the predicted bboxes and the | |||
groundtruth bboxes), MMEngine uses the same abstract data interface to | |||
encapsulate predicted results and groundtruth labels, and it is recommended | |||
to use different name conventions to distinguish them, such as using | |||
``gt_instances`` and ``pred_instances`` to distinguish between labels and | |||
predicted results. Additionally, we distinguish data elements at instance | |||
level, pixel level, and label level. Each of these types has its own | |||
characteristics. Therefore, MMEngine defines the base class | |||
``BaseDataElement``, and implement ``InstanceData``, ``PixelData``, and | |||
``LabelData`` inheriting from ``BaseDataElement`` to represent different | |||
types of ground truth labels or predictions. | |||
Another common data element is sample data. A sample data consists of input | |||
data (such as an image) and its annotations and predictions. In general, | |||
an image can have multiple types of annotations and/or predictions at the | |||
same time (for example, both pixel-level semantic segmentation annotations | |||
and instance-level detection bboxes annotations). All labels and | |||
predictions of a training sample are often passed between Dataset, Model, | |||
Visualizer, and Evaluator components. In order to simplify the interface | |||
between components, we can treat them as a large data element and | |||
encapsulate them. Such data elements are generally called XXDataSample in | |||
the OpenMMLab. Therefore, Similar to `nn.Module`, the `BaseDataElement` | |||
allows `BaseDataElement` as its attribute. Such a class generally | |||
encapsulates all the data of a sample in the algorithm library, and its | |||
attributes generally are various types of data elements. For example, | |||
MMDetection is assigned by the BaseDataElement to encapsulate all the data | |||
elements of the sample labeling and prediction of a sample in the | |||
algorithm library. | |||
The attributes in ``BaseDataElement`` are divided into two parts, | |||
the ``metainfo`` and the ``data`` respectively. | |||
- ``metainfo``: Usually contains the | |||
information about the image such as filename, | |||
image_shape, pad_shape, etc. The attributes can be accessed or | |||
modified by dict-like or object-like operations, such as | |||
``.`` (for data access and modification), ``in``, ``del``, | |||
``pop(str)``, ``get(str)``, ``metainfo_keys()``, | |||
``metainfo_values()``, ``metainfo_items()``, ``set_metainfo()`` (for | |||
set or change key-value pairs in metainfo). | |||
- ``data``: Annotations or model predictions are | |||
stored. The attributes can be accessed or modified by | |||
dict-like or object-like operations, such as | |||
``.``, ``in``, ``del``, ``pop(str)``, ``get(str)``, ``keys()``, | |||
``values()``, ``items()``. Users can also apply tensor-like | |||
methods to all :obj:`torch.Tensor` in the ``data_fields``, | |||
such as ``.cuda()``, ``.cpu()``, ``.numpy()``, ``.to()``, | |||
``to_tensor()``, ``.detach()``. | |||
Args: | |||
metainfo (dict, optional): A dict contains the meta information | |||
of single image, such as ``dict(img_shape=(512, 512, 3), | |||
scale_factor=(1, 1, 1, 1))``. Defaults to None. | |||
kwargs (dict, optional): A dict contains annotations of single image or | |||
model predictions. Defaults to None. | |||
Examples: | |||
>>> import torch | |||
>>> from mmengine.structures import BaseDataElement | |||
>>> gt_instances = BaseDataElement() | |||
>>> bboxes = torch.rand((5, 4)) | |||
>>> scores = torch.rand((5,)) | |||
>>> img_id = 0 | |||
>>> img_shape = (800, 1333) | |||
>>> gt_instances = BaseDataElement( | |||
... metainfo=dict(img_id=img_id, img_shape=img_shape), | |||
... bboxes=bboxes, scores=scores) | |||
>>> gt_instances = BaseDataElement( | |||
... metainfo=dict(img_id=img_id, img_shape=(640, 640))) | |||
>>> # new | |||
>>> gt_instances1 = gt_instances.new( | |||
... metainfo=dict(img_id=1, img_shape=(640, 640)), | |||
... bboxes=torch.rand((5, 4)), | |||
... scores=torch.rand((5,))) | |||
>>> gt_instances2 = gt_instances1.new() | |||
>>> # add and process property | |||
>>> gt_instances = BaseDataElement() | |||
>>> gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100))) | |||
>>> assert 'img_shape' in gt_instances.metainfo_keys() | |||
>>> assert 'img_shape' in gt_instances | |||
>>> assert 'img_shape' not in gt_instances.keys() | |||
>>> assert 'img_shape' in gt_instances.all_keys() | |||
>>> print(gt_instances.img_shape) | |||
(100, 100) | |||
>>> gt_instances.scores = torch.rand((5,)) | |||
>>> assert 'scores' in gt_instances.keys() | |||
>>> assert 'scores' in gt_instances | |||
>>> assert 'scores' in gt_instances.all_keys() | |||
>>> assert 'scores' not in gt_instances.metainfo_keys() | |||
>>> print(gt_instances.scores) | |||
tensor([0.5230, 0.7885, 0.2426, 0.3911, 0.4876]) | |||
>>> gt_instances.bboxes = torch.rand((5, 4)) | |||
>>> assert 'bboxes' in gt_instances.keys() | |||
>>> assert 'bboxes' in gt_instances | |||
>>> assert 'bboxes' in gt_instances.all_keys() | |||
>>> assert 'bboxes' not in gt_instances.metainfo_keys() | |||
>>> print(gt_instances.bboxes) | |||
tensor([[0.0900, 0.0424, 0.1755, 0.4469], | |||
[0.8648, 0.0592, 0.3484, 0.0913], | |||
[0.5808, 0.1909, 0.6165, 0.7088], | |||
[0.5490, 0.4209, 0.9416, 0.2374], | |||
[0.3652, 0.1218, 0.8805, 0.7523]]) | |||
>>> # delete and change property | |||
>>> gt_instances = BaseDataElement( | |||
... metainfo=dict(img_id=0, img_shape=(640, 640)), | |||
... bboxes=torch.rand((6, 4)), scores=torch.rand((6,))) | |||
>>> gt_instances.set_metainfo(dict(img_shape=(1280, 1280))) | |||
>>> gt_instances.img_shape # (1280, 1280) | |||
>>> gt_instances.bboxes = gt_instances.bboxes * 2 | |||
>>> gt_instances.get('img_shape', None) # (1280, 1280) | |||
>>> gt_instances.get('bboxes', None) # 6x4 tensor | |||
>>> del gt_instances.img_shape | |||
>>> del gt_instances.bboxes | |||
>>> assert 'img_shape' not in gt_instances | |||
>>> assert 'bboxes' not in gt_instances | |||
>>> gt_instances.pop('img_shape', None) # None | |||
>>> gt_instances.pop('bboxes', None) # None | |||
>>> # Tensor-like | |||
>>> cuda_instances = gt_instances.cuda() | |||
>>> cuda_instances = gt_instances.to('cuda:0') | |||
>>> cpu_instances = cuda_instances.cpu() | |||
>>> cpu_instances = cuda_instances.to('cpu') | |||
>>> fp16_instances = cuda_instances.to( | |||
... device=None, dtype=torch.float16, non_blocking=False, | |||
... copy=False, memory_format=torch.preserve_format) | |||
>>> cpu_instances = cuda_instances.detach() | |||
>>> np_instances = cpu_instances.numpy() | |||
>>> metainfo = dict(img_shape=(800, 1196, 3)) | |||
>>> gt_instances = BaseDataElement( | |||
... metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3])) | |||
>>> sample = BaseDataElement(metainfo=metainfo, | |||
... gt_instances=gt_instances) | |||
>>> print(sample) | |||
<BaseDataElement( | |||
META INFORMATION | |||
img_shape: (800, 1196, 3) | |||
DATA FIELDS | |||
gt_instances: <BaseDataElement( | |||
META INFORMATION | |||
img_shape: (800, 1196, 3) | |||
DATA FIELDS | |||
det_labels: tensor([0, 1, 2, 3]) | |||
) at 0x7f0ec5eadc70> | |||
) at 0x7f0fea49e130> | |||
>>> # inheritance | |||
>>> class DetDataSample(BaseDataElement): | |||
... @property | |||
... def proposals(self): | |||
... return self._proposals | |||
... @proposals.setter | |||
... def proposals(self, value): | |||
... self.set_field(value, '_proposals', dtype=BaseDataElement) | |||
... @proposals.deleter | |||
... def proposals(self): | |||
... del self._proposals | |||
... @property | |||
... def gt_instances(self): | |||
... return self._gt_instances | |||
... @gt_instances.setter | |||
... def gt_instances(self, value): | |||
... self.set_field(value, '_gt_instances', | |||
... dtype=BaseDataElement) | |||
... @gt_instances.deleter | |||
... def gt_instances(self): | |||
... del self._gt_instances | |||
... @property | |||
... def pred_instances(self): | |||
... return self._pred_instances | |||
... @pred_instances.setter | |||
... def pred_instances(self, value): | |||
... self.set_field(value, '_pred_instances', | |||
... dtype=BaseDataElement) | |||
... @pred_instances.deleter | |||
... def pred_instances(self): | |||
... del self._pred_instances | |||
>>> det_sample = DetDataSample() | |||
>>> proposals = BaseDataElement(bboxes=torch.rand((5, 4))) | |||
>>> det_sample.proposals = proposals | |||
>>> assert 'proposals' in det_sample | |||
>>> assert det_sample.proposals == proposals | |||
>>> del det_sample.proposals | |||
>>> assert 'proposals' not in det_sample | |||
>>> with self.assertRaises(AssertionError): | |||
... det_sample.proposals = torch.rand((5, 4)) | |||
""" | |||
def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None: | |||
self._metainfo_fields: set = set() | |||
self._data_fields: set = set() | |||
if metainfo is not None: | |||
self.set_metainfo(metainfo=metainfo) | |||
if kwargs: | |||
self.set_data(kwargs) | |||
def set_metainfo(self, metainfo: dict) -> None: | |||
"""Set or change key-value pairs in ``metainfo_field`` by parameter | |||
``metainfo``. | |||
Args: | |||
metainfo (dict): A dict contains the meta information | |||
of image, such as ``img_shape``, ``scale_factor``, etc. | |||
""" | |||
assert isinstance( | |||
metainfo, dict | |||
), f"metainfo should be a ``dict`` but got {type(metainfo)}" | |||
meta = copy.deepcopy(metainfo) | |||
for k, v in meta.items(): | |||
self.set_field(name=k, value=v, field_type="metainfo", dtype=None) | |||
def set_data(self, data: dict) -> None: | |||
"""Set or change key-value pairs in ``data_field`` by parameter | |||
``data``. | |||
Args: | |||
data (dict): A dict contains annotations of image or | |||
model predictions. | |||
""" | |||
assert isinstance(data, dict), f"data should be a `dict` but got {data}" | |||
for k, v in data.items(): | |||
# Use `setattr()` rather than `self.set_field` to allow `set_data` | |||
# to set property method. | |||
setattr(self, k, v) | |||
def update(self, instance: "BaseDataElement") -> None: | |||
"""The update() method updates the BaseDataElement with the elements | |||
from another BaseDataElement object. | |||
Args: | |||
instance (BaseDataElement): Another BaseDataElement object for | |||
update the current object. | |||
""" | |||
assert isinstance( | |||
instance, BaseDataElement | |||
), f"instance should be a `BaseDataElement` but got {type(instance)}" | |||
self.set_metainfo(dict(instance.metainfo_items())) | |||
self.set_data(dict(instance.items())) | |||
def new(self, *, metainfo: Optional[dict] = None, **kwargs) -> "BaseDataElement": | |||
"""Return a new data element with same type. If ``metainfo`` and | |||
``data`` are None, the new data element will have same metainfo and | |||
data. If metainfo or data is not None, the new result will overwrite it | |||
with the input value. | |||
Args: | |||
metainfo (dict, optional): A dict contains the meta information | |||
of image, such as ``img_shape``, ``scale_factor``, etc. | |||
Defaults to None. | |||
kwargs (dict): A dict contains annotations of image or | |||
model predictions. | |||
Returns: | |||
BaseDataElement: A new data element with same type. | |||
""" | |||
new_data = self.__class__() | |||
if metainfo is not None: | |||
new_data.set_metainfo(metainfo) | |||
else: | |||
new_data.set_metainfo(dict(self.metainfo_items())) | |||
if kwargs: | |||
new_data.set_data(kwargs) | |||
else: | |||
new_data.set_data(dict(self.items())) | |||
return new_data | |||
def clone(self): | |||
"""Deep copy the current data element. | |||
Returns: | |||
BaseDataElement: The copy of current data element. | |||
""" | |||
clone_data = self.__class__() | |||
clone_data.set_metainfo(dict(self.metainfo_items())) | |||
clone_data.set_data(dict(self.items())) | |||
return clone_data | |||
def keys(self) -> list: | |||
""" | |||
Returns: | |||
list: Contains all keys in data_fields. | |||
""" | |||
# We assume that the name of the attribute related to property is | |||
# '_' + the name of the property. We use this rule to filter out | |||
# private keys. | |||
# TODO: Use a more robust way to solve this problem | |||
private_keys = { | |||
"_" + key | |||
for key in self._data_fields | |||
if isinstance(getattr(type(self), key, None), property) | |||
} | |||
return list(self._data_fields - private_keys) | |||
def metainfo_keys(self) -> list: | |||
""" | |||
Returns: | |||
list: Contains all keys in metainfo_fields. | |||
""" | |||
return list(self._metainfo_fields) | |||
def values(self) -> list: | |||
""" | |||
Returns: | |||
list: Contains all values in data. | |||
""" | |||
return [getattr(self, k) for k in self.keys()] | |||
def metainfo_values(self) -> list: | |||
""" | |||
Returns: | |||
list: Contains all values in metainfo. | |||
""" | |||
return [getattr(self, k) for k in self.metainfo_keys()] | |||
def all_keys(self) -> list: | |||
""" | |||
Returns: | |||
list: Contains all keys in metainfo and data. | |||
""" | |||
return self.metainfo_keys() + self.keys() | |||
def all_values(self) -> list: | |||
""" | |||
Returns: | |||
list: Contains all values in metainfo and data. | |||
""" | |||
return self.metainfo_values() + self.values() | |||
def all_items(self) -> Iterator[Tuple[str, Any]]: | |||
""" | |||
Returns: | |||
iterator: An iterator object whose element is (key, value) tuple | |||
pairs for ``metainfo`` and ``data``. | |||
""" | |||
for k in self.all_keys(): | |||
yield (k, getattr(self, k)) | |||
def items(self) -> Iterator[Tuple[str, Any]]: | |||
""" | |||
Returns: | |||
iterator: An iterator object whose element is (key, value) tuple | |||
pairs for ``data``. | |||
""" | |||
for k in self.keys(): | |||
yield (k, getattr(self, k)) | |||
def metainfo_items(self) -> Iterator[Tuple[str, Any]]: | |||
""" | |||
Returns: | |||
iterator: An iterator object whose element is (key, value) tuple | |||
pairs for ``metainfo``. | |||
""" | |||
for k in self.metainfo_keys(): | |||
yield (k, getattr(self, k)) | |||
@property | |||
def metainfo(self) -> dict: | |||
"""dict: A dict contains metainfo of current data element.""" | |||
return dict(self.metainfo_items()) | |||
def __setattr__(self, name: str, value: Any): | |||
"""setattr is only used to set data.""" | |||
if name in ("_metainfo_fields", "_data_fields"): | |||
if not hasattr(self, name): | |||
super().__setattr__(name, value) | |||
else: | |||
raise AttributeError( | |||
f"{name} has been used as a " | |||
"private attribute, which is immutable." | |||
) | |||
else: | |||
self.set_field(name=name, value=value, field_type="data", dtype=None) | |||
def __delattr__(self, item: str): | |||
"""Delete the item in dataelement. | |||
Args: | |||
item (str): The key to delete. | |||
""" | |||
if item in ("_metainfo_fields", "_data_fields"): | |||
raise AttributeError( | |||
f"{item} has been used as a " "private attribute, which is immutable." | |||
) | |||
super().__delattr__(item) | |||
if item in self._metainfo_fields: | |||
self._metainfo_fields.remove(item) | |||
elif item in self._data_fields: | |||
self._data_fields.remove(item) | |||
# dict-like methods | |||
__delitem__ = __delattr__ | |||
def get(self, key, default=None) -> Any: | |||
"""Get property in data and metainfo as the same as python.""" | |||
# Use `getattr()` rather than `self.__dict__.get()` to allow getting | |||
# properties. | |||
return getattr(self, key, default) | |||
def pop(self, *args) -> Any: | |||
"""Pop property in data and metainfo as the same as python.""" | |||
assert len(args) < 3, "``pop`` get more than 2 arguments" | |||
name = args[0] | |||
if name in self._metainfo_fields: | |||
self._metainfo_fields.remove(args[0]) | |||
return self.__dict__.pop(*args) | |||
elif name in self._data_fields: | |||
self._data_fields.remove(args[0]) | |||
return self.__dict__.pop(*args) | |||
# with default value | |||
elif len(args) == 2: | |||
return args[1] | |||
else: | |||
# don't just use 'self.__dict__.pop(*args)' for only popping key in | |||
# metainfo or data | |||
raise KeyError(f"{args[0]} is not contained in metainfo or data") | |||
def __contains__(self, item: str) -> bool: | |||
"""Whether the item is in dataelement. | |||
Args: | |||
item (str): The key to inquire. | |||
""" | |||
return item in self._data_fields or item in self._metainfo_fields | |||
def set_field( | |||
self, | |||
value: Any, | |||
name: str, | |||
dtype: Optional[Union[Type, Tuple[Type, ...]]] = None, | |||
field_type: str = "data", | |||
) -> None: | |||
"""Special method for set union field, used as property.setter | |||
functions.""" | |||
assert field_type in ["metainfo", "data"] | |||
if dtype is not None: | |||
assert isinstance( | |||
value, dtype | |||
), f"{value} should be a {dtype} but got {type(value)}" | |||
if field_type == "metainfo": | |||
if name in self._data_fields: | |||
raise AttributeError( | |||
f"Cannot set {name} to be a field of metainfo " | |||
f"because {name} is already a data field" | |||
) | |||
self._metainfo_fields.add(name) | |||
else: | |||
if name in self._metainfo_fields: | |||
raise AttributeError( | |||
f"Cannot set {name} to be a field of data " | |||
f"because {name} is already a metainfo field" | |||
) | |||
self._data_fields.add(name) | |||
super().__setattr__(name, value) | |||
# Tensor-like methods | |||
def to(self, *args, **kwargs) -> "BaseDataElement": | |||
"""Apply same name function to all tensors in data_fields.""" | |||
new_data = self.new() | |||
for k, v in self.items(): | |||
if hasattr(v, "to"): | |||
v = v.to(*args, **kwargs) | |||
data = {k: v} | |||
new_data.set_data(data) | |||
return new_data | |||
# Tensor-like methods | |||
def cpu(self) -> "BaseDataElement": | |||
"""Convert all tensors to CPU in data.""" | |||
new_data = self.new() | |||
for k, v in self.items(): | |||
if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
v = v.cpu() | |||
data = {k: v} | |||
new_data.set_data(data) | |||
return new_data | |||
# Tensor-like methods | |||
def cuda(self) -> "BaseDataElement": | |||
"""Convert all tensors to GPU in data.""" | |||
new_data = self.new() | |||
for k, v in self.items(): | |||
if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
v = v.cuda() | |||
data = {k: v} | |||
new_data.set_data(data) | |||
return new_data | |||
# Tensor-like methods | |||
def npu(self) -> "BaseDataElement": | |||
"""Convert all tensors to NPU in data.""" | |||
new_data = self.new() | |||
for k, v in self.items(): | |||
if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
v = v.npu() | |||
data = {k: v} | |||
new_data.set_data(data) | |||
return new_data | |||
def mlu(self) -> "BaseDataElement": | |||
"""Convert all tensors to MLU in data.""" | |||
new_data = self.new() | |||
for k, v in self.items(): | |||
if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
v = v.mlu() | |||
data = {k: v} | |||
new_data.set_data(data) | |||
return new_data | |||
# Tensor-like methods | |||
def detach(self) -> "BaseDataElement": | |||
"""Detach all tensors in data.""" | |||
new_data = self.new() | |||
for k, v in self.items(): | |||
if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
v = v.detach() | |||
data = {k: v} | |||
new_data.set_data(data) | |||
return new_data | |||
# Tensor-like methods | |||
def numpy(self) -> "BaseDataElement": | |||
"""Convert all tensors to np.ndarray in data.""" | |||
new_data = self.new() | |||
for k, v in self.items(): | |||
if isinstance(v, (torch.Tensor, BaseDataElement)): | |||
v = v.detach().cpu().numpy() | |||
data = {k: v} | |||
new_data.set_data(data) | |||
return new_data | |||
def to_tensor(self) -> "BaseDataElement": | |||
"""Convert all np.ndarray to tensor in data.""" | |||
new_data = self.new() | |||
for k, v in self.items(): | |||
data = {} | |||
if isinstance(v, np.ndarray): | |||
v = torch.from_numpy(v) | |||
data[k] = v | |||
elif isinstance(v, BaseDataElement): | |||
v = v.to_tensor() | |||
data[k] = v | |||
new_data.set_data(data) | |||
return new_data | |||
def to_dict(self) -> dict: | |||
"""Convert BaseDataElement to dict.""" | |||
return { | |||
k: v.to_dict() if isinstance(v, BaseDataElement) else v | |||
for k, v in self.all_items() | |||
} | |||
def __repr__(self) -> str: | |||
"""Represent the object.""" | |||
def _addindent(s_: str, num_spaces: int) -> str: | |||
"""This func is modified from `pytorch` https://github.com/pytorch/ | |||
pytorch/blob/b17b2b1cc7b017c3daaeff8cc7ec0f514d42ec37/torch/nn/modu | |||
les/module.py#L29. | |||
Args: | |||
s_ (str): The string to add spaces. | |||
num_spaces (int): The num of space to add. | |||
Returns: | |||
str: The string after add indent. | |||
""" | |||
s = s_.split("\n") | |||
# don't do anything for single-line stuff | |||
if len(s) == 1: | |||
return s_ | |||
first = s.pop(0) | |||
s = [(num_spaces * " ") + line for line in s] | |||
s = "\n".join(s) # type: ignore | |||
s = first + "\n" + s # type: ignore | |||
return s # type: ignore | |||
def dump(obj: Any) -> str: | |||
"""Represent the object. | |||
Args: | |||
obj (Any): The obj to represent. | |||
Returns: | |||
str: The represented str. | |||
""" | |||
_repr = "" | |||
if isinstance(obj, dict): | |||
for k, v in obj.items(): | |||
_repr += f"\n{k}: {_addindent(dump(v), 4)}" | |||
elif isinstance(obj, BaseDataElement): | |||
_repr += "\n\n META INFORMATION" | |||
metainfo_items = dict(obj.metainfo_items()) | |||
_repr += _addindent(dump(metainfo_items), 4) | |||
_repr += "\n\n DATA FIELDS" | |||
items = dict(obj.items()) | |||
_repr += _addindent(dump(items), 4) | |||
classname = obj.__class__.__name__ | |||
_repr = f"<{classname}({_repr}\n) at {hex(id(obj))}>" | |||
else: | |||
_repr += repr(obj) | |||
return _repr | |||
return dump(self) |
@@ -0,0 +1,321 @@ | |||
# Copyright (c) OpenMMLab. All rights reserved. | |||
import itertools | |||
from collections.abc import Sized | |||
from typing import Any, List, Union | |||
import numpy as np | |||
import torch | |||
from ..utils import flatten as flatten_list | |||
from ..utils import to_hashable | |||
from .base_data_element import BaseDataElement | |||
BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] | |||
LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor] | |||
IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndarray] | |||
# Modified from | |||
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa | |||
class ListData(BaseDataElement): | |||
"""Data structure for instance-level annotations or predictions. | |||
Subclass of :class:`BaseDataElement`. All value in `data_fields` | |||
should have the same length. This design refer to | |||
https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501 | |||
ListData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value | |||
in data field can be base data structure such as `torch.Tensor`, `numpy.ndarray`, `list`, `str`, `tuple`, | |||
and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes. | |||
Examples: | |||
>>> # custom data structure | |||
>>> class TmpObject: | |||
... def __init__(self, tmp) -> None: | |||
... assert isinstance(tmp, list) | |||
... self.tmp = tmp | |||
... def __len__(self): | |||
... return len(self.tmp) | |||
... def __getitem__(self, item): | |||
... if isinstance(item, int): | |||
... if item >= len(self) or item < -len(self): # type:ignore | |||
... raise IndexError(f'Index {item} out of range!') | |||
... else: | |||
... # keep the dimension | |||
... item = slice(item, None, len(self)) | |||
... return TmpObject(self.tmp[item]) | |||
... @staticmethod | |||
... def cat(tmp_objs): | |||
... assert all(isinstance(results, TmpObject) for results in tmp_objs) | |||
... if len(tmp_objs) == 1: | |||
... return tmp_objs[0] | |||
... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs] | |||
... tmp_list = list(itertools.chain(*tmp_list)) | |||
... new_data = TmpObject(tmp_list) | |||
... return new_data | |||
... def __repr__(self): | |||
... return str(self.tmp) | |||
>>> from mmengine.structures import ListData | |||
>>> import numpy as np | |||
>>> import torch | |||
>>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) | |||
>>> instance_data = ListData(metainfo=img_meta) | |||
>>> 'img_shape' in instance_data | |||
True | |||
>>> instance_data.det_labels = torch.LongTensor([2, 3]) | |||
>>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7]) | |||
>>> instance_data.bboxes = torch.rand((2, 4)) | |||
>>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]]) | |||
>>> len(instance_data) | |||
2 | |||
>>> print(instance_data) | |||
<ListData( | |||
META INFORMATION | |||
img_shape: (800, 1196, 3) | |||
pad_shape: (800, 1216, 3) | |||
DATA FIELDS | |||
det_labels: tensor([2, 3]) | |||
det_scores: tensor([0.8000, 0.7000]) | |||
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], | |||
[0.8101, 0.3105, 0.5123, 0.6263]]) | |||
polygons: [[1, 2, 3, 4], [5, 6, 7, 8]] | |||
) at 0x7fb492de6280> | |||
>>> sorted_results = instance_data[instance_data.det_scores.sort().indices] | |||
>>> sorted_results.det_scores | |||
tensor([0.7000, 0.8000]) | |||
>>> print(instance_data[instance_data.det_scores > 0.75]) | |||
<ListData( | |||
META INFORMATION | |||
img_shape: (800, 1196, 3) | |||
pad_shape: (800, 1216, 3) | |||
DATA FIELDS | |||
det_labels: tensor([2]) | |||
det_scores: tensor([0.8000]) | |||
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]]) | |||
polygons: [[1, 2, 3, 4]] | |||
) at 0x7f64ecf0ec40> | |||
>>> print(instance_data[instance_data.det_scores > 1]) | |||
<ListData( | |||
META INFORMATION | |||
img_shape: (800, 1196, 3) | |||
pad_shape: (800, 1216, 3) | |||
DATA FIELDS | |||
det_labels: tensor([], dtype=torch.int64) | |||
det_scores: tensor([]) | |||
bboxes: tensor([], size=(0, 4)) | |||
polygons: [] | |||
) at 0x7f660a6a7f70> | |||
>>> print(instance_data.cat([instance_data, instance_data])) | |||
<ListData( | |||
META INFORMATION | |||
img_shape: (800, 1196, 3) | |||
pad_shape: (800, 1216, 3) | |||
DATA FIELDS | |||
det_labels: tensor([2, 3, 2, 3]) | |||
det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000]) | |||
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], | |||
[0.8101, 0.3105, 0.5123, 0.6263], | |||
[0.4997, 0.7707, 0.0595, 0.4188], | |||
[0.8101, 0.3105, 0.5123, 0.6263]]) | |||
polygons: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]] | |||
) at 0x7f203542feb0> | |||
""" | |||
def __setattr__(self, name: str, value: list): | |||
"""setattr is only used to set data. | |||
The value must have the attribute of `__len__` and have the same length | |||
of `ListData`. | |||
""" | |||
if name in ("_metainfo_fields", "_data_fields"): | |||
if not hasattr(self, name): | |||
super().__setattr__(name, value) | |||
else: | |||
raise AttributeError( | |||
f"{name} has been used as a " | |||
"private attribute, which is immutable." | |||
) | |||
else: | |||
assert isinstance(value, list), "value must be of type `list`" | |||
if len(self) > 0: | |||
assert len(value) == len(self), ( | |||
"The length of " | |||
f"values {len(value)} is " | |||
"not consistent with " | |||
"the length of this " | |||
":obj:`ListData` " | |||
f"{len(self)}" | |||
) | |||
super().__setattr__(name, value) | |||
__setitem__ = __setattr__ | |||
def __getitem__(self, item: IndexType) -> "ListData": | |||
""" | |||
Args: | |||
item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`, | |||
:obj:`torch.LongTensor`, :obj:`torch.BoolTensor`): | |||
Get the corresponding values according to item. | |||
Returns: | |||
:obj:`ListData`: Corresponding values. | |||
""" | |||
assert isinstance(item, IndexType.__args__) | |||
if isinstance(item, list): | |||
item = np.array(item) | |||
if isinstance(item, np.ndarray): | |||
# The default int type of numpy is platform dependent, int32 for | |||
# windows and int64 for linux. `torch.Tensor` requires the index | |||
# should be int64, therefore we simply convert it to int64 here. | |||
# More details in https://github.com/numpy/numpy/issues/9464 | |||
item = item.astype(np.int64) if item.dtype == np.int32 else item | |||
item = torch.from_numpy(item) | |||
if isinstance(item, str): | |||
return getattr(self, item) | |||
if isinstance(item, int): | |||
if item >= len(self) or item < -len(self): # type:ignore | |||
raise IndexError(f"Index {item} out of range!") | |||
else: | |||
# keep the dimension | |||
item = slice(item, None, len(self)) | |||
new_data = self.__class__(metainfo=self.metainfo) | |||
if isinstance(item, torch.Tensor): | |||
assert item.dim() == 1, ( | |||
"Only support to get the" " values along the first dimension." | |||
) | |||
if isinstance(item, BoolTypeTensor.__args__): | |||
assert len(item) == len(self), ( | |||
"The shape of the " | |||
"input(BoolTensor) " | |||
f"{len(item)} " | |||
"does not match the shape " | |||
"of the indexed tensor " | |||
"in results_field " | |||
f"{len(self)} at " | |||
"first dimension." | |||
) | |||
for k, v in self.items(): | |||
if isinstance(v, torch.Tensor): | |||
new_data[k] = v[item] | |||
elif isinstance(v, np.ndarray): | |||
new_data[k] = v[item.cpu().numpy()] | |||
elif isinstance(v, (str, list, tuple)) or ( | |||
hasattr(v, "__getitem__") and hasattr(v, "cat") | |||
): | |||
# convert to indexes from BoolTensor | |||
if isinstance(item, BoolTypeTensor.__args__): | |||
indexes = torch.nonzero(item).view(-1).cpu().numpy().tolist() | |||
else: | |||
indexes = item.cpu().numpy().tolist() | |||
slice_list = [] | |||
if indexes: | |||
for index in indexes: | |||
slice_list.append(slice(index, None, len(v))) | |||
else: | |||
slice_list.append(slice(None, 0, None)) | |||
r_list = [v[s] for s in slice_list] | |||
if isinstance(v, (str, list, tuple)): | |||
new_value = r_list[0] | |||
for r in r_list[1:]: | |||
new_value = new_value + r | |||
else: | |||
new_value = v.cat(r_list) | |||
new_data[k] = new_value | |||
else: | |||
raise ValueError( | |||
f"The type of `{k}` is `{type(v)}`, which has no " | |||
"attribute of `cat`, so it does not " | |||
"support slice with `bool`" | |||
) | |||
else: | |||
# item is a slice | |||
for k, v in self.items(): | |||
new_data[k] = v[item] | |||
return new_data # type:ignore | |||
@staticmethod | |||
def cat(instances_list: List["ListData"]) -> "ListData": | |||
"""Concat the instances of all :obj:`ListData` in the list. | |||
Note: To ensure that cat returns as expected, make sure that | |||
all elements in the list must have exactly the same keys. | |||
Args: | |||
instances_list (list[:obj:`ListData`]): A list | |||
of :obj:`ListData`. | |||
Returns: | |||
:obj:`ListData` | |||
""" | |||
assert all(isinstance(results, ListData) for results in instances_list) | |||
assert len(instances_list) > 0 | |||
if len(instances_list) == 1: | |||
return instances_list[0] | |||
# metainfo and data_fields must be exactly the | |||
# same for each element to avoid exceptions. | |||
field_keys_list = [instances.all_keys() for instances in instances_list] | |||
assert len({len(field_keys) for field_keys in field_keys_list}) == 1 and len( | |||
set(itertools.chain(*field_keys_list)) | |||
) == len(field_keys_list[0]), ( | |||
"There are different keys in " | |||
"`instances_list`, which may " | |||
"cause the cat operation " | |||
"to fail. Please make sure all " | |||
"elements in `instances_list` " | |||
"have the exact same key." | |||
) | |||
new_data = instances_list[0].__class__(metainfo=instances_list[0].metainfo) | |||
for k in instances_list[0].keys(): | |||
values = [results[k] for results in instances_list] | |||
v0 = values[0] | |||
if isinstance(v0, torch.Tensor): | |||
new_values = torch.cat(values, dim=0) | |||
elif isinstance(v0, np.ndarray): | |||
new_values = np.concatenate(values, axis=0) | |||
elif isinstance(v0, (str, list, tuple)): | |||
new_values = v0[:] | |||
for v in values[1:]: | |||
new_values += v | |||
elif hasattr(v0, "cat"): | |||
new_values = v0.cat(values) | |||
else: | |||
raise ValueError( | |||
f"The type of `{k}` is `{type(v0)}` which has no " | |||
"attribute of `cat`" | |||
) | |||
new_data[k] = new_values | |||
return new_data # type:ignore | |||
def flatten(self, item: IndexType) -> List: | |||
"""Flatten self[item]. | |||
Returns: | |||
list: Flattened data fields. | |||
""" | |||
return flatten_list(self[item]) | |||
def elements_num(self, item: IndexType) -> int: | |||
"""int: The number of elements in self[item].""" | |||
return len(self.flatten(item)) | |||
def to_tuple(self, item: IndexType) -> tuple: | |||
"""tuple: The data fields in self[item] converted to tuple.""" | |||
return to_hashable(self[item]) | |||
def __len__(self) -> int: | |||
"""int: The length of ListData.""" | |||
if len(self._data_fields) > 0: | |||
one_element = next(iter(self._data_fields)) | |||
return len(getattr(self, one_element)) | |||
# return len(self.values()[0]) | |||
else: | |||
return 0 |
@@ -1,2 +1,3 @@ | |||
from .cache import Cache | |||
from .logger import ABLLogger, print_log | |||
from .utils import * | |||
from .utils import * |
@@ -0,0 +1,112 @@ | |||
import pickle | |||
from os import PathLike | |||
from pathlib import Path | |||
from typing import Callable, Generic, Hashable, TypeVar, Union | |||
from .logger import print_log | |||
K = TypeVar("K") | |||
T = TypeVar("T") | |||
PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields | |||
class Cache(Generic[K, T]): | |||
def __init__( | |||
self, | |||
func: Callable[[K], T], | |||
cache: bool, | |||
cache_file: Union[None, str, PathLike], | |||
key_func: Callable[[K], Hashable] = lambda x: x, | |||
max_size: int = 4096, | |||
): | |||
"""Create cache | |||
:param func: Function this cache evaluates | |||
:param cache: If true, do in memory caching. | |||
:param cache_root: If not None, cache to files at the provided path. | |||
:param key_func: Convert the key into a hashable object if needed | |||
""" | |||
self.func = func | |||
self.key_func = key_func | |||
self.cache = cache | |||
if cache is True or cache_file is not None: | |||
print_log("Caching is activated", logger="current") | |||
self._init_cache(cache_file, max_size) | |||
self.first = self.get_from_dict | |||
else: | |||
self.first = self.func | |||
def __getitem__(self, item: K, *args) -> T: | |||
return self.first(item, *args) | |||
def invalidate(self): | |||
"""Invalidate entire cache.""" | |||
self.cache_dict.clear() | |||
if self.cache_file: | |||
for p in self.cache_root.iterdir(): | |||
p.unlink() | |||
def _init_cache(self, cache_file, max_size): | |||
self.cache = True | |||
self.cache_dict = dict() | |||
self.hits, self.misses, self.maxsize = 0, 0, max_size | |||
self.full = False | |||
self.root = [] # root of the circular doubly linked list | |||
self.root[:] = [self.root, self.root, None, None] | |||
if cache_file is not None: | |||
with open(cache_file, "rb") as f: | |||
cache_dict_from_file = pickle.load(f) | |||
self.maxsize += len(cache_dict_from_file) | |||
print_log( | |||
f"Max size of the cache has been enlarged to {self.maxsize}.", logger="current" | |||
) | |||
for cache_key, result in cache_dict_from_file.items(): | |||
last = self.root[PREV] | |||
link = [last, self.root, cache_key, result] | |||
last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link | |||
def get(self, item: K, *args) -> T: | |||
return self.first(item, *args) | |||
def get_from_dict(self, item: K, *args) -> T: | |||
"""Implements dict based cache.""" | |||
cache_key = (self.key_func(item), *args) | |||
link = self.cache_dict.get(cache_key) | |||
if link is not None: | |||
# Move the link to the front of the circular queue | |||
link_prev, link_next, _key, result = link | |||
link_prev[NEXT] = link_next | |||
link_next[PREV] = link_prev | |||
last = self.root[PREV] | |||
last[NEXT] = self.root[PREV] = link | |||
link[PREV] = last | |||
link[NEXT] = self.root | |||
self.hits += 1 | |||
return result | |||
self.misses += 1 | |||
result = self.func(item, *args) | |||
if self.full: | |||
# Use the old root to store the new key and result. | |||
oldroot = self.root | |||
oldroot[KEY] = cache_key | |||
oldroot[RESULT] = result | |||
# Empty the oldest link and make it the new root. | |||
self.root = oldroot[NEXT] | |||
oldkey = self.root[KEY] | |||
oldresult = self.root[RESULT] | |||
self.root[KEY] = self.root[RESULT] = None | |||
# Now update the cache dictionary. | |||
del self.cache_dict[oldkey] | |||
self.cache_dict[cache_key] = oldroot | |||
else: | |||
# Put result in a new link at the front of the queue. | |||
last = self.root[PREV] | |||
link = [last, self.root, cache_key, result] | |||
last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link | |||
if isinstance(self.maxsize, int): | |||
self.full = len(self.cache_dict) >= self.maxsize | |||
return result |
@@ -1,6 +1,7 @@ | |||
import numpy as np | |||
from itertools import chain | |||
import numpy as np | |||
def flatten(nested_list): | |||
""" | |||
@@ -1,8 +1,8 @@ | |||
# -*- coding: utf-8 -*- | |||
import sys | |||
import os | |||
import re | |||
import sys | |||
if not 'READTHEDOCS' in os.environ: | |||
sys.path.insert(0, os.path.abspath('..')) | |||
@@ -11,7 +11,6 @@ sys.path.append(os.path.abspath('./ABL/')) | |||
# from sphinx.locale import _ | |||
from sphinx_rtd_theme import __version__ | |||
project = u'ABL' | |||
slug = re.sub(r'\W+', '-', project.lower()) | |||
author = u'Yu-Xuan Huang, Wen-Chao Hu, En-Hao Gao' | |||
@@ -1,11 +1,12 @@ | |||
import os | |||
import cv2 | |||
import torch | |||
import torchvision | |||
import pickle | |||
import numpy as np | |||
import random | |||
from collections import defaultdict | |||
import cv2 | |||
import numpy as np | |||
import torch | |||
import torchvision | |||
from torch.utils.data import Dataset | |||
from torchvision.transforms import transforms | |||
@@ -1,18 +1,18 @@ | |||
import os | |||
from collections import defaultdict | |||
import torch | |||
from torch.utils.data import DataLoader | |||
from abl.reasoning import ReasonerBase | |||
from abl.learning import ABLModel, BasicNN | |||
from abl.bridge import SimpleBridge | |||
from abl.evaluation import BaseMetric | |||
from abl.dataset import BridgeDataset, RegressionDataset | |||
from abl.evaluation import BaseMetric | |||
from abl.learning import ABLModel, BasicNN | |||
from abl.reasoning import ReasonerBase | |||
from abl.utils import print_log | |||
from examples.hed.utils import gen_mappings, InfiniteSampler | |||
from examples.models.nn import SymbolNetAutoencoder | |||
from examples.hed.datasets.get_hed import get_pretrain_data | |||
from examples.hed.utils import InfiniteSampler, gen_mappings | |||
from examples.models.nn import SymbolNetAutoencoder | |||
class HEDBridge(SimpleBridge): | |||
@@ -12,7 +12,7 @@ | |||
"\n", | |||
"from abl.reasoning import ReasonerBase, prolog_KB\n", | |||
"from abl.learning import BasicNN, ABLModel\n", | |||
"from abl.evaluation import SymbolMetric, ABLMetric\n", | |||
"from abl.evaluation import SymbolMetric, SemanticsMetric\n", | |||
"from abl.utils import ABLLogger, reform_idx\n", | |||
"\n", | |||
"from examples.hed.hed_bridge import HEDBridge\n", | |||
@@ -206,7 +206,7 @@ | |||
"outputs": [], | |||
"source": [ | |||
"# Add metric\n", | |||
"metric = [SymbolMetric(prefix=\"hed\"), ABLMetric(prefix=\"hed\")]" | |||
"metric = [SymbolMetric(prefix=\"hed\"), SemanticsMetric(prefix=\"hed\")]" | |||
] | |||
}, | |||
{ | |||
@@ -1,6 +1,6 @@ | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
import numpy as np | |||
import torch.utils.data.sampler as sampler | |||
@@ -1,10 +1,10 @@ | |||
import os | |||
import json | |||
import os.path as osp | |||
from PIL import Image | |||
from torchvision.transforms import transforms | |||
CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | |||
CURRENT_DIR = osp.abspath(osp.dirname(__file__)) | |||
img_transform = transforms.Compose( | |||
[transforms.ToTensor(), transforms.Normalize((0.5,), (1,))] | |||
@@ -15,7 +15,7 @@ def get_data(file, get_pseudo_label): | |||
X, Y = [], [] | |||
if get_pseudo_label: | |||
Z = [] | |||
img_dir = os.path.join(CURRENT_DIR, "data/Handwritten_Math_Symbols/") | |||
img_dir = osp.join(CURRENT_DIR, "data/Handwritten_Math_Symbols/") | |||
with open(file) as f: | |||
data = json.load(f) | |||
for idx in range(len(data)): | |||
@@ -40,8 +40,8 @@ def get_data(file, get_pseudo_label): | |||
def get_hwf(train=True, get_gt_pseudo_label=False): | |||
if train: | |||
file = os.path.join(CURRENT_DIR, "data/expr_train.json") | |||
file = osp.join(CURRENT_DIR, "data/expr_train.json") | |||
else: | |||
file = os.path.join(CURRENT_DIR, "data/expr_test.json") | |||
file = osp.join(CURRENT_DIR, "data/expr_test.json") | |||
return get_data(file, get_gt_pseudo_label) |
@@ -6,19 +6,19 @@ | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"import os.path as osp\n", | |||
"\n", | |||
"import torch\n", | |||
"import numpy as np\n", | |||
"import torch.nn as nn\n", | |||
"import os.path as osp\n", | |||
"\n", | |||
"from abl.reasoning import ReasonerBase, KBBase\n", | |||
"from abl.learning import BasicNN, ABLModel\n", | |||
"from abl.bridge import SimpleBridge\n", | |||
"from abl.evaluation import SymbolMetric, SemanticsMetric\n", | |||
"from abl.evaluation import SemanticsMetric, SymbolMetric\n", | |||
"from abl.learning import ABLModel, BasicNN\n", | |||
"from abl.reasoning import ReasonerBase\n", | |||
"from abl.utils import ABLLogger, print_log\n", | |||
"\n", | |||
"from examples.models.nn import SymbolNet\n", | |||
"from datasets.get_hwf import get_hwf" | |||
"from examples.hwf.datasets.get_hwf import get_hwf\n", | |||
"from examples.hwf.hwf_kb import HWF_KB\n", | |||
"from examples.models.nn import SymbolNet" | |||
] | |||
}, | |||
{ | |||
@@ -50,37 +50,8 @@ | |||
"outputs": [], | |||
"source": [ | |||
"# Initialize knowledge base and abducer\n", | |||
"class HWF_KB(KBBase):\n", | |||
" def __init__(\n", | |||
" self, \n", | |||
" pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \n", | |||
" prebuild_GKB=False,\n", | |||
" GKB_len_list=[1, 3, 5, 7],\n", | |||
" max_err=1e-3,\n", | |||
" use_cache=True\n", | |||
" ):\n", | |||
" super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)\n", | |||
"\n", | |||
" def _valid_candidate(self, formula):\n", | |||
" if len(formula) % 2 == 0:\n", | |||
" return False\n", | |||
" for i in range(len(formula)):\n", | |||
" if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']:\n", | |||
" return False\n", | |||
" if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']:\n", | |||
" return False\n", | |||
" return True\n", | |||
"\n", | |||
" def logic_forward(self, formula):\n", | |||
" if not self._valid_candidate(formula):\n", | |||
" return np.inf\n", | |||
" mapping = {str(i): str(i) for i in range(1, 10)}\n", | |||
" mapping.update({'+': '+', '-': '-', 'times': '*', 'div': '/'})\n", | |||
" formula = [mapping[f] for f in formula]\n", | |||
" return eval(''.join(formula))\n", | |||
"\n", | |||
"kb = HWF_KB(prebuild_GKB=True)\n", | |||
"abducer = ReasonerBase(kb, dist_func='confidence')" | |||
"kb = HWF_KB()\n", | |||
"abducer = ReasonerBase(kb, dist_func=\"confidence\")" | |||
] | |||
}, | |||
{ | |||
@@ -117,10 +88,8 @@ | |||
" criterion=criterion,\n", | |||
" optimizer=optimizer,\n", | |||
" device=device,\n", | |||
" save_interval=1,\n", | |||
" save_dir=weights_dir,\n", | |||
" batch_size=128,\n", | |||
" num_epochs=3,\n", | |||
" num_epochs=1,\n", | |||
")" | |||
] | |||
}, | |||
@@ -131,7 +100,7 @@ | |||
"outputs": [], | |||
"source": [ | |||
"# Initialize ABL model\n", | |||
"# The main function of the ABL model is to serialize data and \n", | |||
"# The main function of the ABL model is to serialize data and\n", | |||
"# provide a unified interface for different machine learning models\n", | |||
"model = ABLModel(base_model)" | |||
] | |||
@@ -151,7 +120,7 @@ | |||
"outputs": [], | |||
"source": [ | |||
"# Add metric\n", | |||
"metric_list = [SymbolMetric(prefix=\"hwf\"), SemanticsMetric(prefix=\"hwf\")]" | |||
"metric_list = [SymbolMetric(prefix=\"hwf\"), SemanticsMetric(kb=kb, prefix=\"hwf\")]" | |||
] | |||
}, | |||
{ | |||
@@ -204,7 +173,7 @@ | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"bridge.train(train_data, epochs=3, batch_size=1000)\n", | |||
"bridge.train(train_data, loops=5, segment_size=1000, save_interval=1, save_dir=weights_dir)\n", | |||
"bridge.test(test_data)" | |||
] | |||
} | |||
@@ -0,0 +1,129 @@ | |||
import bisect | |||
from collections import defaultdict | |||
from itertools import product | |||
from multiprocessing import Pool | |||
from typing import Any, Hashable, List | |||
import numpy as np | |||
from abl.reasoning import GroundKB | |||
from abl.structures import ListData | |||
from abl.utils import hamming_dist | |||
class HWF_KB(GroundKB): | |||
def __init__( | |||
self, | |||
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "times", "div"], | |||
GKB_len_list=[1, 3, 5, 7], | |||
max_err=1e-10, | |||
): | |||
self.GKB_len_list = GKB_len_list | |||
self.max_err = max_err | |||
self.label2evaluable = {str(i): str(i) for i in range(1, 10)} | |||
self.label2evaluable.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | |||
super().__init__(pseudo_label_list) | |||
def logic_forward(self, data_sample: ListData): | |||
if not self._valid_candidate(data_sample): | |||
return None | |||
formula = data_sample["pred_pseudo_label"][0] | |||
formula = [self.label2evaluable[f] for f in formula] | |||
data_sample["Y"] = [eval("".join(formula))] | |||
return data_sample["Y"][0] | |||
def check_equal(self, data_sample: ListData, y: Any): | |||
if not self._valid_candidate(data_sample): | |||
return False | |||
formula = data_sample["pred_pseudo_label"][0] | |||
formula = [self.label2evaluable[f] for f in formula] | |||
return abs(eval("".join(formula)) - y) < self.max_err | |||
def construct_base(self) -> dict: | |||
X, Y = [], [] | |||
for length in self.GKB_len_list: | |||
arg_list = [] | |||
for pre_x in self.pseudo_label_list: | |||
post_x_it = product(self.pseudo_label_list, repeat=length - 1) | |||
arg_list.append((pre_x, post_x_it)) | |||
with Pool(processes=len(arg_list)) as pool: | |||
ret_list = pool.map(self._get_XY_list, arg_list) | |||
for XY_list in ret_list: | |||
if len(XY_list) == 0: | |||
continue | |||
part_X, part_Y = zip(*XY_list) | |||
X.extend(part_X) | |||
Y.extend(part_Y) | |||
if Y and isinstance(Y[0], (int, float)): | |||
X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) | |||
GKB = {} | |||
for x, y in zip(X, Y): | |||
GKB.setdefault(len(x), defaultdict(list))[y].append(x) | |||
return GKB | |||
@staticmethod | |||
def get_key(data_sample: ListData) -> Hashable: | |||
return (data_sample["symbol_num"], data_sample["Y"][0]) | |||
def key2candidates(self, key: Hashable) -> List[List[Any]]: | |||
equation_len, y = key | |||
if self.max_err == 0: | |||
return self.GKB[equation_len][y] | |||
else: | |||
potential_candidates = self.GKB[equation_len] | |||
key_list = list(potential_candidates.keys()) | |||
key_idx = bisect.bisect_left(key_list, y) | |||
all_candidates = [] | |||
for idx in range(key_idx - 1, -1, -1): | |||
k = key_list[idx] | |||
if abs(k - y) <= self.max_err: | |||
all_candidates.extend(potential_candidates[k]) | |||
else: | |||
break | |||
for idx in range(key_idx, len(key_list)): | |||
k = key_list[idx] | |||
if abs(k - y) <= self.max_err: | |||
all_candidates.extend(potential_candidates[k]) | |||
else: | |||
break | |||
return all_candidates | |||
def filter_candidates( | |||
self, | |||
data_sample: ListData, | |||
candidates: List[List[Any]], | |||
max_revision_num: int, | |||
require_more_revision: int = 0, | |||
) -> List[List[Any]]: | |||
cost_list = hamming_dist(data_sample["pred_pseudo_label"][0], candidates) | |||
min_revision_num = np.min(cost_list) | |||
revision_num = min(max_revision_num, min_revision_num + require_more_revision) | |||
idxs = np.where(cost_list <= revision_num)[0] | |||
filtered_candidates = [candidates[idx] for idx in idxs] | |||
return filtered_candidates | |||
# TODO: change return value to List[ListData] | |||
def _get_XY_list(self, args): | |||
pre_x, post_x_it = args[0], args[1] | |||
XY_list = [] | |||
for post_x in post_x_it: | |||
x = (pre_x,) + post_x | |||
data_sample = ListData(pred_pseudo_label=[x]) | |||
y = self.logic_forward(data_sample) | |||
if y is not None: | |||
XY_list.append((x, y)) | |||
return XY_list | |||
@staticmethod | |||
def _valid_candidate(data_sample): | |||
formula = data_sample["pred_pseudo_label"][0] | |||
if len(formula) % 2 == 0: | |||
return False | |||
for i in range(len(formula)): | |||
if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: | |||
return False | |||
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | |||
return False | |||
return True |
@@ -1,39 +1,49 @@ | |||
import os.path as osp | |||
import torchvision | |||
from torchvision.transforms import transforms | |||
CURRENT_DIR = osp.abspath(osp.dirname(__file__)) | |||
def get_data(file, img_dataset, get_pseudo_label): | |||
X = [] | |||
X, Y = [], [] | |||
if get_pseudo_label: | |||
Z = [] | |||
Y = [] | |||
with open(file) as f: | |||
for line in f: | |||
line = line.strip().split(' ') | |||
# if len(X) == 1000: | |||
# break | |||
line = line.strip().split(" ") | |||
X.append([img_dataset[int(line[0])][0], img_dataset[int(line[1])][0]]) | |||
if get_pseudo_label: | |||
Z.append([img_dataset[int(line[0])][1], img_dataset[int(line[1])][1]]) | |||
Y.append(int(line[2])) | |||
if get_pseudo_label: | |||
return X, Z, Y | |||
else: | |||
return X, None, Y | |||
def get_mnist_add(train = True, get_pseudo_label = False): | |||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))]) | |||
img_dataset = torchvision.datasets.MNIST(root='./datasets/', train=train, download=True, transform=transform) | |||
def get_mnist_add(train=True, get_pseudo_label=False): | |||
transform = transforms.Compose( | |||
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] | |||
) | |||
img_dataset = torchvision.datasets.MNIST( | |||
root=CURRENT_DIR, train=train, download=True, transform=transform | |||
) | |||
if train: | |||
file = './datasets/train_data.txt' | |||
file = osp.join(CURRENT_DIR, "train_data.txt") | |||
else: | |||
file = './datasets/test_data.txt' | |||
file = osp.join(CURRENT_DIR, "test_data.txt") | |||
return get_data(file, img_dataset, get_pseudo_label) | |||
if __name__ == "__main__": | |||
train_X, train_Y = get_mnist_add(train = True) | |||
test_X, test_Y = get_mnist_add(train = False) | |||
train_X, train_Z, train_Y = get_mnist_add(train=True) | |||
test_X, test_Z, test_Y = get_mnist_add(train=False) | |||
print(len(train_X), len(test_X)) | |||
print(train_X[0][0].shape, train_X[0][1].shape, train_Y[0]) | |||
@@ -2,32 +2,37 @@ | |||
"cells": [ | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 3, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"import torch.nn as nn\n", | |||
"import torch\n", | |||
"import os.path as osp\n", | |||
"\n", | |||
"from abl.reasoning import ReasonerBase, KBBase\n", | |||
"import torch\n", | |||
"import torch.nn as nn\n", | |||
"\n", | |||
"from abl.learning import BasicNN, ABLModel\n", | |||
"from abl.bridge import SimpleBridge\n", | |||
"from abl.evaluation import SymbolMetric, ABLMetric\n", | |||
"from abl.utils import ABLLogger\n", | |||
"\n", | |||
"from models.nn import LeNet5\n", | |||
"from examples.mnist_add.datasets.get_mnist_add import get_mnist_add" | |||
"from abl.evaluation import SemanticsMetric, SymbolMetric\n", | |||
"from abl.learning import ABLModel, BasicNN\n", | |||
"from abl.reasoning import ReasonerBase\n", | |||
"from abl.utils import ABLLogger, print_log\n", | |||
"from examples.mnist_add.datasets.get_mnist_add import get_mnist_add\n", | |||
"from examples.mnist_add.mnist_add_kb import AddKB\n", | |||
"from examples.models.nn import LeNet5" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 4, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"# Initialize logger\n", | |||
"logger = ABLLogger.get_instance(\"abl\")" | |||
"print_log(\"Abductive Learning on the MNIST Add example.\", logger=\"current\")\n", | |||
"\n", | |||
"# Retrieve the directory of the Log file and define the directory for saving the model weights.\n", | |||
"log_dir = ABLLogger.get_current_instance().log_dir\n", | |||
"weights_dir = osp.join(log_dir, \"weights\")" | |||
] | |||
}, | |||
{ | |||
@@ -40,22 +45,19 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 5, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"# Initialize knowledge base and abducer\n", | |||
"class add_KB(KBBase):\n", | |||
" def __init__(self, pseudo_label_list=list(range(10)), prebuild_GKB=False, GKB_len_list=[2], max_err=0, use_cache=True):\n", | |||
" super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)\n", | |||
"\n", | |||
" def logic_forward(self, nums):\n", | |||
" return sum(nums)\n", | |||
"kb = AddKB()\n", | |||
"\n", | |||
"kb = add_KB(prebuild_GKB=True)\n", | |||
"# If use cache, get_key should be implemented in the abducer\n", | |||
"class AddAbducer(ReasonerBase):\n", | |||
" def get_key(self, data_sample):\n", | |||
" return (data_sample.to_tuple(\"pred_pseudo_label\"), data_sample[\"Y\"][0])\n", | |||
"\n", | |||
"# kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='datasets/mnist_add/add.pl')\n", | |||
"abducer = ReasonerBase(kb, dist_func=\"confidence\")" | |||
"abducer = AddAbducer(kb, dist_func=\"confidence\", use_cache=True)" | |||
] | |||
}, | |||
{ | |||
@@ -68,7 +70,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 6, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -81,19 +83,17 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 7, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"# Initialize BasicNN\n", | |||
"# The function of BasicNN is to wrap NN models into the form of an sklearn estimator\n", | |||
"base_model = BasicNN(\n", | |||
" cls,\n", | |||
" criterion,\n", | |||
" optimizer,\n", | |||
" device,\n", | |||
" save_interval=1,\n", | |||
" save_dir=logger.save_dir,\n", | |||
" model=cls,\n", | |||
" criterion=criterion,\n", | |||
" optimizer=optimizer,\n", | |||
" device=device,\n", | |||
" batch_size=32,\n", | |||
" num_epochs=1,\n", | |||
")" | |||
@@ -109,12 +109,12 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 8, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"# Initialize ABL model\n", | |||
"# The main function of the ABL model is to serialize data and \n", | |||
"# The main function of the ABL model is to serialize data and\n", | |||
"# provide a unified interface for different machine learning models\n", | |||
"model = ABLModel(base_model)" | |||
] | |||
@@ -129,12 +129,12 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 9, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"# Add metric\n", | |||
"metric = [SymbolMetric(prefix=\"mnist_add\"), ABLMetric(prefix=\"mnist_add\")]" | |||
"metric = [SymbolMetric(prefix=\"mnist_add\"), SemanticsMetric(kb=kb, prefix=\"mnist_add\")]" | |||
] | |||
}, | |||
{ | |||
@@ -147,7 +147,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 10, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -187,7 +187,7 @@ | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"bridge.train(train_data, epochs=5, batch_size=10000)\n", | |||
"bridge.train(train_data, loops=10, segment_size=10000)\n", | |||
"bridge.test(test_data)" | |||
] | |||
} | |||
@@ -208,7 +208,7 @@ | |||
"name": "python", | |||
"nbconvert_exporter": "python", | |||
"pygments_lexer": "ipython3", | |||
"version": "3.8.13" | |||
"version": "3.8.16" | |||
}, | |||
"orig_nbformat": 4, | |||
"vscode": { | |||
@@ -0,0 +1,17 @@ | |||
from typing import Any | |||
from abl.reasoning import SearchBasedKB | |||
from abl.structures import ListData | |||
class AddKB(SearchBasedKB): | |||
def __init__(self, pseudo_label_list=list(range(10))): | |||
super().__init__( | |||
pseudo_label_list=pseudo_label_list | |||
) | |||
def check_equal(self, data_sample: ListData, y: Any): | |||
return self.logic_forward(data_sample) == y | |||
def logic_forward(self, data_sample): | |||
return sum(data_sample["pred_pseudo_label"][0]) |
@@ -11,8 +11,8 @@ | |||
# ================================================================# | |||
import torch | |||
import numpy as np | |||
import torch | |||
from torch import nn | |||
@@ -66,7 +66,8 @@ class SymbolNet(nn.Module): | |||
num_features = 64 * (image_size[0] // 4 - 1) * (image_size[1] // 4 - 1) | |||
self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU()) | |||
self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU()) | |||
self.fc3 = nn.Sequential(nn.Linear(84, num_classes), nn.Softmax(dim=1)) | |||
# self.fc3 = nn.Sequential(nn.Linear(84, num_classes), nn.Softmax(dim=1)) | |||
self.fc3 = nn.Sequential(nn.Linear(84, num_classes)) | |||
def forward(self, x): | |||
x = self.conv1(x) | |||
@@ -84,9 +85,7 @@ class SymbolNetAutoencoder(nn.Module): | |||
self.base_model = SymbolNet(num_classes, image_size) | |||
self.softmax = nn.Softmax(dim=1) | |||
self.fc1 = nn.Sequential(nn.Linear(num_classes, 100), nn.ReLU()) | |||
self.fc2 = nn.Sequential( | |||
nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU() | |||
) | |||
self.fc2 = nn.Sequential(nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU()) | |||
def forward(self, x): | |||
x = self.base_model(x) | |||
@@ -1,4 +1,5 @@ | |||
import os | |||
from setuptools import find_packages, setup | |||
@@ -0,0 +1,403 @@ | |||
from abl.reasoning import ReasonerBase, BaseKB, GroundKB, PrologBasedKB | |||
if __name__ == "__main__": | |||
prob1 = [ | |||
[ | |||
[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
] | |||
] | |||
prob2 = [ | |||
[ | |||
[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
] | |||
] | |||
class add_KB(BaseKB): | |||
def __init__(self, pseudo_label_list=list(range(10)), use_cache=True): | |||
super().__init__(pseudo_label_list, use_cache=use_cache) | |||
def logic_forward(self, nums): | |||
return sum(nums) | |||
class add_GroundKB(GroundKB): | |||
def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]): | |||
super().__init__(pseudo_label_list, GKB_len_list) | |||
def logic_forward(self, nums): | |||
return sum(nums) | |||
def test_add(reasoner): | |||
res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||
print(res) | |||
res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||
print(res) | |||
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) | |||
print(res) | |||
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) | |||
print(res) | |||
res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) | |||
print(res) | |||
print() | |||
print("add_KB with GKB:") | |||
kb = add_GroundKB() | |||
reasoner = ReasonerBase(kb, "confidence") | |||
test_add(reasoner) | |||
print("add_KB without GKB:") | |||
kb = add_KB() | |||
reasoner = ReasonerBase(kb, "confidence") | |||
test_add(reasoner) | |||
print("add_KB without GKB, no cache") | |||
kb = add_KB(use_cache=False) | |||
reasoner = ReasonerBase(kb, "confidence") | |||
test_add(reasoner) | |||
print("PrologBasedKB with add.pl:") | |||
kb = PrologBasedKB( | |||
pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/datasets/add.pl" | |||
) | |||
reasoner = ReasonerBase(kb, "confidence") | |||
test_add(reasoner) | |||
print("PrologBasedKB with add.pl using zoopt:") | |||
kb = PrologBasedKB( | |||
pseudo_label_list=list(range(10)), | |||
pl_file="examples/mnist_add/datasets/add.pl", | |||
) | |||
reasoner = ReasonerBase(kb, "confidence", use_zoopt=True) | |||
test_add(reasoner) | |||
print("add_KB with multiple inputs at once:") | |||
multiple_prob = [ | |||
[ | |||
[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
], | |||
[ | |||
[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
], | |||
] | |||
kb = add_KB() | |||
reasoner = ReasonerBase(kb, "confidence") | |||
res = reasoner.batch_abduce( | |||
multiple_prob, | |||
[[1, 1], [1, 2]], | |||
[4, 8], | |||
max_revision=2, | |||
require_more_revision=0, | |||
) | |||
print(res) | |||
res = reasoner.batch_abduce( | |||
multiple_prob, | |||
[[1, 1], [1, 2]], | |||
[4, 8], | |||
max_revision=2, | |||
require_more_revision=1, | |||
) | |||
print(res) | |||
print() | |||
class HWF_KB(BaseKB): | |||
def __init__( | |||
self, | |||
pseudo_label_list=[ | |||
"1", | |||
"2", | |||
"3", | |||
"4", | |||
"5", | |||
"6", | |||
"7", | |||
"8", | |||
"9", | |||
"+", | |||
"-", | |||
"times", | |||
"div", | |||
], | |||
max_err=1e-3, | |||
): | |||
super().__init__(pseudo_label_list, max_err) | |||
def _valid_candidate(self, formula): | |||
if len(formula) % 2 == 0: | |||
return False | |||
for i in range(len(formula)): | |||
if i % 2 == 0 and formula[i] not in [ | |||
"1", | |||
"2", | |||
"3", | |||
"4", | |||
"5", | |||
"6", | |||
"7", | |||
"8", | |||
"9", | |||
]: | |||
return False | |||
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | |||
return False | |||
return True | |||
def logic_forward(self, formula): | |||
if not self._valid_candidate(formula): | |||
return np.inf | |||
mapping = {str(i): str(i) for i in range(1, 10)} | |||
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | |||
formula = [mapping[f] for f in formula] | |||
return eval("".join(formula)) | |||
class HWF_GroundKB(GroundKB): | |||
def __init__( | |||
self, | |||
pseudo_label_list=[ | |||
"1", | |||
"2", | |||
"3", | |||
"4", | |||
"5", | |||
"6", | |||
"7", | |||
"8", | |||
"9", | |||
"+", | |||
"-", | |||
"times", | |||
"div", | |||
], | |||
GKB_len_list=[1, 3, 5, 7], | |||
max_err=1e-3, | |||
): | |||
super().__init__(pseudo_label_list, GKB_len_list, max_err) | |||
def _valid_candidate(self, formula): | |||
if len(formula) % 2 == 0: | |||
return False | |||
for i in range(len(formula)): | |||
if i % 2 == 0 and formula[i] not in [ | |||
"1", | |||
"2", | |||
"3", | |||
"4", | |||
"5", | |||
"6", | |||
"7", | |||
"8", | |||
"9", | |||
]: | |||
return False | |||
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | |||
return False | |||
return True | |||
def logic_forward(self, formula): | |||
if not self._valid_candidate(formula): | |||
return np.inf | |||
mapping = {str(i): str(i) for i in range(1, 10)} | |||
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | |||
formula = [mapping[f] for f in formula] | |||
return eval("".join(formula)) | |||
def test_hwf(reasoner): | |||
res = reasoner.batch_abduce( | |||
[None], | |||
[["5", "+", "2"]], | |||
[3], | |||
max_revision=2, | |||
require_more_revision=0, | |||
) | |||
print(res) | |||
res = reasoner.batch_abduce( | |||
[None], | |||
[["5", "+", "9"]], | |||
[65], | |||
max_revision=3, | |||
require_more_revision=0, | |||
) | |||
print(res) | |||
res = reasoner.batch_abduce( | |||
[None], | |||
[["5", "8", "8", "8", "8"]], | |||
[3.17], | |||
max_revision=5, | |||
require_more_revision=3, | |||
) | |||
print(res) | |||
print() | |||
def test_hwf_multiple(reasoner, max_revisions): | |||
res = reasoner.batch_abduce( | |||
[None, None], | |||
[["5", "+", "2"], ["5", "+", "9"]], | |||
[3, 64], | |||
max_revision=max_revisions[0], | |||
require_more_revision=0, | |||
) | |||
print(res) | |||
res = reasoner.batch_abduce( | |||
[None, None], | |||
[["5", "+", "2"], ["5", "+", "9"]], | |||
[3, 64], | |||
max_revision=max_revisions[1], | |||
require_more_revision=0, | |||
) | |||
print(res) | |||
res = reasoner.batch_abduce( | |||
[None, None], | |||
[["5", "+", "2"], ["5", "+", "9"]], | |||
[3, 65], | |||
max_revision=max_revisions[2], | |||
require_more_revision=0, | |||
) | |||
print(res) | |||
print() | |||
print("HWF_KB with GKB, max_err=0.1") | |||
kb = HWF_GroundKB(GKB_len_list=[1, 3, 5], max_err=0.1) | |||
reasoner = ReasonerBase(kb, "hamming") | |||
test_hwf(reasoner) | |||
print("HWF_KB without GKB, max_err=0.1") | |||
kb = HWF_KB(max_err=0.1) | |||
reasoner = ReasonerBase(kb, "hamming") | |||
test_hwf(reasoner) | |||
print("HWF_KB with GKB, max_err=1") | |||
kb = HWF_GroundKB(GKB_len_list=[1, 3, 5], max_err=1) | |||
reasoner = ReasonerBase(kb, "hamming") | |||
test_hwf(reasoner) | |||
print("HWF_KB without GKB, max_err=1") | |||
kb = HWF_KB(max_err=1) | |||
reasoner = ReasonerBase(kb, "hamming") | |||
test_hwf(reasoner) | |||
print("HWF_KB with multiple inputs at once:") | |||
kb = HWF_KB(max_err=0.1) | |||
reasoner = ReasonerBase(kb, "hamming") | |||
test_hwf_multiple(reasoner, max_revisions=[1, 3, 3]) | |||
print("max_revision is float") | |||
test_hwf_multiple(reasoner, max_revisions=[0.5, 0.9, 0.9]) | |||
class HED_prolog_KB(PrologBasedKB): | |||
def __init__(self, pseudo_label_list, pl_file): | |||
super().__init__(pseudo_label_list, pl_file) | |||
def consist_rule(self, exs, rules): | |||
rules = str(rules).replace("'", "") | |||
pl_query = "eval_inst_feature(%s, %s)." % (exs, rules) | |||
return len(list(self.prolog.query(pl_query))) != 0 | |||
def abduce_rules(self, pred_res): | |||
pl_query = "consistent_inst_feature(%s, X)." % pred_res | |||
prolog_result = list(self.prolog.query(pl_query)) | |||
if len(prolog_result) == 0: | |||
return None | |||
prolog_rules = prolog_result[0]["X"] | |||
rules = [rule.value for rule in prolog_rules] | |||
return rules | |||
class HED_Reasoner(ReasonerBase): | |||
def __init__(self, kb, dist_func="hamming"): | |||
super().__init__(kb, dist_func, use_zoopt=True) | |||
def _revise_at_idxs(self, pred_res, y, all_revision_flag, idxs): | |||
pred = [] | |||
k = [] | |||
revision_flag = [] | |||
for idx in idxs: | |||
pred.append(pred_res[idx]) | |||
k.append(y[idx]) | |||
revision_flag += list(all_revision_flag[idx]) | |||
revision_idx = np.where(np.array(revision_flag) != 0)[0] | |||
candidate = self.revise_at_idx(pred, k, revision_idx) | |||
return candidate | |||
def zoopt_revision_score(self, symbol_num, pred_res, pred_prob, y, sol): | |||
all_revision_flag = reform_idx(sol.get_x(), pred_res) | |||
lefted_idxs = [i for i in range(len(pred_res))] | |||
candidate_size = [] | |||
while lefted_idxs: | |||
idxs = [] | |||
idxs.append(lefted_idxs.pop(0)) | |||
max_candidate_idxs = [] | |||
found = False | |||
for idx in range(-1, len(pred_res)): | |||
if (not idx in idxs) and (idx >= 0): | |||
idxs.append(idx) | |||
candidate = self._revise_at_idxs(pred_res, y, all_revision_flag, idxs) | |||
if len(candidate) == 0: | |||
if len(idxs) > 1: | |||
idxs.pop() | |||
else: | |||
if len(idxs) > len(max_candidate_idxs): | |||
found = True | |||
max_candidate_idxs = idxs.copy() | |||
removed = [i for i in lefted_idxs if i in max_candidate_idxs] | |||
if found: | |||
candidate_size.append(len(removed) + 1) | |||
lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] | |||
candidate_size.sort() | |||
score = 0 | |||
import math | |||
for i in range(0, len(candidate_size)): | |||
score -= math.exp(-i) * candidate_size[i] | |||
return score | |||
def abduce_rules(self, pred_res): | |||
return self.kb.abduce_rules(pred_res) | |||
kb = HED_prolog_KB( | |||
pseudo_label_list=[1, 0, "+", "="], | |||
pl_file="examples/hed/datasets/learn_add.pl", | |||
) | |||
reasoner = HED_Reasoner(kb) | |||
consist_exs = [ | |||
[1, 1, "+", 0, "=", 1, 1], | |||
[1, "+", 1, "=", 1, 0], | |||
[0, "+", 0, "=", 0], | |||
] | |||
inconsist_exs1 = [ | |||
[1, 1, "+", 0, "=", 1, 1], | |||
[1, "+", 1, "=", 1, 0], | |||
[0, "+", 0, "=", 0], | |||
[0, "+", 0, "=", 1], | |||
] | |||
inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] | |||
rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"] | |||
print("HED_kb logic forward") | |||
print(kb.logic_forward(consist_exs)) | |||
print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2)) | |||
print() | |||
print("HED_kb consist rule") | |||
print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules)) | |||
print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules)) | |||
print() | |||
print("HED_Reasoner abduce") | |||
res = reasoner.abduce([[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs)) | |||
print(res) | |||
res = reasoner.abduce( | |||
[[[None]]] * len(inconsist_exs1), inconsist_exs1, [None] * len(inconsist_exs1) | |||
) | |||
print(res) | |||
res = reasoner.abduce( | |||
[[[None]]] * len(inconsist_exs2), inconsist_exs2, [None] * len(inconsist_exs2) | |||
) | |||
print(res) | |||
print() | |||
print("HED_Reasoner abduce rules") | |||
abduced_rules = reasoner.abduce_rules(consist_exs) | |||
print(abduced_rules) |