|
|
@@ -224,7 +224,6 @@ class TestSetDistReproDataloder: |
|
|
|
""" |
|
|
|
def setup_method(self): |
|
|
|
self.dataset = PaddleNormalDataset(20) |
|
|
|
self.dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) |
|
|
|
model = PaddleNormalModel_Classification_1(10, 32) |
|
|
|
self.driver = PaddleSingleDriver(model, device="cpu") |
|
|
|
|
|
|
@@ -233,55 +232,59 @@ class TestSetDistReproDataloder: |
|
|
|
测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 |
|
|
|
当dist为字符串时,此时应该返回原来的 dataloader |
|
|
|
""" |
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=False) |
|
|
|
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) |
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) |
|
|
|
|
|
|
|
assert replaced_loader is self.dataloader |
|
|
|
assert replaced_loader is dataloader |
|
|
|
|
|
|
|
def test_set_dist_repro_dataloader_with_reproducible_true(self): |
|
|
|
""" |
|
|
|
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 |
|
|
|
当dist为字符串时,此时应该返回新的 dataloader,且 batch_sampler 为 RandomBatchSampler |
|
|
|
""" |
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist="dist", reproducible=True) |
|
|
|
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) |
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) |
|
|
|
|
|
|
|
assert not (replaced_loader is self.dataloader) |
|
|
|
assert not (replaced_loader is dataloader) |
|
|
|
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) |
|
|
|
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) |
|
|
|
assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size |
|
|
|
assert replaced_loader.drop_last == self.dataloader.drop_last |
|
|
|
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size |
|
|
|
assert replaced_loader.drop_last == dataloader.drop_last |
|
|
|
|
|
|
|
# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) |
|
|
|
# self.check_set_dist_repro_dataloader(dataloader, replaced_loader) |
|
|
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): |
|
|
|
""" |
|
|
|
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler |
|
|
|
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler |
|
|
|
""" |
|
|
|
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) |
|
|
|
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False) |
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False) |
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) |
|
|
|
|
|
|
|
assert not (replaced_loader is self.dataloader) |
|
|
|
assert not (replaced_loader is dataloader) |
|
|
|
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) |
|
|
|
assert replaced_loader.batch_sampler is dist |
|
|
|
|
|
|
|
# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) |
|
|
|
self.check_set_dist_repro_dataloader(dataloader, replaced_loader) |
|
|
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_sampler(self): |
|
|
|
""" |
|
|
|
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 |
|
|
|
应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler |
|
|
|
""" |
|
|
|
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) |
|
|
|
dist = RandomSampler(self.dataset, shuffle=True) |
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(self.dataloader, dist=dist, reproducible=False) |
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) |
|
|
|
|
|
|
|
assert not (replaced_loader is self.dataloader) |
|
|
|
assert not (replaced_loader is dataloader) |
|
|
|
assert isinstance(replaced_loader.batch_sampler, BatchSampler) |
|
|
|
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) |
|
|
|
assert not (replaced_loader.batch_sampler is self.dataloader.batch_sampler) |
|
|
|
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) |
|
|
|
assert replaced_loader.batch_sampler.sampler is dist |
|
|
|
assert replaced_loader.batch_sampler.batch_size == self.dataloader.batch_sampler.batch_size |
|
|
|
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size |
|
|
|
|
|
|
|
# self.check_set_dist_repro_dataloader(self.dataloader, replaced_loader) |
|
|
|
self.check_set_dist_repro_dataloader(dataloader, replaced_loader) |
|
|
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self): |
|
|
|
""" |
|
|
@@ -295,11 +298,12 @@ class TestSetDistReproDataloder: |
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) |
|
|
|
|
|
|
|
assert not (replaced_loader is dataloader) |
|
|
|
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) |
|
|
|
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) |
|
|
|
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size |
|
|
|
assert replaced_loader.drop_last == dataloader.drop_last |
|
|
|
|
|
|
|
# self.check_set_dist_repro_dataloader(dataloader, replaced_loader) |
|
|
|
self.check_set_dist_repro_dataloader(dataloader, replaced_loader) |
|
|
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self): |
|
|
|
""" |
|
|
@@ -316,34 +320,52 @@ class TestSetDistReproDataloder: |
|
|
|
|
|
|
|
assert not (replaced_loader is dataloader) |
|
|
|
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) |
|
|
|
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) |
|
|
|
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) |
|
|
|
assert replaced_loader.batch_sampler.batch_size == 2 |
|
|
|
assert replaced_loader.batch_sampler.sampler.shuffle == True |
|
|
|
|
|
|
|
# self.check_set_dist_repro_dataloader(dataloader, replaced_loader) |
|
|
|
self.check_set_dist_repro_dataloader(dataloader, replaced_loader) |
|
|
|
|
|
|
|
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader): |
|
|
|
""" |
|
|
|
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 |
|
|
|
""" |
|
|
|
# 迭代两个 batch |
|
|
|
# 这里会发生 BatchSampler 里 yield 了多次但 dataloader 只取出一次的情况。 |
|
|
|
num_consumed_batches = 2 |
|
|
|
already_seen_idx = set() |
|
|
|
for idx, batch in replaced_loader: |
|
|
|
already_seen_idx.update(batch) |
|
|
|
if idx >= 1: |
|
|
|
for idx, batch in enumerate(replaced_loader): |
|
|
|
if idx >= num_consumed_batches: |
|
|
|
break |
|
|
|
already_seen_idx.update(batch) |
|
|
|
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): |
|
|
|
sampler_states = replaced_loader.batch_sampler.state_dict() |
|
|
|
else: |
|
|
|
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() |
|
|
|
print(sampler_states["data_idx"]) |
|
|
|
|
|
|
|
# 加载 num_consumed_samples_array,设置正确取出的 batch 数目 |
|
|
|
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) |
|
|
|
|
|
|
|
import time |
|
|
|
time.sleep(5) |
|
|
|
|
|
|
|
# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range |
|
|
|
left_idxes = set() |
|
|
|
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): |
|
|
|
batch_size = replaced_loader.batch_sampler.batch_size |
|
|
|
if num_consumed_samples_array is not None: |
|
|
|
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] |
|
|
|
else: |
|
|
|
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size |
|
|
|
replaced_loader.batch_sampler.load_state_dict(sampler_states) |
|
|
|
else: |
|
|
|
batch_size = replaced_loader.batch_sampler.batch_size |
|
|
|
if num_consumed_samples_array is not None: |
|
|
|
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] |
|
|
|
else: |
|
|
|
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size |
|
|
|
replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states) |
|
|
|
replaced_loader.batch_sampler.sampler.set_epoch(0) |
|
|
|
for idx, batch in enumerate(replaced_loader): |
|
|
|
left_idxes.update(batch) |
|
|
|
|
|
|
|