diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index 3c37efcf..48feea0b 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -26,20 +26,11 @@ class _JittorDataset(Dataset): def __init__(self, dataset) -> None: super(_JittorDataset, self).__init__() self.dataset = dataset + self.total_len = len(dataset) def __getitem__(self, item): return (item, self.dataset[item]) - def __len__(self) -> int: - return len(self.dataset) - - # def __getattr__(self, item): - # # jittor的Dataset没有的方法而用户的dataset存在且实现了getattribute方法,此时用户可以调用 - # try: - # self.dataset.__getattribute__(item) - # except Exception as e: - # raise e - class JittorDataLoader: """ @@ -62,13 +53,17 @@ class JittorDataLoader: :param keep_numpy_array: :param endless: :param collate_fn: 对取得到的数据进行打包的callable函数 - :param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型 """ # TODO 验证支持replacesampler (以后完成) + # 将内部dataset批次设置为1 + if isinstance(dataset, Dataset): + dataset.set_attrs(batch_size=1) + # FastNLP Datset, collate_fn not None if isinstance(dataset, FDataSet) and collate_fn is None: raise ValueError("When use FastNLP DataSet, collate_fn must be not None") + # 将所有dataset转为jittor类型的dataset if not isinstance(dataset, _JittorDataset): self.dataset = _JittorDataset(dataset) @@ -82,17 +77,13 @@ class JittorDataLoader: else: raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") elif isinstance(collate_fn, Callable): - if collate_fn is not collate_batch: - self.collate_fn = collate_fn + self.collate_fn = collate_fn else: self.collate_fn = collate_batch self.dataset.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless) - # 将内部dataset批次设置为1 - if isinstance(self.dataset.dataset, Dataset): - self.dataset.dataset.set_attrs(batch_size=1) self.cur_batch_indices = None @@ -105,12 +96,10 @@ class JittorDataLoader: yield data def __len__(self): - if self.dataset.drop_last: - return len(self.dataset) // self.dataset.batch_size - return (len(self.dataset) - 1) // self.dataset.batch_size + 1 + return len(self.dataset) def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, - pad_fn: Callable = None) -> "JittorDataLoader": + pad_fn: Callable = None) -> Collator: """ 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 @@ -129,14 +118,27 @@ class JittorDataLoader: 形式,输出将被直接作为结果输出。 :return: 返回 Collator 自身 """ - if isinstance(self.collate_fn, Collator): - self.collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, - backend=backend) - return self + collator = self._get_collator() + if isinstance(collator, Collator): + collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) + return collator else: raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") - def set_ignore(self, *field_names) -> "JittorDataLoader": + def _get_collator(self): + """ + 如果 collate_fn 是 Collator 对象,得到该对象。如果没有的话,返回 None + + :return: + """ + collator = None + if hasattr(self.collate_fn, '__wrapped__') and isinstance(self.collate_fn.__wrapped__, Collator): + collator = self.collate_fn.__wrapped__ + elif isinstance(self.collate_fn, Collator): + collator = self.collate_fn + return collator + + def set_ignore(self, *field_names) -> Collator: """ 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 Example:: @@ -147,9 +149,10 @@ class JittorDataLoader: __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 :return: 返回 Collator 自身 """ - if isinstance(self.collate_fn, Collator): - self.collate_fn.set_ignore(*field_names) - return self + collator = self._get_collator() + if isinstance(collator, Collator): + collator.set_ignore(*field_names) + return collator else: raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") diff --git a/tests/core/dataloaders/jittor_dataloader/test_fdl.py b/tests/core/dataloaders/jittor_dataloader/test_fdl.py index 2a834ee8..204863a5 100644 --- a/tests/core/dataloaders/jittor_dataloader/test_fdl.py +++ b/tests/core/dataloaders/jittor_dataloader/test_fdl.py @@ -4,6 +4,7 @@ from datasets import Dataset as HfDataset from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader from fastNLP.core.dataset import DataSet as Fdataset +from fastNLP.core.collators import Collator from fastNLP.envs.imports import _NEED_IMPORT_JITTOR if _NEED_IMPORT_JITTOR: from jittor.dataset import Dataset @@ -53,9 +54,9 @@ class TestJittor: jtl.set_ignore("y") for batch in jtl: assert batch['x'].size() == (16, 4) - jtl = JittorDataLoader(dataset, batch_size=16, drop_last=True, num_workers=2) - - + jtl1 = JittorDataLoader(dataset, batch_size=16, drop_last=True, num_workers=2) + for batch in jtl1: + print(batch) def test_huggingface_datasets(self): @@ -79,4 +80,11 @@ class TestJittor: for idx, batch in enumerate(dataset): print(idx, batch.shape) for idx, batch in enumerate(dataset): - print(idx, batch.shape) \ No newline at end of file + print(idx, batch.shape) + + def test_jittor_get_backend(self): + collate_bacth = Collator(backend='auto') + dl = MyDataset() + dl = dl.set_attrs(collate_batch=collate_bacth, batch_size=256) + for batch in dl: + print(batch) \ No newline at end of file diff --git a/tests/core/dataloaders/paddle_dataloader/test_fdl.py b/tests/core/dataloaders/paddle_dataloader/test_fdl.py index 29489caa..229727e7 100644 --- a/tests/core/dataloaders/paddle_dataloader/test_fdl.py +++ b/tests/core/dataloaders/paddle_dataloader/test_fdl.py @@ -4,11 +4,12 @@ import numpy as np from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader from fastNLP.core.dataset import DataSet from fastNLP.core.log import logger +from fastNLP.core.collators import Collator from fastNLP.envs.imports import _NEED_IMPORT_PADDLE if _NEED_IMPORT_PADDLE: - from paddle.io import Dataset + from paddle.io import Dataset, DataLoader import paddle else: from fastNLP.core.utils.dummy_class import DummyClass as Dataset @@ -61,3 +62,10 @@ class TestPaddle: fdl1.set_ignore('label') for batch in fdl1: assert batch['image'].shape == [4, 10, 5] + + def test_get_backend(self): + ds = RandomDataset() + collate_fn = Collator(backend='auto') + paddle_dl = DataLoader(ds, collate_fn=collate_fn) + for batch in paddle_dl: + print(batch) \ No newline at end of file diff --git a/tests/core/dataloaders/torch_dataloader/test_fdl.py b/tests/core/dataloaders/torch_dataloader/test_fdl.py index f3d0136d..ff38b614 100644 --- a/tests/core/dataloaders/torch_dataloader/test_fdl.py +++ b/tests/core/dataloaders/torch_dataloader/test_fdl.py @@ -112,3 +112,19 @@ class TestFdl: seq_ds = prepare_torch_dataloader(sequence) assert isinstance(seq_ds[0], TorchDataLoader) assert isinstance(seq_ds[1], TorchDataLoader) + + def test_get_backend(self): + from fastNLP.core.collators import Collator + from torch.utils.data import DataLoader, Dataset + + class MyDatset(DataSet): + def __len__(self): + return 1000 + + def __getitem__(self, item): + return [[1, 0], [1], [1, 2, 4]], [1, 0] + + collate_batch = Collator(backend='auto') + dl = DataLoader(MyDatset(), collate_fn=collate_batch) + for batch in dl: + print(batch)