Browse Source

修复测试的一些bug

dev0.8.0
x54-729 2 years ago
parent
commit
babf4b2f19
3 changed files with 4 additions and 8 deletions
  1. +1
    -1
      tests/core/callbacks/test_checkpoint_callback_torch.py
  2. +1
    -5
      tests/core/collators/test_collator.py
  3. +2
    -2
      tests/core/drivers/torch_driver/test_fsdp.py

+ 1
- 1
tests/core/callbacks/test_checkpoint_callback_torch.py View File

@@ -75,7 +75,7 @@ def model_and_optimizers(request):


@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)
@magic_argv_env_context(timeout=100)
def test_model_checkpoint_callback_1(
model_and_optimizers: TrainerParameters,


+ 1
- 5
tests/core/collators/test_collator.py View File

@@ -11,12 +11,8 @@ from ...helpers.utils import Capturing
def _assert_equal(d1, d2):
try:
if 'torch' in str(type(d1)):
if 'float64' in str(d2.dtype):
print(d2.dtype)
assert (d1 == d2).all().item()
if 'oneflow' in str(type(d1)):
if 'float64' in str(d2.dtype):
print(d2.dtype)
elif 'oneflow' in str(type(d1)):
assert (d1 == d2).all().item()
else:
assert all(d1 == d2)


+ 2
- 2
tests/core/drivers/torch_driver/test_fsdp.py View File

@@ -67,7 +67,7 @@ def model_and_optimizers(request):

return trainer_params

@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上")
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="fsdp 需要 torch 版本在 1.12 及以上")
@pytest.mark.torch
@magic_argv_env_context
def test_trainer_torch_without_evaluator(
@@ -97,7 +97,7 @@ def test_trainer_torch_without_evaluator(
dist.destroy_process_group()


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上")
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="fsdp 需要 torch 版本在 1.12 及以上")
@pytest.mark.torch
@pytest.mark.parametrize("save_on_rank0", [True, False])
@magic_argv_env_context(timeout=100)


Loading…
Cancel
Save