diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py index fa99be22..952759f7 100644 --- a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py @@ -8,11 +8,12 @@ from typing import Callable, List, Optional, Union, Dict, Sequence from fastNLP.envs.imports import _NEED_IMPORT_PADDLE if _NEED_IMPORT_PADDLE: - from paddle.io import DataLoader, Dataset + from paddle.io import DataLoader, Dataset, Sampler from paddle.fluid.dataloader.collate import default_collate_fn else: from fastNLP.core.utils.dummy_class import DummyClass as Dataset from fastNLP.core.utils.dummy_class import DummyClass as DataLoader + from fastNLP.core.utils.dummy_class import DummyClass as Sampler from fastNLP.core.collators.collator import Collator from fastNLP.core.dataloaders.utils import indice_collate_wrapper @@ -58,6 +59,9 @@ class PaddleDataLoader(DataLoader): if batch_sampler is None: batch_sampler = RandomBatchSampler(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) + batch_size = 1 + shuffle = False + drop_last = False super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, return_list=return_list, batch_sampler=batch_sampler, diff --git a/tests/core/dataloaders/paddle_dataloader/test_fdl.py b/tests/core/dataloaders/paddle_dataloader/test_fdl.py index d8ba521b..6632ad17 100644 --- a/tests/core/dataloaders/paddle_dataloader/test_fdl.py +++ b/tests/core/dataloaders/paddle_dataloader/test_fdl.py @@ -58,11 +58,3 @@ class TestPaddle: for batch in fdl1: assert batch['image'].shape == [4, 10, 5] print(batch) - - def test_v2(self): - from fastNLP.core.collators import Collator - logger.setLevel("DEBUG") - data = [paddle.Tensor(np.random.random((10, 5)).astype('float32')), paddle.Tensor(np.random.random((10, 5)).astype('float32'))] - col = Collator(backend="jittor") - res = col(data) - print(res) \ No newline at end of file