Browse Source

修复了 dist 为 None 时的 set_dist_repro_dataloader 的逻辑

tags/v1.0.0alpha
YWMditto 2 years ago
parent
commit
1452aa8f6c
1 changed files with 5 additions and 6 deletions
  1. +5
    -6
      fastNLP/core/drivers/torch_driver/ddp.py

+ 5
- 6
fastNLP/core/drivers/torch_driver/ddp.py View File

@@ -471,12 +471,11 @@ class TorchDDPDriver(TorchDriver):
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our "
"control.")
else:
if isinstance(dist, ReproducibleBatchSampler):
dist = re_instantiate_sampler(dist)
return replace_batch_sampler(dataloader, dist)
if isinstance(dist, ReproducibleSampler):
dist = re_instantiate_sampler(dist)
return replace_sampler(dataloader, dist)
args = self.get_dataloader_args(dataloader)
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
return replace_batch_sampler(dataloader, re_instantiate_sampler(args.batch_sampler))
if isinstance(args.sampler, ReproducibleSampler):
return replace_sampler(dataloader, re_instantiate_sampler(args.sampler))
return dataloader
# trainer
elif dist == "dist":


Loading…
Cancel
Save