""" Implementation of PyTorch dataset class used for Prediction. Copyright (c) 2024 LAMDA. All rights reserved. """ from typing import Any, Callable, List, Tuple, Optional import torch from torch.utils.data import Dataset class PredictionDataset(Dataset): """ Dataset used for prediction. Parameters ---------- X : List[Any] The input data. transform : Callable[..., Any], optional A function/transform that takes an object and returns a transformed version. Defaults to None. """ def __init__(self, X: List[Any], transform: Optional[Callable[..., Any]] = None): if not isinstance(X, list): raise ValueError("X should be of type list.") self.X = X self.transform = transform def __len__(self) -> int: """ Return the length of the dataset. Returns ------- int The length of the dataset. """ return len(self.X) def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]: """ Get the item at the given index. Parameters ---------- index : int The index of the item to get. Returns ------- Tuple[Any, torch.Tensor] A tuple containing the object and its label. """ if index >= len(self): raise ValueError("index range error") x = self.X[index] if self.transform is not None: x = self.transform(x) return x