Browse Source

prefecth变更为deprecated warning;

tags/v0.4.10
yh_cc 5 years ago
parent
commit
4b5113cbea
2 changed files with 10 additions and 4 deletions
  1. +8
    -3
      fastNLP/core/trainer.py
  2. +2
    -1
      reproduction/utils.py

+ 8
- 3
fastNLP/core/trainer.py View File

@@ -311,6 +311,7 @@ try:
from tqdm.auto import tqdm from tqdm.auto import tqdm
except: except:
from .utils import _pseudo_tqdm as tqdm from .utils import _pseudo_tqdm as tqdm
import warnings


from .batch import DataSetIter, BatchIter from .batch import DataSetIter, BatchIter
from .callback import CallbackManager, CallbackException from .callback import CallbackManager, CallbackException
@@ -320,7 +321,6 @@ from .metrics import _prepare_metrics
from .optimizer import Optimizer from .optimizer import Optimizer
from .sampler import Sampler from .sampler import Sampler
from .sampler import RandomSampler from .sampler import RandomSampler
from .sampler import SequentialSampler
from .tester import Tester from .tester import Tester
from .utils import _CheckError from .utils import _CheckError
from .utils import _build_args from .utils import _build_args
@@ -395,11 +395,16 @@ class Trainer(object):
""" """
def __init__(self, train_data, model, optimizer=None, loss=None, def __init__(self, train_data, model, optimizer=None, loss=None,
batch_size=32, sampler=None, drop_last=False,update_every=1,
batch_size=32, sampler=None, drop_last=False, update_every=1,
num_workers=0, n_epochs=10, print_every=5, num_workers=0, n_epochs=10, print_every=5,
dev_data=None, metrics=None, metric_key=None, dev_data=None, metrics=None, metric_key=None,
validate_every=-1, save_path=None, use_tqdm=True, device=None,
validate_every=-1, save_path=None, use_tqdm=True, device=None, prefetch=False,
callbacks=None, check_code_level=0): callbacks=None, check_code_level=0):
if prefetch and num_workers==0:
num_workers = 1
if prefetch:
warnings.warn("prefetch is deprecated, will be removed in version 0.5.0, please use num_workers instead.")

super(Trainer, self).__init__() super(Trainer, self).__init__()
if not isinstance(model, nn.Module): if not isinstance(model, nn.Module):
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.")


+ 2
- 1
reproduction/utils.py View File

@@ -13,7 +13,8 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
} }
如果paths为不合法的,将直接进行raise相应的错误 如果paths为不合法的,将直接进行raise相应的错误


:param paths: 路径
:param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train.txt,
test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。
:return: :return:
""" """
if isinstance(paths, str): if isinstance(paths, str):


Loading…
Cancel
Save