diff --git a/docs/source/fastNLP.core.callbacks.fitlog_callback.rst b/docs/source/fastNLP.core.callbacks.fitlog_callback.rst new file mode 100644 index 00000000..020c3ff3 --- /dev/null +++ b/docs/source/fastNLP.core.callbacks.fitlog_callback.rst @@ -0,0 +1,7 @@ +fastNLP.core.callbacks.fitlog\_callback module +============================================== + +.. automodule:: fastNLP.core.callbacks.fitlog_callback + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.callbacks.rst b/docs/source/fastNLP.core.callbacks.rst index 0f3f93ac..89d85f52 100644 --- a/docs/source/fastNLP.core.callbacks.rst +++ b/docs/source/fastNLP.core.callbacks.rst @@ -25,6 +25,7 @@ Submodules fastNLP.core.callbacks.callback_manager fastNLP.core.callbacks.checkpoint_callback fastNLP.core.callbacks.early_stop_callback + fastNLP.core.callbacks.fitlog_callback fastNLP.core.callbacks.has_monitor_callback fastNLP.core.callbacks.load_best_model_callback fastNLP.core.callbacks.lr_scheduler_callback diff --git a/docs/source/fastNLP.modules.mix_modules.rst b/docs/source/fastNLP.modules.mix_modules.rst new file mode 100644 index 00000000..5351c55a --- /dev/null +++ b/docs/source/fastNLP.modules.mix_modules.rst @@ -0,0 +1,15 @@ +fastNLP.modules.mix\_modules package +==================================== + +.. automodule:: fastNLP.modules.mix_modules + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.modules.mix_modules.utils diff --git a/docs/source/fastNLP.modules.mix_modules.utils.rst b/docs/source/fastNLP.modules.mix_modules.utils.rst new file mode 100644 index 00000000..9dab336d --- /dev/null +++ b/docs/source/fastNLP.modules.mix_modules.utils.rst @@ -0,0 +1,7 @@ +fastNLP.modules.mix\_modules.utils module +========================================= + +.. automodule:: fastNLP.modules.mix_modules.utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.rst b/docs/source/fastNLP.modules.rst new file mode 100644 index 00000000..fa1d95de --- /dev/null +++ b/docs/source/fastNLP.modules.rst @@ -0,0 +1,15 @@ +fastNLP.modules package +======================= + +.. automodule:: fastNLP.modules + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.modules.mix_modules diff --git a/docs/source/fastNLP.rst b/docs/source/fastNLP.rst index 726eb9c6..89c8e058 100644 --- a/docs/source/fastNLP.rst +++ b/docs/source/fastNLP.rst @@ -15,3 +15,4 @@ Subpackages fastNLP.core fastNLP.envs fastNLP.io + fastNLP.modules diff --git a/fastNLP/core/utils/paddle_utils.py b/fastNLP/core/utils/paddle_utils.py index c7bb9e79..d3764d4e 100644 --- a/fastNLP/core/utils/paddle_utils.py +++ b/fastNLP/core/utils/paddle_utils.py @@ -22,9 +22,9 @@ from .utils import apply_to_collection def _convert_data_device(device: Union[str, int]) -> str: """ - 用于转换 ``driver`` 的 ``data_device`` 的函数。如果用户设置了 ``FASTNLP_BACKEND=paddle``,那么 ``fastNLP`` 会将 + 用于转换 ``driver`` 的 ``data_device`` 的函数。如果用户设置了 ``FASTNLP_BACKEND=paddle``,那么 **fastNLP** 会将 可见的设备保存在 ``USER_CUDA_VISIBLE_DEVICES`` 中,并且将 ``CUDA_VISIBLE_DEVICES`` 设置为可见的第一张显卡;这是为 - 了顺利执行 ``paddle`` 的分布式训练而设置的。 + 了顺利执行 **paddle** 的分布式训练而设置的。 在这种情况下,单纯使用 ``driver.data_device`` 是无效的。比如在分布式训练中将设备设置为 ``[0,2,3]`` ,且用户设置了 ``CUDA_VISIBLE_DEVICES=3,4,5,6`` ,那么在 ``rank1``的进程中有:: @@ -127,7 +127,7 @@ def get_paddle_device_id(device: Union[str, int]) -> int: def paddle_move_data_to_device(batch: Any, device: Optional[Union[str, int]]) -> Any: r""" - 将 ``paddle`` 的数据集合传输到给定设备。只有 :class:`paddle.Tensor` 对象会被传输到设备中,其余保持不变。 + 将 **paddle** 的数据集合传输到给定设备。只有 :class:`paddle.Tensor` 对象会被传输到设备中,其余保持不变。 :param batch: 需要进行迁移的数据集合; :param device: 目标设备。可以是显卡设备的编号,或是``cpu``, ``gpu`` 或 ``gpu:x`` 格式的字符串;当这个参数 @@ -145,20 +145,20 @@ def paddle_move_data_to_device(batch: Any, device: Optional[Union[str, int]]) -> def is_in_paddle_dist() -> bool: """ - 判断是否处于 ``paddle`` 分布式的进程下,使用 ``PADDLE_RANK_IN_NODE`` 和 ``FLAGS_selected_gpus`` 判断。 + 判断是否处于 **paddle** 分布式的进程下,使用 ``PADDLE_RANK_IN_NODE`` 和 ``FLAGS_selected_gpus`` 判断。 """ return ('PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ) def is_in_fnlp_paddle_dist() -> bool: """ - 判断是否处于 ``fastNLP`` 拉起的 ``paddle`` 分布式进程中 + 判断是否处于 **fastNLP** 拉起的 **paddle** 分布式进程中 """ return FASTNLP_DISTRIBUTED_CHECK in os.environ def is_in_paddle_launch_dist() -> bool: """ - 判断是否处于 ``python -m paddle.distributed.launch`` 方法启动的 ``paddle`` 分布式进程中 + 判断是否处于 ``python -m paddle.distributed.launch`` 方法启动的 **paddle** 分布式进程中 """ return FASTNLP_BACKEND_LAUNCH in os.environ \ No newline at end of file diff --git a/fastNLP/core/utils/rich_progress.py b/fastNLP/core/utils/rich_progress.py index 02a30c26..53d4e281 100644 --- a/fastNLP/core/utils/rich_progress.py +++ b/fastNLP/core/utils/rich_progress.py @@ -1,5 +1,5 @@ """ -该文件用于为 ``fastNLP`` 提供一个统一的 ``progress bar`` 管理,通过共用一个``Task`` 对象, :class:`~fastNLP.core.Trainer` 中 +该文件用于为 **fastNLP** 提供一个统一的 ``progress bar`` 管理,通过共用一个``Task`` 对象, :class:`~fastNLP.core.Trainer` 中 的 ``progress bar`` 和 :class:`~fastNLP.core.Evaluator` 中的 ``progress bar`` 才能不冲突 """ import sys diff --git a/fastNLP/core/utils/torch_utils.py b/fastNLP/core/utils/torch_utils.py index 862ea20d..0cef2205 100644 --- a/fastNLP/core/utils/torch_utils.py +++ b/fastNLP/core/utils/torch_utils.py @@ -44,11 +44,11 @@ class TorchTransferableDataType(ABC): def torch_move_data_to_device(batch: Any, device: Optional[Union[str, "torch.device"]] = None, non_blocking: Optional[bool] = True) -> Any: r""" - 在 ``pytorch`` 中将数据集合 ``batch`` 传输到给定设备。任何定义方法 ``to(device)`` 的对象都将被移动并且集合中的所有其他对象将保持不变; + 在 **pytorch** 中将数据集合 ``batch`` 传输到给定设备。任何定义方法 ``to(device)`` 的对象都将被移动并且集合中的所有其他对象将保持不变; :param batch: 需要迁移的数据; :param device: 数据应当迁移到的设备;当该参数的值为 ``None`` 时则不执行任何操作; - :param non_blocking: ``pytorch`` 的数据迁移方法 ``to`` 的参数; + :param non_blocking: **pytorch** 的数据迁移方法 ``to`` 的参数; :return: 迁移到新设备上的数据集合; """ if device is None: diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index 00da9ac1..4d8bbb5e 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -55,7 +55,7 @@ def get_fn_arg_names(fn: Callable) -> List[str]: def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: r""" - 该函数会根据输入函数的形参名从 ``*args`` (均为 ``dict`` 类型)中找到匹配的值进行调用,如果传入的数据与 ``fn`` 的形参不匹配,可以通过 + 该函数会根据输入函数的形参名从 ``*args`` (均为 **dict** 类型)中找到匹配的值进行调用,如果传入的数据与 ``fn`` 的形参不匹配,可以通过 ``mapping`` 参数进行转换。``mapping`` 参数中的一对 ``(key, value)`` 表示在 ``*args`` 中找到 ``key`` 对应的值,并将这个值传递给形参中名为 ``value`` 的参数。 @@ -259,21 +259,21 @@ def dataclass_to_dict(data: "dataclasses.dataclass") -> Dict: def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None, data: Optional[Any] = None) -> Any: r""" - 用来实现将输入的 ``batch`` 或者输出的 ``outputs`` 通过 ``mapping`` 将键值进行更换的功能; + 用来实现将输入的 **batch** 或者输出的 **outputs** 通过 ``mapping`` 将键值进行更换的功能; 该函数应用于 ``input_mapping`` 和 ``output_mapping``; * 对于 ``input_mapping``,该函数会在 :class:`~fastNLP.core.controllers.TrainBatchLoop` 中取完数据后立刻被调用; * 对于 ``output_mapping``,该函数会在 :class:`~fastNLP.core.Trainer` 的 :meth:`~fastNLP.core.Trainer.train_step` - 以及 :class:`~fastNLP.core.Evaluator` 的 :meth:`~fastNLP.core.Evaluator.train_step` 中得到结果后立刻被调用; + 以及 :class:`~fastNLP.core.Evaluator` 的 :meth:`~fastNLP.core.Evaluator.train_step` 中得到结果后立刻被调用; 转换的逻辑按优先级依次为: - 1. 如果 ``mapping`` 是一个函数,那么会直接返回 ``mapping(data)``; - 2. 如果 ``mapping`` 是一个 ``Dict``,那么 ``data`` 的类型只能为以下三种: ``[Dict, dataclass, Sequence]``; + 1. 如果 ``mapping`` 是一个函数,那么会直接返回 **mapping(data)**; + 2. 如果 ``mapping`` 是一个 **Dict**,那么 ``data`` 的类型只能为以下三种: ``[Dict, dataclass, Sequence]``; - * 如果 ``data`` 是 ``Dict``,那么该函数会将 ``data`` 的 ``key`` 替换为 ``mapping[key]``; - * 如果 ``data`` 是 ``dataclass``,那么该函数会先使用 :func:`dataclasses.asdict` 函数将其转换为 ``Dict``,然后进行转换; - * 如果 ``data`` 是 ``Sequence``,那么该函数会先将其转换成一个对应的字典:: + * 如果 ``data`` 是 **Dict**,那么该函数会将 ``data`` 的 ``key`` 替换为 **mapping[key]**; + * 如果 ``data`` 是 **dataclass**,那么该函数会先使用 :func:`dataclasses.asdict` 函数将其转换为 **Dict**,然后进行转换; + * 如果 ``data`` 是 **Sequence**,那么该函数会先将其转换成一个对应的字典:: { "_0": list[0], @@ -281,7 +281,7 @@ def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None, ... } - 然后使用 ``mapping`` 对这个 ``Dict`` 进行转换,如果没有匹配上 ``mapping`` 中的 ``key`` 则保持 ``\'\_number\'`` 这个形式。 + 然后使用 ``mapping`` 对这个字典进行转换,如果没有匹配上 ``mapping`` 中的 ``key`` 则保持 ``'_number'`` 这个形式。 :param mapping: 用于转换的字典或者函数;当 ``mapping`` 是函数时,返回值必须为字典类型; :param data: 需要被转换的对象; @@ -459,7 +459,7 @@ def _is_iterable(value): def pretty_table_printer(dataset_or_ins) -> PrettyTable: r""" - 用于在 ``fastNLP`` 中展示数据的函数:: + 用于在 **fastNLP** 中展示数据的函数:: >>> ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) +-----------+-----------+-----------------+ diff --git a/fastNLP/modules/mix_modules/utils.py b/fastNLP/modules/mix_modules/utils.py index e69de29b..b19a5d53 100644 --- a/fastNLP/modules/mix_modules/utils.py +++ b/fastNLP/modules/mix_modules/utils.py @@ -0,0 +1,242 @@ +import warnings +from typing import Any, Optional, Union + +import numpy as np + +from fastNLP.core.utils import paddle_to, apply_to_collection +from fastNLP.core.log import logger +from fastNLP.envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE + +if _NEED_IMPORT_PADDLE: + import paddle + +if _NEED_IMPORT_JITTOR: + import jittor + +if _NEED_IMPORT_TORCH: + import torch + +__all__ = [ + "paddle2torch", + "torch2paddle", + "jittor2torch", + "torch2jittor", +] + +def _paddle2torch(paddle_tensor: 'paddle.Tensor', device: Optional[Union[str, int]] = None, no_gradient: bool = None) -> 'torch.Tensor': + """ + 将 :class:`paddle.Tensor` 转换为 :class:`torch.Tensor` ,并且能够保留梯度进行反向传播 + + :param paddle_tensor: 要转换的 **paddle** 张量; + :param device: 是否将转换后的张量迁移到特定设备上,为 ``None``时,和输入的张量相同; + :param no_gradient: 是否保留原张量的梯度。为 ``None`` 时,新的张量与输入张量保持一致; + 为 ``True`` 时,全部不保留梯度;为 ``False`` 时,全部保留梯度; + :return: 转换后的 **torch** 张量; + """ + no_gradient = paddle_tensor.stop_gradient if no_gradient is None else no_gradient + paddle_numpy = paddle_tensor.numpy() + if not np.issubdtype(paddle_numpy.dtype, np.inexact): + no_gradient = True + + if device is None: + if paddle_tensor.place.is_gpu_place(): + # paddlepaddle有两种Place,对应不同的device id获取方式 + if hasattr(paddle_tensor.place, "gpu_device_id"): + # paddle.fluid.core_avx.Place + # 在gpu环境下创建张量的话,张量的place是这一类型 + device = f"cuda:{paddle_tensor.place.gpu_device_id()}" + else: + # paddle.CUDAPlace + device = f"cuda:{paddle_tensor.place.get_device_id()}" + else: + # TODO: 可能需要支持xpu等设备 + device = "cpu" + + if not no_gradient: + # 保持梯度,并保持反向传播 + # torch.tensor会保留numpy数组的类型 + torch_tensor = torch.tensor(paddle_numpy, requires_grad=True, device=device) + hook = torch_tensor.register_hook( + lambda grad: paddle.autograd.backward(paddle_tensor, paddle.to_tensor(grad.cpu().numpy())) + ) + else: + # 不保留梯度 + torch_tensor = torch.tensor(paddle_numpy, requires_grad=False, device=device) + + return torch_tensor + + +def _torch2paddle(torch_tensor: 'torch.Tensor', device: str = None, no_gradient: bool = None) -> 'paddle.Tensor': + """ + 将 :class:`torch.Tensor` 转换为 :class:`paddle.Tensor`,并且能够保留梯度进行反向传播。 + + :param torch_tensor: 要转换的 **torch** 张量; + :param device: 是否将转换后的张量迁移到特定设备上,输入为 ``None`` 时,和输入的张量相同; + :param no_gradient: 是否保留原张量的梯度。为 ``None`` 时,新的张量与输入张量保持一致; + 为 ``True`` 时,全部不保留梯度;为 ``False`` 时,全部保留梯度; + :return: 转换后的 **paddle** 张量; + """ + no_gradient = not torch_tensor.requires_grad if no_gradient is None else no_gradient + if device is None: + if torch_tensor.is_cuda: + device = f"gpu:{torch_tensor.device.index}" + else: + device = "cpu" + + if not no_gradient: + # 保持梯度并保持反向传播 + # paddle的stop_gradient和torch的requires_grad表现是相反的 + paddle_tensor = paddle.to_tensor(torch_tensor.detach().numpy(), stop_gradient=False) + hook = paddle_tensor.register_hook( + lambda grad: torch.autograd.backward(torch_tensor, torch.tensor(grad.numpy())) + ) + else: + paddle_tensor = paddle.to_tensor(torch_tensor.detach().numpy(), stop_gradient=True) + + paddle_tensor = paddle_to(paddle_tensor, device) + + return paddle_tensor + + +def _jittor2torch(jittor_var: 'jittor.Var', device: Optional[Union[str, int]] = None, no_gradient: bool = None) -> 'torch.Tensor': + """ + 将 :class:`jittor.Var` 转换为 :class:`torch.Tensor` 。 + + :param jittor_var: 要转换的 **jittor** 变量; + :param device: 是否将转换后的张量迁移到特定设备上,输入为 ``None`` 时,根据 ``jittor.flags.use_cuda`` 决定; + :param no_gradient: 是否保留原张量的梯度。为``None``时,新的张量与输入张量保持一致; + 为 ``True`` 时,全部不保留梯度;为 ``False`` 时,全部保留梯度; + :return: 转换后的 **torch** 张量; + """ + # TODO: warning:无法保留梯度 + # jittor的grad可以通过callback进行传递 + # 如果outputs有_grad键,可以实现求导 + no_gradient = not jittor_var.requires_grad if no_gradient is None else no_gradient + if no_gradient == False: + warnings.warn("The result tensor will not keep gradients due to differences between jittor and pytorch.") + jittor_numpy = jittor_var.numpy() + if not np.issubdtype(jittor_numpy.dtype, np.inexact): + no_gradient = True + + if device is None: + # jittor的设备分配是自动的 + # 根据use_cuda判断 + if jittor.flags.use_cuda: + device = "cuda:0" + else: + device = "cpu" + + torch_tensor = torch.tensor(jittor_numpy, requires_grad=not no_gradient, device=device) + + return torch_tensor + + +def _torch2jittor(torch_tensor: 'torch.Tensor', no_gradient: bool = None) -> 'jittor.Var': + """ + 将 :class:`torch.Tensor` 转换为 :class:`jittor.Var` 。 + + :param torch_tensor: 要转换的 **torch** 张量; + :param no_gradient: 是否保留原张量的梯度。为``None``时,新的张量与输入张量保持一致; + 为 ``True`` 时,全部不保留梯度;为 ``False`` 时,全部保留梯度; + :return: 转换后的 **jittor** 变量; + """ + no_gradient = not torch_tensor.requires_grad if no_gradient is None else no_gradient + + if not no_gradient: + # 保持梯度并保持反向传播 + jittor_var = jittor.Var(torch_tensor.detach().numpy()) + jittor_var.requires_grad = True + hook = jittor_var.register_hook( + lambda grad: torch.autograd.backward(torch_tensor, torch.tensor(grad.numpy())) + ) + else: + jittor_var = jittor.Var(torch_tensor.detach().numpy()) + jittor_var.requires_grad = False + + return jittor_var + + +def torch2paddle(batch: Any, device: str = None, no_gradient: bool = None) -> Any: + """ + 递归地将输入中包含的 :class:`torch.Tensor` 转换为 :class:`paddle.Tensor` 。 + + :param batch: 包含 :class:`torch.Tensor` 类型的数据集合 + :param device: 是否将转换后的张量迁移到特定设备上。为 ``None`` 时,和输入保持一致; + :param no_gradient: 是否保留原张量的梯度。为 ``None`` 时,新的张量与输入张量保持一致; + 为 ``True`` 时,不保留梯度;为 ``False`` 时,保留梯度; + :return: 转换后的数据; + """ + + return apply_to_collection( + batch, + dtype=torch.Tensor, + function=_torch2paddle, + device=device, + no_gradient=no_gradient, + ) + + +def paddle2torch(batch: Any, device: str = None, no_gradient: bool = None) -> Any: + """ + 递归地将输入中包含的 :class:`paddle.Tensor` 转换为 :class:`torch.Tensor` 。 + + :param batch: 包含 :class:`paddle.Tensor` 类型的数据集合; + :param device: 是否将转换后的张量迁移到特定设备上。为 ``None``时,和输入保持一致; + :param no_gradient: 是否保留原张量的梯度。为 ``None`` 时,新的张量与输入张量保持一致; + 为 ``True`` 时,不保留梯度;为 ``False`` 时,保留梯度; + :return: 转换后的数据; + """ + + return apply_to_collection( + batch, + dtype=paddle.Tensor, + function=_paddle2torch, + device=device, + no_gradient=no_gradient, + ) + + +def jittor2torch(batch: Any, device: str = None, no_gradient: bool = None) -> Any: + """ + 递归地将输入中包含的 :class:`jittor.Var` 转换为 :class:`torch.Tensor` 。 + + .. note:: + + 注意,由于 **pytorch** 和 **jittor** 之间的差异,从 :class:`jittor.Var` 转换 + 至 :class:`torch.Tensor` 的过程中无法保留原张量的梯度。 + + :param batch: 包含 :class:`jittor.Var` 类型的数据集合; + :param device: 是否将转换后的张量迁移到特定设备上。为 ``None``时,和输入保持一致; + :param no_gradient: 是否保留原张量的梯度,在这个函数中该参数无效。 + :return: 转换后的数据; + """ + + return apply_to_collection( + batch, + dtype=jittor.Var, + function=_jittor2torch, + device=device, + no_gradient=no_gradient, + ) + + +def torch2jittor(batch: Any, no_gradient: bool = None) -> Any: + """ + 递归地将输入中包含的 :class:`torch.Tensor` 转换为 :class:`jittor.Var` 。 + + .. note:: + + **jittor** 会自动为创建的变量分配设备。 + + :param batch: 包含 :class:`torch.Tensor` 类型的数据集合; + :param no_gradient: 是否保留原张量的梯度。为 ``None`` 时,新的张量与输入张量保持一致; + 为 ``True`` 时,不保留梯度;为 ``False`` 时,保留梯度; + :return: 转换后的数据; + """ + + return apply_to_collection( + batch, + dtype=torch.Tensor, + function=_torch2jittor, + no_gradient=no_gradient, + ) \ No newline at end of file