import torch import copy from typing import Any, Callable, List, Optional from .abl_model import ABLModel from .basic_nn import BasicNN from lambdaLearn.Base.DeepModelMixin import DeepModelMixin class ModelConverter: """ This class provides functionality to convert LambdaLearn models to ABLkit models. """ def __init__(self) -> None: pass def convert_lambdalearn_to_ablmodel( self, lambdalearn_model, loss_fn: torch.nn.Module, optimizer_dict: dict, scheduler_dict: Optional[dict] = None, device: Optional[torch.device] = None, batch_size: int = 32, num_epochs: int = 1, stop_loss: Optional[float] = 0.0001, num_workers: int = 0, save_interval: Optional[int] = None, save_dir: Optional[str] = None, train_transform: Callable[..., Any] = None, test_transform: Callable[..., Any] = None, collate_fn: Callable[[List[Any]], Any] = None, ): """ Convert a lambdalearn model to an ABLModel. If the lambdalearn model is an instance of DeepModelMixin, its network will be used as the model of BasicNN. Otherwise, the lambdalearn model should implement ``fit`` and ``predict`` methods. Parameters ---------- lambdalearn_model : Union[DeepModelMixin, Any] The LambdaLearn model to be converted. loss_fn : torch.nn.Module The loss function used for training. optimizer_dict : dict The dict contains necessary parameters to construct a optimizer used for training. The optimizer class is specified by the ``optimizer`` key. scheduler_dict : dict, optional The dict contains necessary parameters to construct a learning rate scheduler used for training, which will be called at the end of each run of the ``fit`` method. The scheduler class is specified by the ``scheduler`` key. It should implement the ``step`` method. Defaults to None. device : torch.device, optional The device on which the model will be trained or used for prediction, Defaults to torch.device("cpu"). batch_size : int, optional The batch size used for training. Defaults to 32. num_epochs : int, optional The number of epochs used for training. Defaults to 1. stop_loss : float, optional The loss value at which to stop training. Defaults to 0.0001. num_workers : int The number of workers used for loading data. Defaults to 0. save_interval : int, optional The model will be saved every ``save_interval`` epoch during training. Defaults to None. save_dir : str, optional The directory in which to save the model during training. Defaults to None. train_transform : Callable[..., Any], optional A function/transform that takes an object and returns a transformed version used in the `fit` and `train_epoch` methods. Defaults to None. test_transform : Callable[..., Any], optional A function/transform that takes an object and returns a transformed version in the `predict`, `predict_proba` and `score` methods. Defaults to None. collate_fn : Callable[[List[T]], Any], optional The function used to collate data. Defaults to None. Returns ------- ABLModel The converted ABLModel instance. """ if isinstance(lambdalearn_model, DeepModelMixin): base_model = self.convert_lambdalearn_to_basicnn( lambdalearn_model, loss_fn, optimizer_dict, scheduler_dict, device, batch_size, num_epochs, stop_loss, num_workers, save_interval, save_dir, train_transform, test_transform, collate_fn, ) return ABLModel(base_model) if not (hasattr(lambdalearn_model, "fit") and hasattr(lambdalearn_model, "predict")): raise NotImplementedError( "The lambdalearn_model should be an instance of DeepModelMixin, or implement " + "fit and predict methods." ) return ABLModel(lambdalearn_model) def convert_lambdalearn_to_basicnn( self, lambdalearn_model: DeepModelMixin, loss_fn: torch.nn.Module, optimizer_dict: dict, scheduler_dict: Optional[dict] = None, device: Optional[torch.device] = None, batch_size: int = 32, num_epochs: int = 1, stop_loss: Optional[float] = 0.0001, num_workers: int = 0, save_interval: Optional[int] = None, save_dir: Optional[str] = None, train_transform: Callable[..., Any] = None, test_transform: Callable[..., Any] = None, collate_fn: Callable[[List[Any]], Any] = None, ): """ Convert a lambdalearn model to a BasicNN. If the lambdalearn model is an instance of DeepModelMixin, its network will be used as the model of BasicNN. Parameters ---------- lambdalearn_model : Union[DeepModelMixin, Any] The LambdaLearn model to be converted. loss_fn : torch.nn.Module The loss function used for training. optimizer_dict : dict The dict contains necessary parameters to construct a optimizer used for training. scheduler_dict : dict, optional The dict contains necessary parameters to construct a learning rate scheduler used for training, which will be called at the end of each run of the ``fit`` method. The scheduler class is specified by the ``scheduler`` key. It should implement the ``step`` method. Defaults to None. device : torch.device, optional The device on which the model will be trained or used for prediction, Defaults to torch.device("cpu"). batch_size : int, optional The batch size used for training. Defaults to 32. num_epochs : int, optional The number of epochs used for training. Defaults to 1. stop_loss : float, optional The loss value at which to stop training. Defaults to 0.0001. num_workers : int The number of workers used for loading data. Defaults to 0. save_interval : int, optional The model will be saved every ``save_interval`` epoch during training. Defaults to None. save_dir : str, optional The directory in which to save the model during training. Defaults to None. train_transform : Callable[..., Any], optional A function/transform that takes an object and returns a transformed version used in the `fit` and `train_epoch` methods. Defaults to None. test_transform : Callable[..., Any], optional A function/transform that takes an object and returns a transformed version in the `predict`, `predict_proba` and `score` methods. Defaults to None. collate_fn : Callable[[List[T]], Any], optional The function used to collate data. Defaults to None. Returns ------- BasicNN The converted BasicNN instance. """ if isinstance(lambdalearn_model, DeepModelMixin): if not isinstance(lambdalearn_model.network, torch.nn.Module): raise NotImplementedError( "Expected lambdalearn_model.network to be a torch.nn.Module, " + f"but got {type(lambdalearn_model.network)}" ) # Only use the network part and device of the lambdalearn model network = copy.deepcopy(lambdalearn_model.network) optimizer_class = optimizer_dict["optimizer"] optimizer_dict.pop("optimizer") optimizer = optimizer_class(network.parameters(), **optimizer_dict) if scheduler_dict is not None: scheduler_class = scheduler_dict["scheduler"] scheduler_dict.pop("scheduler") scheduler = scheduler_class(optimizer, **scheduler_dict) else: scheduler = None device = lambdalearn_model.device if device is None else device base_model = BasicNN( model=network, loss_fn=loss_fn, optimizer=optimizer, scheduler=scheduler, device=device, batch_size=batch_size, num_epochs=num_epochs, stop_loss=stop_loss, num_workers=num_workers, save_interval=save_interval, save_dir=save_dir, train_transform=train_transform, test_transform=test_transform, collate_fn=collate_fn, ) return base_model else: raise NotImplementedError( "The lambdalearn_model should be an instance of DeepModelMixin." )