|
|
@@ -2,7 +2,7 @@ from __future__ import annotations |
|
|
|
|
|
|
|
import logging |
|
|
|
import os |
|
|
|
from typing import Any, Callable, List, Optional, Tuple |
|
|
|
from typing import Any, Callable, List, Optional, Tuple, Union |
|
|
|
|
|
|
|
import numpy |
|
|
|
import torch |
|
|
@@ -28,7 +28,7 @@ class BasicNN: |
|
|
|
The learning rate scheduler used for training, which will be called |
|
|
|
at the end of each run of the ``fit`` method. It should implement the |
|
|
|
``step`` method, by default None. |
|
|
|
device : torch.device, optional |
|
|
|
device : Union[torch.device, str] |
|
|
|
The device on which the model will be trained or used for prediction, |
|
|
|
by default torch.device("cpu"). |
|
|
|
batch_size : int, optional |
|
|
@@ -59,7 +59,7 @@ class BasicNN: |
|
|
|
loss_fn: torch.nn.Module, |
|
|
|
optimizer: torch.optim.Optimizer, |
|
|
|
scheduler: Optional[Callable[..., Any]] = None, |
|
|
|
device: torch.device = torch.device("cpu"), |
|
|
|
device: Union[torch.device, str] = torch.device("cpu"), |
|
|
|
batch_size: int = 32, |
|
|
|
num_epochs: int = 1, |
|
|
|
stop_loss: Optional[float] = 0.0001, |
|
|
@@ -79,7 +79,12 @@ class BasicNN: |
|
|
|
if scheduler is not None and not hasattr(scheduler, "step"): |
|
|
|
raise NotImplementedError("scheduler should implement the ``step`` method") |
|
|
|
if not isinstance(device, torch.device): |
|
|
|
raise TypeError("device must be an instance of torch.device") |
|
|
|
if not isinstance(device, str): |
|
|
|
raise TypeError( |
|
|
|
"device must be an instance of torch.device or a str indicates the target device" |
|
|
|
) |
|
|
|
else: |
|
|
|
device = torch.device(device) |
|
|
|
if not isinstance(batch_size, int): |
|
|
|
raise TypeError("batch_size must be an integer") |
|
|
|
if not isinstance(num_epochs, int): |
|
|
|