From 1452aa8f6c54e2ad313d93687baf491b9cb37559 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Thu, 14 Apr 2022 13:50:53 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BA=86=20dist=20=E4=B8=BA?= =?UTF-8?q?=20None=20=E6=97=B6=E7=9A=84=20set=5Fdist=5Frepro=5Fdataloader?= =?UTF-8?q?=20=E7=9A=84=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/torch_driver/ddp.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 11a61dde..c673fe62 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -471,12 +471,11 @@ class TorchDDPDriver(TorchDriver): raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " "control.") else: - if isinstance(dist, ReproducibleBatchSampler): - dist = re_instantiate_sampler(dist) - return replace_batch_sampler(dataloader, dist) - if isinstance(dist, ReproducibleSampler): - dist = re_instantiate_sampler(dist) - return replace_sampler(dataloader, dist) + args = self.get_dataloader_args(dataloader) + if isinstance(args.batch_sampler, ReproducibleBatchSampler): + return replace_batch_sampler(dataloader, re_instantiate_sampler(args.batch_sampler)) + if isinstance(args.sampler, ReproducibleSampler): + return replace_sampler(dataloader, re_instantiate_sampler(args.sampler)) return dataloader # trainer elif dist == "dist":