Browse Source

修改jittor fdl

tags/v1.0.0alpha
MorningForest 3 years ago
parent
commit
d61bfe8e47
4 changed files with 9 additions and 3 deletions
  1. +3
    -1
      fastNLP/core/dataloaders/__init__.py
  2. +1
    -1
      fastNLP/core/dataloaders/jittor_dataloader/fdl.py
  3. +5
    -0
      fastNLP/core/dataloaders/utils.py
  4. +0
    -1
      tests/core/dataloaders/jittor_dataloader/test_fdl.py

+ 3
- 1
fastNLP/core/dataloaders/__init__.py View File

@@ -5,10 +5,12 @@ __all__ = [
'JittorDataLoader',
'prepare_jittor_dataloader',
'prepare_paddle_dataloader',
'prepare_torch_dataloader'
'prepare_torch_dataloader',
'indice_collate_wrapper'
]

from .mix_dataloader import MixDataLoader
from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader
from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader
from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader
from .utils import indice_collate_wrapper

+ 1
- 1
fastNLP/core/dataloaders/jittor_dataloader/fdl.py View File

@@ -12,7 +12,7 @@ if _NEED_IMPORT_JITTOR:
from jittor.dataset import Dataset
else:
from fastNLP.core.dataset import DataSet as Dataset
from fastNLP.core.utils.jittor_utils import jittor_collate_wraps
from fastNLP.core.collators import Collator
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
from fastNLP.core.dataset import DataSet as FDataSet


+ 5
- 0
fastNLP/core/dataloaders/utils.py View File

@@ -1,3 +1,8 @@
__all__ = [
"indice_collate_wrapper"
]


def indice_collate_wrapper(func):
"""
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。


+ 0
- 1
tests/core/dataloaders/jittor_dataloader/test_fdl.py View File

@@ -42,7 +42,6 @@ class TestJittor:
jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4)
# jtl.set_pad_val('x', 'y')
# jtl.set_input('x')
print(str(jittor.Var([0])))
for batch in jtl:
print(batch)
print(jtl.get_batch_indices())


Loading…
Cancel
Save