|
- from functools import reduce
-
- from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改;
- from tests.helpers.datasets.normal_data import NormalIterator
-
-
- class Test_WrapDataLoader:
-
- def test_normal_generator(self):
- all_sanity_batches = [4, 20, 100]
- for sanity_batches in all_sanity_batches:
- data = NormalIterator(num_of_data=1000)
- wrapper = _TruncatedDataLoader(num_batches=sanity_batches)
- dataloader = iter(wrapper(dataloader=data))
- mark = 0
- while True:
- try:
- _data = next(dataloader)
- except StopIteration:
- break
- mark += 1
- assert mark == sanity_batches
-
- def test_torch_dataloader(self):
- from tests.helpers.datasets.torch_data import TorchNormalDataset
- from torch.utils.data import DataLoader
-
- bses = [8, 16, 40]
- all_sanity_batches = [4, 7, 10]
- for bs in bses:
- for sanity_batches in all_sanity_batches:
- dataset = TorchNormalDataset(num_of_data=1000)
- dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
- wrapper = _TruncatedDataLoader(num_batches=sanity_batches)
- dataloader = wrapper(dataloader)
- dataloader = iter(dataloader)
- all_supposed_running_data_num = 0
- while True:
- try:
- _data = next(dataloader)
- except StopIteration:
- break
- all_supposed_running_data_num += _data.shape[0]
- assert all_supposed_running_data_num == bs * sanity_batches
-
- def test_len(self):
- from tests.helpers.datasets.torch_data import TorchNormalDataset
- from torch.utils.data import DataLoader
-
- bses = [8, 16, 40]
- all_sanity_batches = [4, 7, 10]
- length = []
- for bs in bses:
- for sanity_batches in all_sanity_batches:
- dataset = TorchNormalDataset(num_of_data=1000)
- dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
- wrapper = _TruncatedDataLoader(num_batches=sanity_batches)
- dataloader = wrapper(dataloader)
- length.append(len(dataloader))
- assert length == reduce(lambda x, y: x+y, [all_sanity_batches for _ in range(len(bses))])
|