@@ -185,7 +185,7 @@ class JittorDataLoader: | |||||
return self.cur_batch_indices | return self.cur_batch_indices | ||||
def prepare_jittor_dataloader(ds_or_db, train_batch_size: int = 16, shuffle: bool = True, | |||||
def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = True, | |||||
drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | ||||
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | ||||
collate_fn: Union[None, str, Callable] = "auto", | collate_fn: Union[None, str, Callable] = "auto", | ||||
@@ -211,8 +211,9 @@ def prepare_jittor_dataloader(ds_or_db, train_batch_size: int = 16, shuffle: boo | |||||
:param ds_or_db: 传进来的dataset集合或字典或为dataset或DataBundle。其取值只能为``[Dataset, DataBundle, | :param ds_or_db: 传进来的dataset集合或字典或为dataset或DataBundle。其取值只能为``[Dataset, DataBundle, | ||||
Sequence[Dataset], Dict[name, Dataset]]``. | Sequence[Dataset], Dict[name, Dataset]]``. | ||||
:param train_batch_size: 'train'数据集使用的batch_size,跟non_train_batch_size是互斥的。 | |||||
:param non_train_batch_size: 非'train'数据使用batch_size,跟train_batch_size是互斥的。 | |||||
:param batch_size: batch 的大小。 | |||||
:param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 ``Dict`` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数 | |||||
设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。 | |||||
:param shuffle: 是否打乱数据集 | :param shuffle: 是否打乱数据集 | ||||
:param drop_last: 是否去掉最后一个不符合``batch_size``的数据 | :param drop_last: 是否去掉最后一个不符合``batch_size``的数据 | ||||
:param num_workers: 进程的数量,当``num_workers=0``时不开启多进程 | :param num_workers: 进程的数量,当``num_workers=0``时不开启多进程 | ||||
@@ -234,7 +235,7 @@ def prepare_jittor_dataloader(ds_or_db, train_batch_size: int = 16, shuffle: boo | |||||
""" | """ | ||||
from fastNLP.io.data_bundle import DataBundle | from fastNLP.io.data_bundle import DataBundle | ||||
if isinstance(ds_or_db, Dataset): | if isinstance(ds_or_db, Dataset): | ||||
dl = JittorDataLoader(ds_or_db, batch_size=train_batch_size, shuffle=shuffle, | |||||
dl = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=shuffle, | |||||
drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, | drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, | ||||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, | stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, | ||||
collate_fn=collate_fn) | collate_fn=collate_fn) | ||||
@@ -243,7 +244,7 @@ def prepare_jittor_dataloader(ds_or_db, train_batch_size: int = 16, shuffle: boo | |||||
dl_bundle = {} | dl_bundle = {} | ||||
for name, ds in ds_or_db.iter_datasets(): | for name, ds in ds_or_db.iter_datasets(): | ||||
if 'train' in name: | if 'train' in name: | ||||
dl_bundle[name] = JittorDataLoader(ds_or_db, batch_size=train_batch_size, shuffle=shuffle, | |||||
dl_bundle[name] = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=shuffle, | |||||
drop_last=drop_last, num_workers=num_workers, | drop_last=drop_last, num_workers=num_workers, | ||||
buffer_size=buffer_size, | buffer_size=buffer_size, | ||||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, | stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, | ||||
@@ -251,7 +252,7 @@ def prepare_jittor_dataloader(ds_or_db, train_batch_size: int = 16, shuffle: boo | |||||
collate_fn=collate_fn) | collate_fn=collate_fn) | ||||
else: | else: | ||||
dl_bundle[name] = JittorDataLoader(ds_or_db, | dl_bundle[name] = JittorDataLoader(ds_or_db, | ||||
batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, | |||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | |||||
shuffle=shuffle, | shuffle=shuffle, | ||||
drop_last=drop_last, num_workers=num_workers, | drop_last=drop_last, num_workers=num_workers, | ||||
buffer_size=buffer_size, | buffer_size=buffer_size, | ||||
@@ -263,8 +264,8 @@ def prepare_jittor_dataloader(ds_or_db, train_batch_size: int = 16, shuffle: boo | |||||
ds_seq = [] | ds_seq = [] | ||||
for idx, ds in enumerate(ds_or_db): | for idx, ds in enumerate(ds_or_db): | ||||
if idx > 0: | if idx > 0: | ||||
train_batch_size = non_train_batch_size if non_train_batch_size else train_batch_size | |||||
dl = JittorDataLoader(ds, batch_size=train_batch_size, shuffle=shuffle, | |||||
batch_size = non_train_batch_size if non_train_batch_size else batch_size | |||||
dl = JittorDataLoader(ds, batch_size=batch_size, shuffle=shuffle, | |||||
drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, | drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, | ||||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, | stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, | ||||
collate_fn=collate_fn) | collate_fn=collate_fn) | ||||
@@ -275,13 +276,13 @@ def prepare_jittor_dataloader(ds_or_db, train_batch_size: int = 16, shuffle: boo | |||||
ds_dict = {} | ds_dict = {} | ||||
for name, ds in ds_or_db.items(): | for name, ds in ds_or_db.items(): | ||||
if 'train' in name: | if 'train' in name: | ||||
dl = JittorDataLoader(ds, batch_size=train_batch_size, shuffle=shuffle, | |||||
dl = JittorDataLoader(ds, batch_size=batch_size, shuffle=shuffle, | |||||
drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, | drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, | ||||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, | stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, | ||||
collate_fn=collate_fn) | collate_fn=collate_fn) | ||||
else: | else: | ||||
dl = JittorDataLoader(ds_or_db, | dl = JittorDataLoader(ds_or_db, | ||||
batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, | |||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | |||||
shuffle=shuffle, | shuffle=shuffle, | ||||
drop_last=drop_last, num_workers=num_workers, | drop_last=drop_last, num_workers=num_workers, | ||||
buffer_size=buffer_size, | buffer_size=buffer_size, | ||||
@@ -253,7 +253,7 @@ class PaddleDataLoader(DataLoader): | |||||
def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | ||||
return_list: bool = True, | return_list: bool = True, | ||||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | ||||
train_batch_size: int = 16, shuffle: bool = False, | |||||
batch_size: int = 16, shuffle: bool = False, | |||||
drop_last: bool = False, collate_fn: Union[Callable, str, None] = 'auto', | drop_last: bool = False, collate_fn: Union[Callable, str, None] = 'auto', | ||||
num_workers: int = 0, use_buffer_reader: bool = True, | num_workers: int = 0, use_buffer_reader: bool = True, | ||||
use_shared_memory: bool = True, timeout: int = 0, | use_shared_memory: bool = True, timeout: int = 0, | ||||
@@ -280,8 +280,9 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
:param ds_or_db: 传进来的dataset集合或字典或为dataset或DataBundle。其取值只能为``[Dataset, DataBundle, | :param ds_or_db: 传进来的dataset集合或字典或为dataset或DataBundle。其取值只能为``[Dataset, DataBundle, | ||||
Sequence[Dataset], Dict[name, Dataset]]``. | Sequence[Dataset], Dict[name, Dataset]]``. | ||||
:param train_batch_size: 'train'数据集使用的batch_size,跟non_train_batch_size是互斥的。 | |||||
:param non_train_batch_size: 非'train'数据使用batch_size,跟train_batch_size是互斥的。 | |||||
:param batch_size: batch 的大小。 | |||||
:param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 ``Dict`` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数 | |||||
设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。 | |||||
:param feed_list: (list(Tensor)|tuple(Tensor)): feed Tensor list. | :param feed_list: (list(Tensor)|tuple(Tensor)): feed Tensor list. | ||||
The Tensors should be created by :code:`paddle.static.data()`. | The Tensors should be created by :code:`paddle.static.data()`. | ||||
:attr:`feed_list` must be set if :attr:`return_list` is | :attr:`feed_list` must be set if :attr:`return_list` is | ||||
@@ -327,7 +328,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
from fastNLP.io.data_bundle import DataBundle | from fastNLP.io.data_bundle import DataBundle | ||||
if isinstance(ds_or_db, Dataset): | if isinstance(ds_or_db, Dataset): | ||||
dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, | dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, | ||||
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, | |||||
batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, | |||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | ||||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | ||||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | ||||
@@ -338,7 +339,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
if 'train' in name: | if 'train' in name: | ||||
dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places, | dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places, | ||||
return_list=return_list, | return_list=return_list, | ||||
batch_sampler=batch_sampler, batch_size=train_batch_size, | |||||
batch_sampler=batch_sampler, batch_size=batch_size, | |||||
shuffle=shuffle, | shuffle=shuffle, | ||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | ||||
use_shared_memory=use_shared_memory, | use_shared_memory=use_shared_memory, | ||||
@@ -349,7 +350,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places, | dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places, | ||||
return_list=return_list, | return_list=return_list, | ||||
batch_sampler=batch_sampler, | batch_sampler=batch_sampler, | ||||
batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, | |||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | |||||
shuffle=shuffle, | shuffle=shuffle, | ||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | ||||
use_shared_memory=use_shared_memory, | use_shared_memory=use_shared_memory, | ||||
@@ -361,9 +362,9 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
ds_seq = [] | ds_seq = [] | ||||
for idx, ds in enumerate(ds_or_db): | for idx, ds in enumerate(ds_or_db): | ||||
if idx > 0: | if idx > 0: | ||||
train_batch_size = non_train_batch_size if non_train_batch_size else train_batch_size | |||||
batch_size = non_train_batch_size if non_train_batch_size else batch_size | |||||
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | ||||
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, | |||||
batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, | |||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | ||||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | ||||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | ||||
@@ -375,7 +376,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
for name, ds in ds_or_db.items(): | for name, ds in ds_or_db.items(): | ||||
if 'train' in name: | if 'train' in name: | ||||
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | ||||
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, | |||||
batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, | |||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | ||||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | ||||
timeout=timeout, worker_init_fn=worker_init_fn, | timeout=timeout, worker_init_fn=worker_init_fn, | ||||
@@ -383,7 +384,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
else: | else: | ||||
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | ||||
batch_sampler=batch_sampler, | batch_sampler=batch_sampler, | ||||
batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, | |||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | |||||
shuffle=shuffle, | shuffle=shuffle, | ||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | ||||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | ||||
@@ -197,7 +197,7 @@ class TorchDataLoader(DataLoader): | |||||
def prepare_torch_dataloader(ds_or_db, | def prepare_torch_dataloader(ds_or_db, | ||||
train_batch_size: int = 16, | |||||
batch_size: int = 16, | |||||
shuffle: bool = False, | shuffle: bool = False, | ||||
train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | ||||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | ||||
@@ -229,8 +229,9 @@ def prepare_torch_dataloader(ds_or_db, | |||||
:param ds_or_db: 传进来的dataset集合或字典或为dataset或DataBundle。其取值只能为``[Dataset, DataBundle, | :param ds_or_db: 传进来的dataset集合或字典或为dataset或DataBundle。其取值只能为``[Dataset, DataBundle, | ||||
Sequence[Dataset], Dict[name, Dataset]]``. | Sequence[Dataset], Dict[name, Dataset]]``. | ||||
:param shuffle: 是否打乱数据集 | :param shuffle: 是否打乱数据集 | ||||
:param train_batch_size: 'train'数据集使用的batch_size,跟non_train_batch_size是互斥的。 | |||||
:param non_train_batch_size: 非'train'数据使用batch_size,跟train_batch_size是互斥的。 | |||||
:param batch_size: batch 的大小。 | |||||
:param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 ``Dict`` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数 | |||||
设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。 | |||||
:param train_sampler: train'数据集使用的sampler, 现了__len__和__iter__方法的实例化对象,其功能是每次返回dataset的一个index,当其不为None时,shuffle参数无效 | :param train_sampler: train'数据集使用的sampler, 现了__len__和__iter__方法的实例化对象,其功能是每次返回dataset的一个index,当其不为None时,shuffle参数无效 | ||||
:param non_train_sampler: 非'train'数据使用sampler, 实现了__len__和__iter__方法的实例化对象,其功能是每次返回dataset的一个index,当其不为None时,shuffle参数无效 | :param non_train_sampler: 非'train'数据使用sampler, 实现了__len__和__iter__方法的实例化对象,其功能是每次返回dataset的一个index,当其不为None时,shuffle参数无效 | ||||
:param batch_sampler: 实现了__len__和__iter__方法的实例化对象,,其能迭代返回一个list的index数据, index不超过dataset的大小, | :param batch_sampler: 实现了__len__和__iter__方法的实例化对象,,其能迭代返回一个list的index数据, index不超过dataset的大小, | ||||
@@ -259,7 +260,7 @@ def prepare_torch_dataloader(ds_or_db, | |||||
from fastNLP.io import DataBundle | from fastNLP.io import DataBundle | ||||
if isinstance(ds_or_db, DataSet): | if isinstance(ds_or_db, DataSet): | ||||
dl = TorchDataLoader(dataset=ds_or_db, batch_size=train_batch_size, | |||||
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | |||||
shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, | shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, | ||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
@@ -272,7 +273,7 @@ def prepare_torch_dataloader(ds_or_db, | |||||
dl_bundle = {} | dl_bundle = {} | ||||
for name, ds in ds_or_db.iter_datasets(): | for name, ds in ds_or_db.iter_datasets(): | ||||
if 'train' in name: | if 'train' in name: | ||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=train_batch_size, | |||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, | |||||
shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, | shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, | ||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
@@ -282,7 +283,7 @@ def prepare_torch_dataloader(ds_or_db, | |||||
) | ) | ||||
else: | else: | ||||
dl_bundle[name] = TorchDataLoader(dataset=ds, | dl_bundle[name] = TorchDataLoader(dataset=ds, | ||||
batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, | |||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | |||||
shuffle=shuffle, | shuffle=shuffle, | ||||
sampler=non_train_sampler if non_train_sampler else train_sampler, | sampler=non_train_sampler if non_train_sampler else train_sampler, | ||||
batch_sampler=batch_sampler, | batch_sampler=batch_sampler, | ||||
@@ -298,10 +299,10 @@ def prepare_torch_dataloader(ds_or_db, | |||||
dl_bundle = [] | dl_bundle = [] | ||||
for idx, ds in enumerate(ds_or_db): | for idx, ds in enumerate(ds_or_db): | ||||
if idx > 0: | if idx > 0: | ||||
train_batch_size = non_train_batch_size if non_train_batch_size else train_batch_size | |||||
batch_size = non_train_batch_size if non_train_batch_size else batch_size | |||||
train_sampler = non_train_sampler if non_train_sampler else train_sampler | train_sampler = non_train_sampler if non_train_sampler else train_sampler | ||||
dl_bundle.append( | dl_bundle.append( | ||||
TorchDataLoader(dataset=ds, batch_size=train_batch_size, | |||||
TorchDataLoader(dataset=ds, batch_size=batch_size, | |||||
shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, | shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, | ||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
@@ -315,7 +316,7 @@ def prepare_torch_dataloader(ds_or_db, | |||||
dl_bundle = {} | dl_bundle = {} | ||||
for name, ds in ds_or_db.items(): | for name, ds in ds_or_db.items(): | ||||
if 'train' in name: | if 'train' in name: | ||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=train_batch_size, | |||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, | |||||
shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, | shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, | ||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
@@ -325,7 +326,7 @@ def prepare_torch_dataloader(ds_or_db, | |||||
) | ) | ||||
else: | else: | ||||
dl_bundle[name] = TorchDataLoader(dataset=ds, | dl_bundle[name] = TorchDataLoader(dataset=ds, | ||||
batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, | |||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | |||||
shuffle=shuffle, | shuffle=shuffle, | ||||
sampler=non_train_sampler if non_train_sampler else train_sampler, | sampler=non_train_sampler if non_train_sampler else train_sampler, | ||||
batch_sampler=batch_sampler, | batch_sampler=batch_sampler, | ||||