You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_utils.py 2.5 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. from functools import reduce
  2. import pytest
  3. from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改;
  4. from tests.helpers.datasets.normal_data import NormalSampler
  5. class Test_WrapDataLoader:
  6. def test_normal_generator(self):
  7. all_sanity_batches = [4, 20, 100]
  8. for sanity_batches in all_sanity_batches:
  9. data = NormalSampler(num_of_data=1000)
  10. wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches)
  11. dataloader = iter(wrapper)
  12. mark = 0
  13. while True:
  14. try:
  15. _data = next(dataloader)
  16. except StopIteration:
  17. break
  18. mark += 1
  19. assert mark == sanity_batches
  20. @pytest.mark.torch
  21. def test_torch_dataloader(self):
  22. from tests.helpers.datasets.torch_data import TorchNormalDataset
  23. from torch.utils.data import DataLoader
  24. bses = [8, 16, 40]
  25. all_sanity_batches = [4, 7, 10]
  26. for bs in bses:
  27. for sanity_batches in all_sanity_batches:
  28. dataset = TorchNormalDataset(num_of_data=1000)
  29. dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
  30. wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches)
  31. dataloader = iter(wrapper)
  32. all_supposed_running_data_num = 0
  33. while True:
  34. try:
  35. _data = next(dataloader)
  36. except StopIteration:
  37. break
  38. all_supposed_running_data_num += _data.shape[0]
  39. assert all_supposed_running_data_num == bs * sanity_batches
  40. @pytest.mark.torch
  41. def test_len(self):
  42. from tests.helpers.datasets.torch_data import TorchNormalDataset
  43. from torch.utils.data import DataLoader
  44. bses = [8, 16, 40]
  45. all_sanity_batches = [4, 7, 10]
  46. length = []
  47. for bs in bses:
  48. for sanity_batches in all_sanity_batches:
  49. dataset = TorchNormalDataset(num_of_data=1000)
  50. dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
  51. wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches)
  52. length.append(len(wrapper))
  53. assert length == reduce(lambda x, y: x+y, [all_sanity_batches for _ in range(len(bses))])