Browse Source

test_fsdp device 统一为[0,1]

dev0.8.0
x54-729 2 years ago
parent
commit
9f60246749
1 changed files with 2 additions and 2 deletions
  1. +2
    -2
      tests/core/drivers/torch_driver/test_fsdp.py

+ 2
- 2
tests/core/drivers/torch_driver/test_fsdp.py View File

@@ -105,7 +105,7 @@ def test_model_checkpoint_callback_1(
model_and_optimizers: TrainerParameters, model_and_optimizers: TrainerParameters,
save_on_rank0 save_on_rank0
): ):
device = [6, 7]
device = [0, 1]
for version in [0, 1]: for version in [0, 1]:
# 需要在每一个循环开始重新初始化 model,是因为 fsdp 会将当前卡上的 model 删除,从而导致这个引用实际上引用到的是一个空模型; # 需要在每一个循环开始重新初始化 model,是因为 fsdp 会将当前卡上的 model 删除,从而导致这个引用实际上引用到的是一个空模型;
model_and_optimizers.model = TorchNormalModel_Classification_1( model_and_optimizers.model = TorchNormalModel_Classification_1(
@@ -242,7 +242,7 @@ def test_model_checkpoint_callback_1(


@pytest.mark.skip("现在 fsdp 还不支持断点重训;") @pytest.mark.skip("现在 fsdp 还不支持断点重训;")
@pytest.mark.torch @pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch_fsdp", [6, 7])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch_fsdp", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)
@magic_argv_env_context(timeout=100) @magic_argv_env_context(timeout=100)
def test_trainer_checkpoint_callback_1( def test_trainer_checkpoint_callback_1(
model_and_optimizers: TrainerParameters, model_and_optimizers: TrainerParameters,


Loading…
Cancel
Save