@@ -123,17 +123,24 @@ class PaddleSingleDriver(PaddleDriver): | |||
if reproducible: | |||
if isinstance(args.sampler, paddle.io.RandomSampler): | |||
# 如果本来就是随机的,直接替换 | |||
sampler = RandomSampler(args.sampler.data_source) | |||
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | |||
if getattr(args.sampler, '_num_samples', None) is None \ | |||
and getattr(args.sampler, 'replacements', False) is False \ | |||
and getattr(args.sampler, 'generator', None) is None: | |||
# 如果本来就是随机的,并且没有定制,直接替换掉。 | |||
sampler = RandomSampler(args.sampler.data_source, shuffle=True) | |||
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | |||
return replace_sampler(dataloader, sampler) | |||
elif isinstance(args.sampler, paddle.io.SequenceSampler): | |||
# 需要替换为不要 shuffle 的。 | |||
sampler = RandomSampler(args.sampler.data_source, shuffle=False) | |||
logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.") | |||
return replace_sampler(dataloader, sampler) | |||
else: | |||
batch_sampler = ReproduceBatchSampler( | |||
batch_sampler=args.batch_sampler, | |||
batch_size=args.batch_size, | |||
drop_last=args.drop_last | |||
) | |||
return replace_batch_sampler(dataloader, batch_sampler) | |||
batch_sampler = ReproduceBatchSampler( | |||
batch_sampler=args.batch_sampler, | |||
batch_size=args.batch_size, | |||
drop_last=args.drop_last | |||
) | |||
return replace_batch_sampler(dataloader, batch_sampler) | |||
else: | |||
return dataloader | |||
@@ -250,7 +250,7 @@ def test_trainer_output_from_new_proc( | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) | |||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) | |||
@pytest.mark.parametrize("cur_rank", [0]) # 依次测试如果是当前进程出现错误,是否能够正确地 kill 掉其他进程; , 1, 2, 3 | |||
@magic_argv_env_context | |||
def test_trainer_on_exception( | |||
@@ -386,22 +386,16 @@ class TestSetDistReproDataloader: | |||
def test_with_reproducible_true(self, shuffle): | |||
""" | |||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | |||
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True), | |||
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler | |||
当dist为字符串时,此时应该返回新的 dataloader,会替换 sampler 为 RandomSampler | |||
""" | |||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | |||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | |||
assert not (replaced_loader is dataloader) | |||
if shuffle: | |||
# 此时会替换 sampler | |||
assert isinstance(replaced_loader.batch_sampler, paddle.io.BatchSampler) | |||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||
else: | |||
# 此时会替换 batch_sampler | |||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | |||
assert isinstance(replaced_loader.batch_sampler, paddle.io.BatchSampler) | |||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | |||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||
assert replaced_loader.drop_last == dataloader.drop_last | |||
@@ -400,22 +400,19 @@ class TestSetDistReproDataloader: | |||
def test_with_reproducible_true(self, shuffle): | |||
""" | |||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | |||
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 torch.utils.data.RandomSampler(shuffle=True), | |||
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler | |||
当dist为字符串时,此时应该返回新的 dataloader,会替换 sampler 为 RandomSampler; | |||
TODO: | |||
在 Sampler 的参数不是默认的情况下会替换 batch_sampler | |||
""" | |||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | |||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | |||
assert not (replaced_loader is dataloader) | |||
if shuffle: | |||
# 此时会替换 sampler | |||
assert isinstance(replaced_loader.batch_sampler, torch.utils.data.BatchSampler) | |||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||
else: | |||
# 此时会替换 batch_sampler | |||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | |||
# 替换 sampler | |||
assert isinstance(replaced_loader.batch_sampler, torch.utils.data.BatchSampler) | |||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | |||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||
assert replaced_loader.drop_last == dataloader.drop_last | |||