|
|
@@ -105,7 +105,7 @@ def test_model_checkpoint_callback_1( |
|
|
|
model_and_optimizers: TrainerParameters, |
|
|
|
save_on_rank0 |
|
|
|
): |
|
|
|
device = [6, 7] |
|
|
|
device = [0, 1] |
|
|
|
for version in [0, 1]: |
|
|
|
# 需要在每一个循环开始重新初始化 model,是因为 fsdp 会将当前卡上的 model 删除,从而导致这个引用实际上引用到的是一个空模型; |
|
|
|
model_and_optimizers.model = TorchNormalModel_Classification_1( |
|
|
@@ -242,7 +242,7 @@ def test_model_checkpoint_callback_1( |
|
|
|
|
|
|
|
@pytest.mark.skip("现在 fsdp 还不支持断点重训;") |
|
|
|
@pytest.mark.torch |
|
|
|
@pytest.mark.parametrize("driver,device", [("torch_fsdp", [6, 7])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) |
|
|
|
@pytest.mark.parametrize("driver,device", [("torch_fsdp", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) |
|
|
|
@magic_argv_env_context(timeout=100) |
|
|
|
def test_trainer_checkpoint_callback_1( |
|
|
|
model_and_optimizers: TrainerParameters, |
|
|
|