|
|
@@ -9,7 +9,6 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE |
|
|
|
|
|
|
|
if _NEED_IMPORT_PADDLE: |
|
|
|
from paddle.io import DataLoader, Dataset, Sampler |
|
|
|
from paddle.fluid.dataloader.collate import default_collate_fn |
|
|
|
else: |
|
|
|
from fastNLP.core.utils.dummy_class import DummyClass as Dataset |
|
|
|
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader |
|
|
@@ -52,6 +51,9 @@ class PaddleDataLoader(DataLoader): |
|
|
|
num_workers: int = 0, use_buffer_reader: bool = True, |
|
|
|
use_shared_memory: bool = True, timeout: int = 0, |
|
|
|
worker_init_fn: Callable = None, persistent_workers=False) -> None: |
|
|
|
# FastNLP Datset, collate_fn not None |
|
|
|
if isinstance(dataset, FDataSet) and collate_fn is None: |
|
|
|
raise ValueError("When use FastNLP DataSet, collate_fn must be not None") |
|
|
|
|
|
|
|
if not isinstance(dataset, _PaddleDataset): |
|
|
|
dataset = _PaddleDataset(dataset) |
|
|
@@ -66,10 +68,10 @@ class PaddleDataLoader(DataLoader): |
|
|
|
if isinstance(collate_fn, str): |
|
|
|
if collate_fn == 'auto': |
|
|
|
if isinstance(dataset.dataset, FDataSet): |
|
|
|
self._collate_fn = dataset.dataset.collator |
|
|
|
self._collate_fn.set_backend(backend="paddle") |
|
|
|
collate_fn = dataset.dataset.collator |
|
|
|
collate_fn.set_backend(backend="paddle") |
|
|
|
else: |
|
|
|
self._collate_fn = Collator(backend="paddle") |
|
|
|
collate_fn = Collator(backend="paddle") |
|
|
|
|
|
|
|
else: |
|
|
|
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") |
|
|
|