@@ -1,2 +1,2 @@ | |||||
from .learning import abl_model, basic_nn | from .learning import abl_model, basic_nn | ||||
from .reasoning import reasoner, kb | |||||
from .reasoning import base_kb, ground_kb, reasoner, search_based_kb |
@@ -1,52 +1,64 @@ | |||||
from abc import ABCMeta, abstractmethod | from abc import ABCMeta, abstractmethod | ||||
from typing import Any, List, Tuple | |||||
from typing import Any, List, Optional, Tuple, Union | |||||
from ..learning import ABLModel | from ..learning import ABLModel | ||||
from ..reasoning import ReasonerBase | from ..reasoning import ReasonerBase | ||||
from ..structures import ListData | |||||
DataSet = Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]] | |||||
class BaseBridge(metaclass=ABCMeta): | |||||
class BaseBridge(metaclass=ABCMeta): | |||||
def __init__(self, model: ABLModel, abducer: ReasonerBase) -> None: | def __init__(self, model: ABLModel, abducer: ReasonerBase) -> None: | ||||
if not isinstance(model, ABLModel): | if not isinstance(model, ABLModel): | ||||
raise TypeError("Expected an ABLModel") | |||||
raise TypeError( | |||||
"Expected an instance of ABLModel, but received type: {}".format( | |||||
type(model) | |||||
) | |||||
) | |||||
if not isinstance(abducer, ReasonerBase): | if not isinstance(abducer, ReasonerBase): | ||||
raise TypeError("Expected an ReasonerBase") | |||||
raise TypeError( | |||||
"Expected an instance of ReasonerBase, but received type: {}".format( | |||||
type(abducer) | |||||
) | |||||
) | |||||
self.model = model | self.model = model | ||||
self.abducer = abducer | self.abducer = abducer | ||||
@abstractmethod | @abstractmethod | ||||
def predict(self, X: List[List[Any]]) -> Tuple[List[List[Any]], List[List[Any]]]: | |||||
def predict( | |||||
self, data_samples: ListData | |||||
) -> Tuple[List[List[Any]], List[List[Any]]]: | |||||
"""Placeholder for predict labels from input.""" | """Placeholder for predict labels from input.""" | ||||
pass | pass | ||||
@abstractmethod | @abstractmethod | ||||
def abduce_pseudo_label(self, pseudo_label: List[List[Any]]) -> List[List[Any]]: | |||||
def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | |||||
"""Placeholder for abduce pseudo labels.""" | """Placeholder for abduce pseudo labels.""" | ||||
pass | |||||
@abstractmethod | @abstractmethod | ||||
def idx_to_pseudo_label(self, idx: List[List[Any]]) -> List[List[Any]]: | |||||
def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | |||||
"""Placeholder for map label space to symbol space.""" | """Placeholder for map label space to symbol space.""" | ||||
pass | pass | ||||
@abstractmethod | @abstractmethod | ||||
def pseudo_label_to_idx(self, pseudo_label: List[List[Any]]) -> List[List[Any]]: | |||||
def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]: | |||||
"""Placeholder for map symbol space to label space.""" | """Placeholder for map symbol space to label space.""" | ||||
pass | pass | ||||
@abstractmethod | @abstractmethod | ||||
def train(self, train_data): | |||||
def train(self, train_data: Union[ListData, DataSet]): | |||||
"""Placeholder for train loop of ABductive Learning.""" | """Placeholder for train loop of ABductive Learning.""" | ||||
pass | pass | ||||
@abstractmethod | @abstractmethod | ||||
def test(self, test_data): | |||||
def valid(self, valid_data: Union[ListData, DataSet]) -> None: | |||||
"""Placeholder for model test.""" | """Placeholder for model test.""" | ||||
pass | pass | ||||
@abstractmethod | @abstractmethod | ||||
def valid(self, valid_data): | |||||
def test(self, test_data: Union[ListData, DataSet]) -> None: | |||||
"""Placeholder for model validation.""" | """Placeholder for model validation.""" | ||||
pass | pass | ||||
@@ -1,13 +1,14 @@ | |||||
from ..learning import ABLModel | |||||
from ..reasoning import ReasonerBase | |||||
from ..evaluation import BaseMetric | |||||
from .base_bridge import BaseBridge | |||||
from typing import List, Union, Any, Tuple, Dict, Optional | |||||
import os.path as osp | |||||
from typing import Any, Dict, List, Optional, Tuple, Union | |||||
from numpy import ndarray | from numpy import ndarray | ||||
from torch.utils.data import DataLoader | |||||
from ..dataset import BridgeDataset | |||||
from ..utils.logger import print_log | |||||
from ..evaluation import BaseMetric | |||||
from ..learning import ABLModel | |||||
from ..reasoning import ReasonerBase | |||||
from ..structures import ListData | |||||
from ..utils import print_log | |||||
from .base_bridge import BaseBridge, DataSet | |||||
class SimpleBridge(BaseBridge): | class SimpleBridge(BaseBridge): | ||||
@@ -20,85 +21,99 @@ class SimpleBridge(BaseBridge): | |||||
super().__init__(model, abducer) | super().__init__(model, abducer) | ||||
self.metric_list = metric_list | self.metric_list = metric_list | ||||
def predict(self, X) -> Tuple[List[List[Any]], ndarray]: | |||||
pred_res = self.model.predict(X) | |||||
pred_idx, pred_prob = pred_res["label"], pred_res["prob"] | |||||
return pred_idx, pred_prob | |||||
# TODO: add abducer.mapping to the property of SimpleBridge | |||||
def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: | |||||
self.model.predict(data_samples) | |||||
return data_samples["pred_idx"], data_samples.get("pred_prob", None) | |||||
def abduce_pseudo_label( | def abduce_pseudo_label( | ||||
self, | self, | ||||
pred_prob: ndarray, | |||||
pred_pseudo_label: List[List[Any]], | |||||
Y: List[Any], | |||||
data_samples: ListData, | |||||
max_revision: int = -1, | max_revision: int = -1, | ||||
require_more_revision: int = 0, | require_more_revision: int = 0, | ||||
) -> List[List[Any]]: | ) -> List[List[Any]]: | ||||
return self.abducer.batch_abduce(pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision) | |||||
self.abducer.batch_abduce(data_samples, max_revision, require_more_revision) | |||||
return data_samples["abduced_pseudo_label"] | |||||
def idx_to_pseudo_label( | def idx_to_pseudo_label( | ||||
self, idx: List[List[Any]], mapping: Dict = None | |||||
self, data_samples: ListData, mapping: Optional[Dict] = None | |||||
) -> List[List[Any]]: | ) -> List[List[Any]]: | ||||
if mapping is None: | if mapping is None: | ||||
mapping = self.abducer.mapping | mapping = self.abducer.mapping | ||||
return [[mapping[_idx] for _idx in sub_list] for sub_list in idx] | |||||
pred_idx = data_samples.pred_idx | |||||
data_samples.pred_pseudo_label = [ | |||||
[mapping[_idx] for _idx in sub_list] for sub_list in pred_idx | |||||
] | |||||
return data_samples["pred_pseudo_label"] | |||||
def pseudo_label_to_idx( | def pseudo_label_to_idx( | ||||
self, pseudo_label: List[List[Any]], mapping: Dict = None | |||||
self, data_samples: ListData, mapping: Optional[Dict] = None | |||||
) -> List[List[Any]]: | ) -> List[List[Any]]: | ||||
if mapping is None: | if mapping is None: | ||||
mapping = self.abducer.remapping | mapping = self.abducer.remapping | ||||
return [ | |||||
[mapping[_pseudo_label] for _pseudo_label in sub_list] | |||||
for sub_list in pseudo_label | |||||
abduced_idx = [ | |||||
[mapping[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list] | |||||
for sub_list in data_samples.abduced_pseudo_label | |||||
] | ] | ||||
data_samples.abduced_idx = abduced_idx | |||||
return data_samples["abduced_idx"] | |||||
def data_preprocess(self, X: List[Any], gt_pseudo_label: List[Any], Y: List[Any]) -> ListData: | |||||
data_samples = ListData() | |||||
data_samples.X = X | |||||
data_samples.gt_pseudo_label = gt_pseudo_label | |||||
data_samples.Y = Y | |||||
return data_samples | |||||
def train( | def train( | ||||
self, | self, | ||||
train_data: Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]], | |||||
epochs: int = 50, | |||||
batch_size: Union[int, float] = -1, | |||||
train_data: Union[ListData, DataSet], | |||||
loops: int = 50, | |||||
segment_size: Union[int, float] = -1, | |||||
eval_interval: int = 1, | eval_interval: int = 1, | ||||
save_interval: Optional[int] = None, | |||||
save_dir: Optional[str] = None, | |||||
): | ): | ||||
dataset = BridgeDataset(*train_data) | |||||
data_loader = DataLoader( | |||||
dataset, | |||||
batch_size=batch_size, | |||||
collate_fn=lambda data_list: [list(data) for data in zip(*data_list)], | |||||
) | |||||
for epoch in range(epochs): | |||||
for seg_idx, (X, Z, Y) in enumerate(data_loader): | |||||
pred_idx, pred_prob = self.predict(X) | |||||
pred_pseudo_label = self.idx_to_pseudo_label(pred_idx) | |||||
abduced_pseudo_label = self.abduce_pseudo_label( | |||||
pred_prob, pred_pseudo_label, Y | |||||
) | |||||
abduced_label = self.pseudo_label_to_idx(abduced_pseudo_label) | |||||
loss = self.model.train(X, abduced_label) | |||||
if isinstance(train_data, ListData): | |||||
data_samples = train_data | |||||
else: | |||||
data_samples = self.data_preprocess(*train_data) | |||||
for loop in range(loops): | |||||
for seg_idx in range((len(data_samples) - 1) // segment_size + 1): | |||||
sub_data_samples = data_samples[ | |||||
seg_idx * segment_size : (seg_idx + 1) * segment_size | |||||
] | |||||
self.predict(sub_data_samples) | |||||
self.idx_to_pseudo_label(sub_data_samples) | |||||
self.abduce_pseudo_label(sub_data_samples) | |||||
self.pseudo_label_to_idx(sub_data_samples) | |||||
loss = self.model.train(sub_data_samples) | |||||
print_log( | print_log( | ||||
f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{len(data_loader)}] model loss is {loss:.5f}", | |||||
f"loop(train) [{loop + 1}/{loops}] segment(train) [{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] model loss is {loss:.5f}", | |||||
logger="current", | logger="current", | ||||
) | ) | ||||
if (epoch + 1) % eval_interval == 0 or epoch == epochs - 1: | |||||
print_log(f"Evaluation start: Epoch(val) [{epoch}]", logger="current") | |||||
if (loop + 1) % eval_interval == 0 or loop == loops - 1: | |||||
print_log(f"Evaluation start: loop(val) [{loop + 1}]", logger="current") | |||||
self.valid(train_data) | self.valid(train_data) | ||||
def _valid(self, data_loader): | |||||
for X, Z, Y in data_loader: | |||||
pred_idx, pred_prob = self.predict(X) | |||||
pred_pseudo_label = self.idx_to_pseudo_label(pred_idx) | |||||
data_samples = dict( | |||||
pred_idx=pred_idx, | |||||
pred_prob=pred_prob, | |||||
pred_pseudo_label=pred_pseudo_label, | |||||
gt_pseudo_label=Z, | |||||
Y=Y, | |||||
logic_forward=self.abducer.kb.logic_forward, | |||||
) | |||||
if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1): | |||||
print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current") | |||||
self.model.save(save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth")) | |||||
def _valid(self, data_samples: ListData, batch_size: int = 128) -> None: | |||||
for seg_idx in range((len(data_samples) - 1) // batch_size + 1): | |||||
sub_data_samples = data_samples[seg_idx * batch_size : (seg_idx + 1) * batch_size] | |||||
self.predict(sub_data_samples) | |||||
self.idx_to_pseudo_label(sub_data_samples) | |||||
for metric in self.metric_list: | for metric in self.metric_list: | ||||
metric.process(data_samples) | |||||
metric.process(sub_data_samples) | |||||
res = dict() | res = dict() | ||||
for metric in self.metric_list: | for metric in self.metric_list: | ||||
@@ -108,14 +123,12 @@ class SimpleBridge(BaseBridge): | |||||
msg += k + f": {v:.3f} " | msg += k + f": {v:.3f} " | ||||
print_log(msg, logger="current") | print_log(msg, logger="current") | ||||
def valid(self, valid_data, batch_size=1000): | |||||
dataset = BridgeDataset(*valid_data) | |||||
data_loader = DataLoader( | |||||
dataset, | |||||
batch_size=batch_size, | |||||
collate_fn=lambda data_list: [list(data) for data in zip(*data_list)], | |||||
) | |||||
self._valid(data_loader) | |||||
def test(self, test_data, batch_size=1000): | |||||
self.valid(test_data, batch_size) | |||||
def valid(self, valid_data: Union[ListData, DataSet], batch_size: int = 128) -> None: | |||||
if not isinstance(valid_data, ListData): | |||||
data_samples = self.data_preprocess(*valid_data) | |||||
else: | |||||
data_samples = valid_data | |||||
self._valid(data_samples, batch_size=batch_size) | |||||
def test(self, test_data: Union[ListData, DataSet], batch_size: int = 128) -> None: | |||||
self.valid(test_data, batch_size=batch_size) |
@@ -1,3 +1,4 @@ | |||||
from .bridge_dataset import BridgeDataset | from .bridge_dataset import BridgeDataset | ||||
from .classification_dataset import ClassificationDataset | from .classification_dataset import ClassificationDataset | ||||
from .regression_dataset import RegressionDataset | |||||
from .prediction_dataset import PredictionDataset | |||||
from .regression_dataset import RegressionDataset |
@@ -1,5 +1,6 @@ | |||||
from typing import Any, List, Tuple | |||||
from torch.utils.data import Dataset | from torch.utils.data import Dataset | ||||
from typing import List, Any, Tuple | |||||
class BridgeDataset(Dataset): | class BridgeDataset(Dataset): | ||||
@@ -1,6 +1,7 @@ | |||||
from typing import Any, Callable, List, Tuple | |||||
import torch | import torch | ||||
from torch.utils.data import Dataset | from torch.utils.data import Dataset | ||||
from typing import List, Any, Tuple, Callable | |||||
class ClassificationDataset(Dataset): | class ClassificationDataset(Dataset): | ||||
@@ -0,0 +1,56 @@ | |||||
from typing import Any, Callable, List, Tuple | |||||
import torch | |||||
from torch.utils.data import Dataset | |||||
class PredictionDataset(Dataset): | |||||
def __init__(self, X: List[Any], transform: Callable[..., Any] = None): | |||||
""" | |||||
Initialize the dataset used for classification task. | |||||
Parameters | |||||
---------- | |||||
X : List[Any] | |||||
The input data. | |||||
transform : Callable[..., Any], optional | |||||
A function/transform that takes in an object and returns a transformed version. Defaults to None. | |||||
""" | |||||
if not isinstance(X, list): | |||||
raise ValueError("X should be of type list.") | |||||
self.X = X | |||||
self.transform = transform | |||||
def __len__(self) -> int: | |||||
""" | |||||
Return the length of the dataset. | |||||
Returns | |||||
------- | |||||
int | |||||
The length of the dataset. | |||||
""" | |||||
return len(self.X) | |||||
def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]: | |||||
""" | |||||
Get the item at the given index. | |||||
Parameters | |||||
---------- | |||||
index : int | |||||
The index of the item to get. | |||||
Returns | |||||
------- | |||||
Tuple[Any, torch.Tensor] | |||||
A tuple containing the object and its label. | |||||
""" | |||||
if index >= len(self): | |||||
raise ValueError("index range error") | |||||
x = self.X[index] | |||||
if self.transform is not None: | |||||
x = self.transform(x) | |||||
return x |
@@ -1,6 +1,7 @@ | |||||
from typing import Any, List, Tuple | |||||
import torch | import torch | ||||
from torch.utils.data import Dataset | from torch.utils.data import Dataset | ||||
from typing import List, Any, Tuple | |||||
class RegressionDataset(Dataset): | class RegressionDataset(Dataset): | ||||
@@ -1,3 +1,3 @@ | |||||
from .base_metric import BaseMetric | from .base_metric import BaseMetric | ||||
from .symbol_metric import SymbolMetric | |||||
from .semantics_metric import SemanticsMetric | from .semantics_metric import SemanticsMetric | ||||
from .symbol_metric import SymbolMetric |
@@ -1,8 +1,8 @@ | |||||
import logging | |||||
from abc import ABCMeta, abstractmethod | from abc import ABCMeta, abstractmethod | ||||
from typing import Any, List, Optional, Sequence | from typing import Any, List, Optional, Sequence | ||||
from ..utils import print_log | |||||
import logging | |||||
from ..utils import print_log | |||||
class BaseMetric(metaclass=ABCMeta): | class BaseMetric(metaclass=ABCMeta): | ||||
@@ -1,25 +1,22 @@ | |||||
from typing import Optional, Sequence | from typing import Optional, Sequence | ||||
from ..reasoning import BaseKB | |||||
from .base_metric import BaseMetric | from .base_metric import BaseMetric | ||||
class ABLMetric(): | |||||
pass | |||||
class SemanticsMetric(BaseMetric): | class SemanticsMetric(BaseMetric): | ||||
def __init__(self, prefix: Optional[str] = None) -> None: | |||||
def __init__(self, kb: BaseKB = None, prefix: Optional[str] = None) -> None: | |||||
super().__init__(prefix) | super().__init__(prefix) | ||||
self.kb = kb | |||||
def process(self, data_samples: Sequence[dict]) -> None: | def process(self, data_samples: Sequence[dict]) -> None: | ||||
pred_pseudo_label = data_samples["pred_pseudo_label"] | |||||
gt_Y = data_samples["Y"] | |||||
logic_forward = data_samples["logic_forward"] | |||||
for pred_z, y in zip(pred_pseudo_label, gt_Y): | |||||
if logic_forward(pred_z) == y: | |||||
for data_sample in data_samples: | |||||
if self.kb.check_equal(data_sample, data_sample["Y"][0]): | |||||
self.results.append(1) | self.results.append(1) | ||||
else: | else: | ||||
self.results.append(0) | self.results.append(0) | ||||
def compute_metrics(self, results: list) -> dict: | def compute_metrics(self, results: list) -> dict: | ||||
metrics = dict() | metrics = dict() | ||||
metrics["semantics_accuracy"] = sum(results) / len(results) | metrics["semantics_accuracy"] = sum(results) / len(results) | ||||
return metrics | |||||
return metrics |
@@ -1,4 +1,5 @@ | |||||
from typing import Optional, Sequence, Callable | |||||
from typing import Optional, Sequence | |||||
from .base_metric import BaseMetric | from .base_metric import BaseMetric | ||||
@@ -10,8 +10,10 @@ | |||||
# | # | ||||
# ================================================================# | # ================================================================# | ||||
import pickle | import pickle | ||||
from utils import flatten, reform_idx | |||||
from typing import List, Any, Optional | |||||
from typing import Any, Dict | |||||
from ..structures import ListData | |||||
from ..utils import reform_idx | |||||
class ABLModel: | class ABLModel: | ||||
@@ -30,7 +32,7 @@ class ABLModel: | |||||
Methods | Methods | ||||
------- | ------- | ||||
predict(X: List[List[Any]], mapping: Optional[dict] = None) -> dict | |||||
predict(X: List[List[Any]], mapping: Optional[Dict] = None) -> Dict | |||||
Predict the labels and probabilities for the given data. | Predict the labels and probabilities for the given data. | ||||
valid(X: List[List[Any]], Y: List[Any]) -> float | valid(X: List[List[Any]], Y: List[Any]) -> float | ||||
Calculate the accuracy score for the given data. | Calculate the accuracy score for the given data. | ||||
@@ -42,20 +44,13 @@ class ABLModel: | |||||
Load the model from a file. | Load the model from a file. | ||||
""" | """ | ||||
def __init__(self, base_model) -> None: | |||||
self.classifier_list = [] | |||||
self.classifier_list.append(base_model) | |||||
def __init__(self, base_model: Any) -> None: | |||||
if not (hasattr(base_model, "fit") and hasattr(base_model, "predict")): | |||||
raise NotImplementedError("The base_model should implement fit and predict methods.") | |||||
if not ( | |||||
hasattr(base_model, "fit") | |||||
and hasattr(base_model, "predict") | |||||
and hasattr(base_model, "score") | |||||
): | |||||
raise NotImplementedError( | |||||
"base_model should have fit, predict and score methods." | |||||
) | |||||
self.base_model = base_model | |||||
def predict(self, X: List[List[Any]], mapping: Optional[dict] = None) -> dict: | |||||
def predict(self, data_samples: ListData) -> Dict: | |||||
""" | """ | ||||
Predict the labels and probabilities for the given data. | Predict the labels and probabilities for the given data. | ||||
@@ -63,53 +58,30 @@ class ABLModel: | |||||
---------- | ---------- | ||||
X : List[List[Any]] | X : List[List[Any]] | ||||
The data to predict on. | The data to predict on. | ||||
mapping : Optional[dict], optional | |||||
A mapping dictionary to map labels to their original values, by default None. | |||||
Returns | Returns | ||||
------- | ------- | ||||
dict | dict | ||||
A dictionary containing the predicted labels and probabilities. | A dictionary containing the predicted labels and probabilities. | ||||
""" | """ | ||||
model = self.classifier_list[0] | |||||
data_X = flatten(X) | |||||
model = self.base_model | |||||
data_X = data_samples.flatten("X") | |||||
if hasattr(model, "predict_proba"): | if hasattr(model, "predict_proba"): | ||||
prob = model.predict_proba(X=data_X) | prob = model.predict_proba(X=data_X) | ||||
label = prob.argmax(axis=1) | label = prob.argmax(axis=1) | ||||
prob = reform_idx(prob, X) | |||||
prob = reform_idx(prob, data_samples["X"]) | |||||
else: | else: | ||||
prob = None | prob = None | ||||
label = model.predict(X=data_X) | label = model.predict(X=data_X) | ||||
label = reform_idx(label, data_samples["X"]) | |||||
if mapping is not None: | |||||
label = [mapping[y] for y in label] | |||||
label = reform_idx(label, X) | |||||
data_samples.pred_idx = label | |||||
if prob is not None: | |||||
data_samples.pred_prob = prob | |||||
return {"label": label, "prob": prob} | return {"label": label, "prob": prob} | ||||
def valid(self, X: List[List[Any]], Y: List[Any]) -> float: | |||||
""" | |||||
Calculate the accuracy for the given data. | |||||
Parameters | |||||
---------- | |||||
X : List[List[Any]] | |||||
The data to calculate the accuracy on. | |||||
Y : List[Any] | |||||
The true labels for the given data. | |||||
Returns | |||||
------- | |||||
float | |||||
The accuracy score for the given data. | |||||
""" | |||||
data_X = flatten(X) | |||||
data_Y = flatten(Y) | |||||
score = self.classifier_list[0].score(X=data_X, y=data_Y) | |||||
return score | |||||
def train(self, X: List[List[Any]], Y: List[Any]) -> float: | |||||
def train(self, data_samples: ListData) -> float: | |||||
""" | """ | ||||
Train the model on the given data. | Train the model on the given data. | ||||
@@ -125,29 +97,30 @@ class ABLModel: | |||||
float | float | ||||
The loss value of the trained model. | The loss value of the trained model. | ||||
""" | """ | ||||
data_X = flatten(X) | |||||
data_Y = flatten(Y) | |||||
return self.classifier_list[0].fit(X=data_X, y=data_Y) | |||||
data_X = data_samples.flatten("X") | |||||
data_y = data_samples.flatten("abduced_idx") | |||||
return self.base_model.fit(X=data_X, y=data_y) | |||||
def _model_operation(self, operation: str, *args, **kwargs): | def _model_operation(self, operation: str, *args, **kwargs): | ||||
model = self.classifier_list[0] | |||||
model = self.base_model | |||||
if hasattr(model, operation): | if hasattr(model, operation): | ||||
method = getattr(model, operation) | method = getattr(model, operation) | ||||
method(*args, **kwargs) | method(*args, **kwargs) | ||||
else: | else: | ||||
try: | |||||
if not f"{operation}_path" in kwargs.keys(): | |||||
raise ValueError(f"'{operation}_path' should not be None") | |||||
if operation == "save": | |||||
with open(kwargs["save_path"], 'wb') as file: | |||||
pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) | |||||
elif operation == "load": | |||||
with open(kwargs["load_path"], 'rb') as file: | |||||
self.classifier_list[0] = pickle.load(file) | |||||
except: | |||||
raise NotImplementedError( | |||||
f"{type(model).__name__} object doesn't have the {operation} method" | |||||
) | |||||
if not f"{operation}_path" in kwargs.keys(): | |||||
raise ValueError(f"'{operation}_path' should not be None") | |||||
else: | |||||
try: | |||||
if operation == "save": | |||||
with open(kwargs["save_path"], "wb") as file: | |||||
pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) | |||||
elif operation == "load": | |||||
with open(kwargs["load_path"], "rb") as file: | |||||
self.base_model = pickle.load(file) | |||||
except: | |||||
raise NotImplementedError( | |||||
f"{type(model).__name__} object doesn't have the {operation} method and the default pickle-based {operation} method failed." | |||||
) | |||||
def save(self, *args, **kwargs) -> None: | def save(self, *args, **kwargs) -> None: | ||||
""" | """ | ||||
@@ -10,14 +10,16 @@ | |||||
# | # | ||||
# ================================================================# | # ================================================================# | ||||
import torch | |||||
import os | |||||
import logging | |||||
from typing import Any, Callable, List, Optional, T, Tuple | |||||
import numpy | import numpy | ||||
import torch | |||||
from torch.utils.data import DataLoader | from torch.utils.data import DataLoader | ||||
from ..utils.logger import print_log | |||||
from ..dataset import ClassificationDataset | |||||
import os | |||||
from typing import List, Any, T, Optional, Callable, Tuple | |||||
from ..dataset import ClassificationDataset, PredictionDataset | |||||
from ..utils.logger import print_log | |||||
class BasicNN: | class BasicNN: | ||||
@@ -99,9 +101,7 @@ class BasicNN: | |||||
loss_value = self.train_epoch(data_loader) | loss_value = self.train_epoch(data_loader) | ||||
if self.save_interval is not None and (epoch + 1) % self.save_interval == 0: | if self.save_interval is not None and (epoch + 1) % self.save_interval == 0: | ||||
if self.save_dir is None: | if self.save_dir is None: | ||||
raise ValueError( | |||||
"save_dir should not be None if save_interval is not None." | |||||
) | |||||
raise ValueError("save_dir should not be None if save_interval is not None.") | |||||
self.save(epoch + 1) | self.save(epoch + 1) | ||||
if self.stop_loss is not None and loss_value < self.stop_loss: | if self.stop_loss is not None and loss_value < self.stop_loss: | ||||
break | break | ||||
@@ -191,7 +191,7 @@ class BasicNN: | |||||
with torch.no_grad(): | with torch.no_grad(): | ||||
results = [] | results = [] | ||||
for data, _ in data_loader: | |||||
for data in data_loader: | |||||
data = data.to(device) | data = data.to(device) | ||||
out = model(data) | out = model(data) | ||||
results.append(out) | results.append(out) | ||||
@@ -199,7 +199,10 @@ class BasicNN: | |||||
return torch.cat(results, axis=0) | return torch.cat(results, axis=0) | ||||
def predict( | def predict( | ||||
self, data_loader: DataLoader = None, X: List[Any] = None | |||||
self, | |||||
data_loader: DataLoader = None, | |||||
X: List[Any] = None, | |||||
test_transform: Callable[..., Any] = None, | |||||
) -> numpy.ndarray: | ) -> numpy.ndarray: | ||||
""" | """ | ||||
Predict the class of the input data. | Predict the class of the input data. | ||||
@@ -218,11 +221,28 @@ class BasicNN: | |||||
""" | """ | ||||
if data_loader is None: | if data_loader is None: | ||||
data_loader = self._data_loader(X) | |||||
if test_transform is None: | |||||
print_log( | |||||
"Transform used in the training phase will be used in prediction.", | |||||
"current", | |||||
level=logging.WARNING, | |||||
) | |||||
dataset = PredictionDataset(X, self.transform) | |||||
else: | |||||
dataset = PredictionDataset(X, test_transform) | |||||
data_loader = DataLoader( | |||||
dataset, | |||||
batch_size=self.batch_size, | |||||
num_workers=int(self.num_workers), | |||||
collate_fn=self.collate_fn, | |||||
) | |||||
return self._predict(data_loader).argmax(axis=1).cpu().numpy() | return self._predict(data_loader).argmax(axis=1).cpu().numpy() | ||||
def predict_proba( | def predict_proba( | ||||
self, data_loader: DataLoader = None, X: List[Any] = None | |||||
self, | |||||
data_loader: DataLoader = None, | |||||
X: List[Any] = None, | |||||
test_transform: Callable[..., Any] = None, | |||||
) -> numpy.ndarray: | ) -> numpy.ndarray: | ||||
""" | """ | ||||
Predict the probability of each class for the input data. | Predict the probability of each class for the input data. | ||||
@@ -241,7 +261,21 @@ class BasicNN: | |||||
""" | """ | ||||
if data_loader is None: | if data_loader is None: | ||||
data_loader = self._data_loader(X) | |||||
if test_transform is None: | |||||
print_log( | |||||
"Transform used in the training phase will be used in prediction.", | |||||
"current", | |||||
level=logging.WARNING, | |||||
) | |||||
dataset = PredictionDataset(X, self.transform) | |||||
else: | |||||
dataset = PredictionDataset(X, test_transform) | |||||
data_loader = DataLoader( | |||||
dataset, | |||||
batch_size=self.batch_size, | |||||
num_workers=int(self.num_workers), | |||||
collate_fn=self.collate_fn, | |||||
) | |||||
return self._predict(data_loader).softmax(axis=1).cpu().numpy() | return self._predict(data_loader).softmax(axis=1).cpu().numpy() | ||||
def _score(self, data_loader) -> Tuple[float, float]: | def _score(self, data_loader) -> Tuple[float, float]: | ||||
@@ -313,15 +347,14 @@ class BasicNN: | |||||
if data_loader is None: | if data_loader is None: | ||||
data_loader = self._data_loader(X, y) | data_loader = self._data_loader(X, y) | ||||
mean_loss, accuracy = self._score(data_loader) | mean_loss, accuracy = self._score(data_loader) | ||||
print_log( | |||||
f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current" | |||||
) | |||||
print_log(f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current") | |||||
return accuracy | return accuracy | ||||
def _data_loader( | def _data_loader( | ||||
self, | self, | ||||
X: List[Any], | X: List[Any], | ||||
y: List[int] = None, | y: List[int] = None, | ||||
shuffle: bool = True, | |||||
) -> DataLoader: | ) -> DataLoader: | ||||
""" | """ | ||||
Generate a DataLoader for user-provided input and target data. | Generate a DataLoader for user-provided input and target data. | ||||
@@ -350,7 +383,7 @@ class BasicNN: | |||||
data_loader = DataLoader( | data_loader = DataLoader( | ||||
dataset, | dataset, | ||||
batch_size=self.batch_size, | batch_size=self.batch_size, | ||||
shuffle=True, | |||||
shuffle=shuffle, | |||||
num_workers=int(self.num_workers), | num_workers=int(self.num_workers), | ||||
collate_fn=self.collate_fn, | collate_fn=self.collate_fn, | ||||
) | ) | ||||
@@ -368,14 +401,13 @@ class BasicNN: | |||||
The path to save the model, by default None. | The path to save the model, by default None. | ||||
""" | """ | ||||
if self.save_dir is None and save_path is None: | if self.save_dir is None and save_path is None: | ||||
raise ValueError( | |||||
"'save_dir' and 'save_path' should not be None simultaneously." | |||||
) | |||||
raise ValueError("'save_dir' and 'save_path' should not be None simultaneously.") | |||||
if save_path is None: | |||||
save_path = os.path.join( | |||||
self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth" | |||||
) | |||||
if save_path is not None: | |||||
if not os.path.exists(os.path.dirname(save_path)): | |||||
os.makedirs(os.path.dirname(save_path)) | |||||
else: | |||||
save_path = os.path.join(self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth") | |||||
if not os.path.exists(self.save_dir): | if not os.path.exists(self.save_dir): | ||||
os.makedirs(self.save_dir) | os.makedirs(self.save_dir) | ||||
@@ -1,2 +1,6 @@ | |||||
from .base_kb import BaseKB | |||||
from .ground_kb import GroundKB | |||||
from .prolog_based_kb import PrologBasedKB | |||||
from .reasoner import ReasonerBase | from .reasoner import ReasonerBase | ||||
from .kb import KBBase, prolog_KB | |||||
from .search_based_kb import SearchBasedKB | |||||
from .search_engine import BFS, BaseSearchEngine |
@@ -0,0 +1,14 @@ | |||||
from abc import ABC | |||||
class BaseKB(ABC): | |||||
def __init__(self, pseudo_label_list) -> None: | |||||
self.pseudo_label_list = pseudo_label_list | |||||
# TODO: When the output is excessively long, use ellipses as a substitute. | |||||
def __repr__(self): | |||||
return ( | |||||
f"<{self.__class__.__name__}(\n" | |||||
f" pseudo_label_list: {self.pseudo_label_list!r}\n" | |||||
f") at {hex(id(self))}>" | |||||
) |
@@ -0,0 +1,60 @@ | |||||
from abc import ABC, abstractmethod | |||||
from typing import Any, Hashable, List | |||||
from ..structures import ListData | |||||
from .base_kb import BaseKB | |||||
class GroundKB(BaseKB, ABC): | |||||
def __init__(self, pseudo_label_list: List) -> None: | |||||
super().__init__(pseudo_label_list) | |||||
self.GKB = self.construct_base() | |||||
@abstractmethod | |||||
def construct_base(self) -> dict: | |||||
pass | |||||
@abstractmethod | |||||
def get_key(self, data_sample: ListData) -> Hashable: | |||||
pass | |||||
def key2candidates(self, key: Hashable) -> List[List[Any]]: | |||||
return self.GKB[key] | |||||
def filter_candidates( | |||||
self, | |||||
data_sample: ListData, | |||||
candidates: List[List[Any]], | |||||
max_revision_num: int, | |||||
require_more_revision: int = 0, | |||||
) -> List[List[Any]]: | |||||
return candidates | |||||
def abduce_candidates( | |||||
self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0 | |||||
): | |||||
return self._abduce_by_GKB( | |||||
data_sample=data_sample, | |||||
max_revision_num=max_revision_num, | |||||
require_more_revision=require_more_revision, | |||||
) | |||||
def _abduce_by_GKB( | |||||
self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0 | |||||
): | |||||
candidates = self.key2candidates(self.get_key(data_sample)) | |||||
return self.filter_candidates( | |||||
data_sample=data_sample, | |||||
max_revision_num=max_revision_num, | |||||
require_more_revision=require_more_revision, | |||||
candidates=candidates, | |||||
) | |||||
# TODO: When the output is excessively long, use ellipses as a substitute. | |||||
def __repr__(self): | |||||
return ( | |||||
f"<{self.__class__.__name__}(\n" | |||||
f" pseudo_label_list: {self.pseudo_label_list!r}\n" | |||||
f" GKB: {self.GKB!r}\n" | |||||
f") at {hex(id(self))}>" | |||||
) |
@@ -1,222 +0,0 @@ | |||||
from abc import ABC, abstractmethod | |||||
import bisect | |||||
import numpy as np | |||||
from collections import defaultdict | |||||
from itertools import product, combinations | |||||
from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list | |||||
from multiprocessing import Pool | |||||
from functools import lru_cache | |||||
import pyswip | |||||
class KBBase(ABC): | |||||
def __init__(self, pseudo_label_list, max_err=0, use_cache=True): | |||||
# TODO:添加一下类型检查,比如 | |||||
# if not isinstance(X, (np.ndarray, spmatrix)): | |||||
# raise TypeError("X should be numpy array or sparse matrix") | |||||
self.pseudo_label_list = pseudo_label_list | |||||
self.max_err = max_err | |||||
self.use_cache = use_cache | |||||
@abstractmethod | |||||
def logic_forward(self, pseudo_labels): | |||||
pass | |||||
def abduce_candidates(self, pred_res, y, max_revision_num, require_more_revision=0): | |||||
if not self.use_cache: | |||||
return self._abduce_by_search(pred_res, y, max_revision_num, require_more_revision) | |||||
else: | |||||
return self._abduce_by_search_cache(to_hashable(pred_res), to_hashable(y), max_revision_num, require_more_revision) | |||||
def revise_by_idx(self, pred_res, y, revision_idx): | |||||
candidates = [] | |||||
abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx)) | |||||
for c in abduce_c: | |||||
candidate = pred_res.copy() | |||||
for i, idx in enumerate(revision_idx): | |||||
candidate[idx] = c[i] | |||||
if check_equal(self.logic_forward(candidate), y, self.max_err): | |||||
candidates.append(candidate) | |||||
return candidates | |||||
def _revision(self, revision_num, pred_res, y): | |||||
new_candidates = [] | |||||
revision_idx_list = combinations(range(len(pred_res)), revision_num) | |||||
for revision_idx in revision_idx_list: | |||||
candidates = self.revise_by_idx(pred_res, y, revision_idx) | |||||
new_candidates.extend(candidates) | |||||
return new_candidates | |||||
def _abduce_by_search(self, pred_res, y, max_revision_num, require_more_revision): | |||||
candidates = [] | |||||
for revision_num in range(len(pred_res) + 1): | |||||
if revision_num == 0 and check_equal(self.logic_forward(pred_res), y, self.max_err): | |||||
candidates.append(pred_res) | |||||
elif revision_num > 0: | |||||
candidates.extend(self._revision(revision_num, pred_res, y)) | |||||
if len(candidates) > 0: | |||||
min_revision_num = revision_num | |||||
break | |||||
if revision_num >= max_revision_num: | |||||
return [] | |||||
for revision_num in range(min_revision_num + 1, min_revision_num + require_more_revision + 1): | |||||
if revision_num > max_revision_num: | |||||
return candidates | |||||
candidates.extend(self._revision(revision_num, pred_res, y)) | |||||
return candidates | |||||
@lru_cache(maxsize=None) | |||||
def _abduce_by_search_cache(self, pred_res, y, max_revision_num, require_more_revision): | |||||
pred_res = hashable_to_list(pred_res) | |||||
y = hashable_to_list(y) | |||||
return self._abduce_by_search(pred_res, y, max_revision_num, require_more_revision) | |||||
def _dict_len(self, dic): | |||||
if not self.GKB_flag: | |||||
return 0 | |||||
else: | |||||
return sum(len(c) for c in dic.values()) | |||||
def __len__(self): | |||||
if not self.GKB_flag: | |||||
return 0 | |||||
else: | |||||
return sum(self._dict_len(v) for v in self.base.values()) | |||||
class ground_KB(KBBase): | |||||
def __init__(self, pseudo_label_list, GKB_len_list=None, max_err=0): | |||||
super().__init__(pseudo_label_list, max_err) | |||||
self.GKB_len_list = GKB_len_list | |||||
self.base = {} | |||||
X, Y = self._get_GKB() | |||||
for x, y in zip(X, Y): | |||||
self.base.setdefault(len(x), defaultdict(list))[y].append(x) | |||||
# For parallel version of _get_GKB | |||||
def _get_XY_list(self, args): | |||||
pre_x, post_x_it = args[0], args[1] | |||||
XY_list = [] | |||||
for post_x in post_x_it: | |||||
x = (pre_x,) + post_x | |||||
y = self.logic_forward(x) | |||||
if y is not None: | |||||
XY_list.append((x, y)) | |||||
return XY_list | |||||
# Parallel _get_GKB | |||||
def _get_GKB(self): | |||||
X, Y = [], [] | |||||
for length in self.GKB_len_list: | |||||
arg_list = [] | |||||
for pre_x in self.pseudo_label_list: | |||||
post_x_it = product(self.pseudo_label_list, repeat=length - 1) | |||||
arg_list.append((pre_x, post_x_it)) | |||||
with Pool(processes=len(arg_list)) as pool: | |||||
ret_list = pool.map(self._get_XY_list, arg_list) | |||||
for XY_list in ret_list: | |||||
if len(XY_list) == 0: | |||||
continue | |||||
part_X, part_Y = zip(*XY_list) | |||||
X.extend(part_X) | |||||
Y.extend(part_Y) | |||||
if Y and isinstance(Y[0], (int, float)): | |||||
X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) | |||||
return X, Y | |||||
def abduce_candidates(self, pred_res, y, max_revision_num, require_more_revision=0): | |||||
return self._abduce_by_GKB(pred_res, y, max_revision_num, require_more_revision) | |||||
def _find_candidate_GKB(self, pred_res, y): | |||||
if self.max_err == 0: | |||||
return self.base[len(pred_res)][y] | |||||
else: | |||||
potential_candidates = self.base[len(pred_res)] | |||||
key_list = list(potential_candidates.keys()) | |||||
key_idx = bisect.bisect_left(key_list, y) | |||||
all_candidates = [] | |||||
for idx in range(key_idx - 1, 0, -1): | |||||
k = key_list[idx] | |||||
if abs(k - y) <= self.max_err: | |||||
all_candidates.extend(potential_candidates[k]) | |||||
else: | |||||
break | |||||
for idx in range(key_idx, len(key_list)): | |||||
k = key_list[idx] | |||||
if abs(k - y) <= self.max_err: | |||||
all_candidates.extend(potential_candidates[k]) | |||||
else: | |||||
break | |||||
return all_candidates | |||||
def _abduce_by_GKB(self, pred_res, y, max_revision_num, require_more_revision): | |||||
if self.base == {} or len(pred_res) not in self.GKB_len_list: | |||||
return [] | |||||
all_candidates = self._find_candidate_GKB(pred_res, y) | |||||
if len(all_candidates) == 0: | |||||
return [] | |||||
cost_list = hamming_dist(pred_res, all_candidates) | |||||
min_revision_num = np.min(cost_list) | |||||
revision_num = min(max_revision_num, min_revision_num + require_more_revision) | |||||
idxs = np.where(cost_list <= revision_num)[0] | |||||
candidates = [all_candidates[idx] for idx in idxs] | |||||
return candidates | |||||
class prolog_KB(KBBase): | |||||
def __init__(self, pseudo_label_list, pl_file, max_err=0): | |||||
super().__init__(pseudo_label_list, max_err) | |||||
self.prolog = pyswip.Prolog() | |||||
self.prolog.consult(pl_file) | |||||
def logic_forward(self, pseudo_labels): | |||||
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]['Res'] | |||||
if result == 'true': | |||||
return True | |||||
elif result == 'false': | |||||
return False | |||||
return result | |||||
def _revision_pred_res(self, pred_res, revision_idx): | |||||
import re | |||||
revision_pred_res = pred_res.copy() | |||||
revision_pred_res = flatten(revision_pred_res) | |||||
for idx in revision_idx: | |||||
revision_pred_res[idx] = 'P' + str(idx) | |||||
revision_pred_res = reform_idx(revision_pred_res, pred_res) | |||||
# TODO:不知道有没有更简洁的方法 | |||||
regex = r"'P\d+'" | |||||
return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pred_res)) | |||||
def get_query_string(self, pred_res, y, revision_idx): | |||||
query_string = "logic_forward(" | |||||
query_string += self._revision_pred_res(pred_res, revision_idx) | |||||
key_is_none_flag = y is None or (type(y) == list and y[0] is None) | |||||
query_string += ",%s)." % y if not key_is_none_flag else ")." | |||||
return query_string | |||||
def revise_by_idx(self, pred_res, y, revision_idx): | |||||
candidates = [] | |||||
query_string = self.get_query_string(pred_res, y, revision_idx) | |||||
save_pred_res = pred_res | |||||
pred_res = flatten(pred_res) | |||||
abduce_c = [list(z.values()) for z in self.prolog.query(query_string)] | |||||
for c in abduce_c: | |||||
candidate = pred_res.copy() | |||||
for i, idx in enumerate(revision_idx): | |||||
candidate[idx] = c[i] | |||||
candidate = reform_idx(candidate, save_pred_res) | |||||
candidates.append(candidate) | |||||
return candidates |
@@ -0,0 +1,44 @@ | |||||
from abc import ABC, abstractmethod | |||||
from typing import Any, Generator, List, Tuple, Union | |||||
import numpy as np | |||||
import pyswip | |||||
from ..structures import ListData | |||||
from .base_kb import BaseKB | |||||
class PrologBasedKB(BaseKB, ABC): | |||||
def __init__(self, pseudo_label_list, pl_file): | |||||
self.pseudo_label_list = pseudo_label_list | |||||
self.prolog = pyswip.Prolog() | |||||
self.prolog.consult(pl_file) | |||||
def logic_forward( | |||||
self, data_sample: ListData, revision_idx: Union[List, Tuple, np.ndarray] = None | |||||
) -> Generator[Union[Any, pyswip.Variable, list, dict, None], Any, None]: | |||||
return self.prolog.query(self.to_query(data_sample, revision_idx)) | |||||
@abstractmethod | |||||
def to_query(self, data_sample: ListData, revision_idx: Union[List, Tuple, np.ndarray] = None): | |||||
pass | |||||
@abstractmethod | |||||
def postprocess( | |||||
self, query_res, data_sample: ListData, revision_idx: Union[List, Tuple, np.ndarray] | |||||
): | |||||
return list(query_res) | |||||
@abstractmethod | |||||
def filter_candidates( | |||||
self, | |||||
data_sample: ListData, | |||||
candidates: List[List[Any]], | |||||
max_revision_num: int, | |||||
require_more_revision: int = 0, | |||||
) -> List[List[Any]]: | |||||
return candidates | |||||
def revise_at_idx(self, data_sample: ListData, revision_idx: Union[List, Tuple, np.ndarray]): | |||||
query_res = self.logic_forward(data_sample, revision_idx) | |||||
return self.postprocess(query_res, data_sample, revision_idx) |
@@ -1,25 +1,33 @@ | |||||
from typing import Any, List, Mapping, Optional | |||||
import numpy as np | import numpy as np | ||||
from zoopt import Dimension, Objective, Parameter, Opt | |||||
from ..utils.utils import ( | |||||
confidence_dist, | |||||
flatten, | |||||
reform_idx, | |||||
hamming_dist, | |||||
calculate_revision_num, | |||||
) | |||||
from ..structures import ListData | |||||
from ..utils import Cache, calculate_revision_num, confidence_dist, hamming_dist | |||||
from .base_kb import BaseKB | |||||
from .search_engine import BFS, BaseSearchEngine | |||||
class ReasonerBase: | class ReasonerBase: | ||||
def __init__(self, kb, dist_func="hamming", mapping=None, use_zoopt=False): | |||||
def __init__( | |||||
self, | |||||
kb: BaseKB, | |||||
dist_func: str = "confidence", | |||||
mapping: Optional[Mapping] = None, | |||||
search_engine: Optional[BaseSearchEngine] = None, | |||||
use_cache: bool = False, | |||||
cache_file: Optional[str] = None, | |||||
cache_size: Optional[int] = 4096, | |||||
): | |||||
""" | """ | ||||
Base class for all reasoner in the ABL system. | Base class for all reasoner in the ABL system. | ||||
Parameters | Parameters | ||||
---------- | ---------- | ||||
kb : KBBase | |||||
kb : BaseKB | |||||
The knowledge base to be used for reasoning. | The knowledge base to be used for reasoning. | ||||
dist_func : str, optional | dist_func : str, optional | ||||
The distance function to be used. Can be "hamming" or "confidence". Default is "hamming". | |||||
The distance function to be used. Can be "hamming" or "confidence". Default is "confidence". | |||||
mapping : dict, optional | mapping : dict, optional | ||||
A mapping of indices to labels. If None, a default mapping is generated. | A mapping of indices to labels. If None, a default mapping is generated. | ||||
use_zoopt : bool, optional | use_zoopt : bool, optional | ||||
@@ -31,207 +39,204 @@ class ReasonerBase: | |||||
If the specified distance function is neither "hamming" nor "confidence". | If the specified distance function is neither "hamming" nor "confidence". | ||||
""" | """ | ||||
if not (dist_func == "hamming" or dist_func == "confidence"): | |||||
raise NotImplementedError # Only hamming or confidence distance is available. | |||||
if not isinstance(kb, BaseKB): | |||||
raise ValueError("The kb should be of type BaseKB.") | |||||
self.kb = kb | self.kb = kb | ||||
if dist_func not in ["hamming", "confidence"]: | |||||
raise NotImplementedError(f"The distance function '{dist_func}' is not implemented.") | |||||
self.dist_func = dist_func | self.dist_func = dist_func | ||||
self.use_zoopt = use_zoopt | |||||
if mapping is None: | if mapping is None: | ||||
self.mapping = { | |||||
index: label for index, label in enumerate(self.kb.pseudo_label_list) | |||||
} | |||||
self.mapping = {index: label for index, label in enumerate(self.kb.pseudo_label_list)} | |||||
else: | else: | ||||
self.mapping = mapping | |||||
self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) | |||||
if not isinstance(mapping, dict): | |||||
raise ValueError("mapping must be of type dict") | |||||
def _get_cost_list(self, pred_pseudo_label, pred_prob, candidates): | |||||
""" | |||||
Get the list of costs between each pseudo label and candidate. | |||||
for key, value in mapping.items(): | |||||
if not isinstance(key, int): | |||||
raise ValueError("All keys in the mapping must be integers") | |||||
Parameters | |||||
---------- | |||||
pred_pseudo_label : list | |||||
The pseudo label to be used for computing costs of candidates. | |||||
pred_prob : list | |||||
Probabilities of the predictions. Used when distance function is "confidence". | |||||
candidates : list | |||||
List of candidate abduction result. | |||||
if value not in self.kb.pseudo_label_list: | |||||
raise ValueError("All values in the mapping must be in the pseudo_label_list") | |||||
Returns | |||||
------- | |||||
numpy.ndarray | |||||
Array of computed costs for each candidate. | |||||
""" | |||||
if self.dist_func == "hamming": | |||||
return hamming_dist(pred_pseudo_label, candidates) | |||||
elif self.dist_func == "confidence": | |||||
candidates = [[self.remapping[x] for x in c] for c in candidates] | |||||
return confidence_dist(pred_prob, candidates) | |||||
def _get_one_candidate(self, pred_pseudo_label, pred_prob, candidates): | |||||
""" | |||||
Get one candidate. If multiple candidates exist, return the one with minimum cost. | |||||
Parameters | |||||
---------- | |||||
pred_pseudo_label : list | |||||
The pseudo label to be used for selecting a candidate. | |||||
pred_prob : list | |||||
Probabilities of the predictions. | |||||
candidates : list | |||||
List of candidate abduction result. | |||||
self.mapping = mapping | |||||
self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) | |||||
Returns | |||||
------- | |||||
list | |||||
The chosen candidate based on minimum cost. | |||||
If no candidates, an empty list is returned. | |||||
""" | |||||
if len(candidates) == 0: | |||||
return [] | |||||
elif len(candidates) == 1: | |||||
return candidates[0] | |||||
if search_engine is None: | |||||
self.search_engine = BFS() | |||||
else: | else: | ||||
cost_array = self._get_cost_list(pred_pseudo_label, pred_prob, candidates) | |||||
candidate = candidates[np.argmin(cost_array)] | |||||
return candidate | |||||
if not isinstance(search_engine, BaseSearchEngine): | |||||
raise ValueError("The search_engine should be of type BaseSearchEngine.") | |||||
else: | |||||
self.search_engine = search_engine | |||||
self.use_cache = use_cache | |||||
self.cache_file = cache_file | |||||
if self.use_cache: | |||||
if not hasattr(self, "get_key"): | |||||
raise NotImplementedError("If use_cache is True, get_key should be implemented.") | |||||
key_func = self.get_key | |||||
else: | |||||
key_func = lambda x: x | |||||
self.cache = Cache[ListData, List[List[Any]]]( | |||||
func=self.abduce_candidates, | |||||
cache=self.use_cache, | |||||
cache_file=self.cache_file, | |||||
key_func=key_func, | |||||
max_size=cache_size, | |||||
) | |||||
def zoopt_revision_score(self, symbol_num, pred_pseudo_label, pred_prob, y, sol): | |||||
def abduce( | |||||
self, | |||||
data_sample: ListData, | |||||
max_revision: int = -1, | |||||
require_more_revision: int = 0, | |||||
): | |||||
""" | """ | ||||
Get the revision score for a single solution. | |||||
Perform revision by abduction on the given data. | |||||
Parameters | Parameters | ||||
---------- | ---------- | ||||
symbol_num : int | |||||
Number of total symbols. | |||||
pred_pseudo_label : list | |||||
List of predicted pseudo labels. | |||||
pred_prob : list | pred_prob : list | ||||
List of probabilities for predicted results. | List of probabilities for predicted results. | ||||
pred_pseudo_label : list | |||||
List of predicted pseudo labels. | |||||
y : any | y : any | ||||
Ground truth for the predicted results. | Ground truth for the predicted results. | ||||
sol : array-like | |||||
Solution to evaluate. | |||||
max_revision : int or float, optional | |||||
Maximum number of revisions to use. If float, represents the fraction of total revisions to use. | |||||
If -1, any revisions are allowed. Defaults to -1. | |||||
require_more_revision : int, optional | |||||
Number of additional revisions to require. Defaults to 0. | |||||
Returns | Returns | ||||
------- | ------- | ||||
float | |||||
The revision score for the given solution. | |||||
list | |||||
The abduced revisions. | |||||
""" | """ | ||||
revision_idx = np.where(sol.get_x() != 0)[0] | |||||
candidates = self.revise_by_idx(pred_pseudo_label, y, revision_idx) | |||||
if len(candidates) > 0: | |||||
return np.min(self._get_cost_list(pred_pseudo_label, pred_prob, candidates)) | |||||
else: | |||||
return symbol_num | |||||
symbol_num = data_sample.elements_num("pred_pseudo_label") | |||||
max_revision_num = calculate_revision_num(max_revision, symbol_num) | |||||
data_sample.set_metainfo(dict(symbol_num=symbol_num)) | |||||
def _constrain_revision_num(self, solution, max_revision_num): | |||||
x = solution.get_x() | |||||
return max_revision_num - x.sum() | |||||
candidates = self.cache.get(data_sample, max_revision_num, require_more_revision) | |||||
candidate = self.select_one_candidate(data_sample, candidates) | |||||
return candidate | |||||
def zoopt_get_solution( | |||||
self, symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num | |||||
def abduce_candidates( | |||||
self, | |||||
data_sample: ListData, | |||||
max_revision_num: int = -1, | |||||
require_more_revision: int = 0, | |||||
): | ): | ||||
"""Get the optimal solution using the Zoopt library. | |||||
""" | |||||
Perform revision by abduction on the given data. | |||||
Parameters | Parameters | ||||
---------- | ---------- | ||||
symbol_num : int | |||||
Number of total symbols. | |||||
pred_pseudo_label : list | |||||
List of predicted pseudo labels. | |||||
pred_prob : list | pred_prob : list | ||||
List of probabilities for predicted results. | List of probabilities for predicted results. | ||||
pred_pseudo_label : list | |||||
List of predicted pseudo labels. | |||||
y : any | y : any | ||||
Ground truth for the predicted results. | Ground truth for the predicted results. | ||||
max_revision_num : int | |||||
Maximum number of revisions to use. | |||||
max_revision : int or float, optional | |||||
Maximum number of revisions to use. If float, represents the fraction of total revisions to use. | |||||
If -1, any revisions are allowed. Defaults to -1. | |||||
require_more_revision : int, optional | |||||
Number of additional revisions to require. Defaults to 0. | |||||
Returns | Returns | ||||
------- | ------- | ||||
array-like | |||||
The optimal solution, i.e., where to revise predict pseudo label. | |||||
list | |||||
The abduced revisions. | |||||
""" | """ | ||||
dimension = Dimension( | |||||
size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num | |||||
) | |||||
objective = Objective( | |||||
lambda sol: self.zoopt_revision_score( | |||||
symbol_num, pred_pseudo_label, pred_prob, y, sol | |||||
), | |||||
dim=dimension, | |||||
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), | |||||
) | |||||
parameter = Parameter(budget=100, intermediate_result=False, autoset=True) | |||||
solution = Opt.min(objective, parameter).get_x() | |||||
return solution | |||||
def revise_by_idx(self, pred_pseudo_label, y, revision_idx): | |||||
if hasattr(self.kb, "abduce_candidates"): | |||||
candidates = self.kb.abduce_candidates( | |||||
data_sample, max_revision_num, require_more_revision | |||||
) | |||||
elif hasattr(self.kb, "revise_at_idx"): | |||||
candidates = [] | |||||
gen = self.search_engine.generator( | |||||
data_sample, | |||||
max_revision_num=max_revision_num, | |||||
require_more_revision=require_more_revision, | |||||
) | |||||
send_signal = True | |||||
for revision_idx in gen: | |||||
candidates.extend(self.kb.revise_at_idx(data_sample, revision_idx)) | |||||
if len(candidates) > 0 and send_signal: | |||||
try: | |||||
revision_idx = gen.send("success") | |||||
candidates.extend(self.kb.revise_at_idx(data_sample, revision_idx)) | |||||
send_signal = False | |||||
except StopIteration: | |||||
break | |||||
else: | |||||
raise NotImplementedError( | |||||
"The kb should either implement abduce_candidates or revise_at_idx." | |||||
) | |||||
return candidates | |||||
def select_one_candidate(self, data_sample: ListData, candidates: List[List[Any]]): | |||||
""" | """ | ||||
Revise the pseudo label according to the given indices. | |||||
Get one candidate. If multiple candidates exist, return the one with minimum cost. | |||||
Parameters | Parameters | ||||
---------- | ---------- | ||||
pred_pseudo_label : list | pred_pseudo_label : list | ||||
List of predicted pseudo labels. | |||||
y : any | |||||
Ground truth for the predicted results. | |||||
revision_idx : array-like | |||||
Indices of the revisions to retrieve. | |||||
The pseudo label to be used for selecting a candidate. | |||||
pred_prob : list | |||||
Probabilities of the predictions. | |||||
candidates : list | |||||
List of candidate abduction result. | |||||
Returns | Returns | ||||
------- | ------- | ||||
list | list | ||||
The revisions according to the given indices. | |||||
The chosen candidate based on minimum cost. | |||||
If no candidates, an empty list is returned. | |||||
""" | """ | ||||
return self.kb.revise_by_idx(pred_pseudo_label, y, revision_idx) | |||||
if len(candidates) == 0: | |||||
return [] | |||||
elif len(candidates) == 1: | |||||
return candidates[0] | |||||
else: | |||||
cost_array = self._get_dist_list(data_sample, candidates) | |||||
candidate = candidates[np.argmin(cost_array)] | |||||
return candidate | |||||
def abduce( | |||||
self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0 | |||||
): | |||||
def _get_dist_list(self, data_sample: ListData, candidates: List[List[Any]]): | |||||
""" | """ | ||||
Perform revision by abduction on the given data. | |||||
Get the list of costs between each pseudo label and candidate. | |||||
Parameters | Parameters | ||||
---------- | ---------- | ||||
pred_prob : list | |||||
List of probabilities for predicted results. | |||||
pred_pseudo_label : list | pred_pseudo_label : list | ||||
List of predicted pseudo labels. | |||||
y : any | |||||
Ground truth for the predicted results. | |||||
max_revision : int or float, optional | |||||
Maximum number of revisions to use. If float, represents the fraction of total revisions to use. | |||||
If -1, any revisions are allowed. Defaults to -1. | |||||
require_more_revision : int, optional | |||||
Number of additional revisions to require. Defaults to 0. | |||||
The pseudo label to be used for computing costs of candidates. | |||||
pred_prob : list | |||||
Probabilities of the predictions. Used when distance function is "confidence". | |||||
candidates : list | |||||
List of candidate abduction result. | |||||
Returns | Returns | ||||
------- | ------- | ||||
list | |||||
The abduced revisions. | |||||
numpy.ndarray | |||||
Array of computed costs for each candidate. | |||||
""" | """ | ||||
symbol_num = len(flatten(pred_pseudo_label)) | |||||
max_revision_num = calculate_revision_num(max_revision, symbol_num) | |||||
if self.use_zoopt: | |||||
solution = self.zoopt_get_solution( | |||||
symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num | |||||
) | |||||
revision_idx = np.where(solution != 0)[0] | |||||
candidates = self.revise_by_idx(pred_pseudo_label, y, revision_idx) | |||||
else: | |||||
candidates = self.kb.abduce_candidates( | |||||
pred_pseudo_label, y, max_revision_num, require_more_revision | |||||
) | |||||
if self.dist_func == "hamming": | |||||
return hamming_dist(data_sample["pred_pseudo_label"][0], candidates) | |||||
candidate = self._get_one_candidate(pred_pseudo_label, pred_prob, candidates) | |||||
return candidate | |||||
elif self.dist_func == "confidence": | |||||
candidates = [[self.remapping[x] for x in c] for c in candidates] | |||||
return confidence_dist(data_sample["pred_prob"][0], candidates) | |||||
def batch_abduce( | def batch_abduce( | ||||
self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0 | |||||
self, | |||||
data_samples: ListData, | |||||
max_revision: int = -1, | |||||
require_more_revision: int = 0, | |||||
): | ): | ||||
""" | """ | ||||
Perform abduction on the given data in batches. | Perform abduction on the given data in batches. | ||||
@@ -255,384 +260,13 @@ class ReasonerBase: | |||||
list | list | ||||
The abduced revisions in batches. | The abduced revisions in batches. | ||||
""" | """ | ||||
return [ | |||||
abduced_pseudo_label = [ | |||||
self.abduce( | self.abduce( | ||||
_pred_prob, _pred_pseudo_label, _Y, max_revision, require_more_revision | |||||
) | |||||
for _pred_prob, _pred_pseudo_label, _Y in zip( | |||||
pred_prob, pred_pseudo_label, Y | |||||
data_sample, | |||||
max_revision=max_revision, | |||||
require_more_revision=require_more_revision, | |||||
) | ) | ||||
for data_sample in data_samples | |||||
] | ] | ||||
# def _batch_abduce_helper(self, args): | |||||
# z, prob, y, max_revision, require_more_revision = args | |||||
# return self.abduce((z, prob, y), max_revision, require_more_revision) | |||||
# def batch_abduce(self, Z, Y, max_revision=-1, require_more_revision=0): | |||||
# with Pool(processes=os.cpu_count()) as pool: | |||||
# results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]) | |||||
# return results | |||||
def __call__( | |||||
self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0 | |||||
): | |||||
return self.batch_abduce( | |||||
pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision | |||||
) | |||||
if __name__ == "__main__": | |||||
from kb import KBBase, ground_KB, prolog_KB | |||||
prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] | |||||
prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] | |||||
class add_KB(KBBase): | |||||
def __init__(self, pseudo_label_list=list(range(10)), | |||||
use_cache=True): | |||||
super().__init__(pseudo_label_list, use_cache=use_cache) | |||||
def logic_forward(self, nums): | |||||
return sum(nums) | |||||
class add_ground_KB(ground_KB): | |||||
def __init__(self, pseudo_label_list=list(range(10)), | |||||
GKB_len_list=[2]): | |||||
super().__init__(pseudo_label_list, GKB_len_list) | |||||
def logic_forward(self, nums): | |||||
return sum(nums) | |||||
def test_add(reasoner): | |||||
res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||||
print(res) | |||||
res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||||
print(res) | |||||
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) | |||||
print(res) | |||||
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) | |||||
print(res) | |||||
res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) | |||||
print(res) | |||||
print() | |||||
print("add_KB with GKB:") | |||||
kb = add_ground_KB() | |||||
reasoner = ReasonerBase(kb, "confidence") | |||||
test_add(reasoner) | |||||
print("add_KB without GKB:") | |||||
kb = add_KB() | |||||
reasoner = ReasonerBase(kb, "confidence") | |||||
test_add(reasoner) | |||||
print("add_KB without GKB, no cache") | |||||
kb = add_KB(use_cache=False) | |||||
reasoner = ReasonerBase(kb, "confidence") | |||||
test_add(reasoner) | |||||
print("prolog_KB with add.pl:") | |||||
kb = prolog_KB(pseudo_label_list=list(range(10)), | |||||
pl_file="examples/mnist_add/datasets/add.pl") | |||||
reasoner = ReasonerBase(kb, "confidence") | |||||
test_add(reasoner) | |||||
print("prolog_KB with add.pl using zoopt:") | |||||
kb = prolog_KB( | |||||
pseudo_label_list=list(range(10)), | |||||
pl_file="examples/mnist_add/datasets/add.pl", | |||||
) | |||||
reasoner = ReasonerBase(kb, "confidence", use_zoopt=True) | |||||
test_add(reasoner) | |||||
print("add_KB with multiple inputs at once:") | |||||
multiple_prob = [[ | |||||
[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||||
], | |||||
[ | |||||
[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||||
]] | |||||
kb = add_KB() | |||||
reasoner = ReasonerBase(kb, "confidence") | |||||
res = reasoner.batch_abduce( | |||||
multiple_prob, | |||||
[[1, 1], [1, 2]], | |||||
[4, 8], | |||||
max_revision=2, | |||||
require_more_revision=0, | |||||
) | |||||
print(res) | |||||
res = reasoner.batch_abduce( | |||||
multiple_prob, | |||||
[[1, 1], [1, 2]], | |||||
[4, 8], | |||||
max_revision=2, | |||||
require_more_revision=1, | |||||
) | |||||
print(res) | |||||
print() | |||||
class HWF_KB(KBBase): | |||||
def __init__( | |||||
self, | |||||
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", | |||||
"+", "-", "times", "div"], | |||||
max_err=1e-3, | |||||
): | |||||
super().__init__(pseudo_label_list, max_err) | |||||
def _valid_candidate(self, formula): | |||||
if len(formula) % 2 == 0: | |||||
return False | |||||
for i in range(len(formula)): | |||||
if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: | |||||
return False | |||||
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | |||||
return False | |||||
return True | |||||
def logic_forward(self, formula): | |||||
if not self._valid_candidate(formula): | |||||
return np.inf | |||||
mapping = {str(i): str(i) for i in range(1, 10)} | |||||
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | |||||
formula = [mapping[f] for f in formula] | |||||
return eval("".join(formula)) | |||||
class HWF_ground_KB(ground_KB): | |||||
def __init__( | |||||
self, | |||||
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", | |||||
"+", "-", "times", "div"], | |||||
GKB_len_list=[1, 3, 5, 7], | |||||
max_err=1e-3, | |||||
): | |||||
super().__init__(pseudo_label_list, GKB_len_list, max_err) | |||||
def _valid_candidate(self, formula): | |||||
if len(formula) % 2 == 0: | |||||
return False | |||||
for i in range(len(formula)): | |||||
if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: | |||||
return False | |||||
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | |||||
return False | |||||
return True | |||||
def logic_forward(self, formula): | |||||
if not self._valid_candidate(formula): | |||||
return np.inf | |||||
mapping = {str(i): str(i) for i in range(1, 10)} | |||||
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | |||||
formula = [mapping[f] for f in formula] | |||||
return eval("".join(formula)) | |||||
def test_hwf(reasoner): | |||||
res = reasoner.batch_abduce( | |||||
[None], | |||||
[["5", "+", "2"]], | |||||
[3], | |||||
max_revision=2, | |||||
require_more_revision=0, | |||||
) | |||||
print(res) | |||||
res = reasoner.batch_abduce( | |||||
[None], | |||||
[["5", "+", "9"]], | |||||
[65], | |||||
max_revision=3, | |||||
require_more_revision=0, | |||||
) | |||||
print(res) | |||||
res = reasoner.batch_abduce( | |||||
[None], | |||||
[["5", "8", "8", "8", "8"]], | |||||
[3.17], | |||||
max_revision=5, | |||||
require_more_revision=3, | |||||
) | |||||
print(res) | |||||
print() | |||||
def test_hwf_multiple(reasoner, max_revisions): | |||||
res = reasoner.batch_abduce( | |||||
[None, None], | |||||
[["5", "+", "2"], ["5", "+", "9"]], | |||||
[3, 64], | |||||
max_revision=max_revisions[0], | |||||
require_more_revision=0, | |||||
) | |||||
print(res) | |||||
res = reasoner.batch_abduce( | |||||
[None, None], | |||||
[["5", "+", "2"], ["5", "+", "9"]], | |||||
[3, 64], | |||||
max_revision=max_revisions[1], | |||||
require_more_revision=0, | |||||
) | |||||
print(res) | |||||
res = reasoner.batch_abduce( | |||||
[None, None], | |||||
[["5", "+", "2"], ["5", "+", "9"]], | |||||
[3, 65], | |||||
max_revision=max_revisions[2], | |||||
require_more_revision=0, | |||||
) | |||||
print(res) | |||||
print() | |||||
print("HWF_KB with GKB, max_err=0.1") | |||||
kb = HWF_ground_KB(GKB_len_list=[1, 3, 5], max_err=0.1) | |||||
reasoner = ReasonerBase(kb, "hamming") | |||||
test_hwf(reasoner) | |||||
print("HWF_KB without GKB, max_err=0.1") | |||||
kb = HWF_KB(max_err=0.1) | |||||
reasoner = ReasonerBase(kb, "hamming") | |||||
test_hwf(reasoner) | |||||
print("HWF_KB with GKB, max_err=1") | |||||
kb = HWF_ground_KB(GKB_len_list=[1, 3, 5], max_err=1) | |||||
reasoner = ReasonerBase(kb, "hamming") | |||||
test_hwf(reasoner) | |||||
print("HWF_KB without GKB, max_err=1") | |||||
kb = HWF_KB(max_err=1) | |||||
reasoner = ReasonerBase(kb, "hamming") | |||||
test_hwf(reasoner) | |||||
print("HWF_KB with multiple inputs at once:") | |||||
kb = HWF_KB(max_err=0.1) | |||||
reasoner = ReasonerBase(kb, "hamming") | |||||
test_hwf_multiple(reasoner, max_revisions=[1,3,3]) | |||||
print("max_revision is float") | |||||
test_hwf_multiple(reasoner, max_revisions=[0.5,0.9,0.9]) | |||||
class HED_prolog_KB(prolog_KB): | |||||
def __init__(self, pseudo_label_list, pl_file): | |||||
super().__init__(pseudo_label_list, pl_file) | |||||
def consist_rule(self, exs, rules): | |||||
rules = str(rules).replace("'", "") | |||||
pl_query = "eval_inst_feature(%s, %s)." % (exs, rules) | |||||
return len(list(self.prolog.query(pl_query))) != 0 | |||||
def abduce_rules(self, pred_res): | |||||
pl_query = "consistent_inst_feature(%s, X)." % pred_res | |||||
prolog_result = list(self.prolog.query(pl_query)) | |||||
if len(prolog_result) == 0: | |||||
return None | |||||
prolog_rules = prolog_result[0]["X"] | |||||
rules = [rule.value for rule in prolog_rules] | |||||
return rules | |||||
class HED_Reasoner(ReasonerBase): | |||||
def __init__(self, kb, dist_func="hamming"): | |||||
super().__init__(kb, dist_func, use_zoopt=True) | |||||
def _revise_by_idxs(self, pred_res, y, all_revision_flag, idxs): | |||||
pred = [] | |||||
k = [] | |||||
revision_flag = [] | |||||
for idx in idxs: | |||||
pred.append(pred_res[idx]) | |||||
k.append(y[idx]) | |||||
revision_flag += list(all_revision_flag[idx]) | |||||
revision_idx = np.where(np.array(revision_flag) != 0)[0] | |||||
candidate = self.revise_by_idx(pred, k, revision_idx) | |||||
return candidate | |||||
def zoopt_revision_score(self, symbol_num, pred_res, pred_prob, y, sol): | |||||
all_revision_flag = reform_idx(sol.get_x(), pred_res) | |||||
lefted_idxs = [i for i in range(len(pred_res))] | |||||
candidate_size = [] | |||||
while lefted_idxs: | |||||
idxs = [] | |||||
idxs.append(lefted_idxs.pop(0)) | |||||
max_candidate_idxs = [] | |||||
found = False | |||||
for idx in range(-1, len(pred_res)): | |||||
if (not idx in idxs) and (idx >= 0): | |||||
idxs.append(idx) | |||||
candidate = self._revise_by_idxs( | |||||
pred_res, y, all_revision_flag, idxs | |||||
) | |||||
if len(candidate) == 0: | |||||
if len(idxs) > 1: | |||||
idxs.pop() | |||||
else: | |||||
if len(idxs) > len(max_candidate_idxs): | |||||
found = True | |||||
max_candidate_idxs = idxs.copy() | |||||
removed = [i for i in lefted_idxs if i in max_candidate_idxs] | |||||
if found: | |||||
candidate_size.append(len(removed) + 1) | |||||
lefted_idxs = [ | |||||
i for i in lefted_idxs if i not in max_candidate_idxs | |||||
] | |||||
candidate_size.sort() | |||||
score = 0 | |||||
import math | |||||
for i in range(0, len(candidate_size)): | |||||
score -= math.exp(-i) * candidate_size[i] | |||||
return score | |||||
def abduce_rules(self, pred_res): | |||||
return self.kb.abduce_rules(pred_res) | |||||
kb = HED_prolog_KB( | |||||
pseudo_label_list=[1, 0, "+", "="], | |||||
pl_file="examples/hed/datasets/learn_add.pl", | |||||
) | |||||
reasoner = HED_Reasoner(kb) | |||||
consist_exs = [ | |||||
[1, 1, "+", 0, "=", 1, 1], | |||||
[1, "+", 1, "=", 1, 0], | |||||
[0, "+", 0, "=", 0], | |||||
] | |||||
inconsist_exs1 = [ | |||||
[1, 1, "+", 0, "=", 1, 1], | |||||
[1, "+", 1, "=", 1, 0], | |||||
[0, "+", 0, "=", 0], | |||||
[0, "+", 0, "=", 1], | |||||
] | |||||
inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] | |||||
rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"] | |||||
print("HED_kb logic forward") | |||||
print(kb.logic_forward(consist_exs)) | |||||
print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2)) | |||||
print() | |||||
print("HED_kb consist rule") | |||||
print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules)) | |||||
print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules)) | |||||
print() | |||||
print("HED_Reasoner abduce") | |||||
res = reasoner.abduce( | |||||
[[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs) | |||||
) | |||||
print(res) | |||||
res = reasoner.abduce( | |||||
[[[None]]] * len(inconsist_exs1), inconsist_exs1, [None] * len(inconsist_exs1) | |||||
) | |||||
print(res) | |||||
res = reasoner.abduce( | |||||
[[[None]]] * len(inconsist_exs2), inconsist_exs2, [None] * len(inconsist_exs2) | |||||
) | |||||
print(res) | |||||
print() | |||||
print("HED_Reasoner abduce rules") | |||||
abduced_rules = reasoner.abduce_rules(consist_exs) | |||||
print(abduced_rules) | |||||
data_samples.abduced_pseudo_label = abduced_pseudo_label | |||||
return abduced_pseudo_label |
@@ -0,0 +1,49 @@ | |||||
from abc import ABC, abstractmethod | |||||
from itertools import product | |||||
from typing import Any, List, Tuple, Union | |||||
import numpy | |||||
from ..structures import ListData | |||||
from .base_kb import BaseKB | |||||
class SearchBasedKB(BaseKB, ABC): | |||||
def __init__( | |||||
self, | |||||
pseudo_label_list: List, | |||||
) -> None: | |||||
super().__init__(pseudo_label_list) | |||||
@abstractmethod | |||||
def check_equal(self, data_sample: ListData, y: Any): | |||||
"""Placeholder for check_equal.""" | |||||
pass | |||||
def revise_at_idx( | |||||
self, | |||||
data_sample: ListData, | |||||
revision_idx: Union[List, Tuple, numpy.ndarray], | |||||
): | |||||
candidates = [] | |||||
abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx)) | |||||
for c in abduce_c: | |||||
new_data_sample = data_sample.clone() | |||||
candidate = new_data_sample["pred_pseudo_label"][0].copy() | |||||
for i, idx in enumerate(revision_idx): | |||||
candidate[idx] = c[i] | |||||
new_data_sample["pred_pseudo_label"][0] = candidate | |||||
if self.check_equal(new_data_sample, new_data_sample["Y"][0]): | |||||
candidates.append(candidate) | |||||
return candidates | |||||
# TODO: When the output is excessively long, use ellipses as a substitute. | |||||
def __repr__(self): | |||||
return ( | |||||
f"<{self.__class__.__name__}(\n" | |||||
f" pseudo_label_list: {self.pseudo_label_list!r}\n" | |||||
f" search_strategy: {self.search_strategy!r}\n" | |||||
f" use_cache: {self.use_cache!r}\n" | |||||
f" cache_root: {self.cache_root!r}\n" | |||||
f") at {hex(id(self))}>" | |||||
) |
@@ -0,0 +1,2 @@ | |||||
from .base_search_engine import BaseSearchEngine | |||||
from .bfs import BFS |
@@ -0,0 +1,13 @@ | |||||
from abc import ABC, abstractmethod | |||||
from typing import List, Tuple, Union | |||||
import numpy | |||||
from ...structures import ListData | |||||
class BaseSearchEngine(ABC): | |||||
@abstractmethod | |||||
def generator(data_sample: ListData) -> Union[List, Tuple, numpy.ndarray]: | |||||
"""Placeholder for the generator of revision_idx.""" | |||||
pass |
@@ -0,0 +1,28 @@ | |||||
from itertools import combinations | |||||
from typing import List, Tuple, Union | |||||
import numpy | |||||
from ...structures import ListData | |||||
from .base_search_engine import BaseSearchEngine | |||||
class BFS(BaseSearchEngine): | |||||
def __init__(self) -> None: | |||||
pass | |||||
def generator( | |||||
self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0 | |||||
) -> Union[List, Tuple, numpy.ndarray]: | |||||
symbol_num = data_sample["symbol_num"] | |||||
max_revision_num = min(max_revision_num, symbol_num) | |||||
real_end = max_revision_num | |||||
for revision_num in range(max_revision_num + 1): | |||||
if revision_num > real_end: | |||||
break | |||||
revision_idx_tuple = combinations(range(symbol_num), revision_num) | |||||
for revision_idx in revision_idx_tuple: | |||||
received = yield revision_idx | |||||
if received == "success": | |||||
real_end = min(symbol_num, revision_num + require_more_revision) |
@@ -0,0 +1,42 @@ | |||||
from typing import List, Tuple, Union | |||||
import numpy as np | |||||
from zoopt import Dimension, Objective, Opt, Parameter, Solution | |||||
from ...structures import ListData | |||||
from ..reasoner import ReasonerBase | |||||
from ..search_based_kb import SearchBasedKB | |||||
from .base_search_engine import BaseSearchEngine | |||||
class Zoopt(BaseSearchEngine): | |||||
def __init__(self, reasoner: ReasonerBase, kb: SearchBasedKB) -> None: | |||||
self.reasoner = reasoner | |||||
self.kb = kb | |||||
def score_func(self, data_sample: ListData, solution: Solution): | |||||
revision_idx = np.where(solution.get_x() != 0)[0] | |||||
candidates = self.kb.revise_at_idx(data_sample, revision_idx) | |||||
if len(candidates) > 0: | |||||
return np.min(self.reasoner._get_dist_list(data_sample, candidates)) | |||||
else: | |||||
return data_sample["symbol_num"] | |||||
@staticmethod | |||||
def constraint(solution: Solution, max_revision_num: int): | |||||
x = solution.get_x() | |||||
return max_revision_num - x.sum() | |||||
def generator( | |||||
self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0 | |||||
) -> Union[List, Tuple, np.ndarray]: | |||||
symbol_num = data_sample["symbol_num"] | |||||
dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num) | |||||
objective = Objective( | |||||
lambda solution: self.score_func(self, data_sample, solution), | |||||
dim=dimension, | |||||
constraint=lambda solution: self.constraint(solution, max_revision_num), | |||||
) | |||||
parameter = Parameter(budget=100, intermediate_result=False, autoset=True) | |||||
solution = Opt.min(objective, parameter).get_x() | |||||
yield solution |
@@ -0,0 +1,2 @@ | |||||
from .base_data_element import BaseDataElement | |||||
from .list_data import ListData |
@@ -0,0 +1,629 @@ | |||||
# Copyright (c) OpenMMLab. All rights reserved. | |||||
import copy | |||||
from typing import Any, Iterator, Optional, Tuple, Type, Union | |||||
import numpy as np | |||||
import torch | |||||
class BaseDataElement: | |||||
"""A base data interface that supports Tensor-like and dict-like | |||||
operations. | |||||
A typical data elements refer to predicted results or ground truth labels | |||||
on a task, such as predicted bboxes, instance masks, semantic | |||||
segmentation masks, etc. Because groundtruth labels and predicted results | |||||
often have similar properties (for example, the predicted bboxes and the | |||||
groundtruth bboxes), MMEngine uses the same abstract data interface to | |||||
encapsulate predicted results and groundtruth labels, and it is recommended | |||||
to use different name conventions to distinguish them, such as using | |||||
``gt_instances`` and ``pred_instances`` to distinguish between labels and | |||||
predicted results. Additionally, we distinguish data elements at instance | |||||
level, pixel level, and label level. Each of these types has its own | |||||
characteristics. Therefore, MMEngine defines the base class | |||||
``BaseDataElement``, and implement ``InstanceData``, ``PixelData``, and | |||||
``LabelData`` inheriting from ``BaseDataElement`` to represent different | |||||
types of ground truth labels or predictions. | |||||
Another common data element is sample data. A sample data consists of input | |||||
data (such as an image) and its annotations and predictions. In general, | |||||
an image can have multiple types of annotations and/or predictions at the | |||||
same time (for example, both pixel-level semantic segmentation annotations | |||||
and instance-level detection bboxes annotations). All labels and | |||||
predictions of a training sample are often passed between Dataset, Model, | |||||
Visualizer, and Evaluator components. In order to simplify the interface | |||||
between components, we can treat them as a large data element and | |||||
encapsulate them. Such data elements are generally called XXDataSample in | |||||
the OpenMMLab. Therefore, Similar to `nn.Module`, the `BaseDataElement` | |||||
allows `BaseDataElement` as its attribute. Such a class generally | |||||
encapsulates all the data of a sample in the algorithm library, and its | |||||
attributes generally are various types of data elements. For example, | |||||
MMDetection is assigned by the BaseDataElement to encapsulate all the data | |||||
elements of the sample labeling and prediction of a sample in the | |||||
algorithm library. | |||||
The attributes in ``BaseDataElement`` are divided into two parts, | |||||
the ``metainfo`` and the ``data`` respectively. | |||||
- ``metainfo``: Usually contains the | |||||
information about the image such as filename, | |||||
image_shape, pad_shape, etc. The attributes can be accessed or | |||||
modified by dict-like or object-like operations, such as | |||||
``.`` (for data access and modification), ``in``, ``del``, | |||||
``pop(str)``, ``get(str)``, ``metainfo_keys()``, | |||||
``metainfo_values()``, ``metainfo_items()``, ``set_metainfo()`` (for | |||||
set or change key-value pairs in metainfo). | |||||
- ``data``: Annotations or model predictions are | |||||
stored. The attributes can be accessed or modified by | |||||
dict-like or object-like operations, such as | |||||
``.``, ``in``, ``del``, ``pop(str)``, ``get(str)``, ``keys()``, | |||||
``values()``, ``items()``. Users can also apply tensor-like | |||||
methods to all :obj:`torch.Tensor` in the ``data_fields``, | |||||
such as ``.cuda()``, ``.cpu()``, ``.numpy()``, ``.to()``, | |||||
``to_tensor()``, ``.detach()``. | |||||
Args: | |||||
metainfo (dict, optional): A dict contains the meta information | |||||
of single image, such as ``dict(img_shape=(512, 512, 3), | |||||
scale_factor=(1, 1, 1, 1))``. Defaults to None. | |||||
kwargs (dict, optional): A dict contains annotations of single image or | |||||
model predictions. Defaults to None. | |||||
Examples: | |||||
>>> import torch | |||||
>>> from mmengine.structures import BaseDataElement | |||||
>>> gt_instances = BaseDataElement() | |||||
>>> bboxes = torch.rand((5, 4)) | |||||
>>> scores = torch.rand((5,)) | |||||
>>> img_id = 0 | |||||
>>> img_shape = (800, 1333) | |||||
>>> gt_instances = BaseDataElement( | |||||
... metainfo=dict(img_id=img_id, img_shape=img_shape), | |||||
... bboxes=bboxes, scores=scores) | |||||
>>> gt_instances = BaseDataElement( | |||||
... metainfo=dict(img_id=img_id, img_shape=(640, 640))) | |||||
>>> # new | |||||
>>> gt_instances1 = gt_instances.new( | |||||
... metainfo=dict(img_id=1, img_shape=(640, 640)), | |||||
... bboxes=torch.rand((5, 4)), | |||||
... scores=torch.rand((5,))) | |||||
>>> gt_instances2 = gt_instances1.new() | |||||
>>> # add and process property | |||||
>>> gt_instances = BaseDataElement() | |||||
>>> gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100))) | |||||
>>> assert 'img_shape' in gt_instances.metainfo_keys() | |||||
>>> assert 'img_shape' in gt_instances | |||||
>>> assert 'img_shape' not in gt_instances.keys() | |||||
>>> assert 'img_shape' in gt_instances.all_keys() | |||||
>>> print(gt_instances.img_shape) | |||||
(100, 100) | |||||
>>> gt_instances.scores = torch.rand((5,)) | |||||
>>> assert 'scores' in gt_instances.keys() | |||||
>>> assert 'scores' in gt_instances | |||||
>>> assert 'scores' in gt_instances.all_keys() | |||||
>>> assert 'scores' not in gt_instances.metainfo_keys() | |||||
>>> print(gt_instances.scores) | |||||
tensor([0.5230, 0.7885, 0.2426, 0.3911, 0.4876]) | |||||
>>> gt_instances.bboxes = torch.rand((5, 4)) | |||||
>>> assert 'bboxes' in gt_instances.keys() | |||||
>>> assert 'bboxes' in gt_instances | |||||
>>> assert 'bboxes' in gt_instances.all_keys() | |||||
>>> assert 'bboxes' not in gt_instances.metainfo_keys() | |||||
>>> print(gt_instances.bboxes) | |||||
tensor([[0.0900, 0.0424, 0.1755, 0.4469], | |||||
[0.8648, 0.0592, 0.3484, 0.0913], | |||||
[0.5808, 0.1909, 0.6165, 0.7088], | |||||
[0.5490, 0.4209, 0.9416, 0.2374], | |||||
[0.3652, 0.1218, 0.8805, 0.7523]]) | |||||
>>> # delete and change property | |||||
>>> gt_instances = BaseDataElement( | |||||
... metainfo=dict(img_id=0, img_shape=(640, 640)), | |||||
... bboxes=torch.rand((6, 4)), scores=torch.rand((6,))) | |||||
>>> gt_instances.set_metainfo(dict(img_shape=(1280, 1280))) | |||||
>>> gt_instances.img_shape # (1280, 1280) | |||||
>>> gt_instances.bboxes = gt_instances.bboxes * 2 | |||||
>>> gt_instances.get('img_shape', None) # (1280, 1280) | |||||
>>> gt_instances.get('bboxes', None) # 6x4 tensor | |||||
>>> del gt_instances.img_shape | |||||
>>> del gt_instances.bboxes | |||||
>>> assert 'img_shape' not in gt_instances | |||||
>>> assert 'bboxes' not in gt_instances | |||||
>>> gt_instances.pop('img_shape', None) # None | |||||
>>> gt_instances.pop('bboxes', None) # None | |||||
>>> # Tensor-like | |||||
>>> cuda_instances = gt_instances.cuda() | |||||
>>> cuda_instances = gt_instances.to('cuda:0') | |||||
>>> cpu_instances = cuda_instances.cpu() | |||||
>>> cpu_instances = cuda_instances.to('cpu') | |||||
>>> fp16_instances = cuda_instances.to( | |||||
... device=None, dtype=torch.float16, non_blocking=False, | |||||
... copy=False, memory_format=torch.preserve_format) | |||||
>>> cpu_instances = cuda_instances.detach() | |||||
>>> np_instances = cpu_instances.numpy() | |||||
>>> metainfo = dict(img_shape=(800, 1196, 3)) | |||||
>>> gt_instances = BaseDataElement( | |||||
... metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3])) | |||||
>>> sample = BaseDataElement(metainfo=metainfo, | |||||
... gt_instances=gt_instances) | |||||
>>> print(sample) | |||||
<BaseDataElement( | |||||
META INFORMATION | |||||
img_shape: (800, 1196, 3) | |||||
DATA FIELDS | |||||
gt_instances: <BaseDataElement( | |||||
META INFORMATION | |||||
img_shape: (800, 1196, 3) | |||||
DATA FIELDS | |||||
det_labels: tensor([0, 1, 2, 3]) | |||||
) at 0x7f0ec5eadc70> | |||||
) at 0x7f0fea49e130> | |||||
>>> # inheritance | |||||
>>> class DetDataSample(BaseDataElement): | |||||
... @property | |||||
... def proposals(self): | |||||
... return self._proposals | |||||
... @proposals.setter | |||||
... def proposals(self, value): | |||||
... self.set_field(value, '_proposals', dtype=BaseDataElement) | |||||
... @proposals.deleter | |||||
... def proposals(self): | |||||
... del self._proposals | |||||
... @property | |||||
... def gt_instances(self): | |||||
... return self._gt_instances | |||||
... @gt_instances.setter | |||||
... def gt_instances(self, value): | |||||
... self.set_field(value, '_gt_instances', | |||||
... dtype=BaseDataElement) | |||||
... @gt_instances.deleter | |||||
... def gt_instances(self): | |||||
... del self._gt_instances | |||||
... @property | |||||
... def pred_instances(self): | |||||
... return self._pred_instances | |||||
... @pred_instances.setter | |||||
... def pred_instances(self, value): | |||||
... self.set_field(value, '_pred_instances', | |||||
... dtype=BaseDataElement) | |||||
... @pred_instances.deleter | |||||
... def pred_instances(self): | |||||
... del self._pred_instances | |||||
>>> det_sample = DetDataSample() | |||||
>>> proposals = BaseDataElement(bboxes=torch.rand((5, 4))) | |||||
>>> det_sample.proposals = proposals | |||||
>>> assert 'proposals' in det_sample | |||||
>>> assert det_sample.proposals == proposals | |||||
>>> del det_sample.proposals | |||||
>>> assert 'proposals' not in det_sample | |||||
>>> with self.assertRaises(AssertionError): | |||||
... det_sample.proposals = torch.rand((5, 4)) | |||||
""" | |||||
def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None: | |||||
self._metainfo_fields: set = set() | |||||
self._data_fields: set = set() | |||||
if metainfo is not None: | |||||
self.set_metainfo(metainfo=metainfo) | |||||
if kwargs: | |||||
self.set_data(kwargs) | |||||
def set_metainfo(self, metainfo: dict) -> None: | |||||
"""Set or change key-value pairs in ``metainfo_field`` by parameter | |||||
``metainfo``. | |||||
Args: | |||||
metainfo (dict): A dict contains the meta information | |||||
of image, such as ``img_shape``, ``scale_factor``, etc. | |||||
""" | |||||
assert isinstance( | |||||
metainfo, dict | |||||
), f"metainfo should be a ``dict`` but got {type(metainfo)}" | |||||
meta = copy.deepcopy(metainfo) | |||||
for k, v in meta.items(): | |||||
self.set_field(name=k, value=v, field_type="metainfo", dtype=None) | |||||
def set_data(self, data: dict) -> None: | |||||
"""Set or change key-value pairs in ``data_field`` by parameter | |||||
``data``. | |||||
Args: | |||||
data (dict): A dict contains annotations of image or | |||||
model predictions. | |||||
""" | |||||
assert isinstance(data, dict), f"data should be a `dict` but got {data}" | |||||
for k, v in data.items(): | |||||
# Use `setattr()` rather than `self.set_field` to allow `set_data` | |||||
# to set property method. | |||||
setattr(self, k, v) | |||||
def update(self, instance: "BaseDataElement") -> None: | |||||
"""The update() method updates the BaseDataElement with the elements | |||||
from another BaseDataElement object. | |||||
Args: | |||||
instance (BaseDataElement): Another BaseDataElement object for | |||||
update the current object. | |||||
""" | |||||
assert isinstance( | |||||
instance, BaseDataElement | |||||
), f"instance should be a `BaseDataElement` but got {type(instance)}" | |||||
self.set_metainfo(dict(instance.metainfo_items())) | |||||
self.set_data(dict(instance.items())) | |||||
def new(self, *, metainfo: Optional[dict] = None, **kwargs) -> "BaseDataElement": | |||||
"""Return a new data element with same type. If ``metainfo`` and | |||||
``data`` are None, the new data element will have same metainfo and | |||||
data. If metainfo or data is not None, the new result will overwrite it | |||||
with the input value. | |||||
Args: | |||||
metainfo (dict, optional): A dict contains the meta information | |||||
of image, such as ``img_shape``, ``scale_factor``, etc. | |||||
Defaults to None. | |||||
kwargs (dict): A dict contains annotations of image or | |||||
model predictions. | |||||
Returns: | |||||
BaseDataElement: A new data element with same type. | |||||
""" | |||||
new_data = self.__class__() | |||||
if metainfo is not None: | |||||
new_data.set_metainfo(metainfo) | |||||
else: | |||||
new_data.set_metainfo(dict(self.metainfo_items())) | |||||
if kwargs: | |||||
new_data.set_data(kwargs) | |||||
else: | |||||
new_data.set_data(dict(self.items())) | |||||
return new_data | |||||
def clone(self): | |||||
"""Deep copy the current data element. | |||||
Returns: | |||||
BaseDataElement: The copy of current data element. | |||||
""" | |||||
clone_data = self.__class__() | |||||
clone_data.set_metainfo(dict(self.metainfo_items())) | |||||
clone_data.set_data(dict(self.items())) | |||||
return clone_data | |||||
def keys(self) -> list: | |||||
""" | |||||
Returns: | |||||
list: Contains all keys in data_fields. | |||||
""" | |||||
# We assume that the name of the attribute related to property is | |||||
# '_' + the name of the property. We use this rule to filter out | |||||
# private keys. | |||||
# TODO: Use a more robust way to solve this problem | |||||
private_keys = { | |||||
"_" + key | |||||
for key in self._data_fields | |||||
if isinstance(getattr(type(self), key, None), property) | |||||
} | |||||
return list(self._data_fields - private_keys) | |||||
def metainfo_keys(self) -> list: | |||||
""" | |||||
Returns: | |||||
list: Contains all keys in metainfo_fields. | |||||
""" | |||||
return list(self._metainfo_fields) | |||||
def values(self) -> list: | |||||
""" | |||||
Returns: | |||||
list: Contains all values in data. | |||||
""" | |||||
return [getattr(self, k) for k in self.keys()] | |||||
def metainfo_values(self) -> list: | |||||
""" | |||||
Returns: | |||||
list: Contains all values in metainfo. | |||||
""" | |||||
return [getattr(self, k) for k in self.metainfo_keys()] | |||||
def all_keys(self) -> list: | |||||
""" | |||||
Returns: | |||||
list: Contains all keys in metainfo and data. | |||||
""" | |||||
return self.metainfo_keys() + self.keys() | |||||
def all_values(self) -> list: | |||||
""" | |||||
Returns: | |||||
list: Contains all values in metainfo and data. | |||||
""" | |||||
return self.metainfo_values() + self.values() | |||||
def all_items(self) -> Iterator[Tuple[str, Any]]: | |||||
""" | |||||
Returns: | |||||
iterator: An iterator object whose element is (key, value) tuple | |||||
pairs for ``metainfo`` and ``data``. | |||||
""" | |||||
for k in self.all_keys(): | |||||
yield (k, getattr(self, k)) | |||||
def items(self) -> Iterator[Tuple[str, Any]]: | |||||
""" | |||||
Returns: | |||||
iterator: An iterator object whose element is (key, value) tuple | |||||
pairs for ``data``. | |||||
""" | |||||
for k in self.keys(): | |||||
yield (k, getattr(self, k)) | |||||
def metainfo_items(self) -> Iterator[Tuple[str, Any]]: | |||||
""" | |||||
Returns: | |||||
iterator: An iterator object whose element is (key, value) tuple | |||||
pairs for ``metainfo``. | |||||
""" | |||||
for k in self.metainfo_keys(): | |||||
yield (k, getattr(self, k)) | |||||
@property | |||||
def metainfo(self) -> dict: | |||||
"""dict: A dict contains metainfo of current data element.""" | |||||
return dict(self.metainfo_items()) | |||||
def __setattr__(self, name: str, value: Any): | |||||
"""setattr is only used to set data.""" | |||||
if name in ("_metainfo_fields", "_data_fields"): | |||||
if not hasattr(self, name): | |||||
super().__setattr__(name, value) | |||||
else: | |||||
raise AttributeError( | |||||
f"{name} has been used as a " | |||||
"private attribute, which is immutable." | |||||
) | |||||
else: | |||||
self.set_field(name=name, value=value, field_type="data", dtype=None) | |||||
def __delattr__(self, item: str): | |||||
"""Delete the item in dataelement. | |||||
Args: | |||||
item (str): The key to delete. | |||||
""" | |||||
if item in ("_metainfo_fields", "_data_fields"): | |||||
raise AttributeError( | |||||
f"{item} has been used as a " "private attribute, which is immutable." | |||||
) | |||||
super().__delattr__(item) | |||||
if item in self._metainfo_fields: | |||||
self._metainfo_fields.remove(item) | |||||
elif item in self._data_fields: | |||||
self._data_fields.remove(item) | |||||
# dict-like methods | |||||
__delitem__ = __delattr__ | |||||
def get(self, key, default=None) -> Any: | |||||
"""Get property in data and metainfo as the same as python.""" | |||||
# Use `getattr()` rather than `self.__dict__.get()` to allow getting | |||||
# properties. | |||||
return getattr(self, key, default) | |||||
def pop(self, *args) -> Any: | |||||
"""Pop property in data and metainfo as the same as python.""" | |||||
assert len(args) < 3, "``pop`` get more than 2 arguments" | |||||
name = args[0] | |||||
if name in self._metainfo_fields: | |||||
self._metainfo_fields.remove(args[0]) | |||||
return self.__dict__.pop(*args) | |||||
elif name in self._data_fields: | |||||
self._data_fields.remove(args[0]) | |||||
return self.__dict__.pop(*args) | |||||
# with default value | |||||
elif len(args) == 2: | |||||
return args[1] | |||||
else: | |||||
# don't just use 'self.__dict__.pop(*args)' for only popping key in | |||||
# metainfo or data | |||||
raise KeyError(f"{args[0]} is not contained in metainfo or data") | |||||
def __contains__(self, item: str) -> bool: | |||||
"""Whether the item is in dataelement. | |||||
Args: | |||||
item (str): The key to inquire. | |||||
""" | |||||
return item in self._data_fields or item in self._metainfo_fields | |||||
def set_field( | |||||
self, | |||||
value: Any, | |||||
name: str, | |||||
dtype: Optional[Union[Type, Tuple[Type, ...]]] = None, | |||||
field_type: str = "data", | |||||
) -> None: | |||||
"""Special method for set union field, used as property.setter | |||||
functions.""" | |||||
assert field_type in ["metainfo", "data"] | |||||
if dtype is not None: | |||||
assert isinstance( | |||||
value, dtype | |||||
), f"{value} should be a {dtype} but got {type(value)}" | |||||
if field_type == "metainfo": | |||||
if name in self._data_fields: | |||||
raise AttributeError( | |||||
f"Cannot set {name} to be a field of metainfo " | |||||
f"because {name} is already a data field" | |||||
) | |||||
self._metainfo_fields.add(name) | |||||
else: | |||||
if name in self._metainfo_fields: | |||||
raise AttributeError( | |||||
f"Cannot set {name} to be a field of data " | |||||
f"because {name} is already a metainfo field" | |||||
) | |||||
self._data_fields.add(name) | |||||
super().__setattr__(name, value) | |||||
# Tensor-like methods | |||||
def to(self, *args, **kwargs) -> "BaseDataElement": | |||||
"""Apply same name function to all tensors in data_fields.""" | |||||
new_data = self.new() | |||||
for k, v in self.items(): | |||||
if hasattr(v, "to"): | |||||
v = v.to(*args, **kwargs) | |||||
data = {k: v} | |||||
new_data.set_data(data) | |||||
return new_data | |||||
# Tensor-like methods | |||||
def cpu(self) -> "BaseDataElement": | |||||
"""Convert all tensors to CPU in data.""" | |||||
new_data = self.new() | |||||
for k, v in self.items(): | |||||
if isinstance(v, (torch.Tensor, BaseDataElement)): | |||||
v = v.cpu() | |||||
data = {k: v} | |||||
new_data.set_data(data) | |||||
return new_data | |||||
# Tensor-like methods | |||||
def cuda(self) -> "BaseDataElement": | |||||
"""Convert all tensors to GPU in data.""" | |||||
new_data = self.new() | |||||
for k, v in self.items(): | |||||
if isinstance(v, (torch.Tensor, BaseDataElement)): | |||||
v = v.cuda() | |||||
data = {k: v} | |||||
new_data.set_data(data) | |||||
return new_data | |||||
# Tensor-like methods | |||||
def npu(self) -> "BaseDataElement": | |||||
"""Convert all tensors to NPU in data.""" | |||||
new_data = self.new() | |||||
for k, v in self.items(): | |||||
if isinstance(v, (torch.Tensor, BaseDataElement)): | |||||
v = v.npu() | |||||
data = {k: v} | |||||
new_data.set_data(data) | |||||
return new_data | |||||
def mlu(self) -> "BaseDataElement": | |||||
"""Convert all tensors to MLU in data.""" | |||||
new_data = self.new() | |||||
for k, v in self.items(): | |||||
if isinstance(v, (torch.Tensor, BaseDataElement)): | |||||
v = v.mlu() | |||||
data = {k: v} | |||||
new_data.set_data(data) | |||||
return new_data | |||||
# Tensor-like methods | |||||
def detach(self) -> "BaseDataElement": | |||||
"""Detach all tensors in data.""" | |||||
new_data = self.new() | |||||
for k, v in self.items(): | |||||
if isinstance(v, (torch.Tensor, BaseDataElement)): | |||||
v = v.detach() | |||||
data = {k: v} | |||||
new_data.set_data(data) | |||||
return new_data | |||||
# Tensor-like methods | |||||
def numpy(self) -> "BaseDataElement": | |||||
"""Convert all tensors to np.ndarray in data.""" | |||||
new_data = self.new() | |||||
for k, v in self.items(): | |||||
if isinstance(v, (torch.Tensor, BaseDataElement)): | |||||
v = v.detach().cpu().numpy() | |||||
data = {k: v} | |||||
new_data.set_data(data) | |||||
return new_data | |||||
def to_tensor(self) -> "BaseDataElement": | |||||
"""Convert all np.ndarray to tensor in data.""" | |||||
new_data = self.new() | |||||
for k, v in self.items(): | |||||
data = {} | |||||
if isinstance(v, np.ndarray): | |||||
v = torch.from_numpy(v) | |||||
data[k] = v | |||||
elif isinstance(v, BaseDataElement): | |||||
v = v.to_tensor() | |||||
data[k] = v | |||||
new_data.set_data(data) | |||||
return new_data | |||||
def to_dict(self) -> dict: | |||||
"""Convert BaseDataElement to dict.""" | |||||
return { | |||||
k: v.to_dict() if isinstance(v, BaseDataElement) else v | |||||
for k, v in self.all_items() | |||||
} | |||||
def __repr__(self) -> str: | |||||
"""Represent the object.""" | |||||
def _addindent(s_: str, num_spaces: int) -> str: | |||||
"""This func is modified from `pytorch` https://github.com/pytorch/ | |||||
pytorch/blob/b17b2b1cc7b017c3daaeff8cc7ec0f514d42ec37/torch/nn/modu | |||||
les/module.py#L29. | |||||
Args: | |||||
s_ (str): The string to add spaces. | |||||
num_spaces (int): The num of space to add. | |||||
Returns: | |||||
str: The string after add indent. | |||||
""" | |||||
s = s_.split("\n") | |||||
# don't do anything for single-line stuff | |||||
if len(s) == 1: | |||||
return s_ | |||||
first = s.pop(0) | |||||
s = [(num_spaces * " ") + line for line in s] | |||||
s = "\n".join(s) # type: ignore | |||||
s = first + "\n" + s # type: ignore | |||||
return s # type: ignore | |||||
def dump(obj: Any) -> str: | |||||
"""Represent the object. | |||||
Args: | |||||
obj (Any): The obj to represent. | |||||
Returns: | |||||
str: The represented str. | |||||
""" | |||||
_repr = "" | |||||
if isinstance(obj, dict): | |||||
for k, v in obj.items(): | |||||
_repr += f"\n{k}: {_addindent(dump(v), 4)}" | |||||
elif isinstance(obj, BaseDataElement): | |||||
_repr += "\n\n META INFORMATION" | |||||
metainfo_items = dict(obj.metainfo_items()) | |||||
_repr += _addindent(dump(metainfo_items), 4) | |||||
_repr += "\n\n DATA FIELDS" | |||||
items = dict(obj.items()) | |||||
_repr += _addindent(dump(items), 4) | |||||
classname = obj.__class__.__name__ | |||||
_repr = f"<{classname}({_repr}\n) at {hex(id(obj))}>" | |||||
else: | |||||
_repr += repr(obj) | |||||
return _repr | |||||
return dump(self) |
@@ -0,0 +1,321 @@ | |||||
# Copyright (c) OpenMMLab. All rights reserved. | |||||
import itertools | |||||
from collections.abc import Sized | |||||
from typing import Any, List, Union | |||||
import numpy as np | |||||
import torch | |||||
from ..utils import flatten as flatten_list | |||||
from ..utils import to_hashable | |||||
from .base_data_element import BaseDataElement | |||||
BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] | |||||
LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor] | |||||
IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndarray] | |||||
# Modified from | |||||
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa | |||||
class ListData(BaseDataElement): | |||||
"""Data structure for instance-level annotations or predictions. | |||||
Subclass of :class:`BaseDataElement`. All value in `data_fields` | |||||
should have the same length. This design refer to | |||||
https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501 | |||||
ListData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value | |||||
in data field can be base data structure such as `torch.Tensor`, `numpy.ndarray`, `list`, `str`, `tuple`, | |||||
and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes. | |||||
Examples: | |||||
>>> # custom data structure | |||||
>>> class TmpObject: | |||||
... def __init__(self, tmp) -> None: | |||||
... assert isinstance(tmp, list) | |||||
... self.tmp = tmp | |||||
... def __len__(self): | |||||
... return len(self.tmp) | |||||
... def __getitem__(self, item): | |||||
... if isinstance(item, int): | |||||
... if item >= len(self) or item < -len(self): # type:ignore | |||||
... raise IndexError(f'Index {item} out of range!') | |||||
... else: | |||||
... # keep the dimension | |||||
... item = slice(item, None, len(self)) | |||||
... return TmpObject(self.tmp[item]) | |||||
... @staticmethod | |||||
... def cat(tmp_objs): | |||||
... assert all(isinstance(results, TmpObject) for results in tmp_objs) | |||||
... if len(tmp_objs) == 1: | |||||
... return tmp_objs[0] | |||||
... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs] | |||||
... tmp_list = list(itertools.chain(*tmp_list)) | |||||
... new_data = TmpObject(tmp_list) | |||||
... return new_data | |||||
... def __repr__(self): | |||||
... return str(self.tmp) | |||||
>>> from mmengine.structures import ListData | |||||
>>> import numpy as np | |||||
>>> import torch | |||||
>>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) | |||||
>>> instance_data = ListData(metainfo=img_meta) | |||||
>>> 'img_shape' in instance_data | |||||
True | |||||
>>> instance_data.det_labels = torch.LongTensor([2, 3]) | |||||
>>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7]) | |||||
>>> instance_data.bboxes = torch.rand((2, 4)) | |||||
>>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]]) | |||||
>>> len(instance_data) | |||||
2 | |||||
>>> print(instance_data) | |||||
<ListData( | |||||
META INFORMATION | |||||
img_shape: (800, 1196, 3) | |||||
pad_shape: (800, 1216, 3) | |||||
DATA FIELDS | |||||
det_labels: tensor([2, 3]) | |||||
det_scores: tensor([0.8000, 0.7000]) | |||||
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], | |||||
[0.8101, 0.3105, 0.5123, 0.6263]]) | |||||
polygons: [[1, 2, 3, 4], [5, 6, 7, 8]] | |||||
) at 0x7fb492de6280> | |||||
>>> sorted_results = instance_data[instance_data.det_scores.sort().indices] | |||||
>>> sorted_results.det_scores | |||||
tensor([0.7000, 0.8000]) | |||||
>>> print(instance_data[instance_data.det_scores > 0.75]) | |||||
<ListData( | |||||
META INFORMATION | |||||
img_shape: (800, 1196, 3) | |||||
pad_shape: (800, 1216, 3) | |||||
DATA FIELDS | |||||
det_labels: tensor([2]) | |||||
det_scores: tensor([0.8000]) | |||||
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]]) | |||||
polygons: [[1, 2, 3, 4]] | |||||
) at 0x7f64ecf0ec40> | |||||
>>> print(instance_data[instance_data.det_scores > 1]) | |||||
<ListData( | |||||
META INFORMATION | |||||
img_shape: (800, 1196, 3) | |||||
pad_shape: (800, 1216, 3) | |||||
DATA FIELDS | |||||
det_labels: tensor([], dtype=torch.int64) | |||||
det_scores: tensor([]) | |||||
bboxes: tensor([], size=(0, 4)) | |||||
polygons: [] | |||||
) at 0x7f660a6a7f70> | |||||
>>> print(instance_data.cat([instance_data, instance_data])) | |||||
<ListData( | |||||
META INFORMATION | |||||
img_shape: (800, 1196, 3) | |||||
pad_shape: (800, 1216, 3) | |||||
DATA FIELDS | |||||
det_labels: tensor([2, 3, 2, 3]) | |||||
det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000]) | |||||
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], | |||||
[0.8101, 0.3105, 0.5123, 0.6263], | |||||
[0.4997, 0.7707, 0.0595, 0.4188], | |||||
[0.8101, 0.3105, 0.5123, 0.6263]]) | |||||
polygons: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]] | |||||
) at 0x7f203542feb0> | |||||
""" | |||||
def __setattr__(self, name: str, value: list): | |||||
"""setattr is only used to set data. | |||||
The value must have the attribute of `__len__` and have the same length | |||||
of `ListData`. | |||||
""" | |||||
if name in ("_metainfo_fields", "_data_fields"): | |||||
if not hasattr(self, name): | |||||
super().__setattr__(name, value) | |||||
else: | |||||
raise AttributeError( | |||||
f"{name} has been used as a " | |||||
"private attribute, which is immutable." | |||||
) | |||||
else: | |||||
assert isinstance(value, list), "value must be of type `list`" | |||||
if len(self) > 0: | |||||
assert len(value) == len(self), ( | |||||
"The length of " | |||||
f"values {len(value)} is " | |||||
"not consistent with " | |||||
"the length of this " | |||||
":obj:`ListData` " | |||||
f"{len(self)}" | |||||
) | |||||
super().__setattr__(name, value) | |||||
__setitem__ = __setattr__ | |||||
def __getitem__(self, item: IndexType) -> "ListData": | |||||
""" | |||||
Args: | |||||
item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`, | |||||
:obj:`torch.LongTensor`, :obj:`torch.BoolTensor`): | |||||
Get the corresponding values according to item. | |||||
Returns: | |||||
:obj:`ListData`: Corresponding values. | |||||
""" | |||||
assert isinstance(item, IndexType.__args__) | |||||
if isinstance(item, list): | |||||
item = np.array(item) | |||||
if isinstance(item, np.ndarray): | |||||
# The default int type of numpy is platform dependent, int32 for | |||||
# windows and int64 for linux. `torch.Tensor` requires the index | |||||
# should be int64, therefore we simply convert it to int64 here. | |||||
# More details in https://github.com/numpy/numpy/issues/9464 | |||||
item = item.astype(np.int64) if item.dtype == np.int32 else item | |||||
item = torch.from_numpy(item) | |||||
if isinstance(item, str): | |||||
return getattr(self, item) | |||||
if isinstance(item, int): | |||||
if item >= len(self) or item < -len(self): # type:ignore | |||||
raise IndexError(f"Index {item} out of range!") | |||||
else: | |||||
# keep the dimension | |||||
item = slice(item, None, len(self)) | |||||
new_data = self.__class__(metainfo=self.metainfo) | |||||
if isinstance(item, torch.Tensor): | |||||
assert item.dim() == 1, ( | |||||
"Only support to get the" " values along the first dimension." | |||||
) | |||||
if isinstance(item, BoolTypeTensor.__args__): | |||||
assert len(item) == len(self), ( | |||||
"The shape of the " | |||||
"input(BoolTensor) " | |||||
f"{len(item)} " | |||||
"does not match the shape " | |||||
"of the indexed tensor " | |||||
"in results_field " | |||||
f"{len(self)} at " | |||||
"first dimension." | |||||
) | |||||
for k, v in self.items(): | |||||
if isinstance(v, torch.Tensor): | |||||
new_data[k] = v[item] | |||||
elif isinstance(v, np.ndarray): | |||||
new_data[k] = v[item.cpu().numpy()] | |||||
elif isinstance(v, (str, list, tuple)) or ( | |||||
hasattr(v, "__getitem__") and hasattr(v, "cat") | |||||
): | |||||
# convert to indexes from BoolTensor | |||||
if isinstance(item, BoolTypeTensor.__args__): | |||||
indexes = torch.nonzero(item).view(-1).cpu().numpy().tolist() | |||||
else: | |||||
indexes = item.cpu().numpy().tolist() | |||||
slice_list = [] | |||||
if indexes: | |||||
for index in indexes: | |||||
slice_list.append(slice(index, None, len(v))) | |||||
else: | |||||
slice_list.append(slice(None, 0, None)) | |||||
r_list = [v[s] for s in slice_list] | |||||
if isinstance(v, (str, list, tuple)): | |||||
new_value = r_list[0] | |||||
for r in r_list[1:]: | |||||
new_value = new_value + r | |||||
else: | |||||
new_value = v.cat(r_list) | |||||
new_data[k] = new_value | |||||
else: | |||||
raise ValueError( | |||||
f"The type of `{k}` is `{type(v)}`, which has no " | |||||
"attribute of `cat`, so it does not " | |||||
"support slice with `bool`" | |||||
) | |||||
else: | |||||
# item is a slice | |||||
for k, v in self.items(): | |||||
new_data[k] = v[item] | |||||
return new_data # type:ignore | |||||
@staticmethod | |||||
def cat(instances_list: List["ListData"]) -> "ListData": | |||||
"""Concat the instances of all :obj:`ListData` in the list. | |||||
Note: To ensure that cat returns as expected, make sure that | |||||
all elements in the list must have exactly the same keys. | |||||
Args: | |||||
instances_list (list[:obj:`ListData`]): A list | |||||
of :obj:`ListData`. | |||||
Returns: | |||||
:obj:`ListData` | |||||
""" | |||||
assert all(isinstance(results, ListData) for results in instances_list) | |||||
assert len(instances_list) > 0 | |||||
if len(instances_list) == 1: | |||||
return instances_list[0] | |||||
# metainfo and data_fields must be exactly the | |||||
# same for each element to avoid exceptions. | |||||
field_keys_list = [instances.all_keys() for instances in instances_list] | |||||
assert len({len(field_keys) for field_keys in field_keys_list}) == 1 and len( | |||||
set(itertools.chain(*field_keys_list)) | |||||
) == len(field_keys_list[0]), ( | |||||
"There are different keys in " | |||||
"`instances_list`, which may " | |||||
"cause the cat operation " | |||||
"to fail. Please make sure all " | |||||
"elements in `instances_list` " | |||||
"have the exact same key." | |||||
) | |||||
new_data = instances_list[0].__class__(metainfo=instances_list[0].metainfo) | |||||
for k in instances_list[0].keys(): | |||||
values = [results[k] for results in instances_list] | |||||
v0 = values[0] | |||||
if isinstance(v0, torch.Tensor): | |||||
new_values = torch.cat(values, dim=0) | |||||
elif isinstance(v0, np.ndarray): | |||||
new_values = np.concatenate(values, axis=0) | |||||
elif isinstance(v0, (str, list, tuple)): | |||||
new_values = v0[:] | |||||
for v in values[1:]: | |||||
new_values += v | |||||
elif hasattr(v0, "cat"): | |||||
new_values = v0.cat(values) | |||||
else: | |||||
raise ValueError( | |||||
f"The type of `{k}` is `{type(v0)}` which has no " | |||||
"attribute of `cat`" | |||||
) | |||||
new_data[k] = new_values | |||||
return new_data # type:ignore | |||||
def flatten(self, item: IndexType) -> List: | |||||
"""Flatten self[item]. | |||||
Returns: | |||||
list: Flattened data fields. | |||||
""" | |||||
return flatten_list(self[item]) | |||||
def elements_num(self, item: IndexType) -> int: | |||||
"""int: The number of elements in self[item].""" | |||||
return len(self.flatten(item)) | |||||
def to_tuple(self, item: IndexType) -> tuple: | |||||
"""tuple: The data fields in self[item] converted to tuple.""" | |||||
return to_hashable(self[item]) | |||||
def __len__(self) -> int: | |||||
"""int: The length of ListData.""" | |||||
if len(self._data_fields) > 0: | |||||
one_element = next(iter(self._data_fields)) | |||||
return len(getattr(self, one_element)) | |||||
# return len(self.values()[0]) | |||||
else: | |||||
return 0 |
@@ -1,2 +1,3 @@ | |||||
from .cache import Cache | |||||
from .logger import ABLLogger, print_log | from .logger import ABLLogger, print_log | ||||
from .utils import * | |||||
from .utils import * |
@@ -0,0 +1,112 @@ | |||||
import pickle | |||||
from os import PathLike | |||||
from pathlib import Path | |||||
from typing import Callable, Generic, Hashable, TypeVar, Union | |||||
from .logger import print_log | |||||
K = TypeVar("K") | |||||
T = TypeVar("T") | |||||
PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields | |||||
class Cache(Generic[K, T]): | |||||
def __init__( | |||||
self, | |||||
func: Callable[[K], T], | |||||
cache: bool, | |||||
cache_file: Union[None, str, PathLike], | |||||
key_func: Callable[[K], Hashable] = lambda x: x, | |||||
max_size: int = 4096, | |||||
): | |||||
"""Create cache | |||||
:param func: Function this cache evaluates | |||||
:param cache: If true, do in memory caching. | |||||
:param cache_root: If not None, cache to files at the provided path. | |||||
:param key_func: Convert the key into a hashable object if needed | |||||
""" | |||||
self.func = func | |||||
self.key_func = key_func | |||||
self.cache = cache | |||||
if cache is True or cache_file is not None: | |||||
print_log("Caching is activated", logger="current") | |||||
self._init_cache(cache_file, max_size) | |||||
self.first = self.get_from_dict | |||||
else: | |||||
self.first = self.func | |||||
def __getitem__(self, item: K, *args) -> T: | |||||
return self.first(item, *args) | |||||
def invalidate(self): | |||||
"""Invalidate entire cache.""" | |||||
self.cache_dict.clear() | |||||
if self.cache_file: | |||||
for p in self.cache_root.iterdir(): | |||||
p.unlink() | |||||
def _init_cache(self, cache_file, max_size): | |||||
self.cache = True | |||||
self.cache_dict = dict() | |||||
self.hits, self.misses, self.maxsize = 0, 0, max_size | |||||
self.full = False | |||||
self.root = [] # root of the circular doubly linked list | |||||
self.root[:] = [self.root, self.root, None, None] | |||||
if cache_file is not None: | |||||
with open(cache_file, "rb") as f: | |||||
cache_dict_from_file = pickle.load(f) | |||||
self.maxsize += len(cache_dict_from_file) | |||||
print_log( | |||||
f"Max size of the cache has been enlarged to {self.maxsize}.", logger="current" | |||||
) | |||||
for cache_key, result in cache_dict_from_file.items(): | |||||
last = self.root[PREV] | |||||
link = [last, self.root, cache_key, result] | |||||
last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link | |||||
def get(self, item: K, *args) -> T: | |||||
return self.first(item, *args) | |||||
def get_from_dict(self, item: K, *args) -> T: | |||||
"""Implements dict based cache.""" | |||||
cache_key = (self.key_func(item), *args) | |||||
link = self.cache_dict.get(cache_key) | |||||
if link is not None: | |||||
# Move the link to the front of the circular queue | |||||
link_prev, link_next, _key, result = link | |||||
link_prev[NEXT] = link_next | |||||
link_next[PREV] = link_prev | |||||
last = self.root[PREV] | |||||
last[NEXT] = self.root[PREV] = link | |||||
link[PREV] = last | |||||
link[NEXT] = self.root | |||||
self.hits += 1 | |||||
return result | |||||
self.misses += 1 | |||||
result = self.func(item, *args) | |||||
if self.full: | |||||
# Use the old root to store the new key and result. | |||||
oldroot = self.root | |||||
oldroot[KEY] = cache_key | |||||
oldroot[RESULT] = result | |||||
# Empty the oldest link and make it the new root. | |||||
self.root = oldroot[NEXT] | |||||
oldkey = self.root[KEY] | |||||
oldresult = self.root[RESULT] | |||||
self.root[KEY] = self.root[RESULT] = None | |||||
# Now update the cache dictionary. | |||||
del self.cache_dict[oldkey] | |||||
self.cache_dict[cache_key] = oldroot | |||||
else: | |||||
# Put result in a new link at the front of the queue. | |||||
last = self.root[PREV] | |||||
link = [last, self.root, cache_key, result] | |||||
last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link | |||||
if isinstance(self.maxsize, int): | |||||
self.full = len(self.cache_dict) >= self.maxsize | |||||
return result |
@@ -1,6 +1,7 @@ | |||||
import numpy as np | |||||
from itertools import chain | from itertools import chain | ||||
import numpy as np | |||||
def flatten(nested_list): | def flatten(nested_list): | ||||
""" | """ | ||||
@@ -1,8 +1,8 @@ | |||||
# -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
import sys | |||||
import os | import os | ||||
import re | import re | ||||
import sys | |||||
if not 'READTHEDOCS' in os.environ: | if not 'READTHEDOCS' in os.environ: | ||||
sys.path.insert(0, os.path.abspath('..')) | sys.path.insert(0, os.path.abspath('..')) | ||||
@@ -11,7 +11,6 @@ sys.path.append(os.path.abspath('./ABL/')) | |||||
# from sphinx.locale import _ | # from sphinx.locale import _ | ||||
from sphinx_rtd_theme import __version__ | from sphinx_rtd_theme import __version__ | ||||
project = u'ABL' | project = u'ABL' | ||||
slug = re.sub(r'\W+', '-', project.lower()) | slug = re.sub(r'\W+', '-', project.lower()) | ||||
author = u'Yu-Xuan Huang, Wen-Chao Hu, En-Hao Gao' | author = u'Yu-Xuan Huang, Wen-Chao Hu, En-Hao Gao' | ||||
@@ -1,11 +1,12 @@ | |||||
import os | import os | ||||
import cv2 | |||||
import torch | |||||
import torchvision | |||||
import pickle | import pickle | ||||
import numpy as np | |||||
import random | import random | ||||
from collections import defaultdict | from collections import defaultdict | ||||
import cv2 | |||||
import numpy as np | |||||
import torch | |||||
import torchvision | |||||
from torch.utils.data import Dataset | from torch.utils.data import Dataset | ||||
from torchvision.transforms import transforms | from torchvision.transforms import transforms | ||||
@@ -1,18 +1,18 @@ | |||||
import os | import os | ||||
from collections import defaultdict | from collections import defaultdict | ||||
import torch | import torch | ||||
from torch.utils.data import DataLoader | from torch.utils.data import DataLoader | ||||
from abl.reasoning import ReasonerBase | |||||
from abl.learning import ABLModel, BasicNN | |||||
from abl.bridge import SimpleBridge | from abl.bridge import SimpleBridge | ||||
from abl.evaluation import BaseMetric | |||||
from abl.dataset import BridgeDataset, RegressionDataset | from abl.dataset import BridgeDataset, RegressionDataset | ||||
from abl.evaluation import BaseMetric | |||||
from abl.learning import ABLModel, BasicNN | |||||
from abl.reasoning import ReasonerBase | |||||
from abl.utils import print_log | from abl.utils import print_log | ||||
from examples.hed.utils import gen_mappings, InfiniteSampler | |||||
from examples.models.nn import SymbolNetAutoencoder | |||||
from examples.hed.datasets.get_hed import get_pretrain_data | from examples.hed.datasets.get_hed import get_pretrain_data | ||||
from examples.hed.utils import InfiniteSampler, gen_mappings | |||||
from examples.models.nn import SymbolNetAutoencoder | |||||
class HEDBridge(SimpleBridge): | class HEDBridge(SimpleBridge): | ||||
@@ -12,7 +12,7 @@ | |||||
"\n", | "\n", | ||||
"from abl.reasoning import ReasonerBase, prolog_KB\n", | "from abl.reasoning import ReasonerBase, prolog_KB\n", | ||||
"from abl.learning import BasicNN, ABLModel\n", | "from abl.learning import BasicNN, ABLModel\n", | ||||
"from abl.evaluation import SymbolMetric, ABLMetric\n", | |||||
"from abl.evaluation import SymbolMetric, SemanticsMetric\n", | |||||
"from abl.utils import ABLLogger, reform_idx\n", | "from abl.utils import ABLLogger, reform_idx\n", | ||||
"\n", | "\n", | ||||
"from examples.hed.hed_bridge import HEDBridge\n", | "from examples.hed.hed_bridge import HEDBridge\n", | ||||
@@ -206,7 +206,7 @@ | |||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
"# Add metric\n", | "# Add metric\n", | ||||
"metric = [SymbolMetric(prefix=\"hed\"), ABLMetric(prefix=\"hed\")]" | |||||
"metric = [SymbolMetric(prefix=\"hed\"), SemanticsMetric(prefix=\"hed\")]" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
@@ -1,6 +1,6 @@ | |||||
import numpy as np | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import numpy as np | |||||
import torch.utils.data.sampler as sampler | import torch.utils.data.sampler as sampler | ||||
@@ -1,10 +1,10 @@ | |||||
import os | |||||
import json | import json | ||||
import os.path as osp | |||||
from PIL import Image | from PIL import Image | ||||
from torchvision.transforms import transforms | from torchvision.transforms import transforms | ||||
CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | |||||
CURRENT_DIR = osp.abspath(osp.dirname(__file__)) | |||||
img_transform = transforms.Compose( | img_transform = transforms.Compose( | ||||
[transforms.ToTensor(), transforms.Normalize((0.5,), (1,))] | [transforms.ToTensor(), transforms.Normalize((0.5,), (1,))] | ||||
@@ -15,7 +15,7 @@ def get_data(file, get_pseudo_label): | |||||
X, Y = [], [] | X, Y = [], [] | ||||
if get_pseudo_label: | if get_pseudo_label: | ||||
Z = [] | Z = [] | ||||
img_dir = os.path.join(CURRENT_DIR, "data/Handwritten_Math_Symbols/") | |||||
img_dir = osp.join(CURRENT_DIR, "data/Handwritten_Math_Symbols/") | |||||
with open(file) as f: | with open(file) as f: | ||||
data = json.load(f) | data = json.load(f) | ||||
for idx in range(len(data)): | for idx in range(len(data)): | ||||
@@ -40,8 +40,8 @@ def get_data(file, get_pseudo_label): | |||||
def get_hwf(train=True, get_gt_pseudo_label=False): | def get_hwf(train=True, get_gt_pseudo_label=False): | ||||
if train: | if train: | ||||
file = os.path.join(CURRENT_DIR, "data/expr_train.json") | |||||
file = osp.join(CURRENT_DIR, "data/expr_train.json") | |||||
else: | else: | ||||
file = os.path.join(CURRENT_DIR, "data/expr_test.json") | |||||
file = osp.join(CURRENT_DIR, "data/expr_test.json") | |||||
return get_data(file, get_gt_pseudo_label) | return get_data(file, get_gt_pseudo_label) |
@@ -6,19 +6,19 @@ | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
"import os.path as osp\n", | |||||
"\n", | |||||
"import torch\n", | "import torch\n", | ||||
"import numpy as np\n", | |||||
"import torch.nn as nn\n", | "import torch.nn as nn\n", | ||||
"import os.path as osp\n", | |||||
"\n", | "\n", | ||||
"from abl.reasoning import ReasonerBase, KBBase\n", | |||||
"from abl.learning import BasicNN, ABLModel\n", | |||||
"from abl.bridge import SimpleBridge\n", | "from abl.bridge import SimpleBridge\n", | ||||
"from abl.evaluation import SymbolMetric, SemanticsMetric\n", | |||||
"from abl.evaluation import SemanticsMetric, SymbolMetric\n", | |||||
"from abl.learning import ABLModel, BasicNN\n", | |||||
"from abl.reasoning import ReasonerBase\n", | |||||
"from abl.utils import ABLLogger, print_log\n", | "from abl.utils import ABLLogger, print_log\n", | ||||
"\n", | |||||
"from examples.models.nn import SymbolNet\n", | |||||
"from datasets.get_hwf import get_hwf" | |||||
"from examples.hwf.datasets.get_hwf import get_hwf\n", | |||||
"from examples.hwf.hwf_kb import HWF_KB\n", | |||||
"from examples.models.nn import SymbolNet" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
@@ -50,37 +50,8 @@ | |||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
"# Initialize knowledge base and abducer\n", | "# Initialize knowledge base and abducer\n", | ||||
"class HWF_KB(KBBase):\n", | |||||
" def __init__(\n", | |||||
" self, \n", | |||||
" pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \n", | |||||
" prebuild_GKB=False,\n", | |||||
" GKB_len_list=[1, 3, 5, 7],\n", | |||||
" max_err=1e-3,\n", | |||||
" use_cache=True\n", | |||||
" ):\n", | |||||
" super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)\n", | |||||
"\n", | |||||
" def _valid_candidate(self, formula):\n", | |||||
" if len(formula) % 2 == 0:\n", | |||||
" return False\n", | |||||
" for i in range(len(formula)):\n", | |||||
" if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']:\n", | |||||
" return False\n", | |||||
" if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']:\n", | |||||
" return False\n", | |||||
" return True\n", | |||||
"\n", | |||||
" def logic_forward(self, formula):\n", | |||||
" if not self._valid_candidate(formula):\n", | |||||
" return np.inf\n", | |||||
" mapping = {str(i): str(i) for i in range(1, 10)}\n", | |||||
" mapping.update({'+': '+', '-': '-', 'times': '*', 'div': '/'})\n", | |||||
" formula = [mapping[f] for f in formula]\n", | |||||
" return eval(''.join(formula))\n", | |||||
"\n", | |||||
"kb = HWF_KB(prebuild_GKB=True)\n", | |||||
"abducer = ReasonerBase(kb, dist_func='confidence')" | |||||
"kb = HWF_KB()\n", | |||||
"abducer = ReasonerBase(kb, dist_func=\"confidence\")" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
@@ -117,10 +88,8 @@ | |||||
" criterion=criterion,\n", | " criterion=criterion,\n", | ||||
" optimizer=optimizer,\n", | " optimizer=optimizer,\n", | ||||
" device=device,\n", | " device=device,\n", | ||||
" save_interval=1,\n", | |||||
" save_dir=weights_dir,\n", | |||||
" batch_size=128,\n", | " batch_size=128,\n", | ||||
" num_epochs=3,\n", | |||||
" num_epochs=1,\n", | |||||
")" | ")" | ||||
] | ] | ||||
}, | }, | ||||
@@ -131,7 +100,7 @@ | |||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
"# Initialize ABL model\n", | "# Initialize ABL model\n", | ||||
"# The main function of the ABL model is to serialize data and \n", | |||||
"# The main function of the ABL model is to serialize data and\n", | |||||
"# provide a unified interface for different machine learning models\n", | "# provide a unified interface for different machine learning models\n", | ||||
"model = ABLModel(base_model)" | "model = ABLModel(base_model)" | ||||
] | ] | ||||
@@ -151,7 +120,7 @@ | |||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
"# Add metric\n", | "# Add metric\n", | ||||
"metric_list = [SymbolMetric(prefix=\"hwf\"), SemanticsMetric(prefix=\"hwf\")]" | |||||
"metric_list = [SymbolMetric(prefix=\"hwf\"), SemanticsMetric(kb=kb, prefix=\"hwf\")]" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
@@ -204,7 +173,7 @@ | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
"bridge.train(train_data, epochs=3, batch_size=1000)\n", | |||||
"bridge.train(train_data, loops=5, segment_size=1000, save_interval=1, save_dir=weights_dir)\n", | |||||
"bridge.test(test_data)" | "bridge.test(test_data)" | ||||
] | ] | ||||
} | } | ||||
@@ -0,0 +1,129 @@ | |||||
import bisect | |||||
from collections import defaultdict | |||||
from itertools import product | |||||
from multiprocessing import Pool | |||||
from typing import Any, Hashable, List | |||||
import numpy as np | |||||
from abl.reasoning import GroundKB | |||||
from abl.structures import ListData | |||||
from abl.utils import hamming_dist | |||||
class HWF_KB(GroundKB): | |||||
def __init__( | |||||
self, | |||||
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "times", "div"], | |||||
GKB_len_list=[1, 3, 5, 7], | |||||
max_err=1e-10, | |||||
): | |||||
self.GKB_len_list = GKB_len_list | |||||
self.max_err = max_err | |||||
self.label2evaluable = {str(i): str(i) for i in range(1, 10)} | |||||
self.label2evaluable.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | |||||
super().__init__(pseudo_label_list) | |||||
def logic_forward(self, data_sample: ListData): | |||||
if not self._valid_candidate(data_sample): | |||||
return None | |||||
formula = data_sample["pred_pseudo_label"][0] | |||||
formula = [self.label2evaluable[f] for f in formula] | |||||
data_sample["Y"] = [eval("".join(formula))] | |||||
return data_sample["Y"][0] | |||||
def check_equal(self, data_sample: ListData, y: Any): | |||||
if not self._valid_candidate(data_sample): | |||||
return False | |||||
formula = data_sample["pred_pseudo_label"][0] | |||||
formula = [self.label2evaluable[f] for f in formula] | |||||
return abs(eval("".join(formula)) - y) < self.max_err | |||||
def construct_base(self) -> dict: | |||||
X, Y = [], [] | |||||
for length in self.GKB_len_list: | |||||
arg_list = [] | |||||
for pre_x in self.pseudo_label_list: | |||||
post_x_it = product(self.pseudo_label_list, repeat=length - 1) | |||||
arg_list.append((pre_x, post_x_it)) | |||||
with Pool(processes=len(arg_list)) as pool: | |||||
ret_list = pool.map(self._get_XY_list, arg_list) | |||||
for XY_list in ret_list: | |||||
if len(XY_list) == 0: | |||||
continue | |||||
part_X, part_Y = zip(*XY_list) | |||||
X.extend(part_X) | |||||
Y.extend(part_Y) | |||||
if Y and isinstance(Y[0], (int, float)): | |||||
X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) | |||||
GKB = {} | |||||
for x, y in zip(X, Y): | |||||
GKB.setdefault(len(x), defaultdict(list))[y].append(x) | |||||
return GKB | |||||
@staticmethod | |||||
def get_key(data_sample: ListData) -> Hashable: | |||||
return (data_sample["symbol_num"], data_sample["Y"][0]) | |||||
def key2candidates(self, key: Hashable) -> List[List[Any]]: | |||||
equation_len, y = key | |||||
if self.max_err == 0: | |||||
return self.GKB[equation_len][y] | |||||
else: | |||||
potential_candidates = self.GKB[equation_len] | |||||
key_list = list(potential_candidates.keys()) | |||||
key_idx = bisect.bisect_left(key_list, y) | |||||
all_candidates = [] | |||||
for idx in range(key_idx - 1, -1, -1): | |||||
k = key_list[idx] | |||||
if abs(k - y) <= self.max_err: | |||||
all_candidates.extend(potential_candidates[k]) | |||||
else: | |||||
break | |||||
for idx in range(key_idx, len(key_list)): | |||||
k = key_list[idx] | |||||
if abs(k - y) <= self.max_err: | |||||
all_candidates.extend(potential_candidates[k]) | |||||
else: | |||||
break | |||||
return all_candidates | |||||
def filter_candidates( | |||||
self, | |||||
data_sample: ListData, | |||||
candidates: List[List[Any]], | |||||
max_revision_num: int, | |||||
require_more_revision: int = 0, | |||||
) -> List[List[Any]]: | |||||
cost_list = hamming_dist(data_sample["pred_pseudo_label"][0], candidates) | |||||
min_revision_num = np.min(cost_list) | |||||
revision_num = min(max_revision_num, min_revision_num + require_more_revision) | |||||
idxs = np.where(cost_list <= revision_num)[0] | |||||
filtered_candidates = [candidates[idx] for idx in idxs] | |||||
return filtered_candidates | |||||
# TODO: change return value to List[ListData] | |||||
def _get_XY_list(self, args): | |||||
pre_x, post_x_it = args[0], args[1] | |||||
XY_list = [] | |||||
for post_x in post_x_it: | |||||
x = (pre_x,) + post_x | |||||
data_sample = ListData(pred_pseudo_label=[x]) | |||||
y = self.logic_forward(data_sample) | |||||
if y is not None: | |||||
XY_list.append((x, y)) | |||||
return XY_list | |||||
@staticmethod | |||||
def _valid_candidate(data_sample): | |||||
formula = data_sample["pred_pseudo_label"][0] | |||||
if len(formula) % 2 == 0: | |||||
return False | |||||
for i in range(len(formula)): | |||||
if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: | |||||
return False | |||||
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | |||||
return False | |||||
return True |
@@ -1,39 +1,49 @@ | |||||
import os.path as osp | |||||
import torchvision | import torchvision | ||||
from torchvision.transforms import transforms | from torchvision.transforms import transforms | ||||
CURRENT_DIR = osp.abspath(osp.dirname(__file__)) | |||||
def get_data(file, img_dataset, get_pseudo_label): | def get_data(file, img_dataset, get_pseudo_label): | ||||
X = [] | |||||
X, Y = [], [] | |||||
if get_pseudo_label: | if get_pseudo_label: | ||||
Z = [] | Z = [] | ||||
Y = [] | |||||
with open(file) as f: | with open(file) as f: | ||||
for line in f: | for line in f: | ||||
line = line.strip().split(' ') | |||||
# if len(X) == 1000: | |||||
# break | |||||
line = line.strip().split(" ") | |||||
X.append([img_dataset[int(line[0])][0], img_dataset[int(line[1])][0]]) | X.append([img_dataset[int(line[0])][0], img_dataset[int(line[1])][0]]) | ||||
if get_pseudo_label: | if get_pseudo_label: | ||||
Z.append([img_dataset[int(line[0])][1], img_dataset[int(line[1])][1]]) | Z.append([img_dataset[int(line[0])][1], img_dataset[int(line[1])][1]]) | ||||
Y.append(int(line[2])) | Y.append(int(line[2])) | ||||
if get_pseudo_label: | if get_pseudo_label: | ||||
return X, Z, Y | return X, Z, Y | ||||
else: | else: | ||||
return X, None, Y | return X, None, Y | ||||
def get_mnist_add(train = True, get_pseudo_label = False): | |||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))]) | |||||
img_dataset = torchvision.datasets.MNIST(root='./datasets/', train=train, download=True, transform=transform) | |||||
def get_mnist_add(train=True, get_pseudo_label=False): | |||||
transform = transforms.Compose( | |||||
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] | |||||
) | |||||
img_dataset = torchvision.datasets.MNIST( | |||||
root=CURRENT_DIR, train=train, download=True, transform=transform | |||||
) | |||||
if train: | if train: | ||||
file = './datasets/train_data.txt' | |||||
file = osp.join(CURRENT_DIR, "train_data.txt") | |||||
else: | else: | ||||
file = './datasets/test_data.txt' | |||||
file = osp.join(CURRENT_DIR, "test_data.txt") | |||||
return get_data(file, img_dataset, get_pseudo_label) | return get_data(file, img_dataset, get_pseudo_label) | ||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
train_X, train_Y = get_mnist_add(train = True) | |||||
test_X, test_Y = get_mnist_add(train = False) | |||||
train_X, train_Z, train_Y = get_mnist_add(train=True) | |||||
test_X, test_Z, test_Y = get_mnist_add(train=False) | |||||
print(len(train_X), len(test_X)) | print(len(train_X), len(test_X)) | ||||
print(train_X[0][0].shape, train_X[0][1].shape, train_Y[0]) | print(train_X[0][0].shape, train_X[0][1].shape, train_Y[0]) | ||||
@@ -2,32 +2,37 @@ | |||||
"cells": [ | "cells": [ | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 3, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
"import torch.nn as nn\n", | |||||
"import torch\n", | |||||
"import os.path as osp\n", | |||||
"\n", | "\n", | ||||
"from abl.reasoning import ReasonerBase, KBBase\n", | |||||
"import torch\n", | |||||
"import torch.nn as nn\n", | |||||
"\n", | "\n", | ||||
"from abl.learning import BasicNN, ABLModel\n", | |||||
"from abl.bridge import SimpleBridge\n", | "from abl.bridge import SimpleBridge\n", | ||||
"from abl.evaluation import SymbolMetric, ABLMetric\n", | |||||
"from abl.utils import ABLLogger\n", | |||||
"\n", | |||||
"from models.nn import LeNet5\n", | |||||
"from examples.mnist_add.datasets.get_mnist_add import get_mnist_add" | |||||
"from abl.evaluation import SemanticsMetric, SymbolMetric\n", | |||||
"from abl.learning import ABLModel, BasicNN\n", | |||||
"from abl.reasoning import ReasonerBase\n", | |||||
"from abl.utils import ABLLogger, print_log\n", | |||||
"from examples.mnist_add.datasets.get_mnist_add import get_mnist_add\n", | |||||
"from examples.mnist_add.mnist_add_kb import AddKB\n", | |||||
"from examples.models.nn import LeNet5" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 4, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
"# Initialize logger\n", | "# Initialize logger\n", | ||||
"logger = ABLLogger.get_instance(\"abl\")" | |||||
"print_log(\"Abductive Learning on the MNIST Add example.\", logger=\"current\")\n", | |||||
"\n", | |||||
"# Retrieve the directory of the Log file and define the directory for saving the model weights.\n", | |||||
"log_dir = ABLLogger.get_current_instance().log_dir\n", | |||||
"weights_dir = osp.join(log_dir, \"weights\")" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
@@ -40,22 +45,19 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 5, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
"# Initialize knowledge base and abducer\n", | "# Initialize knowledge base and abducer\n", | ||||
"class add_KB(KBBase):\n", | |||||
" def __init__(self, pseudo_label_list=list(range(10)), prebuild_GKB=False, GKB_len_list=[2], max_err=0, use_cache=True):\n", | |||||
" super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)\n", | |||||
"\n", | |||||
" def logic_forward(self, nums):\n", | |||||
" return sum(nums)\n", | |||||
"kb = AddKB()\n", | |||||
"\n", | "\n", | ||||
"kb = add_KB(prebuild_GKB=True)\n", | |||||
"# If use cache, get_key should be implemented in the abducer\n", | |||||
"class AddAbducer(ReasonerBase):\n", | |||||
" def get_key(self, data_sample):\n", | |||||
" return (data_sample.to_tuple(\"pred_pseudo_label\"), data_sample[\"Y\"][0])\n", | |||||
"\n", | "\n", | ||||
"# kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='datasets/mnist_add/add.pl')\n", | |||||
"abducer = ReasonerBase(kb, dist_func=\"confidence\")" | |||||
"abducer = AddAbducer(kb, dist_func=\"confidence\", use_cache=True)" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
@@ -68,7 +70,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 6, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -81,19 +83,17 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 7, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
"# Initialize BasicNN\n", | "# Initialize BasicNN\n", | ||||
"# The function of BasicNN is to wrap NN models into the form of an sklearn estimator\n", | "# The function of BasicNN is to wrap NN models into the form of an sklearn estimator\n", | ||||
"base_model = BasicNN(\n", | "base_model = BasicNN(\n", | ||||
" cls,\n", | |||||
" criterion,\n", | |||||
" optimizer,\n", | |||||
" device,\n", | |||||
" save_interval=1,\n", | |||||
" save_dir=logger.save_dir,\n", | |||||
" model=cls,\n", | |||||
" criterion=criterion,\n", | |||||
" optimizer=optimizer,\n", | |||||
" device=device,\n", | |||||
" batch_size=32,\n", | " batch_size=32,\n", | ||||
" num_epochs=1,\n", | " num_epochs=1,\n", | ||||
")" | ")" | ||||
@@ -109,12 +109,12 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 8, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
"# Initialize ABL model\n", | "# Initialize ABL model\n", | ||||
"# The main function of the ABL model is to serialize data and \n", | |||||
"# The main function of the ABL model is to serialize data and\n", | |||||
"# provide a unified interface for different machine learning models\n", | "# provide a unified interface for different machine learning models\n", | ||||
"model = ABLModel(base_model)" | "model = ABLModel(base_model)" | ||||
] | ] | ||||
@@ -129,12 +129,12 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 9, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
"# Add metric\n", | "# Add metric\n", | ||||
"metric = [SymbolMetric(prefix=\"mnist_add\"), ABLMetric(prefix=\"mnist_add\")]" | |||||
"metric = [SymbolMetric(prefix=\"mnist_add\"), SemanticsMetric(kb=kb, prefix=\"mnist_add\")]" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
@@ -147,7 +147,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 10, | |||||
"execution_count": null, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -187,7 +187,7 @@ | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
"bridge.train(train_data, epochs=5, batch_size=10000)\n", | |||||
"bridge.train(train_data, loops=10, segment_size=10000)\n", | |||||
"bridge.test(test_data)" | "bridge.test(test_data)" | ||||
] | ] | ||||
} | } | ||||
@@ -208,7 +208,7 @@ | |||||
"name": "python", | "name": "python", | ||||
"nbconvert_exporter": "python", | "nbconvert_exporter": "python", | ||||
"pygments_lexer": "ipython3", | "pygments_lexer": "ipython3", | ||||
"version": "3.8.13" | |||||
"version": "3.8.16" | |||||
}, | }, | ||||
"orig_nbformat": 4, | "orig_nbformat": 4, | ||||
"vscode": { | "vscode": { | ||||
@@ -0,0 +1,17 @@ | |||||
from typing import Any | |||||
from abl.reasoning import SearchBasedKB | |||||
from abl.structures import ListData | |||||
class AddKB(SearchBasedKB): | |||||
def __init__(self, pseudo_label_list=list(range(10))): | |||||
super().__init__( | |||||
pseudo_label_list=pseudo_label_list | |||||
) | |||||
def check_equal(self, data_sample: ListData, y: Any): | |||||
return self.logic_forward(data_sample) == y | |||||
def logic_forward(self, data_sample): | |||||
return sum(data_sample["pred_pseudo_label"][0]) |
@@ -11,8 +11,8 @@ | |||||
# ================================================================# | # ================================================================# | ||||
import torch | |||||
import numpy as np | import numpy as np | ||||
import torch | |||||
from torch import nn | from torch import nn | ||||
@@ -66,7 +66,8 @@ class SymbolNet(nn.Module): | |||||
num_features = 64 * (image_size[0] // 4 - 1) * (image_size[1] // 4 - 1) | num_features = 64 * (image_size[0] // 4 - 1) * (image_size[1] // 4 - 1) | ||||
self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU()) | self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU()) | ||||
self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU()) | self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU()) | ||||
self.fc3 = nn.Sequential(nn.Linear(84, num_classes), nn.Softmax(dim=1)) | |||||
# self.fc3 = nn.Sequential(nn.Linear(84, num_classes), nn.Softmax(dim=1)) | |||||
self.fc3 = nn.Sequential(nn.Linear(84, num_classes)) | |||||
def forward(self, x): | def forward(self, x): | ||||
x = self.conv1(x) | x = self.conv1(x) | ||||
@@ -84,9 +85,7 @@ class SymbolNetAutoencoder(nn.Module): | |||||
self.base_model = SymbolNet(num_classes, image_size) | self.base_model = SymbolNet(num_classes, image_size) | ||||
self.softmax = nn.Softmax(dim=1) | self.softmax = nn.Softmax(dim=1) | ||||
self.fc1 = nn.Sequential(nn.Linear(num_classes, 100), nn.ReLU()) | self.fc1 = nn.Sequential(nn.Linear(num_classes, 100), nn.ReLU()) | ||||
self.fc2 = nn.Sequential( | |||||
nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU() | |||||
) | |||||
self.fc2 = nn.Sequential(nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU()) | |||||
def forward(self, x): | def forward(self, x): | ||||
x = self.base_model(x) | x = self.base_model(x) | ||||
@@ -1,4 +1,5 @@ | |||||
import os | import os | ||||
from setuptools import find_packages, setup | from setuptools import find_packages, setup | ||||
@@ -0,0 +1,403 @@ | |||||
from abl.reasoning import ReasonerBase, BaseKB, GroundKB, PrologBasedKB | |||||
if __name__ == "__main__": | |||||
prob1 = [ | |||||
[ | |||||
[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||||
] | |||||
] | |||||
prob2 = [ | |||||
[ | |||||
[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||||
] | |||||
] | |||||
class add_KB(BaseKB): | |||||
def __init__(self, pseudo_label_list=list(range(10)), use_cache=True): | |||||
super().__init__(pseudo_label_list, use_cache=use_cache) | |||||
def logic_forward(self, nums): | |||||
return sum(nums) | |||||
class add_GroundKB(GroundKB): | |||||
def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]): | |||||
super().__init__(pseudo_label_list, GKB_len_list) | |||||
def logic_forward(self, nums): | |||||
return sum(nums) | |||||
def test_add(reasoner): | |||||
res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||||
print(res) | |||||
res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||||
print(res) | |||||
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) | |||||
print(res) | |||||
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) | |||||
print(res) | |||||
res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) | |||||
print(res) | |||||
print() | |||||
print("add_KB with GKB:") | |||||
kb = add_GroundKB() | |||||
reasoner = ReasonerBase(kb, "confidence") | |||||
test_add(reasoner) | |||||
print("add_KB without GKB:") | |||||
kb = add_KB() | |||||
reasoner = ReasonerBase(kb, "confidence") | |||||
test_add(reasoner) | |||||
print("add_KB without GKB, no cache") | |||||
kb = add_KB(use_cache=False) | |||||
reasoner = ReasonerBase(kb, "confidence") | |||||
test_add(reasoner) | |||||
print("PrologBasedKB with add.pl:") | |||||
kb = PrologBasedKB( | |||||
pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/datasets/add.pl" | |||||
) | |||||
reasoner = ReasonerBase(kb, "confidence") | |||||
test_add(reasoner) | |||||
print("PrologBasedKB with add.pl using zoopt:") | |||||
kb = PrologBasedKB( | |||||
pseudo_label_list=list(range(10)), | |||||
pl_file="examples/mnist_add/datasets/add.pl", | |||||
) | |||||
reasoner = ReasonerBase(kb, "confidence", use_zoopt=True) | |||||
test_add(reasoner) | |||||
print("add_KB with multiple inputs at once:") | |||||
multiple_prob = [ | |||||
[ | |||||
[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||||
], | |||||
[ | |||||
[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||||
], | |||||
] | |||||
kb = add_KB() | |||||
reasoner = ReasonerBase(kb, "confidence") | |||||
res = reasoner.batch_abduce( | |||||
multiple_prob, | |||||
[[1, 1], [1, 2]], | |||||
[4, 8], | |||||
max_revision=2, | |||||
require_more_revision=0, | |||||
) | |||||
print(res) | |||||
res = reasoner.batch_abduce( | |||||
multiple_prob, | |||||
[[1, 1], [1, 2]], | |||||
[4, 8], | |||||
max_revision=2, | |||||
require_more_revision=1, | |||||
) | |||||
print(res) | |||||
print() | |||||
class HWF_KB(BaseKB): | |||||
def __init__( | |||||
self, | |||||
pseudo_label_list=[ | |||||
"1", | |||||
"2", | |||||
"3", | |||||
"4", | |||||
"5", | |||||
"6", | |||||
"7", | |||||
"8", | |||||
"9", | |||||
"+", | |||||
"-", | |||||
"times", | |||||
"div", | |||||
], | |||||
max_err=1e-3, | |||||
): | |||||
super().__init__(pseudo_label_list, max_err) | |||||
def _valid_candidate(self, formula): | |||||
if len(formula) % 2 == 0: | |||||
return False | |||||
for i in range(len(formula)): | |||||
if i % 2 == 0 and formula[i] not in [ | |||||
"1", | |||||
"2", | |||||
"3", | |||||
"4", | |||||
"5", | |||||
"6", | |||||
"7", | |||||
"8", | |||||
"9", | |||||
]: | |||||
return False | |||||
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | |||||
return False | |||||
return True | |||||
def logic_forward(self, formula): | |||||
if not self._valid_candidate(formula): | |||||
return np.inf | |||||
mapping = {str(i): str(i) for i in range(1, 10)} | |||||
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | |||||
formula = [mapping[f] for f in formula] | |||||
return eval("".join(formula)) | |||||
class HWF_GroundKB(GroundKB): | |||||
def __init__( | |||||
self, | |||||
pseudo_label_list=[ | |||||
"1", | |||||
"2", | |||||
"3", | |||||
"4", | |||||
"5", | |||||
"6", | |||||
"7", | |||||
"8", | |||||
"9", | |||||
"+", | |||||
"-", | |||||
"times", | |||||
"div", | |||||
], | |||||
GKB_len_list=[1, 3, 5, 7], | |||||
max_err=1e-3, | |||||
): | |||||
super().__init__(pseudo_label_list, GKB_len_list, max_err) | |||||
def _valid_candidate(self, formula): | |||||
if len(formula) % 2 == 0: | |||||
return False | |||||
for i in range(len(formula)): | |||||
if i % 2 == 0 and formula[i] not in [ | |||||
"1", | |||||
"2", | |||||
"3", | |||||
"4", | |||||
"5", | |||||
"6", | |||||
"7", | |||||
"8", | |||||
"9", | |||||
]: | |||||
return False | |||||
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | |||||
return False | |||||
return True | |||||
def logic_forward(self, formula): | |||||
if not self._valid_candidate(formula): | |||||
return np.inf | |||||
mapping = {str(i): str(i) for i in range(1, 10)} | |||||
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | |||||
formula = [mapping[f] for f in formula] | |||||
return eval("".join(formula)) | |||||
def test_hwf(reasoner): | |||||
res = reasoner.batch_abduce( | |||||
[None], | |||||
[["5", "+", "2"]], | |||||
[3], | |||||
max_revision=2, | |||||
require_more_revision=0, | |||||
) | |||||
print(res) | |||||
res = reasoner.batch_abduce( | |||||
[None], | |||||
[["5", "+", "9"]], | |||||
[65], | |||||
max_revision=3, | |||||
require_more_revision=0, | |||||
) | |||||
print(res) | |||||
res = reasoner.batch_abduce( | |||||
[None], | |||||
[["5", "8", "8", "8", "8"]], | |||||
[3.17], | |||||
max_revision=5, | |||||
require_more_revision=3, | |||||
) | |||||
print(res) | |||||
print() | |||||
def test_hwf_multiple(reasoner, max_revisions): | |||||
res = reasoner.batch_abduce( | |||||
[None, None], | |||||
[["5", "+", "2"], ["5", "+", "9"]], | |||||
[3, 64], | |||||
max_revision=max_revisions[0], | |||||
require_more_revision=0, | |||||
) | |||||
print(res) | |||||
res = reasoner.batch_abduce( | |||||
[None, None], | |||||
[["5", "+", "2"], ["5", "+", "9"]], | |||||
[3, 64], | |||||
max_revision=max_revisions[1], | |||||
require_more_revision=0, | |||||
) | |||||
print(res) | |||||
res = reasoner.batch_abduce( | |||||
[None, None], | |||||
[["5", "+", "2"], ["5", "+", "9"]], | |||||
[3, 65], | |||||
max_revision=max_revisions[2], | |||||
require_more_revision=0, | |||||
) | |||||
print(res) | |||||
print() | |||||
print("HWF_KB with GKB, max_err=0.1") | |||||
kb = HWF_GroundKB(GKB_len_list=[1, 3, 5], max_err=0.1) | |||||
reasoner = ReasonerBase(kb, "hamming") | |||||
test_hwf(reasoner) | |||||
print("HWF_KB without GKB, max_err=0.1") | |||||
kb = HWF_KB(max_err=0.1) | |||||
reasoner = ReasonerBase(kb, "hamming") | |||||
test_hwf(reasoner) | |||||
print("HWF_KB with GKB, max_err=1") | |||||
kb = HWF_GroundKB(GKB_len_list=[1, 3, 5], max_err=1) | |||||
reasoner = ReasonerBase(kb, "hamming") | |||||
test_hwf(reasoner) | |||||
print("HWF_KB without GKB, max_err=1") | |||||
kb = HWF_KB(max_err=1) | |||||
reasoner = ReasonerBase(kb, "hamming") | |||||
test_hwf(reasoner) | |||||
print("HWF_KB with multiple inputs at once:") | |||||
kb = HWF_KB(max_err=0.1) | |||||
reasoner = ReasonerBase(kb, "hamming") | |||||
test_hwf_multiple(reasoner, max_revisions=[1, 3, 3]) | |||||
print("max_revision is float") | |||||
test_hwf_multiple(reasoner, max_revisions=[0.5, 0.9, 0.9]) | |||||
class HED_prolog_KB(PrologBasedKB): | |||||
def __init__(self, pseudo_label_list, pl_file): | |||||
super().__init__(pseudo_label_list, pl_file) | |||||
def consist_rule(self, exs, rules): | |||||
rules = str(rules).replace("'", "") | |||||
pl_query = "eval_inst_feature(%s, %s)." % (exs, rules) | |||||
return len(list(self.prolog.query(pl_query))) != 0 | |||||
def abduce_rules(self, pred_res): | |||||
pl_query = "consistent_inst_feature(%s, X)." % pred_res | |||||
prolog_result = list(self.prolog.query(pl_query)) | |||||
if len(prolog_result) == 0: | |||||
return None | |||||
prolog_rules = prolog_result[0]["X"] | |||||
rules = [rule.value for rule in prolog_rules] | |||||
return rules | |||||
class HED_Reasoner(ReasonerBase): | |||||
def __init__(self, kb, dist_func="hamming"): | |||||
super().__init__(kb, dist_func, use_zoopt=True) | |||||
def _revise_at_idxs(self, pred_res, y, all_revision_flag, idxs): | |||||
pred = [] | |||||
k = [] | |||||
revision_flag = [] | |||||
for idx in idxs: | |||||
pred.append(pred_res[idx]) | |||||
k.append(y[idx]) | |||||
revision_flag += list(all_revision_flag[idx]) | |||||
revision_idx = np.where(np.array(revision_flag) != 0)[0] | |||||
candidate = self.revise_at_idx(pred, k, revision_idx) | |||||
return candidate | |||||
def zoopt_revision_score(self, symbol_num, pred_res, pred_prob, y, sol): | |||||
all_revision_flag = reform_idx(sol.get_x(), pred_res) | |||||
lefted_idxs = [i for i in range(len(pred_res))] | |||||
candidate_size = [] | |||||
while lefted_idxs: | |||||
idxs = [] | |||||
idxs.append(lefted_idxs.pop(0)) | |||||
max_candidate_idxs = [] | |||||
found = False | |||||
for idx in range(-1, len(pred_res)): | |||||
if (not idx in idxs) and (idx >= 0): | |||||
idxs.append(idx) | |||||
candidate = self._revise_at_idxs(pred_res, y, all_revision_flag, idxs) | |||||
if len(candidate) == 0: | |||||
if len(idxs) > 1: | |||||
idxs.pop() | |||||
else: | |||||
if len(idxs) > len(max_candidate_idxs): | |||||
found = True | |||||
max_candidate_idxs = idxs.copy() | |||||
removed = [i for i in lefted_idxs if i in max_candidate_idxs] | |||||
if found: | |||||
candidate_size.append(len(removed) + 1) | |||||
lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] | |||||
candidate_size.sort() | |||||
score = 0 | |||||
import math | |||||
for i in range(0, len(candidate_size)): | |||||
score -= math.exp(-i) * candidate_size[i] | |||||
return score | |||||
def abduce_rules(self, pred_res): | |||||
return self.kb.abduce_rules(pred_res) | |||||
kb = HED_prolog_KB( | |||||
pseudo_label_list=[1, 0, "+", "="], | |||||
pl_file="examples/hed/datasets/learn_add.pl", | |||||
) | |||||
reasoner = HED_Reasoner(kb) | |||||
consist_exs = [ | |||||
[1, 1, "+", 0, "=", 1, 1], | |||||
[1, "+", 1, "=", 1, 0], | |||||
[0, "+", 0, "=", 0], | |||||
] | |||||
inconsist_exs1 = [ | |||||
[1, 1, "+", 0, "=", 1, 1], | |||||
[1, "+", 1, "=", 1, 0], | |||||
[0, "+", 0, "=", 0], | |||||
[0, "+", 0, "=", 1], | |||||
] | |||||
inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] | |||||
rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"] | |||||
print("HED_kb logic forward") | |||||
print(kb.logic_forward(consist_exs)) | |||||
print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2)) | |||||
print() | |||||
print("HED_kb consist rule") | |||||
print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules)) | |||||
print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules)) | |||||
print() | |||||
print("HED_Reasoner abduce") | |||||
res = reasoner.abduce([[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs)) | |||||
print(res) | |||||
res = reasoner.abduce( | |||||
[[[None]]] * len(inconsist_exs1), inconsist_exs1, [None] * len(inconsist_exs1) | |||||
) | |||||
print(res) | |||||
res = reasoner.abduce( | |||||
[[[None]]] * len(inconsist_exs2), inconsist_exs2, [None] * len(inconsist_exs2) | |||||
) | |||||
print(res) | |||||
print() | |||||
print("HED_Reasoner abduce rules") | |||||
abduced_rules = reasoner.abduce_rules(consist_exs) | |||||
print(abduced_rules) |