| @@ -123,17 +123,24 @@ class PaddleSingleDriver(PaddleDriver): | |||
| if reproducible: | |||
| if isinstance(args.sampler, paddle.io.RandomSampler): | |||
| # 如果本来就是随机的,直接替换 | |||
| sampler = RandomSampler(args.sampler.data_source) | |||
| logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | |||
| if getattr(args.sampler, '_num_samples', None) is None \ | |||
| and getattr(args.sampler, 'replacements', False) is False \ | |||
| and getattr(args.sampler, 'generator', None) is None: | |||
| # 如果本来就是随机的,并且没有定制,直接替换掉。 | |||
| sampler = RandomSampler(args.sampler.data_source, shuffle=True) | |||
| logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | |||
| return replace_sampler(dataloader, sampler) | |||
| elif isinstance(args.sampler, paddle.io.SequenceSampler): | |||
| # 需要替换为不要 shuffle 的。 | |||
| sampler = RandomSampler(args.sampler.data_source, shuffle=False) | |||
| logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.") | |||
| return replace_sampler(dataloader, sampler) | |||
| else: | |||
| batch_sampler = ReproduceBatchSampler( | |||
| batch_sampler=args.batch_sampler, | |||
| batch_size=args.batch_size, | |||
| drop_last=args.drop_last | |||
| ) | |||
| return replace_batch_sampler(dataloader, batch_sampler) | |||
| batch_sampler = ReproduceBatchSampler( | |||
| batch_sampler=args.batch_sampler, | |||
| batch_size=args.batch_size, | |||
| drop_last=args.drop_last | |||
| ) | |||
| return replace_batch_sampler(dataloader, batch_sampler) | |||
| else: | |||
| return dataloader | |||
| @@ -250,7 +250,7 @@ def test_trainer_output_from_new_proc( | |||
| @pytest.mark.torch | |||
| @pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) | |||
| @pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) | |||
| @pytest.mark.parametrize("cur_rank", [0]) # 依次测试如果是当前进程出现错误,是否能够正确地 kill 掉其他进程; , 1, 2, 3 | |||
| @magic_argv_env_context | |||
| def test_trainer_on_exception( | |||
| @@ -386,22 +386,16 @@ class TestSetDistReproDataloader: | |||
| def test_with_reproducible_true(self, shuffle): | |||
| """ | |||
| 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | |||
| 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True), | |||
| 只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler | |||
| 当dist为字符串时,此时应该返回新的 dataloader,会替换 sampler 为 RandomSampler | |||
| """ | |||
| dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | |||
| assert not (replaced_loader is dataloader) | |||
| if shuffle: | |||
| # 此时会替换 sampler | |||
| assert isinstance(replaced_loader.batch_sampler, paddle.io.BatchSampler) | |||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||
| assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||
| else: | |||
| # 此时会替换 batch_sampler | |||
| assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||
| assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | |||
| assert isinstance(replaced_loader.batch_sampler, paddle.io.BatchSampler) | |||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||
| assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||
| assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | |||
| assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||
| assert replaced_loader.drop_last == dataloader.drop_last | |||
| @@ -400,22 +400,19 @@ class TestSetDistReproDataloader: | |||
| def test_with_reproducible_true(self, shuffle): | |||
| """ | |||
| 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | |||
| 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 torch.utils.data.RandomSampler(shuffle=True), | |||
| 只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler | |||
| 当dist为字符串时,此时应该返回新的 dataloader,会替换 sampler 为 RandomSampler; | |||
| TODO: | |||
| 在 Sampler 的参数不是默认的情况下会替换 batch_sampler | |||
| """ | |||
| dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | |||
| assert not (replaced_loader is dataloader) | |||
| if shuffle: | |||
| # 此时会替换 sampler | |||
| assert isinstance(replaced_loader.batch_sampler, torch.utils.data.BatchSampler) | |||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||
| assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||
| else: | |||
| # 此时会替换 batch_sampler | |||
| assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||
| assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | |||
| # 替换 sampler | |||
| assert isinstance(replaced_loader.batch_sampler, torch.utils.data.BatchSampler) | |||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||
| assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||
| assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | |||
| assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||
| assert replaced_loader.drop_last == dataloader.drop_last | |||