Browse Source

修改jittor fdl

tags/v1.0.0alpha
MorningForest 3 years ago
parent
commit
5d564a58cf
2 changed files with 2 additions and 0 deletions
  1. +1
    -0
      fastNLP/core/dataloaders/jittor_dataloader/fdl.py
  2. +1
    -0
      tests/core/dataloaders/jittor_dataloader/test_fdl.py

+ 1
- 0
fastNLP/core/dataloaders/jittor_dataloader/fdl.py View File

@@ -91,6 +91,7 @@ class JittorDataLoader:
self.dataset.dataset.set_attrs(batch_size=1)
# 用户提供了 collate_fn,则会自动代替 jittor 提供 collate_batch 函数
# self._collate_fn = _collate_fn
self.cur_batch_indices = None

def __iter__(self):
# TODO 第一次迭代后不能设置collate_fn,设置是无效的


+ 1
- 0
tests/core/dataloaders/jittor_dataloader/test_fdl.py View File

@@ -42,6 +42,7 @@ class TestJittor:
jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4)
# jtl.set_pad_val('x', 'y')
# jtl.set_input('x')
print(str(jittor.Var([0])))
for batch in jtl:
print(batch)
print(jtl.get_batch_indices())


Loading…
Cancel
Save