From 5d564a58cf96cf85e33e2b0c51547ffd11cd8929 Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Tue, 3 May 2022 22:20:13 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9jittor=20fdl?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataloaders/jittor_dataloader/fdl.py | 1 + tests/core/dataloaders/jittor_dataloader/test_fdl.py | 1 + 2 files changed, 2 insertions(+) diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index 3e9cf17a..787fcb69 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -91,6 +91,7 @@ class JittorDataLoader: self.dataset.dataset.set_attrs(batch_size=1) # 用户提供了 collate_fn,则会自动代替 jittor 提供 collate_batch 函数 # self._collate_fn = _collate_fn + self.cur_batch_indices = None def __iter__(self): # TODO 第一次迭代后不能设置collate_fn,设置是无效的 diff --git a/tests/core/dataloaders/jittor_dataloader/test_fdl.py b/tests/core/dataloaders/jittor_dataloader/test_fdl.py index 92b49c09..b3124397 100644 --- a/tests/core/dataloaders/jittor_dataloader/test_fdl.py +++ b/tests/core/dataloaders/jittor_dataloader/test_fdl.py @@ -42,6 +42,7 @@ class TestJittor: jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4) # jtl.set_pad_val('x', 'y') # jtl.set_input('x') + print(str(jittor.Var([0]))) for batch in jtl: print(batch) print(jtl.get_batch_indices())