|
|
@@ -1,4 +1,3 @@ |
|
|
|
from dataclasses import replace |
|
|
|
import pytest |
|
|
|
import os |
|
|
|
|
|
|
@@ -20,13 +19,14 @@ import paddle |
|
|
|
import paddle.distributed as dist |
|
|
|
from paddle.io import DataLoader, BatchSampler |
|
|
|
|
|
|
|
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False): |
|
|
|
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"): |
|
|
|
paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension) |
|
|
|
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) |
|
|
|
driver = PaddleFleetDriver( |
|
|
|
model=paddle_model, |
|
|
|
parallel_device=device, |
|
|
|
fp16=fp16, |
|
|
|
output_from_new_proc=output_from_new_proc |
|
|
|
) |
|
|
|
driver.set_optimizers(paddle_opt) |
|
|
|
driver.setup() |
|
|
@@ -292,7 +292,6 @@ class TestSetDistReproDataloader: |
|
|
|
assert replaced_loader.batch_sampler.batch_size == 4 |
|
|
|
assert replaced_loader.drop_last == dataloader.drop_last |
|
|
|
self.check_distributed_sampler(replaced_loader.batch_sampler) |
|
|
|
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) |
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
@@ -319,7 +318,6 @@ class TestSetDistReproDataloader: |
|
|
|
assert replaced_loader.batch_sampler.batch_size == 4 |
|
|
|
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle |
|
|
|
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) |
|
|
|
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) |
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
@@ -340,7 +338,6 @@ class TestSetDistReproDataloader: |
|
|
|
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size |
|
|
|
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle |
|
|
|
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) |
|
|
|
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) |
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
""" |
|
|
@@ -372,7 +369,6 @@ class TestSetDistReproDataloader: |
|
|
|
assert replaced_loader.batch_sampler.batch_size == 4 |
|
|
|
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle |
|
|
|
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) |
|
|
|
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) |
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
@@ -399,7 +395,6 @@ class TestSetDistReproDataloader: |
|
|
|
assert replaced_loader.batch_sampler.batch_size == 4 |
|
|
|
assert replaced_loader.drop_last == dataloader.drop_last |
|
|
|
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) |
|
|
|
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) |
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
@@ -420,7 +415,6 @@ class TestSetDistReproDataloader: |
|
|
|
assert replaced_loader.batch_sampler.batch_size == 4 |
|
|
|
assert replaced_loader.drop_last == dataloader.drop_last |
|
|
|
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) |
|
|
|
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) |
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
def check_distributed_sampler(self, sampler): |
|
|
@@ -437,12 +431,14 @@ class TestSetDistReproDataloader: |
|
|
|
测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确 |
|
|
|
""" |
|
|
|
# 迭代两个 batch |
|
|
|
num_replicas = len(self.device) |
|
|
|
num_consumed_batches = 2 |
|
|
|
already_seen_idx = set() |
|
|
|
for idx, batch in enumerate(replaced_loader): |
|
|
|
if idx >= num_consumed_batches: |
|
|
|
break |
|
|
|
already_seen_idx.update(batch) |
|
|
|
dist.barrier() |
|
|
|
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): |
|
|
|
sampler_states = replaced_loader.batch_sampler.state_dict() |
|
|
|
else: |
|
|
@@ -450,6 +446,7 @@ class TestSetDistReproDataloader: |
|
|
|
|
|
|
|
# 加载 num_consumed_samples_array,设置正确取出的 batch 数目 |
|
|
|
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) |
|
|
|
print("array: ", num_consumed_samples_array) |
|
|
|
|
|
|
|
# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range |
|
|
|
left_idxes = set() |
|
|
@@ -458,7 +455,7 @@ class TestSetDistReproDataloader: |
|
|
|
if num_consumed_samples_array is not None: |
|
|
|
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] |
|
|
|
else: |
|
|
|
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size |
|
|
|
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas |
|
|
|
# 重新改造 dataloader |
|
|
|
new_loader = DataLoader( |
|
|
|
dataset=replaced_loader.dataset, |
|
|
@@ -481,7 +478,7 @@ class TestSetDistReproDataloader: |
|
|
|
if num_consumed_samples_array is not None: |
|
|
|
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_samples] |
|
|
|
else: |
|
|
|
sampler_states["num_consumed_samples"] = num_consumed_samples |
|
|
|
sampler_states["num_consumed_samples"] = num_consumed_samples * num_replicas |
|
|
|
# 重新构造 dataloader |
|
|
|
batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size) |
|
|
|
batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle) |
|
|
@@ -494,10 +491,8 @@ class TestSetDistReproDataloader: |
|
|
|
for idx, batch in enumerate(new_loader): |
|
|
|
left_idxes.update(batch) |
|
|
|
|
|
|
|
num_replicas = len(self.device) |
|
|
|
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas |
|
|
|
assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas |
|
|
|
assert False |
|
|
|
|
|
|
|
|
|
|
|
############################################################################ |
|
|
|