Browse Source

fdl修改

tags/v1.0.0alpha
MorningForest 3 years ago
parent
commit
11e7dda440
1 changed files with 7 additions and 5 deletions
  1. +7
    -5
      fastNLP/core/dataloaders/torch_dataloader/fdl.py

+ 7
- 5
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -178,7 +178,7 @@ class TorchDataLoader(DataLoader):


def prepare_torch_dataloader(ds_or_db,
batch_size: int = 16,
train_batch_size: int = 16,
shuffle: bool = False,
sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
@@ -214,7 +214,7 @@ def prepare_torch_dataloader(ds_or_db,

from fastNLP.io import DataBundle
if isinstance(ds_or_db, DataSet):
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size,
dl = TorchDataLoader(dataset=ds_or_db, batch_size=train_batch_size,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
@@ -227,7 +227,7 @@ def prepare_torch_dataloader(ds_or_db,
dl_bundle = {}
for name, ds in ds_or_db.iter_datasets():
if 'train' in name:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size,
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=train_batch_size,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
@@ -250,8 +250,10 @@ def prepare_torch_dataloader(ds_or_db,
elif isinstance(ds_or_db, Sequence):
dl_bundle = []
for idx, ds in enumerate(ds_or_db):
if idx > 0:
train_batch_size = non_train_batch_size
dl_bundle.append(
TorchDataLoader(dataset=ds, batch_size=batch_size,
TorchDataLoader(dataset=ds, batch_size=train_batch_size,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
@@ -265,7 +267,7 @@ def prepare_torch_dataloader(ds_or_db,
dl_bundle = {}
for name, ds in ds_or_db.items():
if 'train' in name:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size,
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=train_batch_size,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,


Loading…
Cancel
Save