|
|
@@ -471,12 +471,11 @@ class TorchDDPDriver(TorchDriver): |
|
|
|
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " |
|
|
|
"control.") |
|
|
|
else: |
|
|
|
if isinstance(dist, ReproducibleBatchSampler): |
|
|
|
dist = re_instantiate_sampler(dist) |
|
|
|
return replace_batch_sampler(dataloader, dist) |
|
|
|
if isinstance(dist, ReproducibleSampler): |
|
|
|
dist = re_instantiate_sampler(dist) |
|
|
|
return replace_sampler(dataloader, dist) |
|
|
|
args = self.get_dataloader_args(dataloader) |
|
|
|
if isinstance(args.batch_sampler, ReproducibleBatchSampler): |
|
|
|
return replace_batch_sampler(dataloader, re_instantiate_sampler(args.batch_sampler)) |
|
|
|
if isinstance(args.sampler, ReproducibleSampler): |
|
|
|
return replace_sampler(dataloader, re_instantiate_sampler(args.sampler)) |
|
|
|
return dataloader |
|
|
|
# trainer |
|
|
|
elif dist == "dist": |
|
|
|