@@ -5,10 +5,12 @@ __all__ = [ | |||||
'JittorDataLoader', | 'JittorDataLoader', | ||||
'prepare_jittor_dataloader', | 'prepare_jittor_dataloader', | ||||
'prepare_paddle_dataloader', | 'prepare_paddle_dataloader', | ||||
'prepare_torch_dataloader' | |||||
'prepare_torch_dataloader', | |||||
'indice_collate_wrapper' | |||||
] | ] | ||||
from .mix_dataloader import MixDataLoader | from .mix_dataloader import MixDataLoader | ||||
from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader | from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader | ||||
from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader | from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader | ||||
from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader | from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader | ||||
from .utils import indice_collate_wrapper |
@@ -12,7 +12,7 @@ if _NEED_IMPORT_JITTOR: | |||||
from jittor.dataset import Dataset | from jittor.dataset import Dataset | ||||
else: | else: | ||||
from fastNLP.core.dataset import DataSet as Dataset | from fastNLP.core.dataset import DataSet as Dataset | ||||
from fastNLP.core.utils.jittor_utils import jittor_collate_wraps | |||||
from fastNLP.core.collators import Collator | from fastNLP.core.collators import Collator | ||||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper | from fastNLP.core.dataloaders.utils import indice_collate_wrapper | ||||
from fastNLP.core.dataset import DataSet as FDataSet | from fastNLP.core.dataset import DataSet as FDataSet | ||||
@@ -1,3 +1,8 @@ | |||||
__all__ = [ | |||||
"indice_collate_wrapper" | |||||
] | |||||
def indice_collate_wrapper(func): | def indice_collate_wrapper(func): | ||||
""" | """ | ||||
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 | 其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 | ||||
@@ -42,7 +42,6 @@ class TestJittor: | |||||
jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4) | jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4) | ||||
# jtl.set_pad_val('x', 'y') | # jtl.set_pad_val('x', 'y') | ||||
# jtl.set_input('x') | # jtl.set_input('x') | ||||
print(str(jittor.Var([0]))) | |||||
for batch in jtl: | for batch in jtl: | ||||
print(batch) | print(batch) | ||||
print(jtl.get_batch_indices()) | print(jtl.get_batch_indices()) | ||||