diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 3b989ec0..41f760e3 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -311,6 +311,7 @@ try: from tqdm.auto import tqdm except: from .utils import _pseudo_tqdm as tqdm +import warnings from .batch import DataSetIter, BatchIter from .callback import CallbackManager, CallbackException @@ -320,7 +321,6 @@ from .metrics import _prepare_metrics from .optimizer import Optimizer from .sampler import Sampler from .sampler import RandomSampler -from .sampler import SequentialSampler from .tester import Tester from .utils import _CheckError from .utils import _build_args @@ -395,11 +395,16 @@ class Trainer(object): """ def __init__(self, train_data, model, optimizer=None, loss=None, - batch_size=32, sampler=None, drop_last=False,update_every=1, + batch_size=32, sampler=None, drop_last=False, update_every=1, num_workers=0, n_epochs=10, print_every=5, dev_data=None, metrics=None, metric_key=None, - validate_every=-1, save_path=None, use_tqdm=True, device=None, + validate_every=-1, save_path=None, use_tqdm=True, device=None, prefetch=False, callbacks=None, check_code_level=0): + if prefetch and num_workers==0: + num_workers = 1 + if prefetch: + warnings.warn("prefetch is deprecated, will be removed in version 0.5.0, please use num_workers instead.") + super(Trainer, self).__init__() if not isinstance(model, nn.Module): raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") diff --git a/reproduction/utils.py b/reproduction/utils.py index 58883b43..bbfed4dd 100644 --- a/reproduction/utils.py +++ b/reproduction/utils.py @@ -13,7 +13,8 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: } 如果paths为不合法的,将直接进行raise相应的错误 - :param paths: 路径 + :param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train.txt, + test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 :return: """ if isinstance(paths, str):