Browse Source

small

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
ad6ada2487
1 changed files with 1 additions and 4 deletions
  1. +1
    -4
      tests/core/drivers/torch_driver/test_ddp.py

+ 1
- 4
tests/core/drivers/torch_driver/test_ddp.py View File

@@ -773,10 +773,7 @@ class TestSaveLoad:
# 保存状态 # 保存状态
sampler_states = dataloader.batch_sampler.sampler.state_dict() sampler_states = dataloader.batch_sampler.sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches} save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))])
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
dist.barrier() # 等待save成功 dist.barrier() # 等待save成功
# 加载 # 加载
# 更改 batch_size # 更改 batch_size


Loading…
Cancel
Save