|
|
@@ -491,7 +491,8 @@ class PaddleFleetDriver(PaddleDriver): |
|
|
|
rank=self.global_rank |
|
|
|
) |
|
|
|
# TODO 这里暂时统一替换为 BatchSampler |
|
|
|
batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) |
|
|
|
batch_sampler = BatchSampler(dataset=args.dataset, batch_size=args.batch_size, drop_last=False) |
|
|
|
batch_sampler.sampler = sampler |
|
|
|
return replace_batch_sampler(dataloader, batch_sampler) |
|
|
|
else: |
|
|
|
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") |
|
|
|