Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
b0db3a5998
16 changed files with 780 additions and 521 deletions
  1. +2
    -2
      fastNLP/core/controllers/trainer.py
  2. +5
    -8
      fastNLP/core/drivers/jittor_driver/mpi.py
  3. +19
    -47
      fastNLP/core/drivers/jittor_driver/single_device.py
  4. +376
    -0
      fastNLP/core/drivers/paddle_driver/dist_utils.py
  5. +86
    -56
      fastNLP/core/drivers/paddle_driver/fleet.py
  6. +19
    -32
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  7. +48
    -91
      fastNLP/core/drivers/paddle_driver/single_device.py
  8. +5
    -75
      fastNLP/core/drivers/paddle_driver/utils.py
  9. +18
    -43
      fastNLP/core/drivers/torch_paddle_driver/torch_paddle_driver.py
  10. +0
    -1
      fastNLP/envs/set_backend.py
  11. +1
    -1
      fastNLP/modules/mix_modules/mix_module.py
  12. +27
    -53
      tests/core/controllers/test_trainer_paddle.py
  13. +39
    -28
      tests/core/drivers/paddle_driver/test_fleet.py
  14. +133
    -82
      tests/core/drivers/paddle_driver/test_single_device.py
  15. +1
    -1
      tests/helpers/callbacks/helper_callbacks.py
  16. +1
    -1
      tests/helpers/models/paddle_model.py

+ 2
- 2
fastNLP/core/controllers/trainer.py View File

@@ -219,10 +219,10 @@ class Trainer(TrainerEventTrigger):

""" 设置内部的 Evaluator """
if metrics is None and evaluate_dataloaders is not None:
raise ValueError("You have set 'validate_dataloader' but forget to set 'metrics'.")
raise ValueError("You have set 'evaluate_dataloader' but forget to set 'metrics'.")

if metrics is not None and evaluate_dataloaders is None:
raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.")
raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloader'.")

self.evaluator = None
self.monitor = monitor


+ 5
- 8
fastNLP/core/drivers/jittor_driver/mpi.py View File

@@ -1,5 +1,5 @@
import os
from typing import Optional, Union
from typing import Optional, Union, Callable, Dict, Tuple

from .jittor_driver import JittorDriver
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
@@ -61,14 +61,11 @@ class JittorMPIDriver(JittorDriver):
return self._data_device
return self.model_device

def train_step(self, batch):
return self._train_step(batch)

def validate_step(self, batch):
return self._validate_step(batch)
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
pass

def test_step(self, batch):
return self._test_step(batch)
def get_model_call_fn(self, fn: str) -> Tuple:
pass

def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]],
reproducible: bool = False, sampler_or_batch_sampler=None):


+ 19
- 47
fastNLP/core/drivers/jittor_driver/single_device.py View File

@@ -1,9 +1,11 @@
from typing import Dict, Union
from typing import Dict, Union, Tuple, Callable, Optional

from .jittor_driver import JittorDriver
from fastNLP.core.utils import auto_param_call
from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
from fastNLP.core.log import logger

if _NEED_IMPORT_JITTOR:
import jittor
@@ -27,42 +29,6 @@ class JittorSingleDriver(JittorDriver):
self.global_rank = 0
self.world_size = 1

if hasattr(self.model, "train_step"):
self._train_step = self.model.train_step
self._train_signature_fn = None
else:
self._train_step = self.model
model = self.unwrap_model()
self._train_signature_fn = model.execute

if hasattr(self.model, "evaluate_step"):
self._validate_step = self.model.evaluate_step
self._validate_signature_fn = None
elif hasattr(self.model, "test_step"):
self._validate_step = self.model.test_step
self._validate_signature_fn = self.model.test_step
else:
self._validate_step = self.model
model = self.unwrap_model()
self._validate_signature_fn = model.execute

if hasattr(self.model, "test_step"):
self._test_step = self.model.test_step
self._test_signature_fn = None
elif hasattr(self.model, "evaluate_step"):
self._test_step = self.model.evaluate_step
self._test_signature_fn = self.model.evaluate_step
else:
self._test_step = self.model
model = self.unwrap_model()
self._test_signature_fn = model.execute

def train_step(self, batch) -> Dict:
if isinstance(batch, Dict):
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
else:
return self._train_step(batch)

def step(self):
"""
jittor optimizers 的step函数可以传入参数loss
@@ -80,18 +46,24 @@ class JittorSingleDriver(JittorDriver):
for optimizer in self.optimizers:
optimizer.zero_grad()

def validate_step(self, batch):
if isinstance(batch, Dict):
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
return self._validate_step(batch)

def test_step(self, batch):

if isinstance(batch, Dict):
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
return fn(batch)

def get_model_call_fn(self, fn: str) -> Tuple:
if hasattr(self.model, fn):
fn = getattr(self.model, fn)
if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
return fn, None
elif fn in {"train_step", "evaluate_step"}:
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...')
return self.model, self.model.forward
else:
return self._test_step(batch)
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")

def unwrap_model(self):
return self.model


+ 376
- 0
fastNLP/core/drivers/paddle_driver/dist_utils.py View File

@@ -0,0 +1,376 @@
import io
import pickle
_pickler = pickle.Pickler
_unpickler = pickle.Unpickler
from typing import Any, List

from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from torch import distributed as dist
if _TORCH_GREATER_EQUAL_1_8:
try:
from torch._C._distributed_c10d import ProcessGroupGloo
from torch._C._distributed_c10d import _ProcessGroupWrapper
except ImportError:
pass


from fastNLP.core.utils import apply_to_collection


def _validate_output_list_for_rank(my_rank, dst, gather_list):
if dst == my_rank:
if not gather_list:
raise ValueError(
"Argument ``gather_list`` must be specified on destination rank."
)
elif gather_list:
raise ValueError(
"Argument ``gather_list`` must NOT be specified "
"on non-destination ranks."
)


def fastnlp_paddle_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP):
"""
从其它 rank gather 东西到 dst rank 。

Gathers picklable objects from the whole group in a single process.
Similar to :func:`gather`, but Python objects can be passed in. Note that the
object must be picklable in order to be gathered.

Args:
obj (Any): Input object. Must be picklable.
object_gather_list (list[Any]): Output list. On the ``dst`` rank, it
should be correctly sized as the size of the group for this
collective and will contain the output. Must be ``None`` on non-dst
ranks. (default is ``None``)
dst (int, optional): Destination rank. (default is 0)
group: (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Default is ``None``.

Returns:
None. On the ``dst`` rank, ``object_gather_list`` will contain the
output of the collective.

.. note:: Note that this API differs slightly from the gather collective
since it does not provide an async_op handle and thus will be a blocking
call.

.. note:: Note that this API is not supported when using the NCCL backend.

.. warning::
:func:`gather_object` uses ``pickle`` module implicitly, which is
known to be insecure. It is possible to construct malicious pickle data
which will execute arbitrary code during unpickling. Only call this
function with data you trust.

Example::
>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> # Assumes world_size of 3.
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
>>> output = [None for _ in gather_objects]
>>> dist.gather_object(
gather_objects[dist.get_rank()],
output if dist.get_rank() == 0 else None,
dst=0
)
>>> # On rank 0
>>> output
['foo', 12, {1: 2}]
"""
if group is None:
group = DEFAULT_TORCH_GROUP

if dist.distributed_c10d._rank_not_in_group(group):
return

# Ensure object_gather_list is specified appopriately.
my_rank = dist.get_rank()
_validate_output_list_for_rank(my_rank, dst, object_gather_list)
# 防止 unpickle 的时候出现在了发送的 gpu 上。
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
input_tensor, local_size = _object_to_tensor(obj)
group_backend = dist.get_backend(group)
current_device = torch.device("cpu")
is_nccl_backend = group_backend == dist.Backend.NCCL
if is_nccl_backend:
current_device = torch.device('cuda', torch.cuda.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
# Gather all local sizes. This is so that we can find the max size, and index
# until the correct size when deserializing the tensors.
group_size = dist.get_world_size(group=group)
object_sizes_tensor = torch.zeros(group_size, dtype=torch.long, device=current_device)
object_size_list = [
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
]
# Allgather tensor sizes. An all-gather is needed here despite this being a
# gather, since each rank needs to broadcast a tensor of the same (maximal)
# size.
dist.all_gather(object_size_list, local_size, group=group)
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
# Resize tensor to max size across all ranks.
input_tensor.resize_(max_object_size)
# Avoid populating output tensors if the result won't be gathered on this rank.
if my_rank == dst:
coalesced_output_tensor = torch.empty(
max_object_size * group_size, dtype=torch.uint8, device=current_device
)
# Output tensors are nonoverlapping views of coalesced_output_tensor
output_tensors = [
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
for i in range(group_size)
]
# All ranks call gather with equal-sized tensors.
dist.gather(
input_tensor,
gather_list=output_tensors if my_rank == dst else None,
dst=dst,
group=group,
)
if my_rank != dst:
return
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8) # type: ignore[call-overload]
tensor_size = object_size_list[i]
object_gather_list[i] = _tensor_to_object(tensor, tensor_size)


def _object_to_tensor(obj, device=None):
f = io.BytesIO()
_pickler(f).dump(obj)
byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) # type: ignore[attr-defined]
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
# Otherwise, it will casue 100X slowdown.
# See: https://github.com/pytorch/pytorch/issues/65696
byte_tensor = torch.ByteTensor(byte_storage)
local_size = torch.LongTensor([byte_tensor.numel()])
if device is not None:
byte_tensor = byte_tensor.to(device)
local_size = local_size.to(device)
return byte_tensor, local_size


def _tensor_to_object(tensor, tensor_size):
buf = tensor.detach().cpu().numpy().tobytes()[:tensor_size]
return _unpickler(io.BytesIO(buf)).load()


def send_recv_object(obj, src, cur_rank, device, group=None, tag=0):
# src rank send to all other ranks
size = torch.LongTensor([0]).to(device)

if cur_rank == src:
world_size = dist.get_world_size(group=group)
tensor, size = _object_to_tensor(obj)
tensor = tensor.to(device)
size = size.to(device)

# 首先同步 obj 的 size 的信息;
dist.broadcast(size, src, group=group)
for subrank in range(world_size):
if subrank != src:
dist.send(tensor=tensor, dst=subrank, group=group, tag=tag)
else:
dist.broadcast(size, src, group=group)
tensor = torch.ByteTensor([0] * size).to(device)
dist.recv(tensor=tensor, src=src, group=group, tag=tag)

return _tensor_to_object(tensor.cpu(), size)

def fastnlp_paddle_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) ->List:
"""
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。

example:
obj = {
'a': [1, 1],
'b': [[1, 2], [1, 2]],
'c': {
'd': [1, 2]
}
}
->
[
{'a': 1, 'b':[1, 2], 'c':{'d': 1}},
{'a': 1, 'b':[1, 2], 'c':{'d': 2}}
]

:param obj: 任意结构的数据,如果为 tensor ,需要保证每个显卡上的 tensor 的形状是一样的。如果传入的是非 tensor 对象都将直接进行
序列化之后进行传输。
:param device: 当前该参数无意义。
:param group:
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。
"""
if group is None:
group = DEFAULT_TORCH_GROUP
if isinstance(obj, torch.Tensor):
objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))]
dist.all_gather(objs, obj, group=group)
else:
objs = [None for _ in range(dist.get_world_size(group))]
# 防止 unpickle 的时候弄到发送的 gpu 上了
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
if _TORCH_GREATER_EQUAL_1_8:
dist.all_gather_object(objs, obj, group=group)
else:
objs = all_gather_object(objs, obj, group=group)
return objs


def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GROUP):
"""
将 src 上的 obj 对象广播到其它 rank 上。

:param obj:
:param src:
:param device:
:param group:
:return:
"""
if group is None:
group = DEFAULT_TORCH_GROUP
cur_rank = dist.get_rank(group)
if cur_rank == src:
# 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
if _TORCH_GREATER_EQUAL_1_8:
if cur_rank!=src:
get_obj = [None]
dist.broadcast_object_list(get_obj, src=src, group=group)
return get_obj[0]
else:
dist.broadcast_object_list([obj], src=src, group=group)
return obj
if device is None:
device = torch.cuda.current_device()

if cur_rank == src:
tensor, size = _object_to_tensor(obj, device=device)
else:
size = torch.LongTensor([0]).to(device)

dist.broadcast(size, src=src, group=group)
if cur_rank != src:
tensor = torch.empty(
size.int().item(), # type: ignore[arg-type]
dtype=torch.uint8,
device=device
)
dist.broadcast(tensor, src=src, group=group)

return _tensor_to_object(tensor, tensor_size=size.item())


def _check_for_nccl_backend(group):
pg = group or dist.distributed_c10d._get_default_group()
# It is not expected for PG to be wrapped many times, but support it just
# in case
while isinstance(pg, _ProcessGroupWrapper):
pg = pg.wrapped_pg

return (
dist.is_nccl_available() and
isinstance(pg, dist.ProcessGroupNCCL)
)


def all_gather_object(object_list, obj, group=None):
"""
复制 pytorch 的代码,使得可以版本兼容低版本的 pytorch 。

Gathers picklable objects from the whole group into a list. Similar to
:func:`all_gather`, but Python objects can be passed in. Note that the object
must be picklable in order to be gathered.

Args:
object_list (list[Any]): Output list. It should be correctly sized as the
size of the group for this collective and will contain the output.
object (Any): Pickable Python object to be broadcast from current process.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Default is ``None``.

Returns:
None. If the calling rank is part of this group, the output of the
collective will be populated into the input ``object_list``. If the
calling rank is not part of the group, the passed in ``object_list`` will
be unmodified.

.. note:: Note that this API differs slightly from the :func:`all_gather`
collective since it does not provide an ``async_op`` handle and thus
will be a blocking call.

.. note:: For NCCL-based processed groups, internal tensor representations
of objects must be moved to the GPU device before communication takes
place. In this case, the device used is given by
``torch.cuda.current_device()`` and it is the user's responsiblity to
ensure that this is set so that each rank has an individual GPU, via
``torch.cuda.set_device()``.

.. warning::
:func:`all_gather_object` uses ``pickle`` module implicitly, which is
known to be insecure. It is possible to construct malicious pickle data
which will execute arbitrary code during unpickling. Only call this
function with data you trust.

Example::
>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> # Assumes world_size of 3.
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
>>> output = [None for _ in gather_objects]
>>> dist.all_gather_object(output, gather_objects[dist.get_rank()])
>>> output
['foo', 12, {1: 2}]
"""
if dist.distributed_c10d._rank_not_in_group(group):
return
if _TORCH_GREATER_EQUAL_1_8:
current_device = torch.device("cpu")
is_nccl_backend = _check_for_nccl_backend(group)
if is_nccl_backend:
# See note about using torch.cuda.current_device() here in docstring.
# We cannot simply use my_rank since rank == device is not necessarily
# true.
current_device = torch.device("cuda", torch.cuda.current_device())
else:
current_device = torch.cuda.current_device()

input_tensor, local_size = _object_to_tensor(obj, device=current_device)

# Gather all local sizes. This is so that we can find the max size, and index
# until the correct size when deserializing the tensors.
group_size = dist.get_world_size(group=group)
object_sizes_tensor = torch.zeros(
group_size, dtype=torch.long, device=current_device
)
object_size_list = [
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
]
# Allgather tensor sizes
dist.all_gather(object_size_list, local_size, group=group)
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
# Resize tensor to max size across all ranks.
input_tensor.resize_(max_object_size)
coalesced_output_tensor = torch.empty(
max_object_size * group_size, dtype=torch.uint8, device=current_device
)
# Output tensors are nonoverlapping views of coalesced_output_tensor
output_tensors = [
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
for i in range(group_size)
]
dist.all_gather(output_tensors, input_tensor, group=group)
# Deserialize outputs back to object.
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8)
if tensor.device != torch.device("cpu"):
tensor = tensor.cpu()
tensor_size = object_size_list[i]
object_list[i] = _tensor_to_object(tensor, tensor_size)
return object_list

+ 86
- 56
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -1,13 +1,12 @@
import os
import shutil
from functools import partial
from typing import List, Union, Optional, Dict
from typing import List, Union, Optional, Dict, Tuple, Callable

from .paddle_driver import PaddleDriver
from .fleet_launcher import FleetLauncher
from .utils import (
_FleetWrappingModel,
ForwardState,
_MODE_PARAMETER,
get_device_from_visible,
reset_seed,
replace_sampler,
@@ -47,8 +46,7 @@ if _NEED_IMPORT_PADDLE:
__all__ = [
"PaddleFleetDriver",
]
# if os.path.exists(self.gloo_rendezvous_dir):
# shutil.rmtree(self.gloo_rendezvous_dir)

class PaddleFleetDriver(PaddleDriver):
def __init__(
self,
@@ -104,34 +102,6 @@ class PaddleFleetDriver(PaddleDriver):
# 我们就直接将 model_device 置为 None;
self._model_device = None

def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call):
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(step_fn, batch, signature_fn=signature_fn)
else:
return self._validate_step(batch)

model = model._layers
if hasattr(model, "train_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `train_step` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.")
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)

if hasattr(model, "evaluate_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `evaluate_step` method, which we can not call actually, "
"we will call `forward` function instead of `evaluate_step` and you should note that.")
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)

if hasattr(model, "test_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `test_step` method, which we can not call actually, we will"
" call `forward` function instead of `test_step` and you should note that.")
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)

# 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上;
self._data_device = kwargs.get("data_device", None)
if self._data_device is not None:
@@ -150,8 +120,6 @@ class PaddleFleetDriver(PaddleDriver):

self.world_size = None
self.global_rank = 0
self._configured = False # 防止重复调用 configure_ddp() 函数使用
self._has_setup = False # 防止重复调用 setup() 函数

self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {})
check_user_specific_params(self._fleet_kwargs, DataParallel.__init__)
@@ -173,6 +141,9 @@ class PaddleFleetDriver(PaddleDriver):
os.makedirs(name=self.output_from_new_proc, exist_ok=True)
self.output_from_new_proc = os.path.abspath(self.output_from_new_proc)

self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的;
self._has_fleetwrapped = False # 判断传入的模型是否经过 _has_fleetwrapped 包裹;

def setup(self):
"""
在主进程拉起其它子进程,将主进程作为rank 0
@@ -268,17 +239,17 @@ class PaddleFleetDriver(PaddleDriver):
dist.barrier()

def configure_fleet(self):
if not self._configured and not isinstance(self.model, DataParallel):
if not self._has_fleetwrapped and not isinstance(self.model, DataParallel):
self.model = DataParallel(
_FleetWrappingModel(self.model),
**self._fleet_kwargs
)
self._has_fleetwrapped = True

self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call)
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call)
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call)

self._configured = True
def on_exception(self):
if os.path.exists(self.gloo_rendezvous_dir):
shutil.rmtree(self.gloo_rendezvous_dir)
super().on_exception()

@property
def world_size(self) -> int:
@@ -310,14 +281,39 @@ class PaddleFleetDriver(PaddleDriver):
return self._data_device
return self.model_device

def train_step(self, batch):
return self._train_step(batch)

def validate_step(self, batch):
return self._validate_step(batch)
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
if self._has_fleetwrapped:
return self.model(batch, fastnlp_fn=fn, fastnlp_signature_fn=signature_fn,
wo_auto_param_call=self.wo_auto_param_call)
else:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
return fn(batch)

def get_model_call_fn(self, fn: str) -> Tuple:
model = self.unwrap_model()
if self._has_fleetwrapped:
if hasattr(model, fn):
fn = getattr(model, fn)
if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute of model is not `Callable`.")
return fn, None
elif fn in {"train_step", "evaluate_step"}:
return model, model.forward
else:
raise RuntimeError(f"There is no `{fn}` method in your model.")
else:
if hasattr(model, fn):
logger.warning("Notice your model is a `DistributedDataParallel` model. And your model also implements "
f"the `{fn}` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.")
elif fn not in {"train_step", "evaluate_step"}:
raise RuntimeError(f"There is no `{fn}` method in your model. And also notice that your model is a "
"`DistributedDataParallel` model, which means that we will only call model.forward "
"function when we are in forward propagation.")

def test_step(self, batch):
return self._test_step(batch)
return self.model, model.forward

def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]],
reproducible: bool = False, sampler_or_batch_sampler=None):
@@ -406,14 +402,6 @@ class PaddleFleetDriver(PaddleDriver):
else:
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).")

def backward(self, loss):
self.grad_scaler.scale(loss).backward()

def step(self):
for optimizer in self.optimizers:
self.grad_scaler.step(optimizer)
self.grad_scaler.update()

def is_global_zero(self):
return self.global_rank == 0

@@ -450,3 +438,45 @@ class PaddleFleetDriver(PaddleDriver):
if not isinstance(each_optimizer, (Optimizer, DistribuedOptimizer)):
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, "
f"not {type(each_optimizer)}.")

def broadcast_object(self, obj, src:int=0, group=None, **kwargs):
"""
从 src 端将 obj 对象(可能是 tensor ,可能是 object )发送到 dst 处。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行
传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。

:param obj: obj,可能是 Tensor 或 嵌套类型的数据
:param int src: source 的 global rank 。
:param int dst: target 的 global rank,可以是多个目标 rank
:param group: 所属的 group
:param kwargs:
:return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回
接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。
"""
return
return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group)

def all_gather(self, obj, group) -> List:
"""
将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过
pickle 进行序列化,接收到之后再反序列化。

example:
obj = {
'a': [1, 1],
'b': [[1, 2], [1, 2]],
'c': {
'd': [1, 2]
}
}
->
[
{'a': 1, 'b':[1, 2], 'c':{'d': 1}},
{'a': 1, 'b':[1, 2], 'c':{'d': 2}}
]

:param obj: 需要传输的对象,在每个rank上都应该保持相同的结构。
:param group:
:return:
"""
return
return fastnlp_paddle_all_gather(obj, group=group)

+ 19
- 32
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -71,6 +71,14 @@ class PaddleDriver(Driver):
for optimizer in self.optimizers:
optimizer.clear_grad()

def backward(self, loss):
self.grad_scaler.scale(loss).backward()

def step(self):
for optimizer in self.optimizers:
self.grad_scaler.step(optimizer)
self.grad_scaler.update()

@staticmethod
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
r"""
@@ -115,28 +123,6 @@ class PaddleDriver(Driver):
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, "
f"not {type(each_optimizer)}.")

def check_evaluator_mode(self, mode: str):
r"""
因为我们在具体的 driver 的 evaluate_step 和 test_step 的逻辑是如果模型没有实现本函数,那么就去检测模型是否实现了另一个函数;
因此如果用户的 evaluator evaluate_fn 是 validate,但是传入的 model 却没有实现 evaluate_step 函数,而是实现了 test_step 函数,那么
我们应当提醒用户这一行为;
"""
model = self.unwrap_model()
if mode == "validate":
if not hasattr(model, "evaluate_step"):
if hasattr(model, "test_step"):
logger.warning(
"Your model does not have 'evaluate_step' method but has 'test_step' method, but you"
"are using 'Evaluator.validate', we are going to use 'test_step' to substitute for"
"'evaluate_step'.")

else:
if not hasattr(model, "test_step"):
if hasattr(model, "evaluate_step"):
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you"
"are using 'Evaluator.test', we are going to use 'evaluate_step' to substitute for"
"'test_step'.")

@staticmethod
def tensor_to_numeric(tensor, reduce=None):
r"""
@@ -258,20 +244,21 @@ class PaddleDriver(Driver):
if hasattr(sampler, "state_dict") and callable(sampler.state_dict):
sampler_states = sampler.state_dict()
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples
# 会造成多余实际消耗的问题。
num_consumed_samples_array = sampler_states.pop("num_consumed_samples_array", None)
# 会造成多余实际消耗的问题。
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
try:
sampler_states["num_consumed_samples"] = num_consumed_batches * dataloader_args.batch_size
except: # 有可能 batch_size 为 None,就只有损失精度了
pass
assert sampler_states["num_consumed_samples"] != -1, "This is a bug, please report."
if isinstance(sampler, ReproducibleSampler):
# 如果是 sampler 的话,需要计算出实际的 sample 数目
try:
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
except: # 有可能 batch_size 为 None,就只有损失精度了
num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
states['sampler_states'] = sampler_states
else:
raise RuntimeError(
"The sampler has no `state_dict()` method, it will fail to recover to the specific batch.")
states["sampler_states"] = sampler_states

# 2. 保存模型的状态;
if should_save_model:


+ 48
- 91
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -1,5 +1,5 @@
import os
from typing import Optional, Dict, Union
from typing import Optional, Dict, Union, Callable, Tuple

from .paddle_driver import PaddleDriver
from .utils import replace_batch_sampler, replace_sampler, get_device_from_visible
@@ -11,16 +11,19 @@ from fastNLP.core.utils import (
get_paddle_device_id,
paddle_move_data_to_device,
)
from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.core.samplers import (
ReproducibleBatchSampler,
RandomBatchSampler,
ReproducibleSampler,
RandomSampler,
re_instantiate_sampler,
)
from fastNLP.core.log import logger

if _NEED_IMPORT_PADDLE:
import paddle
from paddle import DataParallel
from paddle.fluid.reader import _DatasetKind

__all__ = [
@@ -28,109 +31,57 @@ __all__ = [
]

class PaddleSingleDriver(PaddleDriver):
def __init__(self, model, device: str, fp16: Optional[bool] = False, **kwargs):
def __init__(self, model, device: Union[str, int], fp16: Optional[bool] = False, **kwargs):
if isinstance(model, DataParallel):
raise ValueError("`paddle.DataParallel` is not supported in `PaddleSingleDriver`")

cuda_visible_devices = os.environ.get(USER_CUDA_VISIBLE_DEVICES, None)
if cuda_visible_devices == "":
device = "cpu"
logger.info("You have set `CUDA_VISIBLE_DEVICES` to '' in system environment variable, and we are gonna to"
"use `cpu` instead of `gpu` device.")

super(PaddleSingleDriver, self).__init__(model, fp16=fp16, **kwargs)

if device is None:
raise ValueError("Parameter `device` can not be None in `PaddleSingleDriver`.")

if device != "cpu":
if isinstance(device, int):
device_id = device
else:
device_id = get_paddle_device_id(device)
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id]
self.model_device = get_paddle_gpu_str(device)

self.local_rank = 0
self.global_rank = 0
self.world_size = 1

if isinstance(model, paddle.DataParallel):
# 注意这里的 unwrap_model 调用的是具体子类的方法;
model = self.unwrap_model()
if hasattr(model, "train_step"):
logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also "
"implements the `train_step` method, which we can not call actually, we will "
" call `forward` function instead of `train_step` and you should note that.")
self._train_step = self.model
self._train_signature_fn = model.forward

if hasattr(model, "evaluate_step"):
logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also "
"implements the `evaluate_step` method, which we can not call actually, we "
"will call `forward` function instead of `evaluate_step` and you should note that.")
self._validate_step = self.model
self._validate_signature_fn = model.forward

if hasattr(model, "test_step"):
logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also "
"implements the `test_step` method, which we can not call actually, we will "
"call `forward` function instead of `test_step` and you should note that.")
self._test_step = self.model
self._test_signature_fn = model.forward
else:
if hasattr(self.model, "train_step"):
self._train_step = self.model.train_step
self._train_signature_fn = None
else:
self._train_step = self.model
# 输入的模型是 `DataParallel`,我们需要保证其 signature_fn 是正确的;
model = self.unwrap_model()
self._train_signature_fn = model.forward

if hasattr(self.model, "evaluate_step"):
self._validate_step = self.model.evaluate_step
self._validate_signature_fn = None
elif hasattr(self.model, "test_step"):
self._validate_step = self.model.test_step
self._validate_signature_fn = self.model.test_step
else:
self._validate_step = self.model
model = self.unwrap_model()
self._validate_signature_fn = model.forward

if hasattr(self.model, "test_step"):
self._test_step = self.model.test_step
self._test_signature_fn = None
elif hasattr(self.model, "evaluate_step"):
self._test_step = self.model.evaluate_step
self._test_signature_fn = self.model.evaluate_step
else:
self._test_step = self.model
model = self.unwrap_model()
self._test_signature_fn = model.forward

def setup(self):
device = self.model_device
if device != "cpu":
device_id = get_paddle_device_id(device)
device_id = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id]
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
device = get_device_from_visible(device, output_type=str)
device = get_device_from_visible(device, output_type=str)
paddle.device.set_device(device)
self.model.to(device)

def train_step(self, batch) -> Dict:
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理;
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
return self._train_step(batch)

def backward(self, loss):
self.grad_scaler.scale(loss).backward()

def step(self):
for optimizer in self.optimizers:
self.grad_scaler.step(optimizer)
self.grad_scaler.update()

def validate_step(self, batch) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
return fn(batch)

def get_model_call_fn(self, fn: str) -> Tuple:
if hasattr(self.model, fn):
fn = getattr(self.model, fn)
if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
return fn, None
elif fn in {"train_step", "evaluate_step"}:
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...')
return self.model, self.model.forward
else:
return self._validate_step(batch)

def test_step(self, batch) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
else:
return self._test_step(batch)
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")

def move_data_to_device(self, batch: 'paddle.Tensor'):
r"""
@@ -164,12 +115,18 @@ class PaddleSingleDriver(PaddleDriver):
return replace_sampler(dataloader, sampler)

if reproducible:
batch_sampler = RandomBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last
)
return replace_batch_sampler(dataloader, batch_sampler)
if isinstance(args.sampler, paddle.io.RandomSampler):
# 如果本来就是随机的,直接替换
sampler = RandomSampler(args.sampler.data_source)
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
else:
batch_sampler = RandomBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last
)
return replace_batch_sampler(dataloader, batch_sampler)
else:
return dataloader



+ 5
- 75
fastNLP/core/drivers/paddle_driver/utils.py View File

@@ -11,7 +11,6 @@ from typing import Dict, Optional, Union

from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.core.utils import get_paddle_device_id, auto_param_call, paddle_to
from fastNLP.core.samplers import RandomSampler
from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES
from fastNLP.core.log import logger

@@ -87,8 +86,6 @@ class ForwardState(IntEnum):
TEST = 2
PREDICT = 3

_MODE_PARAMETER = "forward_state"

class _FleetWrappingModel(Layer):
"""
参考_DDPWrappingModel,paddle的分布式训练也需要用paddle.nn.DataParallel进行包装,采用和
@@ -98,83 +95,16 @@ class _FleetWrappingModel(Layer):
super(_FleetWrappingModel, self).__init__()
self.model = model

if isinstance(model, paddle.DataParallel):
model = model._layers
if hasattr(model, "train_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `train_step` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.")
self._train_step = self.model
self._train_signature_fn = model.forward

if hasattr(model, "evaluate_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `evaluate_step` method, which we can not call actually, "
"we will call `forward` function instead of `evaluate_step` and you should note that.")
self._validate_step = self.model
self._validate_signature_fn = model.forward

if hasattr(model, "test_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `test_step` method, which we can not call actually, we will"
" call `forward` function instead of `test_step` and you should note that.")
self._test_step = self.model
self._test_signature_fn = model.forward
else:
if hasattr(model, "train_step"):
self._train_step = model.train_step
self._train_signature_fn = None
else:
self._train_step = model
self._train_signature_fn = model.forward

if hasattr(model, "evaluate_step"):
self._validate_step = model.validate_step
self._validate_signature_fn = None
elif hasattr(model, "test_step"):
self._validate_step = model.test_step
self._validate_signature_fn = None
else:
self._validate_step = model
self._validate_signature_fn = model.forward

if hasattr(model, "test_step"):
self._test_step = model.test_step
self._test_signature_fn = None
elif hasattr(model, "evaluate_step"):
self._test_step = model.validate_step
self._test_signature_fn = None
else:
self._test_step = model
self._test_signature_fn = model.forward

def forward(self, batch, **kwargs) -> Dict:

forward_state = kwargs.pop(_MODE_PARAMETER)
fn = kwargs.pop("fastnlp_fn")
signature_fn = kwargs.pop("fastnlp_signature_fn")
wo_auto_param_call = kwargs.pop("wo_auto_param_call")

if forward_state == ForwardState.TRAIN:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
else:
return self._train_step(batch)
elif forward_state == ForwardState.VALIDATE:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
else:
return self._validate_step(batch)
elif forward_state == ForwardState.TEST:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
else:
return self._test_step(batch)
elif forward_state == ForwardState.PREDICT:
raise NotImplementedError("'PREDICT' evaluate_fn has not been implemented.")
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
raise NotImplementedError("You should direct a concrete evaluate_fn.")
return fn(batch)

class DummyGradScaler:
"""


+ 18
- 43
fastNLP/core/drivers/torch_paddle_driver/torch_paddle_driver.py View File

@@ -1,6 +1,7 @@
from typing import Optional, Dict, Union, Callable
from typing import Optional, Dict, Union, Callable, Tuple

from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH
from fastNLP.core.utils.utils import _get_fun_msg


if _NEED_IMPORT_PADDLE:
@@ -48,33 +49,6 @@ class TorchPaddleDriver(Driver):
elif self._data_device is not None:
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.")

if hasattr(self.model, "train_step"):
self._train_step = self.model.train_step
self._train_signature_fn = None
else:
self._train_step = self.model
self._train_signature_fn = self.model.forward

if hasattr(self.model, "evaluate_step"):
self._validate_step = self.model.evaluate_step
self._validate_signature_fn = None
elif hasattr(self.model, "test_step"):
self._validate_step = self.model.test_step
self._validate_signature_fn = self.model.forward
else:
self._validate_step = self.model
self._validate_signature_fn = self.model.forward

if hasattr(self.model, "test_step"):
self._test_step = self.model.test_step
self._test_signature_fn = None
elif hasattr(self.model, "evaluate_step"):
self._test_step = self.model.evaluate_step
self._test_signature_fn = self.model.forward
else:
self._test_step = self.model
self._test_signature_fn = self.model.forward

def setup(self):
if self.model_device is not None:
paddle.device.set_device(self.model_device.replace("cuda", "gpu"))
@@ -103,12 +77,6 @@ class TorchPaddleDriver(Driver):
f"'torch.optim.Optimizer' or 'paddle.optimizers.Optimizer' type, "
f"not {type(each_optimizer)}.")

def train_step(self, batch) -> Dict:
if isinstance(batch, Dict):
return auto_param_call(self._train_step, batch)
else:
return self._train_step(batch)

def step(self):
for optimizer in self.optimizers:
optimizer.step()
@@ -125,17 +93,24 @@ class TorchPaddleDriver(Driver):
else:
raise ValueError("Unknown optimizers type.")

def validate_step(self, batch):
if isinstance(batch, Dict):
return auto_param_call(self._validate_step, batch)
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
return self._validate_step(batch)

def test_step(self, batch):
if isinstance(batch, Dict):
return auto_param_call(self._test_step, batch)
return fn(batch)

def get_model_call_fn(self, fn: str) -> Tuple:
if hasattr(self.model, fn):
fn = getattr(self.model, fn)
if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
return fn, None
elif fn in {"train_step", "evaluate_step"}:
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...')
return self.model, self.model.forward
else:
return self._test_step(batch)
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")

def predict_step(self, batch):
if isinstance(batch, Dict):


+ 0
- 1
fastNLP/envs/set_backend.py View File

@@ -5,7 +5,6 @@
import os
import json
import sys
import subprocess
from collections import defaultdict




+ 1
- 1
fastNLP/modules/mix_modules/mix_module.py View File

@@ -85,7 +85,7 @@ class MixModule:
def test_step(self, batch):
raise NotImplementedError

def validate_step(self, batch):
def evaluate_step(self, batch):
raise NotImplementedError

def train(self):


+ 27
- 53
tests/core/controllers/test_trainer_paddle.py View File

@@ -1,13 +1,11 @@
import pytest
import os
os.environ["FASTNLP_BACKEND"] = "paddle"
from typing import Any
from dataclasses import dataclass

from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.metrics.accuracy import Accuracy
from fastNLP.core.callbacks.progress_callback import RichCallback
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK

from paddle.optimizer import Adam
from paddle.io import DataLoader
@@ -19,40 +17,18 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordM
from tests.helpers.utils import magic_argv_env_context

@dataclass
class MNISTTrainPaddleConfig:
class TrainPaddleConfig:
num_labels: int = 10
feature_dimension: int = 784
feature_dimension: int = 10

batch_size: int = 32
batch_size: int = 2
shuffle: bool = True
validate_every = -5
evaluate_every = 2

driver: str = "paddle"
device = "gpu"

@dataclass
class MNISTTrainFleetConfig:
num_labels: int = 10
feature_dimension: int = 784

batch_size: int = 32
shuffle: bool = True
validate_every = -5

@dataclass
class TrainerParameters:
model: Any = None
optimizers: Any = None
train_dataloader: Any = None
validate_dataloaders: Any = None
input_mapping: Any = None
output_mapping: Any = None
metrics: Any = None

@pytest.mark.parametrize("driver,device", [("paddle", "cpu")("paddle", 1)])
@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)])
# @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])])
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.7, larger_better=True),
RichCallback(5), RecordLossCallback(loss_threshold=0.3)]])
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True),
RichCallback(5)]])
@magic_argv_env_context
def test_trainer_paddle(
driver,
@@ -60,38 +36,36 @@ def test_trainer_paddle(
callbacks,
n_epochs=2,
):
trainer_params = TrainerParameters()

trainer_params.model = PaddleNormalModel_Classification_1(
num_labels=MNISTTrainPaddleConfig.num_labels,
feature_dimension=MNISTTrainPaddleConfig.feature_dimension
model = PaddleNormalModel_Classification_1(
num_labels=TrainPaddleConfig.num_labels,
feature_dimension=TrainPaddleConfig.feature_dimension
)
trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001)
optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001)
train_dataloader = DataLoader(
dataset=PaddleRandomMaxDataset(6400, 10),
batch_size=MNISTTrainPaddleConfig.batch_size,
dataset=PaddleRandomMaxDataset(20, 10),
batch_size=TrainPaddleConfig.batch_size,
shuffle=True
)
val_dataloader = DataLoader(
dataset=PaddleRandomMaxDataset(1000, 10),
batch_size=MNISTTrainPaddleConfig.batch_size,
dataset=PaddleRandomMaxDataset(20, 10),
batch_size=TrainPaddleConfig.batch_size,
shuffle=True
)
trainer_params.train_dataloader = train_dataloader
trainer_params.validate_dataloaders = val_dataloader
trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every
trainer_params.metrics = {"acc": Accuracy(backend="paddle")}
train_dataloader = train_dataloader
evaluate_dataloaders = val_dataloader
evaluate_every = TrainPaddleConfig.evaluate_every
metrics = {"acc": Accuracy(backend="paddle")}
trainer = Trainer(
model=trainer_params.model,
model=model,
driver=driver,
device=device,
optimizers=trainer_params.optimizers,
train_dataloader=trainer_params.train_dataloader,
validate_dataloaders=trainer_params.validate_dataloaders,
validate_every=trainer_params.validate_every,
input_mapping=trainer_params.input_mapping,
output_mapping=trainer_params.output_mapping,
metrics=trainer_params.metrics,
optimizers=optimizers,
train_dataloader=train_dataloader,
evaluate_dataloaders=evaluate_dataloaders,
evaluate_every=evaluate_every,
input_mapping=None,
output_mapping=None,
metrics=metrics,

n_epochs=n_epochs,
callbacks=callbacks,


+ 39
- 28
tests/core/drivers/paddle_driver/test_fleet.py View File

@@ -117,12 +117,13 @@ class TestSetDistReproDataloader:
"""

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False)

assert not (replaced_loader is dataloader)
@@ -133,12 +134,13 @@ class TestSetDistReproDataloader:
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
sampler = RandomSampler(self.dataset, shuffle=True)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
sampler = RandomSampler(self.dataset, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False)

assert not (replaced_loader is dataloader)
@@ -171,14 +173,15 @@ class TestSetDistReproDataloader:
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler
时的表现
"""
dataloader = DataLoader(
self.dataset,
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4),
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle),
)
dataloader.batch_sampler.set_distributed(
num_replicas=self.driver.world_size,
@@ -195,12 +198,13 @@ class TestSetDistReproDataloader:
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank
@@ -222,11 +226,12 @@ class TestSetDistReproDataloader:
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)

assert replaced_loader is dataloader
@@ -238,14 +243,15 @@ class TestSetDistReproDataloader:
"""

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler
的表现
"""
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4)
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)

@@ -258,13 +264,14 @@ class TestSetDistReproDataloader:
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2, shuffle=shuffle)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
@@ -276,16 +283,17 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.sampler.shuffle == True
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)

assert not (replaced_loader is dataloader)
@@ -293,7 +301,7 @@ class TestSetDistReproDataloader:
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.batch_sampler.sampler.shuffle == True
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
dist.barrier()

"""
@@ -302,13 +310,14 @@ class TestSetDistReproDataloader:
"""

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
@@ -320,18 +329,19 @@ class TestSetDistReproDataloader:
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.sampler.shuffle == True
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler
的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, True)
batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
@@ -349,11 +359,12 @@ class TestSetDistReproDataloader:
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)

assert not (replaced_loader is dataloader)


+ 133
- 82
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -1,4 +1,5 @@
import os
from re import S
os.environ["FASTNLP_BACKEND"] = "paddle"
import pytest
from pathlib import Path
@@ -56,34 +57,57 @@ def test_save_and_load_with_randombatchsampler(only_state_dict):
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
)
num_consumed_batches = 2

# TODO 断点重训完善后在这里迭代几次
already_seen_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_set.update(batch)

sampler_states = dataloader.batch_sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True)
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
# 加载
# 更改 batch_size
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2), 2, False)
)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")

# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空

# 2. 检查 batch_sampler 是否被正确地加载和替换
replaced_loader = states["dataloader"]
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.data_idx == sampler_states["data_idx"]

# 3. 检查 model 的参数是否被正确加载
for batch in dataloader:
res1 = driver1.validate_step(batch)
res2 = driver2.validate_step(batch)
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)

assert paddle.equal_all(res1["pred"], res2["pred"])

# 4. 检查 batch_idx
# TODO
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_batches = set()
for idx, batch in enumerate(replaced_loader):
left_batches.update(batch)

assert len(left_batches) + len(already_seen_set) == len(dataset)
assert len(left_batches | already_seen_set) == len(dataset)


finally:
synchronize_safe_rm(path)

@@ -104,21 +128,36 @@ def test_save_and_load_with_randomsampler(only_state_dict):
dataset,
batch_sampler=batch_sampler
)
num_consumed_batches = 2

# TODO 断点重训完善后在这里迭代几次
already_seen_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_set.update(batch)

sampler_states = dataloader.batch_sampler.sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True)
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])

# 加载
# 更改 batch_size
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2), 2, False)
)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")

# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空

# 2. 检查 sampler 是否被正确地加载和替换
replaced_loader = states["dataloader"]
replaced_loader = load_states["dataloader"]

assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"]
@@ -129,60 +168,51 @@ def test_save_and_load_with_randomsampler(only_state_dict):

# 3. 检查 model 的参数是否被正确加载
for batch in dataloader:
res1 = driver1.validate_step(batch)
res2 = driver2.validate_step(batch)
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)

assert paddle.equal_all(res1["pred"], res2["pred"])

# 4. 检查 batch_idx
# TODO
finally:
synchronize_safe_rm(path)

def test_save_and_load_state_dict(prepare_test_save_load):
"""
测试save和load函数
TODO optimizer的state_dict为空,暂时不测试
"""
try:
path = "dict"
driver1, driver2, dataloader = prepare_test_save_load

driver1.save_model(path)
driver2.load_model(path)

for batch in dataloader:
batch = driver1.move_data_to_device(batch)
res1 = driver1.validate_step(batch)
res2 = driver2.validate_step(batch)
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_batches = set()
for idx, batch in enumerate(replaced_loader):
left_batches.update(batch)

assert paddle.equal_all(res1["pred"], res2["pred"])
assert len(left_batches) + len(already_seen_set) == len(dataset)
assert len(left_batches | already_seen_set) == len(dataset)
finally:
synchronize_safe_rm(path)

def test_save_and_load_whole_model(prepare_test_save_load):
@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_model(prepare_test_save_load, only_state_dict):
"""
测试save和load函数
TODO optimizer的state_dict为空,暂时不测试
测试 save_model 和 load_model 函数
"""
try:
path = "model"
driver1, driver2, dataloader = prepare_test_save_load

driver1.save_model(path, only_state_dict=False, input_spec=[paddle.ones((32, 10))])
driver2.load_model(path, only_state_dict=False)
if only_state_dict:
driver1.save_model(path, only_state_dict)
else:
driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((32, 10))])
driver2.load_model(path, only_state_dict)

for batch in dataloader:
batch = driver1.move_data_to_device(batch)
res1 = driver1.validate_step(batch)
res2 = driver2.validate_step(batch)
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)

assert paddle.equal_all(res1["pred"], res2["pred"])
finally:
synchronize_safe_rm(path + ".pdiparams")
synchronize_safe_rm(path + ".pdiparams.info")
synchronize_safe_rm(path + ".pdmodel")

if only_state_dict:
synchronize_safe_rm(path)
else:
synchronize_safe_rm(path + ".pdiparams")
synchronize_safe_rm(path + ".pdiparams.info")
synchronize_safe_rm(path + ".pdmodel")

class TestSingleDeviceFunction:
"""
@@ -199,13 +229,7 @@ class TestSingleDeviceFunction:
测试能否运行
"""
res = self.driver.unwrap_model()

def test_check_evaluator_mode(self):
"""
这两个函数没有返回值和抛出异常,仅检查是否有import错误等影响运行的因素
"""
self.driver.check_evaluator_mode("validate")
self.driver.check_evaluator_mode("test")
assert res is self.driver.model

def test_is_distributed(self):
assert self.driver.is_distributed() == False
@@ -237,44 +261,55 @@ class TestSetDistReproDataloder:

assert replaced_loader is dataloader

def test_set_dist_repro_dataloader_with_reproducible_true(self):
@pytest.mark.parametrize("shuffle", [True, False])
def test_set_dist_repro_dataloader_with_reproducible_true(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现
当dist为字符串时,此时应该返回新的 dataloader,且 batch_sampler 为 RandomBatchSampler
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True),
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
if shuffle:
# 此时会替换 sampler
assert isinstance(replaced_loader.batch_sampler, paddle.io.BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
else:
# 此时会替换 batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last

# self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

def test_set_dist_repro_dataloader_with_dist_batch_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False)
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle)
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler is dist

self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

def test_set_dist_repro_dataloader_with_dist_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现
应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
dist = RandomSampler(self.dataset, shuffle=True)
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle)
dist = RandomSampler(self.dataset, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)

assert not (replaced_loader is dataloader)
@@ -284,16 +319,21 @@ class TestSetDistReproDataloder:
assert replaced_loader.batch_sampler.sampler is dist
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size

self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader,且其余各项设置和原来相同
"""
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False)
batch_sampler=RandomBatchSampler(
BatchSampler(self.dataset, batch_size=4, shuffle=shuffle),
batch_size=4,
drop_last=False,
)
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)

@@ -303,15 +343,16 @@ class TestSetDistReproDataloder:
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last

self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader,且其余各项设置和原来相同
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2, shuffle=shuffle)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
@@ -323,11 +364,11 @@ class TestSetDistReproDataloder:
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.sampler.shuffle == True
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle

self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

def check_set_dist_repro_dataloader(self, dataloader, replaced_loader):
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle):
"""
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确
"""
@@ -346,9 +387,6 @@ class TestSetDistReproDataloder:
# 加载 num_consumed_samples_array,设置正确取出的 batch 数目
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)

import time
time.sleep(5)

# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
@@ -357,16 +395,29 @@ class TestSetDistReproDataloder:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
replaced_loader.batch_sampler.load_state_dict(sampler_states)
# 重新改造 dataloader
new_loader = DataLoader(
dataset=replaced_loader.dataset,
batch_sampler=RandomBatchSampler(
BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size),
batch_size=batch_size,
drop_last=False,
)
)
new_loader.batch_sampler.load_state_dict(sampler_states)
else:
batch_size = replaced_loader.batch_sampler.batch_size
num_consumed_batches = num_consumed_batches * batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states)
replaced_loader.batch_sampler.sampler.set_epoch(0)
for idx, batch in enumerate(replaced_loader):
# 重新构造 dataloader
batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size)
batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle)
new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler)
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
for idx, batch in enumerate(new_loader):
left_idxes.update(batch)

assert len(left_idxes) + len(already_seen_idx) == len(self.dataset)


+ 1
- 1
tests/helpers/callbacks/helper_callbacks.py View File

@@ -72,7 +72,7 @@ class RecordTrainerEventTriggerCallback(Callback):
print("on_train_end")

def on_train_epoch_begin(self, trainer):
if trainer.current_epoch_idx >= 1:
if trainer.cur_epoch_idx >= 1:
# 触发 on_exception;
raise Exception
print("on_train_epoch_begin")


+ 1
- 1
tests/helpers/models/paddle_model.py View File

@@ -26,7 +26,7 @@ class PaddleNormalModel_Classification_1(paddle.nn.Layer):
x = self(x)
return {"loss": self.loss_fn(x, y)}

def validate_step(self, x, y):
def evaluate_step(self, x, y):

x = self(x)
return {"pred": x, "target": y.reshape((-1,))}


Loading…
Cancel
Save