Browse Source

完善paddle fleet set_dist_repro_dataloader的测试例

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
6bfdb39c2f
2 changed files with 10 additions and 15 deletions
  1. +7
    -12
      tests/core/drivers/paddle_driver/test_fleet.py
  2. +3
    -3
      tests/core/drivers/paddle_driver/test_single_device.py

+ 7
- 12
tests/core/drivers/paddle_driver/test_fleet.py View File

@@ -1,4 +1,3 @@
from dataclasses import replace
import pytest
import os

@@ -20,13 +19,14 @@ import paddle
import paddle.distributed as dist
from paddle.io import DataLoader, BatchSampler

def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False):
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"):
paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=device,
fp16=fp16,
output_from_new_proc=output_from_new_proc
)
driver.set_optimizers(paddle_opt)
driver.setup()
@@ -292,7 +292,6 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier()

@magic_argv_env_context
@@ -319,7 +318,6 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier()

@magic_argv_env_context
@@ -340,7 +338,6 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier()

"""
@@ -372,7 +369,6 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier()

@magic_argv_env_context
@@ -399,7 +395,6 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier()

@magic_argv_env_context
@@ -420,7 +415,6 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier()

def check_distributed_sampler(self, sampler):
@@ -437,12 +431,14 @@ class TestSetDistReproDataloader:
测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确
"""
# 迭代两个 batch
num_replicas = len(self.device)
num_consumed_batches = 2
already_seen_idx = set()
for idx, batch in enumerate(replaced_loader):
if idx >= num_consumed_batches:
break
already_seen_idx.update(batch)
dist.barrier()
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict()
else:
@@ -450,6 +446,7 @@ class TestSetDistReproDataloader:

# 加载 num_consumed_samples_array,设置正确取出的 batch 数目
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
print("array: ", num_consumed_samples_array)

# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
@@ -458,7 +455,7 @@ class TestSetDistReproDataloader:
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas
# 重新改造 dataloader
new_loader = DataLoader(
dataset=replaced_loader.dataset,
@@ -481,7 +478,7 @@ class TestSetDistReproDataloader:
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_samples]
else:
sampler_states["num_consumed_samples"] = num_consumed_samples
sampler_states["num_consumed_samples"] = num_consumed_samples * num_replicas
# 重新构造 dataloader
batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size)
batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle)
@@ -494,10 +491,8 @@ class TestSetDistReproDataloader:
for idx, batch in enumerate(new_loader):
left_idxes.update(batch)

num_replicas = len(self.device)
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas
assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas
assert False


############################################################################


+ 3
- 3
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -513,11 +513,11 @@ class TestSetDistReproDataloder:
new_loader.batch_sampler.load_state_dict(sampler_states)
else:
batch_size = replaced_loader.batch_sampler.batch_size
num_consumed_batches = num_consumed_batches * batch_size
num_consumed_samples = num_consumed_batches * batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_samples]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
sampler_states["num_consumed_samples"] = num_consumed_samples
# 重新构造 dataloader
batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size)
batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle)


Loading…
Cancel
Save