From f183c59cce4f701f7a1fabde03714fdcc33acd06 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Wed, 18 May 2022 12:33:57 +0000 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9torch=5Fdataloader/test=5Ffdl?= =?UTF-8?q?.py=E7=9A=84test=5FVersion=5F111=E6=B5=8B=E8=AF=95=E4=BE=8B?= =?UTF-8?q?=E7=9A=84=E5=8F=82=E6=95=B0=EF=BC=9B=E5=88=A0=E9=99=A4=E4=B8=80?= =?UTF-8?q?=E5=A4=84=E4=B8=8D=E5=BF=85=E8=A6=81=E7=9A=84temp=E6=A0=87?= =?UTF-8?q?=E7=AD=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/callbacks/test_load_best_model_callback_torch.py | 1 - tests/core/dataloaders/torch_dataloader/test_fdl.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/core/callbacks/test_load_best_model_callback_torch.py b/tests/core/callbacks/test_load_best_model_callback_torch.py index c2be7f7a..07576599 100644 --- a/tests/core/callbacks/test_load_best_model_callback_torch.py +++ b/tests/core/callbacks/test_load_best_model_callback_torch.py @@ -73,7 +73,6 @@ def model_and_optimizers(request): @pytest.mark.torch -@pytest.mark.temp @pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) @magic_argv_env_context def test_load_best_model_callback( diff --git a/tests/core/dataloaders/torch_dataloader/test_fdl.py b/tests/core/dataloaders/torch_dataloader/test_fdl.py index 6d20754a..8ed7441b 100644 --- a/tests/core/dataloaders/torch_dataloader/test_fdl.py +++ b/tests/core/dataloaders/torch_dataloader/test_fdl.py @@ -147,13 +147,14 @@ class TestFdl: assert 'Parameter:prefetch_factor' in out[0] @recover_logger + @pytest.mark.temp def test_version_111(self): if parse_version(torch.__version__) <= parse_version('1.7'): pytest.skip("Torch version smaller than 1.7") logger.set_stdout() ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) with Capturing() as out: - dl = TorchDataLoader(ds, num_workers=2, prefetch_factor=3, shuffle=False) + dl = TorchDataLoader(ds, num_workers=0, prefetch_factor=2, generator=torch.Generator(), shuffle=False) for idx, batch in enumerate(dl): assert len(batch['x'])==1 assert batch['x'][0].tolist() == ds[idx]['x']