|
|
@@ -66,9 +66,9 @@ class BasicNN: |
|
|
|
num_workers: int = 0, |
|
|
|
save_interval: Optional[int] = None, |
|
|
|
save_dir: Optional[str] = None, |
|
|
|
train_transform: Callable[..., Any] = None, |
|
|
|
test_transform: Callable[..., Any] = None, |
|
|
|
collate_fn: Callable[[List[Any]], Any] = None, |
|
|
|
train_transform: Optional[Callable[..., Any]] = None, |
|
|
|
test_transform: Optional[Callable[..., Any]] = None, |
|
|
|
collate_fn: Optional[Callable[[List[Any]], Any]] = None, |
|
|
|
) -> None: |
|
|
|
if not isinstance(model, torch.nn.Module): |
|
|
|
raise TypeError("model must be an instance of torch.nn.Module") |
|
|
|