From 9ed4df6f860cbc2026d952ac754bd3db509f246b Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Fri, 8 Apr 2022 21:36:12 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86dataloaders=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/dataloaders/__init__.py | 0 .../dataloaders/jittor_dataloader/test_fdl.py | 80 ++++++++++++++++ .../dataloaders/paddle_dataloader/test_fdl.py | 53 ++++++++++ .../dataloaders/torch_dataloader/test_fdl.py | 96 +++++++++++++++++++ 4 files changed, 229 insertions(+) create mode 100644 tests/core/dataloaders/__init__.py create mode 100644 tests/core/dataloaders/jittor_dataloader/test_fdl.py create mode 100644 tests/core/dataloaders/paddle_dataloader/test_fdl.py create mode 100644 tests/core/dataloaders/torch_dataloader/test_fdl.py diff --git a/tests/core/dataloaders/__init__.py b/tests/core/dataloaders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/dataloaders/jittor_dataloader/test_fdl.py b/tests/core/dataloaders/jittor_dataloader/test_fdl.py new file mode 100644 index 00000000..f2021923 --- /dev/null +++ b/tests/core/dataloaders/jittor_dataloader/test_fdl.py @@ -0,0 +1,80 @@ +import pytest +from jittor.dataset import Dataset +import jittor +import numpy as np +from datasets import Dataset as HfDataset +from datasets import load_dataset + +from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader +from fastNLP.core.dataset import DataSet as Fdataset + + +class MyDataset(Dataset): + + def __init__(self, data_len=1000): + super(MyDataset, self).__init__() + self.data = [jittor.ones((3, 4)) for _ in range(data_len)] + self.set_attrs(total_len=data_len) + self.dataset_len = data_len + + def __getitem__(self, item): + return self.data[item] + # return {'x': [[1, 0], [2, 0, 1]]} + # return np.random.randn(3, 10) + + # def __len__(self): + # return self.dataset_len + + +class TestJittor: + + def test_v1(self): + """ + 测试jittor类型的dataset使用fdl + + :return: + """ + dataset = MyDataset() + jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4) + jtl.set_pad_val('x', 'y') + jtl.set_input('x') + for batch in jtl: + print(batch) + print(jtl.get_batch_indices()) + + def test_v2(self): + """ + 测试fastnlp的dataset + + :return: + """ + dataset = Fdataset({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100}) + jtl = JittorDataLoader(dataset, batch_size=16, drop_last=True) + jtl.set_pad_val('x', val=-1) + jtl.set_input('x', 'y') + for batch in jtl: + assert batch['x'].size() == (16, 4) + + def test_v3(self): + dataset = HfDataset.from_dict({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100}) + jtl = JittorDataLoader(dataset, batch_size=4, drop_last=True) + jtl.set_input('x', 'y') + for batch in jtl: + print(batch) + + def test_v4(self): + dataset = MyDataset() + dl = JittorDataLoader(dataset, batch_size=4, num_workers=2) + print(len(dl)) + for idx, batch in enumerate(dl): + print(batch.shape, idx) + for idx, batch in enumerate(dl): + print(batch.shape, idx) + + def test_v5(self): + dataset = MyDataset() + dataset.set_attrs(batch_size=4, num_workers=2) + for idx, batch in enumerate(dataset): + print(idx, batch.shape) + for idx, batch in enumerate(dataset): + print(idx, batch.shape) \ No newline at end of file diff --git a/tests/core/dataloaders/paddle_dataloader/test_fdl.py b/tests/core/dataloaders/paddle_dataloader/test_fdl.py new file mode 100644 index 00000000..dbca394b --- /dev/null +++ b/tests/core/dataloaders/paddle_dataloader/test_fdl.py @@ -0,0 +1,53 @@ +import unittest + +from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader +from fastNLP.core.dataset import DataSet +from paddle.io import Dataset, DataLoader +import numpy as np +import paddle + + +class RandomDataset(Dataset): + + def __getitem__(self, idx): + image = np.random.random((10, 5)).astype('float32') + return {'image': paddle.Tensor(image), 'label': [[0, 1], [1, 2, 3, 4]]} + + def __len__(self): + return 10 + + +class TestPaddle(unittest.TestCase): + + def test_init(self): + # ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10}) + ds = RandomDataset() + fdl = PaddleDataLoader(ds, batch_size=2) + # fdl = DataLoader(ds, batch_size=2, shuffle=True) + for batch in fdl: + print(batch) + # print(fdl.get_batch_indices()) + + def test_fdl_batch_indices(self): + ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10}) + fdl = PaddleDataLoader(ds, batch_size=4, shuffle=True, drop_last=True) + fdl.set_input("x", "y") + for batch in fdl: + assert len(fdl.get_batch_indices()) == 4 + print(batch) + print(fdl.get_batch_indices()) + + def test_set_inputs_and_set_pad_val(self): + ds = RandomDataset() + fdl = PaddleDataLoader(ds, batch_size=2, drop_last=True) + fdl.set_input('image', 'label') + fdl.set_pad_val('label', val=-1) + for batch in fdl: + assert batch['image'].shape == [2, 10, 5] + print(batch) + fdl1 = PaddleDataLoader(ds, batch_size=4, drop_last=True) + fdl1.set_input('image', 'label') + fdl1.set_pad_val('image', val=None) + for batch in fdl1: + assert batch['image'].shape == [4, 10, 5] + print(batch) diff --git a/tests/core/dataloaders/torch_dataloader/test_fdl.py b/tests/core/dataloaders/torch_dataloader/test_fdl.py new file mode 100644 index 00000000..0cd17ddd --- /dev/null +++ b/tests/core/dataloaders/torch_dataloader/test_fdl.py @@ -0,0 +1,96 @@ +import unittest + +from fastNLP.core.dataloaders.torch_dataloader import FDataLoader, prepare_dataloader +from fastNLP.core.dataset import DataSet +from fastNLP.io.data_bundle import DataBundle + + +class TestFdl(unittest.TestCase): + + def test_init_v1(self): + ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + fdl = FDataLoader(ds, batch_size=3, shuffle=True, drop_last=True) + # for batch in fdl: + # print(batch) + fdl1 = FDataLoader(ds, batch_size=3, shuffle=True, drop_last=True, as_numpy=True) + # for batch in fdl1: + # print(batch) + + def test_set_padding(self): + ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + ds.set_pad_val("x", val=-1) + fdl = FDataLoader(ds, batch_size=3) + fdl.set_input("x", "y") + for batch in fdl: + print(batch) + fdl.set_pad_val("x", val=-2) + for batch in fdl: + print(batch) + + def test_add_collator(self): + ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + + def collate_fn(ins_list): + _dict = {"Y": []} + for ins in ins_list: + _dict["Y"].append(ins['y']) + return _dict + + fdl = FDataLoader(ds, batch_size=3, as_numpy=True) + fdl.set_input("x", "y") + fdl.add_collator(collate_fn) + for batch in fdl: + print(batch) + + def test_get_batch_indices(self): + ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + fdl = FDataLoader(ds, batch_size=3, shuffle=True) + fdl.set_input("y", "x") + for batch in fdl: + print(fdl.get_batch_indices()) + + def test_other_dataset(self): + import numpy as np + + class _DataSet: + + def __init__(self): + pass + + def __getitem__(self, item): + return np.random.randn(5), [[1, 2], [2, 3, 4]] + + def __len__(self): + return 10 + + def __getattribute__(self, item): + return object.__getattribute__(self, item) + + dataset = _DataSet() + dl = FDataLoader(dataset, batch_size=2, shuffle=True) + # dl.set_inputs('data', 'labels') + # dl.set_pad_val('labels', val=None) + for batch in dl: + print(batch) + print(dl.get_batch_indices()) + + def test_prepare_dataloader(self): + ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + dl = prepare_dataloader(ds, batch_size=8, shuffle=True, num_workers=2) + assert isinstance(dl, FDataLoader) + + ds1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + dbl = DataBundle(datasets={'train': ds, 'val': ds1}) + dl_bundle = prepare_dataloader(dbl) + assert isinstance(dl_bundle['train'], FDataLoader) + assert isinstance(dl_bundle['val'], FDataLoader) + + ds_dict = {'train_1': ds, 'val': ds1} + dl_dict = prepare_dataloader(ds_dict) + assert isinstance(dl_dict['train_1'], FDataLoader) + assert isinstance(dl_dict['val'], FDataLoader) + + sequence = [ds, ds1] + seq_ds = prepare_dataloader(sequence) + assert isinstance(seq_ds[0], FDataLoader) + assert isinstance(seq_ds[1], FDataLoader)