|
- from typing import Any, Callable, List, Tuple
-
- import torch
- from torch.utils.data import Dataset
-
-
- class ClassificationDataset(Dataset):
- """
- Dataset used for classification task.
-
- Parameters
- ----------
- X : List[Any]
- The input data.
- Y : List[int]
- The target data.
- transform : Callable[..., Any], optional
- A function/transform that takes in an object and returns a transformed version. Defaults to None.
- """
- def __init__(
- self, X: List[Any], Y: List[int], transform: Callable[..., Any] = None
- ):
- if (not isinstance(X, list)) or (not isinstance(Y, list)):
- raise ValueError("X and Y should be of type list.")
- if len(X) != len(Y):
- raise ValueError("Length of X and Y must be equal.")
-
- self.X = X
- self.Y = torch.LongTensor(Y)
- self.transform = transform
-
- def __len__(self) -> int:
- """
- Return the length of the dataset.
-
- Returns
- -------
- int
- The length of the dataset.
- """
- return len(self.X)
-
- def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]:
- """
- Get the item at the given index.
-
- Parameters
- ----------
- index : int
- The index of the item to get.
-
- Returns
- -------
- Tuple[Any, torch.Tensor]
- A tuple containing the object and its label.
- """
- if index >= len(self):
- raise ValueError("index range error")
-
- x = self.X[index]
- if self.transform is not None:
- x = self.transform(x)
-
- y = self.Y[index]
-
- return x, y
|