diff --git a/tests/core/callbacks/test_load_best_model_callback_torch.py b/tests/core/callbacks/test_load_best_model_callback_torch.py index c2be7f7a..07576599 100644 --- a/tests/core/callbacks/test_load_best_model_callback_torch.py +++ b/tests/core/callbacks/test_load_best_model_callback_torch.py @@ -73,7 +73,6 @@ def model_and_optimizers(request): @pytest.mark.torch -@pytest.mark.temp @pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) @magic_argv_env_context def test_load_best_model_callback( diff --git a/tests/core/dataloaders/torch_dataloader/test_fdl.py b/tests/core/dataloaders/torch_dataloader/test_fdl.py index 6d20754a..8ed7441b 100644 --- a/tests/core/dataloaders/torch_dataloader/test_fdl.py +++ b/tests/core/dataloaders/torch_dataloader/test_fdl.py @@ -147,13 +147,14 @@ class TestFdl: assert 'Parameter:prefetch_factor' in out[0] @recover_logger + @pytest.mark.temp def test_version_111(self): if parse_version(torch.__version__) <= parse_version('1.7'): pytest.skip("Torch version smaller than 1.7") logger.set_stdout() ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) with Capturing() as out: - dl = TorchDataLoader(ds, num_workers=2, prefetch_factor=3, shuffle=False) + dl = TorchDataLoader(ds, num_workers=0, prefetch_factor=2, generator=torch.Generator(), shuffle=False) for idx, batch in enumerate(dl): assert len(batch['x'])==1 assert batch['x'][0].tolist() == ds[idx]['x']