From e6d17ba57fbc4e3fb90d2c41bb810c8681d12252 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Thu, 21 Dec 2023 10:44:15 +0800 Subject: [PATCH] [MNT] remove BridgeDataset --- abl/dataset/__init__.py | 2 -- abl/dataset/bridge_dataset.py | 68 ----------------------------------- 2 files changed, 70 deletions(-) delete mode 100644 abl/dataset/bridge_dataset.py diff --git a/abl/dataset/__init__.py b/abl/dataset/__init__.py index 3e1d0e1..b8237a2 100644 --- a/abl/dataset/__init__.py +++ b/abl/dataset/__init__.py @@ -1,10 +1,8 @@ -from .bridge_dataset import BridgeDataset from .classification_dataset import ClassificationDataset from .prediction_dataset import PredictionDataset from .regression_dataset import RegressionDataset __all__ = [ - "BridgeDataset", "ClassificationDataset", "PredictionDataset", "RegressionDataset", diff --git a/abl/dataset/bridge_dataset.py b/abl/dataset/bridge_dataset.py deleted file mode 100644 index 8ce525a..0000000 --- a/abl/dataset/bridge_dataset.py +++ /dev/null @@ -1,68 +0,0 @@ -from typing import Any, List, Optional, Tuple - -from torch.utils.data import Dataset - - -class BridgeDataset(Dataset): - """Dataset used in ``BaseBridge``. - - Parameters - ---------- - X : List[List[Any]] - A list of objects representing the input data. - gt_pseudo_label : List[List[Any]], optional - A list of objects representing the ground truth label of each element in ``X``. - Y : List[Any] - A list of objects representing the ground truth of the reasoning result of - each instance in ``X``. - """ - - def __init__( - self, - X: List[List[Any]], - gt_pseudo_label: Optional[List[List[Any]]], - Y: List[Any], - ): - if (not isinstance(X, list)) or (not isinstance(Y, list)): - raise ValueError("X and Y should be of type list.") - if len(X) != len(Y): - raise ValueError("Length of X and Y must be equal.") - - self.X = X - self.gt_pseudo_label = gt_pseudo_label - self.Y = Y - - if self.gt_pseudo_label is None: - self.gt_pseudo_label = [None] * len(self.X) - - def __len__(self): - """Return the length of the dataset. - - Returns - ------- - int - The length of the dataset. - """ - return len(self.X) - - def __getitem__(self, index: int) -> Tuple[List, List, Any]: - """Get an item from the dataset. - - Parameters - ---------- - index : int - The index of the item to retrieve. - - Returns - ------- - Tuple[List, List, Any] - A tuple containing the input and output data at the specified index. - """ - if index >= len(self): - raise ValueError("index range error") - - X = self.X[index] - gt_pseudo_label = self.gt_pseudo_label[index] - Y = self.Y[index] - - return (X, gt_pseudo_label, Y)