@@ -24,10 +24,8 @@ class SimpleBridge(BaseBridge): | |||
# TODO: add abducer.mapping to the property of SimpleBridge | |||
def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: | |||
pred_res = self.model.predict(data_samples) | |||
data_samples.pred_idx = pred_res["label"] | |||
data_samples.pred_prob = pred_res["prob"] | |||
return data_samples["pred_idx"], data_samples["pred_prob"] | |||
self.model.predict(data_samples) | |||
return data_samples["pred_idx"], data_samples.get("pred_prob", None) | |||
def abduce_pseudo_label( | |||
self, | |||
@@ -1,3 +1,4 @@ | |||
from .bridge_dataset import BridgeDataset | |||
from .classification_dataset import ClassificationDataset | |||
from .regression_dataset import RegressionDataset | |||
from .prediction_dataset import PredictionDataset | |||
from .regression_dataset import RegressionDataset |
@@ -0,0 +1,56 @@ | |||
from typing import Any, Callable, List, Tuple | |||
import torch | |||
from torch.utils.data import Dataset | |||
class PredictionDataset(Dataset): | |||
def __init__(self, X: List[Any], transform: Callable[..., Any] = None): | |||
""" | |||
Initialize the dataset used for classification task. | |||
Parameters | |||
---------- | |||
X : List[Any] | |||
The input data. | |||
transform : Callable[..., Any], optional | |||
A function/transform that takes in an object and returns a transformed version. Defaults to None. | |||
""" | |||
if not isinstance(X, list): | |||
raise ValueError("X should be of type list.") | |||
self.X = X | |||
self.transform = transform | |||
def __len__(self) -> int: | |||
""" | |||
Return the length of the dataset. | |||
Returns | |||
------- | |||
int | |||
The length of the dataset. | |||
""" | |||
return len(self.X) | |||
def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]: | |||
""" | |||
Get the item at the given index. | |||
Parameters | |||
---------- | |||
index : int | |||
The index of the item to get. | |||
Returns | |||
------- | |||
Tuple[Any, torch.Tensor] | |||
A tuple containing the object and its label. | |||
""" | |||
if index >= len(self): | |||
raise ValueError("index range error") | |||
x = self.X[index] | |||
if self.transform is not None: | |||
x = self.transform(x) | |||
return x |
@@ -71,11 +71,14 @@ class ABLModel: | |||
label = prob.argmax(axis=1) | |||
prob = reform_idx(prob, data_samples["X"]) | |||
else: | |||
prob = [None] * len(data_samples) | |||
prob = None | |||
label = model.predict(X=data_X) | |||
label = reform_idx(label, data_samples["X"]) | |||
data_samples.pred_idx = label | |||
if prob is not None: | |||
data_samples.pred_prob = prob | |||
return {"label": label, "prob": prob} | |||
def train(self, data_samples: ListData) -> float: | |||
@@ -11,13 +11,14 @@ | |||
# ================================================================# | |||
import os | |||
import logging | |||
from typing import Any, Callable, List, Optional, T, Tuple | |||
import numpy | |||
import torch | |||
from torch.utils.data import DataLoader | |||
from ..dataset import ClassificationDataset | |||
from ..dataset import ClassificationDataset, PredictionDataset | |||
from ..utils.logger import print_log | |||
@@ -197,7 +198,12 @@ class BasicNN: | |||
return torch.cat(results, axis=0) | |||
def predict(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray: | |||
def predict( | |||
self, | |||
data_loader: DataLoader = None, | |||
X: List[Any] = None, | |||
test_transform: Callable[..., Any] = None, | |||
) -> numpy.ndarray: | |||
""" | |||
Predict the class of the input data. | |||
@@ -215,12 +221,29 @@ class BasicNN: | |||
""" | |||
if data_loader is None: | |||
if self.transform is not None: | |||
X = [self.transform(x) for x in X] | |||
data_loader = DataLoader(X, batch_size=self.batch_size) | |||
if test_transform is None: | |||
print_log( | |||
"Transform used in the training phase will be used in prediction.", | |||
"current", | |||
level=logging.WARNING, | |||
) | |||
dataset = PredictionDataset(X, self.transform) | |||
else: | |||
dataset = PredictionDataset(X, test_transform) | |||
data_loader = DataLoader( | |||
dataset, | |||
batch_size=self.batch_size, | |||
num_workers=int(self.num_workers), | |||
collate_fn=self.collate_fn, | |||
) | |||
return self._predict(data_loader).argmax(axis=1).cpu().numpy() | |||
def predict_proba(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray: | |||
def predict_proba( | |||
self, | |||
data_loader: DataLoader = None, | |||
X: List[Any] = None, | |||
test_transform: Callable[..., Any] = None, | |||
) -> numpy.ndarray: | |||
""" | |||
Predict the probability of each class for the input data. | |||
@@ -238,9 +261,21 @@ class BasicNN: | |||
""" | |||
if data_loader is None: | |||
if self.transform is not None: | |||
X = [self.transform(x) for x in X] | |||
data_loader = DataLoader(X, batch_size=self.batch_size) | |||
if test_transform is None: | |||
print_log( | |||
"Transform used in the training phase will be used in prediction.", | |||
"current", | |||
level=logging.WARNING, | |||
) | |||
dataset = PredictionDataset(X, self.transform) | |||
else: | |||
dataset = PredictionDataset(X, test_transform) | |||
data_loader = DataLoader( | |||
dataset, | |||
batch_size=self.batch_size, | |||
num_workers=int(self.num_workers), | |||
collate_fn=self.collate_fn, | |||
) | |||
return self._predict(data_loader).softmax(axis=1).cpu().numpy() | |||
def _score(self, data_loader) -> Tuple[float, float]: | |||
@@ -1,8 +1,7 @@ | |||
from abc import ABC, abstractmethod | |||
from typing import Any, Hashable, List | |||
from abl.structures import ListData | |||
from ..structures import ListData | |||
from .base_kb import BaseKB | |||