You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prediction_dataset.py 1.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from typing import Any, Callable, List, Tuple, Optional
  2. import torch
  3. from torch.utils.data import Dataset
  4. class PredictionDataset(Dataset):
  5. """
  6. Dataset used for prediction.
  7. Parameters
  8. ----------
  9. X : List[Any]
  10. The input data.
  11. transform : Callable[..., Any], optional
  12. A function/transform that takes an object and returns a transformed version.
  13. Defaults to None.
  14. """
  15. def __init__(self, X: List[Any], transform: Optional[Callable[..., Any]] = None):
  16. if not isinstance(X, list):
  17. raise ValueError("X should be of type list.")
  18. self.X = X
  19. self.transform = transform
  20. def __len__(self) -> int:
  21. """
  22. Return the length of the dataset.
  23. Returns
  24. -------
  25. int
  26. The length of the dataset.
  27. """
  28. return len(self.X)
  29. def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]:
  30. """
  31. Get the item at the given index.
  32. Parameters
  33. ----------
  34. index : int
  35. The index of the item to get.
  36. Returns
  37. -------
  38. Tuple[Any, torch.Tensor]
  39. A tuple containing the object and its label.
  40. """
  41. if index >= len(self):
  42. raise ValueError("index range error")
  43. x = self.X[index]
  44. if self.transform is not None:
  45. x = self.transform(x)
  46. return x

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.