Browse Source

修改fdl测试用例

tags/v1.0.0alpha
MorningForest 2 years ago
parent
commit
cdf3474e2e
3 changed files with 12 additions and 9 deletions
  1. +6
    -4
      fastNLP/core/dataloaders/paddle_dataloader/fdl.py
  2. +5
    -5
      tests/core/dataloaders/paddle_dataloader/test_fdl.py
  3. +1
    -0
      tests/core/dataloaders/torch_dataloader/test_fdl.py

+ 6
- 4
fastNLP/core/dataloaders/paddle_dataloader/fdl.py View File

@@ -9,7 +9,6 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE

if _NEED_IMPORT_PADDLE:
from paddle.io import DataLoader, Dataset, Sampler
from paddle.fluid.dataloader.collate import default_collate_fn
else:
from fastNLP.core.utils.dummy_class import DummyClass as Dataset
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader
@@ -52,6 +51,9 @@ class PaddleDataLoader(DataLoader):
num_workers: int = 0, use_buffer_reader: bool = True,
use_shared_memory: bool = True, timeout: int = 0,
worker_init_fn: Callable = None, persistent_workers=False) -> None:
# FastNLP Datset, collate_fn not None
if isinstance(dataset, FDataSet) and collate_fn is None:
raise ValueError("When use FastNLP DataSet, collate_fn must be not None")

if not isinstance(dataset, _PaddleDataset):
dataset = _PaddleDataset(dataset)
@@ -66,10 +68,10 @@ class PaddleDataLoader(DataLoader):
if isinstance(collate_fn, str):
if collate_fn == 'auto':
if isinstance(dataset.dataset, FDataSet):
self._collate_fn = dataset.dataset.collator
self._collate_fn.set_backend(backend="paddle")
collate_fn = dataset.dataset.collator
collate_fn.set_backend(backend="paddle")
else:
self._collate_fn = Collator(backend="paddle")
collate_fn = Collator(backend="paddle")

else:
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'")


+ 5
- 5
tests/core/dataloaders/paddle_dataloader/test_fdl.py View File

@@ -6,14 +6,14 @@ from fastNLP.core.dataset import DataSet
from fastNLP.core.log import logger

from fastNLP.envs.imports import _NEED_IMPORT_PADDLE

if _NEED_IMPORT_PADDLE:
from paddle.io import Dataset, DataLoader
from paddle.io import Dataset
import paddle
else:
from fastNLP.core.utils.dummy_class import DummyClass as Dataset



class RandomDataset(Dataset):

def __getitem__(self, idx):
@@ -33,15 +33,15 @@ class TestPaddle:
fdl = PaddleDataLoader(ds, batch_size=2)
# fdl = DataLoader(ds, batch_size=2, shuffle=True)
for batch in fdl:
print(batch)
assert batch['image'].shape == [2, 10, 5]
assert batch['label'].shape == [2, 2, 4]
# print(fdl.get_batch_indices())

def test_fdl_batch_indices(self):
def test_fdl_fastnlp_dataset(self):
ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10})
fdl = PaddleDataLoader(ds, batch_size=4, shuffle=True, drop_last=True)
for batch in fdl:
assert len(fdl.get_batch_indices()) == 4
print(batch)
print(fdl.get_batch_indices())

def test_set_inputs_and_set_pad_val(self):


+ 1
- 0
tests/core/dataloaders/torch_dataloader/test_fdl.py View File

@@ -4,6 +4,7 @@ from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_t
from fastNLP.core.dataset import DataSet
from fastNLP.io.data_bundle import DataBundle
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core import Trainer

if _NEED_IMPORT_TORCH:
import torch


Loading…
Cancel
Save