Browse Source

修改torch_dataloader/test_fdl.py的test_Version_111测试例的参数;删除一处不必要的temp标签

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
f183c59cce
2 changed files with 2 additions and 2 deletions
  1. +0
    -1
      tests/core/callbacks/test_load_best_model_callback_torch.py
  2. +2
    -1
      tests/core/dataloaders/torch_dataloader/test_fdl.py

+ 0
- 1
tests/core/callbacks/test_load_best_model_callback_torch.py View File

@@ -73,7 +73,6 @@ def model_and_optimizers(request):




@pytest.mark.torch @pytest.mark.torch
@pytest.mark.temp
@pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) @pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)
@magic_argv_env_context @magic_argv_env_context
def test_load_best_model_callback( def test_load_best_model_callback(


+ 2
- 1
tests/core/dataloaders/torch_dataloader/test_fdl.py View File

@@ -147,13 +147,14 @@ class TestFdl:
assert 'Parameter:prefetch_factor' in out[0] assert 'Parameter:prefetch_factor' in out[0]


@recover_logger @recover_logger
@pytest.mark.temp
def test_version_111(self): def test_version_111(self):
if parse_version(torch.__version__) <= parse_version('1.7'): if parse_version(torch.__version__) <= parse_version('1.7'):
pytest.skip("Torch version smaller than 1.7") pytest.skip("Torch version smaller than 1.7")
logger.set_stdout() logger.set_stdout()
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
with Capturing() as out: with Capturing() as out:
dl = TorchDataLoader(ds, num_workers=2, prefetch_factor=3, shuffle=False)
dl = TorchDataLoader(ds, num_workers=0, prefetch_factor=2, generator=torch.Generator(), shuffle=False)
for idx, batch in enumerate(dl): for idx, batch in enumerate(dl):
assert len(batch['x'])==1 assert len(batch['x'])==1
assert batch['x'][0].tolist() == ds[idx]['x'] assert batch['x'][0].tolist() == ds[idx]['x']


Loading…
Cancel
Save