diff --git a/abl/bridge/simple_bridge.py b/abl/bridge/simple_bridge.py index ee0e43e..9093bc1 100644 --- a/abl/bridge/simple_bridge.py +++ b/abl/bridge/simple_bridge.py @@ -24,10 +24,8 @@ class SimpleBridge(BaseBridge): # TODO: add abducer.mapping to the property of SimpleBridge def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: - pred_res = self.model.predict(data_samples) - data_samples.pred_idx = pred_res["label"] - data_samples.pred_prob = pred_res["prob"] - return data_samples["pred_idx"], data_samples["pred_prob"] + self.model.predict(data_samples) + return data_samples["pred_idx"], data_samples.get("pred_prob", None) def abduce_pseudo_label( self, diff --git a/abl/dataset/__init__.py b/abl/dataset/__init__.py index 6be0df1..a487476 100644 --- a/abl/dataset/__init__.py +++ b/abl/dataset/__init__.py @@ -1,3 +1,4 @@ from .bridge_dataset import BridgeDataset from .classification_dataset import ClassificationDataset -from .regression_dataset import RegressionDataset \ No newline at end of file +from .prediction_dataset import PredictionDataset +from .regression_dataset import RegressionDataset diff --git a/abl/dataset/prediction_dataset.py b/abl/dataset/prediction_dataset.py new file mode 100644 index 0000000..8e3c717 --- /dev/null +++ b/abl/dataset/prediction_dataset.py @@ -0,0 +1,56 @@ +from typing import Any, Callable, List, Tuple + +import torch +from torch.utils.data import Dataset + + +class PredictionDataset(Dataset): + def __init__(self, X: List[Any], transform: Callable[..., Any] = None): + """ + Initialize the dataset used for classification task. + + Parameters + ---------- + X : List[Any] + The input data. + transform : Callable[..., Any], optional + A function/transform that takes in an object and returns a transformed version. Defaults to None. + """ + if not isinstance(X, list): + raise ValueError("X should be of type list.") + + self.X = X + self.transform = transform + + def __len__(self) -> int: + """ + Return the length of the dataset. + + Returns + ------- + int + The length of the dataset. + """ + return len(self.X) + + def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]: + """ + Get the item at the given index. + + Parameters + ---------- + index : int + The index of the item to get. + + Returns + ------- + Tuple[Any, torch.Tensor] + A tuple containing the object and its label. + """ + if index >= len(self): + raise ValueError("index range error") + + x = self.X[index] + if self.transform is not None: + x = self.transform(x) + return x diff --git a/abl/learning/abl_model.py b/abl/learning/abl_model.py index 4c6fbad..6685cc4 100644 --- a/abl/learning/abl_model.py +++ b/abl/learning/abl_model.py @@ -71,11 +71,14 @@ class ABLModel: label = prob.argmax(axis=1) prob = reform_idx(prob, data_samples["X"]) else: - prob = [None] * len(data_samples) + prob = None label = model.predict(X=data_X) - label = reform_idx(label, data_samples["X"]) + data_samples.pred_idx = label + if prob is not None: + data_samples.pred_prob = prob + return {"label": label, "prob": prob} def train(self, data_samples: ListData) -> float: diff --git a/abl/learning/basic_nn.py b/abl/learning/basic_nn.py index 305ed65..b1da93c 100644 --- a/abl/learning/basic_nn.py +++ b/abl/learning/basic_nn.py @@ -11,13 +11,14 @@ # ================================================================# import os +import logging from typing import Any, Callable, List, Optional, T, Tuple import numpy import torch from torch.utils.data import DataLoader -from ..dataset import ClassificationDataset +from ..dataset import ClassificationDataset, PredictionDataset from ..utils.logger import print_log @@ -197,7 +198,12 @@ class BasicNN: return torch.cat(results, axis=0) - def predict(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray: + def predict( + self, + data_loader: DataLoader = None, + X: List[Any] = None, + test_transform: Callable[..., Any] = None, + ) -> numpy.ndarray: """ Predict the class of the input data. @@ -215,12 +221,29 @@ class BasicNN: """ if data_loader is None: - if self.transform is not None: - X = [self.transform(x) for x in X] - data_loader = DataLoader(X, batch_size=self.batch_size) + if test_transform is None: + print_log( + "Transform used in the training phase will be used in prediction.", + "current", + level=logging.WARNING, + ) + dataset = PredictionDataset(X, self.transform) + else: + dataset = PredictionDataset(X, test_transform) + data_loader = DataLoader( + dataset, + batch_size=self.batch_size, + num_workers=int(self.num_workers), + collate_fn=self.collate_fn, + ) return self._predict(data_loader).argmax(axis=1).cpu().numpy() - def predict_proba(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray: + def predict_proba( + self, + data_loader: DataLoader = None, + X: List[Any] = None, + test_transform: Callable[..., Any] = None, + ) -> numpy.ndarray: """ Predict the probability of each class for the input data. @@ -238,9 +261,21 @@ class BasicNN: """ if data_loader is None: - if self.transform is not None: - X = [self.transform(x) for x in X] - data_loader = DataLoader(X, batch_size=self.batch_size) + if test_transform is None: + print_log( + "Transform used in the training phase will be used in prediction.", + "current", + level=logging.WARNING, + ) + dataset = PredictionDataset(X, self.transform) + else: + dataset = PredictionDataset(X, test_transform) + data_loader = DataLoader( + dataset, + batch_size=self.batch_size, + num_workers=int(self.num_workers), + collate_fn=self.collate_fn, + ) return self._predict(data_loader).softmax(axis=1).cpu().numpy() def _score(self, data_loader) -> Tuple[float, float]: diff --git a/abl/reasoning/ground_kb.py b/abl/reasoning/ground_kb.py index 9f9428d..4b241d2 100644 --- a/abl/reasoning/ground_kb.py +++ b/abl/reasoning/ground_kb.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod from typing import Any, Hashable, List -from abl.structures import ListData - +from ..structures import ListData from .base_kb import BaseKB