You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bdd_nn.py 3.3 kB

3 months ago
3 months ago
3 months ago
3 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import logging
  2. from typing import Any, Callable, List, Optional
  3. import numpy
  4. import torch
  5. from torch.utils.data import DataLoader
  6. from ablkit.learning import BasicNN, PredictionDataset, ClassificationDataset
  7. from ablkit.utils.logger import print_log
  8. class MultiLabelClassificationDataset(ClassificationDataset):
  9. def __init__(self, X: List[Any], Y: List[int], transform: Optional[Callable[..., Any]] = None):
  10. if (not isinstance(X, list)) or (not isinstance(Y, list)):
  11. raise ValueError("X and Y should be of type list.")
  12. self.X = X
  13. self.Y = torch.FloatTensor(numpy.stack(Y, axis=0)) # float32 for BCELoss
  14. self.transform = transform
  15. class BDDNN(BasicNN):
  16. def predict(
  17. self,
  18. data_loader: Optional[DataLoader] = None,
  19. X: Optional[List[Any]] = None,
  20. ) -> numpy.ndarray:
  21. if data_loader is not None and X is not None:
  22. print_log(
  23. "Predict the class of input data in data_loader instead of X.",
  24. logger="current",
  25. level=logging.WARNING,
  26. )
  27. if data_loader is None:
  28. dataset = PredictionDataset(X, self.test_transform)
  29. data_loader = DataLoader(
  30. dataset,
  31. batch_size=self.batch_size,
  32. num_workers=self.num_workers,
  33. collate_fn=self.collate_fn,
  34. pin_memory=torch.cuda.is_available(),
  35. )
  36. pred_probs = self._predict(data_loader).sigmoid()
  37. pred = torch.where(pred_probs > 0.5, 1, 0).int()
  38. return pred.cpu().numpy()
  39. def predict_proba(
  40. self,
  41. data_loader: Optional[DataLoader] = None,
  42. X: Optional[List[Any]] = None,
  43. ) -> numpy.ndarray:
  44. if data_loader is not None and X is not None:
  45. print_log(
  46. "Predict the class probability of input data in data_loader instead of X.",
  47. logger="current",
  48. level=logging.WARNING,
  49. )
  50. if data_loader is None:
  51. dataset = PredictionDataset(X, self.test_transform)
  52. data_loader = DataLoader(
  53. dataset,
  54. batch_size=self.batch_size,
  55. num_workers=self.num_workers,
  56. collate_fn=self.collate_fn,
  57. pin_memory=torch.cuda.is_available(),
  58. )
  59. pred_probs = self._predict(data_loader).sigmoid() # B x NC
  60. return pred_probs.cpu().numpy()
  61. def _data_loader(
  62. self,
  63. X: Optional[List[Any]],
  64. y: Optional[List[int]] = None,
  65. shuffle: Optional[bool] = True,
  66. ) -> DataLoader:
  67. if X is None:
  68. raise ValueError("X should not be None.")
  69. if y is None:
  70. y = [0] * len(X)
  71. if not len(y) == len(X):
  72. raise ValueError("X and y should have equal length.")
  73. dataset = MultiLabelClassificationDataset(X, y, transform=self.train_transform)
  74. data_loader = DataLoader(
  75. dataset,
  76. batch_size=self.batch_size,
  77. shuffle=shuffle,
  78. num_workers=self.num_workers,
  79. collate_fn=self.collate_fn,
  80. pin_memory=torch.cuda.is_available(),
  81. )
  82. return data_loader

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.