From ad6ada2487c492d3a12887aeb6664001639ba166 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 28 Jun 2022 23:25:12 +0800 Subject: [PATCH] small --- tests/core/drivers/torch_driver/test_ddp.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/core/drivers/torch_driver/test_ddp.py b/tests/core/drivers/torch_driver/test_ddp.py index 46abd84c..11f47617 100644 --- a/tests/core/drivers/torch_driver/test_ddp.py +++ b/tests/core/drivers/torch_driver/test_ddp.py @@ -773,10 +773,7 @@ class TestSaveLoad: # 保存状态 sampler_states = dataloader.batch_sampler.sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} - if only_state_dict: - driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) - else: - driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) + driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) dist.barrier() # 等待save成功 # 加载 # 更改 batch_size