Browse Source

修改fdl train_sampler

tags/v1.0.0alpha
MorningForest 2 years ago
parent
commit
d2672c62b1
1 changed files with 8 additions and 8 deletions
  1. +8
    -8
      fastNLP/core/dataloaders/torch_dataloader/fdl.py

+ 8
- 8
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -199,7 +199,7 @@ class TorchDataLoader(DataLoader):
def prepare_torch_dataloader(ds_or_db, def prepare_torch_dataloader(ds_or_db,
batch_size: int = 16, batch_size: int = 16,
shuffle: bool = False, shuffle: bool = False,
train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto',
pin_memory: bool = False, drop_last: bool = False, pin_memory: bool = False, drop_last: bool = False,
@@ -261,7 +261,7 @@ def prepare_torch_dataloader(ds_or_db,
from fastNLP.io import DataBundle from fastNLP.io import DataBundle
if isinstance(ds_or_db, DataSet): if isinstance(ds_or_db, DataSet):
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size,
shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator, multiprocessing_context=multiprocessing_context, generator=generator,
@@ -274,7 +274,7 @@ def prepare_torch_dataloader(ds_or_db,
for name, ds in ds_or_db.iter_datasets(): for name, ds in ds_or_db.iter_datasets():
if 'train' in name: if 'train' in name:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size,
shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator, multiprocessing_context=multiprocessing_context, generator=generator,
@@ -285,7 +285,7 @@ def prepare_torch_dataloader(ds_or_db,
dl_bundle[name] = TorchDataLoader(dataset=ds, dl_bundle[name] = TorchDataLoader(dataset=ds,
batch_size=non_train_batch_size if non_train_batch_size else batch_size, batch_size=non_train_batch_size if non_train_batch_size else batch_size,
shuffle=shuffle, shuffle=shuffle,
sampler=non_train_sampler if non_train_sampler else train_sampler,
sampler=non_train_sampler if non_train_sampler else sampler,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
@@ -300,10 +300,10 @@ def prepare_torch_dataloader(ds_or_db,
for idx, ds in enumerate(ds_or_db): for idx, ds in enumerate(ds_or_db):
if idx > 0: if idx > 0:
batch_size = non_train_batch_size if non_train_batch_size else batch_size batch_size = non_train_batch_size if non_train_batch_size else batch_size
train_sampler = non_train_sampler if non_train_sampler else train_sampler
sampler = non_train_sampler if non_train_sampler else sampler
dl_bundle.append( dl_bundle.append(
TorchDataLoader(dataset=ds, batch_size=batch_size, TorchDataLoader(dataset=ds, batch_size=batch_size,
shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator, multiprocessing_context=multiprocessing_context, generator=generator,
@@ -317,7 +317,7 @@ def prepare_torch_dataloader(ds_or_db,
for name, ds in ds_or_db.items(): for name, ds in ds_or_db.items():
if 'train' in name: if 'train' in name:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size,
shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator, multiprocessing_context=multiprocessing_context, generator=generator,
@@ -328,7 +328,7 @@ def prepare_torch_dataloader(ds_or_db,
dl_bundle[name] = TorchDataLoader(dataset=ds, dl_bundle[name] = TorchDataLoader(dataset=ds,
batch_size=non_train_batch_size if non_train_batch_size else batch_size, batch_size=non_train_batch_size if non_train_batch_size else batch_size,
shuffle=shuffle, shuffle=shuffle,
sampler=non_train_sampler if non_train_sampler else train_sampler,
sampler=non_train_sampler if non_train_sampler else sampler,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,


Loading…
Cancel
Save