You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prediction_dataset.py 1.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. """
  2. Implementation of PyTorch dataset class used for Prediction.
  3. Copyright (c) 2024 LAMDA. All rights reserved.
  4. """
  5. from typing import Any, Callable, List, Tuple, Optional
  6. import torch
  7. from torch.utils.data import Dataset
  8. class PredictionDataset(Dataset):
  9. """
  10. Dataset used for prediction.
  11. Parameters
  12. ----------
  13. X : List[Any]
  14. The input data.
  15. transform : Callable[..., Any], optional
  16. A function/transform that takes an object and returns a transformed version.
  17. Defaults to None.
  18. """
  19. def __init__(self, X: List[Any], transform: Optional[Callable[..., Any]] = None):
  20. if not isinstance(X, list):
  21. raise ValueError("X should be of type list.")
  22. self.X = X
  23. self.transform = transform
  24. def __len__(self) -> int:
  25. """
  26. Return the length of the dataset.
  27. Returns
  28. -------
  29. int
  30. The length of the dataset.
  31. """
  32. return len(self.X)
  33. def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]:
  34. """
  35. Get the item at the given index.
  36. Parameters
  37. ----------
  38. index : int
  39. The index of the item to get.
  40. Returns
  41. -------
  42. Tuple[Any, torch.Tensor]
  43. A tuple containing the object and its label.
  44. """
  45. if index >= len(self):
  46. raise ValueError("index range error")
  47. x = self.X[index]
  48. if self.transform is not None:
  49. x = self.transform(x)
  50. return x

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.