Browse Source

small

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
9b13fd313a
2 changed files with 24 additions and 40 deletions
  1. +15
    -25
      tests/core/drivers/paddle_driver/test_fleet.py
  2. +9
    -15
      tests/core/drivers/paddle_driver/test_single_device.py

+ 15
- 25
tests/core/drivers/paddle_driver/test_fleet.py View File

@@ -135,7 +135,7 @@ class TestSetDistReproDataloader:

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle):
def test_with_dist_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现
此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler
@@ -154,7 +154,7 @@ class TestSetDistReproDataloader:

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle):
def test_with_dist_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现
此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler
@@ -182,7 +182,7 @@ class TestSetDistReproDataloader:
"""

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_none_reproducible_true(self):
def test_with_dist_none_reproducible_true(self):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现
当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错
@@ -195,8 +195,9 @@ class TestSetDistReproDataloader:
dist.barrier()

@magic_argv_env_context
# @pytest.mark.parametrize("shuffle", ([True, False]))
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle):
def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler
时的表现
@@ -224,7 +225,7 @@ class TestSetDistReproDataloader:

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self, shuffle):
def test_with_dist_none_reproducible_false_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现
此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其
@@ -256,7 +257,7 @@ class TestSetDistReproDataloader:

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self, shuffle):
def test_with_dist_none_reproducible_false_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现
此时直接返回原来的 dataloader,不做任何处理。
@@ -274,7 +275,7 @@ class TestSetDistReproDataloader:

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle):
def test_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler
的表现
@@ -296,7 +297,7 @@ class TestSetDistReproDataloader:

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self, shuffle):
def test_with_dist_dist_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
的表现
@@ -322,7 +323,7 @@ class TestSetDistReproDataloader:

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self, shuffle):
def test_with_dist_dist_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现
此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关
@@ -347,7 +348,7 @@ class TestSetDistReproDataloader:

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle):
def test_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
的表现
@@ -373,7 +374,7 @@ class TestSetDistReproDataloader:

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle):
def test_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler
的表现
@@ -399,7 +400,7 @@ class TestSetDistReproDataloader:

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self, shuffle):
def test_with_dist_unrepeat_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现
此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关
@@ -444,18 +445,11 @@ class TestSetDistReproDataloader:
else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict()

# 加载 num_consumed_samples_array,设置正确取出的 batch 数目
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
print("array: ", num_consumed_samples_array)

# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas
# 重新改造 dataloader
new_loader = DataLoader(
dataset=replaced_loader.dataset,
@@ -474,11 +468,7 @@ class TestSetDistReproDataloader:
new_loader.batch_sampler.load_state_dict(sampler_states)
else:
batch_size = replaced_loader.batch_sampler.batch_size
num_consumed_samples = num_consumed_batches * batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_samples]
else:
sampler_states["num_consumed_samples"] = num_consumed_samples * num_replicas
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas
# 重新构造 dataloader
batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size)
batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle)


+ 9
- 15
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -348,7 +348,7 @@ class TestSingleDeviceFunction:
#
############################################################################

class TestSetDistReproDataloder:
class TestSetDistReproDataloader:
"""
专门测试 set_dist_repro_dataloader 函数的类
"""
@@ -357,7 +357,7 @@ class TestSetDistReproDataloder:
model = PaddleNormalModel_Classification_1(10, 32)
self.driver = PaddleSingleDriver(model, device="cpu")
def test_set_dist_repro_dataloader_with_reproducible_false(self):
def test_with_reproducible_false(self):
"""
测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现
当dist为字符串时,此时应该返回原来的 dataloader
@@ -368,7 +368,7 @@ class TestSetDistReproDataloder:
assert replaced_loader is dataloader

@pytest.mark.parametrize("shuffle", [True, False])
def test_set_dist_repro_dataloader_with_reproducible_true(self, shuffle):
def test_with_reproducible_true(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True),
@@ -393,7 +393,7 @@ class TestSetDistReproDataloder:
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle):
def test_with_dist_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler
@@ -409,7 +409,7 @@ class TestSetDistReproDataloder:
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle):
def test_with_dist_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现
应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler
@@ -428,7 +428,7 @@ class TestSetDistReproDataloder:
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self, shuffle):
def test_with_dataloader_reproducible_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader,且其余各项设置和原来相同
@@ -452,7 +452,7 @@ class TestSetDistReproDataloder:
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self, shuffle):
def test_with_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader,且其余各项设置和原来相同
@@ -497,10 +497,7 @@ class TestSetDistReproDataloder:
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新改造 dataloader
new_loader = DataLoader(
dataset=replaced_loader.dataset,
@@ -514,10 +511,7 @@ class TestSetDistReproDataloder:
else:
batch_size = replaced_loader.batch_sampler.batch_size
num_consumed_samples = num_consumed_batches * batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_samples]
else:
sampler_states["num_consumed_samples"] = num_consumed_samples
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新构造 dataloader
batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size)
batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle)


Loading…
Cancel
Save