@@ -63,13 +63,6 @@ class PaddleDataLoader(DataLoader): | |||||
shuffle = False | shuffle = False | ||||
drop_last = False | drop_last = False | ||||
super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, | |||||
return_list=return_list, batch_sampler=batch_sampler, | |||||
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | |||||
collate_fn=None, num_workers=num_workers, | |||||
use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, | |||||
timeout=timeout, worker_init_fn=worker_init_fn, | |||||
persistent_workers=persistent_workers) | |||||
if isinstance(collate_fn, str): | if isinstance(collate_fn, str): | ||||
if collate_fn == 'auto': | if collate_fn == 'auto': | ||||
if isinstance(dataset.dataset, FDataSet): | if isinstance(dataset.dataset, FDataSet): | ||||
@@ -80,11 +73,15 @@ class PaddleDataLoader(DataLoader): | |||||
else: | else: | ||||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | ||||
elif isinstance(collate_fn, Callable): | |||||
if collate_fn is not default_collate_fn: | |||||
self._collate_fn = collate_fn | |||||
else: | |||||
self._collate_fn = default_collate_fn | |||||
super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, | |||||
return_list=return_list, batch_sampler=batch_sampler, | |||||
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | |||||
collate_fn=collate_fn, num_workers=num_workers, | |||||
use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, | |||||
timeout=timeout, worker_init_fn=worker_init_fn, | |||||
persistent_workers=persistent_workers) | |||||
# _collate_fn = _MultiCollator(AutoCollator(as_numpy=True)) | # _collate_fn = _MultiCollator(AutoCollator(as_numpy=True)) | ||||
# if collate_fn is not None: | # if collate_fn is not None: | ||||
# _collate_fn.add_collator(collate_fn) | # _collate_fn.add_collator(collate_fn) | ||||
@@ -96,13 +93,13 @@ class PaddleDataLoader(DataLoader): | |||||
# if len(self._collate_fn.get_collators()) == 0: | # if len(self._collate_fn.get_collators()) == 0: | ||||
# self._collate_fn.add_collator(default_collate_fn) | # self._collate_fn.add_collator(default_collate_fn) | ||||
# self._collate_fn = default_collate_fn | # self._collate_fn = default_collate_fn | ||||
self.collate_fn = indice_collate_wrapper(self._collate_fn) | |||||
self.collate_fn = indice_collate_wrapper(self.collate_fn) | |||||
for indices, data in super().__iter__(): | for indices, data in super().__iter__(): | ||||
self.cur_batch_indices = indices | self.cur_batch_indices = indices | ||||
yield data | yield data | ||||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, | |||||
pad_fn:Callable=None) -> Collator: | |||||
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, | |||||
pad_fn: Callable = None) -> Collator: | |||||
""" | """ | ||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | ||||
@@ -121,12 +118,26 @@ class PaddleDataLoader(DataLoader): | |||||
形式,输出将被直接作为结果输出。 | 形式,输出将被直接作为结果输出。 | ||||
:return: 返回 Collator 自身 | :return: 返回 Collator 自身 | ||||
""" | """ | ||||
if isinstance(self._collate_fn, Collator): | |||||
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | |||||
return self._collate_fn | |||||
collator = self._get_collator() | |||||
if isinstance(collator, Collator): | |||||
collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | |||||
return collator | |||||
else: | else: | ||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") | raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") | ||||
def _get_collator(self): | |||||
""" | |||||
如果 collate_fn 是 Collator 对象,得到该对象。如果没有的话,返回 None | |||||
:return: | |||||
""" | |||||
collator = None | |||||
if hasattr(self.collate_fn, '__wrapped__') and isinstance(self.collate_fn.__wrapped__, Collator): | |||||
collator = self.collate_fn.__wrapped__ | |||||
elif isinstance(self.collate_fn, Collator): | |||||
collator = self.collate_fn | |||||
return collator | |||||
def set_ignore(self, *field_names) -> Collator: | def set_ignore(self, *field_names) -> Collator: | ||||
""" | """ | ||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | ||||
@@ -138,9 +149,10 @@ class PaddleDataLoader(DataLoader): | |||||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | ||||
:return: 返回 Collator 自身 | :return: 返回 Collator 自身 | ||||
""" | """ | ||||
if isinstance(self._collate_fn, Collator): | |||||
self._collate_fn.set_ignore(*field_names) | |||||
return self._collate_fn | |||||
collator = self._get_collator() | |||||
if isinstance(collator, Collator): | |||||
collator.set_ignore(*field_names) | |||||
return collator | |||||
else: | else: | ||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") | raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") | ||||
@@ -163,6 +175,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
worker_init_fn: Callable = None, persistent_workers=False, | worker_init_fn: Callable = None, persistent_workers=False, | ||||
non_train_batch_size: int = 16) \ | non_train_batch_size: int = 16) \ | ||||
-> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: | -> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: | ||||
from fastNLP.io.data_bundle import DataBundle | |||||
if isinstance(ds_or_db, Dataset): | if isinstance(ds_or_db, Dataset): | ||||
dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, | dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, | ||||
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, | batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, | ||||
@@ -170,6 +183,30 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | ||||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | ||||
return dl | return dl | ||||
elif isinstance(ds_or_db, DataBundle): | |||||
dl_bundle = {} | |||||
for name, ds in ds_or_db.iter_datasets(): | |||||
if 'train' in name: | |||||
dl_bundle[name] = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, | |||||
return_list=return_list, | |||||
batch_sampler=batch_sampler, batch_size=train_batch_size, | |||||
shuffle=shuffle, | |||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||||
use_shared_memory=use_shared_memory, | |||||
use_buffer_reader=use_buffer_reader, | |||||
timeout=timeout, worker_init_fn=worker_init_fn, | |||||
persistent_workers=persistent_workers) | |||||
else: | |||||
dl_bundle[name] = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, | |||||
return_list=return_list, | |||||
batch_sampler=batch_sampler, batch_size=non_train_batch_size, | |||||
shuffle=shuffle, | |||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||||
use_shared_memory=use_shared_memory, | |||||
use_buffer_reader=use_buffer_reader, | |||||
timeout=timeout, worker_init_fn=worker_init_fn, | |||||
persistent_workers=persistent_workers) | |||||
return dl_bundle | |||||
elif isinstance(ds_or_db, Sequence): | elif isinstance(ds_or_db, Sequence): | ||||
ds_seq = [] | ds_seq = [] | ||||
for ds in ds_or_db: | for ds in ds_or_db: | ||||
@@ -14,7 +14,6 @@ from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
from torch.utils.data import DataLoader, Sampler | from torch.utils.data import DataLoader, Sampler | ||||
from torch.utils.data._utils.collate import default_collate | |||||
else: | else: | ||||
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | ||||
@@ -73,12 +72,15 @@ class TorchDataLoader(DataLoader): | |||||
:param prefetch_factor: | :param prefetch_factor: | ||||
:param persistent_workers: | :param persistent_workers: | ||||
""" | """ | ||||
if isinstance(dataset, DataSet) and collate_fn is None: | |||||
raise ValueError("When use FastNLP DataSet, collate_fn must be not None") | |||||
if not isinstance(dataset, _FDataSet): | if not isinstance(dataset, _FDataSet): | ||||
dataset = _FDataSet(dataset) | dataset = _FDataSet(dataset) | ||||
if sampler is None and batch_sampler is None: | if sampler is None and batch_sampler is None: | ||||
sampler = RandomSampler(dataset, shuffle=shuffle) | sampler = RandomSampler(dataset, shuffle=shuffle) | ||||
shuffle=False | |||||
shuffle = False | |||||
if isinstance(collate_fn, str): | if isinstance(collate_fn, str): | ||||
if collate_fn == 'auto': | if collate_fn == 'auto': | ||||
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset | if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset | ||||
@@ -107,8 +109,8 @@ class TorchDataLoader(DataLoader): | |||||
self.cur_batch_indices = indices | self.cur_batch_indices = indices | ||||
yield data | yield data | ||||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, | |||||
pad_fn:Callable=None) -> Collator: | |||||
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, | |||||
pad_fn: Callable = None) -> Collator: | |||||
""" | """ | ||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | ||||
@@ -174,12 +176,12 @@ class TorchDataLoader(DataLoader): | |||||
return self.cur_batch_indices | return self.cur_batch_indices | ||||
def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]], | def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]], | ||||
batch_size: int = 16, | batch_size: int = 16, | ||||
shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | |||||
shuffle: bool = True, | |||||
sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | |||||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | ||||
num_workers: int = 0, collate_fn: Union[str, Callable, None] = None, | |||||
num_workers: int = 0, collate_fn: Union[str, Callable, None] = 'auto', | |||||
pin_memory: bool = False, drop_last: bool = False, | pin_memory: bool = False, drop_last: bool = False, | ||||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | timeout: float = 0, worker_init_fn: Optional[Callable] = None, | ||||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | ||||
@@ -220,7 +222,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping | |||||
) | ) | ||||
return dl | return dl | ||||
elif type(ds_or_db, DataBundle): | |||||
elif isinstance(ds_or_db, DataBundle): | |||||
dl_bundle = {} | dl_bundle = {} | ||||
for name, ds in ds_or_db.iter_datasets(): | for name, ds in ds_or_db.iter_datasets(): | ||||
if 'train' in name: | if 'train' in name: | ||||
@@ -10,7 +10,7 @@ def indice_collate_wrapper(func): | |||||
:param func: 需要修饰的函数 | :param func: 需要修饰的函数 | ||||
:return: | :return: | ||||
""" | """ | ||||
if func.__name__ == '_indice_collate_wrapper': # 如果已经被包裹过了 | |||||
if hasattr(func, '__name__') and func.__name__ == '_indice_collate_wrapper': # 如果已经被包裹过了 | |||||
return func | return func | ||||
def _indice_collate_wrapper(tuple_data): # 这里不能使用 functools.wraps ,否则会检测不到 | def _indice_collate_wrapper(tuple_data): # 这里不能使用 functools.wraps ,否则会检测不到 | ||||
@@ -3,6 +3,10 @@ import pytest | |||||
from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_torch_dataloader | from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_torch_dataloader | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.io.data_bundle import DataBundle | from fastNLP.io.data_bundle import DataBundle | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@@ -22,20 +26,16 @@ class TestFdl: | |||||
fdl = TorchDataLoader(ds, batch_size=3) | fdl = TorchDataLoader(ds, batch_size=3) | ||||
fdl.set_pad("x", -1) | fdl.set_pad("x", -1) | ||||
for batch in fdl: | for batch in fdl: | ||||
print(batch) | |||||
# fdl.set_pad_val("x", val=-2) | |||||
# for batch in fdl: | |||||
# print(batch) | |||||
assert batch['x'].shape == torch.Size([3, 4]) | |||||
def test_get_batch_indices(self): | def test_get_batch_indices(self): | ||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ||||
fdl = TorchDataLoader(ds, batch_size=3, shuffle=True) | fdl = TorchDataLoader(ds, batch_size=3, shuffle=True) | ||||
for batch in fdl: | for batch in fdl: | ||||
print(fdl.get_batch_indices()) | |||||
assert len(fdl.get_batch_indices()) == 3 | |||||
def test_other_dataset(self): | def test_other_dataset(self): | ||||
import numpy as np | import numpy as np | ||||
class _DataSet: | class _DataSet: | ||||
def __init__(self): | def __init__(self): | ||||
@@ -55,8 +55,41 @@ class TestFdl: | |||||
# dl.set_inputs('data', 'labels') | # dl.set_inputs('data', 'labels') | ||||
# dl.set_pad_val('labels', val=None) | # dl.set_pad_val('labels', val=None) | ||||
for batch in dl: | for batch in dl: | ||||
print(batch) | |||||
print(dl.get_batch_indices()) | |||||
assert batch[0].shape == torch.Size([2, 5]) | |||||
assert batch[1].shape == torch.Size([2, 2, 3]) | |||||
def test_default_collate_fn(self): | |||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||||
with pytest.raises(ValueError): | |||||
fdl = TorchDataLoader(ds, batch_size=3, collate_fn=None) | |||||
import numpy as np | |||||
class _DataSet: | |||||
def __init__(self): | |||||
pass | |||||
def __getitem__(self, item): | |||||
return np.random.randn(5), [[1, 2], [2, 3, 4]] | |||||
def __len__(self): | |||||
return 10 | |||||
fdl = TorchDataLoader(_DataSet(), batch_size=3, collate_fn=None, drop_last=True) | |||||
for batch in fdl: | |||||
assert batch[0].shape == torch.Size([3, 5]) | |||||
def test_my_collate_fn(self): | |||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||||
def collate_fn(batch): | |||||
res = {'x': [], 'y': []} | |||||
for ins in batch: | |||||
res['x'].append(ins['x']) | |||||
res['y'].append(ins['y']) | |||||
return res | |||||
fdl = TorchDataLoader(ds, collate_fn=collate_fn, batch_size=3, drop_last=True) | |||||
for batch in fdl: | |||||
assert batch['x'] == [[1, 2], [2, 3, 4], [4, 5, 6, 7]] | |||||
assert batch['y'] == [1, 0, 1] | |||||
def test_prepare_torch_dataloader(self): | def test_prepare_torch_dataloader(self): | ||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ||||