@@ -19,6 +19,7 @@ else: | |||
from fastNLP.core.collators import Collator | |||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper | |||
from fastNLP.core.dataset import DataSet as FDataSet | |||
from ..utils import HasLenGetitemType | |||
class _JittorDataset(Dataset): | |||
@@ -204,10 +205,10 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | |||
collate_fn: Union[None, str, Callable] = "auto", | |||
non_train_batch_size: int = None) \ | |||
-> Union[Sequence[JittorDataLoader], Dict[str, JittorDataLoader], JittorDataLoader]: | |||
-> Union[Dict[str, JittorDataLoader], JittorDataLoader]: | |||
""" | |||
``prepare_jittor_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 :class:`JittorDataLoader` 对象, 详见 :class:`~fastNLP.core.dataloaders.JittorDataLoader`。 | |||
根据 ds_or_db 的类型 ``[DataSet, DataBundle,Sequence[Dataset], Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | |||
根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | |||
* 当 ds_or_db 为 ``DataSet`` 时,``prepare_jittor_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 | |||
帮你实例化一个 :class:`JittorDataLoader` 对象并返回该对象。 详见 :class:`~fastNLP.core.dataloaders.JittorDataLoader`。 | |||
@@ -219,19 +220,15 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||
:class:`JittorDataLoader` 对象;当 key 中包含'train'字符串时,``prepare_Jittor_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, | |||
其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: JittorDataLoader`` 组成 | |||
``Dict[key, JittorDataLoader]`` 的字典返回。 | |||
* 当 ds_or_db 为 ``Sequence[Dataset]`` 数据类型时, prepare_jittor_dataloader 会将 Sequence[0] 的数据集默认为 train 数据集对待, | |||
会将 batch_size 和 sampler 作为参数, 而 Sequence[1:] 数据集均视为非 train 数据集对待,使用 non_train_size 和 non_train_sampler 作为参数。 | |||
最终将所有实例化好的 :class:`JittorDataLoader` 组成 ``Sequence[JittorDataLoader]`` 返回。 | |||
:param ds_or_db: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。其取值只能为 ``[DataSet, DataBundle, | |||
Sequence[DataSet], Dict[str, DataSet]]``. | |||
Dict[str, DataSet]]``. | |||
* ds_or_db 为 :class:`~fastNLP.core.dataset.DataSet`,返回值为 :class:`~fastNLP.core.dataloaders.JittorDataLoader` | |||
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 :class:`Dict[str, JittorDataLoader]` 的字典 | |||
* ds_or_db 为 :class:`Sequence[DataSet]` 序列, 返回值为 :class:`Sequence[JittorDataLoader]` 的序列 | |||
* ds_or_db 为 :class:`Dict[str, DataSet]` 字典, 返回值也为 :class:`Dict[str, JittorDataLoader]` 的字典 | |||
:param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 :class:`Dict`, :class:`Sequence` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数 | |||
:param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 :class:`Dict` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数 | |||
设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。 默认为 ``16``。 | |||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||
:param shuffle: 是否打乱数据集, 默认为 ``False``。 | |||
@@ -253,28 +250,23 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||
* collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
:return: 返回数据类型为 :class:`Sequence[JittorDataLoader]` , :class:`Dict[str, JittorDataLoader]`, :class:`JittorDataLoader` 其中之一,根据输入 | |||
:return: 返回数据类型为 :class:`Dict[str, JittorDataLoader]`, :class:`JittorDataLoader` 其中之一,根据输入 | |||
``ds_or_db`` 变化而变化。 | |||
""" | |||
from fastNLP.io.data_bundle import DataBundle | |||
if isinstance(ds_or_db, Dataset): | |||
dl = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=shuffle, | |||
drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, | |||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, | |||
collate_fn=collate_fn) | |||
return dl | |||
elif isinstance(ds_or_db, DataBundle): | |||
if isinstance(ds_or_db, DataBundle): | |||
dl_bundle = {} | |||
for name, ds in ds_or_db.iter_datasets(): | |||
if 'train' in name: | |||
dl_bundle[name] = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=shuffle, | |||
dl_bundle[name] = JittorDataLoader(ds, batch_size=batch_size, shuffle=shuffle, | |||
drop_last=drop_last, num_workers=num_workers, | |||
buffer_size=buffer_size, | |||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, | |||
endless=endless, | |||
collate_fn=collate_fn) | |||
else: | |||
dl_bundle[name] = JittorDataLoader(ds_or_db, | |||
dl_bundle[name] = JittorDataLoader(ds, | |||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | |||
shuffle=shuffle, | |||
drop_last=drop_last, num_workers=num_workers, | |||
@@ -283,17 +275,6 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||
endless=endless, | |||
collate_fn=collate_fn) | |||
return dl_bundle | |||
elif isinstance(ds_or_db, Sequence): | |||
ds_seq = [] | |||
for idx, ds in enumerate(ds_or_db): | |||
if idx > 0: | |||
batch_size = non_train_batch_size if non_train_batch_size else batch_size | |||
dl = JittorDataLoader(ds, batch_size=batch_size, shuffle=shuffle, | |||
drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, | |||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, | |||
collate_fn=collate_fn) | |||
ds_seq.append(dl) | |||
return ds_seq | |||
elif isinstance(ds_or_db, Dict): | |||
ds_dict = {} | |||
@@ -304,7 +285,7 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, | |||
collate_fn=collate_fn) | |||
else: | |||
dl = JittorDataLoader(ds_or_db, | |||
dl = JittorDataLoader(ds, | |||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | |||
shuffle=shuffle, | |||
drop_last=drop_last, num_workers=num_workers, | |||
@@ -314,5 +295,13 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||
collate_fn=collate_fn) | |||
ds_dict[name] = dl | |||
return ds_dict | |||
elif isinstance(ds_or_db, HasLenGetitemType): | |||
dl = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=shuffle, | |||
drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, | |||
stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, | |||
collate_fn=collate_fn) | |||
return dl | |||
else: | |||
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or sequence or mapping!") | |||
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or mapping!") |
@@ -19,7 +19,7 @@ from fastNLP.core.collators.collator import Collator | |||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper | |||
from fastNLP.core.dataset import DataSet as FDataSet | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, RandomBatchSampler | |||
from ..utils import _match_param | |||
from ..utils import _match_param, HasLenGetitemType | |||
class _PaddleDataset(Dataset): | |||
@@ -256,10 +256,10 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||
use_shared_memory: bool = True, timeout: int = 0, | |||
worker_init_fn: Callable = None, persistent_workers=False, | |||
non_train_batch_size: int = None) \ | |||
-> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: | |||
-> Union[Dict[str, PaddleDataLoader], PaddleDataLoader]: | |||
""" | |||
``prepare_paddle_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``PaddleDataloader``对象, 详见 :class:`~fastNLP.core.dataloaders.PaddleDataLoader`。 | |||
根据 ds_or_db 的类型 ``[DataSet, DataBundle,Sequence[Dataset], Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | |||
根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | |||
* 当 ds_or_db 为 ``DataSet``时,``prepare_paddle_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 | |||
帮你实例化一个 ``PaddleDataLoader`` 对象并返回该对象。 详见:class:`~fastNLP.core.dataloaders.PaddleDataLoader`。 | |||
@@ -271,16 +271,12 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||
``PaddleDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_paddle_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, | |||
其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: PaddleDataLoader`` 组成 | |||
``Dict[key, PaddleDataLoader]`` 的字典返回。 | |||
* 当 ds_or_db 为 ``Sequence[Dataset]`` 数据类型时, prepare_paddle_dataloader 会将 Sequence[0] 的数据集默认为 train 数据集对待, | |||
会将 batch_size 和 sampler 作为参数, 而 Sequence[1:] 数据集均视为非 train 数据集对待,使用 non_train_size 和 non_train_sampler 作为参数。 | |||
最终将所有实例化好的 ``PaddleDataLoader`` 组成 ``Sequence[PaddleDataLoader]`` 返回。 | |||
::param ds_or_db: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。其取值只能为 ``[DataSet, DataBundle, | |||
Sequence[DataSet], Dict[str, DataSet]]``. | |||
Dict[str, DataSet]]``. | |||
* ds_or_db 为 :class:`~fastNLP.core.dataset.DataSet`,返回值为:class:`~fastNLP.core.dataloaders.PaddleDataLoader` | |||
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, PaddleDataLoader]`` 的字典 | |||
* ds_or_db 为 ``Sequence[DataSet]`` 序列, 返回值为 ``Sequence[PaddleDataLoader]`` 的序列 | |||
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值也为 ``Dict[str, PaddleDataLoader]`` 的字典 | |||
:param feed_list: (list(Tensor)|tuple(Tensor)): feed Tensor list. | |||
@@ -321,14 +317,8 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||
""" | |||
from fastNLP.io.data_bundle import DataBundle | |||
if isinstance(ds_or_db, Dataset): | |||
dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, | |||
batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, | |||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||
return dl | |||
elif isinstance(ds_or_db, DataBundle): | |||
if isinstance(ds_or_db, DataBundle): | |||
dl_bundle = {} | |||
for name, ds in ds_or_db.iter_datasets(): | |||
if 'train' in name: | |||
@@ -353,18 +343,6 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||
timeout=timeout, worker_init_fn=worker_init_fn, | |||
persistent_workers=persistent_workers) | |||
return dl_bundle | |||
elif isinstance(ds_or_db, Sequence): | |||
ds_seq = [] | |||
for idx, ds in enumerate(ds_or_db): | |||
if idx > 0: | |||
batch_size = non_train_batch_size if non_train_batch_size else batch_size | |||
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | |||
batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, | |||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||
ds_seq.append(dl) | |||
return ds_seq | |||
elif isinstance(ds_or_db, Dict): | |||
ds_dict = {} | |||
@@ -387,5 +365,13 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||
persistent_workers=persistent_workers) | |||
ds_dict[name] = dl | |||
return ds_dict | |||
elif isinstance(ds_or_db, HasLenGetitemType): | |||
dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, | |||
batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, | |||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||
return dl | |||
else: | |||
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or sequence or mapping!") | |||
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or mapping!") |
@@ -3,7 +3,8 @@ __all__ = [ | |||
'prepare_torch_dataloader' | |||
] | |||
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List | |||
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List, Any | |||
from abc import ABC | |||
from copy import deepcopy | |||
from fastNLP.core.dataset import DataSet | |||
@@ -12,9 +13,10 @@ from fastNLP.core.dataloaders.utils import indice_collate_wrapper | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler | |||
from ..utils import _match_param | |||
from ..utils import HasLenGetitemType | |||
if _NEED_IMPORT_TORCH: | |||
from torch.utils.data import DataLoader, Sampler | |||
from torch.utils.data import DataLoader, Sampler, Dataset | |||
else: | |||
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | |||
@@ -223,10 +225,10 @@ def prepare_torch_dataloader(ds_or_db, | |||
persistent_workers: bool = False, | |||
non_train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | |||
non_train_batch_size: int = None) \ | |||
-> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: | |||
-> Union[TorchDataLoader, Dict[str, TorchDataLoader]]: | |||
""" | |||
``prepare_torch_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``TorchDataloader``对象, 详见 :class:`~fastNLP.core.dataloaders.TorchDataLoader`。 | |||
根据 ds_or_db 的类型 ``[DataSet, DataBundle,Sequence[Dataset], Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | |||
根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | |||
* 当 ds_or_db 为 ``DataSet``时,``prepare_torch_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 | |||
帮你实例化一个 ``TorchDataLoader`` 对象并返回该对象。 详见:class:`~fastNLP.core.dataloaders.TorchDataLoader`。 | |||
@@ -238,16 +240,12 @@ def prepare_torch_dataloader(ds_or_db, | |||
``TorchDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_torch_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, | |||
其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: TorchDataLoader`` 组成 | |||
``Dict[key, TorchDataLoader]`` 的字典返回。 | |||
* 当 ds_or_db 为 ``Sequence[Dataset]`` 数据类型时, prepare_torch_dataloader 会将 Sequence[0] 的数据集默认为 train 数据集对待, | |||
会将 batch_size 和 sampler 作为参数, 而 Sequence[1:] 数据集均视为非 train 数据集对待,使用 non_train_size 和 non_train_sampler 作为参数。 | |||
最终将所有实例化好的 ``TorchDataLoader`` 组成 ``Sequence[TorchDataLoader]`` 返回。 | |||
:param ds_or_db: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。其取值只能为 ``[DataSet, DataBundle, | |||
Sequence[DataSet], Dict[str, DataSet]]``. | |||
Dict[str, DataSet]]``. | |||
* ds_or_db 为 :class:`~fastNLP.core.dataset.DataSet`,返回值为:class:`~fastNLP.core.dataloaders.TorchDataLoader` | |||
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典 | |||
* ds_or_db 为 ``Sequence[DataSet]`` 序列, 返回值为 ``Sequence[TorchDataLoader]`` 的序列 | |||
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值也为 ``Dict[str, TorchDataLoader]`` 的字典 | |||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||
@@ -284,17 +282,8 @@ def prepare_torch_dataloader(ds_or_db, | |||
""" | |||
from fastNLP.io import DataBundle | |||
if isinstance(ds_or_db, DataSet): | |||
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | |||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
) | |||
return dl | |||
elif isinstance(ds_or_db, DataBundle): | |||
if isinstance(ds_or_db, DataBundle): | |||
dl_bundle = {} | |||
for name, ds in ds_or_db.iter_datasets(): | |||
if 'train' in name: | |||
@@ -320,23 +309,6 @@ def prepare_torch_dataloader(ds_or_db, | |||
) | |||
return dl_bundle | |||
elif isinstance(ds_or_db, Sequence): | |||
dl_bundle = [] | |||
for idx, ds in enumerate(ds_or_db): | |||
if idx > 0: | |||
batch_size = non_train_batch_size if non_train_batch_size else batch_size | |||
sampler = non_train_sampler if non_train_sampler else sampler | |||
dl_bundle.append( | |||
TorchDataLoader(dataset=ds, batch_size=batch_size, | |||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
) | |||
) | |||
return dl_bundle | |||
elif isinstance(ds_or_db, Mapping): | |||
dl_bundle = {} | |||
for name, ds in ds_or_db.items(): | |||
@@ -363,5 +335,16 @@ def prepare_torch_dataloader(ds_or_db, | |||
) | |||
return dl_bundle | |||
elif isinstance(ds_or_db, HasLenGetitemType): | |||
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | |||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
) | |||
return dl | |||
else: | |||
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or sequence or mapping!") | |||
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or mapping!") |
@@ -1,4 +1,5 @@ | |||
from typing import Callable | |||
from typing import Callable, Any, Union | |||
from abc import ABC | |||
import inspect | |||
import ast | |||
@@ -96,6 +97,20 @@ def _match_param(fun, call_fn:Callable, fn_name:str=None): | |||
logger.debug(f"Exception happens when match parameters for {fn_name}: {e}") | |||
return None | |||
class HasLenGetitemType(ABC): | |||
""" | |||
判断是否实现了 __len__ 和 __getitem__ 方法的类 | |||
""" | |||
@classmethod | |||
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | |||
if cls is HasLenGetitemType: | |||
flag = callable(getattr(subclass, '__getitem__', None)) and callable(getattr(subclass, '__len__', None)) | |||
return flag | |||
return NotImplemented | |||
if __name__ == '__main__': | |||
def demo(*args, **kwargs): | |||
pass | |||
@@ -5,9 +5,10 @@ from fastNLP.envs import _module_available | |||
if _module_available('datasets'): | |||
from datasets import Dataset as HfDataset | |||
from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader | |||
from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader | |||
from fastNLP.core.dataset import DataSet as Fdataset | |||
from fastNLP.core.collators import Collator | |||
from fastNLP.io.data_bundle import DataBundle | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
if _NEED_IMPORT_JITTOR: | |||
from jittor.dataset import Dataset | |||
@@ -61,7 +62,6 @@ class TestJittor: | |||
for batch in jtl1: | |||
print(batch) | |||
def test_huggingface_datasets(self): | |||
dataset = HfDataset.from_dict({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100}) | |||
jtl = JittorDataLoader(dataset, batch_size=4, drop_last=True, shuffle=False) | |||
@@ -90,4 +90,37 @@ class TestJittor: | |||
dl = MyDataset() | |||
dl = dl.set_attrs(collate_batch=collate_bacth, batch_size=256) | |||
for batch in dl: | |||
print(batch) | |||
print(batch) | |||
def test_prepare_jittor_dataloader(self): | |||
# 测试 fastnlp 的 dataset | |||
ds = Fdataset({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
dl = prepare_jittor_dataloader(ds, batch_size=8, shuffle=True, num_workers=2) | |||
assert isinstance(dl, JittorDataLoader) | |||
ds1 = Fdataset({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
dbl = DataBundle(datasets={'train': ds, 'val': ds1}) | |||
dl_bundle = prepare_jittor_dataloader(dbl) | |||
assert isinstance(dl_bundle['train'], JittorDataLoader) | |||
assert isinstance(dl_bundle['val'], JittorDataLoader) | |||
ds_dict = {'train_1': ds, 'val': ds1} | |||
dl_dict = prepare_jittor_dataloader(ds_dict) | |||
assert isinstance(dl_dict['train_1'], JittorDataLoader) | |||
assert isinstance(dl_dict['val'], JittorDataLoader) | |||
# 测试 jittor 的 dataset | |||
ds1 = MyDataset() | |||
dl1 = prepare_jittor_dataloader(ds1, batch_size=8, shuffle=True, num_workers=2) | |||
assert isinstance(dl1, JittorDataLoader) | |||
ds2 = MyDataset() | |||
dbl1 = DataBundle(datasets={'train': ds1, 'val': ds2}) | |||
dl_bundle1 = prepare_jittor_dataloader(dbl1) | |||
assert isinstance(dl_bundle1['train'], JittorDataLoader) | |||
assert isinstance(dl_bundle1['val'], JittorDataLoader) | |||
ds_dict1 = {'train_1': ds1, 'val': ds2} | |||
dl_dict1 = prepare_jittor_dataloader(ds_dict1) | |||
assert isinstance(dl_dict1['train_1'], JittorDataLoader) | |||
assert isinstance(dl_dict1['val'], JittorDataLoader) |
@@ -1,8 +1,9 @@ | |||
import pytest | |||
import numpy as np | |||
from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader | |||
from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader, prepare_paddle_dataloader | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.io.data_bundle import DataBundle | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.collators import Collator | |||
@@ -90,4 +91,36 @@ class TestPaddle: | |||
ds = PaddleArgMaxDataset(100, 2) | |||
dl = DataLoader(ds, places=None, collate_fn=Collator(), batch_size=4) | |||
for batch in dl: | |||
print(batch) | |||
print(batch) | |||
def test_prepare_paddle_dataloader(self): | |||
# 测试 fastNLP 的 dataset | |||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
dl = prepare_paddle_dataloader(ds, batch_size=8, shuffle=True, num_workers=2) | |||
assert isinstance(dl, PaddleDataLoader) | |||
ds1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
dbl = DataBundle(datasets={'train': ds, 'val': ds1}) | |||
dl_bundle = prepare_paddle_dataloader(dbl) | |||
assert isinstance(dl_bundle['train'], PaddleDataLoader) | |||
assert isinstance(dl_bundle['val'], PaddleDataLoader) | |||
ds_dict = {'train_1': ds, 'val': ds1} | |||
dl_dict = prepare_paddle_dataloader(ds_dict) | |||
assert isinstance(dl_dict['train_1'], PaddleDataLoader) | |||
assert isinstance(dl_dict['val'], PaddleDataLoader) | |||
ds2 = RandomDataset() | |||
dl1 = prepare_paddle_dataloader(ds2, batch_size=8, shuffle=True, num_workers=2) | |||
assert isinstance(dl1, PaddleDataLoader) | |||
ds3 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
dbl1 = DataBundle(datasets={'train': ds2, 'val': ds3}) | |||
dl_bundle1 = prepare_paddle_dataloader(dbl1) | |||
assert isinstance(dl_bundle1['train'], PaddleDataLoader) | |||
assert isinstance(dl_bundle1['val'], PaddleDataLoader) | |||
ds_dict1 = {'train_1': ds2, 'val': ds3} | |||
dl_dict1 = prepare_paddle_dataloader(ds_dict1) | |||
assert isinstance(dl_dict1['train_1'], PaddleDataLoader) | |||
assert isinstance(dl_dict1['val'], PaddleDataLoader) |
@@ -96,6 +96,7 @@ class TestFdl: | |||
assert batch['y'] == [1, 0, 1] | |||
def test_prepare_torch_dataloader(self): | |||
# 测试 fastNLP 的 dataset | |||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
dl = prepare_torch_dataloader(ds, batch_size=8, shuffle=True, num_workers=2) | |||
assert isinstance(dl, TorchDataLoader) | |||
@@ -111,10 +112,39 @@ class TestFdl: | |||
assert isinstance(dl_dict['train_1'], TorchDataLoader) | |||
assert isinstance(dl_dict['val'], TorchDataLoader) | |||
sequence = [ds, ds1] | |||
seq_ds = prepare_torch_dataloader(sequence) | |||
assert isinstance(seq_ds[0], TorchDataLoader) | |||
assert isinstance(seq_ds[1], TorchDataLoader) | |||
# 测试其他 dataset | |||
class _DataSet: | |||
def __init__(self): | |||
pass | |||
def __getitem__(self, item): | |||
return np.random.randn(5), [[1, 2], [2, 3, 4]] | |||
def __len__(self): | |||
return 10 | |||
def __getattribute__(self, item): | |||
return object.__getattribute__(self, item) | |||
ds2 = _DataSet() | |||
dl1 = prepare_torch_dataloader(ds2, batch_size=8, shuffle=True, num_workers=2) | |||
assert isinstance(dl1, TorchDataLoader) | |||
ds3 = _DataSet() | |||
dbl1 = DataBundle(datasets={'train': ds2, 'val': ds3}) | |||
dl_bundle1 = prepare_torch_dataloader(dbl1) | |||
assert isinstance(dl_bundle1['train'], TorchDataLoader) | |||
assert isinstance(dl_bundle1['val'], TorchDataLoader) | |||
ds_dict1 = {'train_1': ds2, 'val': ds3} | |||
dl_dict1 = prepare_torch_dataloader(ds_dict1) | |||
assert isinstance(dl_dict1['train_1'], TorchDataLoader) | |||
assert isinstance(dl_dict1['val'], TorchDataLoader) | |||
# sequence = [ds, ds1] | |||
# seq_ds = prepare_torch_dataloader(sequence) | |||
# assert isinstance(seq_ds[0], TorchDataLoader) | |||
# assert isinstance(seq_ds[1], TorchDataLoader) | |||
def test_get_backend(self): | |||
from fastNLP.core.collators import Collator | |||