diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index 00b73b51..d227a162 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -75,7 +75,7 @@ def model_and_optimizers(request): @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) @magic_argv_env_context(timeout=100) def test_model_checkpoint_callback_1( model_and_optimizers: TrainerParameters, diff --git a/tests/core/collators/test_collator.py b/tests/core/collators/test_collator.py index d00cbe05..7c099c54 100644 --- a/tests/core/collators/test_collator.py +++ b/tests/core/collators/test_collator.py @@ -11,12 +11,8 @@ from ...helpers.utils import Capturing def _assert_equal(d1, d2): try: if 'torch' in str(type(d1)): - if 'float64' in str(d2.dtype): - print(d2.dtype) assert (d1 == d2).all().item() - if 'oneflow' in str(type(d1)): - if 'float64' in str(d2.dtype): - print(d2.dtype) + elif 'oneflow' in str(type(d1)): assert (d1 == d2).all().item() else: assert all(d1 == d2) diff --git a/tests/core/drivers/torch_driver/test_fsdp.py b/tests/core/drivers/torch_driver/test_fsdp.py index 586a97ea..de291bfd 100644 --- a/tests/core/drivers/torch_driver/test_fsdp.py +++ b/tests/core/drivers/torch_driver/test_fsdp.py @@ -67,7 +67,7 @@ def model_and_optimizers(request): return trainer_params -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上") +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="fsdp 需要 torch 版本在 1.12 及以上") @pytest.mark.torch @magic_argv_env_context def test_trainer_torch_without_evaluator( @@ -97,7 +97,7 @@ def test_trainer_torch_without_evaluator( dist.destroy_process_group() -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上") +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="fsdp 需要 torch 版本在 1.12 及以上") @pytest.mark.torch @pytest.mark.parametrize("save_on_rank0", [True, False]) @magic_argv_env_context(timeout=100)