From 26c80d620cf02d9dedb92707e77442f94267135b Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Fri, 15 Apr 2022 16:01:21 +0000 Subject: [PATCH 01/10] =?UTF-8?q?paddle=E5=8D=95=E5=8D=A1=E5=8A=A0?= =?UTF-8?q?=E8=BD=BDfp16=E7=9A=84=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../drivers/paddle_driver/paddle_driver.py | 36 ++++++++++++---- fastNLP/core/drivers/paddle_driver/utils.py | 7 ++-- .../paddle_driver/test_single_device.py | 42 ++++++++++++------- 3 files changed, 59 insertions(+), 26 deletions(-) diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 3b8ad7d8..75e0352f 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -7,7 +7,7 @@ from dataclasses import dataclass import numpy as np -from .utils import _build_fp16_env, optimizer_state_to_device +from .utils import _build_fp16_env, optimizer_state_to_device, DummyGradScaler from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.core.drivers.driver import Driver from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device @@ -278,6 +278,12 @@ class PaddleDriver(Driver): logger.debug("Save optimizer state dict.") states["optimizers_state_dict"] = optimizers_state_dict + + # 4.保存fp16的状态 + if not isinstance(self.grad_scaler, DummyGradScaler): + grad_scaler_state_dict = self.grad_scaler.state_dict() + states['grad_scaler_state_dict'] = grad_scaler_state_dict + paddle.save(states, str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: @@ -285,7 +291,7 @@ class PaddleDriver(Driver): states = paddle.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) # 1. 加载 optimizers 的状态; - optimizers_state_dict = states["optimizers_state_dict"] + optimizers_state_dict = states.pop("optimizers_state_dict") for i in range(len(self.optimizers)): optimizer: Optimizer = self.optimizers[i] optimizer.set_state_dict(optimizers_state_dict[f"optimizer{i}"]) @@ -295,18 +301,32 @@ class PaddleDriver(Driver): if should_load_model: self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict) if only_state_dict: - logger.debug("Load model state dict.") + logger.debug("Load model state dict...") else: - logger.debug("Load model.") - - # 3. 恢复 sampler 的状态; + logger.debug("Load model...") + + # 3. 加载fp16的状态; + if "grad_scaler_state_dict" in states: + grad_scaler_state_dict = states.pop("grad_scaler_state_dict") + if isinstance(self.grad_scaler, DummyGradScaler): + self.auto_cast, _grad_scaler = _build_fp16_env(dummy=False) + self.grad_scaler = _grad_scaler() + self.fp16 = True + self.grad_scaler.load_state_dict(grad_scaler_state_dict) + logger.debug("Load grad_scaler state dict...") + elif not isinstance(self.grad_scaler, DummyGradScaler): + logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " + f"the training process may be unstable.") + + # 4. 恢复 sampler 的状态; dataloader_args = self.get_dataloader_args(dataloader) if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): sampler = dataloader_args.batch_sampler elif isinstance(dataloader_args.sampler, ReproducibleSampler): sampler = dataloader_args.sampler elif self.is_distributed(): - raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") + raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " + "`ReproducibleSampler`.") else: sampler = RandomBatchSampler( batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, @@ -316,7 +336,7 @@ class PaddleDriver(Driver): sampler.load_state_dict(states["sampler_states"]) states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) - # 4. 修改 trainer_state.batch_idx_in_epoch + # 5. 修改 trainer_state.batch_idx_in_epoch # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; if not isinstance(sampler, ReproducibleBatchSampler): if dataloader_args.drop_last: diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index feb5c3eb..48598a34 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -19,7 +19,7 @@ if _NEED_IMPORT_PADDLE: import paddle from paddle import nn from paddle.nn import Layer - from paddle.io import DataLoader, BatchSampler, Dataset + from paddle.io import DataLoader, BatchSampler from paddle.amp import auto_cast, GradScaler else: from fastNLP.core.utils.dummy_class import DummyClass as Layer @@ -140,8 +140,7 @@ class DummyGradScaler: def _build_fp16_env(dummy=False): if dummy: - auto_cast = ExitStack - GradScaler = DummyGradScaler + return ExitStack, DummyGradScaler else: if not paddle.device.is_compiled_with_cuda(): raise RuntimeError("No cuda") @@ -150,7 +149,7 @@ def _build_fp16_env(dummy=False): "NOTE: your device does NOT support faster training with fp16, " "please switch to FP32 which is likely to be faster" ) - return auto_cast, GradScaler + return auto_cast, GradScaler def find_free_ports(num): def __free_port(): diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 79527f39..12f52537 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -1,4 +1,3 @@ -from dataclasses import replace import os from re import S os.environ["FASTNLP_BACKEND"] = "paddle" @@ -536,13 +535,13 @@ class TestSetDistReproDataloder: # ############################################################################ -def generate_random_driver(features, labels): +def generate_random_driver(features, labels, fp16, device="cpu"): """ 生成driver """ model = PaddleNormalModel_Classification_1(labels, features) opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.01) - driver = PaddleSingleDriver(model, device="cpu") + driver = PaddleSingleDriver(model, device=device, fp16=fp16) driver.set_optimizers(opt) driver.setup() @@ -584,21 +583,23 @@ def test_save_and_load_model(prepare_test_save_load, only_state_dict): synchronize_safe_rm(path + ".pdiparams.info") synchronize_safe_rm(path + ".pdmodel") -@pytest.mark.parametrize("only_state_dict", ([True, False])) -def test_save_and_load_with_randombatchsampler(only_state_dict): +# @pytest.mark.parametrize("only_state_dict", ([True, False])) +@pytest.mark.parametrize("only_state_dict", ([True])) +@pytest.mark.parametrize("fp16", ([True, False])) +def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): """ 测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 """ try: path = "model.ckp" - - driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) dataset = PaddleRandomMaxDataset(40, 10) dataloader = DataLoader( dataset=dataset, batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) ) + driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu") + num_consumed_batches = 2 already_seen_x_set = set() @@ -633,8 +634,13 @@ def test_save_and_load_with_randombatchsampler(only_state_dict): assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 - # 3. 检查 model 的参数是否正确 - # 4. 检查 batch_idx + # 3. 检查 fp16 是否被加载 + if fp16: + assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler) + + + # 4. 检查 model 的参数是否正确 + # 5. 检查 batch_idx start_batch = load_states.pop('batch_idx_in_epoch') assert start_batch == 2 * num_consumed_batches left_x_batches = set() @@ -654,8 +660,12 @@ def test_save_and_load_with_randombatchsampler(only_state_dict): finally: synchronize_safe_rm(path) -@pytest.mark.parametrize("only_state_dict", ([True, False])) -def test_save_and_load_with_randomsampler(only_state_dict): +# @pytest.mark.parametrize("only_state_dict", ([True, False])) +# TODO 在有迭代且使用了paddle.jit.save的时候会引发段错误,注释掉任意一段都不会出错 +# 但无法在单独的文件中复现 +@pytest.mark.parametrize("only_state_dict", ([True])) +@pytest.mark.parametrize("fp16", ([True, False])) +def test_save_and_load_with_randomsampler(only_state_dict, fp16): """ 测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 """ @@ -663,7 +673,7 @@ def test_save_and_load_with_randomsampler(only_state_dict): try: path = "model.ckp" - driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) + driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu") dataset = PaddleRandomMaxDataset(40, 10) batch_sampler = BatchSampler(dataset=dataset, batch_size=4) batch_sampler.sampler = RandomSampler(dataset, True) @@ -711,8 +721,12 @@ def test_save_and_load_with_randomsampler(only_state_dict): assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] - # 3. 检查 model 的参数是否正确 - # 4. 检查 batch_idx + # 3. 检查 fp16 是否被加载 + if fp16: + assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler) + + # 4. 检查 model 的参数是否正确 + # 5. 检查 batch_idx start_batch = load_states.pop('batch_idx_in_epoch') assert start_batch == 2 * num_consumed_batches left_x_batches = set() From 32d8e2747267b39d857c060bad46a20bb9d7ead3 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Fri, 15 Apr 2022 16:10:10 +0000 Subject: [PATCH 02/10] small --- tests/core/drivers/paddle_driver/test_single_device.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 12f52537..92c55434 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -535,7 +535,7 @@ class TestSetDistReproDataloder: # ############################################################################ -def generate_random_driver(features, labels, fp16, device="cpu"): +def generate_random_driver(features, labels, fp16=False, device="cpu"): """ 生成driver """ @@ -549,8 +549,8 @@ def generate_random_driver(features, labels, fp16, device="cpu"): @pytest.fixture def prepare_test_save_load(): - dataset = PaddleRandomMaxDataset(320, 10) - dataloader = DataLoader(dataset, batch_size=32) + dataset = PaddleRandomMaxDataset(40, 10) + dataloader = DataLoader(dataset, batch_size=4) driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) return driver1, driver2, dataloader From de544707d97e8bb18797fa936ea093656c2604e0 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sat, 16 Apr 2022 05:40:23 +0000 Subject: [PATCH 03/10] =?UTF-8?q?paddle=20fleet=20set=5Fdist=5Frepro=5Fdat?= =?UTF-8?q?aloader=E7=9A=84=E6=B5=8B=E8=AF=95=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/paddle_driver/fleet.py | 4 +- .../core/drivers/paddle_driver/test_fleet.py | 360 +++++++++++++++++- 2 files changed, 347 insertions(+), 17 deletions(-) diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 762b3114..30d3b701 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -1,6 +1,5 @@ import os import shutil -from functools import partial from typing import List, Union, Optional, Dict, Tuple, Callable from .paddle_driver import PaddleDriver @@ -38,7 +37,6 @@ if _NEED_IMPORT_PADDLE: from paddle import DataParallel import paddle.distributed.fleet as fleet import paddle.distributed as paddledist - from paddle.io import BatchSampler from paddle.optimizer import Optimizer from paddle.fluid.reader import _DatasetKind from paddle.fluid.dygraph import parallel_helper @@ -305,7 +303,7 @@ class PaddleFleetDriver(PaddleDriver): raise RuntimeError(f"There is no `{fn}` method in your model.") else: if hasattr(model, fn): - logger.warning("Notice your model is a `DistributedDataParallel` model. And your model also implements " + logger.warning("Notice your model is a `DataParallel` model. And your model also implements " f"the `{fn}` method, which we can not call actually, we will" " call `forward` function instead of `train_step` and you should note that.") elif fn not in {"train_step", "evaluate_step"}: diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index de98f9c5..ea279292 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -12,28 +12,44 @@ from fastNLP.core.samplers import ( UnrepeatedSequentialSampler, ) from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 -from tests.helpers.datasets.paddle_data import PaddleNormalDataset +from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset from tests.helpers.utils import magic_argv_env_context +from fastNLP.core import synchronize_safe_rm import paddle import paddle.distributed as dist from paddle.io import DataLoader, BatchSampler -def generate_driver(num_labels, feature_dimension): +def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False): paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension) paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) driver = PaddleFleetDriver( model=paddle_model, - parallel_device=[0,1], + parallel_device=device, + fp16=fp16, ) driver.set_optimizers(paddle_opt) driver.setup() return driver +@magic_argv_env_context +def test_multi_drivers(): + """ + 测试使用了多个 PaddleFleetDriver 的情况。 + """ + driver1 = generate_driver(10, 10) + driver2 = generate_driver(20, 10) + + with pytest.raises(RuntimeError): + # 设备设置不同,应该报错 + driver3 = generate_driver(20, 3, device=[0,2]) + + dist.barrier() + ############################################################################ # -# 测试PaddleFleetDriver的一些函数 +# 测试 PaddleFleetDriver 的一些函数 # ############################################################################ @@ -106,10 +122,11 @@ class TestSetDistReproDataloader: @classmethod def setup_class(cls): - cls.driver = generate_driver(10, 10) + cls.device = [0, 1] + cls.driver = generate_driver(10, 10, device=cls.device) def setup_method(self): - self.dataset = PaddleNormalDataset(20) + self.dataset = PaddleNormalDataset(40) """ 传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 @@ -121,6 +138,7 @@ class TestSetDistReproDataloader: def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 + 此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler """ dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) @@ -130,6 +148,7 @@ class TestSetDistReproDataloader: assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) assert replaced_loader.batch_sampler is batch_sampler self.check_distributed_sampler(replaced_loader.batch_sampler) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() @@ -138,6 +157,7 @@ class TestSetDistReproDataloader: def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 + 此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler """ dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) sampler = RandomSampler(self.dataset, shuffle=shuffle) @@ -150,6 +170,7 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.sampler is sampler assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() @@ -164,6 +185,7 @@ class TestSetDistReproDataloader: def test_set_dist_repro_dataloader_with_dist_none_reproducible_true(self): """ 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 + 当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错 """ dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) with pytest.raises(RuntimeError): @@ -178,6 +200,8 @@ class TestSetDistReproDataloader: """ 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler 时的表现 + 此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler + 和原 dataloader 相同 """ dataloader = DataLoader( self.dataset, @@ -194,6 +218,7 @@ class TestSetDistReproDataloader: assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) assert replaced_loader.batch_sampler.batch_size == 4 self.check_distributed_sampler(dataloader.batch_sampler) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() @@ -202,8 +227,10 @@ class TestSetDistReproDataloader: def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现 + 此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 + batch_sampler.sampler 和原 dataloader 相同 """ - batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4) batch_sampler.sampler = RandomSampler(self.dataset, shuffle) batch_sampler.sampler.set_distributed( num_replicas=self.driver.world_size, @@ -220,9 +247,11 @@ class TestSetDistReproDataloader: assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) - assert replaced_loader.batch_sampler.batch_size == 2 + assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.batch_sampler.drop_last == False self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + dist.barrier() @magic_argv_env_context @@ -230,6 +259,7 @@ class TestSetDistReproDataloader: def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 + 此时直接返回原来的 dataloader,不做任何处理。 """ dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) @@ -248,6 +278,7 @@ class TestSetDistReproDataloader: """ 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler 的表现 + 此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性 """ dataloader = DataLoader( dataset=self.dataset, @@ -261,6 +292,7 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.drop_last == dataloader.drop_last self.check_distributed_sampler(replaced_loader.batch_sampler) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() @magic_argv_env_context @@ -269,8 +301,10 @@ class TestSetDistReproDataloader: """ 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler 的表现 + 此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关 + 的属性 """ - batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2, shuffle=shuffle) + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4, shuffle=shuffle) batch_sampler.sampler = RandomSampler(self.dataset, shuffle) dataloader = DataLoader( self.dataset, @@ -282,9 +316,10 @@ class TestSetDistReproDataloader: assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) - assert replaced_loader.batch_sampler.batch_size == 2 + assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() @magic_argv_env_context @@ -292,6 +327,8 @@ class TestSetDistReproDataloader: def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现 + 此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关 + 的属性 """ dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) @@ -302,6 +339,8 @@ class TestSetDistReproDataloader: assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.batch_sampler.sampler.shuffle == shuffle + self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() """ @@ -315,8 +354,10 @@ class TestSetDistReproDataloader: """ 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler 的表现 + 此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关 + 的属性 """ - batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4) batch_sampler.sampler = RandomSampler(self.dataset, shuffle) dataloader = DataLoader( self.dataset, @@ -328,9 +369,10 @@ class TestSetDistReproDataloader: assert isinstance(replaced_loader.batch_sampler, BatchSampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) - assert replaced_loader.batch_sampler.batch_size == 2 + assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() @magic_argv_env_context @@ -339,8 +381,9 @@ class TestSetDistReproDataloader: """ 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler 的表现 + 此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler """ - batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4) batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, shuffle) dataloader = DataLoader( self.dataset, @@ -353,9 +396,10 @@ class TestSetDistReproDataloader: assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) - assert replaced_loader.batch_sampler.batch_size == 2 + assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.drop_last == dataloader.drop_last self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() @magic_argv_env_context @@ -363,6 +407,8 @@ class TestSetDistReproDataloader: def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现 + 此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关 + 的属性 """ dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) @@ -374,6 +420,7 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.drop_last == dataloader.drop_last self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() def check_distributed_sampler(self, sampler): @@ -385,3 +432,288 @@ class TestSetDistReproDataloader: if not isinstance(sampler, UnrepeatedSampler): assert sampler.pad == True + def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle): + """ + 测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确 + """ + # 迭代两个 batch + num_consumed_batches = 2 + already_seen_idx = set() + for idx, batch in enumerate(replaced_loader): + if idx >= num_consumed_batches: + break + already_seen_idx.update(batch) + if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): + sampler_states = replaced_loader.batch_sampler.state_dict() + else: + sampler_states = replaced_loader.batch_sampler.sampler.state_dict() + + # 加载 num_consumed_samples_array,设置正确取出的 batch 数目 + num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) + + # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range + left_idxes = set() + if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): + batch_size = replaced_loader.batch_sampler.batch_size + if num_consumed_samples_array is not None: + sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] + else: + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size + # 重新改造 dataloader + new_loader = DataLoader( + dataset=replaced_loader.dataset, + batch_sampler=BucketedBatchSampler( + replaced_loader.dataset, + length=replaced_loader.dataset._data, + batch_size=batch_size, + shuffle=shuffle, + ) + ) + new_loader.batch_sampler.set_distributed( + num_replicas=self.driver.world_size, + rank=self.driver.global_rank, + pad=True + ) + new_loader.batch_sampler.load_state_dict(sampler_states) + else: + batch_size = replaced_loader.batch_sampler.batch_size + num_consumed_samples = num_consumed_batches * batch_size + if num_consumed_samples_array is not None: + sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_samples] + else: + sampler_states["num_consumed_samples"] = num_consumed_samples + # 重新构造 dataloader + batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size) + batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle) + batch_sampler.sampler.set_distributed( + num_replicas=self.driver.world_size, + rank=self.driver.global_rank + ) + new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler) + new_loader.batch_sampler.sampler.load_state_dict(sampler_states) + for idx, batch in enumerate(new_loader): + left_idxes.update(batch) + + num_replicas = len(self.device) + assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas + assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas + assert False + + +############################################################################ +# +# 测试 save 和 load 相关的功能 +# +############################################################################ +class TestSaveLoad: + """ + 测试多卡情况下 save 和 load 相关函数的表现 + """ + def setup_method(self): + self.dataset = PaddleRandomMaxDataset(20, 10) + self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10) + + @magic_argv_env_context + @pytest.mark.parametrize("only_state_dict", ([True, False])) + def test_save_and_load_model(self, only_state_dict): + """ + 测试 save_model 和 load_model 函数 + """ + try: + path = "model" + dataloader = DataLoader(self.dataset, batch_size=2) + + if only_state_dict: + self.driver1.save_model(path, only_state_dict) + else: + self.driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((4, 10))]) + + # 同步 + dist.barrier() + self.driver2.load_model(path, only_state_dict) + + for idx, batch in enumerate(dataloader): + batch = self.driver1.move_data_to_device(batch) + res1 = self.driver1.model( + batch, + fastnlp_fn=self.driver1.model._layers.model.evaluate_step, + # Driver.model -> DataParallel._layers -> _FleetWrappingModel.model + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + res2 = self.driver2.model( + batch, + fastnlp_fn=self.driver2.model._layers.model.evaluate_step, + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + + assert paddle.equal_all(res1["pred"], res2["pred"]) + finally: + if only_state_dict: + synchronize_safe_rm(path) + else: + synchronize_safe_rm(path + ".pdiparams") + synchronize_safe_rm(path + ".pdiparams.info") + synchronize_safe_rm(path + ".pdmodel") + + @magic_argv_env_context + @pytest.mark.parametrize("only_state_dict", ([True, False])) + @pytest.mark.parametrize("fp16", ([True, False])) + def test_save_and_load_with_randombatchsampler(self, only_state_dict, fp16): + return + """ + 测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 + """ + + try: + path = "model.ckp" + + driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) + dataset = PaddleRandomMaxDataset(40, 10) + dataloader = DataLoader( + dataset=dataset, + batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) + ) + num_consumed_batches = 2 + + already_seen_x_set = set() + already_seen_y_set = set() + for idx, batch in enumerate(dataloader): + if idx >= num_consumed_batches: + break + already_seen_x_set.update(batch["x"]) + already_seen_y_set.update(batch["y"]) + + sampler_states = dataloader.batch_sampler.state_dict() + save_states = {"num_consumed_batches": num_consumed_batches} + if only_state_dict: + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + else: + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) + # 加载 + # 更改 batch_size + dataloader = DataLoader( + dataset=dataset, + batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False) + ) + load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + replaced_loader = load_states.pop("dataloader") + # 1. 检查 optimizer 的状态 + # TODO optimizer 的 state_dict 总是为空 + + # 2. 检查 batch_sampler 是否被正确地加载和替换 + assert not (replaced_loader is dataloader) + assert replaced_loader.batch_sampler is dataloader.batch_sampler + assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] + assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 + + # 3. 检查 fp16 是否被加载 + if fp16: + assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler) + + # 4. 检查 model 的参数是否正确 + # 5. 检查 batch_idx + start_batch = load_states.pop('batch_idx_in_epoch') + assert start_batch == 2 * num_consumed_batches + left_x_batches = set() + left_y_batches = set() + for idx, batch in enumerate(replaced_loader): + + left_x_batches.update(batch["x"]) + left_y_batches.update(batch["y"]) + res1 = driver1.model.evaluate_step(**batch) + res2 = driver2.model.evaluate_step(**batch) + assert paddle.equal_all(res1["pred"], res2["pred"]) + + assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) + assert len(left_x_batches | already_seen_x_set) == len(dataset) + assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) + assert len(left_y_batches | already_seen_y_set) == len(dataset) + finally: + synchronize_safe_rm(path) + + @magic_argv_env_context + @pytest.mark.parametrize("only_state_dict", ([True, False])) + @pytest.mark.parametrize("fp16", ([True, False])) + def test_save_and_load_with_randomsampler(self, only_state_dict, fp16): + return + """ + 测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 + """ + + try: + path = "model.ckp" + + driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) + dataset = PaddleRandomMaxDataset(40, 10) + batch_sampler = BatchSampler(dataset=dataset, batch_size=4) + batch_sampler.sampler = RandomSampler(dataset, True) + dataloader = DataLoader( + dataset, + batch_sampler=batch_sampler + ) + num_consumed_batches = 2 + + already_seen_x_set = set() + already_seen_y_set = set() + for idx, batch in enumerate(dataloader): + if idx >= num_consumed_batches: + break + already_seen_x_set.update(batch["x"]) + already_seen_y_set.update(batch["y"]) + + sampler_states = dataloader.batch_sampler.sampler.state_dict() + save_states = {"num_consumed_batches": num_consumed_batches} + if only_state_dict: + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + else: + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) + + # 加载 + # 更改 batch_size + batch_sampler = BatchSampler(dataset=dataset, batch_size=2) + batch_sampler.sampler = RandomSampler(dataset, True) + dataloader = DataLoader( + dataset, + batch_sampler=batch_sampler + ) + load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + replaced_loader = load_states.pop("dataloader") + + # 1. 检查 optimizer 的状态 + # TODO optimizer 的 state_dict 总是为空 + + # 2. 检查 sampler 是否被正确地加载和替换 + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] + assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] + assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches + assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] + assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] + # 3. 检查 fp16 是否被加载 + if fp16: + assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler) + + # 4. 检查 model 的参数是否正确 + # 5. 检查 batch_idx + start_batch = load_states.pop('batch_idx_in_epoch') + assert start_batch == 2 * num_consumed_batches + left_x_batches = set() + left_y_batches = set() + for idx, batch in enumerate(replaced_loader): + + left_x_batches.update(batch["x"]) + left_y_batches.update(batch["y"]) + res1 = driver1.model.evaluate_step(**batch) + res2 = driver2.model.evaluate_step(**batch) + assert paddle.equal_all(res1["pred"], res2["pred"]) + + assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) + assert len(left_x_batches | already_seen_x_set) == len(dataset) + assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) + assert len(left_y_batches | already_seen_y_set) == len(dataset) + finally: + synchronize_safe_rm(path) From fcd27cfc3f88ec4a6154d12d3bb0d8bdb3c44ac5 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sat, 16 Apr 2022 05:50:53 +0000 Subject: [PATCH 04/10] =?UTF-8?q?=E6=B7=BB=E5=8A=A0FASTNLP=5FNO=5FSYNC?= =?UTF-8?q?=E7=9B=B8=E5=85=B3=E7=9A=84=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/drivers/paddle_driver/dist_utils.py | 22 +++++++++++++++++++ fastNLP/core/drivers/paddle_driver/fleet.py | 9 ++++++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/fastNLP/core/drivers/paddle_driver/dist_utils.py b/fastNLP/core/drivers/paddle_driver/dist_utils.py index 3bfbbd4f..4d9ae5f0 100644 --- a/fastNLP/core/drivers/paddle_driver/dist_utils.py +++ b/fastNLP/core/drivers/paddle_driver/dist_utils.py @@ -1,4 +1,5 @@ import io +import os import pickle _pickler = pickle.Pickler _unpickler = pickle.Unpickler @@ -7,6 +8,7 @@ from typing import Any, List from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP from fastNLP.envs.imports import _NEED_IMPORT_TORCH +from fastNLP.envs.env import FASTNLP_NO_SYNC if _NEED_IMPORT_TORCH: import torch from torch import distributed as dist @@ -83,6 +85,14 @@ def fastnlp_paddle_gather_object(obj, object_gather_list=None, dst=0, group=DEFA >>> output ['foo', 12, {1: 2}] """ + if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: + return [obj] + + if dist.get_rank() == dst: + object_gather_list = [None for _ in range(dist.get_world_size(group))] + else: + object_gather_list = None + if group is None: group = DEFAULT_TORCH_GROUP @@ -207,6 +217,9 @@ def fastnlp_paddle_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) :param group: :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 """ + if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: + return [obj] + if group is None: group = DEFAULT_TORCH_GROUP if isinstance(obj, torch.Tensor): @@ -233,6 +246,12 @@ def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GR :param group: :return: """ + if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: + if src == dist.get_rank(group): + return obj + else: + return None + if group is None: group = DEFAULT_TORCH_GROUP cur_rank = dist.get_rank(group) @@ -328,6 +347,9 @@ def all_gather_object(object_list, obj, group=None): >>> output ['foo', 12, {1: 2}] """ + if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: + return [obj] + if dist.distributed_c10d._rank_not_in_group(group): return if _TORCH_GREATER_EQUAL_1_8: diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index ad07da8b..c407ab9f 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -29,7 +29,7 @@ from fastNLP.core.samplers import ( re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler, ) -from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED +from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC from fastNLP.core.log import logger if _NEED_IMPORT_PADDLE: @@ -234,7 +234,8 @@ class PaddleFleetDriver(PaddleDriver): self.global_rank = paddledist.get_rank() def barrier(self): - paddledist.barrier() + if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 + paddledist.barrier() def configure_fleet(self): if not self._has_fleetwrapped and not isinstance(self.model, DataParallel): @@ -451,6 +452,8 @@ class PaddleFleetDriver(PaddleDriver): 接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 """ return + if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC == 2 直接返回。 + return return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group) def all_gather(self, obj, group) -> List: @@ -477,4 +480,6 @@ class PaddleFleetDriver(PaddleDriver): :return: """ return + if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC 表示不执行 + return [obj] return fastnlp_paddle_all_gather(obj, group=group) From 6bfdb39c2f3db859bb980e7a4b7a5685d855ba72 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sat, 16 Apr 2022 06:42:29 +0000 Subject: [PATCH 05/10] =?UTF-8?q?=E5=AE=8C=E5=96=84paddle=20fleet=20set=5F?= =?UTF-8?q?dist=5Frepro=5Fdataloader=E7=9A=84=E6=B5=8B=E8=AF=95=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/drivers/paddle_driver/test_fleet.py | 19 +++++++------------ .../paddle_driver/test_single_device.py | 6 +++--- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index 5fe52c54..125a1c43 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -1,4 +1,3 @@ -from dataclasses import replace import pytest import os @@ -20,13 +19,14 @@ import paddle import paddle.distributed as dist from paddle.io import DataLoader, BatchSampler -def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False): +def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"): paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension) paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) driver = PaddleFleetDriver( model=paddle_model, parallel_device=device, fp16=fp16, + output_from_new_proc=output_from_new_proc ) driver.set_optimizers(paddle_opt) driver.setup() @@ -292,7 +292,6 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.drop_last == dataloader.drop_last self.check_distributed_sampler(replaced_loader.batch_sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() @magic_argv_env_context @@ -319,7 +318,6 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() @magic_argv_env_context @@ -340,7 +338,6 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() """ @@ -372,7 +369,6 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() @magic_argv_env_context @@ -399,7 +395,6 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.drop_last == dataloader.drop_last self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() @magic_argv_env_context @@ -420,7 +415,6 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.drop_last == dataloader.drop_last self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() def check_distributed_sampler(self, sampler): @@ -437,12 +431,14 @@ class TestSetDistReproDataloader: 测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确 """ # 迭代两个 batch + num_replicas = len(self.device) num_consumed_batches = 2 already_seen_idx = set() for idx, batch in enumerate(replaced_loader): if idx >= num_consumed_batches: break already_seen_idx.update(batch) + dist.barrier() if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): sampler_states = replaced_loader.batch_sampler.state_dict() else: @@ -450,6 +446,7 @@ class TestSetDistReproDataloader: # 加载 num_consumed_samples_array,设置正确取出的 batch 数目 num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) + print("array: ", num_consumed_samples_array) # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range left_idxes = set() @@ -458,7 +455,7 @@ class TestSetDistReproDataloader: if num_consumed_samples_array is not None: sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] else: - sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas # 重新改造 dataloader new_loader = DataLoader( dataset=replaced_loader.dataset, @@ -481,7 +478,7 @@ class TestSetDistReproDataloader: if num_consumed_samples_array is not None: sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_samples] else: - sampler_states["num_consumed_samples"] = num_consumed_samples + sampler_states["num_consumed_samples"] = num_consumed_samples * num_replicas # 重新构造 dataloader batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size) batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle) @@ -494,10 +491,8 @@ class TestSetDistReproDataloader: for idx, batch in enumerate(new_loader): left_idxes.update(batch) - num_replicas = len(self.device) assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas - assert False ############################################################################ diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 92c55434..1c9a8241 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -513,11 +513,11 @@ class TestSetDistReproDataloder: new_loader.batch_sampler.load_state_dict(sampler_states) else: batch_size = replaced_loader.batch_sampler.batch_size - num_consumed_batches = num_consumed_batches * batch_size + num_consumed_samples = num_consumed_batches * batch_size if num_consumed_samples_array is not None: - sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] + sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_samples] else: - sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size + sampler_states["num_consumed_samples"] = num_consumed_samples # 重新构造 dataloader batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size) batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle) From 3dbb3677f0d839f918001f528a8410f88c150401 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sat, 16 Apr 2022 08:39:07 +0000 Subject: [PATCH 06/10] =?UTF-8?q?=E5=BE=AE=E8=B0=83=20reproducible=20sampl?= =?UTF-8?q?er=20=E7=9A=84=E5=88=9D=E5=A7=8B=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/samplers/reproducible_batch_sampler.py | 12 ++---------- fastNLP/core/samplers/reproducible_sampler.py | 10 ++-------- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index be43bc74..171a784b 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -19,7 +19,7 @@ from abc import abstractmethod class ReproducibleBatchSampler: def __init__(self, **kwargs): - pass + self.num_replicas = 1 @abstractmethod def set_distributed(self, num_replicas, rank, pad=True): @@ -53,14 +53,6 @@ class ReproducibleBatchSampler: def batch_idx_in_epoch(self): raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") - @property - def num_replicas(self): - return self._num_replicas - - @num_replicas.setter - def num_replicas(self, value): - self._num_replicas = value - class RandomBatchSampler(ReproducibleBatchSampler): # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; @@ -322,7 +314,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): if len(batches[-1])==0: batches.pop(-1) - assert len(list(chain(*batches))) == self.num_left_samples + assert sum(map(len, batches)) == self.num_left_samples if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: batches = batches[:-1] diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index c3facbb9..c8425dc7 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -20,6 +20,8 @@ class ReproducibleSampler: 或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 """ + def __init__(self, **kwargs): + self.num_replicas = 1 def set_distributed(self, num_replicas, rank, pad=True): raise NotImplementedError("Each specific sampler should implement its own `set_distributed` method.") @@ -47,14 +49,6 @@ class ReproducibleSampler: def set_epoch(self, epoch): pass - @property - def num_repliacs(self): - return self._num_replicas - - @num_repliacs.setter - def num_repliacs(self, value): - self._num_replicas = value - class RandomSampler(ReproducibleSampler): def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): From 9b13fd313a59b4cd3081c50b713c65ec36ef1596 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sat, 16 Apr 2022 08:39:38 +0000 Subject: [PATCH 07/10] small --- .../core/drivers/paddle_driver/test_fleet.py | 40 +++++++------------ .../paddle_driver/test_single_device.py | 24 +++++------ 2 files changed, 24 insertions(+), 40 deletions(-) diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index 125a1c43..ad471acd 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -135,7 +135,7 @@ class TestSetDistReproDataloader: @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle): + def test_with_dist_batch_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler @@ -154,7 +154,7 @@ class TestSetDistReproDataloader: @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle): + def test_with_dist_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler @@ -182,7 +182,7 @@ class TestSetDistReproDataloader: """ @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_none_reproducible_true(self): + def test_with_dist_none_reproducible_true(self): """ 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错 @@ -195,8 +195,9 @@ class TestSetDistReproDataloader: dist.barrier() @magic_argv_env_context + # @pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): + def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler 时的表现 @@ -224,7 +225,7 @@ class TestSetDistReproDataloader: @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self, shuffle): + def test_with_dist_none_reproducible_false_dataloader_reproducible_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现 此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 @@ -256,7 +257,7 @@ class TestSetDistReproDataloader: @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self, shuffle): + def test_with_dist_none_reproducible_false_dataloader_normal(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 此时直接返回原来的 dataloader,不做任何处理。 @@ -274,7 +275,7 @@ class TestSetDistReproDataloader: @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle): + def test_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler 的表现 @@ -296,7 +297,7 @@ class TestSetDistReproDataloader: @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self, shuffle): + def test_with_dist_dist_dataloader_reproducible_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler 的表现 @@ -322,7 +323,7 @@ class TestSetDistReproDataloader: @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self, shuffle): + def test_with_dist_dist_dataloader_normal(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现 此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关 @@ -347,7 +348,7 @@ class TestSetDistReproDataloader: @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle): + def test_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler 的表现 @@ -373,7 +374,7 @@ class TestSetDistReproDataloader: @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle): + def test_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler 的表现 @@ -399,7 +400,7 @@ class TestSetDistReproDataloader: @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self, shuffle): + def test_with_dist_unrepeat_dataloader_normal(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现 此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关 @@ -444,18 +445,11 @@ class TestSetDistReproDataloader: else: sampler_states = replaced_loader.batch_sampler.sampler.state_dict() - # 加载 num_consumed_samples_array,设置正确取出的 batch 数目 - num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) - print("array: ", num_consumed_samples_array) - # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range left_idxes = set() if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): batch_size = replaced_loader.batch_sampler.batch_size - if num_consumed_samples_array is not None: - sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] - else: - sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas # 重新改造 dataloader new_loader = DataLoader( dataset=replaced_loader.dataset, @@ -474,11 +468,7 @@ class TestSetDistReproDataloader: new_loader.batch_sampler.load_state_dict(sampler_states) else: batch_size = replaced_loader.batch_sampler.batch_size - num_consumed_samples = num_consumed_batches * batch_size - if num_consumed_samples_array is not None: - sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_samples] - else: - sampler_states["num_consumed_samples"] = num_consumed_samples * num_replicas + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas # 重新构造 dataloader batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size) batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle) diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 1c9a8241..080b6333 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -348,7 +348,7 @@ class TestSingleDeviceFunction: # ############################################################################ -class TestSetDistReproDataloder: +class TestSetDistReproDataloader: """ 专门测试 set_dist_repro_dataloader 函数的类 """ @@ -357,7 +357,7 @@ class TestSetDistReproDataloder: model = PaddleNormalModel_Classification_1(10, 32) self.driver = PaddleSingleDriver(model, device="cpu") - def test_set_dist_repro_dataloader_with_reproducible_false(self): + def test_with_reproducible_false(self): """ 测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 当dist为字符串时,此时应该返回原来的 dataloader @@ -368,7 +368,7 @@ class TestSetDistReproDataloder: assert replaced_loader is dataloader @pytest.mark.parametrize("shuffle", [True, False]) - def test_set_dist_repro_dataloader_with_reproducible_true(self, shuffle): + def test_with_reproducible_true(self, shuffle): """ 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True), @@ -393,7 +393,7 @@ class TestSetDistReproDataloder: self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle): + def test_with_dist_batch_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler @@ -409,7 +409,7 @@ class TestSetDistReproDataloder: self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle): + def test_with_dist_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler @@ -428,7 +428,7 @@ class TestSetDistReproDataloder: self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self, shuffle): + def test_with_dataloader_reproducible_batch_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 应该返回新的 dataloader,且其余各项设置和原来相同 @@ -452,7 +452,7 @@ class TestSetDistReproDataloder: self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self, shuffle): + def test_with_dataloader_reproducible_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 应该返回新的 dataloader,且其余各项设置和原来相同 @@ -497,10 +497,7 @@ class TestSetDistReproDataloder: left_idxes = set() if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): batch_size = replaced_loader.batch_sampler.batch_size - if num_consumed_samples_array is not None: - sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] - else: - sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size # 重新改造 dataloader new_loader = DataLoader( dataset=replaced_loader.dataset, @@ -514,10 +511,7 @@ class TestSetDistReproDataloder: else: batch_size = replaced_loader.batch_sampler.batch_size num_consumed_samples = num_consumed_batches * batch_size - if num_consumed_samples_array is not None: - sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_samples] - else: - sampler_states["num_consumed_samples"] = num_consumed_samples + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size # 重新构造 dataloader batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size) batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle) From 77f6b63ba669e5844af7398892e2301e9f46a7e0 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sat, 16 Apr 2022 08:40:16 +0000 Subject: [PATCH 08/10] =?UTF-8?q?paddle=20save=E5=87=BD=E6=95=B0=E9=80=82?= =?UTF-8?q?=E5=BA=94=E6=96=B0=E7=9A=84sampler?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../drivers/paddle_driver/paddle_driver.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 75e0352f..fe8bf404 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -247,18 +247,27 @@ class PaddleDriver(Driver): # 会造成多余实际消耗的问题。 num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) if num_consumed_samples_array is not None: - if isinstance(sampler, ReproducibleSampler): - # 如果是 sampler 的话,需要计算出实际的 sample 数目 - try: + if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 + if dataloader_args.batch_size is not None: num_consumed_batches = num_consumed_batches * dataloader_args.batch_size - except: # 有可能 batch_size 为 None,就只有损失精度了 + else: # 有可能 batch_size 为 None,就只有损失精度了 + logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " + "it may cause missing some samples when reload.") num_consumed_batches = sampler_states['num_consumed_samples'] sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." - states['sampler_states'] = sampler_states + else: + if dataloader_args.batch_size is not None: + sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ + * num_consumed_batches + else: + logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " + "it may cause missing some samples when reload.") else: raise RuntimeError( "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") + + states['sampler_states'] = sampler_states # 2. 保存模型的状态; if should_save_model: From cb01a661f1662a44995bf9988f58df17c32d1db6 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sat, 16 Apr 2022 15:46:57 +0000 Subject: [PATCH 09/10] =?UTF-8?q?BucketedBatchSampler=E7=9A=84batch=5Fid?= =?UTF-8?q?=5Fin=5Fepoch=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/samplers/reproducible_batch_sampler.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index 171a784b..e8acc645 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -411,4 +411,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler): self.old_num_replicas = states['num_replicas'] def set_epoch(self, epoch): - self.epoch = epoch \ No newline at end of file + self.epoch = epoch + + @property + def batch_idx_in_epoch(self): + if self.drop_last: + return len(self.dataset) // self.batch_size - (len(self.dataset) - self.num_consumed_samples) // self.batch_size + else: + return (len(self.dataset) + self.batch_size - 1) // self.batch_size - \ + (len(self.dataset) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size \ No newline at end of file From 514415e9d47d8a91841c089d1d4b89b4b65f30a4 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sat, 16 Apr 2022 15:47:18 +0000 Subject: [PATCH 10/10] =?UTF-8?q?=E5=AE=8C=E6=88=90=E4=BA=86paddle=20fleet?= =?UTF-8?q?=E7=9A=84save=20load=E5=87=BD=E6=95=B0=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/drivers/paddle_driver/test_fleet.py | 186 ++++++++++++------ 1 file changed, 127 insertions(+), 59 deletions(-) diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index ad471acd..76d1f793 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -1,5 +1,6 @@ import pytest import os +from pathlib import Path os.environ["FASTNLP_BACKEND"] = "paddle" from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver @@ -33,20 +34,6 @@ def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, out return driver -@magic_argv_env_context -def test_multi_drivers(): - """ - 测试使用了多个 PaddleFleetDriver 的情况。 - """ - driver1 = generate_driver(10, 10) - driver2 = generate_driver(20, 10) - - with pytest.raises(RuntimeError): - # 设备设置不同,应该报错 - driver3 = generate_driver(20, 3, device=[0,2]) - - dist.barrier() - ############################################################################ # # 测试 PaddleFleetDriver 的一些函数 @@ -62,6 +49,19 @@ class TestFleetDriverFunction: def setup_class(cls): cls.driver = generate_driver(10, 10) + @magic_argv_env_context + def test_multi_drivers(self): + """ + 测试使用了多个 PaddleFleetDriver 的情况。 + """ + driver2 = generate_driver(20, 10) + + with pytest.raises(RuntimeError): + # 设备设置不同,应该报错 + driver3 = generate_driver(20, 3, device=[0,2]) + + dist.barrier() + @magic_argv_env_context def test_move_data_to_device(self): """ @@ -494,9 +494,14 @@ class TestSaveLoad: """ 测试多卡情况下 save 和 load 相关函数的表现 """ + + @classmethod + def setup_class(cls): + # 不在这里 setup 的话会报错 + cls.driver = generate_driver(10, 10) + def setup_method(self): self.dataset = PaddleRandomMaxDataset(20, 10) - self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10) @magic_argv_env_context @pytest.mark.parametrize("only_state_dict", ([True, False])) @@ -506,7 +511,9 @@ class TestSaveLoad: """ try: path = "model" + dataloader = DataLoader(self.dataset, batch_size=2) + self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10) if only_state_dict: self.driver1.save_model(path, only_state_dict) @@ -545,20 +552,30 @@ class TestSaveLoad: @magic_argv_env_context @pytest.mark.parametrize("only_state_dict", ([True, False])) @pytest.mark.parametrize("fp16", ([True, False])) - def test_save_and_load_with_randombatchsampler(self, only_state_dict, fp16): - return + @pytest.mark.parametrize("device", ([[0,1]])) + def test_save_and_load_with_bucketedbatchsampler(self, device, only_state_dict, fp16): """ 测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 """ try: path = "model.ckp" + num_replicas = len(device) - driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) - dataset = PaddleRandomMaxDataset(40, 10) + self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ + generate_driver(10, 10, device=device, fp16=False) dataloader = DataLoader( - dataset=dataset, - batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) + dataset=self.dataset, + batch_sampler=BucketedBatchSampler( + self.dataset, + length=[10 for i in range(len(self.dataset))], + batch_size=4, + ) + ) + dataloader.batch_sampler.set_distributed( + num_replicas=self.driver1.world_size, + rank=self.driver1.global_rank, + pad=True ) num_consumed_batches = 2 @@ -570,19 +587,32 @@ class TestSaveLoad: already_seen_x_set.update(batch["x"]) already_seen_y_set.update(batch["y"]) + # 同步 + dist.barrier() + + # 保存状态 sampler_states = dataloader.batch_sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} if only_state_dict: - driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) else: - driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) + self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) # 加载 # 更改 batch_size dataloader = DataLoader( - dataset=dataset, - batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False) + dataset=self.dataset, + batch_sampler=BucketedBatchSampler( + self.dataset, + length=[10 for i in range(len(self.dataset))], + batch_size=4, + ) + ) + dataloader.batch_sampler.set_distributed( + num_replicas=self.driver2.world_size, + rank=self.driver2.global_rank, + pad=True ) - load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 # TODO optimizer 的 state_dict 总是为空 @@ -590,13 +620,13 @@ class TestSaveLoad: # 2. 检查 batch_sampler 是否被正确地加载和替换 assert not (replaced_loader is dataloader) assert replaced_loader.batch_sampler is dataloader.batch_sampler - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) - assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] - assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 + assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) + assert replaced_loader.batch_sampler.seed == sampler_states["seed"] + assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas # 3. 检查 fp16 是否被加载 if fp16: - assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler) + assert isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler) # 4. 检查 model 的参数是否正确 # 5. 检查 batch_idx @@ -608,22 +638,33 @@ class TestSaveLoad: left_x_batches.update(batch["x"]) left_y_batches.update(batch["y"]) - res1 = driver1.model.evaluate_step(**batch) - res2 = driver2.model.evaluate_step(**batch) + res1 = self.driver1.model( + batch, + fastnlp_fn=self.driver1.model._layers.model.evaluate_step, + # Driver.model -> DataParallel._layers -> _FleetWrappingModel.model + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + res2 = self.driver2.model( + batch, + fastnlp_fn=self.driver2.model._layers.model.evaluate_step, + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) assert paddle.equal_all(res1["pred"], res2["pred"]) - assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) - assert len(left_x_batches | already_seen_x_set) == len(dataset) - assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) - assert len(left_y_batches | already_seen_y_set) == len(dataset) + assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas + assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas finally: synchronize_safe_rm(path) @magic_argv_env_context @pytest.mark.parametrize("only_state_dict", ([True, False])) @pytest.mark.parametrize("fp16", ([True, False])) - def test_save_and_load_with_randomsampler(self, only_state_dict, fp16): - return + @pytest.mark.parametrize("device", ([[0,1]])) + def test_save_and_load_with_randomsampler(self, device, only_state_dict, fp16): """ 测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 """ @@ -631,12 +672,19 @@ class TestSaveLoad: try: path = "model.ckp" - driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) - dataset = PaddleRandomMaxDataset(40, 10) - batch_sampler = BatchSampler(dataset=dataset, batch_size=4) - batch_sampler.sampler = RandomSampler(dataset, True) + num_replicas = len(device) + + self.driver1 = generate_driver(10, 10, device=device, fp16=fp16) + self.driver2 = generate_driver(10, 10, device=device, fp16=False) + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4) + batch_sampler.sampler = RandomSampler(self.dataset, True) + batch_sampler.sampler.set_distributed( + num_replicas=self.driver1.world_size, + rank=self.driver1.global_rank, + pad=True + ) dataloader = DataLoader( - dataset, + self.dataset, batch_sampler=batch_sampler ) num_consumed_batches = 2 @@ -649,22 +697,30 @@ class TestSaveLoad: already_seen_x_set.update(batch["x"]) already_seen_y_set.update(batch["y"]) + # 同步 + dist.barrier() + + # 保存状态 sampler_states = dataloader.batch_sampler.sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} if only_state_dict: - driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) else: - driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) - + self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) # 加载 # 更改 batch_size - batch_sampler = BatchSampler(dataset=dataset, batch_size=2) - batch_sampler.sampler = RandomSampler(dataset, True) + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) + batch_sampler.sampler = RandomSampler(self.dataset, True) + batch_sampler.sampler.set_distributed( + num_replicas=self.driver2.world_size, + rank=self.driver2.global_rank, + pad=True + ) dataloader = DataLoader( - dataset, + self.dataset, batch_sampler=batch_sampler ) - load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 @@ -675,12 +731,12 @@ class TestSaveLoad: assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] - assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches + assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches * num_replicas assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] # 3. 检查 fp16 是否被加载 if fp16: - assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler) + assert isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler) # 4. 检查 model 的参数是否正确 # 5. 检查 batch_idx @@ -692,13 +748,25 @@ class TestSaveLoad: left_x_batches.update(batch["x"]) left_y_batches.update(batch["y"]) - res1 = driver1.model.evaluate_step(**batch) - res2 = driver2.model.evaluate_step(**batch) + res1 = self.driver1.model( + batch, + fastnlp_fn=self.driver1.model._layers.model.evaluate_step, + # Driver.model -> DataParallel._layers -> _FleetWrappingModel.model + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + res2 = self.driver2.model( + batch, + fastnlp_fn=self.driver2.model._layers.model.evaluate_step, + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) assert paddle.equal_all(res1["pred"], res2["pred"]) - assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) - assert len(left_x_batches | already_seen_x_set) == len(dataset) - assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) - assert len(left_y_batches | already_seen_y_set) == len(dataset) + assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas + assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas + finally: - synchronize_safe_rm(path) + synchronize_safe_rm(path) \ No newline at end of file