|
@@ -199,7 +199,7 @@ class TorchDataLoader(DataLoader): |
|
|
def prepare_torch_dataloader(ds_or_db, |
|
|
def prepare_torch_dataloader(ds_or_db, |
|
|
batch_size: int = 16, |
|
|
batch_size: int = 16, |
|
|
shuffle: bool = False, |
|
|
shuffle: bool = False, |
|
|
train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, |
|
|
|
|
|
|
|
|
sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, |
|
|
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, |
|
|
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, |
|
|
num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', |
|
|
num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', |
|
|
pin_memory: bool = False, drop_last: bool = False, |
|
|
pin_memory: bool = False, drop_last: bool = False, |
|
@@ -261,7 +261,7 @@ def prepare_torch_dataloader(ds_or_db, |
|
|
from fastNLP.io import DataBundle |
|
|
from fastNLP.io import DataBundle |
|
|
if isinstance(ds_or_db, DataSet): |
|
|
if isinstance(ds_or_db, DataSet): |
|
|
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, |
|
|
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, |
|
|
shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, |
|
|
|
|
|
|
|
|
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, |
|
|
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, |
|
|
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, |
|
|
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, |
|
|
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, |
|
|
multiprocessing_context=multiprocessing_context, generator=generator, |
|
|
multiprocessing_context=multiprocessing_context, generator=generator, |
|
@@ -274,7 +274,7 @@ def prepare_torch_dataloader(ds_or_db, |
|
|
for name, ds in ds_or_db.iter_datasets(): |
|
|
for name, ds in ds_or_db.iter_datasets(): |
|
|
if 'train' in name: |
|
|
if 'train' in name: |
|
|
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, |
|
|
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, |
|
|
shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, |
|
|
|
|
|
|
|
|
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, |
|
|
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, |
|
|
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, |
|
|
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, |
|
|
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, |
|
|
multiprocessing_context=multiprocessing_context, generator=generator, |
|
|
multiprocessing_context=multiprocessing_context, generator=generator, |
|
@@ -285,7 +285,7 @@ def prepare_torch_dataloader(ds_or_db, |
|
|
dl_bundle[name] = TorchDataLoader(dataset=ds, |
|
|
dl_bundle[name] = TorchDataLoader(dataset=ds, |
|
|
batch_size=non_train_batch_size if non_train_batch_size else batch_size, |
|
|
batch_size=non_train_batch_size if non_train_batch_size else batch_size, |
|
|
shuffle=shuffle, |
|
|
shuffle=shuffle, |
|
|
sampler=non_train_sampler if non_train_sampler else train_sampler, |
|
|
|
|
|
|
|
|
sampler=non_train_sampler if non_train_sampler else sampler, |
|
|
batch_sampler=batch_sampler, |
|
|
batch_sampler=batch_sampler, |
|
|
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, |
|
|
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, |
|
|
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, |
|
|
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, |
|
@@ -300,10 +300,10 @@ def prepare_torch_dataloader(ds_or_db, |
|
|
for idx, ds in enumerate(ds_or_db): |
|
|
for idx, ds in enumerate(ds_or_db): |
|
|
if idx > 0: |
|
|
if idx > 0: |
|
|
batch_size = non_train_batch_size if non_train_batch_size else batch_size |
|
|
batch_size = non_train_batch_size if non_train_batch_size else batch_size |
|
|
train_sampler = non_train_sampler if non_train_sampler else train_sampler |
|
|
|
|
|
|
|
|
sampler = non_train_sampler if non_train_sampler else sampler |
|
|
dl_bundle.append( |
|
|
dl_bundle.append( |
|
|
TorchDataLoader(dataset=ds, batch_size=batch_size, |
|
|
TorchDataLoader(dataset=ds, batch_size=batch_size, |
|
|
shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, |
|
|
|
|
|
|
|
|
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, |
|
|
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, |
|
|
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, |
|
|
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, |
|
|
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, |
|
|
multiprocessing_context=multiprocessing_context, generator=generator, |
|
|
multiprocessing_context=multiprocessing_context, generator=generator, |
|
@@ -317,7 +317,7 @@ def prepare_torch_dataloader(ds_or_db, |
|
|
for name, ds in ds_or_db.items(): |
|
|
for name, ds in ds_or_db.items(): |
|
|
if 'train' in name: |
|
|
if 'train' in name: |
|
|
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, |
|
|
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, |
|
|
shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, |
|
|
|
|
|
|
|
|
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, |
|
|
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, |
|
|
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, |
|
|
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, |
|
|
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, |
|
|
multiprocessing_context=multiprocessing_context, generator=generator, |
|
|
multiprocessing_context=multiprocessing_context, generator=generator, |
|
@@ -328,7 +328,7 @@ def prepare_torch_dataloader(ds_or_db, |
|
|
dl_bundle[name] = TorchDataLoader(dataset=ds, |
|
|
dl_bundle[name] = TorchDataLoader(dataset=ds, |
|
|
batch_size=non_train_batch_size if non_train_batch_size else batch_size, |
|
|
batch_size=non_train_batch_size if non_train_batch_size else batch_size, |
|
|
shuffle=shuffle, |
|
|
shuffle=shuffle, |
|
|
sampler=non_train_sampler if non_train_sampler else train_sampler, |
|
|
|
|
|
|
|
|
sampler=non_train_sampler if non_train_sampler else sampler, |
|
|
batch_sampler=batch_sampler, |
|
|
batch_sampler=batch_sampler, |
|
|
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, |
|
|
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, |
|
|
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, |
|
|
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, |
|
|