From da25fbee3f954d936ce1c3cb6e311a9be51af4fa Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Thu, 5 May 2022 19:24:50 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9fdl?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/dataloaders/paddle_dataloader/fdl.py | 79 ++++++++++++++----- .../core/dataloaders/torch_dataloader/fdl.py | 18 +++-- fastNLP/core/dataloaders/utils.py | 2 +- .../dataloaders/torch_dataloader/test_fdl.py | 49 ++++++++++-- 4 files changed, 110 insertions(+), 38 deletions(-) diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py index 952759f7..b157dd68 100644 --- a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py @@ -63,13 +63,6 @@ class PaddleDataLoader(DataLoader): shuffle = False drop_last = False - super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, - return_list=return_list, batch_sampler=batch_sampler, - batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, - collate_fn=None, num_workers=num_workers, - use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, - timeout=timeout, worker_init_fn=worker_init_fn, - persistent_workers=persistent_workers) if isinstance(collate_fn, str): if collate_fn == 'auto': if isinstance(dataset.dataset, FDataSet): @@ -80,11 +73,15 @@ class PaddleDataLoader(DataLoader): else: raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") - elif isinstance(collate_fn, Callable): - if collate_fn is not default_collate_fn: - self._collate_fn = collate_fn - else: - self._collate_fn = default_collate_fn + + super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, + return_list=return_list, batch_sampler=batch_sampler, + batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, + collate_fn=collate_fn, num_workers=num_workers, + use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, + timeout=timeout, worker_init_fn=worker_init_fn, + persistent_workers=persistent_workers) + # _collate_fn = _MultiCollator(AutoCollator(as_numpy=True)) # if collate_fn is not None: # _collate_fn.add_collator(collate_fn) @@ -96,13 +93,13 @@ class PaddleDataLoader(DataLoader): # if len(self._collate_fn.get_collators()) == 0: # self._collate_fn.add_collator(default_collate_fn) # self._collate_fn = default_collate_fn - self.collate_fn = indice_collate_wrapper(self._collate_fn) + self.collate_fn = indice_collate_wrapper(self.collate_fn) for indices, data in super().__iter__(): self.cur_batch_indices = indices yield data - def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, - pad_fn:Callable=None) -> Collator: + def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, + pad_fn: Callable = None) -> Collator: """ 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 @@ -121,12 +118,26 @@ class PaddleDataLoader(DataLoader): 形式,输出将被直接作为结果输出。 :return: 返回 Collator 自身 """ - if isinstance(self._collate_fn, Collator): - self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) - return self._collate_fn + collator = self._get_collator() + if isinstance(collator, Collator): + collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) + return collator else: raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") + def _get_collator(self): + """ + 如果 collate_fn 是 Collator 对象,得到该对象。如果没有的话,返回 None + + :return: + """ + collator = None + if hasattr(self.collate_fn, '__wrapped__') and isinstance(self.collate_fn.__wrapped__, Collator): + collator = self.collate_fn.__wrapped__ + elif isinstance(self.collate_fn, Collator): + collator = self.collate_fn + return collator + def set_ignore(self, *field_names) -> Collator: """ 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 @@ -138,9 +149,10 @@ class PaddleDataLoader(DataLoader): __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 :return: 返回 Collator 自身 """ - if isinstance(self._collate_fn, Collator): - self._collate_fn.set_ignore(*field_names) - return self._collate_fn + collator = self._get_collator() + if isinstance(collator, Collator): + collator.set_ignore(*field_names) + return collator else: raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") @@ -163,6 +175,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, worker_init_fn: Callable = None, persistent_workers=False, non_train_batch_size: int = 16) \ -> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: + from fastNLP.io.data_bundle import DataBundle if isinstance(ds_or_db, Dataset): dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, @@ -170,6 +183,30 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) return dl + elif isinstance(ds_or_db, DataBundle): + dl_bundle = {} + for name, ds in ds_or_db.iter_datasets(): + if 'train' in name: + dl_bundle[name] = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, + return_list=return_list, + batch_sampler=batch_sampler, batch_size=train_batch_size, + shuffle=shuffle, + drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, + use_shared_memory=use_shared_memory, + use_buffer_reader=use_buffer_reader, + timeout=timeout, worker_init_fn=worker_init_fn, + persistent_workers=persistent_workers) + else: + dl_bundle[name] = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, + return_list=return_list, + batch_sampler=batch_sampler, batch_size=non_train_batch_size, + shuffle=shuffle, + drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, + use_shared_memory=use_shared_memory, + use_buffer_reader=use_buffer_reader, + timeout=timeout, worker_init_fn=worker_init_fn, + persistent_workers=persistent_workers) + return dl_bundle elif isinstance(ds_or_db, Sequence): ds_seq = [] for ds in ds_or_db: diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 923f6415..b827e1ab 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -14,7 +14,6 @@ from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, if _NEED_IMPORT_TORCH: from torch.utils.data import DataLoader, Sampler - from torch.utils.data._utils.collate import default_collate else: from fastNLP.core.utils.dummy_class import DummyClass as DataLoader @@ -73,12 +72,15 @@ class TorchDataLoader(DataLoader): :param prefetch_factor: :param persistent_workers: """ + if isinstance(dataset, DataSet) and collate_fn is None: + raise ValueError("When use FastNLP DataSet, collate_fn must be not None") + if not isinstance(dataset, _FDataSet): dataset = _FDataSet(dataset) if sampler is None and batch_sampler is None: sampler = RandomSampler(dataset, shuffle=shuffle) - shuffle=False + shuffle = False if isinstance(collate_fn, str): if collate_fn == 'auto': if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset @@ -107,8 +109,8 @@ class TorchDataLoader(DataLoader): self.cur_batch_indices = indices yield data - def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, - pad_fn:Callable=None) -> Collator: + def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, + pad_fn: Callable = None) -> Collator: """ 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 @@ -174,12 +176,12 @@ class TorchDataLoader(DataLoader): return self.cur_batch_indices - def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]], batch_size: int = 16, - shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, + shuffle: bool = True, + sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, - num_workers: int = 0, collate_fn: Union[str, Callable, None] = None, + num_workers: int = 0, collate_fn: Union[str, Callable, None] = 'auto', pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[Callable] = None, multiprocessing_context=None, generator=None, prefetch_factor: int = 2, @@ -220,7 +222,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping ) return dl - elif type(ds_or_db, DataBundle): + elif isinstance(ds_or_db, DataBundle): dl_bundle = {} for name, ds in ds_or_db.iter_datasets(): if 'train' in name: diff --git a/fastNLP/core/dataloaders/utils.py b/fastNLP/core/dataloaders/utils.py index 0ee496c5..39ce5983 100644 --- a/fastNLP/core/dataloaders/utils.py +++ b/fastNLP/core/dataloaders/utils.py @@ -10,7 +10,7 @@ def indice_collate_wrapper(func): :param func: 需要修饰的函数 :return: """ - if func.__name__ == '_indice_collate_wrapper': # 如果已经被包裹过了 + if hasattr(func, '__name__') and func.__name__ == '_indice_collate_wrapper': # 如果已经被包裹过了 return func def _indice_collate_wrapper(tuple_data): # 这里不能使用 functools.wraps ,否则会检测不到 diff --git a/tests/core/dataloaders/torch_dataloader/test_fdl.py b/tests/core/dataloaders/torch_dataloader/test_fdl.py index 52fe48ff..8aa12ab6 100644 --- a/tests/core/dataloaders/torch_dataloader/test_fdl.py +++ b/tests/core/dataloaders/torch_dataloader/test_fdl.py @@ -3,6 +3,10 @@ import pytest from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_torch_dataloader from fastNLP.core.dataset import DataSet from fastNLP.io.data_bundle import DataBundle +from fastNLP.envs.imports import _NEED_IMPORT_TORCH + +if _NEED_IMPORT_TORCH: + import torch @pytest.mark.torch @@ -22,20 +26,16 @@ class TestFdl: fdl = TorchDataLoader(ds, batch_size=3) fdl.set_pad("x", -1) for batch in fdl: - print(batch) - # fdl.set_pad_val("x", val=-2) - # for batch in fdl: - # print(batch) + assert batch['x'].shape == torch.Size([3, 4]) def test_get_batch_indices(self): ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) fdl = TorchDataLoader(ds, batch_size=3, shuffle=True) for batch in fdl: - print(fdl.get_batch_indices()) + assert len(fdl.get_batch_indices()) == 3 def test_other_dataset(self): import numpy as np - class _DataSet: def __init__(self): @@ -55,8 +55,41 @@ class TestFdl: # dl.set_inputs('data', 'labels') # dl.set_pad_val('labels', val=None) for batch in dl: - print(batch) - print(dl.get_batch_indices()) + assert batch[0].shape == torch.Size([2, 5]) + assert batch[1].shape == torch.Size([2, 2, 3]) + + def test_default_collate_fn(self): + ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + with pytest.raises(ValueError): + fdl = TorchDataLoader(ds, batch_size=3, collate_fn=None) + import numpy as np + class _DataSet: + + def __init__(self): + pass + + def __getitem__(self, item): + return np.random.randn(5), [[1, 2], [2, 3, 4]] + + def __len__(self): + return 10 + + fdl = TorchDataLoader(_DataSet(), batch_size=3, collate_fn=None, drop_last=True) + for batch in fdl: + assert batch[0].shape == torch.Size([3, 5]) + + def test_my_collate_fn(self): + ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + def collate_fn(batch): + res = {'x': [], 'y': []} + for ins in batch: + res['x'].append(ins['x']) + res['y'].append(ins['y']) + return res + fdl = TorchDataLoader(ds, collate_fn=collate_fn, batch_size=3, drop_last=True) + for batch in fdl: + assert batch['x'] == [[1, 2], [2, 3, 4], [4, 5, 6, 7]] + assert batch['y'] == [1, 0, 1] def test_prepare_torch_dataloader(self): ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})