Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
368c17fe73
8 changed files with 540 additions and 100 deletions
  1. +22
    -0
      fastNLP/core/drivers/paddle_driver/dist_utils.py
  2. +8
    -5
      fastNLP/core/drivers/paddle_driver/fleet.py
  3. +42
    -13
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  4. +3
    -4
      fastNLP/core/drivers/paddle_driver/utils.py
  5. +11
    -11
      fastNLP/core/samplers/reproducible_batch_sampler.py
  6. +2
    -8
      fastNLP/core/samplers/reproducible_sampler.py
  7. +412
    -27
      tests/core/drivers/paddle_driver/test_fleet.py
  8. +40
    -32
      tests/core/drivers/paddle_driver/test_single_device.py

+ 22
- 0
fastNLP/core/drivers/paddle_driver/dist_utils.py View File

@@ -1,4 +1,5 @@
import io
import os
import pickle
_pickler = pickle.Pickler
_unpickler = pickle.Unpickler
@@ -7,6 +8,7 @@ from typing import Any, List
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.envs.env import FASTNLP_NO_SYNC
if _NEED_IMPORT_TORCH:
import torch
from torch import distributed as dist
@@ -83,6 +85,14 @@ def fastnlp_paddle_gather_object(obj, object_gather_list=None, dst=0, group=DEFA
>>> output
['foo', 12, {1: 2}]
"""
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
return [obj]

if dist.get_rank() == dst:
object_gather_list = [None for _ in range(dist.get_world_size(group))]
else:
object_gather_list = None

if group is None:
group = DEFAULT_TORCH_GROUP

@@ -207,6 +217,9 @@ def fastnlp_paddle_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP)
:param group:
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。
"""
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
return [obj]

if group is None:
group = DEFAULT_TORCH_GROUP
if isinstance(obj, torch.Tensor):
@@ -233,6 +246,12 @@ def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GR
:param group:
:return:
"""
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
if src == dist.get_rank(group):
return obj
else:
return None

if group is None:
group = DEFAULT_TORCH_GROUP
cur_rank = dist.get_rank(group)
@@ -328,6 +347,9 @@ def all_gather_object(object_list, obj, group=None):
>>> output
['foo', 12, {1: 2}]
"""
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
return [obj]

if dist.distributed_c10d._rank_not_in_group(group):
return
if _TORCH_GREATER_EQUAL_1_8:


+ 8
- 5
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -1,6 +1,5 @@
import os
import shutil
from functools import partial
from typing import List, Union, Optional, Dict, Tuple, Callable

from .paddle_driver import PaddleDriver
@@ -30,7 +29,7 @@ from fastNLP.core.samplers import (
re_instantiate_sampler,
conversion_between_reproducible_and_unrepeated_sampler,
)
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC
from fastNLP.core.log import logger

if _NEED_IMPORT_PADDLE:
@@ -38,7 +37,6 @@ if _NEED_IMPORT_PADDLE:
from paddle import DataParallel
import paddle.distributed.fleet as fleet
import paddle.distributed as paddledist
from paddle.io import BatchSampler
from paddle.optimizer import Optimizer
from paddle.fluid.reader import _DatasetKind
from paddle.fluid.dygraph import parallel_helper
@@ -236,7 +234,8 @@ class PaddleFleetDriver(PaddleDriver):
self.global_rank = paddledist.get_rank()

def barrier(self):
paddledist.barrier()
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行
paddledist.barrier()

def configure_fleet(self):
if not self._has_fleetwrapped and not isinstance(self.model, DataParallel):
@@ -305,7 +304,7 @@ class PaddleFleetDriver(PaddleDriver):
raise RuntimeError(f"There is no `{fn}` method in your model.")
else:
if hasattr(model, fn):
logger.warning("Notice your model is a `DistributedDataParallel` model. And your model also implements "
logger.warning("Notice your model is a `DataParallel` model. And your model also implements "
f"the `{fn}` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.")
elif fn not in {"train_step", "evaluate_step"}:
@@ -453,6 +452,8 @@ class PaddleFleetDriver(PaddleDriver):
接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。
"""
return
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC == 2 直接返回。
return
return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group)

def all_gather(self, obj, group) -> List:
@@ -479,4 +480,6 @@ class PaddleFleetDriver(PaddleDriver):
:return:
"""
return
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC 表示不执行
return [obj]
return fastnlp_paddle_all_gather(obj, group=group)

+ 42
- 13
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -7,7 +7,7 @@ from dataclasses import dataclass

import numpy as np

from .utils import _build_fp16_env, optimizer_state_to_device
from .utils import _build_fp16_env, optimizer_state_to_device, DummyGradScaler
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.core.drivers.driver import Driver
from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device
@@ -247,18 +247,27 @@ class PaddleDriver(Driver):
# 会造成多余实际消耗的问题。
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
if num_consumed_samples_array is not None:
if isinstance(sampler, ReproducibleSampler):
# 如果是 sampler 的话,需要计算出实际的 sample 数目
try:
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。
if dataloader_args.batch_size is not None:
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
except: # 有可能 batch_size 为 None,就只有损失精度了
else: # 有可能 batch_size 为 None,就只有损失精度了
logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
states['sampler_states'] = sampler_states
else:
if dataloader_args.batch_size is not None:
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \
* num_consumed_batches
else:
logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
else:
raise RuntimeError(
"The sampler has no `state_dict()` method, it will fail to recover to the specific batch.")
states['sampler_states'] = sampler_states

# 2. 保存模型的状态;
if should_save_model:
@@ -278,6 +287,12 @@ class PaddleDriver(Driver):

logger.debug("Save optimizer state dict.")
states["optimizers_state_dict"] = optimizers_state_dict

# 4.保存fp16的状态
if not isinstance(self.grad_scaler, DummyGradScaler):
grad_scaler_state_dict = self.grad_scaler.state_dict()
states['grad_scaler_state_dict'] = grad_scaler_state_dict

paddle.save(states, str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)))

def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
@@ -285,7 +300,7 @@ class PaddleDriver(Driver):
states = paddle.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)))

# 1. 加载 optimizers 的状态;
optimizers_state_dict = states["optimizers_state_dict"]
optimizers_state_dict = states.pop("optimizers_state_dict")
for i in range(len(self.optimizers)):
optimizer: Optimizer = self.optimizers[i]
optimizer.set_state_dict(optimizers_state_dict[f"optimizer{i}"])
@@ -295,18 +310,32 @@ class PaddleDriver(Driver):
if should_load_model:
self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict)
if only_state_dict:
logger.debug("Load model state dict.")
logger.debug("Load model state dict...")
else:
logger.debug("Load model.")

# 3. 恢复 sampler 的状态;
logger.debug("Load model...")

# 3. 加载fp16的状态;
if "grad_scaler_state_dict" in states:
grad_scaler_state_dict = states.pop("grad_scaler_state_dict")
if isinstance(self.grad_scaler, DummyGradScaler):
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=False)
self.grad_scaler = _grad_scaler()
self.fp16 = True
self.grad_scaler.load_state_dict(grad_scaler_state_dict)
logger.debug("Load grad_scaler state dict...")
elif not isinstance(self.grad_scaler, DummyGradScaler):
logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "
f"the training process may be unstable.")

# 4. 恢复 sampler 的状态;
dataloader_args = self.get_dataloader_args(dataloader)
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
sampler = dataloader_args.batch_sampler
elif isinstance(dataloader_args.sampler, ReproducibleSampler):
sampler = dataloader_args.sampler
elif self.is_distributed():
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.")
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or "
"`ReproducibleSampler`.")
else:
sampler = RandomBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
@@ -316,7 +345,7 @@ class PaddleDriver(Driver):
sampler.load_state_dict(states["sampler_states"])
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)

# 4. 修改 trainer_state.batch_idx_in_epoch
# 5. 修改 trainer_state.batch_idx_in_epoch
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler;
if not isinstance(sampler, ReproducibleBatchSampler):
if dataloader_args.drop_last:


+ 3
- 4
fastNLP/core/drivers/paddle_driver/utils.py View File

@@ -19,7 +19,7 @@ if _NEED_IMPORT_PADDLE:
import paddle
from paddle import nn
from paddle.nn import Layer
from paddle.io import DataLoader, BatchSampler, Dataset
from paddle.io import DataLoader, BatchSampler
from paddle.amp import auto_cast, GradScaler
else:
from fastNLP.core.utils.dummy_class import DummyClass as Layer
@@ -140,8 +140,7 @@ class DummyGradScaler:

def _build_fp16_env(dummy=False):
if dummy:
auto_cast = ExitStack
GradScaler = DummyGradScaler
return ExitStack, DummyGradScaler
else:
if not paddle.device.is_compiled_with_cuda():
raise RuntimeError("No cuda")
@@ -150,7 +149,7 @@ def _build_fp16_env(dummy=False):
"NOTE: your device does NOT support faster training with fp16, "
"please switch to FP32 which is likely to be faster"
)
return auto_cast, GradScaler
return auto_cast, GradScaler

def find_free_ports(num):
def __free_port():


+ 11
- 11
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -19,7 +19,7 @@ from abc import abstractmethod

class ReproducibleBatchSampler:
def __init__(self, **kwargs):
pass
self.num_replicas = 1

@abstractmethod
def set_distributed(self, num_replicas, rank, pad=True):
@@ -53,14 +53,6 @@ class ReproducibleBatchSampler:
def batch_idx_in_epoch(self):
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.")

@property
def num_replicas(self):
return self._num_replicas

@num_replicas.setter
def num_replicas(self, value):
self._num_replicas = value


class RandomBatchSampler(ReproducibleBatchSampler):
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
@@ -322,7 +314,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
if len(batches[-1])==0:
batches.pop(-1)

assert len(list(chain(*batches))) == self.num_left_samples
assert sum(map(len, batches)) == self.num_left_samples

if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size:
batches = batches[:-1]
@@ -419,4 +411,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
self.old_num_replicas = states['num_replicas']

def set_epoch(self, epoch):
self.epoch = epoch
self.epoch = epoch

@property
def batch_idx_in_epoch(self):
if self.drop_last:
return len(self.dataset) // self.batch_size - (len(self.dataset) - self.num_consumed_samples) // self.batch_size
else:
return (len(self.dataset) + self.batch_size - 1) // self.batch_size - \
(len(self.dataset) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size

+ 2
- 8
fastNLP/core/samplers/reproducible_sampler.py View File

@@ -20,6 +20,8 @@ class ReproducibleSampler:
或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。

"""
def __init__(self, **kwargs):
self.num_replicas = 1

def set_distributed(self, num_replicas, rank, pad=True):
raise NotImplementedError("Each specific sampler should implement its own `set_distributed` method.")
@@ -47,14 +49,6 @@ class ReproducibleSampler:
def set_epoch(self, epoch):
pass

@property
def num_repliacs(self):
return self._num_replicas

@num_repliacs.setter
def num_repliacs(self, value):
self._num_replicas = value


class RandomSampler(ReproducibleSampler):
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):


+ 412
- 27
tests/core/drivers/paddle_driver/test_fleet.py View File

@@ -1,6 +1,6 @@
from dataclasses import replace
import pytest
import os
from pathlib import Path

os.environ["FASTNLP_BACKEND"] = "paddle"
from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver
@@ -12,19 +12,22 @@ from fastNLP.core.samplers import (
UnrepeatedSequentialSampler,
)
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleNormalDataset
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
from tests.helpers.utils import magic_argv_env_context
from fastNLP.core import synchronize_safe_rm

import paddle
import paddle.distributed as dist
from paddle.io import DataLoader, BatchSampler

def generate_driver(num_labels, feature_dimension):
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"):
paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
parallel_device=device,
fp16=fp16,
output_from_new_proc=output_from_new_proc
)
driver.set_optimizers(paddle_opt)
driver.setup()
@@ -33,7 +36,7 @@ def generate_driver(num_labels, feature_dimension):

############################################################################
#
# 测试PaddleFleetDriver的一些函数
# 测试 PaddleFleetDriver 的一些函数
#
############################################################################

@@ -46,6 +49,19 @@ class TestFleetDriverFunction:
def setup_class(cls):
cls.driver = generate_driver(10, 10)

@magic_argv_env_context
def test_multi_drivers(self):
"""
测试使用了多个 PaddleFleetDriver 的情况。
"""
driver2 = generate_driver(20, 10)

with pytest.raises(RuntimeError):
# 设备设置不同,应该报错
driver3 = generate_driver(20, 3, device=[0,2])

dist.barrier()

@magic_argv_env_context
def test_move_data_to_device(self):
"""
@@ -106,10 +122,11 @@ class TestSetDistReproDataloader:

@classmethod
def setup_class(cls):
cls.driver = generate_driver(10, 10)
cls.device = [0, 1]
cls.driver = generate_driver(10, 10, device=cls.device)

def setup_method(self):
self.dataset = PaddleNormalDataset(20)
self.dataset = PaddleNormalDataset(40)

"""
传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况
@@ -118,9 +135,10 @@ class TestSetDistReproDataloader:

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle):
def test_with_dist_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现
此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
@@ -130,14 +148,16 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler is batch_sampler
self.check_distributed_sampler(replaced_loader.batch_sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier()

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle):
def test_with_dist_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现
此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
sampler = RandomSampler(self.dataset, shuffle=shuffle)
@@ -150,6 +170,7 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.sampler is sampler
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

dist.barrier()
@@ -161,9 +182,10 @@ class TestSetDistReproDataloader:
"""

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_none_reproducible_true(self):
def test_with_dist_none_reproducible_true(self):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现
当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
with pytest.raises(RuntimeError):
@@ -173,11 +195,14 @@ class TestSetDistReproDataloader:
dist.barrier()

@magic_argv_env_context
# @pytest.mark.parametrize("shuffle", ([True, False]))
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle):
def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler
时的表现
此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler
和原 dataloader 相同
"""
dataloader = DataLoader(
self.dataset,
@@ -194,16 +219,19 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler.batch_size == 4
self.check_distributed_sampler(dataloader.batch_sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

dist.barrier()

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self, shuffle):
def test_with_dist_none_reproducible_false_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现
此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其
batch_sampler.sampler 和原 dataloader 相同
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size,
@@ -220,16 +248,19 @@ class TestSetDistReproDataloader:
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.batch_sampler.drop_last == False
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier()

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self, shuffle):
def test_with_dist_none_reproducible_false_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现
此时直接返回原来的 dataloader,不做任何处理。
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
@@ -244,10 +275,11 @@ class TestSetDistReproDataloader:

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle):
def test_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler
的表现
此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性
"""
dataloader = DataLoader(
dataset=self.dataset,
@@ -265,12 +297,14 @@ class TestSetDistReproDataloader:

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self, shuffle):
def test_with_dist_dist_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
的表现
此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关
的属性
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2, shuffle=shuffle)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4, shuffle=shuffle)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
@@ -282,16 +316,18 @@ class TestSetDistReproDataloader:
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self, shuffle):
def test_with_dist_dist_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现
此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关
的属性
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
@@ -302,6 +338,7 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()

"""
@@ -311,12 +348,14 @@ class TestSetDistReproDataloader:

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle):
def test_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
的表现
此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关
的属性
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
@@ -328,19 +367,20 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle):
def test_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler
的表现
此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4)
batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
@@ -353,16 +393,18 @@ class TestSetDistReproDataloader:
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self, shuffle):
def test_with_dist_unrepeat_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现
此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关
的属性
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
@@ -385,3 +427,346 @@ class TestSetDistReproDataloader:
if not isinstance(sampler, UnrepeatedSampler):
assert sampler.pad == True

def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle):
"""
测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确
"""
# 迭代两个 batch
num_replicas = len(self.device)
num_consumed_batches = 2
already_seen_idx = set()
for idx, batch in enumerate(replaced_loader):
if idx >= num_consumed_batches:
break
already_seen_idx.update(batch)
dist.barrier()
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict()
else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict()

# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas
# 重新改造 dataloader
new_loader = DataLoader(
dataset=replaced_loader.dataset,
batch_sampler=BucketedBatchSampler(
replaced_loader.dataset,
length=replaced_loader.dataset._data,
batch_size=batch_size,
shuffle=shuffle,
)
)
new_loader.batch_sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank,
pad=True
)
new_loader.batch_sampler.load_state_dict(sampler_states)
else:
batch_size = replaced_loader.batch_sampler.batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas
# 重新构造 dataloader
batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size)
batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle)
batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank
)
new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler)
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
for idx, batch in enumerate(new_loader):
left_idxes.update(batch)

assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas
assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas


############################################################################
#
# 测试 save 和 load 相关的功能
#
############################################################################
class TestSaveLoad:
"""
测试多卡情况下 save 和 load 相关函数的表现
"""

@classmethod
def setup_class(cls):
# 不在这里 setup 的话会报错
cls.driver = generate_driver(10, 10)

def setup_method(self):
self.dataset = PaddleRandomMaxDataset(20, 10)

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_model(self, only_state_dict):
"""
测试 save_model 和 load_model 函数
"""
try:
path = "model"

dataloader = DataLoader(self.dataset, batch_size=2)
self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10)

if only_state_dict:
self.driver1.save_model(path, only_state_dict)
else:
self.driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((4, 10))])

# 同步
dist.barrier()
self.driver2.load_model(path, only_state_dict)

for idx, batch in enumerate(dataloader):
batch = self.driver1.move_data_to_device(batch)
res1 = self.driver1.model(
batch,
fastnlp_fn=self.driver1.model._layers.model.evaluate_step,
# Driver.model -> DataParallel._layers -> _FleetWrappingModel.model
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
res2 = self.driver2.model(
batch,
fastnlp_fn=self.driver2.model._layers.model.evaluate_step,
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)

assert paddle.equal_all(res1["pred"], res2["pred"])
finally:
if only_state_dict:
synchronize_safe_rm(path)
else:
synchronize_safe_rm(path + ".pdiparams")
synchronize_safe_rm(path + ".pdiparams.info")
synchronize_safe_rm(path + ".pdmodel")

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("fp16", ([True, False]))
@pytest.mark.parametrize("device", ([[0,1]]))
def test_save_and_load_with_bucketedbatchsampler(self, device, only_state_dict, fp16):
"""
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况
"""

try:
path = "model.ckp"
num_replicas = len(device)

self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \
generate_driver(10, 10, device=device, fp16=False)
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=BucketedBatchSampler(
self.dataset,
length=[10 for i in range(len(self.dataset))],
batch_size=4,
)
)
dataloader.batch_sampler.set_distributed(
num_replicas=self.driver1.world_size,
rank=self.driver1.global_rank,
pad=True
)
num_consumed_batches = 2

already_seen_x_set = set()
already_seen_y_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])

# 同步
dist.barrier()

# 保存状态
sampler_states = dataloader.batch_sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
# 加载
# 更改 batch_size
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=BucketedBatchSampler(
self.dataset,
length=[10 for i in range(len(self.dataset))],
batch_size=4,
)
)
dataloader.batch_sampler.set_distributed(
num_replicas=self.driver2.world_size,
rank=self.driver2.global_rank,
pad=True
)
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空

# 2. 检查 batch_sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert replaced_loader.batch_sampler is dataloader.batch_sampler
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler.seed == sampler_states["seed"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas

# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler)

# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
for idx, batch in enumerate(replaced_loader):

left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
res1 = self.driver1.model(
batch,
fastnlp_fn=self.driver1.model._layers.model.evaluate_step,
# Driver.model -> DataParallel._layers -> _FleetWrappingModel.model
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
res2 = self.driver2.model(
batch,
fastnlp_fn=self.driver2.model._layers.model.evaluate_step,
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
assert paddle.equal_all(res1["pred"], res2["pred"])

assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas
assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas
assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas
finally:
synchronize_safe_rm(path)

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("fp16", ([True, False]))
@pytest.mark.parametrize("device", ([[0,1]]))
def test_save_and_load_with_randomsampler(self, device, only_state_dict, fp16):
"""
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况
"""

try:
path = "model.ckp"

num_replicas = len(device)

self.driver1 = generate_driver(10, 10, device=device, fp16=fp16)
self.driver2 = generate_driver(10, 10, device=device, fp16=False)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler.sampler.set_distributed(
num_replicas=self.driver1.world_size,
rank=self.driver1.global_rank,
pad=True
)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
)
num_consumed_batches = 2

already_seen_x_set = set()
already_seen_y_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])

# 同步
dist.barrier()

# 保存状态
sampler_states = dataloader.batch_sampler.sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
# 加载
# 更改 batch_size
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler.sampler.set_distributed(
num_replicas=self.driver2.world_size,
rank=self.driver2.global_rank,
pad=True
)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
)
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")

# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空

# 2. 检查 sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"]
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"]
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches * num_replicas
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"]
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]
# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler)

# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
for idx, batch in enumerate(replaced_loader):

left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
res1 = self.driver1.model(
batch,
fastnlp_fn=self.driver1.model._layers.model.evaluate_step,
# Driver.model -> DataParallel._layers -> _FleetWrappingModel.model
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
res2 = self.driver2.model(
batch,
fastnlp_fn=self.driver2.model._layers.model.evaluate_step,
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
assert paddle.equal_all(res1["pred"], res2["pred"])

assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas
assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas
assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas

finally:
synchronize_safe_rm(path)

+ 40
- 32
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -1,4 +1,3 @@
from dataclasses import replace
import os
from re import S
os.environ["FASTNLP_BACKEND"] = "paddle"
@@ -349,7 +348,7 @@ class TestSingleDeviceFunction:
#
############################################################################

class TestSetDistReproDataloder:
class TestSetDistReproDataloader:
"""
专门测试 set_dist_repro_dataloader 函数的类
"""
@@ -358,7 +357,7 @@ class TestSetDistReproDataloder:
model = PaddleNormalModel_Classification_1(10, 32)
self.driver = PaddleSingleDriver(model, device="cpu")
def test_set_dist_repro_dataloader_with_reproducible_false(self):
def test_with_reproducible_false(self):
"""
测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现
当dist为字符串时,此时应该返回原来的 dataloader
@@ -369,7 +368,7 @@ class TestSetDistReproDataloder:
assert replaced_loader is dataloader

@pytest.mark.parametrize("shuffle", [True, False])
def test_set_dist_repro_dataloader_with_reproducible_true(self, shuffle):
def test_with_reproducible_true(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True),
@@ -394,7 +393,7 @@ class TestSetDistReproDataloder:
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle):
def test_with_dist_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler
@@ -410,7 +409,7 @@ class TestSetDistReproDataloder:
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle):
def test_with_dist_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现
应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler
@@ -429,7 +428,7 @@ class TestSetDistReproDataloder:
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self, shuffle):
def test_with_dataloader_reproducible_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader,且其余各项设置和原来相同
@@ -453,7 +452,7 @@ class TestSetDistReproDataloder:
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self, shuffle):
def test_with_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader,且其余各项设置和原来相同
@@ -498,10 +497,7 @@ class TestSetDistReproDataloder:
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新改造 dataloader
new_loader = DataLoader(
dataset=replaced_loader.dataset,
@@ -514,11 +510,8 @@ class TestSetDistReproDataloder:
new_loader.batch_sampler.load_state_dict(sampler_states)
else:
batch_size = replaced_loader.batch_sampler.batch_size
num_consumed_batches = num_consumed_batches * batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
num_consumed_samples = num_consumed_batches * batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新构造 dataloader
batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size)
batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle)
@@ -536,13 +529,13 @@ class TestSetDistReproDataloder:
#
############################################################################

def generate_random_driver(features, labels):
def generate_random_driver(features, labels, fp16=False, device="cpu"):
"""
生成driver
"""
model = PaddleNormalModel_Classification_1(labels, features)
opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.01)
driver = PaddleSingleDriver(model, device="cpu")
driver = PaddleSingleDriver(model, device=device, fp16=fp16)
driver.set_optimizers(opt)
driver.setup()

@@ -550,8 +543,8 @@ def generate_random_driver(features, labels):

@pytest.fixture
def prepare_test_save_load():
dataset = PaddleRandomMaxDataset(320, 10)
dataloader = DataLoader(dataset, batch_size=32)
dataset = PaddleRandomMaxDataset(40, 10)
dataloader = DataLoader(dataset, batch_size=4)
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
return driver1, driver2, dataloader

@@ -584,21 +577,23 @@ def test_save_and_load_model(prepare_test_save_load, only_state_dict):
rank_zero_rm(path + ".pdiparams.info")
rank_zero_rm(path + ".pdmodel")

@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_with_randombatchsampler(only_state_dict):
# @pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("only_state_dict", ([True]))
@pytest.mark.parametrize("fp16", ([True, False]))
def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
"""
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况
"""

try:
path = "model.ckp"

driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
dataset = PaddleRandomMaxDataset(40, 10)
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
)
driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu")

num_consumed_batches = 2

already_seen_x_set = set()
@@ -633,8 +628,13 @@ def test_save_and_load_with_randombatchsampler(only_state_dict):
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4

# 3. 检查 model 的参数是否正确
# 4. 检查 batch_idx
# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler)


# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
@@ -654,8 +654,12 @@ def test_save_and_load_with_randombatchsampler(only_state_dict):
finally:
rank_zero_rm(path)

@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_with_randomsampler(only_state_dict):
# @pytest.mark.parametrize("only_state_dict", ([True, False]))
# TODO 在有迭代且使用了paddle.jit.save的时候会引发段错误,注释掉任意一段都不会出错
# 但无法在单独的文件中复现
@pytest.mark.parametrize("only_state_dict", ([True]))
@pytest.mark.parametrize("fp16", ([True, False]))
def test_save_and_load_with_randomsampler(only_state_dict, fp16):
"""
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况
"""
@@ -663,7 +667,7 @@ def test_save_and_load_with_randomsampler(only_state_dict):
try:
path = "model.ckp"

driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu")
dataset = PaddleRandomMaxDataset(40, 10)
batch_sampler = BatchSampler(dataset=dataset, batch_size=4)
batch_sampler.sampler = RandomSampler(dataset, True)
@@ -711,8 +715,12 @@ def test_save_and_load_with_randomsampler(only_state_dict):
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"]
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]

# 3. 检查 model 的参数是否正确
# 4. 检查 batch_idx
# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler)

# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()


Loading…
Cancel
Save