From 11e7dda440e86c5ff6ff8bef6333dcdc918719ef Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Tue, 10 May 2022 02:38:38 +0800 Subject: [PATCH] =?UTF-8?q?fdl=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataloaders/torch_dataloader/fdl.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index b5fb5eab..0eae1a49 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -178,7 +178,7 @@ class TorchDataLoader(DataLoader): def prepare_torch_dataloader(ds_or_db, - batch_size: int = 16, + train_batch_size: int = 16, shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, @@ -214,7 +214,7 @@ def prepare_torch_dataloader(ds_or_db, from fastNLP.io import DataBundle if isinstance(ds_or_db, DataSet): - dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, + dl = TorchDataLoader(dataset=ds_or_db, batch_size=train_batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, @@ -227,7 +227,7 @@ def prepare_torch_dataloader(ds_or_db, dl_bundle = {} for name, ds in ds_or_db.iter_datasets(): if 'train' in name: - dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, + dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=train_batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, @@ -250,8 +250,10 @@ def prepare_torch_dataloader(ds_or_db, elif isinstance(ds_or_db, Sequence): dl_bundle = [] for idx, ds in enumerate(ds_or_db): + if idx > 0: + train_batch_size = non_train_batch_size dl_bundle.append( - TorchDataLoader(dataset=ds, batch_size=batch_size, + TorchDataLoader(dataset=ds, batch_size=train_batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, @@ -265,7 +267,7 @@ def prepare_torch_dataloader(ds_or_db, dl_bundle = {} for name, ds in ds_or_db.items(): if 'train' in name: - dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, + dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=train_batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,