Browse Source

修改fdl测试用例

tags/v1.0.0alpha
MorningForest 3 years ago
parent
commit
af8fdf9080
4 changed files with 68 additions and 33 deletions
  1. +31
    -28
      fastNLP/core/dataloaders/jittor_dataloader/fdl.py
  2. +12
    -4
      tests/core/dataloaders/jittor_dataloader/test_fdl.py
  3. +9
    -1
      tests/core/dataloaders/paddle_dataloader/test_fdl.py
  4. +16
    -0
      tests/core/dataloaders/torch_dataloader/test_fdl.py

+ 31
- 28
fastNLP/core/dataloaders/jittor_dataloader/fdl.py View File

@@ -26,20 +26,11 @@ class _JittorDataset(Dataset):
def __init__(self, dataset) -> None: def __init__(self, dataset) -> None:
super(_JittorDataset, self).__init__() super(_JittorDataset, self).__init__()
self.dataset = dataset self.dataset = dataset
self.total_len = len(dataset)


def __getitem__(self, item): def __getitem__(self, item):
return (item, self.dataset[item]) return (item, self.dataset[item])


def __len__(self) -> int:
return len(self.dataset)

# def __getattr__(self, item):
# # jittor的Dataset没有的方法而用户的dataset存在且实现了getattribute方法,此时用户可以调用
# try:
# self.dataset.__getattribute__(item)
# except Exception as e:
# raise e



class JittorDataLoader: class JittorDataLoader:
""" """
@@ -62,13 +53,17 @@ class JittorDataLoader:
:param keep_numpy_array: :param keep_numpy_array:
:param endless: :param endless:
:param collate_fn: 对取得到的数据进行打包的callable函数 :param collate_fn: 对取得到的数据进行打包的callable函数
:param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型
""" """
# TODO 验证支持replacesampler (以后完成) # TODO 验证支持replacesampler (以后完成)
# 将内部dataset批次设置为1
if isinstance(dataset, Dataset):
dataset.set_attrs(batch_size=1)

# FastNLP Datset, collate_fn not None # FastNLP Datset, collate_fn not None
if isinstance(dataset, FDataSet) and collate_fn is None: if isinstance(dataset, FDataSet) and collate_fn is None:
raise ValueError("When use FastNLP DataSet, collate_fn must be not None") raise ValueError("When use FastNLP DataSet, collate_fn must be not None")


# 将所有dataset转为jittor类型的dataset
if not isinstance(dataset, _JittorDataset): if not isinstance(dataset, _JittorDataset):
self.dataset = _JittorDataset(dataset) self.dataset = _JittorDataset(dataset)


@@ -82,17 +77,13 @@ class JittorDataLoader:
else: else:
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") raise ValueError(f"collate_fn: {collate_fn} must be 'auto'")
elif isinstance(collate_fn, Callable): elif isinstance(collate_fn, Callable):
if collate_fn is not collate_batch:
self.collate_fn = collate_fn
self.collate_fn = collate_fn
else: else:
self.collate_fn = collate_batch self.collate_fn = collate_batch


self.dataset.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, self.dataset.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last,
num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad,
keep_numpy_array=keep_numpy_array, endless=endless) keep_numpy_array=keep_numpy_array, endless=endless)
# 将内部dataset批次设置为1
if isinstance(self.dataset.dataset, Dataset):
self.dataset.dataset.set_attrs(batch_size=1)


self.cur_batch_indices = None self.cur_batch_indices = None


@@ -105,12 +96,10 @@ class JittorDataLoader:
yield data yield data


def __len__(self): def __len__(self):
if self.dataset.drop_last:
return len(self.dataset) // self.dataset.batch_size
return (len(self.dataset) - 1) // self.dataset.batch_size + 1
return len(self.dataset)


def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None,
pad_fn: Callable = None) -> "JittorDataLoader":
pad_fn: Callable = None) -> Collator:
""" """
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。


@@ -129,14 +118,27 @@ class JittorDataLoader:
形式,输出将被直接作为结果输出。 形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身 :return: 返回 Collator 自身
""" """
if isinstance(self.collate_fn, Collator):
self.collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn,
backend=backend)
return self
collator = self._get_collator()
if isinstance(collator, Collator):
collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
return collator
else: else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")


def set_ignore(self, *field_names) -> "JittorDataLoader":
def _get_collator(self):
"""
如果 collate_fn 是 Collator 对象,得到该对象。如果没有的话,返回 None

:return:
"""
collator = None
if hasattr(self.collate_fn, '__wrapped__') and isinstance(self.collate_fn.__wrapped__, Collator):
collator = self.collate_fn.__wrapped__
elif isinstance(self.collate_fn, Collator):
collator = self.collate_fn
return collator

def set_ignore(self, *field_names) -> Collator:
""" """
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
Example:: Example::
@@ -147,9 +149,10 @@ class JittorDataLoader:
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。
:return: 返回 Collator 自身 :return: 返回 Collator 自身
""" """
if isinstance(self.collate_fn, Collator):
self.collate_fn.set_ignore(*field_names)
return self
collator = self._get_collator()
if isinstance(collator, Collator):
collator.set_ignore(*field_names)
return collator
else: else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")




+ 12
- 4
tests/core/dataloaders/jittor_dataloader/test_fdl.py View File

@@ -4,6 +4,7 @@ from datasets import Dataset as HfDataset


from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader
from fastNLP.core.dataset import DataSet as Fdataset from fastNLP.core.dataset import DataSet as Fdataset
from fastNLP.core.collators import Collator
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
if _NEED_IMPORT_JITTOR: if _NEED_IMPORT_JITTOR:
from jittor.dataset import Dataset from jittor.dataset import Dataset
@@ -53,9 +54,9 @@ class TestJittor:
jtl.set_ignore("y") jtl.set_ignore("y")
for batch in jtl: for batch in jtl:
assert batch['x'].size() == (16, 4) assert batch['x'].size() == (16, 4)
jtl = JittorDataLoader(dataset, batch_size=16, drop_last=True, num_workers=2)
jtl1 = JittorDataLoader(dataset, batch_size=16, drop_last=True, num_workers=2)
for batch in jtl1:
print(batch)




def test_huggingface_datasets(self): def test_huggingface_datasets(self):
@@ -79,4 +80,11 @@ class TestJittor:
for idx, batch in enumerate(dataset): for idx, batch in enumerate(dataset):
print(idx, batch.shape) print(idx, batch.shape)
for idx, batch in enumerate(dataset): for idx, batch in enumerate(dataset):
print(idx, batch.shape)
print(idx, batch.shape)

def test_jittor_get_backend(self):
collate_bacth = Collator(backend='auto')
dl = MyDataset()
dl = dl.set_attrs(collate_batch=collate_bacth, batch_size=256)
for batch in dl:
print(batch)

+ 9
- 1
tests/core/dataloaders/paddle_dataloader/test_fdl.py View File

@@ -4,11 +4,12 @@ import numpy as np
from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.core.collators import Collator


from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.envs.imports import _NEED_IMPORT_PADDLE


if _NEED_IMPORT_PADDLE: if _NEED_IMPORT_PADDLE:
from paddle.io import Dataset
from paddle.io import Dataset, DataLoader
import paddle import paddle
else: else:
from fastNLP.core.utils.dummy_class import DummyClass as Dataset from fastNLP.core.utils.dummy_class import DummyClass as Dataset
@@ -61,3 +62,10 @@ class TestPaddle:
fdl1.set_ignore('label') fdl1.set_ignore('label')
for batch in fdl1: for batch in fdl1:
assert batch['image'].shape == [4, 10, 5] assert batch['image'].shape == [4, 10, 5]

def test_get_backend(self):
ds = RandomDataset()
collate_fn = Collator(backend='auto')
paddle_dl = DataLoader(ds, collate_fn=collate_fn)
for batch in paddle_dl:
print(batch)

+ 16
- 0
tests/core/dataloaders/torch_dataloader/test_fdl.py View File

@@ -112,3 +112,19 @@ class TestFdl:
seq_ds = prepare_torch_dataloader(sequence) seq_ds = prepare_torch_dataloader(sequence)
assert isinstance(seq_ds[0], TorchDataLoader) assert isinstance(seq_ds[0], TorchDataLoader)
assert isinstance(seq_ds[1], TorchDataLoader) assert isinstance(seq_ds[1], TorchDataLoader)

def test_get_backend(self):
from fastNLP.core.collators import Collator
from torch.utils.data import DataLoader, Dataset

class MyDatset(DataSet):
def __len__(self):
return 1000

def __getitem__(self, item):
return [[1, 0], [1], [1, 2, 4]], [1, 0]

collate_batch = Collator(backend='auto')
dl = DataLoader(MyDatset(), collate_fn=collate_batch)
for batch in dl:
print(batch)

Loading…
Cancel
Save