diff --git a/tests/core/drivers/torch_driver/test_fsdp.py b/tests/core/drivers/torch_driver/test_fsdp.py index de291bfd..a05f005d 100644 --- a/tests/core/drivers/torch_driver/test_fsdp.py +++ b/tests/core/drivers/torch_driver/test_fsdp.py @@ -105,7 +105,7 @@ def test_model_checkpoint_callback_1( model_and_optimizers: TrainerParameters, save_on_rank0 ): - device = [6, 7] + device = [0, 1] for version in [0, 1]: # 需要在每一个循环开始重新初始化 model,是因为 fsdp 会将当前卡上的 model 删除,从而导致这个引用实际上引用到的是一个空模型; model_and_optimizers.model = TorchNormalModel_Classification_1( @@ -242,7 +242,7 @@ def test_model_checkpoint_callback_1( @pytest.mark.skip("现在 fsdp 还不支持断点重训;") @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch_fsdp", [6, 7])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch_fsdp", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) @magic_argv_env_context(timeout=100) def test_trainer_checkpoint_callback_1( model_and_optimizers: TrainerParameters,