|
@@ -773,10 +773,7 @@ class TestSaveLoad: |
|
|
# 保存状态 |
|
|
# 保存状态 |
|
|
sampler_states = dataloader.batch_sampler.sampler.state_dict() |
|
|
sampler_states = dataloader.batch_sampler.sampler.state_dict() |
|
|
save_states = {"num_consumed_batches": num_consumed_batches} |
|
|
save_states = {"num_consumed_batches": num_consumed_batches} |
|
|
if only_state_dict: |
|
|
|
|
|
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) |
|
|
|
|
|
else: |
|
|
|
|
|
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) |
|
|
|
|
|
|
|
|
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) |
|
|
dist.barrier() # 等待save成功 |
|
|
dist.barrier() # 等待save成功 |
|
|
# 加载 |
|
|
# 加载 |
|
|
# 更改 batch_size |
|
|
# 更改 batch_size |
|
|