Browse Source

修改fdl 的prepare_*_dataloader,使其支持fastnlp的dataset

tags/v1.0.0alpha
MorningForest 2 years ago
parent
commit
ec17562f20
7 changed files with 176 additions and 107 deletions
  1. +20
    -31
      fastNLP/core/dataloaders/jittor_dataloader/fdl.py
  2. +15
    -29
      fastNLP/core/dataloaders/paddle_dataloader/fdl.py
  3. +20
    -37
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  4. +16
    -1
      fastNLP/core/dataloaders/utils.py
  5. +36
    -3
      tests/core/dataloaders/jittor_dataloader/test_fdl.py
  6. +35
    -2
      tests/core/dataloaders/paddle_dataloader/test_fdl.py
  7. +34
    -4
      tests/core/dataloaders/torch_dataloader/test_fdl.py

+ 20
- 31
fastNLP/core/dataloaders/jittor_dataloader/fdl.py View File

@@ -19,6 +19,7 @@ else:
from fastNLP.core.collators import Collator
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
from fastNLP.core.dataset import DataSet as FDataSet
from ..utils import HasLenGetitemType


class _JittorDataset(Dataset):
@@ -204,10 +205,10 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False,
collate_fn: Union[None, str, Callable] = "auto",
non_train_batch_size: int = None) \
-> Union[Sequence[JittorDataLoader], Dict[str, JittorDataLoader], JittorDataLoader]:
-> Union[Dict[str, JittorDataLoader], JittorDataLoader]:
"""
``prepare_jittor_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 :class:`JittorDataLoader` 对象, 详见 :class:`~fastNLP.core.dataloaders.JittorDataLoader`。
根据 ds_or_db 的类型 ``[DataSet, DataBundle,Sequence[Dataset], Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下:
根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下:

* 当 ds_or_db 为 ``DataSet`` 时,``prepare_jittor_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来
帮你实例化一个 :class:`JittorDataLoader` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.JittorDataLoader`。
@@ -219,19 +220,15 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa
:class:`JittorDataLoader` 对象;当 key 中包含'train'字符串时,``prepare_Jittor_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数,
其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: JittorDataLoader`` 组成
``Dict[key, JittorDataLoader]`` 的字典返回。
* 当 ds_or_db 为 ``Sequence[Dataset]`` 数据类型时, prepare_jittor_dataloader 会将 Sequence[0] 的数据集默认为 train 数据集对待,
会将 batch_size 和 sampler 作为参数, 而 Sequence[1:] 数据集均视为非 train 数据集对待,使用 non_train_size 和 non_train_sampler 作为参数。
最终将所有实例化好的 :class:`JittorDataLoader` 组成 ``Sequence[JittorDataLoader]`` 返回。

:param ds_or_db: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。其取值只能为 ``[DataSet, DataBundle,
Sequence[DataSet], Dict[str, DataSet]]``.
Dict[str, DataSet]]``.

* ds_or_db 为 :class:`~fastNLP.core.dataset.DataSet`,返回值为 :class:`~fastNLP.core.dataloaders.JittorDataLoader`
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 :class:`Dict[str, JittorDataLoader]` 的字典
* ds_or_db 为 :class:`Sequence[DataSet]` 序列, 返回值为 :class:`Sequence[JittorDataLoader]` 的序列
* ds_or_db 为 :class:`Dict[str, DataSet]` 字典, 返回值也为 :class:`Dict[str, JittorDataLoader]` 的字典
:param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 :class:`Dict`, :class:`Sequence` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数
:param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 :class:`Dict` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数
设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。 默认为 ``16``。
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
:param shuffle: 是否打乱数据集, 默认为 ``False``。
@@ -253,28 +250,23 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa
* collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是
dataset 的一条数据;该 Callable 函数还应当返回一个对象。
:return: 返回数据类型为 :class:`Sequence[JittorDataLoader]` , :class:`Dict[str, JittorDataLoader]`, :class:`JittorDataLoader` 其中之一,根据输入
:return: 返回数据类型为 :class:`Dict[str, JittorDataLoader]`, :class:`JittorDataLoader` 其中之一,根据输入
``ds_or_db`` 变化而变化。
"""
from fastNLP.io.data_bundle import DataBundle
if isinstance(ds_or_db, Dataset):
dl = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=shuffle,
drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size,
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless,
collate_fn=collate_fn)
return dl
elif isinstance(ds_or_db, DataBundle):

if isinstance(ds_or_db, DataBundle):
dl_bundle = {}
for name, ds in ds_or_db.iter_datasets():
if 'train' in name:
dl_bundle[name] = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=shuffle,
dl_bundle[name] = JittorDataLoader(ds, batch_size=batch_size, shuffle=shuffle,
drop_last=drop_last, num_workers=num_workers,
buffer_size=buffer_size,
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array,
endless=endless,
collate_fn=collate_fn)
else:
dl_bundle[name] = JittorDataLoader(ds_or_db,
dl_bundle[name] = JittorDataLoader(ds,
batch_size=non_train_batch_size if non_train_batch_size else batch_size,
shuffle=shuffle,
drop_last=drop_last, num_workers=num_workers,
@@ -283,17 +275,6 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa
endless=endless,
collate_fn=collate_fn)
return dl_bundle
elif isinstance(ds_or_db, Sequence):
ds_seq = []
for idx, ds in enumerate(ds_or_db):
if idx > 0:
batch_size = non_train_batch_size if non_train_batch_size else batch_size
dl = JittorDataLoader(ds, batch_size=batch_size, shuffle=shuffle,
drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size,
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless,
collate_fn=collate_fn)
ds_seq.append(dl)
return ds_seq

elif isinstance(ds_or_db, Dict):
ds_dict = {}
@@ -304,7 +285,7 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless,
collate_fn=collate_fn)
else:
dl = JittorDataLoader(ds_or_db,
dl = JittorDataLoader(ds,
batch_size=non_train_batch_size if non_train_batch_size else batch_size,
shuffle=shuffle,
drop_last=drop_last, num_workers=num_workers,
@@ -314,5 +295,13 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa
collate_fn=collate_fn)
ds_dict[name] = dl
return ds_dict

elif isinstance(ds_or_db, HasLenGetitemType):
dl = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=shuffle,
drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size,
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless,
collate_fn=collate_fn)
return dl

else:
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or sequence or mapping!")
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or mapping!")

+ 15
- 29
fastNLP/core/dataloaders/paddle_dataloader/fdl.py View File

@@ -19,7 +19,7 @@ from fastNLP.core.collators.collator import Collator
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
from fastNLP.core.dataset import DataSet as FDataSet
from fastNLP.core.samplers import ReproducibleBatchSampler, RandomBatchSampler
from ..utils import _match_param
from ..utils import _match_param, HasLenGetitemType


class _PaddleDataset(Dataset):
@@ -256,10 +256,10 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
use_shared_memory: bool = True, timeout: int = 0,
worker_init_fn: Callable = None, persistent_workers=False,
non_train_batch_size: int = None) \
-> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]:
-> Union[Dict[str, PaddleDataLoader], PaddleDataLoader]:
"""
``prepare_paddle_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``PaddleDataloader``对象, 详见 :class:`~fastNLP.core.dataloaders.PaddleDataLoader`。
根据 ds_or_db 的类型 ``[DataSet, DataBundle,Sequence[Dataset], Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下:
根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下:

* 当 ds_or_db 为 ``DataSet``时,``prepare_paddle_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来
帮你实例化一个 ``PaddleDataLoader`` 对象并返回该对象。 详见:class:`~fastNLP.core.dataloaders.PaddleDataLoader`。
@@ -271,16 +271,12 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
``PaddleDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_paddle_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数,
其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: PaddleDataLoader`` 组成
``Dict[key, PaddleDataLoader]`` 的字典返回。
* 当 ds_or_db 为 ``Sequence[Dataset]`` 数据类型时, prepare_paddle_dataloader 会将 Sequence[0] 的数据集默认为 train 数据集对待,
会将 batch_size 和 sampler 作为参数, 而 Sequence[1:] 数据集均视为非 train 数据集对待,使用 non_train_size 和 non_train_sampler 作为参数。
最终将所有实例化好的 ``PaddleDataLoader`` 组成 ``Sequence[PaddleDataLoader]`` 返回。

::param ds_or_db: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。其取值只能为 ``[DataSet, DataBundle,
Sequence[DataSet], Dict[str, DataSet]]``.
Dict[str, DataSet]]``.

* ds_or_db 为 :class:`~fastNLP.core.dataset.DataSet`,返回值为:class:`~fastNLP.core.dataloaders.PaddleDataLoader`
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, PaddleDataLoader]`` 的字典
* ds_or_db 为 ``Sequence[DataSet]`` 序列, 返回值为 ``Sequence[PaddleDataLoader]`` 的序列
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值也为 ``Dict[str, PaddleDataLoader]`` 的字典

:param feed_list: (list(Tensor)|tuple(Tensor)): feed Tensor list.
@@ -321,14 +317,8 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,

"""
from fastNLP.io.data_bundle import DataBundle
if isinstance(ds_or_db, Dataset):
dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list,
batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle,
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader,
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers)
return dl
elif isinstance(ds_or_db, DataBundle):

if isinstance(ds_or_db, DataBundle):
dl_bundle = {}
for name, ds in ds_or_db.iter_datasets():
if 'train' in name:
@@ -353,18 +343,6 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
timeout=timeout, worker_init_fn=worker_init_fn,
persistent_workers=persistent_workers)
return dl_bundle
elif isinstance(ds_or_db, Sequence):
ds_seq = []
for idx, ds in enumerate(ds_or_db):
if idx > 0:
batch_size = non_train_batch_size if non_train_batch_size else batch_size
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list,
batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle,
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader,
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers)
ds_seq.append(dl)
return ds_seq

elif isinstance(ds_or_db, Dict):
ds_dict = {}
@@ -387,5 +365,13 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
persistent_workers=persistent_workers)
ds_dict[name] = dl
return ds_dict

elif isinstance(ds_or_db, HasLenGetitemType):
dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list,
batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle,
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader,
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers)
return dl
else:
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or sequence or mapping!")
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or mapping!")

+ 20
- 37
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -3,7 +3,8 @@ __all__ = [
'prepare_torch_dataloader'
]

from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List, Any
from abc import ABC
from copy import deepcopy

from fastNLP.core.dataset import DataSet
@@ -12,9 +13,10 @@ from fastNLP.core.dataloaders.utils import indice_collate_wrapper
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler
from ..utils import _match_param
from ..utils import HasLenGetitemType

if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader, Sampler
from torch.utils.data import DataLoader, Sampler, Dataset
else:
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader

@@ -223,10 +225,10 @@ def prepare_torch_dataloader(ds_or_db,
persistent_workers: bool = False,
non_train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
non_train_batch_size: int = None) \
-> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]:
-> Union[TorchDataLoader, Dict[str, TorchDataLoader]]:
"""
``prepare_torch_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``TorchDataloader``对象, 详见 :class:`~fastNLP.core.dataloaders.TorchDataLoader`。
根据 ds_or_db 的类型 ``[DataSet, DataBundle,Sequence[Dataset], Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下:
根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下:

* 当 ds_or_db 为 ``DataSet``时,``prepare_torch_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来
帮你实例化一个 ``TorchDataLoader`` 对象并返回该对象。 详见:class:`~fastNLP.core.dataloaders.TorchDataLoader`。
@@ -238,16 +240,12 @@ def prepare_torch_dataloader(ds_or_db,
``TorchDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_torch_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数,
其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: TorchDataLoader`` 组成
``Dict[key, TorchDataLoader]`` 的字典返回。
* 当 ds_or_db 为 ``Sequence[Dataset]`` 数据类型时, prepare_torch_dataloader 会将 Sequence[0] 的数据集默认为 train 数据集对待,
会将 batch_size 和 sampler 作为参数, 而 Sequence[1:] 数据集均视为非 train 数据集对待,使用 non_train_size 和 non_train_sampler 作为参数。
最终将所有实例化好的 ``TorchDataLoader`` 组成 ``Sequence[TorchDataLoader]`` 返回。

:param ds_or_db: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。其取值只能为 ``[DataSet, DataBundle,
Sequence[DataSet], Dict[str, DataSet]]``.
Dict[str, DataSet]]``.

* ds_or_db 为 :class:`~fastNLP.core.dataset.DataSet`,返回值为:class:`~fastNLP.core.dataloaders.TorchDataLoader`
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典
* ds_or_db 为 ``Sequence[DataSet]`` 序列, 返回值为 ``Sequence[TorchDataLoader]`` 的序列
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值也为 ``Dict[str, TorchDataLoader]`` 的字典

:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
@@ -284,17 +282,8 @@ def prepare_torch_dataloader(ds_or_db,
"""

from fastNLP.io import DataBundle
if isinstance(ds_or_db, DataSet):
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
)
return dl

elif isinstance(ds_or_db, DataBundle):
if isinstance(ds_or_db, DataBundle):
dl_bundle = {}
for name, ds in ds_or_db.iter_datasets():
if 'train' in name:
@@ -320,23 +309,6 @@ def prepare_torch_dataloader(ds_or_db,
)
return dl_bundle

elif isinstance(ds_or_db, Sequence):
dl_bundle = []
for idx, ds in enumerate(ds_or_db):
if idx > 0:
batch_size = non_train_batch_size if non_train_batch_size else batch_size
sampler = non_train_sampler if non_train_sampler else sampler
dl_bundle.append(
TorchDataLoader(dataset=ds, batch_size=batch_size,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
)
)
return dl_bundle

elif isinstance(ds_or_db, Mapping):
dl_bundle = {}
for name, ds in ds_or_db.items():
@@ -363,5 +335,16 @@ def prepare_torch_dataloader(ds_or_db,
)

return dl_bundle

elif isinstance(ds_or_db, HasLenGetitemType):
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
)
return dl

else:
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or sequence or mapping!")
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or mapping!")

+ 16
- 1
fastNLP/core/dataloaders/utils.py View File

@@ -1,4 +1,5 @@
from typing import Callable
from typing import Callable, Any, Union
from abc import ABC
import inspect
import ast

@@ -96,6 +97,20 @@ def _match_param(fun, call_fn:Callable, fn_name:str=None):
logger.debug(f"Exception happens when match parameters for {fn_name}: {e}")
return None


class HasLenGetitemType(ABC):
"""
判断是否实现了 __len__ 和 __getitem__ 方法的类

"""
@classmethod
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
if cls is HasLenGetitemType:
flag = callable(getattr(subclass, '__getitem__', None)) and callable(getattr(subclass, '__len__', None))
return flag
return NotImplemented


if __name__ == '__main__':
def demo(*args, **kwargs):
pass


+ 36
- 3
tests/core/dataloaders/jittor_dataloader/test_fdl.py View File

@@ -5,9 +5,10 @@ from fastNLP.envs import _module_available
if _module_available('datasets'):
from datasets import Dataset as HfDataset

from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader
from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader
from fastNLP.core.dataset import DataSet as Fdataset
from fastNLP.core.collators import Collator
from fastNLP.io.data_bundle import DataBundle
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
if _NEED_IMPORT_JITTOR:
from jittor.dataset import Dataset
@@ -61,7 +62,6 @@ class TestJittor:
for batch in jtl1:
print(batch)


def test_huggingface_datasets(self):
dataset = HfDataset.from_dict({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100})
jtl = JittorDataLoader(dataset, batch_size=4, drop_last=True, shuffle=False)
@@ -90,4 +90,37 @@ class TestJittor:
dl = MyDataset()
dl = dl.set_attrs(collate_batch=collate_bacth, batch_size=256)
for batch in dl:
print(batch)
print(batch)

def test_prepare_jittor_dataloader(self):
# 测试 fastnlp 的 dataset
ds = Fdataset({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
dl = prepare_jittor_dataloader(ds, batch_size=8, shuffle=True, num_workers=2)
assert isinstance(dl, JittorDataLoader)

ds1 = Fdataset({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
dbl = DataBundle(datasets={'train': ds, 'val': ds1})
dl_bundle = prepare_jittor_dataloader(dbl)
assert isinstance(dl_bundle['train'], JittorDataLoader)
assert isinstance(dl_bundle['val'], JittorDataLoader)

ds_dict = {'train_1': ds, 'val': ds1}
dl_dict = prepare_jittor_dataloader(ds_dict)
assert isinstance(dl_dict['train_1'], JittorDataLoader)
assert isinstance(dl_dict['val'], JittorDataLoader)

# 测试 jittor 的 dataset
ds1 = MyDataset()
dl1 = prepare_jittor_dataloader(ds1, batch_size=8, shuffle=True, num_workers=2)
assert isinstance(dl1, JittorDataLoader)

ds2 = MyDataset()
dbl1 = DataBundle(datasets={'train': ds1, 'val': ds2})
dl_bundle1 = prepare_jittor_dataloader(dbl1)
assert isinstance(dl_bundle1['train'], JittorDataLoader)
assert isinstance(dl_bundle1['val'], JittorDataLoader)

ds_dict1 = {'train_1': ds1, 'val': ds2}
dl_dict1 = prepare_jittor_dataloader(ds_dict1)
assert isinstance(dl_dict1['train_1'], JittorDataLoader)
assert isinstance(dl_dict1['val'], JittorDataLoader)

+ 35
- 2
tests/core/dataloaders/paddle_dataloader/test_fdl.py View File

@@ -1,8 +1,9 @@
import pytest
import numpy as np

from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader
from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader, prepare_paddle_dataloader
from fastNLP.core.dataset import DataSet
from fastNLP.io.data_bundle import DataBundle
from fastNLP.core.log import logger
from fastNLP.core.collators import Collator

@@ -90,4 +91,36 @@ class TestPaddle:
ds = PaddleArgMaxDataset(100, 2)
dl = DataLoader(ds, places=None, collate_fn=Collator(), batch_size=4)
for batch in dl:
print(batch)
print(batch)

def test_prepare_paddle_dataloader(self):
# 测试 fastNLP 的 dataset
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
dl = prepare_paddle_dataloader(ds, batch_size=8, shuffle=True, num_workers=2)
assert isinstance(dl, PaddleDataLoader)

ds1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
dbl = DataBundle(datasets={'train': ds, 'val': ds1})
dl_bundle = prepare_paddle_dataloader(dbl)
assert isinstance(dl_bundle['train'], PaddleDataLoader)
assert isinstance(dl_bundle['val'], PaddleDataLoader)

ds_dict = {'train_1': ds, 'val': ds1}
dl_dict = prepare_paddle_dataloader(ds_dict)
assert isinstance(dl_dict['train_1'], PaddleDataLoader)
assert isinstance(dl_dict['val'], PaddleDataLoader)

ds2 = RandomDataset()
dl1 = prepare_paddle_dataloader(ds2, batch_size=8, shuffle=True, num_workers=2)
assert isinstance(dl1, PaddleDataLoader)

ds3 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
dbl1 = DataBundle(datasets={'train': ds2, 'val': ds3})
dl_bundle1 = prepare_paddle_dataloader(dbl1)
assert isinstance(dl_bundle1['train'], PaddleDataLoader)
assert isinstance(dl_bundle1['val'], PaddleDataLoader)

ds_dict1 = {'train_1': ds2, 'val': ds3}
dl_dict1 = prepare_paddle_dataloader(ds_dict1)
assert isinstance(dl_dict1['train_1'], PaddleDataLoader)
assert isinstance(dl_dict1['val'], PaddleDataLoader)

+ 34
- 4
tests/core/dataloaders/torch_dataloader/test_fdl.py View File

@@ -96,6 +96,7 @@ class TestFdl:
assert batch['y'] == [1, 0, 1]

def test_prepare_torch_dataloader(self):
# 测试 fastNLP 的 dataset
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
dl = prepare_torch_dataloader(ds, batch_size=8, shuffle=True, num_workers=2)
assert isinstance(dl, TorchDataLoader)
@@ -111,10 +112,39 @@ class TestFdl:
assert isinstance(dl_dict['train_1'], TorchDataLoader)
assert isinstance(dl_dict['val'], TorchDataLoader)

sequence = [ds, ds1]
seq_ds = prepare_torch_dataloader(sequence)
assert isinstance(seq_ds[0], TorchDataLoader)
assert isinstance(seq_ds[1], TorchDataLoader)
# 测试其他 dataset
class _DataSet:

def __init__(self):
pass

def __getitem__(self, item):
return np.random.randn(5), [[1, 2], [2, 3, 4]]

def __len__(self):
return 10

def __getattribute__(self, item):
return object.__getattribute__(self, item)

ds2 = _DataSet()
dl1 = prepare_torch_dataloader(ds2, batch_size=8, shuffle=True, num_workers=2)
assert isinstance(dl1, TorchDataLoader)

ds3 = _DataSet()
dbl1 = DataBundle(datasets={'train': ds2, 'val': ds3})
dl_bundle1 = prepare_torch_dataloader(dbl1)
assert isinstance(dl_bundle1['train'], TorchDataLoader)
assert isinstance(dl_bundle1['val'], TorchDataLoader)

ds_dict1 = {'train_1': ds2, 'val': ds3}
dl_dict1 = prepare_torch_dataloader(ds_dict1)
assert isinstance(dl_dict1['train_1'], TorchDataLoader)
assert isinstance(dl_dict1['val'], TorchDataLoader)
# sequence = [ds, ds1]
# seq_ds = prepare_torch_dataloader(sequence)
# assert isinstance(seq_ds[0], TorchDataLoader)
# assert isinstance(seq_ds[1], TorchDataLoader)

def test_get_backend(self):
from fastNLP.core.collators import Collator


Loading…
Cancel
Save