|
|
@@ -177,12 +177,12 @@ class TorchDataLoader(DataLoader): |
|
|
|
return self.cur_batch_indices |
|
|
|
|
|
|
|
|
|
|
|
def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]], |
|
|
|
def prepare_torch_dataloader(ds_or_db, |
|
|
|
batch_size: int = 16, |
|
|
|
shuffle: bool = True, |
|
|
|
shuffle: bool = False, |
|
|
|
sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, |
|
|
|
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, |
|
|
|
num_workers: int = 0, collate_fn: Union[str, Callable, None] = 'auto', |
|
|
|
num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', |
|
|
|
pin_memory: bool = False, drop_last: bool = False, |
|
|
|
timeout: float = 0, worker_init_fn: Optional[Callable] = None, |
|
|
|
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, |
|
|
|