You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prediction_dataset.py 1.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. from typing import Any, Callable, List, Tuple
  2. import torch
  3. from torch.utils.data import Dataset
  4. class PredictionDataset(Dataset):
  5. def __init__(self, X: List[Any], transform: Callable[..., Any] = None):
  6. """
  7. Initialize the dataset used for classification task.
  8. Parameters
  9. ----------
  10. X : List[Any]
  11. The input data.
  12. transform : Callable[..., Any], optional
  13. A function/transform that takes in an object and returns a transformed version. Defaults to None.
  14. """
  15. if not isinstance(X, list):
  16. raise ValueError("X should be of type list.")
  17. self.X = X
  18. self.transform = transform
  19. def __len__(self) -> int:
  20. """
  21. Return the length of the dataset.
  22. Returns
  23. -------
  24. int
  25. The length of the dataset.
  26. """
  27. return len(self.X)
  28. def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]:
  29. """
  30. Get the item at the given index.
  31. Parameters
  32. ----------
  33. index : int
  34. The index of the item to get.
  35. Returns
  36. -------
  37. Tuple[Any, torch.Tensor]
  38. A tuple containing the object and its label.
  39. """
  40. if index >= len(self):
  41. raise ValueError("index range error")
  42. x = self.X[index]
  43. if self.transform is not None:
  44. x = self.transform(x)
  45. return x

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.