@@ -54,7 +54,7 @@ class JittorDataLoader: | |||||
:param endless: | :param endless: | ||||
:param collate_fn: 对取得到的数据进行打包的callable函数 | :param collate_fn: 对取得到的数据进行打包的callable函数 | ||||
""" | """ | ||||
# TODO 验证支持replacesampler (以后完成) | |||||
# TODO 验证支持replacesampler (以后完成) 增加Sampler | |||||
# 将内部dataset批次设置为1 | # 将内部dataset批次设置为1 | ||||
if isinstance(dataset, Dataset): | if isinstance(dataset, Dataset): | ||||
dataset.set_attrs(batch_size=1) | dataset.set_attrs(batch_size=1) | ||||
@@ -172,7 +172,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
return_list: bool = True, | return_list: bool = True, | ||||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | ||||
train_batch_size: int = 1, shuffle: bool = False, | train_batch_size: int = 1, shuffle: bool = False, | ||||
drop_last: bool = False, collate_fn: Union[Callable, str, None] = None, | |||||
drop_last: bool = False, collate_fn: Union[Callable, str, None] = 'auto', | |||||
num_workers: int = 0, use_buffer_reader: bool = True, | num_workers: int = 0, use_buffer_reader: bool = True, | ||||
use_shared_memory: bool = True, timeout: int = 0, | use_shared_memory: bool = True, timeout: int = 0, | ||||
worker_init_fn: Callable = None, persistent_workers=False, | worker_init_fn: Callable = None, persistent_workers=False, | ||||
@@ -177,12 +177,12 @@ class TorchDataLoader(DataLoader): | |||||
return self.cur_batch_indices | return self.cur_batch_indices | ||||
def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]], | |||||
def prepare_torch_dataloader(ds_or_db, | |||||
batch_size: int = 16, | batch_size: int = 16, | ||||
shuffle: bool = True, | |||||
shuffle: bool = False, | |||||
sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | ||||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | ||||
num_workers: int = 0, collate_fn: Union[str, Callable, None] = 'auto', | |||||
num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', | |||||
pin_memory: bool = False, drop_last: bool = False, | pin_memory: bool = False, drop_last: bool = False, | ||||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | timeout: float = 0, worker_init_fn: Optional[Callable] = None, | ||||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | ||||