You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

regression_dataset.py 1.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. """
  2. Implementation of PyTorch dataset class used for regression.
  3. Copyright (c) 2024 LAMDA. All rights reserved.
  4. """
  5. from typing import Any, List, Tuple
  6. from torch.utils.data import Dataset
  7. class RegressionDataset(Dataset):
  8. """
  9. Dataset used for regression task.
  10. Parameters
  11. ----------
  12. X : List[Any]
  13. A list of objects representing the input data.
  14. Y : List[Any]
  15. A list of objects representing the output data.
  16. """
  17. def __init__(self, X: List[Any], Y: List[Any]):
  18. if (not isinstance(X, list)) or (not isinstance(Y, list)):
  19. raise ValueError("X and Y should be of type list.")
  20. if len(X) != len(Y):
  21. raise ValueError("Length of X and Y must be equal.")
  22. self.X = X
  23. self.Y = Y
  24. def __len__(self):
  25. """Return the length of the dataset.
  26. Returns
  27. -------
  28. int
  29. The length of the dataset.
  30. """
  31. return len(self.X)
  32. def __getitem__(self, index: int) -> Tuple[Any, Any]:
  33. """Get an item from the dataset.
  34. Parameters
  35. ----------
  36. index : int
  37. The index of the item to retrieve.
  38. Returns
  39. -------
  40. Tuple[Any, Any]
  41. A tuple containing the input and output data at the specified index.
  42. """
  43. if index >= len(self):
  44. raise ValueError("index range error")
  45. x = self.X[index]
  46. y = self.Y[index]
  47. return x, y

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.