@@ -19,6 +19,7 @@ else: | |||||
from fastNLP.core.collators import Collator | from fastNLP.core.collators import Collator | ||||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper | from fastNLP.core.dataloaders.utils import indice_collate_wrapper | ||||
from fastNLP.core.dataset import DataSet as FDataSet | from fastNLP.core.dataset import DataSet as FDataSet | ||||
from ..utils import HasLenGetitemType | |||||
class _JittorDataset(Dataset): | class _JittorDataset(Dataset): | ||||
@@ -204,10 +205,10 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||||
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | ||||
collate_fn: Union[None, str, Callable] = "auto", | collate_fn: Union[None, str, Callable] = "auto", | ||||
non_train_batch_size: int = None) \ | non_train_batch_size: int = None) \ | ||||
-> Union[Sequence[JittorDataLoader], Dict[str, JittorDataLoader], JittorDataLoader]: | |||||
-> Union[Dict[str, JittorDataLoader], JittorDataLoader]: | |||||
""" | """ | ||||
``prepare_jittor_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 :class:`JittorDataLoader` 对象, 详见 :class:`~fastNLP.core.dataloaders.JittorDataLoader`。 | ``prepare_jittor_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 :class:`JittorDataLoader` 对象, 详见 :class:`~fastNLP.core.dataloaders.JittorDataLoader`。 | ||||
根据 ds_or_db 的类型 ``[DataSet, DataBundle,Sequence[Dataset], Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | |||||
根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | |||||
* 当 ds_or_db 为 ``DataSet`` 时,``prepare_jittor_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 | * 当 ds_or_db 为 ``DataSet`` 时,``prepare_jittor_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 | ||||
帮你实例化一个 :class:`JittorDataLoader` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.JittorDataLoader`。 | 帮你实例化一个 :class:`JittorDataLoader` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.JittorDataLoader`。 | ||||
@@ -219,19 +220,15 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||||
:class:`JittorDataLoader` 对象;当 key 中包含'train'字符串时,``prepare_Jittor_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, | :class:`JittorDataLoader` 对象;当 key 中包含'train'字符串时,``prepare_Jittor_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, | ||||
其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: JittorDataLoader`` 组成 | 其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: JittorDataLoader`` 组成 | ||||
``Dict[key, JittorDataLoader]`` 的字典返回。 | ``Dict[key, JittorDataLoader]`` 的字典返回。 | ||||
* 当 ds_or_db 为 ``Sequence[Dataset]`` 数据类型时, prepare_jittor_dataloader 会将 Sequence[0] 的数据集默认为 train 数据集对待, | |||||
会将 batch_size 和 sampler 作为参数, 而 Sequence[1:] 数据集均视为非 train 数据集对待,使用 non_train_size 和 non_train_sampler 作为参数。 | |||||
最终将所有实例化好的 :class:`JittorDataLoader` 组成 ``Sequence[JittorDataLoader]`` 返回。 | |||||
:param ds_or_db: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。其取值只能为 ``[DataSet, DataBundle, | :param ds_or_db: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。其取值只能为 ``[DataSet, DataBundle, | ||||
Sequence[DataSet], Dict[str, DataSet]]``. | |||||
Dict[str, DataSet]]``. | |||||
* ds_or_db 为 :class:`~fastNLP.core.dataset.DataSet`,返回值为 :class:`~fastNLP.core.dataloaders.JittorDataLoader` | * ds_or_db 为 :class:`~fastNLP.core.dataset.DataSet`,返回值为 :class:`~fastNLP.core.dataloaders.JittorDataLoader` | ||||
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 :class:`Dict[str, JittorDataLoader]` 的字典 | * ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 :class:`Dict[str, JittorDataLoader]` 的字典 | ||||
* ds_or_db 为 :class:`Sequence[DataSet]` 序列, 返回值为 :class:`Sequence[JittorDataLoader]` 的序列 | |||||
* ds_or_db 为 :class:`Dict[str, DataSet]` 字典, 返回值也为 :class:`Dict[str, JittorDataLoader]` 的字典 | * ds_or_db 为 :class:`Dict[str, DataSet]` 字典, 返回值也为 :class:`Dict[str, JittorDataLoader]` 的字典 | ||||
:param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 :class:`Dict`, :class:`Sequence` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数 | |||||
:param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 :class:`Dict` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数 | |||||
设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。 默认为 ``16``。 | 设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。 默认为 ``16``。 | ||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | ||||
:param shuffle: 是否打乱数据集, 默认为 ``False``。 | :param shuffle: 是否打乱数据集, 默认为 ``False``。 | ||||
@@ -253,28 +250,23 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||||
* collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | * collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | ||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | ||||
:return: 返回数据类型为 :class:`Sequence[JittorDataLoader]` , :class:`Dict[str, JittorDataLoader]`, :class:`JittorDataLoader` 其中之一,根据输入 | |||||
:return: 返回数据类型为 :class:`Dict[str, JittorDataLoader]`, :class:`JittorDataLoader` 其中之一,根据输入 | |||||
``ds_or_db`` 变化而变化。 | ``ds_or_db`` 变化而变化。 | ||||
""" | """ | ||||
from fastNLP.io.data_bundle import DataBundle | from fastNLP.io.data_bundle import DataBundle | ||||
if isinstance(ds_or_db, Dataset): | |||||
dl = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=shuffle, | |||||
drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, | |||||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, | |||||
collate_fn=collate_fn) | |||||
return dl | |||||
elif isinstance(ds_or_db, DataBundle): | |||||
if isinstance(ds_or_db, DataBundle): | |||||
dl_bundle = {} | dl_bundle = {} | ||||
for name, ds in ds_or_db.iter_datasets(): | for name, ds in ds_or_db.iter_datasets(): | ||||
if 'train' in name: | if 'train' in name: | ||||
dl_bundle[name] = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=shuffle, | |||||
dl_bundle[name] = JittorDataLoader(ds, batch_size=batch_size, shuffle=shuffle, | |||||
drop_last=drop_last, num_workers=num_workers, | drop_last=drop_last, num_workers=num_workers, | ||||
buffer_size=buffer_size, | buffer_size=buffer_size, | ||||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, | stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, | ||||
endless=endless, | endless=endless, | ||||
collate_fn=collate_fn) | collate_fn=collate_fn) | ||||
else: | else: | ||||
dl_bundle[name] = JittorDataLoader(ds_or_db, | |||||
dl_bundle[name] = JittorDataLoader(ds, | |||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | batch_size=non_train_batch_size if non_train_batch_size else batch_size, | ||||
shuffle=shuffle, | shuffle=shuffle, | ||||
drop_last=drop_last, num_workers=num_workers, | drop_last=drop_last, num_workers=num_workers, | ||||
@@ -283,17 +275,6 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||||
endless=endless, | endless=endless, | ||||
collate_fn=collate_fn) | collate_fn=collate_fn) | ||||
return dl_bundle | return dl_bundle | ||||
elif isinstance(ds_or_db, Sequence): | |||||
ds_seq = [] | |||||
for idx, ds in enumerate(ds_or_db): | |||||
if idx > 0: | |||||
batch_size = non_train_batch_size if non_train_batch_size else batch_size | |||||
dl = JittorDataLoader(ds, batch_size=batch_size, shuffle=shuffle, | |||||
drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, | |||||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, | |||||
collate_fn=collate_fn) | |||||
ds_seq.append(dl) | |||||
return ds_seq | |||||
elif isinstance(ds_or_db, Dict): | elif isinstance(ds_or_db, Dict): | ||||
ds_dict = {} | ds_dict = {} | ||||
@@ -304,7 +285,7 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, | stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, | ||||
collate_fn=collate_fn) | collate_fn=collate_fn) | ||||
else: | else: | ||||
dl = JittorDataLoader(ds_or_db, | |||||
dl = JittorDataLoader(ds, | |||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | batch_size=non_train_batch_size if non_train_batch_size else batch_size, | ||||
shuffle=shuffle, | shuffle=shuffle, | ||||
drop_last=drop_last, num_workers=num_workers, | drop_last=drop_last, num_workers=num_workers, | ||||
@@ -314,5 +295,13 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||||
collate_fn=collate_fn) | collate_fn=collate_fn) | ||||
ds_dict[name] = dl | ds_dict[name] = dl | ||||
return ds_dict | return ds_dict | ||||
elif isinstance(ds_or_db, HasLenGetitemType): | |||||
dl = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=shuffle, | |||||
drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, | |||||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, | |||||
collate_fn=collate_fn) | |||||
return dl | |||||
else: | else: | ||||
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or sequence or mapping!") | |||||
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or mapping!") |
@@ -19,7 +19,7 @@ from fastNLP.core.collators.collator import Collator | |||||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper | from fastNLP.core.dataloaders.utils import indice_collate_wrapper | ||||
from fastNLP.core.dataset import DataSet as FDataSet | from fastNLP.core.dataset import DataSet as FDataSet | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, RandomBatchSampler | from fastNLP.core.samplers import ReproducibleBatchSampler, RandomBatchSampler | ||||
from ..utils import _match_param | |||||
from ..utils import _match_param, HasLenGetitemType | |||||
class _PaddleDataset(Dataset): | class _PaddleDataset(Dataset): | ||||
@@ -256,10 +256,10 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
use_shared_memory: bool = True, timeout: int = 0, | use_shared_memory: bool = True, timeout: int = 0, | ||||
worker_init_fn: Callable = None, persistent_workers=False, | worker_init_fn: Callable = None, persistent_workers=False, | ||||
non_train_batch_size: int = None) \ | non_train_batch_size: int = None) \ | ||||
-> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: | |||||
-> Union[Dict[str, PaddleDataLoader], PaddleDataLoader]: | |||||
""" | """ | ||||
``prepare_paddle_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``PaddleDataloader``对象, 详见 :class:`~fastNLP.core.dataloaders.PaddleDataLoader`。 | ``prepare_paddle_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``PaddleDataloader``对象, 详见 :class:`~fastNLP.core.dataloaders.PaddleDataLoader`。 | ||||
根据 ds_or_db 的类型 ``[DataSet, DataBundle,Sequence[Dataset], Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | |||||
根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | |||||
* 当 ds_or_db 为 ``DataSet``时,``prepare_paddle_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 | * 当 ds_or_db 为 ``DataSet``时,``prepare_paddle_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 | ||||
帮你实例化一个 ``PaddleDataLoader`` 对象并返回该对象。 详见:class:`~fastNLP.core.dataloaders.PaddleDataLoader`。 | 帮你实例化一个 ``PaddleDataLoader`` 对象并返回该对象。 详见:class:`~fastNLP.core.dataloaders.PaddleDataLoader`。 | ||||
@@ -271,16 +271,12 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
``PaddleDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_paddle_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, | ``PaddleDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_paddle_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, | ||||
其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: PaddleDataLoader`` 组成 | 其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: PaddleDataLoader`` 组成 | ||||
``Dict[key, PaddleDataLoader]`` 的字典返回。 | ``Dict[key, PaddleDataLoader]`` 的字典返回。 | ||||
* 当 ds_or_db 为 ``Sequence[Dataset]`` 数据类型时, prepare_paddle_dataloader 会将 Sequence[0] 的数据集默认为 train 数据集对待, | |||||
会将 batch_size 和 sampler 作为参数, 而 Sequence[1:] 数据集均视为非 train 数据集对待,使用 non_train_size 和 non_train_sampler 作为参数。 | |||||
最终将所有实例化好的 ``PaddleDataLoader`` 组成 ``Sequence[PaddleDataLoader]`` 返回。 | |||||
::param ds_or_db: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。其取值只能为 ``[DataSet, DataBundle, | ::param ds_or_db: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。其取值只能为 ``[DataSet, DataBundle, | ||||
Sequence[DataSet], Dict[str, DataSet]]``. | |||||
Dict[str, DataSet]]``. | |||||
* ds_or_db 为 :class:`~fastNLP.core.dataset.DataSet`,返回值为:class:`~fastNLP.core.dataloaders.PaddleDataLoader` | * ds_or_db 为 :class:`~fastNLP.core.dataset.DataSet`,返回值为:class:`~fastNLP.core.dataloaders.PaddleDataLoader` | ||||
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, PaddleDataLoader]`` 的字典 | * ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, PaddleDataLoader]`` 的字典 | ||||
* ds_or_db 为 ``Sequence[DataSet]`` 序列, 返回值为 ``Sequence[PaddleDataLoader]`` 的序列 | |||||
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值也为 ``Dict[str, PaddleDataLoader]`` 的字典 | * ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值也为 ``Dict[str, PaddleDataLoader]`` 的字典 | ||||
:param feed_list: (list(Tensor)|tuple(Tensor)): feed Tensor list. | :param feed_list: (list(Tensor)|tuple(Tensor)): feed Tensor list. | ||||
@@ -321,14 +317,8 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
""" | """ | ||||
from fastNLP.io.data_bundle import DataBundle | from fastNLP.io.data_bundle import DataBundle | ||||
if isinstance(ds_or_db, Dataset): | |||||
dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, | |||||
batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, | |||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||||
return dl | |||||
elif isinstance(ds_or_db, DataBundle): | |||||
if isinstance(ds_or_db, DataBundle): | |||||
dl_bundle = {} | dl_bundle = {} | ||||
for name, ds in ds_or_db.iter_datasets(): | for name, ds in ds_or_db.iter_datasets(): | ||||
if 'train' in name: | if 'train' in name: | ||||
@@ -353,18 +343,6 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
timeout=timeout, worker_init_fn=worker_init_fn, | timeout=timeout, worker_init_fn=worker_init_fn, | ||||
persistent_workers=persistent_workers) | persistent_workers=persistent_workers) | ||||
return dl_bundle | return dl_bundle | ||||
elif isinstance(ds_or_db, Sequence): | |||||
ds_seq = [] | |||||
for idx, ds in enumerate(ds_or_db): | |||||
if idx > 0: | |||||
batch_size = non_train_batch_size if non_train_batch_size else batch_size | |||||
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | |||||
batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, | |||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||||
ds_seq.append(dl) | |||||
return ds_seq | |||||
elif isinstance(ds_or_db, Dict): | elif isinstance(ds_or_db, Dict): | ||||
ds_dict = {} | ds_dict = {} | ||||
@@ -387,5 +365,13 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
persistent_workers=persistent_workers) | persistent_workers=persistent_workers) | ||||
ds_dict[name] = dl | ds_dict[name] = dl | ||||
return ds_dict | return ds_dict | ||||
elif isinstance(ds_or_db, HasLenGetitemType): | |||||
dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, | |||||
batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, | |||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||||
return dl | |||||
else: | else: | ||||
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or sequence or mapping!") | |||||
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or mapping!") |
@@ -3,7 +3,8 @@ __all__ = [ | |||||
'prepare_torch_dataloader' | 'prepare_torch_dataloader' | ||||
] | ] | ||||
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List | |||||
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List, Any | |||||
from abc import ABC | |||||
from copy import deepcopy | from copy import deepcopy | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
@@ -12,9 +13,10 @@ from fastNLP.core.dataloaders.utils import indice_collate_wrapper | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler | from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler | ||||
from ..utils import _match_param | from ..utils import _match_param | ||||
from ..utils import HasLenGetitemType | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
from torch.utils.data import DataLoader, Sampler | |||||
from torch.utils.data import DataLoader, Sampler, Dataset | |||||
else: | else: | ||||
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | ||||
@@ -223,10 +225,10 @@ def prepare_torch_dataloader(ds_or_db, | |||||
persistent_workers: bool = False, | persistent_workers: bool = False, | ||||
non_train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | non_train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | ||||
non_train_batch_size: int = None) \ | non_train_batch_size: int = None) \ | ||||
-> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: | |||||
-> Union[TorchDataLoader, Dict[str, TorchDataLoader]]: | |||||
""" | """ | ||||
``prepare_torch_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``TorchDataloader``对象, 详见 :class:`~fastNLP.core.dataloaders.TorchDataLoader`。 | ``prepare_torch_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``TorchDataloader``对象, 详见 :class:`~fastNLP.core.dataloaders.TorchDataLoader`。 | ||||
根据 ds_or_db 的类型 ``[DataSet, DataBundle,Sequence[Dataset], Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | |||||
根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | |||||
* 当 ds_or_db 为 ``DataSet``时,``prepare_torch_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 | * 当 ds_or_db 为 ``DataSet``时,``prepare_torch_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 | ||||
帮你实例化一个 ``TorchDataLoader`` 对象并返回该对象。 详见:class:`~fastNLP.core.dataloaders.TorchDataLoader`。 | 帮你实例化一个 ``TorchDataLoader`` 对象并返回该对象。 详见:class:`~fastNLP.core.dataloaders.TorchDataLoader`。 | ||||
@@ -238,16 +240,12 @@ def prepare_torch_dataloader(ds_or_db, | |||||
``TorchDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_torch_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, | ``TorchDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_torch_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, | ||||
其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: TorchDataLoader`` 组成 | 其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: TorchDataLoader`` 组成 | ||||
``Dict[key, TorchDataLoader]`` 的字典返回。 | ``Dict[key, TorchDataLoader]`` 的字典返回。 | ||||
* 当 ds_or_db 为 ``Sequence[Dataset]`` 数据类型时, prepare_torch_dataloader 会将 Sequence[0] 的数据集默认为 train 数据集对待, | |||||
会将 batch_size 和 sampler 作为参数, 而 Sequence[1:] 数据集均视为非 train 数据集对待,使用 non_train_size 和 non_train_sampler 作为参数。 | |||||
最终将所有实例化好的 ``TorchDataLoader`` 组成 ``Sequence[TorchDataLoader]`` 返回。 | |||||
:param ds_or_db: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。其取值只能为 ``[DataSet, DataBundle, | :param ds_or_db: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。其取值只能为 ``[DataSet, DataBundle, | ||||
Sequence[DataSet], Dict[str, DataSet]]``. | |||||
Dict[str, DataSet]]``. | |||||
* ds_or_db 为 :class:`~fastNLP.core.dataset.DataSet`,返回值为:class:`~fastNLP.core.dataloaders.TorchDataLoader` | * ds_or_db 为 :class:`~fastNLP.core.dataset.DataSet`,返回值为:class:`~fastNLP.core.dataloaders.TorchDataLoader` | ||||
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典 | * ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典 | ||||
* ds_or_db 为 ``Sequence[DataSet]`` 序列, 返回值为 ``Sequence[TorchDataLoader]`` 的序列 | |||||
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值也为 ``Dict[str, TorchDataLoader]`` 的字典 | * ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值也为 ``Dict[str, TorchDataLoader]`` 的字典 | ||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | ||||
@@ -284,17 +282,8 @@ def prepare_torch_dataloader(ds_or_db, | |||||
""" | """ | ||||
from fastNLP.io import DataBundle | from fastNLP.io import DataBundle | ||||
if isinstance(ds_or_db, DataSet): | |||||
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | |||||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||||
) | |||||
return dl | |||||
elif isinstance(ds_or_db, DataBundle): | |||||
if isinstance(ds_or_db, DataBundle): | |||||
dl_bundle = {} | dl_bundle = {} | ||||
for name, ds in ds_or_db.iter_datasets(): | for name, ds in ds_or_db.iter_datasets(): | ||||
if 'train' in name: | if 'train' in name: | ||||
@@ -320,23 +309,6 @@ def prepare_torch_dataloader(ds_or_db, | |||||
) | ) | ||||
return dl_bundle | return dl_bundle | ||||
elif isinstance(ds_or_db, Sequence): | |||||
dl_bundle = [] | |||||
for idx, ds in enumerate(ds_or_db): | |||||
if idx > 0: | |||||
batch_size = non_train_batch_size if non_train_batch_size else batch_size | |||||
sampler = non_train_sampler if non_train_sampler else sampler | |||||
dl_bundle.append( | |||||
TorchDataLoader(dataset=ds, batch_size=batch_size, | |||||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||||
) | |||||
) | |||||
return dl_bundle | |||||
elif isinstance(ds_or_db, Mapping): | elif isinstance(ds_or_db, Mapping): | ||||
dl_bundle = {} | dl_bundle = {} | ||||
for name, ds in ds_or_db.items(): | for name, ds in ds_or_db.items(): | ||||
@@ -363,5 +335,16 @@ def prepare_torch_dataloader(ds_or_db, | |||||
) | ) | ||||
return dl_bundle | return dl_bundle | ||||
elif isinstance(ds_or_db, HasLenGetitemType): | |||||
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | |||||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||||
) | |||||
return dl | |||||
else: | else: | ||||
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or sequence or mapping!") | |||||
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or mapping!") |
@@ -1,4 +1,5 @@ | |||||
from typing import Callable | |||||
from typing import Callable, Any, Union | |||||
from abc import ABC | |||||
import inspect | import inspect | ||||
import ast | import ast | ||||
@@ -96,6 +97,20 @@ def _match_param(fun, call_fn:Callable, fn_name:str=None): | |||||
logger.debug(f"Exception happens when match parameters for {fn_name}: {e}") | logger.debug(f"Exception happens when match parameters for {fn_name}: {e}") | ||||
return None | return None | ||||
class HasLenGetitemType(ABC): | |||||
""" | |||||
判断是否实现了 __len__ 和 __getitem__ 方法的类 | |||||
""" | |||||
@classmethod | |||||
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | |||||
if cls is HasLenGetitemType: | |||||
flag = callable(getattr(subclass, '__getitem__', None)) and callable(getattr(subclass, '__len__', None)) | |||||
return flag | |||||
return NotImplemented | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
def demo(*args, **kwargs): | def demo(*args, **kwargs): | ||||
pass | pass | ||||
@@ -5,9 +5,10 @@ from fastNLP.envs import _module_available | |||||
if _module_available('datasets'): | if _module_available('datasets'): | ||||
from datasets import Dataset as HfDataset | from datasets import Dataset as HfDataset | ||||
from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader | |||||
from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader | |||||
from fastNLP.core.dataset import DataSet as Fdataset | from fastNLP.core.dataset import DataSet as Fdataset | ||||
from fastNLP.core.collators import Collator | from fastNLP.core.collators import Collator | ||||
from fastNLP.io.data_bundle import DataBundle | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | ||||
if _NEED_IMPORT_JITTOR: | if _NEED_IMPORT_JITTOR: | ||||
from jittor.dataset import Dataset | from jittor.dataset import Dataset | ||||
@@ -61,7 +62,6 @@ class TestJittor: | |||||
for batch in jtl1: | for batch in jtl1: | ||||
print(batch) | print(batch) | ||||
def test_huggingface_datasets(self): | def test_huggingface_datasets(self): | ||||
dataset = HfDataset.from_dict({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100}) | dataset = HfDataset.from_dict({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100}) | ||||
jtl = JittorDataLoader(dataset, batch_size=4, drop_last=True, shuffle=False) | jtl = JittorDataLoader(dataset, batch_size=4, drop_last=True, shuffle=False) | ||||
@@ -90,4 +90,37 @@ class TestJittor: | |||||
dl = MyDataset() | dl = MyDataset() | ||||
dl = dl.set_attrs(collate_batch=collate_bacth, batch_size=256) | dl = dl.set_attrs(collate_batch=collate_bacth, batch_size=256) | ||||
for batch in dl: | for batch in dl: | ||||
print(batch) | |||||
print(batch) | |||||
def test_prepare_jittor_dataloader(self): | |||||
# 测试 fastnlp 的 dataset | |||||
ds = Fdataset({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||||
dl = prepare_jittor_dataloader(ds, batch_size=8, shuffle=True, num_workers=2) | |||||
assert isinstance(dl, JittorDataLoader) | |||||
ds1 = Fdataset({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||||
dbl = DataBundle(datasets={'train': ds, 'val': ds1}) | |||||
dl_bundle = prepare_jittor_dataloader(dbl) | |||||
assert isinstance(dl_bundle['train'], JittorDataLoader) | |||||
assert isinstance(dl_bundle['val'], JittorDataLoader) | |||||
ds_dict = {'train_1': ds, 'val': ds1} | |||||
dl_dict = prepare_jittor_dataloader(ds_dict) | |||||
assert isinstance(dl_dict['train_1'], JittorDataLoader) | |||||
assert isinstance(dl_dict['val'], JittorDataLoader) | |||||
# 测试 jittor 的 dataset | |||||
ds1 = MyDataset() | |||||
dl1 = prepare_jittor_dataloader(ds1, batch_size=8, shuffle=True, num_workers=2) | |||||
assert isinstance(dl1, JittorDataLoader) | |||||
ds2 = MyDataset() | |||||
dbl1 = DataBundle(datasets={'train': ds1, 'val': ds2}) | |||||
dl_bundle1 = prepare_jittor_dataloader(dbl1) | |||||
assert isinstance(dl_bundle1['train'], JittorDataLoader) | |||||
assert isinstance(dl_bundle1['val'], JittorDataLoader) | |||||
ds_dict1 = {'train_1': ds1, 'val': ds2} | |||||
dl_dict1 = prepare_jittor_dataloader(ds_dict1) | |||||
assert isinstance(dl_dict1['train_1'], JittorDataLoader) | |||||
assert isinstance(dl_dict1['val'], JittorDataLoader) |
@@ -1,8 +1,9 @@ | |||||
import pytest | import pytest | ||||
import numpy as np | import numpy as np | ||||
from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader | |||||
from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader, prepare_paddle_dataloader | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.io.data_bundle import DataBundle | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.collators import Collator | from fastNLP.core.collators import Collator | ||||
@@ -90,4 +91,36 @@ class TestPaddle: | |||||
ds = PaddleArgMaxDataset(100, 2) | ds = PaddleArgMaxDataset(100, 2) | ||||
dl = DataLoader(ds, places=None, collate_fn=Collator(), batch_size=4) | dl = DataLoader(ds, places=None, collate_fn=Collator(), batch_size=4) | ||||
for batch in dl: | for batch in dl: | ||||
print(batch) | |||||
print(batch) | |||||
def test_prepare_paddle_dataloader(self): | |||||
# 测试 fastNLP 的 dataset | |||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||||
dl = prepare_paddle_dataloader(ds, batch_size=8, shuffle=True, num_workers=2) | |||||
assert isinstance(dl, PaddleDataLoader) | |||||
ds1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||||
dbl = DataBundle(datasets={'train': ds, 'val': ds1}) | |||||
dl_bundle = prepare_paddle_dataloader(dbl) | |||||
assert isinstance(dl_bundle['train'], PaddleDataLoader) | |||||
assert isinstance(dl_bundle['val'], PaddleDataLoader) | |||||
ds_dict = {'train_1': ds, 'val': ds1} | |||||
dl_dict = prepare_paddle_dataloader(ds_dict) | |||||
assert isinstance(dl_dict['train_1'], PaddleDataLoader) | |||||
assert isinstance(dl_dict['val'], PaddleDataLoader) | |||||
ds2 = RandomDataset() | |||||
dl1 = prepare_paddle_dataloader(ds2, batch_size=8, shuffle=True, num_workers=2) | |||||
assert isinstance(dl1, PaddleDataLoader) | |||||
ds3 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||||
dbl1 = DataBundle(datasets={'train': ds2, 'val': ds3}) | |||||
dl_bundle1 = prepare_paddle_dataloader(dbl1) | |||||
assert isinstance(dl_bundle1['train'], PaddleDataLoader) | |||||
assert isinstance(dl_bundle1['val'], PaddleDataLoader) | |||||
ds_dict1 = {'train_1': ds2, 'val': ds3} | |||||
dl_dict1 = prepare_paddle_dataloader(ds_dict1) | |||||
assert isinstance(dl_dict1['train_1'], PaddleDataLoader) | |||||
assert isinstance(dl_dict1['val'], PaddleDataLoader) |
@@ -96,6 +96,7 @@ class TestFdl: | |||||
assert batch['y'] == [1, 0, 1] | assert batch['y'] == [1, 0, 1] | ||||
def test_prepare_torch_dataloader(self): | def test_prepare_torch_dataloader(self): | ||||
# 测试 fastNLP 的 dataset | |||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ||||
dl = prepare_torch_dataloader(ds, batch_size=8, shuffle=True, num_workers=2) | dl = prepare_torch_dataloader(ds, batch_size=8, shuffle=True, num_workers=2) | ||||
assert isinstance(dl, TorchDataLoader) | assert isinstance(dl, TorchDataLoader) | ||||
@@ -111,10 +112,39 @@ class TestFdl: | |||||
assert isinstance(dl_dict['train_1'], TorchDataLoader) | assert isinstance(dl_dict['train_1'], TorchDataLoader) | ||||
assert isinstance(dl_dict['val'], TorchDataLoader) | assert isinstance(dl_dict['val'], TorchDataLoader) | ||||
sequence = [ds, ds1] | |||||
seq_ds = prepare_torch_dataloader(sequence) | |||||
assert isinstance(seq_ds[0], TorchDataLoader) | |||||
assert isinstance(seq_ds[1], TorchDataLoader) | |||||
# 测试其他 dataset | |||||
class _DataSet: | |||||
def __init__(self): | |||||
pass | |||||
def __getitem__(self, item): | |||||
return np.random.randn(5), [[1, 2], [2, 3, 4]] | |||||
def __len__(self): | |||||
return 10 | |||||
def __getattribute__(self, item): | |||||
return object.__getattribute__(self, item) | |||||
ds2 = _DataSet() | |||||
dl1 = prepare_torch_dataloader(ds2, batch_size=8, shuffle=True, num_workers=2) | |||||
assert isinstance(dl1, TorchDataLoader) | |||||
ds3 = _DataSet() | |||||
dbl1 = DataBundle(datasets={'train': ds2, 'val': ds3}) | |||||
dl_bundle1 = prepare_torch_dataloader(dbl1) | |||||
assert isinstance(dl_bundle1['train'], TorchDataLoader) | |||||
assert isinstance(dl_bundle1['val'], TorchDataLoader) | |||||
ds_dict1 = {'train_1': ds2, 'val': ds3} | |||||
dl_dict1 = prepare_torch_dataloader(ds_dict1) | |||||
assert isinstance(dl_dict1['train_1'], TorchDataLoader) | |||||
assert isinstance(dl_dict1['val'], TorchDataLoader) | |||||
# sequence = [ds, ds1] | |||||
# seq_ds = prepare_torch_dataloader(sequence) | |||||
# assert isinstance(seq_ds[0], TorchDataLoader) | |||||
# assert isinstance(seq_ds[1], TorchDataLoader) | |||||
def test_get_backend(self): | def test_get_backend(self): | ||||
from fastNLP.core.collators import Collator | from fastNLP.core.collators import Collator | ||||