From 9f6024674969342db02bfb0d58751a1479d34925 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 20 Sep 2022 19:30:52 +0800 Subject: [PATCH] =?UTF-8?q?test=5Ffsdp=20device=20=E7=BB=9F=E4=B8=80?= =?UTF-8?q?=E4=B8=BA[0,1]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/drivers/torch_driver/test_fsdp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/core/drivers/torch_driver/test_fsdp.py b/tests/core/drivers/torch_driver/test_fsdp.py index de291bfd..a05f005d 100644 --- a/tests/core/drivers/torch_driver/test_fsdp.py +++ b/tests/core/drivers/torch_driver/test_fsdp.py @@ -105,7 +105,7 @@ def test_model_checkpoint_callback_1( model_and_optimizers: TrainerParameters, save_on_rank0 ): - device = [6, 7] + device = [0, 1] for version in [0, 1]: # 需要在每一个循环开始重新初始化 model,是因为 fsdp 会将当前卡上的 model 删除,从而导致这个引用实际上引用到的是一个空模型; model_and_optimizers.model = TorchNormalModel_Classification_1( @@ -242,7 +242,7 @@ def test_model_checkpoint_callback_1( @pytest.mark.skip("现在 fsdp 还不支持断点重训;") @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch_fsdp", [6, 7])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch_fsdp", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) @magic_argv_env_context(timeout=100) def test_trainer_checkpoint_callback_1( model_and_optimizers: TrainerParameters,