|
- import logging
- from typing import Any, Callable, List, Optional
-
- import numpy
- import torch
- from torch.utils.data import DataLoader
-
- from ablkit.learning import BasicNN, PredictionDataset, ClassificationDataset
- from ablkit.utils.logger import print_log
-
-
- class MultiLabelClassificationDataset(ClassificationDataset):
- def __init__(self, X: List[Any], Y: List[int], transform: Optional[Callable[..., Any]] = None):
- if (not isinstance(X, list)) or (not isinstance(Y, list)):
- raise ValueError("X and Y should be of type list.")
- self.X = X
- self.Y = torch.FloatTensor(numpy.stack(Y, axis=0)) # float32 for BCELoss
- self.transform = transform
-
-
- class BDDNN(BasicNN):
- def predict(
- self,
- data_loader: Optional[DataLoader] = None,
- X: Optional[List[Any]] = None,
- ) -> numpy.ndarray:
- if data_loader is not None and X is not None:
- print_log(
- "Predict the class of input data in data_loader instead of X.",
- logger="current",
- level=logging.WARNING,
- )
-
- if data_loader is None:
- dataset = PredictionDataset(X, self.test_transform)
- data_loader = DataLoader(
- dataset,
- batch_size=self.batch_size,
- num_workers=self.num_workers,
- collate_fn=self.collate_fn,
- pin_memory=torch.cuda.is_available(),
- )
- pred_probs = self._predict(data_loader).sigmoid()
- pred = torch.where(pred_probs > 0.5, 1, 0).int()
- return pred.cpu().numpy()
-
- def predict_proba(
- self,
- data_loader: Optional[DataLoader] = None,
- X: Optional[List[Any]] = None,
- ) -> numpy.ndarray:
- if data_loader is not None and X is not None:
- print_log(
- "Predict the class probability of input data in data_loader instead of X.",
- logger="current",
- level=logging.WARNING,
- )
-
- if data_loader is None:
- dataset = PredictionDataset(X, self.test_transform)
- data_loader = DataLoader(
- dataset,
- batch_size=self.batch_size,
- num_workers=self.num_workers,
- collate_fn=self.collate_fn,
- pin_memory=torch.cuda.is_available(),
- )
- pred_probs = self._predict(data_loader).sigmoid() # B x NC
- return pred_probs.cpu().numpy()
-
- def _data_loader(
- self,
- X: Optional[List[Any]],
- y: Optional[List[int]] = None,
- shuffle: Optional[bool] = True,
- ) -> DataLoader:
- if X is None:
- raise ValueError("X should not be None.")
- if y is None:
- y = [0] * len(X)
- if not len(y) == len(X):
- raise ValueError("X and y should have equal length.")
-
- dataset = MultiLabelClassificationDataset(X, y, transform=self.train_transform)
- data_loader = DataLoader(
- dataset,
- batch_size=self.batch_size,
- shuffle=shuffle,
- num_workers=self.num_workers,
- collate_fn=self.collate_fn,
- pin_memory=torch.cuda.is_available(),
- )
- return data_loader
|