Browse Source

prefecth变更为deprecated warning;

tags/v0.4.10
yh_cc 5 years ago
parent
commit
4b5113cbea
2 changed files with 10 additions and 4 deletions
  1. +8
    -3
      fastNLP/core/trainer.py
  2. +2
    -1
      reproduction/utils.py

+ 8
- 3
fastNLP/core/trainer.py View File

@@ -311,6 +311,7 @@ try:
from tqdm.auto import tqdm
except:
from .utils import _pseudo_tqdm as tqdm
import warnings

from .batch import DataSetIter, BatchIter
from .callback import CallbackManager, CallbackException
@@ -320,7 +321,6 @@ from .metrics import _prepare_metrics
from .optimizer import Optimizer
from .sampler import Sampler
from .sampler import RandomSampler
from .sampler import SequentialSampler
from .tester import Tester
from .utils import _CheckError
from .utils import _build_args
@@ -395,11 +395,16 @@ class Trainer(object):
"""
def __init__(self, train_data, model, optimizer=None, loss=None,
batch_size=32, sampler=None, drop_last=False,update_every=1,
batch_size=32, sampler=None, drop_last=False, update_every=1,
num_workers=0, n_epochs=10, print_every=5,
dev_data=None, metrics=None, metric_key=None,
validate_every=-1, save_path=None, use_tqdm=True, device=None,
validate_every=-1, save_path=None, use_tqdm=True, device=None, prefetch=False,
callbacks=None, check_code_level=0):
if prefetch and num_workers==0:
num_workers = 1
if prefetch:
warnings.warn("prefetch is deprecated, will be removed in version 0.5.0, please use num_workers instead.")

super(Trainer, self).__init__()
if not isinstance(model, nn.Module):
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.")


+ 2
- 1
reproduction/utils.py View File

@@ -13,7 +13,8 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
}
如果paths为不合法的,将直接进行raise相应的错误

:param paths: 路径
:param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train.txt,
test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。
:return:
"""
if isinstance(paths, str):


Loading…
Cancel
Save