Browse Source

[MNT] support str-type device

pull/1/head
Gao Enhao 1 year ago
parent
commit
848707d9c6
2 changed files with 10 additions and 5 deletions
  1. +9
    -4
      abl/learning/basic_nn.py
  2. +1
    -1
      docs/Examples/HED.rst

+ 9
- 4
abl/learning/basic_nn.py View File

@@ -2,7 +2,7 @@ from __future__ import annotations

import logging
import os
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple, Union

import numpy
import torch
@@ -28,7 +28,7 @@ class BasicNN:
The learning rate scheduler used for training, which will be called
at the end of each run of the ``fit`` method. It should implement the
``step`` method, by default None.
device : torch.device, optional
device : Union[torch.device, str]
The device on which the model will be trained or used for prediction,
by default torch.device("cpu").
batch_size : int, optional
@@ -59,7 +59,7 @@ class BasicNN:
loss_fn: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: Optional[Callable[..., Any]] = None,
device: torch.device = torch.device("cpu"),
device: Union[torch.device, str] = torch.device("cpu"),
batch_size: int = 32,
num_epochs: int = 1,
stop_loss: Optional[float] = 0.0001,
@@ -79,7 +79,12 @@ class BasicNN:
if scheduler is not None and not hasattr(scheduler, "step"):
raise NotImplementedError("scheduler should implement the ``step`` method")
if not isinstance(device, torch.device):
raise TypeError("device must be an instance of torch.device")
if not isinstance(device, str):
raise TypeError(
"device must be an instance of torch.device or a str indicates the target device"
)
else:
device = torch.device(device)
if not isinstance(batch_size, int):
raise TypeError("batch_size must be an integer")
if not isinstance(num_epochs, int):


+ 1
- 1
docs/Examples/HED.rst View File

@@ -188,7 +188,7 @@ sklearn-style interface.
cls,
loss_fn,
optimizer,
device,
device=device,
batch_size=32,
num_epochs=1,
stop_loss=None,


Loading…
Cancel
Save