From cdf3474e2e8763024df1961c7dbe9477f50eee94 Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Thu, 5 May 2022 20:23:15 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9fdl=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataloaders/paddle_dataloader/fdl.py | 10 ++++++---- tests/core/dataloaders/paddle_dataloader/test_fdl.py | 10 +++++----- tests/core/dataloaders/torch_dataloader/test_fdl.py | 1 + 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py index b157dd68..b063644e 100644 --- a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py @@ -9,7 +9,6 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE if _NEED_IMPORT_PADDLE: from paddle.io import DataLoader, Dataset, Sampler - from paddle.fluid.dataloader.collate import default_collate_fn else: from fastNLP.core.utils.dummy_class import DummyClass as Dataset from fastNLP.core.utils.dummy_class import DummyClass as DataLoader @@ -52,6 +51,9 @@ class PaddleDataLoader(DataLoader): num_workers: int = 0, use_buffer_reader: bool = True, use_shared_memory: bool = True, timeout: int = 0, worker_init_fn: Callable = None, persistent_workers=False) -> None: + # FastNLP Datset, collate_fn not None + if isinstance(dataset, FDataSet) and collate_fn is None: + raise ValueError("When use FastNLP DataSet, collate_fn must be not None") if not isinstance(dataset, _PaddleDataset): dataset = _PaddleDataset(dataset) @@ -66,10 +68,10 @@ class PaddleDataLoader(DataLoader): if isinstance(collate_fn, str): if collate_fn == 'auto': if isinstance(dataset.dataset, FDataSet): - self._collate_fn = dataset.dataset.collator - self._collate_fn.set_backend(backend="paddle") + collate_fn = dataset.dataset.collator + collate_fn.set_backend(backend="paddle") else: - self._collate_fn = Collator(backend="paddle") + collate_fn = Collator(backend="paddle") else: raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") diff --git a/tests/core/dataloaders/paddle_dataloader/test_fdl.py b/tests/core/dataloaders/paddle_dataloader/test_fdl.py index 6632ad17..08f71cac 100644 --- a/tests/core/dataloaders/paddle_dataloader/test_fdl.py +++ b/tests/core/dataloaders/paddle_dataloader/test_fdl.py @@ -6,14 +6,14 @@ from fastNLP.core.dataset import DataSet from fastNLP.core.log import logger from fastNLP.envs.imports import _NEED_IMPORT_PADDLE + if _NEED_IMPORT_PADDLE: - from paddle.io import Dataset, DataLoader + from paddle.io import Dataset import paddle else: from fastNLP.core.utils.dummy_class import DummyClass as Dataset - class RandomDataset(Dataset): def __getitem__(self, idx): @@ -33,15 +33,15 @@ class TestPaddle: fdl = PaddleDataLoader(ds, batch_size=2) # fdl = DataLoader(ds, batch_size=2, shuffle=True) for batch in fdl: - print(batch) + assert batch['image'].shape == [2, 10, 5] + assert batch['label'].shape == [2, 2, 4] # print(fdl.get_batch_indices()) - def test_fdl_batch_indices(self): + def test_fdl_fastnlp_dataset(self): ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10}) fdl = PaddleDataLoader(ds, batch_size=4, shuffle=True, drop_last=True) for batch in fdl: assert len(fdl.get_batch_indices()) == 4 - print(batch) print(fdl.get_batch_indices()) def test_set_inputs_and_set_pad_val(self): diff --git a/tests/core/dataloaders/torch_dataloader/test_fdl.py b/tests/core/dataloaders/torch_dataloader/test_fdl.py index 8aa12ab6..f3d0136d 100644 --- a/tests/core/dataloaders/torch_dataloader/test_fdl.py +++ b/tests/core/dataloaders/torch_dataloader/test_fdl.py @@ -4,6 +4,7 @@ from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_t from fastNLP.core.dataset import DataSet from fastNLP.io.data_bundle import DataBundle from fastNLP.envs.imports import _NEED_IMPORT_TORCH +from fastNLP.core import Trainer if _NEED_IMPORT_TORCH: import torch