@@ -9,7 +9,6 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
from paddle.io import DataLoader, Dataset, Sampler | from paddle.io import DataLoader, Dataset, Sampler | ||||
from paddle.fluid.dataloader.collate import default_collate_fn | |||||
else: | else: | ||||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset | from fastNLP.core.utils.dummy_class import DummyClass as Dataset | ||||
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | ||||
@@ -52,6 +51,9 @@ class PaddleDataLoader(DataLoader): | |||||
num_workers: int = 0, use_buffer_reader: bool = True, | num_workers: int = 0, use_buffer_reader: bool = True, | ||||
use_shared_memory: bool = True, timeout: int = 0, | use_shared_memory: bool = True, timeout: int = 0, | ||||
worker_init_fn: Callable = None, persistent_workers=False) -> None: | worker_init_fn: Callable = None, persistent_workers=False) -> None: | ||||
# FastNLP Datset, collate_fn not None | |||||
if isinstance(dataset, FDataSet) and collate_fn is None: | |||||
raise ValueError("When use FastNLP DataSet, collate_fn must be not None") | |||||
if not isinstance(dataset, _PaddleDataset): | if not isinstance(dataset, _PaddleDataset): | ||||
dataset = _PaddleDataset(dataset) | dataset = _PaddleDataset(dataset) | ||||
@@ -66,10 +68,10 @@ class PaddleDataLoader(DataLoader): | |||||
if isinstance(collate_fn, str): | if isinstance(collate_fn, str): | ||||
if collate_fn == 'auto': | if collate_fn == 'auto': | ||||
if isinstance(dataset.dataset, FDataSet): | if isinstance(dataset.dataset, FDataSet): | ||||
self._collate_fn = dataset.dataset.collator | |||||
self._collate_fn.set_backend(backend="paddle") | |||||
collate_fn = dataset.dataset.collator | |||||
collate_fn.set_backend(backend="paddle") | |||||
else: | else: | ||||
self._collate_fn = Collator(backend="paddle") | |||||
collate_fn = Collator(backend="paddle") | |||||
else: | else: | ||||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | ||||
@@ -6,14 +6,14 @@ from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
from paddle.io import Dataset, DataLoader | |||||
from paddle.io import Dataset | |||||
import paddle | import paddle | ||||
else: | else: | ||||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset | from fastNLP.core.utils.dummy_class import DummyClass as Dataset | ||||
class RandomDataset(Dataset): | class RandomDataset(Dataset): | ||||
def __getitem__(self, idx): | def __getitem__(self, idx): | ||||
@@ -33,15 +33,15 @@ class TestPaddle: | |||||
fdl = PaddleDataLoader(ds, batch_size=2) | fdl = PaddleDataLoader(ds, batch_size=2) | ||||
# fdl = DataLoader(ds, batch_size=2, shuffle=True) | # fdl = DataLoader(ds, batch_size=2, shuffle=True) | ||||
for batch in fdl: | for batch in fdl: | ||||
print(batch) | |||||
assert batch['image'].shape == [2, 10, 5] | |||||
assert batch['label'].shape == [2, 2, 4] | |||||
# print(fdl.get_batch_indices()) | # print(fdl.get_batch_indices()) | ||||
def test_fdl_batch_indices(self): | |||||
def test_fdl_fastnlp_dataset(self): | |||||
ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10}) | ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10}) | ||||
fdl = PaddleDataLoader(ds, batch_size=4, shuffle=True, drop_last=True) | fdl = PaddleDataLoader(ds, batch_size=4, shuffle=True, drop_last=True) | ||||
for batch in fdl: | for batch in fdl: | ||||
assert len(fdl.get_batch_indices()) == 4 | assert len(fdl.get_batch_indices()) == 4 | ||||
print(batch) | |||||
print(fdl.get_batch_indices()) | print(fdl.get_batch_indices()) | ||||
def test_set_inputs_and_set_pad_val(self): | def test_set_inputs_and_set_pad_val(self): | ||||
@@ -4,6 +4,7 @@ from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_t | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.io.data_bundle import DataBundle | from fastNLP.io.data_bundle import DataBundle | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
from fastNLP.core import Trainer | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||