@@ -75,7 +75,7 @@ def model_and_optimizers(request): | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@magic_argv_env_context(timeout=100) | @magic_argv_env_context(timeout=100) | ||||
def test_model_checkpoint_callback_1( | def test_model_checkpoint_callback_1( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
@@ -11,12 +11,8 @@ from ...helpers.utils import Capturing | |||||
def _assert_equal(d1, d2): | def _assert_equal(d1, d2): | ||||
try: | try: | ||||
if 'torch' in str(type(d1)): | if 'torch' in str(type(d1)): | ||||
if 'float64' in str(d2.dtype): | |||||
print(d2.dtype) | |||||
assert (d1 == d2).all().item() | assert (d1 == d2).all().item() | ||||
if 'oneflow' in str(type(d1)): | |||||
if 'float64' in str(d2.dtype): | |||||
print(d2.dtype) | |||||
elif 'oneflow' in str(type(d1)): | |||||
assert (d1 == d2).all().item() | assert (d1 == d2).all().item() | ||||
else: | else: | ||||
assert all(d1 == d2) | assert all(d1 == d2) | ||||
@@ -67,7 +67,7 @@ def model_and_optimizers(request): | |||||
return trainer_params | return trainer_params | ||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上") | |||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="fsdp 需要 torch 版本在 1.12 及以上") | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_torch_without_evaluator( | def test_trainer_torch_without_evaluator( | ||||
@@ -97,7 +97,7 @@ def test_trainer_torch_without_evaluator( | |||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上") | |||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="fsdp 需要 torch 版本在 1.12 及以上") | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("save_on_rank0", [True, False]) | @pytest.mark.parametrize("save_on_rank0", [True, False]) | ||||
@magic_argv_env_context(timeout=100) | @magic_argv_env_context(timeout=100) | ||||