You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_converter.py 9.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. import torch
  2. import copy
  3. from typing import Any, Callable, List, Optional
  4. from .abl_model import ABLModel
  5. from .basic_nn import BasicNN
  6. from lambdaLearn.Base.DeepModelMixin import DeepModelMixin
  7. class ModelConverter:
  8. """
  9. This class provides functionality to convert LambdaLearn models to ABLkit models.
  10. """
  11. def __init__(self) -> None:
  12. pass
  13. def convert_lambdalearn_to_ablmodel(
  14. self,
  15. lambdalearn_model,
  16. loss_fn: torch.nn.Module,
  17. optimizer_dict: dict,
  18. scheduler_dict: Optional[dict] = None,
  19. device: Optional[torch.device] = None,
  20. batch_size: int = 32,
  21. num_epochs: int = 1,
  22. stop_loss: Optional[float] = 0.0001,
  23. num_workers: int = 0,
  24. save_interval: Optional[int] = None,
  25. save_dir: Optional[str] = None,
  26. train_transform: Callable[..., Any] = None,
  27. test_transform: Callable[..., Any] = None,
  28. collate_fn: Callable[[List[Any]], Any] = None,
  29. ):
  30. """
  31. Convert a lambdalearn model to an ABLModel. If the lambdalearn model is an instance of
  32. DeepModelMixin, its network will be used as the model of BasicNN. Otherwise, the lambdalearn
  33. model should implement ``fit`` and ``predict`` methods.
  34. Parameters
  35. ----------
  36. lambdalearn_model : Union[DeepModelMixin, Any]
  37. The LambdaLearn model to be converted.
  38. loss_fn : torch.nn.Module
  39. The loss function used for training.
  40. optimizer_dict : dict
  41. The dict contains necessary parameters to construct a optimizer used for training.
  42. The optimizer class is specified by the ``optimizer`` key.
  43. scheduler_dict : dict, optional
  44. The dict contains necessary parameters to construct a learning rate scheduler used
  45. for training, which will be called at the end of each run of the ``fit`` method.
  46. The scheduler class is specified by the ``scheduler`` key. It should implement the
  47. ``step`` method. Defaults to None.
  48. device : torch.device, optional
  49. The device on which the model will be trained or used for prediction,
  50. Defaults to torch.device("cpu").
  51. batch_size : int, optional
  52. The batch size used for training. Defaults to 32.
  53. num_epochs : int, optional
  54. The number of epochs used for training. Defaults to 1.
  55. stop_loss : float, optional
  56. The loss value at which to stop training. Defaults to 0.0001.
  57. num_workers : int
  58. The number of workers used for loading data. Defaults to 0.
  59. save_interval : int, optional
  60. The model will be saved every ``save_interval`` epoch during training. Defaults to None.
  61. save_dir : str, optional
  62. The directory in which to save the model during training. Defaults to None.
  63. train_transform : Callable[..., Any], optional
  64. A function/transform that takes an object and returns a transformed version used
  65. in the `fit` and `train_epoch` methods. Defaults to None.
  66. test_transform : Callable[..., Any], optional
  67. A function/transform that takes an object and returns a transformed version in the
  68. `predict`, `predict_proba` and `score` methods. Defaults to None.
  69. collate_fn : Callable[[List[T]], Any], optional
  70. The function used to collate data. Defaults to None.
  71. Returns
  72. -------
  73. ABLModel
  74. The converted ABLModel instance.
  75. """
  76. if isinstance(lambdalearn_model, DeepModelMixin):
  77. base_model = self.convert_lambdalearn_to_basicnn(
  78. lambdalearn_model,
  79. loss_fn,
  80. optimizer_dict,
  81. scheduler_dict,
  82. device,
  83. batch_size,
  84. num_epochs,
  85. stop_loss,
  86. num_workers,
  87. save_interval,
  88. save_dir,
  89. train_transform,
  90. test_transform,
  91. collate_fn,
  92. )
  93. return ABLModel(base_model)
  94. if not (hasattr(lambdalearn_model, "fit") and hasattr(lambdalearn_model, "predict")):
  95. raise NotImplementedError(
  96. "The lambdalearn_model should be an instance of DeepModelMixin, or implement "
  97. + "fit and predict methods."
  98. )
  99. return ABLModel(lambdalearn_model)
  100. def convert_lambdalearn_to_basicnn(
  101. self,
  102. lambdalearn_model: DeepModelMixin,
  103. loss_fn: torch.nn.Module,
  104. optimizer_dict: dict,
  105. scheduler_dict: Optional[dict] = None,
  106. device: Optional[torch.device] = None,
  107. batch_size: int = 32,
  108. num_epochs: int = 1,
  109. stop_loss: Optional[float] = 0.0001,
  110. num_workers: int = 0,
  111. save_interval: Optional[int] = None,
  112. save_dir: Optional[str] = None,
  113. train_transform: Callable[..., Any] = None,
  114. test_transform: Callable[..., Any] = None,
  115. collate_fn: Callable[[List[Any]], Any] = None,
  116. ):
  117. """
  118. Convert a lambdalearn model to a BasicNN. If the lambdalearn model is an instance of
  119. DeepModelMixin, its network will be used as the model of BasicNN.
  120. Parameters
  121. ----------
  122. lambdalearn_model : Union[DeepModelMixin, Any]
  123. The LambdaLearn model to be converted.
  124. loss_fn : torch.nn.Module
  125. The loss function used for training.
  126. optimizer_dict : dict
  127. The dict contains necessary parameters to construct a optimizer used for training.
  128. scheduler_dict : dict, optional
  129. The dict contains necessary parameters to construct a learning rate scheduler used
  130. for training, which will be called at the end of each run of the ``fit`` method.
  131. The scheduler class is specified by the ``scheduler`` key. It should implement the
  132. ``step`` method. Defaults to None.
  133. device : torch.device, optional
  134. The device on which the model will be trained or used for prediction,
  135. Defaults to torch.device("cpu").
  136. batch_size : int, optional
  137. The batch size used for training. Defaults to 32.
  138. num_epochs : int, optional
  139. The number of epochs used for training. Defaults to 1.
  140. stop_loss : float, optional
  141. The loss value at which to stop training. Defaults to 0.0001.
  142. num_workers : int
  143. The number of workers used for loading data. Defaults to 0.
  144. save_interval : int, optional
  145. The model will be saved every ``save_interval`` epoch during training. Defaults to None.
  146. save_dir : str, optional
  147. The directory in which to save the model during training. Defaults to None.
  148. train_transform : Callable[..., Any], optional
  149. A function/transform that takes an object and returns a transformed version used
  150. in the `fit` and `train_epoch` methods. Defaults to None.
  151. test_transform : Callable[..., Any], optional
  152. A function/transform that takes an object and returns a transformed version in the
  153. `predict`, `predict_proba` and `score` methods. Defaults to None.
  154. collate_fn : Callable[[List[T]], Any], optional
  155. The function used to collate data. Defaults to None.
  156. Returns
  157. -------
  158. BasicNN
  159. The converted BasicNN instance.
  160. """
  161. if isinstance(lambdalearn_model, DeepModelMixin):
  162. if not isinstance(lambdalearn_model.network, torch.nn.Module):
  163. raise NotImplementedError(
  164. "Expected lambdalearn_model.network to be a torch.nn.Module, "
  165. + f"but got {type(lambdalearn_model.network)}"
  166. )
  167. # Only use the network part and device of the lambdalearn model
  168. network = copy.deepcopy(lambdalearn_model.network)
  169. optimizer_class = optimizer_dict["optimizer"]
  170. optimizer_dict.pop("optimizer")
  171. optimizer = optimizer_class(network.parameters(), **optimizer_dict)
  172. if scheduler_dict is not None:
  173. scheduler_class = scheduler_dict["scheduler"]
  174. scheduler_dict.pop("scheduler")
  175. scheduler = scheduler_class(optimizer, **scheduler_dict)
  176. else:
  177. scheduler = None
  178. device = lambdalearn_model.device if device is None else device
  179. base_model = BasicNN(
  180. model=network,
  181. loss_fn=loss_fn,
  182. optimizer=optimizer,
  183. scheduler=scheduler,
  184. device=device,
  185. batch_size=batch_size,
  186. num_epochs=num_epochs,
  187. stop_loss=stop_loss,
  188. num_workers=num_workers,
  189. save_interval=save_interval,
  190. save_dir=save_dir,
  191. train_transform=train_transform,
  192. test_transform=test_transform,
  193. collate_fn=collate_fn,
  194. )
  195. return base_model
  196. else:
  197. raise NotImplementedError(
  198. "The lambdalearn_model should be an instance of DeepModelMixin."
  199. )

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.