Browse Source

paddle 和测试例跟进 set_dist_repro_dataloader 函数;修改test_trainer_wo_evaluator_torch.py的bug

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
c4af9f21c6
4 changed files with 31 additions and 33 deletions
  1. +17
    -10
      fastNLP/core/drivers/paddle_driver/single_device.py
  2. +1
    -1
      tests/core/controllers/test_trainer_wo_evaluator_torch.py
  3. +5
    -11
      tests/core/drivers/paddle_driver/test_single_device.py
  4. +8
    -11
      tests/core/drivers/torch_driver/test_single_device.py

+ 17
- 10
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -123,17 +123,24 @@ class PaddleSingleDriver(PaddleDriver):

if reproducible:
if isinstance(args.sampler, paddle.io.RandomSampler):
# 如果本来就是随机的,直接替换
sampler = RandomSampler(args.sampler.data_source)
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.")
if getattr(args.sampler, '_num_samples', None) is None \
and getattr(args.sampler, 'replacements', False) is False \
and getattr(args.sampler, 'generator', None) is None:
# 如果本来就是随机的,并且没有定制,直接替换掉。
sampler = RandomSampler(args.sampler.data_source, shuffle=True)
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
elif isinstance(args.sampler, paddle.io.SequenceSampler):
# 需要替换为不要 shuffle 的。
sampler = RandomSampler(args.sampler.data_source, shuffle=False)
logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
else:
batch_sampler = ReproduceBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last
)
return replace_batch_sampler(dataloader, batch_sampler)
batch_sampler = ReproduceBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last
)
return replace_batch_sampler(dataloader, batch_sampler)
else:
return dataloader



+ 1
- 1
tests/core/controllers/test_trainer_wo_evaluator_torch.py View File

@@ -250,7 +250,7 @@ def test_trainer_output_from_new_proc(


@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", [1, 2])])
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])])
@pytest.mark.parametrize("cur_rank", [0]) # 依次测试如果是当前进程出现错误,是否能够正确地 kill 掉其他进程; , 1, 2, 3
@magic_argv_env_context
def test_trainer_on_exception(


+ 5
- 11
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -386,22 +386,16 @@ class TestSetDistReproDataloader:
def test_with_reproducible_true(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True),
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler
当dist为字符串时,此时应该返回新的 dataloader,会替换 sampler 为 RandomSampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)

assert not (replaced_loader is dataloader)
if shuffle:
# 此时会替换 sampler
assert isinstance(replaced_loader.batch_sampler, paddle.io.BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
else:
# 此时会替换 batch_sampler
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
assert isinstance(replaced_loader.batch_sampler, paddle.io.BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last



+ 8
- 11
tests/core/drivers/torch_driver/test_single_device.py View File

@@ -400,22 +400,19 @@ class TestSetDistReproDataloader:
def test_with_reproducible_true(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 torch.utils.data.RandomSampler(shuffle=True),
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler
当dist为字符串时,此时应该返回新的 dataloader,会替换 sampler 为 RandomSampler;
TODO:
在 Sampler 的参数不是默认的情况下会替换 batch_sampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)

assert not (replaced_loader is dataloader)
if shuffle:
# 此时会替换 sampler
assert isinstance(replaced_loader.batch_sampler, torch.utils.data.BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
else:
# 此时会替换 batch_sampler
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
# 替换 sampler
assert isinstance(replaced_loader.batch_sampler, torch.utils.data.BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last



Loading…
Cancel
Save