diff --git a/fastNLP/core/drivers/jittor_driver/mpi.py b/fastNLP/core/drivers/jittor_driver/mpi.py index c467b868..98ac44a0 100644 --- a/fastNLP/core/drivers/jittor_driver/mpi.py +++ b/fastNLP/core/drivers/jittor_driver/mpi.py @@ -1,5 +1,5 @@ import os -from typing import Optional, Union +from typing import Optional, Union, Callable, Dict, Tuple from .jittor_driver import JittorDriver from fastNLP.envs.imports import _NEED_IMPORT_JITTOR @@ -61,14 +61,11 @@ class JittorMPIDriver(JittorDriver): return self._data_device return self.model_device - def train_step(self, batch): - return self._train_step(batch) - - def validate_step(self, batch): - return self._validate_step(batch) + def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: + pass - def test_step(self, batch): - return self._test_step(batch) + def get_model_call_fn(self, fn: str) -> Tuple: + pass def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], reproducible: bool = False, sampler_or_batch_sampler=None): diff --git a/fastNLP/core/drivers/jittor_driver/single_device.py b/fastNLP/core/drivers/jittor_driver/single_device.py index 84bdb28b..695e6ec9 100644 --- a/fastNLP/core/drivers/jittor_driver/single_device.py +++ b/fastNLP/core/drivers/jittor_driver/single_device.py @@ -1,9 +1,11 @@ -from typing import Dict, Union +from typing import Dict, Union, Tuple, Callable, Optional from .jittor_driver import JittorDriver from fastNLP.core.utils import auto_param_call +from fastNLP.core.utils.utils import _get_fun_msg from fastNLP.envs.imports import _NEED_IMPORT_JITTOR from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler +from fastNLP.core.log import logger if _NEED_IMPORT_JITTOR: import jittor @@ -27,42 +29,6 @@ class JittorSingleDriver(JittorDriver): self.global_rank = 0 self.world_size = 1 - if hasattr(self.model, "train_step"): - self._train_step = self.model.train_step - self._train_signature_fn = None - else: - self._train_step = self.model - model = self.unwrap_model() - self._train_signature_fn = model.execute - - if hasattr(self.model, "evaluate_step"): - self._validate_step = self.model.evaluate_step - self._validate_signature_fn = None - elif hasattr(self.model, "test_step"): - self._validate_step = self.model.test_step - self._validate_signature_fn = self.model.test_step - else: - self._validate_step = self.model - model = self.unwrap_model() - self._validate_signature_fn = model.execute - - if hasattr(self.model, "test_step"): - self._test_step = self.model.test_step - self._test_signature_fn = None - elif hasattr(self.model, "evaluate_step"): - self._test_step = self.model.evaluate_step - self._test_signature_fn = self.model.evaluate_step - else: - self._test_step = self.model - model = self.unwrap_model() - self._test_signature_fn = model.execute - - def train_step(self, batch) -> Dict: - if isinstance(batch, Dict): - return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) - else: - return self._train_step(batch) - def step(self): """ jittor optimizers 的step函数可以传入参数loss @@ -80,18 +46,24 @@ class JittorSingleDriver(JittorDriver): for optimizer in self.optimizers: optimizer.zero_grad() - def validate_step(self, batch): - if isinstance(batch, Dict): - return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) + def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: + if isinstance(batch, Dict) and not self.wo_auto_param_call: + return auto_param_call(fn, batch, signature_fn=signature_fn) else: - return self._validate_step(batch) - - def test_step(self, batch): - - if isinstance(batch, Dict): - return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) + return fn(batch) + + def get_model_call_fn(self, fn: str) -> Tuple: + if hasattr(self.model, fn): + fn = getattr(self.model, fn) + if not callable(fn): + raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") + logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') + return fn, None + elif fn in {"train_step", "evaluate_step"}: + logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...') + return self.model, self.model.forward else: - return self._test_step(batch) + raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") def unwrap_model(self): return self.model diff --git a/fastNLP/core/drivers/paddle_driver/dist_utils.py b/fastNLP/core/drivers/paddle_driver/dist_utils.py new file mode 100644 index 00000000..3bfbbd4f --- /dev/null +++ b/fastNLP/core/drivers/paddle_driver/dist_utils.py @@ -0,0 +1,376 @@ +import io +import pickle +_pickler = pickle.Pickler +_unpickler = pickle.Unpickler +from typing import Any, List + +from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 +from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + from torch import distributed as dist + if _TORCH_GREATER_EQUAL_1_8: + try: + from torch._C._distributed_c10d import ProcessGroupGloo + from torch._C._distributed_c10d import _ProcessGroupWrapper + except ImportError: + pass + + +from fastNLP.core.utils import apply_to_collection + + +def _validate_output_list_for_rank(my_rank, dst, gather_list): + if dst == my_rank: + if not gather_list: + raise ValueError( + "Argument ``gather_list`` must be specified on destination rank." + ) + elif gather_list: + raise ValueError( + "Argument ``gather_list`` must NOT be specified " + "on non-destination ranks." + ) + + +def fastnlp_paddle_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP): + """ + 从其它 rank gather 东西到 dst rank 。 + + Gathers picklable objects from the whole group in a single process. + Similar to :func:`gather`, but Python objects can be passed in. Note that the + object must be picklable in order to be gathered. + + Args: + obj (Any): Input object. Must be picklable. + object_gather_list (list[Any]): Output list. On the ``dst`` rank, it + should be correctly sized as the size of the group for this + collective and will contain the output. Must be ``None`` on non-dst + ranks. (default is ``None``) + dst (int, optional): Destination rank. (default is 0) + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + + Returns: + None. On the ``dst`` rank, ``object_gather_list`` will contain the + output of the collective. + + .. note:: Note that this API differs slightly from the gather collective + since it does not provide an async_op handle and thus will be a blocking + call. + + .. note:: Note that this API is not supported when using the NCCL backend. + + .. warning:: + :func:`gather_object` uses ``pickle`` module implicitly, which is + known to be insecure. It is possible to construct malicious pickle data + which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + Example:: + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes world_size of 3. + >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object + >>> output = [None for _ in gather_objects] + >>> dist.gather_object( + gather_objects[dist.get_rank()], + output if dist.get_rank() == 0 else None, + dst=0 + ) + >>> # On rank 0 + >>> output + ['foo', 12, {1: 2}] + """ + if group is None: + group = DEFAULT_TORCH_GROUP + + if dist.distributed_c10d._rank_not_in_group(group): + return + + # Ensure object_gather_list is specified appopriately. + my_rank = dist.get_rank() + _validate_output_list_for_rank(my_rank, dst, object_gather_list) + # 防止 unpickle 的时候出现在了发送的 gpu 上。 + obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) + input_tensor, local_size = _object_to_tensor(obj) + group_backend = dist.get_backend(group) + current_device = torch.device("cpu") + is_nccl_backend = group_backend == dist.Backend.NCCL + if is_nccl_backend: + current_device = torch.device('cuda', torch.cuda.current_device()) + input_tensor = input_tensor.to(current_device) + local_size = local_size.to(current_device) + # Gather all local sizes. This is so that we can find the max size, and index + # until the correct size when deserializing the tensors. + group_size = dist.get_world_size(group=group) + object_sizes_tensor = torch.zeros(group_size, dtype=torch.long, device=current_device) + object_size_list = [ + object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) + ] + # Allgather tensor sizes. An all-gather is needed here despite this being a + # gather, since each rank needs to broadcast a tensor of the same (maximal) + # size. + dist.all_gather(object_size_list, local_size, group=group) + max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] + # Resize tensor to max size across all ranks. + input_tensor.resize_(max_object_size) + # Avoid populating output tensors if the result won't be gathered on this rank. + if my_rank == dst: + coalesced_output_tensor = torch.empty( + max_object_size * group_size, dtype=torch.uint8, device=current_device + ) + # Output tensors are nonoverlapping views of coalesced_output_tensor + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] + for i in range(group_size) + ] + # All ranks call gather with equal-sized tensors. + dist.gather( + input_tensor, + gather_list=output_tensors if my_rank == dst else None, + dst=dst, + group=group, + ) + if my_rank != dst: + return + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(torch.uint8) # type: ignore[call-overload] + tensor_size = object_size_list[i] + object_gather_list[i] = _tensor_to_object(tensor, tensor_size) + + +def _object_to_tensor(obj, device=None): + f = io.BytesIO() + _pickler(f).dump(obj) + byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) # type: ignore[attr-defined] + # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. + # Otherwise, it will casue 100X slowdown. + # See: https://github.com/pytorch/pytorch/issues/65696 + byte_tensor = torch.ByteTensor(byte_storage) + local_size = torch.LongTensor([byte_tensor.numel()]) + if device is not None: + byte_tensor = byte_tensor.to(device) + local_size = local_size.to(device) + return byte_tensor, local_size + + +def _tensor_to_object(tensor, tensor_size): + buf = tensor.detach().cpu().numpy().tobytes()[:tensor_size] + return _unpickler(io.BytesIO(buf)).load() + + +def send_recv_object(obj, src, cur_rank, device, group=None, tag=0): + # src rank send to all other ranks + size = torch.LongTensor([0]).to(device) + + if cur_rank == src: + world_size = dist.get_world_size(group=group) + tensor, size = _object_to_tensor(obj) + tensor = tensor.to(device) + size = size.to(device) + + # 首先同步 obj 的 size 的信息; + dist.broadcast(size, src, group=group) + for subrank in range(world_size): + if subrank != src: + dist.send(tensor=tensor, dst=subrank, group=group, tag=tag) + else: + dist.broadcast(size, src, group=group) + tensor = torch.ByteTensor([0] * size).to(device) + dist.recv(tensor=tensor, src=src, group=group, tag=tag) + + return _tensor_to_object(tensor.cpu(), size) + +def fastnlp_paddle_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) ->List: + """ + 实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 + + example: + obj = { + 'a': [1, 1], + 'b': [[1, 2], [1, 2]], + 'c': { + 'd': [1, 2] + } + } + -> + [ + {'a': 1, 'b':[1, 2], 'c':{'d': 1}}, + {'a': 1, 'b':[1, 2], 'c':{'d': 2}} + ] + + :param obj: 任意结构的数据,如果为 tensor ,需要保证每个显卡上的 tensor 的形状是一样的。如果传入的是非 tensor 对象都将直接进行 + 序列化之后进行传输。 + :param device: 当前该参数无意义。 + :param group: + :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 + """ + if group is None: + group = DEFAULT_TORCH_GROUP + if isinstance(obj, torch.Tensor): + objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))] + dist.all_gather(objs, obj, group=group) + else: + objs = [None for _ in range(dist.get_world_size(group))] + # 防止 unpickle 的时候弄到发送的 gpu 上了 + obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) + if _TORCH_GREATER_EQUAL_1_8: + dist.all_gather_object(objs, obj, group=group) + else: + objs = all_gather_object(objs, obj, group=group) + return objs + + +def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GROUP): + """ + 将 src 上的 obj 对象广播到其它 rank 上。 + + :param obj: + :param src: + :param device: + :param group: + :return: + """ + if group is None: + group = DEFAULT_TORCH_GROUP + cur_rank = dist.get_rank(group) + if cur_rank == src: + # 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里 + obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) + if _TORCH_GREATER_EQUAL_1_8: + if cur_rank!=src: + get_obj = [None] + dist.broadcast_object_list(get_obj, src=src, group=group) + return get_obj[0] + else: + dist.broadcast_object_list([obj], src=src, group=group) + return obj + if device is None: + device = torch.cuda.current_device() + + if cur_rank == src: + tensor, size = _object_to_tensor(obj, device=device) + else: + size = torch.LongTensor([0]).to(device) + + dist.broadcast(size, src=src, group=group) + if cur_rank != src: + tensor = torch.empty( + size.int().item(), # type: ignore[arg-type] + dtype=torch.uint8, + device=device + ) + dist.broadcast(tensor, src=src, group=group) + + return _tensor_to_object(tensor, tensor_size=size.item()) + + +def _check_for_nccl_backend(group): + pg = group or dist.distributed_c10d._get_default_group() + # It is not expected for PG to be wrapped many times, but support it just + # in case + while isinstance(pg, _ProcessGroupWrapper): + pg = pg.wrapped_pg + + return ( + dist.is_nccl_available() and + isinstance(pg, dist.ProcessGroupNCCL) + ) + + +def all_gather_object(object_list, obj, group=None): + """ + 复制 pytorch 的代码,使得可以版本兼容低版本的 pytorch 。 + + Gathers picklable objects from the whole group into a list. Similar to + :func:`all_gather`, but Python objects can be passed in. Note that the object + must be picklable in order to be gathered. + + Args: + object_list (list[Any]): Output list. It should be correctly sized as the + size of the group for this collective and will contain the output. + object (Any): Pickable Python object to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + + Returns: + None. If the calling rank is part of this group, the output of the + collective will be populated into the input ``object_list``. If the + calling rank is not part of the group, the passed in ``object_list`` will + be unmodified. + + .. note:: Note that this API differs slightly from the :func:`all_gather` + collective since it does not provide an ``async_op`` handle and thus + will be a blocking call. + + .. note:: For NCCL-based processed groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsiblity to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. warning:: + :func:`all_gather_object` uses ``pickle`` module implicitly, which is + known to be insecure. It is possible to construct malicious pickle data + which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + Example:: + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes world_size of 3. + >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object + >>> output = [None for _ in gather_objects] + >>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) + >>> output + ['foo', 12, {1: 2}] + """ + if dist.distributed_c10d._rank_not_in_group(group): + return + if _TORCH_GREATER_EQUAL_1_8: + current_device = torch.device("cpu") + is_nccl_backend = _check_for_nccl_backend(group) + if is_nccl_backend: + # See note about using torch.cuda.current_device() here in docstring. + # We cannot simply use my_rank since rank == device is not necessarily + # true. + current_device = torch.device("cuda", torch.cuda.current_device()) + else: + current_device = torch.cuda.current_device() + + input_tensor, local_size = _object_to_tensor(obj, device=current_device) + + # Gather all local sizes. This is so that we can find the max size, and index + # until the correct size when deserializing the tensors. + group_size = dist.get_world_size(group=group) + object_sizes_tensor = torch.zeros( + group_size, dtype=torch.long, device=current_device + ) + object_size_list = [ + object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) + ] + # Allgather tensor sizes + dist.all_gather(object_size_list, local_size, group=group) + max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] + # Resize tensor to max size across all ranks. + input_tensor.resize_(max_object_size) + coalesced_output_tensor = torch.empty( + max_object_size * group_size, dtype=torch.uint8, device=current_device + ) + # Output tensors are nonoverlapping views of coalesced_output_tensor + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] + for i in range(group_size) + ] + dist.all_gather(output_tensors, input_tensor, group=group) + # Deserialize outputs back to object. + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(torch.uint8) + if tensor.device != torch.device("cpu"): + tensor = tensor.cpu() + tensor_size = object_size_list[i] + object_list[i] = _tensor_to_object(tensor, tensor_size) + return object_list \ No newline at end of file diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 1b29fd07..a083e42c 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -1,13 +1,12 @@ import os +import shutil from functools import partial -from typing import List, Union, Optional, Dict +from typing import List, Union, Optional, Dict, Tuple, Callable from .paddle_driver import PaddleDriver from .fleet_launcher import FleetLauncher from .utils import ( _FleetWrappingModel, - ForwardState, - _MODE_PARAMETER, get_device_from_visible, reset_seed, replace_sampler, @@ -47,8 +46,7 @@ if _NEED_IMPORT_PADDLE: __all__ = [ "PaddleFleetDriver", ] -# if os.path.exists(self.gloo_rendezvous_dir): -# shutil.rmtree(self.gloo_rendezvous_dir) + class PaddleFleetDriver(PaddleDriver): def __init__( self, @@ -104,34 +102,6 @@ class PaddleFleetDriver(PaddleDriver): # 我们就直接将 model_device 置为 None; self._model_device = None - def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call): - if isinstance(batch, Dict) and not wo_auto_param_call: - return auto_param_call(step_fn, batch, signature_fn=signature_fn) - else: - return self._validate_step(batch) - - model = model._layers - if hasattr(model, "train_step"): - logger.warning( - "Notice your model is a `paddle.DataParallel` model. And your " - "model also implements the `train_step` method, which we can not call actually, we will" - " call `forward` function instead of `train_step` and you should note that.") - self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) - - if hasattr(model, "evaluate_step"): - logger.warning( - "Notice your model is a `paddle.DataParallel` model. And your " - "model also implements the `evaluate_step` method, which we can not call actually, " - "we will call `forward` function instead of `evaluate_step` and you should note that.") - self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) - - if hasattr(model, "test_step"): - logger.warning( - "Notice your model is a `paddle.DataParallel` model. And your " - "model also implements the `test_step` method, which we can not call actually, we will" - " call `forward` function instead of `test_step` and you should note that.") - self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) - # 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; self._data_device = kwargs.get("data_device", None) if self._data_device is not None: @@ -150,8 +120,6 @@ class PaddleFleetDriver(PaddleDriver): self.world_size = None self.global_rank = 0 - self._configured = False # 防止重复调用 configure_ddp() 函数使用 - self._has_setup = False # 防止重复调用 setup() 函数 self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {}) check_user_specific_params(self._fleet_kwargs, DataParallel.__init__) @@ -173,6 +141,9 @@ class PaddleFleetDriver(PaddleDriver): os.makedirs(name=self.output_from_new_proc, exist_ok=True) self.output_from_new_proc = os.path.abspath(self.output_from_new_proc) + self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; + self._has_fleetwrapped = False # 判断传入的模型是否经过 _has_fleetwrapped 包裹; + def setup(self): """ 在主进程拉起其它子进程,将主进程作为rank 0 @@ -268,17 +239,17 @@ class PaddleFleetDriver(PaddleDriver): dist.barrier() def configure_fleet(self): - if not self._configured and not isinstance(self.model, DataParallel): + if not self._has_fleetwrapped and not isinstance(self.model, DataParallel): self.model = DataParallel( _FleetWrappingModel(self.model), **self._fleet_kwargs ) + self._has_fleetwrapped = True - self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call) - self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call) - self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call) - - self._configured = True + def on_exception(self): + if os.path.exists(self.gloo_rendezvous_dir): + shutil.rmtree(self.gloo_rendezvous_dir) + super().on_exception() @property def world_size(self) -> int: @@ -310,14 +281,39 @@ class PaddleFleetDriver(PaddleDriver): return self._data_device return self.model_device - def train_step(self, batch): - return self._train_step(batch) - - def validate_step(self, batch): - return self._validate_step(batch) + def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: + if self._has_fleetwrapped: + return self.model(batch, fastnlp_fn=fn, fastnlp_signature_fn=signature_fn, + wo_auto_param_call=self.wo_auto_param_call) + else: + if isinstance(batch, Dict) and not self.wo_auto_param_call: + return auto_param_call(fn, batch, signature_fn=signature_fn) + else: + return fn(batch) + + def get_model_call_fn(self, fn: str) -> Tuple: + model = self.unwrap_model() + if self._has_fleetwrapped: + if hasattr(model, fn): + fn = getattr(model, fn) + if not callable(fn): + raise RuntimeError(f"The `{fn}` attribute of model is not `Callable`.") + return fn, None + elif fn in {"train_step", "evaluate_step"}: + return model, model.forward + else: + raise RuntimeError(f"There is no `{fn}` method in your model.") + else: + if hasattr(model, fn): + logger.warning("Notice your model is a `DistributedDataParallel` model. And your model also implements " + f"the `{fn}` method, which we can not call actually, we will" + " call `forward` function instead of `train_step` and you should note that.") + elif fn not in {"train_step", "evaluate_step"}: + raise RuntimeError(f"There is no `{fn}` method in your model. And also notice that your model is a " + "`DistributedDataParallel` model, which means that we will only call model.forward " + "function when we are in forward propagation.") - def test_step(self, batch): - return self._test_step(batch) + return self.model, model.forward def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]], reproducible: bool = False, sampler_or_batch_sampler=None): @@ -406,14 +402,6 @@ class PaddleFleetDriver(PaddleDriver): else: raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") - def backward(self, loss): - self.grad_scaler.scale(loss).backward() - - def step(self): - for optimizer in self.optimizers: - self.grad_scaler.step(optimizer) - self.grad_scaler.update() - def is_global_zero(self): return self.global_rank == 0 @@ -450,3 +438,45 @@ class PaddleFleetDriver(PaddleDriver): if not isinstance(each_optimizer, (Optimizer, DistribuedOptimizer)): raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, " f"not {type(each_optimizer)}.") + + def broadcast_object(self, obj, src:int=0, group=None, **kwargs): + """ + 从 src 端将 obj 对象(可能是 tensor ,可能是 object )发送到 dst 处。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行 + 传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。 + + :param obj: obj,可能是 Tensor 或 嵌套类型的数据 + :param int src: source 的 global rank 。 + :param int dst: target 的 global rank,可以是多个目标 rank + :param group: 所属的 group + :param kwargs: + :return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回 + 接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 + """ + return + return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group) + + def all_gather(self, obj, group) -> List: + """ + 将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过 + pickle 进行序列化,接收到之后再反序列化。 + + example: + obj = { + 'a': [1, 1], + 'b': [[1, 2], [1, 2]], + 'c': { + 'd': [1, 2] + } + } + -> + [ + {'a': 1, 'b':[1, 2], 'c':{'d': 1}}, + {'a': 1, 'b':[1, 2], 'c':{'d': 2}} + ] + + :param obj: 需要传输的对象,在每个rank上都应该保持相同的结构。 + :param group: + :return: + """ + return + return fastnlp_paddle_all_gather(obj, group=group) diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 977eaf2c..37a5e59e 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -71,6 +71,14 @@ class PaddleDriver(Driver): for optimizer in self.optimizers: optimizer.clear_grad() + def backward(self, loss): + self.grad_scaler.scale(loss).backward() + + def step(self): + for optimizer in self.optimizers: + self.grad_scaler.step(optimizer) + self.grad_scaler.update() + @staticmethod def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): r""" @@ -115,28 +123,6 @@ class PaddleDriver(Driver): raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, " f"not {type(each_optimizer)}.") - def check_evaluator_mode(self, mode: str): - r""" - 因为我们在具体的 driver 的 evaluate_step 和 test_step 的逻辑是如果模型没有实现本函数,那么就去检测模型是否实现了另一个函数; - 因此如果用户的 evaluator evaluate_fn 是 validate,但是传入的 model 却没有实现 evaluate_step 函数,而是实现了 test_step 函数,那么 - 我们应当提醒用户这一行为; - """ - model = self.unwrap_model() - if mode == "validate": - if not hasattr(model, "evaluate_step"): - if hasattr(model, "test_step"): - logger.warning( - "Your model does not have 'evaluate_step' method but has 'test_step' method, but you" - "are using 'Evaluator.validate', we are going to use 'test_step' to substitute for" - "'evaluate_step'.") - - else: - if not hasattr(model, "test_step"): - if hasattr(model, "evaluate_step"): - logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" - "are using 'Evaluator.test', we are going to use 'evaluate_step' to substitute for" - "'test_step'.") - @staticmethod def tensor_to_numeric(tensor, reduce=None): r""" @@ -268,10 +254,10 @@ class PaddleDriver(Driver): except: # 有可能 batch_size 为 None,就只有损失精度了 pass assert sampler_states["num_consumed_samples"] != -1, "This is a bug, please report." + states["sampler_states"] = sampler_states else: raise RuntimeError( "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") - states["sampler_states"] = sampler_states # 2. 保存模型的状态; if should_save_model: diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index f11cb49a..e47360ee 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -1,5 +1,5 @@ import os -from typing import Optional, Dict, Union +from typing import Optional, Dict, Union, Callable, Tuple from .paddle_driver import PaddleDriver from .utils import replace_batch_sampler, replace_sampler, get_device_from_visible @@ -11,16 +11,19 @@ from fastNLP.core.utils import ( get_paddle_device_id, paddle_move_data_to_device, ) +from fastNLP.core.utils.utils import _get_fun_msg from fastNLP.core.samplers import ( ReproducibleBatchSampler, RandomBatchSampler, ReproducibleSampler, + RandomSampler, re_instantiate_sampler, ) from fastNLP.core.log import logger if _NEED_IMPORT_PADDLE: import paddle + from paddle import DataParallel from paddle.fluid.reader import _DatasetKind __all__ = [ @@ -28,109 +31,57 @@ __all__ = [ ] class PaddleSingleDriver(PaddleDriver): - def __init__(self, model, device: str, fp16: Optional[bool] = False, **kwargs): + def __init__(self, model, device: Union[str, int], fp16: Optional[bool] = False, **kwargs): + if isinstance(model, DataParallel): + raise ValueError("`paddle.DataParallel` is not supported in `PaddleSingleDriver`") + + cuda_visible_devices = os.environ.get(USER_CUDA_VISIBLE_DEVICES, None) + if cuda_visible_devices == "": + device = "cpu" + logger.info("You have set `CUDA_VISIBLE_DEVICES` to '' in system environment variable, and we are gonna to" + "use `cpu` instead of `gpu` device.") + super(PaddleSingleDriver, self).__init__(model, fp16=fp16, **kwargs) if device is None: raise ValueError("Parameter `device` can not be None in `PaddleSingleDriver`.") + if device != "cpu": + if isinstance(device, int): + device_id = device + else: + device_id = get_paddle_device_id(device) + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id] self.model_device = get_paddle_gpu_str(device) self.local_rank = 0 self.global_rank = 0 self.world_size = 1 - if isinstance(model, paddle.DataParallel): - # 注意这里的 unwrap_model 调用的是具体子类的方法; - model = self.unwrap_model() - if hasattr(model, "train_step"): - logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also " - "implements the `train_step` method, which we can not call actually, we will " - " call `forward` function instead of `train_step` and you should note that.") - self._train_step = self.model - self._train_signature_fn = model.forward - - if hasattr(model, "evaluate_step"): - logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also " - "implements the `evaluate_step` method, which we can not call actually, we " - "will call `forward` function instead of `evaluate_step` and you should note that.") - self._validate_step = self.model - self._validate_signature_fn = model.forward - - if hasattr(model, "test_step"): - logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also " - "implements the `test_step` method, which we can not call actually, we will " - "call `forward` function instead of `test_step` and you should note that.") - self._test_step = self.model - self._test_signature_fn = model.forward - else: - if hasattr(self.model, "train_step"): - self._train_step = self.model.train_step - self._train_signature_fn = None - else: - self._train_step = self.model - # 输入的模型是 `DataParallel`,我们需要保证其 signature_fn 是正确的; - model = self.unwrap_model() - self._train_signature_fn = model.forward - - if hasattr(self.model, "evaluate_step"): - self._validate_step = self.model.evaluate_step - self._validate_signature_fn = None - elif hasattr(self.model, "test_step"): - self._validate_step = self.model.test_step - self._validate_signature_fn = self.model.test_step - else: - self._validate_step = self.model - model = self.unwrap_model() - self._validate_signature_fn = model.forward - - if hasattr(self.model, "test_step"): - self._test_step = self.model.test_step - self._test_signature_fn = None - elif hasattr(self.model, "evaluate_step"): - self._test_step = self.model.evaluate_step - self._test_signature_fn = self.model.evaluate_step - else: - self._test_step = self.model - model = self.unwrap_model() - self._test_signature_fn = model.forward - def setup(self): device = self.model_device - if device != "cpu": - device_id = get_paddle_device_id(device) - device_id = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id] - os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) - device = get_device_from_visible(device, output_type=str) + device = get_device_from_visible(device, output_type=str) paddle.device.set_device(device) self.model.to(device) - def train_step(self, batch) -> Dict: - # 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; + def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: if isinstance(batch, Dict) and not self.wo_auto_param_call: - return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) + return auto_param_call(fn, batch, signature_fn=signature_fn) else: - return self._train_step(batch) - - def backward(self, loss): - self.grad_scaler.scale(loss).backward() - - def step(self): - for optimizer in self.optimizers: - self.grad_scaler.step(optimizer) - self.grad_scaler.update() - - def validate_step(self, batch) -> Dict: - if isinstance(batch, Dict) and not self.wo_auto_param_call: - return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) + return fn(batch) + + def get_model_call_fn(self, fn: str) -> Tuple: + if hasattr(self.model, fn): + fn = getattr(self.model, fn) + if not callable(fn): + raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") + logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') + return fn, None + elif fn in {"train_step", "evaluate_step"}: + logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...') + return self.model, self.model.forward else: - return self._validate_step(batch) - - def test_step(self, batch) -> Dict: - if isinstance(batch, Dict) and not self.wo_auto_param_call: - return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) - else: - return self._test_step(batch) + raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") def move_data_to_device(self, batch: 'paddle.Tensor'): r""" @@ -164,12 +115,18 @@ class PaddleSingleDriver(PaddleDriver): return replace_sampler(dataloader, sampler) if reproducible: - batch_sampler = RandomBatchSampler( - batch_sampler=args.batch_sampler, - batch_size=args.batch_size, - drop_last=args.drop_last - ) - return replace_batch_sampler(dataloader, batch_sampler) + if isinstance(args.sampler, paddle.io.RandomSampler): + # 如果本来就是随机的,直接替换 + sampler = RandomSampler(args.sampler.data_source) + logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") + return replace_sampler(dataloader, sampler) + else: + batch_sampler = RandomBatchSampler( + batch_sampler=args.batch_sampler, + batch_size=args.batch_size, + drop_last=args.drop_last + ) + return replace_batch_sampler(dataloader, batch_sampler) else: return dataloader diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index 2f74cc65..feb5c3eb 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -11,7 +11,6 @@ from typing import Dict, Optional, Union from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.core.utils import get_paddle_device_id, auto_param_call, paddle_to -from fastNLP.core.samplers import RandomSampler from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES from fastNLP.core.log import logger @@ -87,8 +86,6 @@ class ForwardState(IntEnum): TEST = 2 PREDICT = 3 -_MODE_PARAMETER = "forward_state" - class _FleetWrappingModel(Layer): """ 参考_DDPWrappingModel,paddle的分布式训练也需要用paddle.nn.DataParallel进行包装,采用和 @@ -98,83 +95,16 @@ class _FleetWrappingModel(Layer): super(_FleetWrappingModel, self).__init__() self.model = model - if isinstance(model, paddle.DataParallel): - model = model._layers - if hasattr(model, "train_step"): - logger.warning( - "Notice your model is a `paddle.DataParallel` model. And your " - "model also implements the `train_step` method, which we can not call actually, we will" - " call `forward` function instead of `train_step` and you should note that.") - self._train_step = self.model - self._train_signature_fn = model.forward - - if hasattr(model, "evaluate_step"): - logger.warning( - "Notice your model is a `paddle.DataParallel` model. And your " - "model also implements the `evaluate_step` method, which we can not call actually, " - "we will call `forward` function instead of `evaluate_step` and you should note that.") - self._validate_step = self.model - self._validate_signature_fn = model.forward - - if hasattr(model, "test_step"): - logger.warning( - "Notice your model is a `paddle.DataParallel` model. And your " - "model also implements the `test_step` method, which we can not call actually, we will" - " call `forward` function instead of `test_step` and you should note that.") - self._test_step = self.model - self._test_signature_fn = model.forward - else: - if hasattr(model, "train_step"): - self._train_step = model.train_step - self._train_signature_fn = None - else: - self._train_step = model - self._train_signature_fn = model.forward - - if hasattr(model, "evaluate_step"): - self._validate_step = model.validate_step - self._validate_signature_fn = None - elif hasattr(model, "test_step"): - self._validate_step = model.test_step - self._validate_signature_fn = None - else: - self._validate_step = model - self._validate_signature_fn = model.forward - - if hasattr(model, "test_step"): - self._test_step = model.test_step - self._test_signature_fn = None - elif hasattr(model, "evaluate_step"): - self._test_step = model.validate_step - self._test_signature_fn = None - else: - self._test_step = model - self._test_signature_fn = model.forward - def forward(self, batch, **kwargs) -> Dict: - forward_state = kwargs.pop(_MODE_PARAMETER) + fn = kwargs.pop("fastnlp_fn") + signature_fn = kwargs.pop("fastnlp_signature_fn") wo_auto_param_call = kwargs.pop("wo_auto_param_call") - if forward_state == ForwardState.TRAIN: - if isinstance(batch, Dict) and not wo_auto_param_call: - return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) - else: - return self._train_step(batch) - elif forward_state == ForwardState.VALIDATE: - if isinstance(batch, Dict) and not wo_auto_param_call: - return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) - else: - return self._validate_step(batch) - elif forward_state == ForwardState.TEST: - if isinstance(batch, Dict) and not wo_auto_param_call: - return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) - else: - return self._test_step(batch) - elif forward_state == ForwardState.PREDICT: - raise NotImplementedError("'PREDICT' evaluate_fn has not been implemented.") + if isinstance(batch, Dict) and not wo_auto_param_call: + return auto_param_call(fn, batch, signature_fn=signature_fn) else: - raise NotImplementedError("You should direct a concrete evaluate_fn.") + return fn(batch) class DummyGradScaler: """ diff --git a/fastNLP/core/drivers/torch_paddle_driver/torch_paddle_driver.py b/fastNLP/core/drivers/torch_paddle_driver/torch_paddle_driver.py index 2f4526ac..20be8a37 100644 --- a/fastNLP/core/drivers/torch_paddle_driver/torch_paddle_driver.py +++ b/fastNLP/core/drivers/torch_paddle_driver/torch_paddle_driver.py @@ -1,6 +1,7 @@ -from typing import Optional, Dict, Union, Callable +from typing import Optional, Dict, Union, Callable, Tuple from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH +from fastNLP.core.utils.utils import _get_fun_msg if _NEED_IMPORT_PADDLE: @@ -48,33 +49,6 @@ class TorchPaddleDriver(Driver): elif self._data_device is not None: raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") - if hasattr(self.model, "train_step"): - self._train_step = self.model.train_step - self._train_signature_fn = None - else: - self._train_step = self.model - self._train_signature_fn = self.model.forward - - if hasattr(self.model, "evaluate_step"): - self._validate_step = self.model.evaluate_step - self._validate_signature_fn = None - elif hasattr(self.model, "test_step"): - self._validate_step = self.model.test_step - self._validate_signature_fn = self.model.forward - else: - self._validate_step = self.model - self._validate_signature_fn = self.model.forward - - if hasattr(self.model, "test_step"): - self._test_step = self.model.test_step - self._test_signature_fn = None - elif hasattr(self.model, "evaluate_step"): - self._test_step = self.model.evaluate_step - self._test_signature_fn = self.model.forward - else: - self._test_step = self.model - self._test_signature_fn = self.model.forward - def setup(self): if self.model_device is not None: paddle.device.set_device(self.model_device.replace("cuda", "gpu")) @@ -103,12 +77,6 @@ class TorchPaddleDriver(Driver): f"'torch.optim.Optimizer' or 'paddle.optimizers.Optimizer' type, " f"not {type(each_optimizer)}.") - def train_step(self, batch) -> Dict: - if isinstance(batch, Dict): - return auto_param_call(self._train_step, batch) - else: - return self._train_step(batch) - def step(self): for optimizer in self.optimizers: optimizer.step() @@ -125,17 +93,24 @@ class TorchPaddleDriver(Driver): else: raise ValueError("Unknown optimizers type.") - def validate_step(self, batch): - if isinstance(batch, Dict): - return auto_param_call(self._validate_step, batch) + def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: + if isinstance(batch, Dict) and not self.wo_auto_param_call: + return auto_param_call(fn, batch, signature_fn=signature_fn) else: - return self._validate_step(batch) - - def test_step(self, batch): - if isinstance(batch, Dict): - return auto_param_call(self._test_step, batch) + return fn(batch) + + def get_model_call_fn(self, fn: str) -> Tuple: + if hasattr(self.model, fn): + fn = getattr(self.model, fn) + if not callable(fn): + raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") + logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') + return fn, None + elif fn in {"train_step", "evaluate_step"}: + logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...') + return self.model, self.model.forward else: - return self._test_step(batch) + raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") def predict_step(self, batch): if isinstance(batch, Dict): diff --git a/fastNLP/modules/mix_modules/mix_module.py b/fastNLP/modules/mix_modules/mix_module.py index 2ee26133..1c2bd9e1 100644 --- a/fastNLP/modules/mix_modules/mix_module.py +++ b/fastNLP/modules/mix_modules/mix_module.py @@ -85,7 +85,7 @@ class MixModule: def test_step(self, batch): raise NotImplementedError - def validate_step(self, batch): + def evaluate_step(self, batch): raise NotImplementedError def train(self): diff --git a/tests/core/controllers/test_trainer_paddle.py b/tests/core/controllers/test_trainer_paddle.py index 69b16427..8a3ab2ce 100644 --- a/tests/core/controllers/test_trainer_paddle.py +++ b/tests/core/controllers/test_trainer_paddle.py @@ -1,13 +1,11 @@ import pytest import os os.environ["FASTNLP_BACKEND"] = "paddle" -from typing import Any from dataclasses import dataclass from fastNLP.core.controllers.trainer import Trainer from fastNLP.core.metrics.accuracy import Accuracy from fastNLP.core.callbacks.progress_callback import RichCallback -from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK from paddle.optimizer import Adam from paddle.io import DataLoader @@ -19,40 +17,18 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordM from tests.helpers.utils import magic_argv_env_context @dataclass -class MNISTTrainPaddleConfig: +class TrainPaddleConfig: num_labels: int = 10 - feature_dimension: int = 784 + feature_dimension: int = 10 - batch_size: int = 32 + batch_size: int = 2 shuffle: bool = True - validate_every = -5 + evaluate_every = 2 - driver: str = "paddle" - device = "gpu" - -@dataclass -class MNISTTrainFleetConfig: - num_labels: int = 10 - feature_dimension: int = 784 - - batch_size: int = 32 - shuffle: bool = True - validate_every = -5 - -@dataclass -class TrainerParameters: - model: Any = None - optimizers: Any = None - train_dataloader: Any = None - validate_dataloaders: Any = None - input_mapping: Any = None - output_mapping: Any = None - metrics: Any = None - -@pytest.mark.parametrize("driver,device", [("paddle", "cpu")("paddle", 1)]) +@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)]) # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) -@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.7, larger_better=True), - RichCallback(5), RecordLossCallback(loss_threshold=0.3)]]) +@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), + RichCallback(5)]]) @magic_argv_env_context def test_trainer_paddle( driver, @@ -60,38 +36,36 @@ def test_trainer_paddle( callbacks, n_epochs=2, ): - trainer_params = TrainerParameters() - - trainer_params.model = PaddleNormalModel_Classification_1( - num_labels=MNISTTrainPaddleConfig.num_labels, - feature_dimension=MNISTTrainPaddleConfig.feature_dimension + model = PaddleNormalModel_Classification_1( + num_labels=TrainPaddleConfig.num_labels, + feature_dimension=TrainPaddleConfig.feature_dimension ) - trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001) + optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) train_dataloader = DataLoader( - dataset=PaddleRandomMaxDataset(6400, 10), - batch_size=MNISTTrainPaddleConfig.batch_size, + dataset=PaddleRandomMaxDataset(20, 10), + batch_size=TrainPaddleConfig.batch_size, shuffle=True ) val_dataloader = DataLoader( - dataset=PaddleRandomMaxDataset(1000, 10), - batch_size=MNISTTrainPaddleConfig.batch_size, + dataset=PaddleRandomMaxDataset(20, 10), + batch_size=TrainPaddleConfig.batch_size, shuffle=True ) - trainer_params.train_dataloader = train_dataloader - trainer_params.validate_dataloaders = val_dataloader - trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every - trainer_params.metrics = {"acc": Accuracy(backend="paddle")} + train_dataloader = train_dataloader + evaluate_dataloaders = val_dataloader + evaluate_every = TrainPaddleConfig.evaluate_every + metrics = {"acc": Accuracy(backend="paddle")} trainer = Trainer( - model=trainer_params.model, + model=model, driver=driver, device=device, - optimizers=trainer_params.optimizers, - train_dataloader=trainer_params.train_dataloader, - validate_dataloaders=trainer_params.validate_dataloaders, - validate_every=trainer_params.validate_every, - input_mapping=trainer_params.input_mapping, - output_mapping=trainer_params.output_mapping, - metrics=trainer_params.metrics, + optimizers=optimizers, + train_dataloader=train_dataloader, + evaluate_dataloaders=evaluate_dataloaders, + evaluate_every=evaluate_every, + input_mapping=None, + output_mapping=None, + metrics=metrics, n_epochs=n_epochs, callbacks=callbacks, diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 9661c015..fd947c73 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -56,34 +56,57 @@ def test_save_and_load_with_randombatchsampler(only_state_dict): dataset=dataset, batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) ) + num_consumed_batches = 2 # TODO 断点重训完善后在这里迭代几次 + already_seen_set = set() + for idx, batch in enumerate(dataloader): + if idx >= num_consumed_batches: + break + already_seen_set.update(batch) sampler_states = dataloader.batch_sampler.state_dict() + save_states = {"num_consumed_batches": num_consumed_batches} if only_state_dict: - driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True) + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) else: - driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) - states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) + + # 加载 + # 更改 batch_size + dataloader = DataLoader( + dataset=dataset, + batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2), 2, False) + ) + load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 # TODO optimizer 的 state_dict 总是为空 # 2. 检查 batch_sampler 是否被正确地加载和替换 - replaced_loader = states["dataloader"] assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] assert replaced_loader.batch_sampler.data_idx == sampler_states["data_idx"] # 3. 检查 model 的参数是否被正确加载 for batch in dataloader: - res1 = driver1.validate_step(batch) - res2 = driver2.validate_step(batch) + res1 = driver1.model.evaluate_step(**batch) + res2 = driver2.model.evaluate_step(**batch) assert paddle.equal_all(res1["pred"], res2["pred"]) # 4. 检查 batch_idx - # TODO + start_batch = load_states.pop('batch_idx_in_epoch') + assert start_batch == 2 * num_consumed_batches + left_batches = set() + for idx, batch in enumerate(replaced_loader): + left_batches.update(batch) + + assert len(left_batches) + len(already_seen_set) == len(dataset) + assert len(left_batches | already_seen_set) == len(dataset) + + finally: synchronize_safe_rm(path) @@ -104,21 +127,36 @@ def test_save_and_load_with_randomsampler(only_state_dict): dataset, batch_sampler=batch_sampler ) + num_consumed_batches = 2 # TODO 断点重训完善后在这里迭代几次 + already_seen_set = set() + for idx, batch in enumerate(dataloader): + if idx >= num_consumed_batches: + break + already_seen_set.update(batch) sampler_states = dataloader.batch_sampler.sampler.state_dict() + save_states = {"num_consumed_batches": num_consumed_batches} if only_state_dict: - driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True) + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) else: - driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) - states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) + + # 加载 + # 更改 batch_size + dataloader = DataLoader( + dataset=dataset, + batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2), 2, False) + ) + load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 # TODO optimizer 的 state_dict 总是为空 # 2. 检查 sampler 是否被正确地加载和替换 - replaced_loader = states["dataloader"] + replaced_loader = load_states["dataloader"] assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] @@ -129,60 +167,51 @@ def test_save_and_load_with_randomsampler(only_state_dict): # 3. 检查 model 的参数是否被正确加载 for batch in dataloader: - res1 = driver1.validate_step(batch) - res2 = driver2.validate_step(batch) + res1 = driver1.model.evaluate_step(**batch) + res2 = driver2.model.evaluate_step(**batch) assert paddle.equal_all(res1["pred"], res2["pred"]) # 4. 检查 batch_idx - # TODO - finally: - synchronize_safe_rm(path) - -def test_save_and_load_state_dict(prepare_test_save_load): - """ - 测试save和load函数 - TODO optimizer的state_dict为空,暂时不测试 - """ - try: - path = "dict" - driver1, driver2, dataloader = prepare_test_save_load - - driver1.save_model(path) - driver2.load_model(path) - - for batch in dataloader: - batch = driver1.move_data_to_device(batch) - res1 = driver1.validate_step(batch) - res2 = driver2.validate_step(batch) + start_batch = load_states.pop('batch_idx_in_epoch') + assert start_batch == 2 * num_consumed_batches + left_batches = set() + for idx, batch in enumerate(replaced_loader): + left_batches.update(batch) - assert paddle.equal_all(res1["pred"], res2["pred"]) + assert len(left_batches) + len(already_seen_set) == len(dataset) + assert len(left_batches | already_seen_set) == len(dataset) finally: synchronize_safe_rm(path) -def test_save_and_load_whole_model(prepare_test_save_load): +@pytest.mark.parametrize("only_state_dict", ([True, False])) +def test_save_and_load_model(prepare_test_save_load, only_state_dict): """ - 测试save和load函数 - TODO optimizer的state_dict为空,暂时不测试 + 测试 save_model 和 load_model 函数 """ try: path = "model" driver1, driver2, dataloader = prepare_test_save_load - driver1.save_model(path, only_state_dict=False, input_spec=[paddle.ones((32, 10))]) - driver2.load_model(path, only_state_dict=False) + if only_state_dict: + driver1.save_model(path, only_state_dict) + else: + driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((32, 10))]) + driver2.load_model(path, only_state_dict) for batch in dataloader: batch = driver1.move_data_to_device(batch) - res1 = driver1.validate_step(batch) - res2 = driver2.validate_step(batch) + res1 = driver1.model.evaluate_step(**batch) + res2 = driver2.model.evaluate_step(**batch) assert paddle.equal_all(res1["pred"], res2["pred"]) finally: - synchronize_safe_rm(path + ".pdiparams") - synchronize_safe_rm(path + ".pdiparams.info") - synchronize_safe_rm(path + ".pdmodel") - + if only_state_dict: + synchronize_safe_rm(path) + else: + synchronize_safe_rm(path + ".pdiparams") + synchronize_safe_rm(path + ".pdiparams.info") + synchronize_safe_rm(path + ".pdmodel") class TestSingleDeviceFunction: """ @@ -199,13 +228,7 @@ class TestSingleDeviceFunction: 测试能否运行 """ res = self.driver.unwrap_model() - - def test_check_evaluator_mode(self): - """ - 这两个函数没有返回值和抛出异常,仅检查是否有import错误等影响运行的因素 - """ - self.driver.check_evaluator_mode("validate") - self.driver.check_evaluator_mode("test") + assert res is self.driver.model def test_is_distributed(self): assert self.driver.is_distributed() == False @@ -237,21 +260,30 @@ class TestSetDistReproDataloder: assert replaced_loader is dataloader - def test_set_dist_repro_dataloader_with_reproducible_true(self): + @pytest.mark.parametrize("shuffle", [True, False]) + def test_set_dist_repro_dataloader_with_reproducible_true(self, shuffle): """ 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 - 当dist为字符串时,此时应该返回新的 dataloader,且 batch_sampler 为 RandomBatchSampler + 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True), + 只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler """ - dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) assert not (replaced_loader is dataloader) - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) - assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) + if shuffle: + # 此时会替换 sampler + assert isinstance(replaced_loader.batch_sampler, paddle.io.BatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + else: + # 此时会替换 batch_sampler + assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.drop_last == dataloader.drop_last - # self.check_set_dist_repro_dataloader(dataloader, replaced_loader) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader) def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): """ diff --git a/tests/helpers/callbacks/helper_callbacks.py b/tests/helpers/callbacks/helper_callbacks.py index 751d59f2..c3a9d4da 100644 --- a/tests/helpers/callbacks/helper_callbacks.py +++ b/tests/helpers/callbacks/helper_callbacks.py @@ -72,7 +72,7 @@ class RecordTrainerEventTriggerCallback(Callback): print("on_train_end") def on_train_epoch_begin(self, trainer): - if trainer.current_epoch_idx >= 1: + if trainer.cur_epoch_idx >= 1: # 触发 on_exception; raise Exception print("on_train_epoch_begin") diff --git a/tests/helpers/models/paddle_model.py b/tests/helpers/models/paddle_model.py index a830b1ff..efa8c0ce 100644 --- a/tests/helpers/models/paddle_model.py +++ b/tests/helpers/models/paddle_model.py @@ -26,7 +26,7 @@ class PaddleNormalModel_Classification_1(paddle.nn.Layer): x = self(x) return {"loss": self.loss_fn(x, y)} - def validate_step(self, x, y): + def evaluate_step(self, x, y): x = self(x) return {"pred": x, "target": y.reshape((-1,))}