Browse Source

修改fdl

tags/v1.0.0alpha
MorningForest 2 years ago
parent
commit
da25fbee3f
4 changed files with 110 additions and 38 deletions
  1. +58
    -21
      fastNLP/core/dataloaders/paddle_dataloader/fdl.py
  2. +10
    -8
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  3. +1
    -1
      fastNLP/core/dataloaders/utils.py
  4. +41
    -8
      tests/core/dataloaders/torch_dataloader/test_fdl.py

+ 58
- 21
fastNLP/core/dataloaders/paddle_dataloader/fdl.py View File

@@ -63,13 +63,6 @@ class PaddleDataLoader(DataLoader):
shuffle = False shuffle = False
drop_last = False drop_last = False


super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places,
return_list=return_list, batch_sampler=batch_sampler,
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last,
collate_fn=None, num_workers=num_workers,
use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory,
timeout=timeout, worker_init_fn=worker_init_fn,
persistent_workers=persistent_workers)
if isinstance(collate_fn, str): if isinstance(collate_fn, str):
if collate_fn == 'auto': if collate_fn == 'auto':
if isinstance(dataset.dataset, FDataSet): if isinstance(dataset.dataset, FDataSet):
@@ -80,11 +73,15 @@ class PaddleDataLoader(DataLoader):


else: else:
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") raise ValueError(f"collate_fn: {collate_fn} must be 'auto'")
elif isinstance(collate_fn, Callable):
if collate_fn is not default_collate_fn:
self._collate_fn = collate_fn
else:
self._collate_fn = default_collate_fn

super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places,
return_list=return_list, batch_sampler=batch_sampler,
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last,
collate_fn=collate_fn, num_workers=num_workers,
use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory,
timeout=timeout, worker_init_fn=worker_init_fn,
persistent_workers=persistent_workers)

# _collate_fn = _MultiCollator(AutoCollator(as_numpy=True)) # _collate_fn = _MultiCollator(AutoCollator(as_numpy=True))
# if collate_fn is not None: # if collate_fn is not None:
# _collate_fn.add_collator(collate_fn) # _collate_fn.add_collator(collate_fn)
@@ -96,13 +93,13 @@ class PaddleDataLoader(DataLoader):
# if len(self._collate_fn.get_collators()) == 0: # if len(self._collate_fn.get_collators()) == 0:
# self._collate_fn.add_collator(default_collate_fn) # self._collate_fn.add_collator(default_collate_fn)
# self._collate_fn = default_collate_fn # self._collate_fn = default_collate_fn
self.collate_fn = indice_collate_wrapper(self._collate_fn)
self.collate_fn = indice_collate_wrapper(self.collate_fn)
for indices, data in super().__iter__(): for indices, data in super().__iter__():
self.cur_batch_indices = indices self.cur_batch_indices = indices
yield data yield data


def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None,
pad_fn:Callable=None) -> Collator:
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None,
pad_fn: Callable = None) -> Collator:
""" """
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。


@@ -121,12 +118,26 @@ class PaddleDataLoader(DataLoader):
形式,输出将被直接作为结果输出。 形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身 :return: 返回 Collator 自身
""" """
if isinstance(self._collate_fn, Collator):
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
return self._collate_fn
collator = self._get_collator()
if isinstance(collator, Collator):
collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
return collator
else: else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")


def _get_collator(self):
"""
如果 collate_fn 是 Collator 对象,得到该对象。如果没有的话,返回 None

:return:
"""
collator = None
if hasattr(self.collate_fn, '__wrapped__') and isinstance(self.collate_fn.__wrapped__, Collator):
collator = self.collate_fn.__wrapped__
elif isinstance(self.collate_fn, Collator):
collator = self.collate_fn
return collator

def set_ignore(self, *field_names) -> Collator: def set_ignore(self, *field_names) -> Collator:
""" """
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
@@ -138,9 +149,10 @@ class PaddleDataLoader(DataLoader):
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。
:return: 返回 Collator 自身 :return: 返回 Collator 自身
""" """
if isinstance(self._collate_fn, Collator):
self._collate_fn.set_ignore(*field_names)
return self._collate_fn
collator = self._get_collator()
if isinstance(collator, Collator):
collator.set_ignore(*field_names)
return collator
else: else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")


@@ -163,6 +175,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
worker_init_fn: Callable = None, persistent_workers=False, worker_init_fn: Callable = None, persistent_workers=False,
non_train_batch_size: int = 16) \ non_train_batch_size: int = 16) \
-> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: -> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]:
from fastNLP.io.data_bundle import DataBundle
if isinstance(ds_or_db, Dataset): if isinstance(ds_or_db, Dataset):
dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list,
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle,
@@ -170,6 +183,30 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader,
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers)
return dl return dl
elif isinstance(ds_or_db, DataBundle):
dl_bundle = {}
for name, ds in ds_or_db.iter_datasets():
if 'train' in name:
dl_bundle[name] = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places,
return_list=return_list,
batch_sampler=batch_sampler, batch_size=train_batch_size,
shuffle=shuffle,
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
use_shared_memory=use_shared_memory,
use_buffer_reader=use_buffer_reader,
timeout=timeout, worker_init_fn=worker_init_fn,
persistent_workers=persistent_workers)
else:
dl_bundle[name] = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places,
return_list=return_list,
batch_sampler=batch_sampler, batch_size=non_train_batch_size,
shuffle=shuffle,
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
use_shared_memory=use_shared_memory,
use_buffer_reader=use_buffer_reader,
timeout=timeout, worker_init_fn=worker_init_fn,
persistent_workers=persistent_workers)
return dl_bundle
elif isinstance(ds_or_db, Sequence): elif isinstance(ds_or_db, Sequence):
ds_seq = [] ds_seq = []
for ds in ds_or_db: for ds in ds_or_db:


+ 10
- 8
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -14,7 +14,6 @@ from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler,


if _NEED_IMPORT_TORCH: if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader, Sampler from torch.utils.data import DataLoader, Sampler
from torch.utils.data._utils.collate import default_collate
else: else:
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader from fastNLP.core.utils.dummy_class import DummyClass as DataLoader


@@ -73,12 +72,15 @@ class TorchDataLoader(DataLoader):
:param prefetch_factor: :param prefetch_factor:
:param persistent_workers: :param persistent_workers:
""" """
if isinstance(dataset, DataSet) and collate_fn is None:
raise ValueError("When use FastNLP DataSet, collate_fn must be not None")

if not isinstance(dataset, _FDataSet): if not isinstance(dataset, _FDataSet):
dataset = _FDataSet(dataset) dataset = _FDataSet(dataset)


if sampler is None and batch_sampler is None: if sampler is None and batch_sampler is None:
sampler = RandomSampler(dataset, shuffle=shuffle) sampler = RandomSampler(dataset, shuffle=shuffle)
shuffle=False
shuffle = False
if isinstance(collate_fn, str): if isinstance(collate_fn, str):
if collate_fn == 'auto': if collate_fn == 'auto':
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset
@@ -107,8 +109,8 @@ class TorchDataLoader(DataLoader):
self.cur_batch_indices = indices self.cur_batch_indices = indices
yield data yield data


def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None,
pad_fn:Callable=None) -> Collator:
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None,
pad_fn: Callable = None) -> Collator:
""" """
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。


@@ -174,12 +176,12 @@ class TorchDataLoader(DataLoader):
return self.cur_batch_indices return self.cur_batch_indices





def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]], def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]],
batch_size: int = 16, batch_size: int = 16,
shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
shuffle: bool = True,
sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
num_workers: int = 0, collate_fn: Union[str, Callable, None] = None,
num_workers: int = 0, collate_fn: Union[str, Callable, None] = 'auto',
pin_memory: bool = False, drop_last: bool = False, pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[Callable] = None, timeout: float = 0, worker_init_fn: Optional[Callable] = None,
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, multiprocessing_context=None, generator=None, prefetch_factor: int = 2,
@@ -220,7 +222,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping
) )
return dl return dl


elif type(ds_or_db, DataBundle):
elif isinstance(ds_or_db, DataBundle):
dl_bundle = {} dl_bundle = {}
for name, ds in ds_or_db.iter_datasets(): for name, ds in ds_or_db.iter_datasets():
if 'train' in name: if 'train' in name:


+ 1
- 1
fastNLP/core/dataloaders/utils.py View File

@@ -10,7 +10,7 @@ def indice_collate_wrapper(func):
:param func: 需要修饰的函数 :param func: 需要修饰的函数
:return: :return:
""" """
if func.__name__ == '_indice_collate_wrapper': # 如果已经被包裹过了
if hasattr(func, '__name__') and func.__name__ == '_indice_collate_wrapper': # 如果已经被包裹过了
return func return func


def _indice_collate_wrapper(tuple_data): # 这里不能使用 functools.wraps ,否则会检测不到 def _indice_collate_wrapper(tuple_data): # 这里不能使用 functools.wraps ,否则会检测不到


+ 41
- 8
tests/core/dataloaders/torch_dataloader/test_fdl.py View File

@@ -3,6 +3,10 @@ import pytest
from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_torch_dataloader from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_torch_dataloader
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.io.data_bundle import DataBundle from fastNLP.io.data_bundle import DataBundle
from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
import torch




@pytest.mark.torch @pytest.mark.torch
@@ -22,20 +26,16 @@ class TestFdl:
fdl = TorchDataLoader(ds, batch_size=3) fdl = TorchDataLoader(ds, batch_size=3)
fdl.set_pad("x", -1) fdl.set_pad("x", -1)
for batch in fdl: for batch in fdl:
print(batch)
# fdl.set_pad_val("x", val=-2)
# for batch in fdl:
# print(batch)
assert batch['x'].shape == torch.Size([3, 4])


def test_get_batch_indices(self): def test_get_batch_indices(self):
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
fdl = TorchDataLoader(ds, batch_size=3, shuffle=True) fdl = TorchDataLoader(ds, batch_size=3, shuffle=True)
for batch in fdl: for batch in fdl:
print(fdl.get_batch_indices())
assert len(fdl.get_batch_indices()) == 3


def test_other_dataset(self): def test_other_dataset(self):
import numpy as np import numpy as np

class _DataSet: class _DataSet:


def __init__(self): def __init__(self):
@@ -55,8 +55,41 @@ class TestFdl:
# dl.set_inputs('data', 'labels') # dl.set_inputs('data', 'labels')
# dl.set_pad_val('labels', val=None) # dl.set_pad_val('labels', val=None)
for batch in dl: for batch in dl:
print(batch)
print(dl.get_batch_indices())
assert batch[0].shape == torch.Size([2, 5])
assert batch[1].shape == torch.Size([2, 2, 3])

def test_default_collate_fn(self):
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
with pytest.raises(ValueError):
fdl = TorchDataLoader(ds, batch_size=3, collate_fn=None)
import numpy as np
class _DataSet:

def __init__(self):
pass

def __getitem__(self, item):
return np.random.randn(5), [[1, 2], [2, 3, 4]]

def __len__(self):
return 10

fdl = TorchDataLoader(_DataSet(), batch_size=3, collate_fn=None, drop_last=True)
for batch in fdl:
assert batch[0].shape == torch.Size([3, 5])

def test_my_collate_fn(self):
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
def collate_fn(batch):
res = {'x': [], 'y': []}
for ins in batch:
res['x'].append(ins['x'])
res['y'].append(ins['y'])
return res
fdl = TorchDataLoader(ds, collate_fn=collate_fn, batch_size=3, drop_last=True)
for batch in fdl:
assert batch['x'] == [[1, 2], [2, 3, 4], [4, 5, 6, 7]]
assert batch['y'] == [1, 0, 1]


def test_prepare_torch_dataloader(self): def test_prepare_torch_dataloader(self):
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})


Loading…
Cancel
Save