diff --git a/fastNLP/core/dataloaders/__init__.py b/fastNLP/core/dataloaders/__init__.py index 40dd7b1c..e9dc51b4 100644 --- a/fastNLP/core/dataloaders/__init__.py +++ b/fastNLP/core/dataloaders/__init__.py @@ -5,10 +5,12 @@ __all__ = [ 'JittorDataLoader', 'prepare_jittor_dataloader', 'prepare_paddle_dataloader', - 'prepare_torch_dataloader' + 'prepare_torch_dataloader', + 'indice_collate_wrapper' ] from .mix_dataloader import MixDataLoader from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader +from .utils import indice_collate_wrapper diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index 507073a4..2345a9b9 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -12,7 +12,7 @@ if _NEED_IMPORT_JITTOR: from jittor.dataset import Dataset else: from fastNLP.core.dataset import DataSet as Dataset -from fastNLP.core.utils.jittor_utils import jittor_collate_wraps + from fastNLP.core.collators import Collator from fastNLP.core.dataloaders.utils import indice_collate_wrapper from fastNLP.core.dataset import DataSet as FDataSet diff --git a/fastNLP/core/dataloaders/utils.py b/fastNLP/core/dataloaders/utils.py index a71dc50c..2305cebe 100644 --- a/fastNLP/core/dataloaders/utils.py +++ b/fastNLP/core/dataloaders/utils.py @@ -1,3 +1,8 @@ +__all__ = [ + "indice_collate_wrapper" +] + + def indice_collate_wrapper(func): """ 其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 diff --git a/tests/core/dataloaders/jittor_dataloader/test_fdl.py b/tests/core/dataloaders/jittor_dataloader/test_fdl.py index b3124397..92b49c09 100644 --- a/tests/core/dataloaders/jittor_dataloader/test_fdl.py +++ b/tests/core/dataloaders/jittor_dataloader/test_fdl.py @@ -42,7 +42,6 @@ class TestJittor: jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4) # jtl.set_pad_val('x', 'y') # jtl.set_input('x') - print(str(jittor.Var([0]))) for batch in jtl: print(batch) print(jtl.get_batch_indices())