|
|
@@ -135,7 +135,7 @@ class TestSetDistReproDataloader: |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
@pytest.mark.parametrize("shuffle", ([True, False])) |
|
|
|
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle): |
|
|
|
def test_with_dist_batch_sampler(self, shuffle): |
|
|
|
""" |
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 |
|
|
|
此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler |
|
|
@@ -154,7 +154,7 @@ class TestSetDistReproDataloader: |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
@pytest.mark.parametrize("shuffle", ([True, False])) |
|
|
|
def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle): |
|
|
|
def test_with_dist_sampler(self, shuffle): |
|
|
|
""" |
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 |
|
|
|
此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler |
|
|
@@ -182,7 +182,7 @@ class TestSetDistReproDataloader: |
|
|
|
""" |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
def test_set_dist_repro_dataloader_with_dist_none_reproducible_true(self): |
|
|
|
def test_with_dist_none_reproducible_true(self): |
|
|
|
""" |
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 |
|
|
|
当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错 |
|
|
@@ -195,8 +195,9 @@ class TestSetDistReproDataloader: |
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
# @pytest.mark.parametrize("shuffle", ([True, False])) |
|
|
|
@pytest.mark.parametrize("shuffle", ([True, False])) |
|
|
|
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): |
|
|
|
def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): |
|
|
|
""" |
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler |
|
|
|
时的表现 |
|
|
@@ -224,7 +225,7 @@ class TestSetDistReproDataloader: |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
@pytest.mark.parametrize("shuffle", ([True, False])) |
|
|
|
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self, shuffle): |
|
|
|
def test_with_dist_none_reproducible_false_dataloader_reproducible_sampler(self, shuffle): |
|
|
|
""" |
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现 |
|
|
|
此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 |
|
|
@@ -256,7 +257,7 @@ class TestSetDistReproDataloader: |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
@pytest.mark.parametrize("shuffle", ([True, False])) |
|
|
|
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self, shuffle): |
|
|
|
def test_with_dist_none_reproducible_false_dataloader_normal(self, shuffle): |
|
|
|
""" |
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 |
|
|
|
此时直接返回原来的 dataloader,不做任何处理。 |
|
|
@@ -274,7 +275,7 @@ class TestSetDistReproDataloader: |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
@pytest.mark.parametrize("shuffle", ([True, False])) |
|
|
|
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle): |
|
|
|
def test_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle): |
|
|
|
""" |
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler |
|
|
|
的表现 |
|
|
@@ -296,7 +297,7 @@ class TestSetDistReproDataloader: |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
@pytest.mark.parametrize("shuffle", ([True, False])) |
|
|
|
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self, shuffle): |
|
|
|
def test_with_dist_dist_dataloader_reproducible_sampler(self, shuffle): |
|
|
|
""" |
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler |
|
|
|
的表现 |
|
|
@@ -322,7 +323,7 @@ class TestSetDistReproDataloader: |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
@pytest.mark.parametrize("shuffle", ([True, False])) |
|
|
|
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self, shuffle): |
|
|
|
def test_with_dist_dist_dataloader_normal(self, shuffle): |
|
|
|
""" |
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现 |
|
|
|
此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关 |
|
|
@@ -347,7 +348,7 @@ class TestSetDistReproDataloader: |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
@pytest.mark.parametrize("shuffle", ([True, False])) |
|
|
|
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle): |
|
|
|
def test_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle): |
|
|
|
""" |
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler |
|
|
|
的表现 |
|
|
@@ -373,7 +374,7 @@ class TestSetDistReproDataloader: |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
@pytest.mark.parametrize("shuffle", ([True, False])) |
|
|
|
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle): |
|
|
|
def test_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle): |
|
|
|
""" |
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler |
|
|
|
的表现 |
|
|
@@ -399,7 +400,7 @@ class TestSetDistReproDataloader: |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
@pytest.mark.parametrize("shuffle", ([True, False])) |
|
|
|
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self, shuffle): |
|
|
|
def test_with_dist_unrepeat_dataloader_normal(self, shuffle): |
|
|
|
""" |
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现 |
|
|
|
此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关 |
|
|
@@ -444,18 +445,11 @@ class TestSetDistReproDataloader: |
|
|
|
else: |
|
|
|
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() |
|
|
|
|
|
|
|
# 加载 num_consumed_samples_array,设置正确取出的 batch 数目 |
|
|
|
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) |
|
|
|
print("array: ", num_consumed_samples_array) |
|
|
|
|
|
|
|
# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range |
|
|
|
left_idxes = set() |
|
|
|
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): |
|
|
|
batch_size = replaced_loader.batch_sampler.batch_size |
|
|
|
if num_consumed_samples_array is not None: |
|
|
|
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] |
|
|
|
else: |
|
|
|
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas |
|
|
|
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas |
|
|
|
# 重新改造 dataloader |
|
|
|
new_loader = DataLoader( |
|
|
|
dataset=replaced_loader.dataset, |
|
|
@@ -474,11 +468,7 @@ class TestSetDistReproDataloader: |
|
|
|
new_loader.batch_sampler.load_state_dict(sampler_states) |
|
|
|
else: |
|
|
|
batch_size = replaced_loader.batch_sampler.batch_size |
|
|
|
num_consumed_samples = num_consumed_batches * batch_size |
|
|
|
if num_consumed_samples_array is not None: |
|
|
|
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_samples] |
|
|
|
else: |
|
|
|
sampler_states["num_consumed_samples"] = num_consumed_samples * num_replicas |
|
|
|
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas |
|
|
|
# 重新构造 dataloader |
|
|
|
batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size) |
|
|
|
batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle) |
|
|
|