You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

basic_model.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582
  1. # coding: utf-8
  2. # ================================================================#
  3. # Copyright (C) 2020 Freecss All rights reserved.
  4. #
  5. # File Name :basic_model.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2020/11/21
  9. # Description :
  10. #
  11. # ================================================================#
  12. import sys
  13. sys.path.append("..")
  14. import torch
  15. import numpy
  16. from torch.utils.data import Dataset, DataLoader
  17. import os
  18. from multiprocessing import Pool
  19. from typing import List, Any, T, Tuple, Optional, Callable
  20. class BasicDataset(Dataset):
  21. def __init__(self, X: List[Any], Y: List[Any]):
  22. """Initialize a basic dataset.
  23. Parameters
  24. ----------
  25. X : List[Any]
  26. A list of objects representing the input data.
  27. Y : List[Any]
  28. A list of objects representing the output data.
  29. """
  30. self.X = X
  31. self.Y = Y
  32. def __len__(self):
  33. """Return the length of the dataset.
  34. Returns
  35. -------
  36. int
  37. The length of the dataset.
  38. """
  39. return len(self.X)
  40. def __getitem__(self, index: int) -> Tuple(Any, Any):
  41. """Get an item from the dataset.
  42. Parameters
  43. ----------
  44. index : int
  45. The index of the item to retrieve.
  46. Returns
  47. -------
  48. Tuple(Any, Any)
  49. A tuple containing the input and output data at the specified index.
  50. """
  51. assert index < len(self), "index range error"
  52. img = self.X[index]
  53. label = self.Y[index]
  54. return (img, label)
  55. class XYDataset(Dataset):
  56. def __init__(self, X: List[Any], Y: List[int], transform: Callable[...] = None):
  57. """
  58. Initialize the dataset used for classification task.
  59. Parameters
  60. ----------
  61. X : List[Any]
  62. The input data.
  63. Y : List[int]
  64. The target data.
  65. transform : callable, optional
  66. A function/transform that takes in an object and returns a transformed version. Defaults to None.
  67. """
  68. self.X = X
  69. self.Y = torch.LongTensor(Y)
  70. self.n_sample = len(X)
  71. self.transform = transform
  72. def __len__(self) -> int:
  73. """
  74. Return the length of the dataset.
  75. Returns
  76. -------
  77. int
  78. The length of the dataset.
  79. """
  80. return len(self.X)
  81. def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]:
  82. """
  83. Get the item at the given index.
  84. Parameters
  85. ----------
  86. index : int
  87. The index of the item to get.
  88. Returns
  89. -------
  90. Tuple[Any, torch.Tensor]
  91. A tuple containing the object and its label.
  92. """
  93. assert index < len(self), "index range error"
  94. img = self.X[index]
  95. if self.transform is not None:
  96. img = self.transform(img)
  97. label = self.Y[index]
  98. return (img, label)
  99. class FakeRecorder:
  100. def __init__(self):
  101. pass
  102. def print(self, *x):
  103. pass
  104. class BasicModel:
  105. """
  106. Wrap NN models into the form of an sklearn estimator
  107. Parameters
  108. ----------
  109. model : torch.nn.Module
  110. The PyTorch model to be trained or used for prediction.
  111. criterion : torch.nn.Module
  112. The loss function used for training.
  113. optimizer : torch.nn.Module
  114. The optimizer used for training.
  115. device : torch.device
  116. The device on which the model will be trained or used for prediction.
  117. batch_size : int, optional
  118. The batch size used for training, by default 1.
  119. num_epochs : int, optional
  120. The number of epochs used for training, by default 1.
  121. stop_loss : Optional[float], optional
  122. The loss value at which to stop training, by default 0.01.
  123. num_workers : int, optional
  124. The number of workers used for loading data, by default 0.
  125. save_interval : Optional[int], optional
  126. The interval at which to save the model during training, by default None.
  127. save_dir : Optional[str], optional
  128. The directory in which to save the model during training, by default None.
  129. transform : Callable[..., Any], optional
  130. The transformation function used for data augmentation, by default None.
  131. collate_fn : Callable[[List[T]], Any], optional
  132. The function used to collate data, by default None.
  133. recorder : Any, optional
  134. The recorder used to record training progress, by default None.
  135. Attributes
  136. ----------
  137. model : torch.nn.Module
  138. The PyTorch model to be trained or used for prediction.
  139. batch_size : int
  140. The batch size used for training.
  141. num_epochs : int
  142. The number of epochs used for training.
  143. stop_loss : Optional[float]
  144. The loss value at which to stop training.
  145. num_workers : int
  146. The number of workers used for loading data.
  147. criterion : torch.nn.Module
  148. The loss function used for training.
  149. optimizer : torch.nn.Module
  150. The optimizer used for training.
  151. transform : Callable[..., Any]
  152. The transformation function used for data augmentation.
  153. device : torch.device
  154. The device on which the model will be trained or used for prediction.
  155. recorder : Any
  156. The recorder used to record training progress.
  157. save_interval : Optional[int]
  158. The interval at which to save the model during training.
  159. save_dir : Optional[str]
  160. The directory in which to save the model during training.
  161. collate_fn : Callable[[List[T]], Any]
  162. The function used to collate data.
  163. Methods
  164. -------
  165. fit(data_loader=None, X=None, y=None)
  166. Train the model.
  167. train_epoch(data_loader)
  168. Train the model for one epoch.
  169. predict(data_loader=None, X=None, print_prefix="")
  170. Predict the class of the input data.
  171. predict_proba(data_loader=None, X=None, print_prefix="")
  172. Predict the probability of each class for the input data.
  173. val(data_loader=None, X=None, y=None, print_prefix="")
  174. Validate the model.
  175. score(data_loader=None, X=None, y=None, print_prefix="")
  176. Score the model.
  177. _data_loader(X, y=None)
  178. Load data.
  179. save(epoch_id, save_dir="")
  180. Save the model.
  181. load(epoch_id, load_dir="")
  182. Load the model.
  183. """
  184. def __init__(
  185. self,
  186. model: torch.nn.Module,
  187. criterion: torch.nn.Module,
  188. optimizer: torch.nn.Module,
  189. device: torch.device,
  190. batch_size: int = 1,
  191. num_epochs: int = 1,
  192. stop_loss: Optional[float] = 0.01,
  193. num_workers: int = 0,
  194. save_interval: Optional[int] = None,
  195. save_dir: Optional[str] = None,
  196. transform: Callable[...] = None,
  197. collate_fn: Callable[[List[T]], Any] = None,
  198. recorder=None,
  199. ):
  200. self.model = model.to(device)
  201. self.batch_size = batch_size
  202. self.num_epochs = num_epochs
  203. self.stop_loss = stop_loss
  204. self.num_workers = num_workers
  205. self.criterion = criterion
  206. self.optimizer = optimizer
  207. self.transform = transform
  208. self.device = device
  209. if recorder is None:
  210. recorder = FakeRecorder()
  211. self.recorder = recorder
  212. self.save_interval = save_interval
  213. self.save_dir = save_dir
  214. self.collate_fn = collate_fn
  215. def _fit(self, data_loader, n_epoch, stop_loss):
  216. recorder = self.recorder
  217. recorder.print("model fitting")
  218. min_loss = 1e10
  219. for epoch in range(n_epoch):
  220. loss_value = self.train_epoch(data_loader)
  221. recorder.print(f"{epoch}/{n_epoch} model training loss is {loss_value}")
  222. if min_loss < 0 or loss_value < min_loss:
  223. min_loss = loss_value
  224. if self.save_interval is not None and (epoch + 1) % self.save_interval == 0:
  225. assert self.save_dir is not None
  226. self.save(epoch + 1, self.save_dir)
  227. if stop_loss is not None and loss_value < stop_loss:
  228. break
  229. recorder.print("Model fitted, minimal loss is ", min_loss)
  230. return loss_value
  231. def fit(
  232. self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None
  233. ) -> float:
  234. """
  235. Train the model.
  236. Parameters
  237. ----------
  238. data_loader : DataLoader, optional
  239. The data loader used for training, by default None
  240. X : List[Any], optional
  241. The input data, by default None
  242. y : List[int], optional
  243. The target data, by default None
  244. Returns
  245. -------
  246. float
  247. The loss value of the trained model.
  248. """
  249. if data_loader is None:
  250. data_loader = self._data_loader(X, y)
  251. return self._fit(data_loader, self.num_epochs, self.stop_loss)
  252. def train_epoch(self, data_loader: DataLoader):
  253. """
  254. Train the model for one epoch.
  255. Parameters
  256. ----------
  257. data_loader : DataLoader
  258. The data loader used for training.
  259. Returns
  260. -------
  261. float
  262. The loss value of the trained model.
  263. """
  264. model = self.model
  265. criterion = self.criterion
  266. optimizer = self.optimizer
  267. device = self.device
  268. model.train()
  269. total_loss, total_num = 0.0, 0
  270. for data, target in data_loader:
  271. data, target = data.to(device), target.to(device)
  272. out = model(data)
  273. loss = criterion(out, target)
  274. optimizer.zero_grad()
  275. loss.backward()
  276. optimizer.step()
  277. total_loss += loss.item() * data.size(0)
  278. total_num += data.size(0)
  279. return total_loss / total_num
  280. def _predict(self, data_loader):
  281. model = self.model
  282. device = self.device
  283. model.eval()
  284. with torch.no_grad():
  285. results = []
  286. for data, _ in data_loader:
  287. data = data.to(device)
  288. out = model(data)
  289. results.append(out)
  290. return torch.cat(results, axis=0)
  291. def predict(
  292. self,
  293. data_loader: DataLoader = None,
  294. X: List[Any] = None,
  295. print_prefix: str = "",
  296. ) -> numpy.ndarray:
  297. """
  298. Predict the class of the input data.
  299. Parameters
  300. ----------
  301. data_loader : DataLoader, optional
  302. The data loader used for prediction, by default None
  303. X : List[Any], optional
  304. The input data, by default None
  305. print_prefix : str, optional
  306. The prefix used for printing, by default ""
  307. Returns
  308. -------
  309. numpy.ndarray
  310. The predicted class of the input data.
  311. """
  312. recorder = self.recorder
  313. recorder.print("Start Predict Class ", print_prefix)
  314. if data_loader is None:
  315. data_loader = self._data_loader(X)
  316. return self._predict(data_loader).argmax(axis=1).cpu().numpy()
  317. def predict_proba(
  318. self,
  319. data_loader: DataLoader = None,
  320. X: List[Any] = None,
  321. print_prefix: str = "",
  322. ) -> numpy.ndarray:
  323. """
  324. Predict the probability of each class for the input data.
  325. Parameters
  326. ----------
  327. data_loader : DataLoader, optional
  328. The data loader used for prediction, by default None
  329. X : List[Any], optional
  330. The input data, by default None
  331. print_prefix : str, optional
  332. The prefix used for printing, by default ""
  333. Returns
  334. -------
  335. numpy.ndarray
  336. The predicted probability of each class for the input data.
  337. """
  338. recorder = self.recorder
  339. recorder.print("Start Predict Probability ", print_prefix)
  340. if data_loader is None:
  341. data_loader = self._data_loader(X)
  342. return self._predict(data_loader).softmax(axis=1).cpu().numpy()
  343. def _val(self, data_loader):
  344. model = self.model
  345. criterion = self.criterion
  346. device = self.device
  347. model.eval()
  348. total_correct_num, total_num, total_loss = 0, 0, 0.0
  349. with torch.no_grad():
  350. for data, target in data_loader:
  351. data, target = data.to(device), target.to(device)
  352. out = model(data)
  353. if len(out.shape) > 1:
  354. correct_num = sum(target == out.argmax(axis=1)).item()
  355. else:
  356. correct_num = sum(target == (out > 0.5)).item()
  357. loss = criterion(out, target)
  358. total_loss += loss.item() * data.size(0)
  359. total_correct_num += correct_num
  360. total_num += data.size(0)
  361. mean_loss = total_loss / total_num
  362. accuracy = total_correct_num / total_num
  363. return mean_loss, accuracy
  364. def val(
  365. self,
  366. data_loader: DataLoader = None,
  367. X: List[Any] = None,
  368. y: List[int] = None,
  369. print_prefix: str = "",
  370. ) -> float:
  371. """
  372. Validate the model.
  373. Parameters
  374. ----------
  375. data_loader : DataLoader, optional
  376. The data loader used for validation, by default None
  377. X : List[Any], optional
  378. The input data, by default None
  379. y : List[int], optional
  380. The target data, by default None
  381. print_prefix : str, optional
  382. The prefix used for printing, by default ""
  383. Returns
  384. -------
  385. float
  386. The accuracy of the model.
  387. """
  388. recorder = self.recorder
  389. recorder.print("Start val ", print_prefix)
  390. if data_loader is None:
  391. data_loader = self._data_loader(X, y)
  392. mean_loss, accuracy = self._val(data_loader)
  393. recorder.print(
  394. "[%s] Val loss: %f, accuray: %f" % (print_prefix, mean_loss, accuracy)
  395. )
  396. return accuracy
  397. def score(
  398. self,
  399. data_loader: DataLoader = None,
  400. X: List[Any] = None,
  401. y: List[int] = None,
  402. print_prefix: str = "",
  403. ) -> float:
  404. """
  405. Score the model.
  406. Parameters
  407. ----------
  408. data_loader : DataLoader, optional
  409. The data loader used for scoring, by default None
  410. X : List[Any], optional
  411. The input data, by default None
  412. y : List[int], optional
  413. The target data, by default None
  414. print_prefix : str, optional
  415. The prefix used for printing, by default ""
  416. Returns
  417. -------
  418. float
  419. The accuracy of the model.
  420. """
  421. return self.val(data_loader, X, y, print_prefix)
  422. def _data_loader(
  423. self,
  424. X: List[Any],
  425. y: List[int] = None,
  426. ) -> DataLoader:
  427. """
  428. Generate data_loader for user provided data.
  429. Parameters
  430. ----------
  431. X : List[Any]
  432. The input data.
  433. y : List[int], optional
  434. The target data, by default None
  435. Returns
  436. -------
  437. DataLoader
  438. The data loader.
  439. """
  440. collate_fn = self.collate_fn
  441. transform = self.transform
  442. if y is None:
  443. y = [0] * len(X)
  444. dataset = XYDataset(X, y, transform=transform)
  445. sampler = None
  446. data_loader = DataLoader(
  447. dataset,
  448. batch_size=self.batch_size,
  449. shuffle=False,
  450. sampler=sampler,
  451. num_workers=int(self.num_workers),
  452. collate_fn=collate_fn,
  453. )
  454. return data_loader
  455. def save(self, epoch_id: int, save_dir: str = ""):
  456. """
  457. Save the model and the optimizer.
  458. Parameters
  459. ----------
  460. epoch_id : int
  461. The epoch id.
  462. save_dir : str, optional
  463. The directory to save the model, by default ""
  464. """
  465. recorder = self.recorder
  466. if not os.path.exists(save_dir):
  467. os.makedirs(save_dir)
  468. recorder.print("Saving model and opter")
  469. save_path = os.path.join(save_dir, str(epoch_id) + "_net.pth")
  470. torch.save(self.model.state_dict(), save_path)
  471. save_path = os.path.join(save_dir, str(epoch_id) + "_opt.pth")
  472. torch.save(self.optimizer.state_dict(), save_path)
  473. def load(self, epoch_id: int, load_dir: str = ""):
  474. """
  475. Load the model and the optimizer.
  476. Parameters
  477. ----------
  478. epoch_id : int
  479. The epoch id.
  480. load_dir : str, optional
  481. The directory to load the model, by default ""
  482. """
  483. recorder = self.recorder
  484. recorder.print("Loading model and opter")
  485. load_path = os.path.join(load_dir, str(epoch_id) + "_net.pth")
  486. self.model.load_state_dict(torch.load(load_path))
  487. load_path = os.path.join(load_dir, str(epoch_id) + "_opt.pth")
  488. self.optimizer.load_state_dict(torch.load(load_path))
  489. if __name__ == "__main__":
  490. pass

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.