diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index 75e3ad79..0c51c37b 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -19,6 +19,7 @@ else: from fastNLP.core.collators import Collator from fastNLP.core.dataloaders.utils import indice_collate_wrapper from fastNLP.core.dataset import DataSet as FDataSet +from ..utils import HasLenGetitemType class _JittorDataset(Dataset): @@ -204,10 +205,10 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, collate_fn: Union[None, str, Callable] = "auto", non_train_batch_size: int = None) \ - -> Union[Sequence[JittorDataLoader], Dict[str, JittorDataLoader], JittorDataLoader]: + -> Union[Dict[str, JittorDataLoader], JittorDataLoader]: """ ``prepare_jittor_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 :class:`JittorDataLoader` 对象, 详见 :class:`~fastNLP.core.dataloaders.JittorDataLoader`。 - 根据 ds_or_db 的类型 ``[DataSet, DataBundle,Sequence[Dataset], Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: + 根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: * 当 ds_or_db 为 ``DataSet`` 时,``prepare_jittor_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 帮你实例化一个 :class:`JittorDataLoader` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.JittorDataLoader`。 @@ -219,19 +220,15 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa :class:`JittorDataLoader` 对象;当 key 中包含'train'字符串时,``prepare_Jittor_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, 其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: JittorDataLoader`` 组成 ``Dict[key, JittorDataLoader]`` 的字典返回。 - * 当 ds_or_db 为 ``Sequence[Dataset]`` 数据类型时, prepare_jittor_dataloader 会将 Sequence[0] 的数据集默认为 train 数据集对待, - 会将 batch_size 和 sampler 作为参数, 而 Sequence[1:] 数据集均视为非 train 数据集对待,使用 non_train_size 和 non_train_sampler 作为参数。 - 最终将所有实例化好的 :class:`JittorDataLoader` 组成 ``Sequence[JittorDataLoader]`` 返回。 :param ds_or_db: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。其取值只能为 ``[DataSet, DataBundle, - Sequence[DataSet], Dict[str, DataSet]]``. + Dict[str, DataSet]]``. * ds_or_db 为 :class:`~fastNLP.core.dataset.DataSet`,返回值为 :class:`~fastNLP.core.dataloaders.JittorDataLoader` * ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 :class:`Dict[str, JittorDataLoader]` 的字典 - * ds_or_db 为 :class:`Sequence[DataSet]` 序列, 返回值为 :class:`Sequence[JittorDataLoader]` 的序列 * ds_or_db 为 :class:`Dict[str, DataSet]` 字典, 返回值也为 :class:`Dict[str, JittorDataLoader]` 的字典 - :param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 :class:`Dict`, :class:`Sequence` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数 + :param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 :class:`Dict` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数 设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。 默认为 ``16``。 :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 :param shuffle: 是否打乱数据集, 默认为 ``False``。 @@ -253,28 +250,23 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa * collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 dataset 的一条数据;该 Callable 函数还应当返回一个对象。 - :return: 返回数据类型为 :class:`Sequence[JittorDataLoader]` , :class:`Dict[str, JittorDataLoader]`, :class:`JittorDataLoader` 其中之一,根据输入 + :return: 返回数据类型为 :class:`Dict[str, JittorDataLoader]`, :class:`JittorDataLoader` 其中之一,根据输入 ``ds_or_db`` 变化而变化。 """ from fastNLP.io.data_bundle import DataBundle - if isinstance(ds_or_db, Dataset): - dl = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=shuffle, - drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, - stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, - collate_fn=collate_fn) - return dl - elif isinstance(ds_or_db, DataBundle): + + if isinstance(ds_or_db, DataBundle): dl_bundle = {} for name, ds in ds_or_db.iter_datasets(): if 'train' in name: - dl_bundle[name] = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=shuffle, + dl_bundle[name] = JittorDataLoader(ds, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, collate_fn=collate_fn) else: - dl_bundle[name] = JittorDataLoader(ds_or_db, + dl_bundle[name] = JittorDataLoader(ds, batch_size=non_train_batch_size if non_train_batch_size else batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers, @@ -283,17 +275,6 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa endless=endless, collate_fn=collate_fn) return dl_bundle - elif isinstance(ds_or_db, Sequence): - ds_seq = [] - for idx, ds in enumerate(ds_or_db): - if idx > 0: - batch_size = non_train_batch_size if non_train_batch_size else batch_size - dl = JittorDataLoader(ds, batch_size=batch_size, shuffle=shuffle, - drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, - stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, - collate_fn=collate_fn) - ds_seq.append(dl) - return ds_seq elif isinstance(ds_or_db, Dict): ds_dict = {} @@ -304,7 +285,7 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, collate_fn=collate_fn) else: - dl = JittorDataLoader(ds_or_db, + dl = JittorDataLoader(ds, batch_size=non_train_batch_size if non_train_batch_size else batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers, @@ -314,5 +295,13 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa collate_fn=collate_fn) ds_dict[name] = dl return ds_dict + + elif isinstance(ds_or_db, HasLenGetitemType): + dl = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=shuffle, + drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, + stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, + collate_fn=collate_fn) + return dl + else: - raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or sequence or mapping!") + raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or mapping!") diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py index b89ee40e..68992e50 100644 --- a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py @@ -19,7 +19,7 @@ from fastNLP.core.collators.collator import Collator from fastNLP.core.dataloaders.utils import indice_collate_wrapper from fastNLP.core.dataset import DataSet as FDataSet from fastNLP.core.samplers import ReproducibleBatchSampler, RandomBatchSampler -from ..utils import _match_param +from ..utils import _match_param, HasLenGetitemType class _PaddleDataset(Dataset): @@ -256,10 +256,10 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, use_shared_memory: bool = True, timeout: int = 0, worker_init_fn: Callable = None, persistent_workers=False, non_train_batch_size: int = None) \ - -> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: + -> Union[Dict[str, PaddleDataLoader], PaddleDataLoader]: """ ``prepare_paddle_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``PaddleDataloader``对象, 详见 :class:`~fastNLP.core.dataloaders.PaddleDataLoader`。 - 根据 ds_or_db 的类型 ``[DataSet, DataBundle,Sequence[Dataset], Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: + 根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: * 当 ds_or_db 为 ``DataSet``时,``prepare_paddle_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 帮你实例化一个 ``PaddleDataLoader`` 对象并返回该对象。 详见:class:`~fastNLP.core.dataloaders.PaddleDataLoader`。 @@ -271,16 +271,12 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, ``PaddleDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_paddle_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, 其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: PaddleDataLoader`` 组成 ``Dict[key, PaddleDataLoader]`` 的字典返回。 - * 当 ds_or_db 为 ``Sequence[Dataset]`` 数据类型时, prepare_paddle_dataloader 会将 Sequence[0] 的数据集默认为 train 数据集对待, - 会将 batch_size 和 sampler 作为参数, 而 Sequence[1:] 数据集均视为非 train 数据集对待,使用 non_train_size 和 non_train_sampler 作为参数。 - 最终将所有实例化好的 ``PaddleDataLoader`` 组成 ``Sequence[PaddleDataLoader]`` 返回。 ::param ds_or_db: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。其取值只能为 ``[DataSet, DataBundle, - Sequence[DataSet], Dict[str, DataSet]]``. + Dict[str, DataSet]]``. * ds_or_db 为 :class:`~fastNLP.core.dataset.DataSet`,返回值为:class:`~fastNLP.core.dataloaders.PaddleDataLoader` * ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, PaddleDataLoader]`` 的字典 - * ds_or_db 为 ``Sequence[DataSet]`` 序列, 返回值为 ``Sequence[PaddleDataLoader]`` 的序列 * ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值也为 ``Dict[str, PaddleDataLoader]`` 的字典 :param feed_list: (list(Tensor)|tuple(Tensor)): feed Tensor list. @@ -321,14 +317,8 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, """ from fastNLP.io.data_bundle import DataBundle - if isinstance(ds_or_db, Dataset): - dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, - batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, - drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, - use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, - timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) - return dl - elif isinstance(ds_or_db, DataBundle): + + if isinstance(ds_or_db, DataBundle): dl_bundle = {} for name, ds in ds_or_db.iter_datasets(): if 'train' in name: @@ -353,18 +343,6 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) return dl_bundle - elif isinstance(ds_or_db, Sequence): - ds_seq = [] - for idx, ds in enumerate(ds_or_db): - if idx > 0: - batch_size = non_train_batch_size if non_train_batch_size else batch_size - dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, - batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, - drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, - use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, - timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) - ds_seq.append(dl) - return ds_seq elif isinstance(ds_or_db, Dict): ds_dict = {} @@ -387,5 +365,13 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, persistent_workers=persistent_workers) ds_dict[name] = dl return ds_dict + + elif isinstance(ds_or_db, HasLenGetitemType): + dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, + batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, + drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, + use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, + timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) + return dl else: - raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or sequence or mapping!") + raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or mapping!") diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index d18dbc84..707c54ca 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -3,7 +3,8 @@ __all__ = [ 'prepare_torch_dataloader' ] -from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List +from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List, Any +from abc import ABC from copy import deepcopy from fastNLP.core.dataset import DataSet @@ -12,9 +13,10 @@ from fastNLP.core.dataloaders.utils import indice_collate_wrapper from fastNLP.envs.imports import _NEED_IMPORT_TORCH from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler from ..utils import _match_param +from ..utils import HasLenGetitemType if _NEED_IMPORT_TORCH: - from torch.utils.data import DataLoader, Sampler + from torch.utils.data import DataLoader, Sampler, Dataset else: from fastNLP.core.utils.dummy_class import DummyClass as DataLoader @@ -223,10 +225,10 @@ def prepare_torch_dataloader(ds_or_db, persistent_workers: bool = False, non_train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, non_train_batch_size: int = None) \ - -> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: + -> Union[TorchDataLoader, Dict[str, TorchDataLoader]]: """ ``prepare_torch_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``TorchDataloader``对象, 详见 :class:`~fastNLP.core.dataloaders.TorchDataLoader`。 - 根据 ds_or_db 的类型 ``[DataSet, DataBundle,Sequence[Dataset], Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: + 根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: * 当 ds_or_db 为 ``DataSet``时,``prepare_torch_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 帮你实例化一个 ``TorchDataLoader`` 对象并返回该对象。 详见:class:`~fastNLP.core.dataloaders.TorchDataLoader`。 @@ -238,16 +240,12 @@ def prepare_torch_dataloader(ds_or_db, ``TorchDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_torch_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, 其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: TorchDataLoader`` 组成 ``Dict[key, TorchDataLoader]`` 的字典返回。 - * 当 ds_or_db 为 ``Sequence[Dataset]`` 数据类型时, prepare_torch_dataloader 会将 Sequence[0] 的数据集默认为 train 数据集对待, - 会将 batch_size 和 sampler 作为参数, 而 Sequence[1:] 数据集均视为非 train 数据集对待,使用 non_train_size 和 non_train_sampler 作为参数。 - 最终将所有实例化好的 ``TorchDataLoader`` 组成 ``Sequence[TorchDataLoader]`` 返回。 :param ds_or_db: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。其取值只能为 ``[DataSet, DataBundle, - Sequence[DataSet], Dict[str, DataSet]]``. + Dict[str, DataSet]]``. * ds_or_db 为 :class:`~fastNLP.core.dataset.DataSet`,返回值为:class:`~fastNLP.core.dataloaders.TorchDataLoader` * ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典 - * ds_or_db 为 ``Sequence[DataSet]`` 序列, 返回值为 ``Sequence[TorchDataLoader]`` 的序列 * ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值也为 ``Dict[str, TorchDataLoader]`` 的字典 :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 @@ -284,17 +282,8 @@ def prepare_torch_dataloader(ds_or_db, """ from fastNLP.io import DataBundle - if isinstance(ds_or_db, DataSet): - dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, - shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, - num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, - drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, - multiprocessing_context=multiprocessing_context, generator=generator, - prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, - ) - return dl - elif isinstance(ds_or_db, DataBundle): + if isinstance(ds_or_db, DataBundle): dl_bundle = {} for name, ds in ds_or_db.iter_datasets(): if 'train' in name: @@ -320,23 +309,6 @@ def prepare_torch_dataloader(ds_or_db, ) return dl_bundle - elif isinstance(ds_or_db, Sequence): - dl_bundle = [] - for idx, ds in enumerate(ds_or_db): - if idx > 0: - batch_size = non_train_batch_size if non_train_batch_size else batch_size - sampler = non_train_sampler if non_train_sampler else sampler - dl_bundle.append( - TorchDataLoader(dataset=ds, batch_size=batch_size, - shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, - num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, - drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, - multiprocessing_context=multiprocessing_context, generator=generator, - prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, - ) - ) - return dl_bundle - elif isinstance(ds_or_db, Mapping): dl_bundle = {} for name, ds in ds_or_db.items(): @@ -363,5 +335,16 @@ def prepare_torch_dataloader(ds_or_db, ) return dl_bundle + + elif isinstance(ds_or_db, HasLenGetitemType): + dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, + shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, + num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, + drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, + multiprocessing_context=multiprocessing_context, generator=generator, + prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, + ) + return dl + else: - raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or sequence or mapping!") + raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or mapping!") diff --git a/fastNLP/core/dataloaders/utils.py b/fastNLP/core/dataloaders/utils.py index 6c6118d9..d905101f 100644 --- a/fastNLP/core/dataloaders/utils.py +++ b/fastNLP/core/dataloaders/utils.py @@ -1,4 +1,5 @@ -from typing import Callable +from typing import Callable, Any, Union +from abc import ABC import inspect import ast @@ -96,6 +97,20 @@ def _match_param(fun, call_fn:Callable, fn_name:str=None): logger.debug(f"Exception happens when match parameters for {fn_name}: {e}") return None + +class HasLenGetitemType(ABC): + """ + 判断是否实现了 __len__ 和 __getitem__ 方法的类 + + """ + @classmethod + def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: + if cls is HasLenGetitemType: + flag = callable(getattr(subclass, '__getitem__', None)) and callable(getattr(subclass, '__len__', None)) + return flag + return NotImplemented + + if __name__ == '__main__': def demo(*args, **kwargs): pass diff --git a/tests/core/dataloaders/jittor_dataloader/test_fdl.py b/tests/core/dataloaders/jittor_dataloader/test_fdl.py index a455c265..13bd94ae 100644 --- a/tests/core/dataloaders/jittor_dataloader/test_fdl.py +++ b/tests/core/dataloaders/jittor_dataloader/test_fdl.py @@ -5,9 +5,10 @@ from fastNLP.envs import _module_available if _module_available('datasets'): from datasets import Dataset as HfDataset -from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader +from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader from fastNLP.core.dataset import DataSet as Fdataset from fastNLP.core.collators import Collator +from fastNLP.io.data_bundle import DataBundle from fastNLP.envs.imports import _NEED_IMPORT_JITTOR if _NEED_IMPORT_JITTOR: from jittor.dataset import Dataset @@ -61,7 +62,6 @@ class TestJittor: for batch in jtl1: print(batch) - def test_huggingface_datasets(self): dataset = HfDataset.from_dict({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100}) jtl = JittorDataLoader(dataset, batch_size=4, drop_last=True, shuffle=False) @@ -90,4 +90,37 @@ class TestJittor: dl = MyDataset() dl = dl.set_attrs(collate_batch=collate_bacth, batch_size=256) for batch in dl: - print(batch) \ No newline at end of file + print(batch) + + def test_prepare_jittor_dataloader(self): + # 测试 fastnlp 的 dataset + ds = Fdataset({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + dl = prepare_jittor_dataloader(ds, batch_size=8, shuffle=True, num_workers=2) + assert isinstance(dl, JittorDataLoader) + + ds1 = Fdataset({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + dbl = DataBundle(datasets={'train': ds, 'val': ds1}) + dl_bundle = prepare_jittor_dataloader(dbl) + assert isinstance(dl_bundle['train'], JittorDataLoader) + assert isinstance(dl_bundle['val'], JittorDataLoader) + + ds_dict = {'train_1': ds, 'val': ds1} + dl_dict = prepare_jittor_dataloader(ds_dict) + assert isinstance(dl_dict['train_1'], JittorDataLoader) + assert isinstance(dl_dict['val'], JittorDataLoader) + + # 测试 jittor 的 dataset + ds1 = MyDataset() + dl1 = prepare_jittor_dataloader(ds1, batch_size=8, shuffle=True, num_workers=2) + assert isinstance(dl1, JittorDataLoader) + + ds2 = MyDataset() + dbl1 = DataBundle(datasets={'train': ds1, 'val': ds2}) + dl_bundle1 = prepare_jittor_dataloader(dbl1) + assert isinstance(dl_bundle1['train'], JittorDataLoader) + assert isinstance(dl_bundle1['val'], JittorDataLoader) + + ds_dict1 = {'train_1': ds1, 'val': ds2} + dl_dict1 = prepare_jittor_dataloader(ds_dict1) + assert isinstance(dl_dict1['train_1'], JittorDataLoader) + assert isinstance(dl_dict1['val'], JittorDataLoader) diff --git a/tests/core/dataloaders/paddle_dataloader/test_fdl.py b/tests/core/dataloaders/paddle_dataloader/test_fdl.py index 1a90aa11..f514ece8 100644 --- a/tests/core/dataloaders/paddle_dataloader/test_fdl.py +++ b/tests/core/dataloaders/paddle_dataloader/test_fdl.py @@ -1,8 +1,9 @@ import pytest import numpy as np -from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader +from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader, prepare_paddle_dataloader from fastNLP.core.dataset import DataSet +from fastNLP.io.data_bundle import DataBundle from fastNLP.core.log import logger from fastNLP.core.collators import Collator @@ -90,4 +91,36 @@ class TestPaddle: ds = PaddleArgMaxDataset(100, 2) dl = DataLoader(ds, places=None, collate_fn=Collator(), batch_size=4) for batch in dl: - print(batch) \ No newline at end of file + print(batch) + + def test_prepare_paddle_dataloader(self): + # 测试 fastNLP 的 dataset + ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + dl = prepare_paddle_dataloader(ds, batch_size=8, shuffle=True, num_workers=2) + assert isinstance(dl, PaddleDataLoader) + + ds1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + dbl = DataBundle(datasets={'train': ds, 'val': ds1}) + dl_bundle = prepare_paddle_dataloader(dbl) + assert isinstance(dl_bundle['train'], PaddleDataLoader) + assert isinstance(dl_bundle['val'], PaddleDataLoader) + + ds_dict = {'train_1': ds, 'val': ds1} + dl_dict = prepare_paddle_dataloader(ds_dict) + assert isinstance(dl_dict['train_1'], PaddleDataLoader) + assert isinstance(dl_dict['val'], PaddleDataLoader) + + ds2 = RandomDataset() + dl1 = prepare_paddle_dataloader(ds2, batch_size=8, shuffle=True, num_workers=2) + assert isinstance(dl1, PaddleDataLoader) + + ds3 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + dbl1 = DataBundle(datasets={'train': ds2, 'val': ds3}) + dl_bundle1 = prepare_paddle_dataloader(dbl1) + assert isinstance(dl_bundle1['train'], PaddleDataLoader) + assert isinstance(dl_bundle1['val'], PaddleDataLoader) + + ds_dict1 = {'train_1': ds2, 'val': ds3} + dl_dict1 = prepare_paddle_dataloader(ds_dict1) + assert isinstance(dl_dict1['train_1'], PaddleDataLoader) + assert isinstance(dl_dict1['val'], PaddleDataLoader) diff --git a/tests/core/dataloaders/torch_dataloader/test_fdl.py b/tests/core/dataloaders/torch_dataloader/test_fdl.py index b53790bb..1be34a1a 100644 --- a/tests/core/dataloaders/torch_dataloader/test_fdl.py +++ b/tests/core/dataloaders/torch_dataloader/test_fdl.py @@ -96,6 +96,7 @@ class TestFdl: assert batch['y'] == [1, 0, 1] def test_prepare_torch_dataloader(self): + # 测试 fastNLP 的 dataset ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) dl = prepare_torch_dataloader(ds, batch_size=8, shuffle=True, num_workers=2) assert isinstance(dl, TorchDataLoader) @@ -111,10 +112,39 @@ class TestFdl: assert isinstance(dl_dict['train_1'], TorchDataLoader) assert isinstance(dl_dict['val'], TorchDataLoader) - sequence = [ds, ds1] - seq_ds = prepare_torch_dataloader(sequence) - assert isinstance(seq_ds[0], TorchDataLoader) - assert isinstance(seq_ds[1], TorchDataLoader) + # 测试其他 dataset + class _DataSet: + + def __init__(self): + pass + + def __getitem__(self, item): + return np.random.randn(5), [[1, 2], [2, 3, 4]] + + def __len__(self): + return 10 + + def __getattribute__(self, item): + return object.__getattribute__(self, item) + + ds2 = _DataSet() + dl1 = prepare_torch_dataloader(ds2, batch_size=8, shuffle=True, num_workers=2) + assert isinstance(dl1, TorchDataLoader) + + ds3 = _DataSet() + dbl1 = DataBundle(datasets={'train': ds2, 'val': ds3}) + dl_bundle1 = prepare_torch_dataloader(dbl1) + assert isinstance(dl_bundle1['train'], TorchDataLoader) + assert isinstance(dl_bundle1['val'], TorchDataLoader) + + ds_dict1 = {'train_1': ds2, 'val': ds3} + dl_dict1 = prepare_torch_dataloader(ds_dict1) + assert isinstance(dl_dict1['train_1'], TorchDataLoader) + assert isinstance(dl_dict1['val'], TorchDataLoader) + # sequence = [ds, ds1] + # seq_ds = prepare_torch_dataloader(sequence) + # assert isinstance(seq_ds[0], TorchDataLoader) + # assert isinstance(seq_ds[1], TorchDataLoader) def test_get_backend(self): from fastNLP.core.collators import Collator