You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

basic_nn.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559
  1. """
  2. This module contains the class BasicNN, which servers as a wrapper for PyTorch NN models.
  3. Copyright (c) 2024 LAMDA. All rights reserved.
  4. """
  5. from __future__ import annotations
  6. import logging
  7. import os
  8. from typing import Any, Callable, List, Optional, Tuple, Union
  9. import numpy
  10. import torch
  11. from torch.utils.data import DataLoader
  12. from ..utils.logger import print_log
  13. from .torch_dataset import ClassificationDataset, PredictionDataset
  14. class BasicNN:
  15. """
  16. Wrap NN models into the form of an sklearn estimator.
  17. Parameters
  18. ----------
  19. model : torch.nn.Module
  20. The PyTorch model to be trained or used for prediction.
  21. loss_fn : torch.nn.Module
  22. The loss function used for training.
  23. optimizer : torch.optim.Optimizer
  24. The optimizer used for training.
  25. scheduler : Callable[..., Any], optional
  26. The learning rate scheduler used for training, which will be called
  27. at the end of each run of the ``fit`` method. It should implement the
  28. ``step`` method. Defaults to None.
  29. device : Union[torch.device, str]
  30. The device on which the model will be trained or used for prediction,
  31. Defaults to torch.device("cpu").
  32. batch_size : int, optional
  33. The batch size used for training. Defaults to 32.
  34. num_epochs : int, optional
  35. The number of epochs used for training. Defaults to 1.
  36. stop_loss : float, optional
  37. The loss value at which to stop training. Defaults to 0.0001.
  38. num_workers : int
  39. The number of workers used for loading data. Defaults to 0.
  40. save_interval : int, optional
  41. The model will be saved every ``save_interval`` epoch during training. Defaults to None.
  42. save_dir : str, optional
  43. The directory in which to save the model during training. Defaults to None.
  44. train_transform : Callable[..., Any], optional
  45. A function/transform that takes an object and returns a transformed version used
  46. in the ``fit`` and ``train_epoch`` methods. Defaults to None.
  47. test_transform : Callable[..., Any], optional
  48. A function/transform that takes an object and returns a transformed version in the
  49. ``predict``, ``predict_proba`` and ``score`` methods. Defaults to None.
  50. collate_fn : Callable[[List[T]], Any], optional
  51. The function used to collate data. Defaults to None.
  52. """
  53. def __init__(
  54. self,
  55. model: torch.nn.Module,
  56. loss_fn: torch.nn.Module,
  57. optimizer: torch.optim.Optimizer,
  58. scheduler: Optional[Callable[..., Any]] = None,
  59. device: Union[torch.device, str] = torch.device("cpu"),
  60. batch_size: int = 32,
  61. num_epochs: int = 1,
  62. stop_loss: Optional[float] = 0.0001,
  63. num_workers: int = 0,
  64. save_interval: Optional[int] = None,
  65. save_dir: Optional[str] = None,
  66. train_transform: Optional[Callable[..., Any]] = None,
  67. test_transform: Optional[Callable[..., Any]] = None,
  68. collate_fn: Optional[Callable[[List[Any]], Any]] = None,
  69. ) -> None:
  70. if not isinstance(model, torch.nn.Module):
  71. raise TypeError("model must be an instance of torch.nn.Module")
  72. if not isinstance(loss_fn, torch.nn.Module):
  73. raise TypeError("loss_fn must be an instance of torch.nn.Module")
  74. if not isinstance(optimizer, torch.optim.Optimizer):
  75. raise TypeError("optimizer must be an instance of torch.optim.Optimizer")
  76. if scheduler is not None and not hasattr(scheduler, "step"):
  77. raise NotImplementedError("scheduler should implement the ``step`` method")
  78. if not isinstance(device, torch.device):
  79. if not isinstance(device, str):
  80. raise TypeError(
  81. "device must be an instance of torch.device or a str indicating "
  82. + "the target device"
  83. )
  84. else:
  85. device = torch.device(device)
  86. if not isinstance(batch_size, int):
  87. raise TypeError("batch_size must be an integer")
  88. if not isinstance(num_epochs, int):
  89. raise TypeError("num_epochs must be an integer")
  90. if stop_loss is not None and not isinstance(stop_loss, float):
  91. raise TypeError("stop_loss must be a float")
  92. if not isinstance(num_workers, int):
  93. raise TypeError("num_workers must be an integer")
  94. if save_interval is not None and not isinstance(save_interval, int):
  95. raise TypeError("save_interval must be an integer")
  96. if save_dir is not None and not isinstance(save_dir, str):
  97. raise TypeError("save_dir must be a string")
  98. if train_transform is not None and not callable(train_transform):
  99. raise TypeError("train_transform must be callable")
  100. if test_transform is not None and not callable(test_transform):
  101. raise TypeError("test_transform must be callable")
  102. if collate_fn is not None and not callable(collate_fn):
  103. raise TypeError("collate_fn must be callable")
  104. self.model = model.to(device)
  105. self.loss_fn = loss_fn
  106. self.optimizer = optimizer
  107. self.scheduler = scheduler
  108. self.device = device
  109. self.batch_size = batch_size
  110. self.num_epochs = num_epochs
  111. self.stop_loss = stop_loss
  112. self.num_workers = num_workers
  113. self.save_interval = save_interval
  114. self.save_dir = save_dir
  115. self.train_transform = train_transform
  116. self.test_transform = test_transform
  117. self.collate_fn = collate_fn
  118. if self.save_interval is not None and self.save_dir is None:
  119. raise ValueError("save_dir should not be None if save_interval is not None.")
  120. if self.train_transform is not None and self.test_transform is None:
  121. print_log(
  122. "Transform used in the training phase will be used in prediction.",
  123. logger="current",
  124. level=logging.WARNING,
  125. )
  126. self.test_transform = self.train_transform
  127. def _fit(self, data_loader: DataLoader) -> BasicNN:
  128. """
  129. Internal method to fit the model on data for ``self.num_epochs`` times,
  130. with early stopping.
  131. Parameters
  132. ----------
  133. data_loader : DataLoader
  134. Data loader providing training samples.
  135. Returns
  136. -------
  137. BasicNN
  138. The model itself after training.
  139. """
  140. if not isinstance(data_loader, DataLoader):
  141. raise TypeError(
  142. f"data_loader must be an instance of torch.utils.data.DataLoader, "
  143. f"but got {type(data_loader)}"
  144. )
  145. for epoch in range(self.num_epochs):
  146. loss_value = self.train_epoch(data_loader)
  147. if self.save_interval is not None and (epoch + 1) % self.save_interval == 0:
  148. self.save(epoch + 1)
  149. if self.stop_loss is not None and loss_value < self.stop_loss:
  150. break
  151. if self.scheduler is not None:
  152. self.scheduler.step()
  153. print_log(f"model loss: {loss_value:.5f}", logger="current")
  154. return self
  155. def fit(
  156. self,
  157. data_loader: Optional[DataLoader] = None,
  158. X: Optional[List[Any]] = None,
  159. y: Optional[List[int]] = None,
  160. ) -> BasicNN:
  161. """
  162. Train the model for self.num_epochs times or until the average loss on one epoch
  163. is less than self.stop_loss. It supports training with either a DataLoader
  164. object (data_loader) or a pair of input data (X) and target labels (y). If both
  165. data_loader and (X, y) are provided, the method will prioritize using the data_loader.
  166. Parameters
  167. ----------
  168. data_loader : DataLoader, optional
  169. The data loader used for training. Defaults to None.
  170. X : List[Any], optional
  171. The input data. Defaults to None.
  172. y : List[int], optional
  173. The target data. Defaults to None.
  174. Returns
  175. -------
  176. BasicNN
  177. The model itself after training.
  178. """
  179. if data_loader is not None and X is not None:
  180. print_log(
  181. "data_loader will be used to train the model instead of X and y.",
  182. logger="current",
  183. level=logging.WARNING,
  184. )
  185. if data_loader is None:
  186. if X is None:
  187. raise ValueError("data_loader and X can not be None simultaneously.")
  188. else:
  189. data_loader = self._data_loader(X, y)
  190. return self._fit(data_loader)
  191. def train_epoch(self, data_loader: DataLoader) -> float:
  192. """
  193. Train the model with an instance of DataLoader (data_loader) for one epoch.
  194. Parameters
  195. ----------
  196. data_loader : DataLoader
  197. The data loader used for training.
  198. Returns
  199. -------
  200. float
  201. The average loss on one epoch.
  202. """
  203. model = self.model
  204. loss_fn = self.loss_fn
  205. optimizer = self.optimizer
  206. device = self.device
  207. model.train()
  208. total_loss, total_num = 0.0, 0
  209. for data, target in data_loader:
  210. data, target = data.to(device), target.to(device)
  211. out = model(data)
  212. loss = loss_fn(out, target)
  213. optimizer.zero_grad()
  214. loss.backward()
  215. optimizer.step()
  216. total_loss += loss.item() * data.size(0)
  217. total_num += data.size(0)
  218. return total_loss / total_num
  219. def _predict(self, data_loader: DataLoader) -> torch.Tensor:
  220. """
  221. Internal method to predict the outputs given a DataLoader.
  222. Parameters
  223. ----------
  224. data_loader : DataLoader
  225. The DataLoader providing input samples.
  226. Returns
  227. -------
  228. torch.Tensor
  229. Raw output from the model.
  230. """
  231. if not isinstance(data_loader, DataLoader):
  232. raise TypeError(
  233. f"data_loader must be an instance of torch.utils.data.DataLoader, "
  234. f"but got {type(data_loader)}"
  235. )
  236. model = self.model
  237. device = self.device
  238. model.eval()
  239. with torch.no_grad():
  240. results = []
  241. for data in data_loader:
  242. data = data.to(device)
  243. out = model(data)
  244. results.append(out)
  245. return torch.cat(results, axis=0)
  246. def predict(
  247. self,
  248. data_loader: Optional[DataLoader] = None,
  249. X: Optional[List[Any]] = None,
  250. ) -> numpy.ndarray:
  251. """
  252. Predict the class of the input data. This method supports prediction with either
  253. a DataLoader object (data_loader) or a list of input data (X). If both data_loader
  254. and X are provided, the method will predict the input data in data_loader
  255. instead of X.
  256. Parameters
  257. ----------
  258. data_loader : DataLoader, optional
  259. The data loader used for prediction. Defaults to None.
  260. X : List[Any], optional
  261. The input data. Defaults to None.
  262. Returns
  263. -------
  264. numpy.ndarray
  265. The predicted class of the input data.
  266. """
  267. if data_loader is not None and X is not None:
  268. print_log(
  269. "Predict the class of input data in data_loader instead of X.",
  270. logger="current",
  271. level=logging.WARNING,
  272. )
  273. if data_loader is None:
  274. dataset = PredictionDataset(X, self.test_transform)
  275. data_loader = DataLoader(
  276. dataset,
  277. batch_size=self.batch_size,
  278. num_workers=self.num_workers,
  279. collate_fn=self.collate_fn,
  280. pin_memory=torch.cuda.is_available(),
  281. )
  282. return self._predict(data_loader).argmax(axis=1).cpu().numpy()
  283. def predict_proba(
  284. self,
  285. data_loader: Optional[DataLoader] = None,
  286. X: Optional[List[Any]] = None,
  287. ) -> numpy.ndarray:
  288. """
  289. Predict the probability of each class for the input data. This method supports
  290. prediction with either a DataLoader object (data_loader) or a list of input data (X).
  291. If both data_loader and X are provided, the method will predict the input data in
  292. data_loader instead of X.
  293. Parameters
  294. ----------
  295. data_loader : DataLoader, optional
  296. The data loader used for prediction. Defaults to None.
  297. X : List[Any], optional
  298. The input data. Defaults to None.
  299. Warning
  300. -------
  301. This method calculates the probability by applying a softmax function to the output
  302. of the neural network. If your neural network already includes a softmax function
  303. as its final activation, applying softmax again here will lead to incorrect probabilities.
  304. Returns
  305. -------
  306. numpy.ndarray
  307. The predicted probability of each class for the input data.
  308. """
  309. if data_loader is not None and X is not None:
  310. print_log(
  311. "Predict the class probability of input data in data_loader instead of X.",
  312. logger="current",
  313. level=logging.WARNING,
  314. )
  315. if data_loader is None:
  316. dataset = PredictionDataset(X, self.test_transform)
  317. data_loader = DataLoader(
  318. dataset,
  319. batch_size=self.batch_size,
  320. num_workers=self.num_workers,
  321. collate_fn=self.collate_fn,
  322. pin_memory=torch.cuda.is_available(),
  323. )
  324. return self._predict(data_loader).softmax(axis=1).cpu().numpy()
  325. def _score(self, data_loader: DataLoader) -> Tuple[float, float]:
  326. """
  327. Internal method to compute loss and accuracy for the data provided through a DataLoader.
  328. Parameters
  329. ----------
  330. data_loader : DataLoader
  331. Data loader to use for evaluation.
  332. Returns
  333. -------
  334. Tuple[float, float]
  335. mean_loss: float, The mean loss of the model on the provided data.
  336. accuracy: float, The accuracy of the model on the provided data.
  337. """
  338. if not isinstance(data_loader, DataLoader):
  339. raise TypeError(
  340. f"data_loader must be an instance of torch.utils.data.DataLoader, "
  341. f"but got {type(data_loader)}"
  342. )
  343. model = self.model
  344. loss_fn = self.loss_fn
  345. device = self.device
  346. model.eval()
  347. total_correct_num, total_num, total_loss = 0, 0, 0.0
  348. with torch.no_grad():
  349. for data, target in data_loader:
  350. data, target = data.to(device), target.to(device)
  351. out = model(data)
  352. if len(out.shape) > 1:
  353. correct_num = (target == out.argmax(axis=1)).sum().item()
  354. else:
  355. correct_num = (target == (out > 0.5)).sum().item()
  356. loss = loss_fn(out, target)
  357. total_loss += loss.item() * data.size(0)
  358. total_correct_num += correct_num
  359. total_num += data.size(0)
  360. mean_loss = total_loss / total_num
  361. accuracy = total_correct_num / total_num
  362. return mean_loss, accuracy
  363. def score(
  364. self,
  365. data_loader: Optional[DataLoader] = None,
  366. X: Optional[List[Any]] = None,
  367. y: Optional[List[int]] = None,
  368. ) -> float:
  369. """
  370. Validate the model. It supports validation with either a DataLoader object (data_loader)
  371. or a pair of input data (X) and ground truth labels (y). If both data_loader and
  372. (X, y) are provided, the method will prioritize using the data_loader.
  373. Parameters
  374. ----------
  375. data_loader : DataLoader, optional
  376. The data loader used for scoring. Defaults to None.
  377. X : List[Any], optional
  378. The input data. Defaults to None.
  379. y : List[int], optional
  380. The target data. Defaults to None.
  381. Returns
  382. -------
  383. float
  384. The accuracy of the model.
  385. """
  386. print_log("Start machine learning model validation", logger="current")
  387. if data_loader is not None and X is not None:
  388. print_log(
  389. "data_loader will be used to validate the model instead of X and y.",
  390. logger="current",
  391. level=logging.WARNING,
  392. )
  393. if data_loader is None:
  394. if X is None or y is None:
  395. raise ValueError("data_loader and (X, y) can not be None simultaneously.")
  396. else:
  397. data_loader = self._data_loader(X, y)
  398. mean_loss, accuracy = self._score(data_loader)
  399. print_log(f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current")
  400. return accuracy
  401. def _data_loader(
  402. self,
  403. X: Optional[List[Any]],
  404. y: Optional[List[int]] = None,
  405. shuffle: Optional[bool] = True,
  406. ) -> DataLoader:
  407. """
  408. Generate a DataLoader for user-provided input data and target labels.
  409. Parameters
  410. ----------
  411. X : List[Any]
  412. Input samples.
  413. y : List[int], optional
  414. Target labels. If None, dummy labels are created. Defaults to None.
  415. shuffle : bool, optional
  416. Whether to shuffle the data. Defaults to True.
  417. Returns
  418. -------
  419. DataLoader
  420. A DataLoader providing batches of (X, y) pairs.
  421. """
  422. if X is None:
  423. raise ValueError("X should not be None.")
  424. if y is None:
  425. y = [0] * len(X)
  426. if not len(y) == len(X):
  427. raise ValueError("X and y should have equal length.")
  428. dataset = ClassificationDataset(X, y, transform=self.train_transform)
  429. data_loader = DataLoader(
  430. dataset,
  431. batch_size=self.batch_size,
  432. shuffle=shuffle,
  433. num_workers=self.num_workers,
  434. collate_fn=self.collate_fn,
  435. pin_memory=torch.cuda.is_available(),
  436. )
  437. return data_loader
  438. def save(self, epoch_id: int = 0, save_path: Optional[str] = None) -> None:
  439. """
  440. Save the model and the optimizer. User can either provide a save_path or specify
  441. the epoch_id at which the model and optimizer is saved. if both save_path and
  442. epoch_id are provided, save_path will be used. If only epoch_id is specified,
  443. model and optimizer will be saved to the path f"model_checkpoint_epoch_{epoch_id}.pth"
  444. under ``self.save_dir``. save_path and epoch_id can not be None simultaneously.
  445. Parameters
  446. ----------
  447. epoch_id : int
  448. The epoch id.
  449. save_path : str, optional
  450. The path to save the model. Defaults to None.
  451. """
  452. if self.save_dir is None and save_path is None:
  453. raise ValueError("'save_dir' and 'save_path' should not be None simultaneously.")
  454. if save_path is not None:
  455. if not os.path.exists(os.path.dirname(save_path)):
  456. os.makedirs(os.path.dirname(save_path))
  457. else:
  458. save_path = os.path.join(self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth")
  459. if not os.path.exists(self.save_dir):
  460. os.makedirs(self.save_dir)
  461. print_log(f"Checkpoints will be saved to {save_path}", logger="current")
  462. save_parma_dic = {
  463. "model": self.model.state_dict(),
  464. "optimizer": self.optimizer.state_dict(),
  465. }
  466. torch.save(save_parma_dic, save_path)
  467. def load(self, load_path: str) -> None:
  468. """
  469. Load the model and the optimizer.
  470. Parameters
  471. ----------
  472. load_path : str
  473. The directory to load the model. Defaults to "".
  474. """
  475. if load_path is None:
  476. raise ValueError("Load path should not be None.")
  477. print_log(
  478. f"Loads checkpoint by local backend from path: {load_path}",
  479. logger="current",
  480. )
  481. param_dic = torch.load(load_path)
  482. self.model.load_state_dict(param_dic["model"])
  483. if "optimizer" in param_dic.keys():
  484. self.optimizer.load_state_dict(param_dic["optimizer"])

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.