|
|
@@ -147,13 +147,14 @@ class TestFdl: |
|
|
|
assert 'Parameter:prefetch_factor' in out[0] |
|
|
|
|
|
|
|
@recover_logger |
|
|
|
@pytest.mark.temp |
|
|
|
def test_version_111(self): |
|
|
|
if parse_version(torch.__version__) <= parse_version('1.7'): |
|
|
|
pytest.skip("Torch version smaller than 1.7") |
|
|
|
logger.set_stdout() |
|
|
|
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) |
|
|
|
with Capturing() as out: |
|
|
|
dl = TorchDataLoader(ds, num_workers=2, prefetch_factor=3, shuffle=False) |
|
|
|
dl = TorchDataLoader(ds, num_workers=0, prefetch_factor=2, generator=torch.Generator(), shuffle=False) |
|
|
|
for idx, batch in enumerate(dl): |
|
|
|
assert len(batch['x'])==1 |
|
|
|
assert batch['x'][0].tolist() == ds[idx]['x'] |
|
|
|