diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index 9d362938..ba404814 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -123,17 +123,24 @@ class PaddleSingleDriver(PaddleDriver): if reproducible: if isinstance(args.sampler, paddle.io.RandomSampler): - # 如果本来就是随机的,直接替换 - sampler = RandomSampler(args.sampler.data_source) - logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") + if getattr(args.sampler, '_num_samples', None) is None \ + and getattr(args.sampler, 'replacements', False) is False \ + and getattr(args.sampler, 'generator', None) is None: + # 如果本来就是随机的,并且没有定制,直接替换掉。 + sampler = RandomSampler(args.sampler.data_source, shuffle=True) + logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") + return replace_sampler(dataloader, sampler) + elif isinstance(args.sampler, paddle.io.SequenceSampler): + # 需要替换为不要 shuffle 的。 + sampler = RandomSampler(args.sampler.data_source, shuffle=False) + logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.") return replace_sampler(dataloader, sampler) - else: - batch_sampler = ReproduceBatchSampler( - batch_sampler=args.batch_sampler, - batch_size=args.batch_size, - drop_last=args.drop_last - ) - return replace_batch_sampler(dataloader, batch_sampler) + batch_sampler = ReproduceBatchSampler( + batch_sampler=args.batch_sampler, + batch_size=args.batch_size, + drop_last=args.drop_last + ) + return replace_batch_sampler(dataloader, batch_sampler) else: return dataloader diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index 5b794459..ad7bf97d 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -250,7 +250,7 @@ def test_trainer_output_from_new_proc( @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) +@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) @pytest.mark.parametrize("cur_rank", [0]) # 依次测试如果是当前进程出现错误,是否能够正确地 kill 掉其他进程; , 1, 2, 3 @magic_argv_env_context def test_trainer_on_exception( diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index e7d6707a..67ea1b42 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -386,22 +386,16 @@ class TestSetDistReproDataloader: def test_with_reproducible_true(self, shuffle): """ 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 - 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True), - 只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler + 当dist为字符串时,此时应该返回新的 dataloader,会替换 sampler 为 RandomSampler """ dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) assert not (replaced_loader is dataloader) - if shuffle: - # 此时会替换 sampler - assert isinstance(replaced_loader.batch_sampler, paddle.io.BatchSampler) - assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) - assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) - else: - # 此时会替换 batch_sampler - assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) - assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) + assert isinstance(replaced_loader.batch_sampler, paddle.io.BatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert replaced_loader.batch_sampler.sampler.shuffle == shuffle assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.drop_last == dataloader.drop_last diff --git a/tests/core/drivers/torch_driver/test_single_device.py b/tests/core/drivers/torch_driver/test_single_device.py index 1fbc9d82..51555918 100644 --- a/tests/core/drivers/torch_driver/test_single_device.py +++ b/tests/core/drivers/torch_driver/test_single_device.py @@ -400,22 +400,19 @@ class TestSetDistReproDataloader: def test_with_reproducible_true(self, shuffle): """ 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 - 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 torch.utils.data.RandomSampler(shuffle=True), - 只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler + 当dist为字符串时,此时应该返回新的 dataloader,会替换 sampler 为 RandomSampler; + TODO: + 在 Sampler 的参数不是默认的情况下会替换 batch_sampler """ dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) assert not (replaced_loader is dataloader) - if shuffle: - # 此时会替换 sampler - assert isinstance(replaced_loader.batch_sampler, torch.utils.data.BatchSampler) - assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) - assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) - else: - # 此时会替换 batch_sampler - assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) - assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) + # 替换 sampler + assert isinstance(replaced_loader.batch_sampler, torch.utils.data.BatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert replaced_loader.batch_sampler.sampler.shuffle == shuffle assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.drop_last == dataloader.drop_last