From 6bfdb39c2f3db859bb980e7a4b7a5685d855ba72 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sat, 16 Apr 2022 06:42:29 +0000 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84paddle=20fleet=20set=5Fdist?= =?UTF-8?q?=5Frepro=5Fdataloader=E7=9A=84=E6=B5=8B=E8=AF=95=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/drivers/paddle_driver/test_fleet.py | 19 +++++++------------ .../paddle_driver/test_single_device.py | 6 +++--- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index 5fe52c54..125a1c43 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -1,4 +1,3 @@ -from dataclasses import replace import pytest import os @@ -20,13 +19,14 @@ import paddle import paddle.distributed as dist from paddle.io import DataLoader, BatchSampler -def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False): +def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"): paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension) paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) driver = PaddleFleetDriver( model=paddle_model, parallel_device=device, fp16=fp16, + output_from_new_proc=output_from_new_proc ) driver.set_optimizers(paddle_opt) driver.setup() @@ -292,7 +292,6 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.drop_last == dataloader.drop_last self.check_distributed_sampler(replaced_loader.batch_sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() @magic_argv_env_context @@ -319,7 +318,6 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() @magic_argv_env_context @@ -340,7 +338,6 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() """ @@ -372,7 +369,6 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() @magic_argv_env_context @@ -399,7 +395,6 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.drop_last == dataloader.drop_last self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() @magic_argv_env_context @@ -420,7 +415,6 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.drop_last == dataloader.drop_last self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() def check_distributed_sampler(self, sampler): @@ -437,12 +431,14 @@ class TestSetDistReproDataloader: 测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确 """ # 迭代两个 batch + num_replicas = len(self.device) num_consumed_batches = 2 already_seen_idx = set() for idx, batch in enumerate(replaced_loader): if idx >= num_consumed_batches: break already_seen_idx.update(batch) + dist.barrier() if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): sampler_states = replaced_loader.batch_sampler.state_dict() else: @@ -450,6 +446,7 @@ class TestSetDistReproDataloader: # 加载 num_consumed_samples_array,设置正确取出的 batch 数目 num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) + print("array: ", num_consumed_samples_array) # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range left_idxes = set() @@ -458,7 +455,7 @@ class TestSetDistReproDataloader: if num_consumed_samples_array is not None: sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] else: - sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas # 重新改造 dataloader new_loader = DataLoader( dataset=replaced_loader.dataset, @@ -481,7 +478,7 @@ class TestSetDistReproDataloader: if num_consumed_samples_array is not None: sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_samples] else: - sampler_states["num_consumed_samples"] = num_consumed_samples + sampler_states["num_consumed_samples"] = num_consumed_samples * num_replicas # 重新构造 dataloader batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size) batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle) @@ -494,10 +491,8 @@ class TestSetDistReproDataloader: for idx, batch in enumerate(new_loader): left_idxes.update(batch) - num_replicas = len(self.device) assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas - assert False ############################################################################ diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 92c55434..1c9a8241 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -513,11 +513,11 @@ class TestSetDistReproDataloder: new_loader.batch_sampler.load_state_dict(sampler_states) else: batch_size = replaced_loader.batch_sampler.batch_size - num_consumed_batches = num_consumed_batches * batch_size + num_consumed_samples = num_consumed_batches * batch_size if num_consumed_samples_array is not None: - sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] + sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_samples] else: - sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size + sampler_states["num_consumed_samples"] = num_consumed_samples # 重新构造 dataloader batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size) batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle)