|
@@ -117,12 +117,13 @@ class TestSetDistReproDataloader: |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
@magic_argv_env_context |
|
|
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("shuffle", ([True, False])) |
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle): |
|
|
""" |
|
|
""" |
|
|
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 |
|
|
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 |
|
|
""" |
|
|
""" |
|
|
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) |
|
|
|
|
|
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4) |
|
|
|
|
|
|
|
|
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) |
|
|
|
|
|
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) |
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) |
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) |
|
|
|
|
|
|
|
|
assert not (replaced_loader is dataloader) |
|
|
assert not (replaced_loader is dataloader) |
|
@@ -133,12 +134,13 @@ class TestSetDistReproDataloader: |
|
|
dist.barrier() |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
@magic_argv_env_context |
|
|
def test_set_dist_repro_dataloader_with_dist_sampler(self): |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("shuffle", ([True, False])) |
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle): |
|
|
""" |
|
|
""" |
|
|
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 |
|
|
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 |
|
|
""" |
|
|
""" |
|
|
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) |
|
|
|
|
|
sampler = RandomSampler(self.dataset, shuffle=True) |
|
|
|
|
|
|
|
|
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) |
|
|
|
|
|
sampler = RandomSampler(self.dataset, shuffle=shuffle) |
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False) |
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False) |
|
|
|
|
|
|
|
|
assert not (replaced_loader is dataloader) |
|
|
assert not (replaced_loader is dataloader) |
|
@@ -171,14 +173,15 @@ class TestSetDistReproDataloader: |
|
|
dist.barrier() |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
@magic_argv_env_context |
|
|
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self): |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("shuffle", ([True, False])) |
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): |
|
|
""" |
|
|
""" |
|
|
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler |
|
|
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler |
|
|
时的表现 |
|
|
时的表现 |
|
|
""" |
|
|
""" |
|
|
dataloader = DataLoader( |
|
|
dataloader = DataLoader( |
|
|
self.dataset, |
|
|
self.dataset, |
|
|
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4), |
|
|
|
|
|
|
|
|
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle), |
|
|
) |
|
|
) |
|
|
dataloader.batch_sampler.set_distributed( |
|
|
dataloader.batch_sampler.set_distributed( |
|
|
num_replicas=self.driver.world_size, |
|
|
num_replicas=self.driver.world_size, |
|
@@ -195,12 +198,13 @@ class TestSetDistReproDataloader: |
|
|
dist.barrier() |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
@magic_argv_env_context |
|
|
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self): |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("shuffle", ([True, False])) |
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self, shuffle): |
|
|
""" |
|
|
""" |
|
|
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现 |
|
|
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现 |
|
|
""" |
|
|
""" |
|
|
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) |
|
|
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) |
|
|
batch_sampler.sampler = RandomSampler(self.dataset, True) |
|
|
|
|
|
|
|
|
batch_sampler.sampler = RandomSampler(self.dataset, shuffle) |
|
|
batch_sampler.sampler.set_distributed( |
|
|
batch_sampler.sampler.set_distributed( |
|
|
num_replicas=self.driver.world_size, |
|
|
num_replicas=self.driver.world_size, |
|
|
rank=self.driver.global_rank |
|
|
rank=self.driver.global_rank |
|
@@ -222,11 +226,12 @@ class TestSetDistReproDataloader: |
|
|
dist.barrier() |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
@magic_argv_env_context |
|
|
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self): |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("shuffle", ([True, False])) |
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self, shuffle): |
|
|
""" |
|
|
""" |
|
|
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 |
|
|
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 |
|
|
""" |
|
|
""" |
|
|
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) |
|
|
|
|
|
|
|
|
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) |
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) |
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) |
|
|
|
|
|
|
|
|
assert replaced_loader is dataloader |
|
|
assert replaced_loader is dataloader |
|
@@ -238,14 +243,15 @@ class TestSetDistReproDataloader: |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
@magic_argv_env_context |
|
|
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self): |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("shuffle", ([True, False])) |
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle): |
|
|
""" |
|
|
""" |
|
|
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler |
|
|
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler |
|
|
的表现 |
|
|
的表现 |
|
|
""" |
|
|
""" |
|
|
dataloader = DataLoader( |
|
|
dataloader = DataLoader( |
|
|
dataset=self.dataset, |
|
|
dataset=self.dataset, |
|
|
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4) |
|
|
|
|
|
|
|
|
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) |
|
|
) |
|
|
) |
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) |
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) |
|
|
|
|
|
|
|
@@ -258,13 +264,14 @@ class TestSetDistReproDataloader: |
|
|
dist.barrier() |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
@magic_argv_env_context |
|
|
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self): |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("shuffle", ([True, False])) |
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self, shuffle): |
|
|
""" |
|
|
""" |
|
|
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler |
|
|
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler |
|
|
的表现 |
|
|
的表现 |
|
|
""" |
|
|
""" |
|
|
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) |
|
|
|
|
|
batch_sampler.sampler = RandomSampler(self.dataset, True) |
|
|
|
|
|
|
|
|
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2, shuffle=shuffle) |
|
|
|
|
|
batch_sampler.sampler = RandomSampler(self.dataset, shuffle) |
|
|
dataloader = DataLoader( |
|
|
dataloader = DataLoader( |
|
|
self.dataset, |
|
|
self.dataset, |
|
|
batch_sampler=batch_sampler |
|
|
batch_sampler=batch_sampler |
|
@@ -276,16 +283,17 @@ class TestSetDistReproDataloader: |
|
|
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) |
|
|
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) |
|
|
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) |
|
|
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) |
|
|
assert replaced_loader.batch_sampler.batch_size == 2 |
|
|
assert replaced_loader.batch_sampler.batch_size == 2 |
|
|
assert replaced_loader.batch_sampler.sampler.shuffle == True |
|
|
|
|
|
|
|
|
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle |
|
|
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) |
|
|
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) |
|
|
dist.barrier() |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
@magic_argv_env_context |
|
|
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self): |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("shuffle", ([True, False])) |
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self, shuffle): |
|
|
""" |
|
|
""" |
|
|
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现 |
|
|
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现 |
|
|
""" |
|
|
""" |
|
|
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) |
|
|
|
|
|
|
|
|
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) |
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) |
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) |
|
|
|
|
|
|
|
|
assert not (replaced_loader is dataloader) |
|
|
assert not (replaced_loader is dataloader) |
|
@@ -293,7 +301,7 @@ class TestSetDistReproDataloader: |
|
|
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) |
|
|
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) |
|
|
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) |
|
|
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) |
|
|
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size |
|
|
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size |
|
|
assert replaced_loader.batch_sampler.sampler.shuffle == True |
|
|
|
|
|
|
|
|
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle |
|
|
dist.barrier() |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
@@ -302,13 +310,14 @@ class TestSetDistReproDataloader: |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
@magic_argv_env_context |
|
|
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self): |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("shuffle", ([True, False])) |
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle): |
|
|
""" |
|
|
""" |
|
|
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler |
|
|
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler |
|
|
的表现 |
|
|
的表现 |
|
|
""" |
|
|
""" |
|
|
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) |
|
|
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) |
|
|
batch_sampler.sampler = RandomSampler(self.dataset, True) |
|
|
|
|
|
|
|
|
batch_sampler.sampler = RandomSampler(self.dataset, shuffle) |
|
|
dataloader = DataLoader( |
|
|
dataloader = DataLoader( |
|
|
self.dataset, |
|
|
self.dataset, |
|
|
batch_sampler=batch_sampler |
|
|
batch_sampler=batch_sampler |
|
@@ -320,18 +329,19 @@ class TestSetDistReproDataloader: |
|
|
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) |
|
|
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) |
|
|
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) |
|
|
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) |
|
|
assert replaced_loader.batch_sampler.batch_size == 2 |
|
|
assert replaced_loader.batch_sampler.batch_size == 2 |
|
|
assert replaced_loader.batch_sampler.sampler.shuffle == True |
|
|
|
|
|
|
|
|
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle |
|
|
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) |
|
|
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) |
|
|
dist.barrier() |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
@magic_argv_env_context |
|
|
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self): |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("shuffle", ([True, False])) |
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle): |
|
|
""" |
|
|
""" |
|
|
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler |
|
|
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler |
|
|
的表现 |
|
|
的表现 |
|
|
""" |
|
|
""" |
|
|
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) |
|
|
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) |
|
|
batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, True) |
|
|
|
|
|
|
|
|
batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, shuffle) |
|
|
dataloader = DataLoader( |
|
|
dataloader = DataLoader( |
|
|
self.dataset, |
|
|
self.dataset, |
|
|
batch_sampler=batch_sampler |
|
|
batch_sampler=batch_sampler |
|
@@ -349,11 +359,12 @@ class TestSetDistReproDataloader: |
|
|
dist.barrier() |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
@magic_argv_env_context |
|
|
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self): |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("shuffle", ([True, False])) |
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self, shuffle): |
|
|
""" |
|
|
""" |
|
|
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现 |
|
|
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现 |
|
|
""" |
|
|
""" |
|
|
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) |
|
|
|
|
|
|
|
|
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) |
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) |
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) |
|
|
|
|
|
|
|
|
assert not (replaced_loader is dataloader) |
|
|
assert not (replaced_loader is dataloader) |
|
|