|
|
@@ -311,6 +311,7 @@ try: |
|
|
|
from tqdm.auto import tqdm |
|
|
|
except: |
|
|
|
from .utils import _pseudo_tqdm as tqdm |
|
|
|
import warnings |
|
|
|
|
|
|
|
from .batch import DataSetIter, BatchIter |
|
|
|
from .callback import CallbackManager, CallbackException |
|
|
@@ -320,7 +321,6 @@ from .metrics import _prepare_metrics |
|
|
|
from .optimizer import Optimizer |
|
|
|
from .sampler import Sampler |
|
|
|
from .sampler import RandomSampler |
|
|
|
from .sampler import SequentialSampler |
|
|
|
from .tester import Tester |
|
|
|
from .utils import _CheckError |
|
|
|
from .utils import _build_args |
|
|
@@ -395,11 +395,16 @@ class Trainer(object): |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, train_data, model, optimizer=None, loss=None, |
|
|
|
batch_size=32, sampler=None, drop_last=False,update_every=1, |
|
|
|
batch_size=32, sampler=None, drop_last=False, update_every=1, |
|
|
|
num_workers=0, n_epochs=10, print_every=5, |
|
|
|
dev_data=None, metrics=None, metric_key=None, |
|
|
|
validate_every=-1, save_path=None, use_tqdm=True, device=None, |
|
|
|
validate_every=-1, save_path=None, use_tqdm=True, device=None, prefetch=False, |
|
|
|
callbacks=None, check_code_level=0): |
|
|
|
if prefetch and num_workers==0: |
|
|
|
num_workers = 1 |
|
|
|
if prefetch: |
|
|
|
warnings.warn("prefetch is deprecated, will be removed in version 0.5.0, please use num_workers instead.") |
|
|
|
|
|
|
|
super(Trainer, self).__init__() |
|
|
|
if not isinstance(model, nn.Module): |
|
|
|
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") |
|
|
|