diff --git a/fastNLP/core/drivers/paddle_driver/dist_utils.py b/fastNLP/core/drivers/paddle_driver/dist_utils.py index 3bfbbd4f..4d9ae5f0 100644 --- a/fastNLP/core/drivers/paddle_driver/dist_utils.py +++ b/fastNLP/core/drivers/paddle_driver/dist_utils.py @@ -1,4 +1,5 @@ import io +import os import pickle _pickler = pickle.Pickler _unpickler = pickle.Unpickler @@ -7,6 +8,7 @@ from typing import Any, List from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP from fastNLP.envs.imports import _NEED_IMPORT_TORCH +from fastNLP.envs.env import FASTNLP_NO_SYNC if _NEED_IMPORT_TORCH: import torch from torch import distributed as dist @@ -83,6 +85,14 @@ def fastnlp_paddle_gather_object(obj, object_gather_list=None, dst=0, group=DEFA >>> output ['foo', 12, {1: 2}] """ + if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: + return [obj] + + if dist.get_rank() == dst: + object_gather_list = [None for _ in range(dist.get_world_size(group))] + else: + object_gather_list = None + if group is None: group = DEFAULT_TORCH_GROUP @@ -207,6 +217,9 @@ def fastnlp_paddle_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) :param group: :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 """ + if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: + return [obj] + if group is None: group = DEFAULT_TORCH_GROUP if isinstance(obj, torch.Tensor): @@ -233,6 +246,12 @@ def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GR :param group: :return: """ + if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: + if src == dist.get_rank(group): + return obj + else: + return None + if group is None: group = DEFAULT_TORCH_GROUP cur_rank = dist.get_rank(group) @@ -328,6 +347,9 @@ def all_gather_object(object_list, obj, group=None): >>> output ['foo', 12, {1: 2}] """ + if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: + return [obj] + if dist.distributed_c10d._rank_not_in_group(group): return if _TORCH_GREATER_EQUAL_1_8: diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 6e2c85ee..c407ab9f 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -1,6 +1,5 @@ import os import shutil -from functools import partial from typing import List, Union, Optional, Dict, Tuple, Callable from .paddle_driver import PaddleDriver @@ -30,7 +29,7 @@ from fastNLP.core.samplers import ( re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler, ) -from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED +from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC from fastNLP.core.log import logger if _NEED_IMPORT_PADDLE: @@ -38,7 +37,6 @@ if _NEED_IMPORT_PADDLE: from paddle import DataParallel import paddle.distributed.fleet as fleet import paddle.distributed as paddledist - from paddle.io import BatchSampler from paddle.optimizer import Optimizer from paddle.fluid.reader import _DatasetKind from paddle.fluid.dygraph import parallel_helper @@ -236,7 +234,8 @@ class PaddleFleetDriver(PaddleDriver): self.global_rank = paddledist.get_rank() def barrier(self): - paddledist.barrier() + if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 + paddledist.barrier() def configure_fleet(self): if not self._has_fleetwrapped and not isinstance(self.model, DataParallel): @@ -305,7 +304,7 @@ class PaddleFleetDriver(PaddleDriver): raise RuntimeError(f"There is no `{fn}` method in your model.") else: if hasattr(model, fn): - logger.warning("Notice your model is a `DistributedDataParallel` model. And your model also implements " + logger.warning("Notice your model is a `DataParallel` model. And your model also implements " f"the `{fn}` method, which we can not call actually, we will" " call `forward` function instead of `train_step` and you should note that.") elif fn not in {"train_step", "evaluate_step"}: @@ -453,6 +452,8 @@ class PaddleFleetDriver(PaddleDriver): 接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 """ return + if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC == 2 直接返回。 + return return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group) def all_gather(self, obj, group) -> List: @@ -479,4 +480,6 @@ class PaddleFleetDriver(PaddleDriver): :return: """ return + if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC 表示不执行 + return [obj] return fastnlp_paddle_all_gather(obj, group=group) diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 3b8ad7d8..fe8bf404 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -7,7 +7,7 @@ from dataclasses import dataclass import numpy as np -from .utils import _build_fp16_env, optimizer_state_to_device +from .utils import _build_fp16_env, optimizer_state_to_device, DummyGradScaler from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.core.drivers.driver import Driver from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device @@ -247,18 +247,27 @@ class PaddleDriver(Driver): # 会造成多余实际消耗的问题。 num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) if num_consumed_samples_array is not None: - if isinstance(sampler, ReproducibleSampler): - # 如果是 sampler 的话,需要计算出实际的 sample 数目 - try: + if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 + if dataloader_args.batch_size is not None: num_consumed_batches = num_consumed_batches * dataloader_args.batch_size - except: # 有可能 batch_size 为 None,就只有损失精度了 + else: # 有可能 batch_size 为 None,就只有损失精度了 + logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " + "it may cause missing some samples when reload.") num_consumed_batches = sampler_states['num_consumed_samples'] sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." - states['sampler_states'] = sampler_states + else: + if dataloader_args.batch_size is not None: + sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ + * num_consumed_batches + else: + logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " + "it may cause missing some samples when reload.") else: raise RuntimeError( "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") + + states['sampler_states'] = sampler_states # 2. 保存模型的状态; if should_save_model: @@ -278,6 +287,12 @@ class PaddleDriver(Driver): logger.debug("Save optimizer state dict.") states["optimizers_state_dict"] = optimizers_state_dict + + # 4.保存fp16的状态 + if not isinstance(self.grad_scaler, DummyGradScaler): + grad_scaler_state_dict = self.grad_scaler.state_dict() + states['grad_scaler_state_dict'] = grad_scaler_state_dict + paddle.save(states, str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: @@ -285,7 +300,7 @@ class PaddleDriver(Driver): states = paddle.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) # 1. 加载 optimizers 的状态; - optimizers_state_dict = states["optimizers_state_dict"] + optimizers_state_dict = states.pop("optimizers_state_dict") for i in range(len(self.optimizers)): optimizer: Optimizer = self.optimizers[i] optimizer.set_state_dict(optimizers_state_dict[f"optimizer{i}"]) @@ -295,18 +310,32 @@ class PaddleDriver(Driver): if should_load_model: self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict) if only_state_dict: - logger.debug("Load model state dict.") + logger.debug("Load model state dict...") else: - logger.debug("Load model.") - - # 3. 恢复 sampler 的状态; + logger.debug("Load model...") + + # 3. 加载fp16的状态; + if "grad_scaler_state_dict" in states: + grad_scaler_state_dict = states.pop("grad_scaler_state_dict") + if isinstance(self.grad_scaler, DummyGradScaler): + self.auto_cast, _grad_scaler = _build_fp16_env(dummy=False) + self.grad_scaler = _grad_scaler() + self.fp16 = True + self.grad_scaler.load_state_dict(grad_scaler_state_dict) + logger.debug("Load grad_scaler state dict...") + elif not isinstance(self.grad_scaler, DummyGradScaler): + logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " + f"the training process may be unstable.") + + # 4. 恢复 sampler 的状态; dataloader_args = self.get_dataloader_args(dataloader) if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): sampler = dataloader_args.batch_sampler elif isinstance(dataloader_args.sampler, ReproducibleSampler): sampler = dataloader_args.sampler elif self.is_distributed(): - raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") + raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " + "`ReproducibleSampler`.") else: sampler = RandomBatchSampler( batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, @@ -316,7 +345,7 @@ class PaddleDriver(Driver): sampler.load_state_dict(states["sampler_states"]) states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) - # 4. 修改 trainer_state.batch_idx_in_epoch + # 5. 修改 trainer_state.batch_idx_in_epoch # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; if not isinstance(sampler, ReproducibleBatchSampler): if dataloader_args.drop_last: diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index feb5c3eb..48598a34 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -19,7 +19,7 @@ if _NEED_IMPORT_PADDLE: import paddle from paddle import nn from paddle.nn import Layer - from paddle.io import DataLoader, BatchSampler, Dataset + from paddle.io import DataLoader, BatchSampler from paddle.amp import auto_cast, GradScaler else: from fastNLP.core.utils.dummy_class import DummyClass as Layer @@ -140,8 +140,7 @@ class DummyGradScaler: def _build_fp16_env(dummy=False): if dummy: - auto_cast = ExitStack - GradScaler = DummyGradScaler + return ExitStack, DummyGradScaler else: if not paddle.device.is_compiled_with_cuda(): raise RuntimeError("No cuda") @@ -150,7 +149,7 @@ def _build_fp16_env(dummy=False): "NOTE: your device does NOT support faster training with fp16, " "please switch to FP32 which is likely to be faster" ) - return auto_cast, GradScaler + return auto_cast, GradScaler def find_free_ports(num): def __free_port(): diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index be43bc74..e8acc645 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -19,7 +19,7 @@ from abc import abstractmethod class ReproducibleBatchSampler: def __init__(self, **kwargs): - pass + self.num_replicas = 1 @abstractmethod def set_distributed(self, num_replicas, rank, pad=True): @@ -53,14 +53,6 @@ class ReproducibleBatchSampler: def batch_idx_in_epoch(self): raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") - @property - def num_replicas(self): - return self._num_replicas - - @num_replicas.setter - def num_replicas(self, value): - self._num_replicas = value - class RandomBatchSampler(ReproducibleBatchSampler): # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; @@ -322,7 +314,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): if len(batches[-1])==0: batches.pop(-1) - assert len(list(chain(*batches))) == self.num_left_samples + assert sum(map(len, batches)) == self.num_left_samples if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: batches = batches[:-1] @@ -419,4 +411,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler): self.old_num_replicas = states['num_replicas'] def set_epoch(self, epoch): - self.epoch = epoch \ No newline at end of file + self.epoch = epoch + + @property + def batch_idx_in_epoch(self): + if self.drop_last: + return len(self.dataset) // self.batch_size - (len(self.dataset) - self.num_consumed_samples) // self.batch_size + else: + return (len(self.dataset) + self.batch_size - 1) // self.batch_size - \ + (len(self.dataset) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size \ No newline at end of file diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index c3facbb9..c8425dc7 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -20,6 +20,8 @@ class ReproducibleSampler: 或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 """ + def __init__(self, **kwargs): + self.num_replicas = 1 def set_distributed(self, num_replicas, rank, pad=True): raise NotImplementedError("Each specific sampler should implement its own `set_distributed` method.") @@ -47,14 +49,6 @@ class ReproducibleSampler: def set_epoch(self, epoch): pass - @property - def num_repliacs(self): - return self._num_replicas - - @num_repliacs.setter - def num_repliacs(self, value): - self._num_replicas = value - class RandomSampler(ReproducibleSampler): def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index c775a3a2..76d1f793 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -1,6 +1,6 @@ -from dataclasses import replace import pytest import os +from pathlib import Path os.environ["FASTNLP_BACKEND"] = "paddle" from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver @@ -12,19 +12,22 @@ from fastNLP.core.samplers import ( UnrepeatedSequentialSampler, ) from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 -from tests.helpers.datasets.paddle_data import PaddleNormalDataset +from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset from tests.helpers.utils import magic_argv_env_context +from fastNLP.core import synchronize_safe_rm import paddle import paddle.distributed as dist from paddle.io import DataLoader, BatchSampler -def generate_driver(num_labels, feature_dimension): +def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"): paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension) paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) driver = PaddleFleetDriver( model=paddle_model, - parallel_device=[0,1], + parallel_device=device, + fp16=fp16, + output_from_new_proc=output_from_new_proc ) driver.set_optimizers(paddle_opt) driver.setup() @@ -33,7 +36,7 @@ def generate_driver(num_labels, feature_dimension): ############################################################################ # -# 测试PaddleFleetDriver的一些函数 +# 测试 PaddleFleetDriver 的一些函数 # ############################################################################ @@ -46,6 +49,19 @@ class TestFleetDriverFunction: def setup_class(cls): cls.driver = generate_driver(10, 10) + @magic_argv_env_context + def test_multi_drivers(self): + """ + 测试使用了多个 PaddleFleetDriver 的情况。 + """ + driver2 = generate_driver(20, 10) + + with pytest.raises(RuntimeError): + # 设备设置不同,应该报错 + driver3 = generate_driver(20, 3, device=[0,2]) + + dist.barrier() + @magic_argv_env_context def test_move_data_to_device(self): """ @@ -106,10 +122,11 @@ class TestSetDistReproDataloader: @classmethod def setup_class(cls): - cls.driver = generate_driver(10, 10) + cls.device = [0, 1] + cls.driver = generate_driver(10, 10, device=cls.device) def setup_method(self): - self.dataset = PaddleNormalDataset(20) + self.dataset = PaddleNormalDataset(40) """ 传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 @@ -118,9 +135,10 @@ class TestSetDistReproDataloader: @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle): + def test_with_dist_batch_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 + 此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler """ dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) @@ -130,14 +148,16 @@ class TestSetDistReproDataloader: assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) assert replaced_loader.batch_sampler is batch_sampler self.check_distributed_sampler(replaced_loader.batch_sampler) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle): + def test_with_dist_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 + 此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler """ dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) sampler = RandomSampler(self.dataset, shuffle=shuffle) @@ -150,6 +170,7 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.sampler is sampler assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() @@ -161,9 +182,10 @@ class TestSetDistReproDataloader: """ @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_none_reproducible_true(self): + def test_with_dist_none_reproducible_true(self): """ 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 + 当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错 """ dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) with pytest.raises(RuntimeError): @@ -173,11 +195,14 @@ class TestSetDistReproDataloader: dist.barrier() @magic_argv_env_context + # @pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): + def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler 时的表现 + 此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler + 和原 dataloader 相同 """ dataloader = DataLoader( self.dataset, @@ -194,16 +219,19 @@ class TestSetDistReproDataloader: assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) assert replaced_loader.batch_sampler.batch_size == 4 self.check_distributed_sampler(dataloader.batch_sampler) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) dist.barrier() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self, shuffle): + def test_with_dist_none_reproducible_false_dataloader_reproducible_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现 + 此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 + batch_sampler.sampler 和原 dataloader 相同 """ - batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4) batch_sampler.sampler = RandomSampler(self.dataset, shuffle) batch_sampler.sampler.set_distributed( num_replicas=self.driver.world_size, @@ -220,16 +248,19 @@ class TestSetDistReproDataloader: assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) - assert replaced_loader.batch_sampler.batch_size == 2 + assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.batch_sampler.drop_last == False self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + dist.barrier() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self, shuffle): + def test_with_dist_none_reproducible_false_dataloader_normal(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 + 此时直接返回原来的 dataloader,不做任何处理。 """ dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) @@ -244,10 +275,11 @@ class TestSetDistReproDataloader: @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle): + def test_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler 的表现 + 此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性 """ dataloader = DataLoader( dataset=self.dataset, @@ -265,12 +297,14 @@ class TestSetDistReproDataloader: @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self, shuffle): + def test_with_dist_dist_dataloader_reproducible_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler 的表现 + 此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关 + 的属性 """ - batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2, shuffle=shuffle) + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4, shuffle=shuffle) batch_sampler.sampler = RandomSampler(self.dataset, shuffle) dataloader = DataLoader( self.dataset, @@ -282,16 +316,18 @@ class TestSetDistReproDataloader: assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) - assert replaced_loader.batch_sampler.batch_size == 2 + assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self, shuffle): + def test_with_dist_dist_dataloader_normal(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现 + 此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关 + 的属性 """ dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) @@ -302,6 +338,7 @@ class TestSetDistReproDataloader: assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.batch_sampler.sampler.shuffle == shuffle + self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() """ @@ -311,12 +348,14 @@ class TestSetDistReproDataloader: @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle): + def test_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler 的表现 + 此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关 + 的属性 """ - batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4) batch_sampler.sampler = RandomSampler(self.dataset, shuffle) dataloader = DataLoader( self.dataset, @@ -328,19 +367,20 @@ class TestSetDistReproDataloader: assert isinstance(replaced_loader.batch_sampler, BatchSampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) - assert replaced_loader.batch_sampler.batch_size == 2 + assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle): + def test_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler 的表现 + 此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler """ - batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4) batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, shuffle) dataloader = DataLoader( self.dataset, @@ -353,16 +393,18 @@ class TestSetDistReproDataloader: assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) - assert replaced_loader.batch_sampler.batch_size == 2 + assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.drop_last == dataloader.drop_last self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self, shuffle): + def test_with_dist_unrepeat_dataloader_normal(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现 + 此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关 + 的属性 """ dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) @@ -385,3 +427,346 @@ class TestSetDistReproDataloader: if not isinstance(sampler, UnrepeatedSampler): assert sampler.pad == True + def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle): + """ + 测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确 + """ + # 迭代两个 batch + num_replicas = len(self.device) + num_consumed_batches = 2 + already_seen_idx = set() + for idx, batch in enumerate(replaced_loader): + if idx >= num_consumed_batches: + break + already_seen_idx.update(batch) + dist.barrier() + if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): + sampler_states = replaced_loader.batch_sampler.state_dict() + else: + sampler_states = replaced_loader.batch_sampler.sampler.state_dict() + + # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range + left_idxes = set() + if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): + batch_size = replaced_loader.batch_sampler.batch_size + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas + # 重新改造 dataloader + new_loader = DataLoader( + dataset=replaced_loader.dataset, + batch_sampler=BucketedBatchSampler( + replaced_loader.dataset, + length=replaced_loader.dataset._data, + batch_size=batch_size, + shuffle=shuffle, + ) + ) + new_loader.batch_sampler.set_distributed( + num_replicas=self.driver.world_size, + rank=self.driver.global_rank, + pad=True + ) + new_loader.batch_sampler.load_state_dict(sampler_states) + else: + batch_size = replaced_loader.batch_sampler.batch_size + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas + # 重新构造 dataloader + batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size) + batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle) + batch_sampler.sampler.set_distributed( + num_replicas=self.driver.world_size, + rank=self.driver.global_rank + ) + new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler) + new_loader.batch_sampler.sampler.load_state_dict(sampler_states) + for idx, batch in enumerate(new_loader): + left_idxes.update(batch) + + assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas + assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas + + +############################################################################ +# +# 测试 save 和 load 相关的功能 +# +############################################################################ +class TestSaveLoad: + """ + 测试多卡情况下 save 和 load 相关函数的表现 + """ + + @classmethod + def setup_class(cls): + # 不在这里 setup 的话会报错 + cls.driver = generate_driver(10, 10) + + def setup_method(self): + self.dataset = PaddleRandomMaxDataset(20, 10) + + @magic_argv_env_context + @pytest.mark.parametrize("only_state_dict", ([True, False])) + def test_save_and_load_model(self, only_state_dict): + """ + 测试 save_model 和 load_model 函数 + """ + try: + path = "model" + + dataloader = DataLoader(self.dataset, batch_size=2) + self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10) + + if only_state_dict: + self.driver1.save_model(path, only_state_dict) + else: + self.driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((4, 10))]) + + # 同步 + dist.barrier() + self.driver2.load_model(path, only_state_dict) + + for idx, batch in enumerate(dataloader): + batch = self.driver1.move_data_to_device(batch) + res1 = self.driver1.model( + batch, + fastnlp_fn=self.driver1.model._layers.model.evaluate_step, + # Driver.model -> DataParallel._layers -> _FleetWrappingModel.model + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + res2 = self.driver2.model( + batch, + fastnlp_fn=self.driver2.model._layers.model.evaluate_step, + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + + assert paddle.equal_all(res1["pred"], res2["pred"]) + finally: + if only_state_dict: + synchronize_safe_rm(path) + else: + synchronize_safe_rm(path + ".pdiparams") + synchronize_safe_rm(path + ".pdiparams.info") + synchronize_safe_rm(path + ".pdmodel") + + @magic_argv_env_context + @pytest.mark.parametrize("only_state_dict", ([True, False])) + @pytest.mark.parametrize("fp16", ([True, False])) + @pytest.mark.parametrize("device", ([[0,1]])) + def test_save_and_load_with_bucketedbatchsampler(self, device, only_state_dict, fp16): + """ + 测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 + """ + + try: + path = "model.ckp" + num_replicas = len(device) + + self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ + generate_driver(10, 10, device=device, fp16=False) + dataloader = DataLoader( + dataset=self.dataset, + batch_sampler=BucketedBatchSampler( + self.dataset, + length=[10 for i in range(len(self.dataset))], + batch_size=4, + ) + ) + dataloader.batch_sampler.set_distributed( + num_replicas=self.driver1.world_size, + rank=self.driver1.global_rank, + pad=True + ) + num_consumed_batches = 2 + + already_seen_x_set = set() + already_seen_y_set = set() + for idx, batch in enumerate(dataloader): + if idx >= num_consumed_batches: + break + already_seen_x_set.update(batch["x"]) + already_seen_y_set.update(batch["y"]) + + # 同步 + dist.barrier() + + # 保存状态 + sampler_states = dataloader.batch_sampler.state_dict() + save_states = {"num_consumed_batches": num_consumed_batches} + if only_state_dict: + self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + else: + self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) + # 加载 + # 更改 batch_size + dataloader = DataLoader( + dataset=self.dataset, + batch_sampler=BucketedBatchSampler( + self.dataset, + length=[10 for i in range(len(self.dataset))], + batch_size=4, + ) + ) + dataloader.batch_sampler.set_distributed( + num_replicas=self.driver2.world_size, + rank=self.driver2.global_rank, + pad=True + ) + load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + replaced_loader = load_states.pop("dataloader") + # 1. 检查 optimizer 的状态 + # TODO optimizer 的 state_dict 总是为空 + + # 2. 检查 batch_sampler 是否被正确地加载和替换 + assert not (replaced_loader is dataloader) + assert replaced_loader.batch_sampler is dataloader.batch_sampler + assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) + assert replaced_loader.batch_sampler.seed == sampler_states["seed"] + assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas + + # 3. 检查 fp16 是否被加载 + if fp16: + assert isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler) + + # 4. 检查 model 的参数是否正确 + # 5. 检查 batch_idx + start_batch = load_states.pop('batch_idx_in_epoch') + assert start_batch == 2 * num_consumed_batches + left_x_batches = set() + left_y_batches = set() + for idx, batch in enumerate(replaced_loader): + + left_x_batches.update(batch["x"]) + left_y_batches.update(batch["y"]) + res1 = self.driver1.model( + batch, + fastnlp_fn=self.driver1.model._layers.model.evaluate_step, + # Driver.model -> DataParallel._layers -> _FleetWrappingModel.model + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + res2 = self.driver2.model( + batch, + fastnlp_fn=self.driver2.model._layers.model.evaluate_step, + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + assert paddle.equal_all(res1["pred"], res2["pred"]) + + assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas + assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas + finally: + synchronize_safe_rm(path) + + @magic_argv_env_context + @pytest.mark.parametrize("only_state_dict", ([True, False])) + @pytest.mark.parametrize("fp16", ([True, False])) + @pytest.mark.parametrize("device", ([[0,1]])) + def test_save_and_load_with_randomsampler(self, device, only_state_dict, fp16): + """ + 测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 + """ + + try: + path = "model.ckp" + + num_replicas = len(device) + + self.driver1 = generate_driver(10, 10, device=device, fp16=fp16) + self.driver2 = generate_driver(10, 10, device=device, fp16=False) + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4) + batch_sampler.sampler = RandomSampler(self.dataset, True) + batch_sampler.sampler.set_distributed( + num_replicas=self.driver1.world_size, + rank=self.driver1.global_rank, + pad=True + ) + dataloader = DataLoader( + self.dataset, + batch_sampler=batch_sampler + ) + num_consumed_batches = 2 + + already_seen_x_set = set() + already_seen_y_set = set() + for idx, batch in enumerate(dataloader): + if idx >= num_consumed_batches: + break + already_seen_x_set.update(batch["x"]) + already_seen_y_set.update(batch["y"]) + + # 同步 + dist.barrier() + + # 保存状态 + sampler_states = dataloader.batch_sampler.sampler.state_dict() + save_states = {"num_consumed_batches": num_consumed_batches} + if only_state_dict: + self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + else: + self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) + # 加载 + # 更改 batch_size + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) + batch_sampler.sampler = RandomSampler(self.dataset, True) + batch_sampler.sampler.set_distributed( + num_replicas=self.driver2.world_size, + rank=self.driver2.global_rank, + pad=True + ) + dataloader = DataLoader( + self.dataset, + batch_sampler=batch_sampler + ) + load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + replaced_loader = load_states.pop("dataloader") + + # 1. 检查 optimizer 的状态 + # TODO optimizer 的 state_dict 总是为空 + + # 2. 检查 sampler 是否被正确地加载和替换 + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] + assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] + assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches * num_replicas + assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] + assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] + # 3. 检查 fp16 是否被加载 + if fp16: + assert isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler) + + # 4. 检查 model 的参数是否正确 + # 5. 检查 batch_idx + start_batch = load_states.pop('batch_idx_in_epoch') + assert start_batch == 2 * num_consumed_batches + left_x_batches = set() + left_y_batches = set() + for idx, batch in enumerate(replaced_loader): + + left_x_batches.update(batch["x"]) + left_y_batches.update(batch["y"]) + res1 = self.driver1.model( + batch, + fastnlp_fn=self.driver1.model._layers.model.evaluate_step, + # Driver.model -> DataParallel._layers -> _FleetWrappingModel.model + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + res2 = self.driver2.model( + batch, + fastnlp_fn=self.driver2.model._layers.model.evaluate_step, + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + assert paddle.equal_all(res1["pred"], res2["pred"]) + + assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas + assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas + + finally: + synchronize_safe_rm(path) \ No newline at end of file diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 0c8e4256..c80bd609 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -1,4 +1,3 @@ -from dataclasses import replace import os from re import S os.environ["FASTNLP_BACKEND"] = "paddle" @@ -349,7 +348,7 @@ class TestSingleDeviceFunction: # ############################################################################ -class TestSetDistReproDataloder: +class TestSetDistReproDataloader: """ 专门测试 set_dist_repro_dataloader 函数的类 """ @@ -358,7 +357,7 @@ class TestSetDistReproDataloder: model = PaddleNormalModel_Classification_1(10, 32) self.driver = PaddleSingleDriver(model, device="cpu") - def test_set_dist_repro_dataloader_with_reproducible_false(self): + def test_with_reproducible_false(self): """ 测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 当dist为字符串时,此时应该返回原来的 dataloader @@ -369,7 +368,7 @@ class TestSetDistReproDataloder: assert replaced_loader is dataloader @pytest.mark.parametrize("shuffle", [True, False]) - def test_set_dist_repro_dataloader_with_reproducible_true(self, shuffle): + def test_with_reproducible_true(self, shuffle): """ 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True), @@ -394,7 +393,7 @@ class TestSetDistReproDataloder: self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle): + def test_with_dist_batch_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler @@ -410,7 +409,7 @@ class TestSetDistReproDataloder: self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle): + def test_with_dist_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler @@ -429,7 +428,7 @@ class TestSetDistReproDataloder: self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self, shuffle): + def test_with_dataloader_reproducible_batch_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 应该返回新的 dataloader,且其余各项设置和原来相同 @@ -453,7 +452,7 @@ class TestSetDistReproDataloder: self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) @pytest.mark.parametrize("shuffle", ([True, False])) - def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self, shuffle): + def test_with_dataloader_reproducible_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 应该返回新的 dataloader,且其余各项设置和原来相同 @@ -498,10 +497,7 @@ class TestSetDistReproDataloder: left_idxes = set() if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): batch_size = replaced_loader.batch_sampler.batch_size - if num_consumed_samples_array is not None: - sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] - else: - sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size # 重新改造 dataloader new_loader = DataLoader( dataset=replaced_loader.dataset, @@ -514,11 +510,8 @@ class TestSetDistReproDataloder: new_loader.batch_sampler.load_state_dict(sampler_states) else: batch_size = replaced_loader.batch_sampler.batch_size - num_consumed_batches = num_consumed_batches * batch_size - if num_consumed_samples_array is not None: - sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] - else: - sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size + num_consumed_samples = num_consumed_batches * batch_size + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size # 重新构造 dataloader batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size) batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle) @@ -536,13 +529,13 @@ class TestSetDistReproDataloder: # ############################################################################ -def generate_random_driver(features, labels): +def generate_random_driver(features, labels, fp16=False, device="cpu"): """ 生成driver """ model = PaddleNormalModel_Classification_1(labels, features) opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.01) - driver = PaddleSingleDriver(model, device="cpu") + driver = PaddleSingleDriver(model, device=device, fp16=fp16) driver.set_optimizers(opt) driver.setup() @@ -550,8 +543,8 @@ def generate_random_driver(features, labels): @pytest.fixture def prepare_test_save_load(): - dataset = PaddleRandomMaxDataset(320, 10) - dataloader = DataLoader(dataset, batch_size=32) + dataset = PaddleRandomMaxDataset(40, 10) + dataloader = DataLoader(dataset, batch_size=4) driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) return driver1, driver2, dataloader @@ -584,21 +577,23 @@ def test_save_and_load_model(prepare_test_save_load, only_state_dict): rank_zero_rm(path + ".pdiparams.info") rank_zero_rm(path + ".pdmodel") -@pytest.mark.parametrize("only_state_dict", ([True, False])) -def test_save_and_load_with_randombatchsampler(only_state_dict): +# @pytest.mark.parametrize("only_state_dict", ([True, False])) +@pytest.mark.parametrize("only_state_dict", ([True])) +@pytest.mark.parametrize("fp16", ([True, False])) +def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): """ 测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 """ try: path = "model.ckp" - - driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) dataset = PaddleRandomMaxDataset(40, 10) dataloader = DataLoader( dataset=dataset, batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) ) + driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu") + num_consumed_batches = 2 already_seen_x_set = set() @@ -633,8 +628,13 @@ def test_save_and_load_with_randombatchsampler(only_state_dict): assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 - # 3. 检查 model 的参数是否正确 - # 4. 检查 batch_idx + # 3. 检查 fp16 是否被加载 + if fp16: + assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler) + + + # 4. 检查 model 的参数是否正确 + # 5. 检查 batch_idx start_batch = load_states.pop('batch_idx_in_epoch') assert start_batch == 2 * num_consumed_batches left_x_batches = set() @@ -654,8 +654,12 @@ def test_save_and_load_with_randombatchsampler(only_state_dict): finally: rank_zero_rm(path) -@pytest.mark.parametrize("only_state_dict", ([True, False])) -def test_save_and_load_with_randomsampler(only_state_dict): +# @pytest.mark.parametrize("only_state_dict", ([True, False])) +# TODO 在有迭代且使用了paddle.jit.save的时候会引发段错误,注释掉任意一段都不会出错 +# 但无法在单独的文件中复现 +@pytest.mark.parametrize("only_state_dict", ([True])) +@pytest.mark.parametrize("fp16", ([True, False])) +def test_save_and_load_with_randomsampler(only_state_dict, fp16): """ 测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 """ @@ -663,7 +667,7 @@ def test_save_and_load_with_randomsampler(only_state_dict): try: path = "model.ckp" - driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) + driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu") dataset = PaddleRandomMaxDataset(40, 10) batch_sampler = BatchSampler(dataset=dataset, batch_size=4) batch_sampler.sampler = RandomSampler(dataset, True) @@ -711,8 +715,12 @@ def test_save_and_load_with_randomsampler(only_state_dict): assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] - # 3. 检查 model 的参数是否正确 - # 4. 检查 batch_idx + # 3. 检查 fp16 是否被加载 + if fp16: + assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler) + + # 4. 检查 model 的参数是否正确 + # 5. 检查 batch_idx start_batch = load_states.pop('batch_idx_in_epoch') assert start_batch == 2 * num_consumed_batches left_x_batches = set()