From d2672c62b1134b60117b507e60bae21bf9088810 Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Wed, 18 May 2022 21:54:12 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9fdl=20train=5Fsampler?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataloaders/torch_dataloader/fdl.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 99faec7e..9818ab39 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -199,7 +199,7 @@ class TorchDataLoader(DataLoader): def prepare_torch_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = False, - train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, + sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', pin_memory: bool = False, drop_last: bool = False, @@ -261,7 +261,7 @@ def prepare_torch_dataloader(ds_or_db, from fastNLP.io import DataBundle if isinstance(ds_or_db, DataSet): dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, - shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, + shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, @@ -274,7 +274,7 @@ def prepare_torch_dataloader(ds_or_db, for name, ds in ds_or_db.iter_datasets(): if 'train' in name: dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, - shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, + shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, @@ -285,7 +285,7 @@ def prepare_torch_dataloader(ds_or_db, dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size if non_train_batch_size else batch_size, shuffle=shuffle, - sampler=non_train_sampler if non_train_sampler else train_sampler, + sampler=non_train_sampler if non_train_sampler else sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, @@ -300,10 +300,10 @@ def prepare_torch_dataloader(ds_or_db, for idx, ds in enumerate(ds_or_db): if idx > 0: batch_size = non_train_batch_size if non_train_batch_size else batch_size - train_sampler = non_train_sampler if non_train_sampler else train_sampler + sampler = non_train_sampler if non_train_sampler else sampler dl_bundle.append( TorchDataLoader(dataset=ds, batch_size=batch_size, - shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, + shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, @@ -317,7 +317,7 @@ def prepare_torch_dataloader(ds_or_db, for name, ds in ds_or_db.items(): if 'train' in name: dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, - shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, + shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, @@ -328,7 +328,7 @@ def prepare_torch_dataloader(ds_or_db, dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size if non_train_batch_size else batch_size, shuffle=shuffle, - sampler=non_train_sampler if non_train_sampler else train_sampler, + sampler=non_train_sampler if non_train_sampler else sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,