Browse Source

PaddleDataLoader 1、初始化时如果 batch_sampler 为 None 时重置 shuffle 等参数 2、测试例中删除了不必要的 test_v2

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
38e578f6b5
2 changed files with 5 additions and 9 deletions
  1. +5
    -1
      fastNLP/core/dataloaders/paddle_dataloader/fdl.py
  2. +0
    -8
      tests/core/dataloaders/paddle_dataloader/test_fdl.py

+ 5
- 1
fastNLP/core/dataloaders/paddle_dataloader/fdl.py View File

@@ -8,11 +8,12 @@ from typing import Callable, List, Optional, Union, Dict, Sequence
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.envs.imports import _NEED_IMPORT_PADDLE


if _NEED_IMPORT_PADDLE: if _NEED_IMPORT_PADDLE:
from paddle.io import DataLoader, Dataset
from paddle.io import DataLoader, Dataset, Sampler
from paddle.fluid.dataloader.collate import default_collate_fn from paddle.fluid.dataloader.collate import default_collate_fn
else: else:
from fastNLP.core.utils.dummy_class import DummyClass as Dataset from fastNLP.core.utils.dummy_class import DummyClass as Dataset
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader from fastNLP.core.utils.dummy_class import DummyClass as DataLoader
from fastNLP.core.utils.dummy_class import DummyClass as Sampler


from fastNLP.core.collators.collator import Collator from fastNLP.core.collators.collator import Collator
from fastNLP.core.dataloaders.utils import indice_collate_wrapper from fastNLP.core.dataloaders.utils import indice_collate_wrapper
@@ -58,6 +59,9 @@ class PaddleDataLoader(DataLoader):
if batch_sampler is None: if batch_sampler is None:
batch_sampler = RandomBatchSampler(dataset, batch_size=batch_size, shuffle=shuffle, batch_sampler = RandomBatchSampler(dataset, batch_size=batch_size, shuffle=shuffle,
drop_last=drop_last) drop_last=drop_last)
batch_size = 1
shuffle = False
drop_last = False


super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places,
return_list=return_list, batch_sampler=batch_sampler, return_list=return_list, batch_sampler=batch_sampler,


+ 0
- 8
tests/core/dataloaders/paddle_dataloader/test_fdl.py View File

@@ -58,11 +58,3 @@ class TestPaddle:
for batch in fdl1: for batch in fdl1:
assert batch['image'].shape == [4, 10, 5] assert batch['image'].shape == [4, 10, 5]
print(batch) print(batch)

def test_v2(self):
from fastNLP.core.collators import Collator
logger.setLevel("DEBUG")
data = [paddle.Tensor(np.random.random((10, 5)).astype('float32')), paddle.Tensor(np.random.random((10, 5)).astype('float32'))]
col = Collator(backend="jittor")
res = col(data)
print(res)

Loading…
Cancel
Save