Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
368c17fe73
8 changed files with 540 additions and 100 deletions
  1. +22
    -0
      fastNLP/core/drivers/paddle_driver/dist_utils.py
  2. +8
    -5
      fastNLP/core/drivers/paddle_driver/fleet.py
  3. +42
    -13
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  4. +3
    -4
      fastNLP/core/drivers/paddle_driver/utils.py
  5. +11
    -11
      fastNLP/core/samplers/reproducible_batch_sampler.py
  6. +2
    -8
      fastNLP/core/samplers/reproducible_sampler.py
  7. +412
    -27
      tests/core/drivers/paddle_driver/test_fleet.py
  8. +40
    -32
      tests/core/drivers/paddle_driver/test_single_device.py

+ 22
- 0
fastNLP/core/drivers/paddle_driver/dist_utils.py View File

@@ -1,4 +1,5 @@
import io import io
import os
import pickle import pickle
_pickler = pickle.Pickler _pickler = pickle.Pickler
_unpickler = pickle.Unpickler _unpickler = pickle.Unpickler
@@ -7,6 +8,7 @@ from typing import Any, List
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP
from fastNLP.envs.imports import _NEED_IMPORT_TORCH from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.envs.env import FASTNLP_NO_SYNC
if _NEED_IMPORT_TORCH: if _NEED_IMPORT_TORCH:
import torch import torch
from torch import distributed as dist from torch import distributed as dist
@@ -83,6 +85,14 @@ def fastnlp_paddle_gather_object(obj, object_gather_list=None, dst=0, group=DEFA
>>> output >>> output
['foo', 12, {1: 2}] ['foo', 12, {1: 2}]
""" """
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
return [obj]

if dist.get_rank() == dst:
object_gather_list = [None for _ in range(dist.get_world_size(group))]
else:
object_gather_list = None

if group is None: if group is None:
group = DEFAULT_TORCH_GROUP group = DEFAULT_TORCH_GROUP


@@ -207,6 +217,9 @@ def fastnlp_paddle_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP)
:param group: :param group:
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。
""" """
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
return [obj]

if group is None: if group is None:
group = DEFAULT_TORCH_GROUP group = DEFAULT_TORCH_GROUP
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
@@ -233,6 +246,12 @@ def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GR
:param group: :param group:
:return: :return:
""" """
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
if src == dist.get_rank(group):
return obj
else:
return None

if group is None: if group is None:
group = DEFAULT_TORCH_GROUP group = DEFAULT_TORCH_GROUP
cur_rank = dist.get_rank(group) cur_rank = dist.get_rank(group)
@@ -328,6 +347,9 @@ def all_gather_object(object_list, obj, group=None):
>>> output >>> output
['foo', 12, {1: 2}] ['foo', 12, {1: 2}]
""" """
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
return [obj]

if dist.distributed_c10d._rank_not_in_group(group): if dist.distributed_c10d._rank_not_in_group(group):
return return
if _TORCH_GREATER_EQUAL_1_8: if _TORCH_GREATER_EQUAL_1_8:


+ 8
- 5
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -1,6 +1,5 @@
import os import os
import shutil import shutil
from functools import partial
from typing import List, Union, Optional, Dict, Tuple, Callable from typing import List, Union, Optional, Dict, Tuple, Callable


from .paddle_driver import PaddleDriver from .paddle_driver import PaddleDriver
@@ -30,7 +29,7 @@ from fastNLP.core.samplers import (
re_instantiate_sampler, re_instantiate_sampler,
conversion_between_reproducible_and_unrepeated_sampler, conversion_between_reproducible_and_unrepeated_sampler,
) )
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC
from fastNLP.core.log import logger from fastNLP.core.log import logger


if _NEED_IMPORT_PADDLE: if _NEED_IMPORT_PADDLE:
@@ -38,7 +37,6 @@ if _NEED_IMPORT_PADDLE:
from paddle import DataParallel from paddle import DataParallel
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
import paddle.distributed as paddledist import paddle.distributed as paddledist
from paddle.io import BatchSampler
from paddle.optimizer import Optimizer from paddle.optimizer import Optimizer
from paddle.fluid.reader import _DatasetKind from paddle.fluid.reader import _DatasetKind
from paddle.fluid.dygraph import parallel_helper from paddle.fluid.dygraph import parallel_helper
@@ -236,7 +234,8 @@ class PaddleFleetDriver(PaddleDriver):
self.global_rank = paddledist.get_rank() self.global_rank = paddledist.get_rank()


def barrier(self): def barrier(self):
paddledist.barrier()
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行
paddledist.barrier()


def configure_fleet(self): def configure_fleet(self):
if not self._has_fleetwrapped and not isinstance(self.model, DataParallel): if not self._has_fleetwrapped and not isinstance(self.model, DataParallel):
@@ -305,7 +304,7 @@ class PaddleFleetDriver(PaddleDriver):
raise RuntimeError(f"There is no `{fn}` method in your model.") raise RuntimeError(f"There is no `{fn}` method in your model.")
else: else:
if hasattr(model, fn): if hasattr(model, fn):
logger.warning("Notice your model is a `DistributedDataParallel` model. And your model also implements "
logger.warning("Notice your model is a `DataParallel` model. And your model also implements "
f"the `{fn}` method, which we can not call actually, we will" f"the `{fn}` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.") " call `forward` function instead of `train_step` and you should note that.")
elif fn not in {"train_step", "evaluate_step"}: elif fn not in {"train_step", "evaluate_step"}:
@@ -453,6 +452,8 @@ class PaddleFleetDriver(PaddleDriver):
接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。
""" """
return return
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC == 2 直接返回。
return
return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group) return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group)


def all_gather(self, obj, group) -> List: def all_gather(self, obj, group) -> List:
@@ -479,4 +480,6 @@ class PaddleFleetDriver(PaddleDriver):
:return: :return:
""" """
return return
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC 表示不执行
return [obj]
return fastnlp_paddle_all_gather(obj, group=group) return fastnlp_paddle_all_gather(obj, group=group)

+ 42
- 13
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -7,7 +7,7 @@ from dataclasses import dataclass


import numpy as np import numpy as np


from .utils import _build_fp16_env, optimizer_state_to_device
from .utils import _build_fp16_env, optimizer_state_to_device, DummyGradScaler
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.core.drivers.driver import Driver from fastNLP.core.drivers.driver import Driver
from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device
@@ -247,18 +247,27 @@ class PaddleDriver(Driver):
# 会造成多余实际消耗的问题。 # 会造成多余实际消耗的问题。
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
if num_consumed_samples_array is not None: if num_consumed_samples_array is not None:
if isinstance(sampler, ReproducibleSampler):
# 如果是 sampler 的话,需要计算出实际的 sample 数目
try:
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。
if dataloader_args.batch_size is not None:
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
except: # 有可能 batch_size 为 None,就只有损失精度了
else: # 有可能 batch_size 为 None,就只有损失精度了
logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
num_consumed_batches = sampler_states['num_consumed_samples'] num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
states['sampler_states'] = sampler_states
else:
if dataloader_args.batch_size is not None:
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \
* num_consumed_batches
else:
logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
else: else:
raise RuntimeError( raise RuntimeError(
"The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.")
states['sampler_states'] = sampler_states


# 2. 保存模型的状态; # 2. 保存模型的状态;
if should_save_model: if should_save_model:
@@ -278,6 +287,12 @@ class PaddleDriver(Driver):


logger.debug("Save optimizer state dict.") logger.debug("Save optimizer state dict.")
states["optimizers_state_dict"] = optimizers_state_dict states["optimizers_state_dict"] = optimizers_state_dict

# 4.保存fp16的状态
if not isinstance(self.grad_scaler, DummyGradScaler):
grad_scaler_state_dict = self.grad_scaler.state_dict()
states['grad_scaler_state_dict'] = grad_scaler_state_dict

paddle.save(states, str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) paddle.save(states, str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)))


def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
@@ -285,7 +300,7 @@ class PaddleDriver(Driver):
states = paddle.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) states = paddle.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)))


# 1. 加载 optimizers 的状态; # 1. 加载 optimizers 的状态;
optimizers_state_dict = states["optimizers_state_dict"]
optimizers_state_dict = states.pop("optimizers_state_dict")
for i in range(len(self.optimizers)): for i in range(len(self.optimizers)):
optimizer: Optimizer = self.optimizers[i] optimizer: Optimizer = self.optimizers[i]
optimizer.set_state_dict(optimizers_state_dict[f"optimizer{i}"]) optimizer.set_state_dict(optimizers_state_dict[f"optimizer{i}"])
@@ -295,18 +310,32 @@ class PaddleDriver(Driver):
if should_load_model: if should_load_model:
self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict) self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict)
if only_state_dict: if only_state_dict:
logger.debug("Load model state dict.")
logger.debug("Load model state dict...")
else: else:
logger.debug("Load model.")

# 3. 恢复 sampler 的状态;
logger.debug("Load model...")

# 3. 加载fp16的状态;
if "grad_scaler_state_dict" in states:
grad_scaler_state_dict = states.pop("grad_scaler_state_dict")
if isinstance(self.grad_scaler, DummyGradScaler):
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=False)
self.grad_scaler = _grad_scaler()
self.fp16 = True
self.grad_scaler.load_state_dict(grad_scaler_state_dict)
logger.debug("Load grad_scaler state dict...")
elif not isinstance(self.grad_scaler, DummyGradScaler):
logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "
f"the training process may be unstable.")

# 4. 恢复 sampler 的状态;
dataloader_args = self.get_dataloader_args(dataloader) dataloader_args = self.get_dataloader_args(dataloader)
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
sampler = dataloader_args.batch_sampler sampler = dataloader_args.batch_sampler
elif isinstance(dataloader_args.sampler, ReproducibleSampler): elif isinstance(dataloader_args.sampler, ReproducibleSampler):
sampler = dataloader_args.sampler sampler = dataloader_args.sampler
elif self.is_distributed(): elif self.is_distributed():
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.")
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or "
"`ReproducibleSampler`.")
else: else:
sampler = RandomBatchSampler( sampler = RandomBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
@@ -316,7 +345,7 @@ class PaddleDriver(Driver):
sampler.load_state_dict(states["sampler_states"]) sampler.load_state_dict(states["sampler_states"])
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)


# 4. 修改 trainer_state.batch_idx_in_epoch
# 5. 修改 trainer_state.batch_idx_in_epoch
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler;
if not isinstance(sampler, ReproducibleBatchSampler): if not isinstance(sampler, ReproducibleBatchSampler):
if dataloader_args.drop_last: if dataloader_args.drop_last:


+ 3
- 4
fastNLP/core/drivers/paddle_driver/utils.py View File

@@ -19,7 +19,7 @@ if _NEED_IMPORT_PADDLE:
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import Layer from paddle.nn import Layer
from paddle.io import DataLoader, BatchSampler, Dataset
from paddle.io import DataLoader, BatchSampler
from paddle.amp import auto_cast, GradScaler from paddle.amp import auto_cast, GradScaler
else: else:
from fastNLP.core.utils.dummy_class import DummyClass as Layer from fastNLP.core.utils.dummy_class import DummyClass as Layer
@@ -140,8 +140,7 @@ class DummyGradScaler:


def _build_fp16_env(dummy=False): def _build_fp16_env(dummy=False):
if dummy: if dummy:
auto_cast = ExitStack
GradScaler = DummyGradScaler
return ExitStack, DummyGradScaler
else: else:
if not paddle.device.is_compiled_with_cuda(): if not paddle.device.is_compiled_with_cuda():
raise RuntimeError("No cuda") raise RuntimeError("No cuda")
@@ -150,7 +149,7 @@ def _build_fp16_env(dummy=False):
"NOTE: your device does NOT support faster training with fp16, " "NOTE: your device does NOT support faster training with fp16, "
"please switch to FP32 which is likely to be faster" "please switch to FP32 which is likely to be faster"
) )
return auto_cast, GradScaler
return auto_cast, GradScaler


def find_free_ports(num): def find_free_ports(num):
def __free_port(): def __free_port():


+ 11
- 11
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -19,7 +19,7 @@ from abc import abstractmethod


class ReproducibleBatchSampler: class ReproducibleBatchSampler:
def __init__(self, **kwargs): def __init__(self, **kwargs):
pass
self.num_replicas = 1


@abstractmethod @abstractmethod
def set_distributed(self, num_replicas, rank, pad=True): def set_distributed(self, num_replicas, rank, pad=True):
@@ -53,14 +53,6 @@ class ReproducibleBatchSampler:
def batch_idx_in_epoch(self): def batch_idx_in_epoch(self):
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.")


@property
def num_replicas(self):
return self._num_replicas

@num_replicas.setter
def num_replicas(self, value):
self._num_replicas = value



class RandomBatchSampler(ReproducibleBatchSampler): class RandomBatchSampler(ReproducibleBatchSampler):
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
@@ -322,7 +314,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
if len(batches[-1])==0: if len(batches[-1])==0:
batches.pop(-1) batches.pop(-1)


assert len(list(chain(*batches))) == self.num_left_samples
assert sum(map(len, batches)) == self.num_left_samples


if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size:
batches = batches[:-1] batches = batches[:-1]
@@ -419,4 +411,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
self.old_num_replicas = states['num_replicas'] self.old_num_replicas = states['num_replicas']


def set_epoch(self, epoch): def set_epoch(self, epoch):
self.epoch = epoch
self.epoch = epoch

@property
def batch_idx_in_epoch(self):
if self.drop_last:
return len(self.dataset) // self.batch_size - (len(self.dataset) - self.num_consumed_samples) // self.batch_size
else:
return (len(self.dataset) + self.batch_size - 1) // self.batch_size - \
(len(self.dataset) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size

+ 2
- 8
fastNLP/core/samplers/reproducible_sampler.py View File

@@ -20,6 +20,8 @@ class ReproducibleSampler:
或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。


""" """
def __init__(self, **kwargs):
self.num_replicas = 1


def set_distributed(self, num_replicas, rank, pad=True): def set_distributed(self, num_replicas, rank, pad=True):
raise NotImplementedError("Each specific sampler should implement its own `set_distributed` method.") raise NotImplementedError("Each specific sampler should implement its own `set_distributed` method.")
@@ -47,14 +49,6 @@ class ReproducibleSampler:
def set_epoch(self, epoch): def set_epoch(self, epoch):
pass pass


@property
def num_repliacs(self):
return self._num_replicas

@num_repliacs.setter
def num_repliacs(self, value):
self._num_replicas = value



class RandomSampler(ReproducibleSampler): class RandomSampler(ReproducibleSampler):
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):


+ 412
- 27
tests/core/drivers/paddle_driver/test_fleet.py View File

@@ -1,6 +1,6 @@
from dataclasses import replace
import pytest import pytest
import os import os
from pathlib import Path


os.environ["FASTNLP_BACKEND"] = "paddle" os.environ["FASTNLP_BACKEND"] = "paddle"
from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver
@@ -12,19 +12,22 @@ from fastNLP.core.samplers import (
UnrepeatedSequentialSampler, UnrepeatedSequentialSampler,
) )
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleNormalDataset
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
from tests.helpers.utils import magic_argv_env_context from tests.helpers.utils import magic_argv_env_context
from fastNLP.core import synchronize_safe_rm


import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from paddle.io import DataLoader, BatchSampler from paddle.io import DataLoader, BatchSampler


def generate_driver(num_labels, feature_dimension):
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"):
paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension) paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver( driver = PaddleFleetDriver(
model=paddle_model, model=paddle_model,
parallel_device=[0,1],
parallel_device=device,
fp16=fp16,
output_from_new_proc=output_from_new_proc
) )
driver.set_optimizers(paddle_opt) driver.set_optimizers(paddle_opt)
driver.setup() driver.setup()
@@ -33,7 +36,7 @@ def generate_driver(num_labels, feature_dimension):


############################################################################ ############################################################################
# #
# 测试PaddleFleetDriver的一些函数
# 测试 PaddleFleetDriver 的一些函数
# #
############################################################################ ############################################################################


@@ -46,6 +49,19 @@ class TestFleetDriverFunction:
def setup_class(cls): def setup_class(cls):
cls.driver = generate_driver(10, 10) cls.driver = generate_driver(10, 10)


@magic_argv_env_context
def test_multi_drivers(self):
"""
测试使用了多个 PaddleFleetDriver 的情况。
"""
driver2 = generate_driver(20, 10)

with pytest.raises(RuntimeError):
# 设备设置不同,应该报错
driver3 = generate_driver(20, 3, device=[0,2])

dist.barrier()

@magic_argv_env_context @magic_argv_env_context
def test_move_data_to_device(self): def test_move_data_to_device(self):
""" """
@@ -106,10 +122,11 @@ class TestSetDistReproDataloader:


@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
cls.driver = generate_driver(10, 10)
cls.device = [0, 1]
cls.driver = generate_driver(10, 10, device=cls.device)


def setup_method(self): def setup_method(self):
self.dataset = PaddleNormalDataset(20)
self.dataset = PaddleNormalDataset(40)


""" """
传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况
@@ -118,9 +135,10 @@ class TestSetDistReproDataloader:


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle):
def test_with_dist_batch_sampler(self, shuffle):
""" """
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现
此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler
""" """
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
@@ -130,14 +148,16 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler is batch_sampler assert replaced_loader.batch_sampler is batch_sampler
self.check_distributed_sampler(replaced_loader.batch_sampler) self.check_distributed_sampler(replaced_loader.batch_sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier() dist.barrier()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle):
def test_with_dist_sampler(self, shuffle):
""" """
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现
此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler
""" """
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
sampler = RandomSampler(self.dataset, shuffle=shuffle) sampler = RandomSampler(self.dataset, shuffle=shuffle)
@@ -150,6 +170,7 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.sampler is sampler assert replaced_loader.batch_sampler.sampler is sampler
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)


dist.barrier() dist.barrier()
@@ -161,9 +182,10 @@ class TestSetDistReproDataloader:
""" """


@magic_argv_env_context @magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_none_reproducible_true(self):
def test_with_dist_none_reproducible_true(self):
""" """
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现
当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错
""" """
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
@@ -173,11 +195,14 @@ class TestSetDistReproDataloader:
dist.barrier() dist.barrier()


@magic_argv_env_context @magic_argv_env_context
# @pytest.mark.parametrize("shuffle", ([True, False]))
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle):
def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle):
""" """
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler
时的表现 时的表现
此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler
和原 dataloader 相同
""" """
dataloader = DataLoader( dataloader = DataLoader(
self.dataset, self.dataset,
@@ -194,16 +219,19 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.batch_sampler.batch_size == 4
self.check_distributed_sampler(dataloader.batch_sampler) self.check_distributed_sampler(dataloader.batch_sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)


dist.barrier() dist.barrier()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self, shuffle):
def test_with_dist_none_reproducible_false_dataloader_reproducible_sampler(self, shuffle):
""" """
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现
此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其
batch_sampler.sampler 和原 dataloader 相同
""" """
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle) batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
batch_sampler.sampler.set_distributed( batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size, num_replicas=self.driver.world_size,
@@ -220,16 +248,19 @@ class TestSetDistReproDataloader:
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.batch_sampler.drop_last == False assert replaced_loader.batch_sampler.drop_last == False
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier() dist.barrier()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self, shuffle):
def test_with_dist_none_reproducible_false_dataloader_normal(self, shuffle):
""" """
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现
此时直接返回原来的 dataloader,不做任何处理。
""" """
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
@@ -244,10 +275,11 @@ class TestSetDistReproDataloader:


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle):
def test_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle):
""" """
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler
的表现 的表现
此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性
""" """
dataloader = DataLoader( dataloader = DataLoader(
dataset=self.dataset, dataset=self.dataset,
@@ -265,12 +297,14 @@ class TestSetDistReproDataloader:


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self, shuffle):
def test_with_dist_dist_dataloader_reproducible_sampler(self, shuffle):
""" """
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
的表现 的表现
此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关
的属性
""" """
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2, shuffle=shuffle)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4, shuffle=shuffle)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle) batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
dataloader = DataLoader( dataloader = DataLoader(
self.dataset, self.dataset,
@@ -282,16 +316,18 @@ class TestSetDistReproDataloader:
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier() dist.barrier()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self, shuffle):
def test_with_dist_dist_dataloader_normal(self, shuffle):
""" """
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现
此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关
的属性
""" """
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
@@ -302,6 +338,7 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier() dist.barrier()


""" """
@@ -311,12 +348,14 @@ class TestSetDistReproDataloader:


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle):
def test_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle):
""" """
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
的表现 的表现
此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关
的属性
""" """
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle) batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
dataloader = DataLoader( dataloader = DataLoader(
self.dataset, self.dataset,
@@ -328,19 +367,20 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler, BatchSampler) assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier() dist.barrier()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle):
def test_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle):
""" """
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler
的表现 的表现
此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler
""" """
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4)
batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, shuffle) batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, shuffle)
dataloader = DataLoader( dataloader = DataLoader(
self.dataset, self.dataset,
@@ -353,16 +393,18 @@ class TestSetDistReproDataloader:
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.drop_last == dataloader.drop_last assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier() dist.barrier()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self, shuffle):
def test_with_dist_unrepeat_dataloader_normal(self, shuffle):
""" """
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现
此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关
的属性
""" """
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
@@ -385,3 +427,346 @@ class TestSetDistReproDataloader:
if not isinstance(sampler, UnrepeatedSampler): if not isinstance(sampler, UnrepeatedSampler):
assert sampler.pad == True assert sampler.pad == True


def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle):
"""
测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确
"""
# 迭代两个 batch
num_replicas = len(self.device)
num_consumed_batches = 2
already_seen_idx = set()
for idx, batch in enumerate(replaced_loader):
if idx >= num_consumed_batches:
break
already_seen_idx.update(batch)
dist.barrier()
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict()
else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict()

# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas
# 重新改造 dataloader
new_loader = DataLoader(
dataset=replaced_loader.dataset,
batch_sampler=BucketedBatchSampler(
replaced_loader.dataset,
length=replaced_loader.dataset._data,
batch_size=batch_size,
shuffle=shuffle,
)
)
new_loader.batch_sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank,
pad=True
)
new_loader.batch_sampler.load_state_dict(sampler_states)
else:
batch_size = replaced_loader.batch_sampler.batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas
# 重新构造 dataloader
batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size)
batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle)
batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank
)
new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler)
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
for idx, batch in enumerate(new_loader):
left_idxes.update(batch)

assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas
assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas


############################################################################
#
# 测试 save 和 load 相关的功能
#
############################################################################
class TestSaveLoad:
"""
测试多卡情况下 save 和 load 相关函数的表现
"""

@classmethod
def setup_class(cls):
# 不在这里 setup 的话会报错
cls.driver = generate_driver(10, 10)

def setup_method(self):
self.dataset = PaddleRandomMaxDataset(20, 10)

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_model(self, only_state_dict):
"""
测试 save_model 和 load_model 函数
"""
try:
path = "model"

dataloader = DataLoader(self.dataset, batch_size=2)
self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10)

if only_state_dict:
self.driver1.save_model(path, only_state_dict)
else:
self.driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((4, 10))])

# 同步
dist.barrier()
self.driver2.load_model(path, only_state_dict)

for idx, batch in enumerate(dataloader):
batch = self.driver1.move_data_to_device(batch)
res1 = self.driver1.model(
batch,
fastnlp_fn=self.driver1.model._layers.model.evaluate_step,
# Driver.model -> DataParallel._layers -> _FleetWrappingModel.model
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
res2 = self.driver2.model(
batch,
fastnlp_fn=self.driver2.model._layers.model.evaluate_step,
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)

assert paddle.equal_all(res1["pred"], res2["pred"])
finally:
if only_state_dict:
synchronize_safe_rm(path)
else:
synchronize_safe_rm(path + ".pdiparams")
synchronize_safe_rm(path + ".pdiparams.info")
synchronize_safe_rm(path + ".pdmodel")

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("fp16", ([True, False]))
@pytest.mark.parametrize("device", ([[0,1]]))
def test_save_and_load_with_bucketedbatchsampler(self, device, only_state_dict, fp16):
"""
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况
"""

try:
path = "model.ckp"
num_replicas = len(device)

self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \
generate_driver(10, 10, device=device, fp16=False)
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=BucketedBatchSampler(
self.dataset,
length=[10 for i in range(len(self.dataset))],
batch_size=4,
)
)
dataloader.batch_sampler.set_distributed(
num_replicas=self.driver1.world_size,
rank=self.driver1.global_rank,
pad=True
)
num_consumed_batches = 2

already_seen_x_set = set()
already_seen_y_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])

# 同步
dist.barrier()

# 保存状态
sampler_states = dataloader.batch_sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
# 加载
# 更改 batch_size
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=BucketedBatchSampler(
self.dataset,
length=[10 for i in range(len(self.dataset))],
batch_size=4,
)
)
dataloader.batch_sampler.set_distributed(
num_replicas=self.driver2.world_size,
rank=self.driver2.global_rank,
pad=True
)
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空

# 2. 检查 batch_sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert replaced_loader.batch_sampler is dataloader.batch_sampler
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler.seed == sampler_states["seed"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas

# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler)

# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
for idx, batch in enumerate(replaced_loader):

left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
res1 = self.driver1.model(
batch,
fastnlp_fn=self.driver1.model._layers.model.evaluate_step,
# Driver.model -> DataParallel._layers -> _FleetWrappingModel.model
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
res2 = self.driver2.model(
batch,
fastnlp_fn=self.driver2.model._layers.model.evaluate_step,
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
assert paddle.equal_all(res1["pred"], res2["pred"])

assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas
assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas
assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas
finally:
synchronize_safe_rm(path)

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("fp16", ([True, False]))
@pytest.mark.parametrize("device", ([[0,1]]))
def test_save_and_load_with_randomsampler(self, device, only_state_dict, fp16):
"""
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况
"""

try:
path = "model.ckp"

num_replicas = len(device)

self.driver1 = generate_driver(10, 10, device=device, fp16=fp16)
self.driver2 = generate_driver(10, 10, device=device, fp16=False)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler.sampler.set_distributed(
num_replicas=self.driver1.world_size,
rank=self.driver1.global_rank,
pad=True
)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
)
num_consumed_batches = 2

already_seen_x_set = set()
already_seen_y_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])

# 同步
dist.barrier()

# 保存状态
sampler_states = dataloader.batch_sampler.sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
# 加载
# 更改 batch_size
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler.sampler.set_distributed(
num_replicas=self.driver2.world_size,
rank=self.driver2.global_rank,
pad=True
)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
)
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")

# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空

# 2. 检查 sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"]
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"]
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches * num_replicas
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"]
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]
# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler)

# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
for idx, batch in enumerate(replaced_loader):

left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
res1 = self.driver1.model(
batch,
fastnlp_fn=self.driver1.model._layers.model.evaluate_step,
# Driver.model -> DataParallel._layers -> _FleetWrappingModel.model
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
res2 = self.driver2.model(
batch,
fastnlp_fn=self.driver2.model._layers.model.evaluate_step,
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
assert paddle.equal_all(res1["pred"], res2["pred"])

assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas
assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas
assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas

finally:
synchronize_safe_rm(path)

+ 40
- 32
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -1,4 +1,3 @@
from dataclasses import replace
import os import os
from re import S from re import S
os.environ["FASTNLP_BACKEND"] = "paddle" os.environ["FASTNLP_BACKEND"] = "paddle"
@@ -349,7 +348,7 @@ class TestSingleDeviceFunction:
# #
############################################################################ ############################################################################


class TestSetDistReproDataloder:
class TestSetDistReproDataloader:
""" """
专门测试 set_dist_repro_dataloader 函数的类 专门测试 set_dist_repro_dataloader 函数的类
""" """
@@ -358,7 +357,7 @@ class TestSetDistReproDataloder:
model = PaddleNormalModel_Classification_1(10, 32) model = PaddleNormalModel_Classification_1(10, 32)
self.driver = PaddleSingleDriver(model, device="cpu") self.driver = PaddleSingleDriver(model, device="cpu")
def test_set_dist_repro_dataloader_with_reproducible_false(self):
def test_with_reproducible_false(self):
""" """
测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现
当dist为字符串时,此时应该返回原来的 dataloader 当dist为字符串时,此时应该返回原来的 dataloader
@@ -369,7 +368,7 @@ class TestSetDistReproDataloder:
assert replaced_loader is dataloader assert replaced_loader is dataloader


@pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("shuffle", [True, False])
def test_set_dist_repro_dataloader_with_reproducible_true(self, shuffle):
def test_with_reproducible_true(self, shuffle):
""" """
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True), 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True),
@@ -394,7 +393,7 @@ class TestSetDistReproDataloder:
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)


@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle):
def test_with_dist_batch_sampler(self, shuffle):
""" """
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler
@@ -410,7 +409,7 @@ class TestSetDistReproDataloder:
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)


@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle):
def test_with_dist_sampler(self, shuffle):
""" """
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现
应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler 应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler
@@ -429,7 +428,7 @@ class TestSetDistReproDataloder:
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)


@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self, shuffle):
def test_with_dataloader_reproducible_batch_sampler(self, shuffle):
""" """
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader,且其余各项设置和原来相同 应该返回新的 dataloader,且其余各项设置和原来相同
@@ -453,7 +452,7 @@ class TestSetDistReproDataloder:
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)


@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self, shuffle):
def test_with_dataloader_reproducible_sampler(self, shuffle):
""" """
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader,且其余各项设置和原来相同 应该返回新的 dataloader,且其余各项设置和原来相同
@@ -498,10 +497,7 @@ class TestSetDistReproDataloder:
left_idxes = set() left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size batch_size = replaced_loader.batch_sampler.batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新改造 dataloader # 重新改造 dataloader
new_loader = DataLoader( new_loader = DataLoader(
dataset=replaced_loader.dataset, dataset=replaced_loader.dataset,
@@ -514,11 +510,8 @@ class TestSetDistReproDataloder:
new_loader.batch_sampler.load_state_dict(sampler_states) new_loader.batch_sampler.load_state_dict(sampler_states)
else: else:
batch_size = replaced_loader.batch_sampler.batch_size batch_size = replaced_loader.batch_sampler.batch_size
num_consumed_batches = num_consumed_batches * batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
num_consumed_samples = num_consumed_batches * batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新构造 dataloader # 重新构造 dataloader
batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size) batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size)
batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle) batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle)
@@ -536,13 +529,13 @@ class TestSetDistReproDataloder:
# #
############################################################################ ############################################################################


def generate_random_driver(features, labels):
def generate_random_driver(features, labels, fp16=False, device="cpu"):
""" """
生成driver 生成driver
""" """
model = PaddleNormalModel_Classification_1(labels, features) model = PaddleNormalModel_Classification_1(labels, features)
opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.01) opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.01)
driver = PaddleSingleDriver(model, device="cpu")
driver = PaddleSingleDriver(model, device=device, fp16=fp16)
driver.set_optimizers(opt) driver.set_optimizers(opt)
driver.setup() driver.setup()


@@ -550,8 +543,8 @@ def generate_random_driver(features, labels):


@pytest.fixture @pytest.fixture
def prepare_test_save_load(): def prepare_test_save_load():
dataset = PaddleRandomMaxDataset(320, 10)
dataloader = DataLoader(dataset, batch_size=32)
dataset = PaddleRandomMaxDataset(40, 10)
dataloader = DataLoader(dataset, batch_size=4)
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
return driver1, driver2, dataloader return driver1, driver2, dataloader


@@ -584,21 +577,23 @@ def test_save_and_load_model(prepare_test_save_load, only_state_dict):
rank_zero_rm(path + ".pdiparams.info") rank_zero_rm(path + ".pdiparams.info")
rank_zero_rm(path + ".pdmodel") rank_zero_rm(path + ".pdmodel")


@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_with_randombatchsampler(only_state_dict):
# @pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("only_state_dict", ([True]))
@pytest.mark.parametrize("fp16", ([True, False]))
def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
""" """
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况
""" """


try: try:
path = "model.ckp" path = "model.ckp"

driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
dataset = PaddleRandomMaxDataset(40, 10) dataset = PaddleRandomMaxDataset(40, 10)
dataloader = DataLoader( dataloader = DataLoader(
dataset=dataset, dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
) )
driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu")

num_consumed_batches = 2 num_consumed_batches = 2


already_seen_x_set = set() already_seen_x_set = set()
@@ -633,8 +628,13 @@ def test_save_and_load_with_randombatchsampler(only_state_dict):
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4


# 3. 检查 model 的参数是否正确
# 4. 检查 batch_idx
# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler)


# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch') start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches assert start_batch == 2 * num_consumed_batches
left_x_batches = set() left_x_batches = set()
@@ -654,8 +654,12 @@ def test_save_and_load_with_randombatchsampler(only_state_dict):
finally: finally:
rank_zero_rm(path) rank_zero_rm(path)


@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_with_randomsampler(only_state_dict):
# @pytest.mark.parametrize("only_state_dict", ([True, False]))
# TODO 在有迭代且使用了paddle.jit.save的时候会引发段错误,注释掉任意一段都不会出错
# 但无法在单独的文件中复现
@pytest.mark.parametrize("only_state_dict", ([True]))
@pytest.mark.parametrize("fp16", ([True, False]))
def test_save_and_load_with_randomsampler(only_state_dict, fp16):
""" """
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况
""" """
@@ -663,7 +667,7 @@ def test_save_and_load_with_randomsampler(only_state_dict):
try: try:
path = "model.ckp" path = "model.ckp"


driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu")
dataset = PaddleRandomMaxDataset(40, 10) dataset = PaddleRandomMaxDataset(40, 10)
batch_sampler = BatchSampler(dataset=dataset, batch_size=4) batch_sampler = BatchSampler(dataset=dataset, batch_size=4)
batch_sampler.sampler = RandomSampler(dataset, True) batch_sampler.sampler = RandomSampler(dataset, True)
@@ -711,8 +715,12 @@ def test_save_and_load_with_randomsampler(only_state_dict):
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"]
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]


# 3. 检查 model 的参数是否正确
# 4. 检查 batch_idx
# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler)

# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch') start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches assert start_batch == 2 * num_consumed_batches
left_x_batches = set() left_x_batches = set()


Loading…
Cancel
Save