@@ -24,10 +24,8 @@ class SimpleBridge(BaseBridge): | |||||
# TODO: add abducer.mapping to the property of SimpleBridge | # TODO: add abducer.mapping to the property of SimpleBridge | ||||
def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: | def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: | ||||
pred_res = self.model.predict(data_samples) | |||||
data_samples.pred_idx = pred_res["label"] | |||||
data_samples.pred_prob = pred_res["prob"] | |||||
return data_samples["pred_idx"], data_samples["pred_prob"] | |||||
self.model.predict(data_samples) | |||||
return data_samples["pred_idx"], data_samples.get("pred_prob", None) | |||||
def abduce_pseudo_label( | def abduce_pseudo_label( | ||||
self, | self, | ||||
@@ -1,3 +1,4 @@ | |||||
from .bridge_dataset import BridgeDataset | from .bridge_dataset import BridgeDataset | ||||
from .classification_dataset import ClassificationDataset | from .classification_dataset import ClassificationDataset | ||||
from .regression_dataset import RegressionDataset | |||||
from .prediction_dataset import PredictionDataset | |||||
from .regression_dataset import RegressionDataset |
@@ -0,0 +1,56 @@ | |||||
from typing import Any, Callable, List, Tuple | |||||
import torch | |||||
from torch.utils.data import Dataset | |||||
class PredictionDataset(Dataset): | |||||
def __init__(self, X: List[Any], transform: Callable[..., Any] = None): | |||||
""" | |||||
Initialize the dataset used for classification task. | |||||
Parameters | |||||
---------- | |||||
X : List[Any] | |||||
The input data. | |||||
transform : Callable[..., Any], optional | |||||
A function/transform that takes in an object and returns a transformed version. Defaults to None. | |||||
""" | |||||
if not isinstance(X, list): | |||||
raise ValueError("X should be of type list.") | |||||
self.X = X | |||||
self.transform = transform | |||||
def __len__(self) -> int: | |||||
""" | |||||
Return the length of the dataset. | |||||
Returns | |||||
------- | |||||
int | |||||
The length of the dataset. | |||||
""" | |||||
return len(self.X) | |||||
def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]: | |||||
""" | |||||
Get the item at the given index. | |||||
Parameters | |||||
---------- | |||||
index : int | |||||
The index of the item to get. | |||||
Returns | |||||
------- | |||||
Tuple[Any, torch.Tensor] | |||||
A tuple containing the object and its label. | |||||
""" | |||||
if index >= len(self): | |||||
raise ValueError("index range error") | |||||
x = self.X[index] | |||||
if self.transform is not None: | |||||
x = self.transform(x) | |||||
return x |
@@ -71,11 +71,14 @@ class ABLModel: | |||||
label = prob.argmax(axis=1) | label = prob.argmax(axis=1) | ||||
prob = reform_idx(prob, data_samples["X"]) | prob = reform_idx(prob, data_samples["X"]) | ||||
else: | else: | ||||
prob = [None] * len(data_samples) | |||||
prob = None | |||||
label = model.predict(X=data_X) | label = model.predict(X=data_X) | ||||
label = reform_idx(label, data_samples["X"]) | label = reform_idx(label, data_samples["X"]) | ||||
data_samples.pred_idx = label | |||||
if prob is not None: | |||||
data_samples.pred_prob = prob | |||||
return {"label": label, "prob": prob} | return {"label": label, "prob": prob} | ||||
def train(self, data_samples: ListData) -> float: | def train(self, data_samples: ListData) -> float: | ||||
@@ -11,13 +11,14 @@ | |||||
# ================================================================# | # ================================================================# | ||||
import os | import os | ||||
import logging | |||||
from typing import Any, Callable, List, Optional, T, Tuple | from typing import Any, Callable, List, Optional, T, Tuple | ||||
import numpy | import numpy | ||||
import torch | import torch | ||||
from torch.utils.data import DataLoader | from torch.utils.data import DataLoader | ||||
from ..dataset import ClassificationDataset | |||||
from ..dataset import ClassificationDataset, PredictionDataset | |||||
from ..utils.logger import print_log | from ..utils.logger import print_log | ||||
@@ -197,7 +198,12 @@ class BasicNN: | |||||
return torch.cat(results, axis=0) | return torch.cat(results, axis=0) | ||||
def predict(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray: | |||||
def predict( | |||||
self, | |||||
data_loader: DataLoader = None, | |||||
X: List[Any] = None, | |||||
test_transform: Callable[..., Any] = None, | |||||
) -> numpy.ndarray: | |||||
""" | """ | ||||
Predict the class of the input data. | Predict the class of the input data. | ||||
@@ -215,12 +221,29 @@ class BasicNN: | |||||
""" | """ | ||||
if data_loader is None: | if data_loader is None: | ||||
if self.transform is not None: | |||||
X = [self.transform(x) for x in X] | |||||
data_loader = DataLoader(X, batch_size=self.batch_size) | |||||
if test_transform is None: | |||||
print_log( | |||||
"Transform used in the training phase will be used in prediction.", | |||||
"current", | |||||
level=logging.WARNING, | |||||
) | |||||
dataset = PredictionDataset(X, self.transform) | |||||
else: | |||||
dataset = PredictionDataset(X, test_transform) | |||||
data_loader = DataLoader( | |||||
dataset, | |||||
batch_size=self.batch_size, | |||||
num_workers=int(self.num_workers), | |||||
collate_fn=self.collate_fn, | |||||
) | |||||
return self._predict(data_loader).argmax(axis=1).cpu().numpy() | return self._predict(data_loader).argmax(axis=1).cpu().numpy() | ||||
def predict_proba(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray: | |||||
def predict_proba( | |||||
self, | |||||
data_loader: DataLoader = None, | |||||
X: List[Any] = None, | |||||
test_transform: Callable[..., Any] = None, | |||||
) -> numpy.ndarray: | |||||
""" | """ | ||||
Predict the probability of each class for the input data. | Predict the probability of each class for the input data. | ||||
@@ -238,9 +261,21 @@ class BasicNN: | |||||
""" | """ | ||||
if data_loader is None: | if data_loader is None: | ||||
if self.transform is not None: | |||||
X = [self.transform(x) for x in X] | |||||
data_loader = DataLoader(X, batch_size=self.batch_size) | |||||
if test_transform is None: | |||||
print_log( | |||||
"Transform used in the training phase will be used in prediction.", | |||||
"current", | |||||
level=logging.WARNING, | |||||
) | |||||
dataset = PredictionDataset(X, self.transform) | |||||
else: | |||||
dataset = PredictionDataset(X, test_transform) | |||||
data_loader = DataLoader( | |||||
dataset, | |||||
batch_size=self.batch_size, | |||||
num_workers=int(self.num_workers), | |||||
collate_fn=self.collate_fn, | |||||
) | |||||
return self._predict(data_loader).softmax(axis=1).cpu().numpy() | return self._predict(data_loader).softmax(axis=1).cpu().numpy() | ||||
def _score(self, data_loader) -> Tuple[float, float]: | def _score(self, data_loader) -> Tuple[float, float]: | ||||
@@ -1,8 +1,7 @@ | |||||
from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
from typing import Any, Hashable, List | from typing import Any, Hashable, List | ||||
from abl.structures import ListData | |||||
from ..structures import ListData | |||||
from .base_kb import BaseKB | from .base_kb import BaseKB | ||||