|
@@ -8,11 +8,12 @@ from typing import Callable, List, Optional, Union, Dict, Sequence |
|
|
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE |
|
|
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE |
|
|
|
|
|
|
|
|
if _NEED_IMPORT_PADDLE: |
|
|
if _NEED_IMPORT_PADDLE: |
|
|
from paddle.io import DataLoader, Dataset |
|
|
|
|
|
|
|
|
from paddle.io import DataLoader, Dataset, Sampler |
|
|
from paddle.fluid.dataloader.collate import default_collate_fn |
|
|
from paddle.fluid.dataloader.collate import default_collate_fn |
|
|
else: |
|
|
else: |
|
|
from fastNLP.core.utils.dummy_class import DummyClass as Dataset |
|
|
from fastNLP.core.utils.dummy_class import DummyClass as Dataset |
|
|
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader |
|
|
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader |
|
|
|
|
|
from fastNLP.core.utils.dummy_class import DummyClass as Sampler |
|
|
|
|
|
|
|
|
from fastNLP.core.collators.collator import Collator |
|
|
from fastNLP.core.collators.collator import Collator |
|
|
from fastNLP.core.dataloaders.utils import indice_collate_wrapper |
|
|
from fastNLP.core.dataloaders.utils import indice_collate_wrapper |
|
@@ -58,6 +59,9 @@ class PaddleDataLoader(DataLoader): |
|
|
if batch_sampler is None: |
|
|
if batch_sampler is None: |
|
|
batch_sampler = RandomBatchSampler(dataset, batch_size=batch_size, shuffle=shuffle, |
|
|
batch_sampler = RandomBatchSampler(dataset, batch_size=batch_size, shuffle=shuffle, |
|
|
drop_last=drop_last) |
|
|
drop_last=drop_last) |
|
|
|
|
|
batch_size = 1 |
|
|
|
|
|
shuffle = False |
|
|
|
|
|
drop_last = False |
|
|
|
|
|
|
|
|
super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, |
|
|
super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, |
|
|
return_list=return_list, batch_sampler=batch_sampler, |
|
|
return_list=return_list, batch_sampler=batch_sampler, |
|
|