You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

classification_dataset.py 1.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. """
  2. Implementation of PyTorch dataset class used for classification.
  3. Copyright (c) 2024 LAMDA. All rights reserved.
  4. """
  5. from typing import Any, Callable, List, Tuple, Optional
  6. import torch
  7. from torch.utils.data import Dataset
  8. class ClassificationDataset(Dataset):
  9. """
  10. Dataset used for classification task.
  11. Parameters
  12. ----------
  13. X : List[Any]
  14. The input data.
  15. Y : List[int]
  16. The target data.
  17. transform : Callable[..., Any], optional
  18. A function/transform that takes an object and returns a transformed version.
  19. Defaults to None.
  20. """
  21. def __init__(self, X: List[Any], Y: List[int], transform: Optional[Callable[..., Any]] = None):
  22. if (not isinstance(X, list)) or (not isinstance(Y, list)):
  23. raise ValueError("X and Y should be of type list.")
  24. if len(X) != len(Y):
  25. raise ValueError("Length of X and Y must be equal.")
  26. self.X = X
  27. self.Y = torch.LongTensor(Y)
  28. self.transform = transform
  29. def __len__(self) -> int:
  30. """
  31. Return the length of the dataset.
  32. Returns
  33. -------
  34. int
  35. The length of the dataset.
  36. """
  37. return len(self.X)
  38. def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]:
  39. """
  40. Get the item at the given index.
  41. Parameters
  42. ----------
  43. index : int
  44. The index of the item to get.
  45. Returns
  46. -------
  47. Tuple[Any, torch.Tensor]
  48. A tuple containing the object and its label.
  49. """
  50. if index >= len(self):
  51. raise ValueError("index range error")
  52. x = self.X[index]
  53. if self.transform is not None:
  54. x = self.transform(x)
  55. y = self.Y[index]
  56. return x, y

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.